Segmentation

[VOC PASCAL 2012] Semantic segmentation 하기 - 3

_DK_Kim 2024. 3. 3. 00:24
이전 포스팅 내용이 궁금하시다면 아래의 링크를 참고하시면 감사하겠습니다!
   - [VOC PASCAL 2012] Semantic segmentation 하기 - 1
   - [VOC PASCAL 2012] Semantic segmentation 하기 - 2

 

이번에는 학습 코드를 작성해보자. 모델은 지난 포스팅에 구현한 FCN과 U-Net을 사용할 것이다.


1. 라이브러리 불러오기 및 학습 configuration 작성

 

학습에 필요한 라이브러리 및 학습 configuration 값을 설정해보자.

 

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
import wandb
import torch.optim as opt
import torch.nn.functional as F
import random
import os

from tqdm.auto import tqdm
from torch.utils.data import DataLoader
from Dataset import _VOCdataset
from utils import Random_processing
from torchinfo import summary

from Model.FCN import O_FCN8s
from Model.UNet import UNet


random.seed(123)
torch.manual_seed(123)

class config:
    # model_name = 'O_FCN8s'
    model_name = 'UNet_vgg16_bn'
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    max_epoch = 25
    batch_size = 16

 

코드에서 보이는 것처럼, 25 epoch 동안 학습을 진행할 것이고 학습과 검증 모두 16개의 mini batch를 사용해서 진행할 것이다.

그 후, 아래의 mean값과 std값을 사용하여 정규화된 데이터셋을 사용하여 학습 및 검증 data loader를 만들어주자.

 

mean = [0.4567, 0.4431, 0.4086]
std = [0.2680, 0.2649, 0.2806]

ds_train = _VOCdataset(mode='train', transform=Random_processing(mean, std))
ds_valid = _VOCdataset(mode='valid', transform=Random_processing(mean, std))

dl_train = DataLoader(ds_train, batch_size=config.batch_size, shuffle=True)
dl_valid = DataLoader(ds_valid, batch_size=config.batch_size, shuffle=False)

 

다음으로, wandb에 Jupyter notebook을 연결해주자.

wandb.login()
WANDB_CONFIG = {'_wandb_kernel' : 'neuracort'}
run = wandb.init(
    project='PASCAL VOC 2012 Semantic Segmentation - Model performance',
    config=WANDB_CONFIG
)
wandb.run.name = 'UNet with Vgg16 backbone'

 

마지막으로 학습에 사용할 모델을 지정해준다.

 

# model = O_FCN8s(n_class=22).to(config.device)
model = UNet().to(config.device)

2. Loss 함수 정의하기

 

학습에 사용할 loss 함수를 정의해준다. 이번 학습에는 cross entropy와 dice loss를 결합하여 사용할 것이다.

두 loss 함수에 대한 설명은 아래의 링크를 참고하면 좋을 것 같다.

 

https://cvad.tistory.com/30

 

오차 함수의 종류와 특징

Task를 잘 수행할 있도록 모델을 학습시킬 때 고려해야할 점에는 무엇이 있을까? 모델의 구조나 데이터셋 등 다양한 요소가 있겠지만, 오차 함수도 모델 학습 결과에 큰 영향을 미친다. 오늘은 이

cvad.tistory.com

 

 

내가 구현하고자 하는 loss 함수를 수식으로 나타내면 아래와 같다.

$$  L = \lambda_{ce}\times L_{ce}+(1 - \lambda_{ce})\times L_{dice} $$

 

여기서, $lambda_{ce}$는 0~1 사이 값을 갖는 가중치로 저 값에 따라 두 loss 함수의 비중이 달라지는 형태다.

이번 학습에서는 두 loss 함수의 비율에 따른 학습 성능도 비교해 볼 예정이기 때문에 위와 같이 함수를 정의했다.

구현된 코드는 아래와 같다.

 

class CombinedLoss(nn.Module):
    def __init__(self, weight=0.5, smooth=1e-6):
        super(CombinedLoss, self).__init__()
        self.weight = weight
        self.smooth = smooth
        self.cross_entropy_loss = nn.CrossEntropyLoss()
        
    def dice_loss(self, inputs, targets):
        inputs = torch.softmax(inputs, dim=1)
        targets = F.one_hot(targets.long(), num_classes=inputs.shape[1]).permute(0, 3, 1, 2).float()
        
        intersection = (inputs * targets).sum((2, 3))
        dice = (2. * intersection + self.smooth) / (inputs.sum((2, 3)) + targets.sum((2, 3)) + self.smooth)
        
        return 1 - dice.mean()
    
    def forward(self, inputs, targets):
        ce_loss = self.cross_entropy_loss(inputs, targets.squeeze().long())
        dl_loss = self.dice_loss(inputs, targets.squeeze().long())
        return self.weight * ce_loss + (1 - self.weight) * dl_loss


loss_fn = CombinedLoss(weight=0.5)
optimizer = opt.Adam(model.parameters(), lr=1e-3)

 

optimizer는 무난하게 adam을 사용하였다.


3. 학습 코드 작성 및 진행

 

이제 학습을 진행하는 코드를 작성해보자.

 

class engine():
    def model_train(model, data, optimizer, loss_func):
        imgs, lbls = data['image'], data['label']
        model.train()

        imgs = imgs.to(config.device)
        lbls = lbls.to(config.device)

        preds = model(imgs)
        loss = loss_func(preds, lbls)

        optimizer.zero_grad()

        loss.backward()
        optimizer.step()

        return loss.item()
    
    @torch.no_grad()
    def model_valid(model, data, loss_func):
        imgs, lbls = data['image'], data['label']
        model.eval()

        imgs = imgs.to(config.device)
        lbls = lbls.to(config.device)

        preds = model(imgs)
        loss= loss_func(preds, lbls)

        return loss.item()

 

 

이전에 U-Net 구현 모델과 동일한 구조의 코드다.

 

log_path = f'Model/{config.model_name}/log'
best_model_path = log_path + '/best_model'

# Check log & best model path
if os.path.exists(best_model_path):
    pass
else:
    os.makedirs(best_model_path)

# Set test image
test_sample = next(iter(dl_valid))
test_img, test_lbl = test_sample['image'], test_sample['label']

# Save the view image and mask image
view_img = test_img[0].squeeze().permute(1,2,0).detach().cpu().numpy()
view_lbl = test_lbl[0].squeeze().detach().cpu().numpy()


# Reverse the normalization
view_img = (view_img*np.array(std) + np.array(mean))*255.0
view_img = view_img.astype(np.int16)

test_img = test_img.to(config.device)

 

마찬가지로, 매 epoch마다 학습 결과를 저장할 path를 만들어주고, wandb에 시각화할 이미지와 마스크 이미지를 정의해주었다.

 

def run():
    best_val_loss = 1e10
    
    for epoch in range(config.max_epoch):
        print('-----------------------------')
        print(f'    Epoch : {epoch}')
        print('-----------------------------')


        for _, data in tqdm(enumerate(dl_train), total=len(dl_train)):
            train_loss= engine.model_train(model, data, optimizer, loss_fn)

        for _, data in tqdm(enumerate(dl_valid), total=len(dl_valid)):
            valid_loss= engine.model_valid(model, data, loss_fn)

        torch.save(model.state_dict(), os.path.join(log_path, f'epoch_{epoch}.pth'))

        if valid_loss < best_val_loss:
            torch.save(model.state_dict(), os.path.join(best_model_path, 'best_model.pth'))
            best_val_loss = valid_loss

        with torch.no_grad():
            model.eval()
            
            view_preds = model(test_img)
            view_preds = torch.argmax(view_preds, dim=1)
            view_pred = view_preds[0].detach().cpu().numpy()
            view_pred = view_pred.astype(np.int16)

        wandb.log(
            {
                'epoch' : epoch,
                'train_loss' : train_loss,
                'valid_loss' : valid_loss,
                'Color Image' : [wandb.Image(view_img, caption='Color image', mode="RGB")],
                'Ground Truth' : [wandb.Image(view_lbl, caption='Ground truth', mode="L")],
                'Model prediction' : [wandb.Image(view_pred, caption='Model Prediction', mode="L")]
            }   
        )
        print()

run()

 

이제 이 코드를 실행해서 모델을 학습시켜주면 된다.

아쉽게도 wandb의 문제가 있는 것인지 wandb에 이미지가 업로드되지 않아서 epoch마다 모델의 성능 변화를 관찰할 순 없었다...


4. Metric

 

학습이 완료되면, best model을 불러와서 mIoU를 분석해보자. 코드는 다음과 같다.

 

import torch

log_path = f'Model/{config.model_name}/log'
best_model_path = log_path + '/best_model'

# model = O_FCN8s(n_class=22).to(config.device)
model = UNet().to(config.device)

best_model_weight = torch.load(os.path.join(best_model_path, 'best_model.pth'))
model.load_state_dict(best_model_weight)

def fast_hist(pred, label, n_class):
    mask = (label >= 0) & (label < n_class)
    hist = torch.bincount(
        n_class * label[mask].view(-1).long() + pred[mask].view(-1),
        minlength=n_class ** 2).reshape(n_class, n_class)
    return hist

def compute_IoU_per_class(hist):
    intersection = torch.diag(hist)
    union = hist.sum(1) + hist.sum(0) - intersection
    IoU = intersection / union
    return IoU

model.eval()
n_class = 22
total_hist = torch.zeros((n_class, n_class), device=config.device)  

with torch.no_grad():
    for data in dl_valid:
        img, lbl = data['image'], data['label']
        img = img.to(config.device)
        lbl = lbl.to(config.device).squeeze(1)  
        preds = model(img)
        preds = torch.argmax(preds, dim=1)  
        
        for pred, label in zip(preds, lbl):
            total_hist += fast_hist(pred, label, n_class).to(config.device)  
IoUs = compute_IoU_per_class(total_hist)

for i, IoU in enumerate(IoUs):
    print(f"Class {i} IoU: {IoU.item(): .4f}")

 

이렇게 학습 코드 구현에 대한 포스팅을 마무리하겠다. 다음 포스팅은 학습 결과를 분석해보겠다.

 

포스팅에 대한 질문이나 잘못된 사항있다면 댓글 남겨주시면 감사하겠습니다.

728x90