Published on

미니프로젝트 part 3

Authors
  • avatar
    Name
    Inhwan Cho
    Twitter
  • 이번 미니 프로젝트의 목표는
  1. 구글 이미지를 웹크롤링을 이용하여 이미지 저장 후
  2. 이미지 분류 모델(Swin Transformer모델을 이용한 파인 튜닝)
  3. 모델 불러오기 + 구글 이미지를 분류 & 삭제(+정렬)
import torchvision
import torch
from PIL import Image, ImageFilter
import os
import numpy as np
import matplotlib.pyplot as plt
import random
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torchvision.transforms as transforms
import cv2
import glob
import math
from einops import rearrange #차원 관리 모듈
import timm # 파인튜닝모듈
from tqdm.notebook import tqdm

# configuration
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
num_epochs = 10
lr = 0.001
batch_size = 32
num_workers = 4 # ipykenel에서는 주석처리를 해야됩니다(멀티프로세싱 오류)

# 로컬 환경
class_names = os.listdir('./data/train/')
class_names = sorted(class_names)
class_len = len(class_names)

# Dataset 클래스 만들기
class Sports_Dataset(Dataset):
    def __init__(self, data_name):
    #data_name will be set 'train' or 'valid'(the folder names)
        self.dataname = data_name
        #train파일 경로리스트 생성     
        self.img_path = []
        #경로가 .jpg 확장자인 파일들을 img_path에 리스트화 시켜줌 
        for name in class_names:
            self.img_path.append(glob.glob(f'./data/{data_name}/{name}/*jpg'))
        #2차원 리스트 -> 1차원 리스트
        self.img_path = sum(self.img_path, [])
        #train파일의 labels 생성
        self.labels = []
        for path in self.img_path:
            self.labels.append(class_names.index(path.split('/')[3]))
        
        #텐서타입으로 변경하는 변수 생성
        self.img_transpose = transforms.Compose([transforms.ToTensor()])

    def __getitem__(self, index):
        img = Image.open(self.img_path[index])

        # swin_base_patch4_window7_224모델이 224,224로 트레이닝 된 모델이라서
        # 이미지 보간법을 이용하여 (224,224)사이즈로 변경
        if img.size != (224,224):
            img = img.resize((224,224),Image.Resampling.BILINEAR)
        
        # Data augmentation with PIL when only 'train set'
        if self.dataname == 'train':
            if random.uniform(0,1) < 0.3 or img.getbands() == 'L':
                img = img.convert('L').convert('RGB')
            
            # Random crop with size (64,64) from 30%
            if random.uniform(0,1) < 0.3 :
                img = img.resize((224+64,224+64), Image.Resampling.BILINEAR)
                x = random.randrange(0,64)
                y = random.randrange(0,64)
                img = img.crop((x,y,x+224, y+224))
                
                
            # Random Gaussian blur from 20%
            if random.uniform(0,1) < 0.2:
                img = img.filter(ImageFilter.GaussianBlur(random.uniform(0.5,1.2)))
            
            # Random flip from 30%
            if random.uniform(0,1) < 0.3:
                img = img.transpose(Image.Transpose.FLIP_LEFT_RIGHT)
        
        # 왜인지 모르겠지만 구글이미지 중 컬러인데 1채널이 존재해서 무조건 3채널로 바꿔줘야합니다.
        else :    
            img = img.convert('RGB')
        
        lbl = self.labels[index]
        lbl = torch.tensor(lbl)
        img = self.img_transpose(img)
        
        return img, lbl
        #Sports_Dataset[0] == img, Sports_Dataset[1] == lbl(데이터 사용 방식)

    def __len__(self):
        return len(self.img_path)
    
# 모델 로드

model = timm.create_model('swin_base_patch4_window7_224',pretrained=True)
model.head = nn.Sequential(
    nn.Linear(1024,512),
    nn.ReLU(),
    nn.Dropout(0.3),
    nn.Linear(512,class_len)
)
model = model.to(device)
model.eval()
model.load_state_dict(torch.load('./nets/SwinTransformer (1).ckpt', map_location=device))

test_dataset = Sports_Dataset('test')
test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False)
#불러온 모델(.ckpt파일)의 성능 검사
preds = []
gts = []

score = 0
for i, (images, labels) in enumerate(test_loader):
    images = images.to(device)
    labels = labels.to(device)

    g_labels = model(images)
    
    pred = torch.max(g_labels, 1)[1][0].item()
    a = torch.max(g_labels, 1)
    preds.append(pred)
    gt = labels[0].item()
    gts.append(gt)
    
    score += int(pred == gt)

avg = score / len(test_dataset)
print('Accuracy: {:.4f}\n'.format(avg))

Accuracy: 0.9840

test_dataset = Sports_Dataset('base_test')
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                           batch_size=1,
                                           shuffle=False)
len(test_loader) 
ori_len = len(test_dataset)
_, ax = plt.subplots(2, 4, figsize=(16,10))

for i in range(8):
    data = test_dataset.__getitem__(np.random.choice(range(test_dataset.__len__())))
    label = data[1]
    
    image = data[0].cpu().detach().numpy().transpose(1, 2, 0) * 255
    image = image.astype(np.uint32)
    
    label = data[1]
    
    idx = torch.max(model(data[0].unsqueeze(0).to(device)), 1)[1][0].item()
    
    ax[i//4][i-(i//4)*4].imshow(image)
    ax[i//4][i-(i//4)*4].set_title('Predict: {}\nGT: {}'.format(class_names[idx], class_names[label]))
  • 여기서 원본과 예상이 다른 파일들을 삭제할 겁니다.

output_24_0

folder_name = 'base_test'
target_folder = 'baseball'
test_dataset = Sports_Dataset(folder_name)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                           batch_size=1,
                                           shuffle=False)
# 원래 : baseball, 예측 : chuckwagon racing 13
ori_len = len(test_dataset)

for i in range(ori_len):
    try:
        data = test_dataset.__getitem__(i)
        label = data[1]
        
        idx = torch.max(model(data[0].unsqueeze(0).to(device)), 1)[1][0].item()
        print(f'원래 사진 : {class_names[label]}, 예측 : {class_names[idx]}', i)
        if class_names[label] != class_names[idx]:
            path = f'data/{folder_name}/{target_folder}/{class_names[label]}_{i}.jpg'
            os.remove(path)
            print('\033[31m'+f'removed file {path}'+'\033[30m')
        # if i = :
            # break
    except:
        print('\033[31m'+f'경로 {path}{class_names[label]}_{i}.jpg 사진이 없습니다'+'\033[30m')
#파일 이름 초기화 및 정렬 !!!!!!!(1회만 실행)!!!!!!!!
folder_name = 'base_test'
target_folder = 'baseball'

#폴더 경로
folder_path = 'data/base_test/baseball'
file_names = os.listdir(folder_path)

#파일 이름 초기화 후 재정렬(1회만 실행해야함 여러번 실행 시 중복된 이름이 삭제됨)
for i, name in enumerate(file_names):
    newname = f'{folder_path}/{target_folder}_{i}.jpg'
    print(newname)
    src = os.path.join(folder_path, name)
    os.rename(src, newname)