CVAD

전이학습(Transfer learning)이란? - 2 본문

AI 개념정리

전이학습(Transfer learning)이란? - 2

_DK_Kim 2024. 3. 24. 18:18
이전 포스팅 내용이 궁금하시다면 아래의 링크를 참고하시면 감사하겠습니다!
   - 전이학습(Transfer learning)이란? - 1

 

저번 포스팅에서는 전이 학습의 개념과 원리, 적용 방법론에 대해서 알아보았다. 이번에는 실제 모델을 사용해서 전이 학습을 적용해 보자.

 

두 가지 방법론 중 이번에 구현할 것은 feature extractor로 사용하는 방법이다. 사전 학습된 VGG16 모델을 사용해서 간단한 classifcation task를 해결하도록 classifier 부분만 새로 교체하여 학습할 것이다.

 

이번 포스팅의 절차는 아래와 같다.

 

  1. 데이터셋 준비
  2. Classifier 제거 후 학습
  3. 결과 확인

1. 데이터셋 준비

데이터셋은 'Dogs vs Cats'을 사용할 것이다. 이 데이터셋은 아래의 Kaggle 링크를 통해 다운로드할 수 있다.

 

https://www.kaggle.com/competitions/dogs-vs-cats/data

 

Dogs vs. Cats | Kaggle

 

www.kaggle.com

 

여기서 train 데이터를 다운로드하여서 확인해 보면, 25000장의 개와 고양이 사진을 확인할 수 있다.

 

 

그리고 파일명에 개 사진인지 고양이 사진인지 표시되어 있다.

이제 각 class 이미지 수를 알아보자. 

 

import os

data_path = '../../Data/DogsandCats/train'
img_list = os.listdir(data_path)

cat = 0
dog = 0

for img in img_list:
    if img[:3] == 'cat':
        cat +=1
    elif img[:3] == 'dog':
        dog +=1

print(f'cat : {cat} || dog : {dog}')
========================================================================================
cat : 12500 || dog : 12500

 

위의 코드를 실행하면 결과를 알 수 있다. 개와 고양이가 딱 절반이다.

나는 train 파일의 이미지들을 두 클래스 균일하도록 나누어서 검증 데이터셋으로 분할하였다.

 

전체 데이터의 80%만 학습으로 사용하고, 나머지는 검증용으로 valid 폴더를 생성해서 저장해 보자.

 

import shutil

dogs_list = [imgs for imgs in img_list if imgs[:3] == 'dog']
cats_list = [imgs for imgs in img_list if imgs[:3] == 'cat']

valid_path = '../../Data/DogsandCats/valid'

if not os.path.exists(valid_path):
    os.makedirs(valid_path)

num_imgs = int(len(dogs_list)*0.2)
for i in range(num_imgs):
    shutil.move(os.path.join(data_path, dogs_list[i]), os.path.join(valid_path, dogs_list[i]))
    shutil.move(os.path.join(data_path, cats_list[i]), os.path.join(valid_path, cats_list[i]))

 

그러면 개와 고양이 사진이 각각 2500장씩, 총 5000장의 검증 데이터셋을 만들 수 있다. 이제, 데이터셋 파일을 정의해 보자.

 

import os
from torch.utils.data import Dataset
from torchvision.io import read_image

data_path = '../../Data/DogsandCats'

class DogsandCats(Dataset):
    def __init__(self, mode, transform=None):
        '''
        mode = 'train' or 'valid' | call train dataset or validation dataset
        transform : None(default)
        '''

        self.mode = mode
        if self.mode == 'train':
            self.img_path = data_path + '/train'
        else:
            self.img_path = data_path + '/valid'
        
        self.img_list = os.listdir(self.img_path)
        self.transform = transform

    def __len__(self):
        return len(self.img_list)
    
    def __getitem__(self, idx):
        name = self.img_list[idx][:3]
        if name == 'dog':
            lbl = 0
        elif name == 'cat':
            lbl = 1
        img = read_image(os.path.join(self.img_path, self.img_list[idx]))/255.0

        if self.transform:
            img = self.transform(img)

        return img, lbl

 

개 사진일 경우 label값을 0으로 주었고, 반대의 경우는 1을 주었다. 그 외의 dataset 구성 코드는 다른 예제 포스팅에서 다뤘던 형식을 사용하였다.


2. Classifier를 제거 후 학습

먼저 우리가 확인해 볼 것은 classifier 부분을 제거한 후 사전학습된 VGG16 모델의 출력층만 바꿔서 볼 것이다.

모델 구현 코드를 작성해 보자.

 

import torch
import torchvision.models as models
import torch.nn as nn
from torchinfo import summary

class VGG16_new_classifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1)

        for param in self.model.parameters():
            param.requires_grad = False
        
        self.classifier = nn.Sequential(
            nn.Linear(25088, 4096), 
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, 2)  
        )

        # Replace the classifier in the pre-trained model with the new one
        self.model.classifier = self.classifier

    def forward(self, x):
        x = self.model(x)
        return x

model = VGG16_new_classifier().to(config.device)
# summary(model, (1,3,224,224))

 

위의 코드를 잠시 살펴보면, torchvision에서 제공하는 vgg16 모델을 불러온 뒤 classifier를 제외한 나머지 layer의 가중치를 고정시켜 준다.

 

그리고, 원래 classifier를 학습되지 않은 새로운 classifier로 바꿔주면 된다.

 

이제 학습에 필요한 코드들을 작성해 보자

 

import torch.optim as optim
import torch.nn as nn
import matplotlib.pyplot as plt
import torch
import os

from torch.utils.data import DataLoader
from Dataset import DogsandCats
from torchvision.transforms import v2 as T
from tqdm.auto import tqdm


def my_transforms():
    transforms = T.Compose([
        T.Resize([224, 224], interpolation=T.InterpolationMode.NEAREST_EXACT),
        T.Normalize(mean=[.4885, .4553, .4172], std=[.2293, .2249, .2252]),
        T.ToDtype(torch.float, scale=True),
        T.ToPureTensor(),
    ]
    )
    return transforms

ds_train = DogsandCats(mode='train', transform=my_transforms())
ds_valid = DogsandCats(mode='valid', transform=my_transforms())

dl_train = DataLoader(ds_train, batch_size=64, shuffle=True, num_workers=4)
dl_valid = DataLoader(ds_valid, batch_size=32, shuffle=False)

class config():
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model_name = 'VGG16_new_classifier'
    max_epoch = 5

 

전처리는 위와 같이 이미지 크기를 $224\times224$로 맞춰주고 정규화만 적용해 주었다. 데이터 증강의 경우 별도로 진행하지 않았는데, 예측 클래스 수가 적고, 학습 데이터셋의 크기가 충분히 크기 때문이다.

 

이미 대부분의 layer가 학습되어 있으니, 5 epoch만 학습을 해보자. 아마 충분히 학습될 것이다.

 

오차 함수로는 Corssentrpy를 사용해 주고, 최적화기법은 Adam을 사용하였다.

 

loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=.1)

log_path = f'Model/{config.model_name}/log'
best_model_path = log_path + '/best_model'

# Check log & best model path
if os.path.exists(best_model_path):
    pass
else:
    os.makedirs(best_model_path)


class engine():
    def model_train(model, data, optimizer, loss_fn):
        imgs, lbls = data['image'], data['label']
        imgs, lbls = imgs.to(config.device), lbls.to(config.device)

        model.train()

        preds = model(imgs)
        loss = loss_fn(preds, lbls)

        optimizer.zero_grad()

        loss.backward()
        optimizer.step()

        return loss.item()
    
    @ torch.no_grad()
    def model_valid(mdoel, data, loss_fn):
        imgs, lbls = data['image'], data['label']
        imgs, lbls = imgs.to(config.device), lbls.to(config.device)

        model.eval()

        preds = model(imgs)
        loss = loss_fn(preds, lbls)

        return loss.item()


def run():
    best_val_loss = 1e9
    
    for epoch in range(config.max_epoch):

        for _, data in tqdm(enumerate(dl_train), total=len(dl_train)):
            train_loss = engine.model_train(model, data, optimizer, loss_fn)

        for _, data in tqdm(enumerate(dl_valid), total=len(dl_valid)):
            valid_loss = engine.model_valid(model, data, loss_fn)


        torch.save(model.state_dict(), os.path.join(log_path, f'epoch_{epoch}.pth'))

        if valid_loss < best_val_loss:
            torch.save(model.state_dict(), os.path.join(best_model_path, 'best_model.pth'))

        print(f'[Epoch : {epoch}/{config.max_epoch}] Train loss : {train_loss} || Validation loss : {valid_loss}')
        print()
run()

3. 결과 확인

아래는 5 epoch 학습의 결과다.

100%|██████████| 313/313 [01:11<00:00,  4.35it/s]
100%|██████████| 157/157 [01:21<00:00,  1.92it/s]
[Epoch : 0/5] Train loss : 0.14762777090072632 || Validation loss : 0.0013951573055237532

100%|██████████| 313/313 [01:10<00:00,  4.42it/s]
100%|██████████| 157/157 [00:17<00:00,  8.89it/s]
[Epoch : 1/5] Train loss : 0.015938522294163704 || Validation loss : 0.0007364111370407045

100%|██████████| 313/313 [01:10<00:00,  4.42it/s]
100%|██████████| 157/157 [00:17<00:00,  8.81it/s]
[Epoch : 2/5] Train loss : 0.05040908232331276 || Validation loss : 0.002780304057523608

100%|██████████| 313/313 [01:10<00:00,  4.42it/s]
100%|██████████| 157/157 [00:17<00:00,  8.87it/s]
[Epoch : 3/5] Train loss : 0.06553690880537033 || Validation loss : 0.0857008844614029

100%|██████████| 313/313 [01:10<00:00,  4.42it/s]
100%|██████████| 157/157 [00:17<00:00,  8.89it/s]
[Epoch : 4/5] Train loss : 0.2567245662212372 || Validation loss : 0.013202370144426823

 

이미 대부분 학습되어 있는 모델이라 그런지, 2번째 epoch에 best에 도달하고 그 이후에는 아마 over fitting이 난 것 같다.

 

가장 validation loss가 낮은 모델로 학습 결과를 살펴보자.

 

from torchvision.io import read_image

test_img_1 = '../../Data/DogsandCats/valid/cat.0.jpg'
test_img_2 = '../../Data/DogsandCats/valid/dog.7.jpg'

img1 = read_image(test_img_1)
img2 = read_image(test_img_2)
imgs = [img1, img2]

get_transforms = my_transforms()
model = VGG16_new_classifier().to(config.device)
best_state = torch.load(os.path.join(best_model_path, 'best_model.pth'))
model.load_state_dict(best_state)

for img in imgs:
    img = get_transforms(img/255.0).unsqueeze(dim=0)
    with torch.no_grad():
        model.eval()
        img = img.to(config.device)
        pred = model(img)

        real_label = torch.argmax(pred).item()

    if real_label == 1:
        print('Cat')
    elif real_label == 0:
        print('Dog')
========================================================================================
Cat
Dog

 

처음 예상처럼 예측이 잘 이루어지는 것을 확인할 수 있다.


이번 포스팅에서 우리는 전이 학습의 2가지 방법론 중 feature extractor로 사용하는 예제를 다루었다. 위의 예제처럼 유사한 task에 대해 진행할 수도 있지만, 완전히 다른 task에서 활용할 수 있다. Backbone으로 사용하는 것이 이러한 경우에 속한다.

 

만약에 Backbone으로 사용하는 경우에 대해 더 알고 싶다면, VOC PASCAL segmentation 예제의 두 모델코드를 참고하면 좋을 것 같다. 다음 포스팅에서는 나머지 한 개의 방법론 예제를 다루겠다.

 

포스팅에 대한 질문이나 잘못된 사항 댓글 달아주시면 감사하겠습니다! :)

 

728x90