import torchvision
import torch
from PIL import Image, ImageFilter
import os
import numpy as np
import matplotlib.pyplot as plt
import random
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torchvision.transforms as transforms
import cv2
import glob
import math
from einops import rearrange
import timm
from tqdm.notebook import tqdm
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
num_epochs = 10
lr = 0.001
batch_size = 32
num_workers = 4
import zipfile
zip_file=zipfile.ZipFile('/content/drive/MyDrive/lesson_data/archive.zip')
zip_file.extractall(path='/content/data')
class_names = os.listdir('./data/train/')
class_names.sort()
class_len = len(class_names)
class Sports_Dataset(Dataset):
def __init__(self, data_name):
self.dataname = data_name
self.img_path = []
for name in class_names:
self.img_path.append(glob.glob(f'./data/{data_name}/{name}/*jpg'))
self.img_path = sum(self.img_path, [])
self.labels = []
for path in self.img_path:
self.labels.append(class_names.index(path.split('/')[3]))
self.img_transpose = transforms.Compose([transforms.ToTensor()])
def __getitem__(self, index):
img = Image.open(self.img_path[index])
if img.size != (224,224):
img = img.resize((224,224),Image.Resampling.BILINEAR)
if self.dataname == 'train':
if random.uniform(0,1) < 0.3 or img.getbands() == 'L':
img = img.convert('L').convert('RGB')
if random.uniform(0,1) < 0.3 :
img = img.resize((224+64,224+64), Image.Resampling.BILINEAR)
x = random.randrange(0,64)
y = random.randrange(0,64)
img = img.crop((x,y,x+224, y+224))
if random.uniform(0,1) < 0.2:
img = img.filter(ImageFilter.GaussianBlur(random.uniform(0.5,1.2)))
if random.uniform(0,1) < 0.3:
img = img.transpose(Image.Transpose.FLIP_LEFT_RIGHT)
else :
img = img.convert('RGB')
lbl = self.labels[index]
lbl = torch.tensor(lbl)
img = self.img_transpose(img)
return img, lbl
def __len__(self):
return len(self.img_path)
train_dataset = Sports_Dataset('train')
print(train_dataset.__len__())
_, ax = plt.subplots(2, 4, figsize=(16,10))
for i in range(8):
data = train_dataset.__getitem__(random.choice(range(train_dataset.__len__())))
image = data[0].cpu().detach().numpy().transpose(1, 2, 0) * 255
image = image.astype(np.uint32)
label = data[1]
ax[i//4][i-(i//4)*4].imshow(image)
ax[i//4][i-(i//4)*4].set_title(class_names[label])
model = timm.create_model('swin_base_patch4_window7_224', pretrained=True)
model.head = nn.Sequential(nn.Linear(1024, 512),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, class_len))
model = model.to(device)
criterion = timm.loss.LabelSmoothingCrossEntropy()
criterion = criterion.to(device)
optimizer = torch.optim.AdamW(model.head.parameters(), lr=lr)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
val_dataset = Sports_Dataset('valid')
val_loader = DataLoader(dataset=val_dataset, batch_size=1, shuffle=False)
def update_lr(optimizer, lr):
for param_group in optimizer.param_groups:
param_group['lr'] = lr
model.train()
total_step = len(train_loader)
curr_lr = lr
best_score = 0
for epoch in range(2):
total_loss = 0
for i, (images,labels) in enumerate(tqdm(train_loader)):
images = images.to(device)
labels = labels.to(device)
g_labels = model(images)
loss = criterion(g_labels,labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
if (i+1) % 100 == 0:
print(f'{batch_size*(i+1)} / {train_dataset.__len__()}')
model.eval()
score = 0
for i, (images, labels) in enumerate(valid_loader):
images = images.to(device)
labels = labels.to(device)
g_labels = model(images)
score += int(torch.max(g_labels, 1)[1][0] == labels[0])
print(f'Epoch : {epoch+1}, Loss : {total_loss/total_step}')
avg = score / len(val_dataset)
print(f'Accuracy : {avg :.2f}\n')
model.train()
if best_score < avg:
best_score = avg
if not os.path.exists('./nets'):
os.mkdir('./nets')
torch.save(model.state_dict(), 'nets/SwinTransformer.ckpt')
if (epoch+1) %2 == 0:
curr_lr = lr * 0.8
update_lr(optimizer, curr_lr)
model.eval()
model.load_state_dict(torch.load('nets/SwinTransformer.ckpt', map_location=device))
test_dataset = Sports_Dataset('test')
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
batch_size=1,
shuffle=False)
preds = []
gts = []
score = 0
for i, (images, labels) in enumerate(test_loader):
images = images.to(device)
labels = labels.to(device)
g_labels = model(images)
pred = torch.max(g_labels, 1)[1][0].item()
preds.append(pred)
gt = labels[0].item()
gts.append(gt)
score += int(pred == gt)
avg = score / len(val_dataset)
print('Accuracy: {:.4f}\n'.format(avg))
test_dataset = Sports_data('test')
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
batch_size=1,
shuffle=False)
_, ax = plt.subplots(2, 4, figsize=(16,10))
for i in range(8):
data = test_dataset.__getitem__(np.random.choice(range(test_dataset.__len__())))
image = data[0].cpu().detach().numpy().transpose(1, 2, 0) * 255
image = image.astype(np.uint32)
label = data[1]
idx = torch.max(model(data[0].unsqueeze(0).to(device)), 1)[1][0].item()
ax[i//4][i-(i//4)*4].imshow(image)
ax[i//4][i-(i//4)*4].set_title('Predict: {}\nGT: {}'.format(class_names[idx], class_names[label]))
for i in range(class_len):
score_sum = 0
for j in range(5):
score_sum += int(gts[i*5+j] == preds[i*5+j])
if score_sum == 5:
print('\033[92m' + '{}: {} / 5'.format(class_names[i], score_sum))
else:
print('\033[91m' + '{}: {} / 5'.format(class_names[i], score_sum))