[VOC PASCAL 2012] Semantic segmentation 하기 - 3
이전 포스팅 내용이 궁금하시다면 아래의 링크를 참고하시면 감사하겠습니다!
- [VOC PASCAL 2012] Semantic segmentation 하기 - 1
- [VOC PASCAL 2012] Semantic segmentation 하기 - 2
이번에는 학습 코드를 작성해보자. 모델은 지난 포스팅에 구현한 FCN과 U-Net을 사용할 것이다.
1. 라이브러리 불러오기 및 학습 configuration 작성
학습에 필요한 라이브러리 및 학습 configuration 값을 설정해보자.
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
import wandb
import torch.optim as opt
import torch.nn.functional as F
import random
import os
from tqdm.auto import tqdm
from torch.utils.data import DataLoader
from Dataset import _VOCdataset
from utils import Random_processing
from torchinfo import summary
from Model.FCN import O_FCN8s
from Model.UNet import UNet
random.seed(123)
torch.manual_seed(123)
class config:
# model_name = 'O_FCN8s'
model_name = 'UNet_vgg16_bn'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
max_epoch = 25
batch_size = 16
코드에서 보이는 것처럼, 25 epoch 동안 학습을 진행할 것이고 학습과 검증 모두 16개의 mini batch를 사용해서 진행할 것이다.
그 후, 아래의 mean값과 std값을 사용하여 정규화된 데이터셋을 사용하여 학습 및 검증 data loader를 만들어주자.
mean = [0.4567, 0.4431, 0.4086]
std = [0.2680, 0.2649, 0.2806]
ds_train = _VOCdataset(mode='train', transform=Random_processing(mean, std))
ds_valid = _VOCdataset(mode='valid', transform=Random_processing(mean, std))
dl_train = DataLoader(ds_train, batch_size=config.batch_size, shuffle=True)
dl_valid = DataLoader(ds_valid, batch_size=config.batch_size, shuffle=False)
다음으로, wandb에 Jupyter notebook을 연결해주자.
wandb.login()
WANDB_CONFIG = {'_wandb_kernel' : 'neuracort'}
run = wandb.init(
project='PASCAL VOC 2012 Semantic Segmentation - Model performance',
config=WANDB_CONFIG
)
wandb.run.name = 'UNet with Vgg16 backbone'
마지막으로 학습에 사용할 모델을 지정해준다.
# model = O_FCN8s(n_class=22).to(config.device)
model = UNet().to(config.device)
2. Loss 함수 정의하기
학습에 사용할 loss 함수를 정의해준다. 이번 학습에는 cross entropy와 dice loss를 결합하여 사용할 것이다.
두 loss 함수에 대한 설명은 아래의 링크를 참고하면 좋을 것 같다.
오차 함수의 종류와 특징
Task를 잘 수행할 있도록 모델을 학습시킬 때 고려해야할 점에는 무엇이 있을까? 모델의 구조나 데이터셋 등 다양한 요소가 있겠지만, 오차 함수도 모델 학습 결과에 큰 영향을 미친다. 오늘은 이
cvad.tistory.com
내가 구현하고자 하는 loss 함수를 수식으로 나타내면 아래와 같다.
$$ L = \lambda_{ce}\times L_{ce}+(1 - \lambda_{ce})\times L_{dice} $$
여기서, $lambda_{ce}$는 0~1 사이 값을 갖는 가중치로 저 값에 따라 두 loss 함수의 비중이 달라지는 형태다.
이번 학습에서는 두 loss 함수의 비율에 따른 학습 성능도 비교해 볼 예정이기 때문에 위와 같이 함수를 정의했다.
구현된 코드는 아래와 같다.
class CombinedLoss(nn.Module):
def __init__(self, weight=0.5, smooth=1e-6):
super(CombinedLoss, self).__init__()
self.weight = weight
self.smooth = smooth
self.cross_entropy_loss = nn.CrossEntropyLoss()
def dice_loss(self, inputs, targets):
inputs = torch.softmax(inputs, dim=1)
targets = F.one_hot(targets.long(), num_classes=inputs.shape[1]).permute(0, 3, 1, 2).float()
intersection = (inputs * targets).sum((2, 3))
dice = (2. * intersection + self.smooth) / (inputs.sum((2, 3)) + targets.sum((2, 3)) + self.smooth)
return 1 - dice.mean()
def forward(self, inputs, targets):
ce_loss = self.cross_entropy_loss(inputs, targets.squeeze().long())
dl_loss = self.dice_loss(inputs, targets.squeeze().long())
return self.weight * ce_loss + (1 - self.weight) * dl_loss
loss_fn = CombinedLoss(weight=0.5)
optimizer = opt.Adam(model.parameters(), lr=1e-3)
optimizer는 무난하게 adam을 사용하였다.
3. 학습 코드 작성 및 진행
이제 학습을 진행하는 코드를 작성해보자.
class engine():
def model_train(model, data, optimizer, loss_func):
imgs, lbls = data['image'], data['label']
model.train()
imgs = imgs.to(config.device)
lbls = lbls.to(config.device)
preds = model(imgs)
loss = loss_func(preds, lbls)
optimizer.zero_grad()
loss.backward()
optimizer.step()
return loss.item()
@torch.no_grad()
def model_valid(model, data, loss_func):
imgs, lbls = data['image'], data['label']
model.eval()
imgs = imgs.to(config.device)
lbls = lbls.to(config.device)
preds = model(imgs)
loss= loss_func(preds, lbls)
return loss.item()
이전에 U-Net 구현 모델과 동일한 구조의 코드다.
log_path = f'Model/{config.model_name}/log'
best_model_path = log_path + '/best_model'
# Check log & best model path
if os.path.exists(best_model_path):
pass
else:
os.makedirs(best_model_path)
# Set test image
test_sample = next(iter(dl_valid))
test_img, test_lbl = test_sample['image'], test_sample['label']
# Save the view image and mask image
view_img = test_img[0].squeeze().permute(1,2,0).detach().cpu().numpy()
view_lbl = test_lbl[0].squeeze().detach().cpu().numpy()
# Reverse the normalization
view_img = (view_img*np.array(std) + np.array(mean))*255.0
view_img = view_img.astype(np.int16)
test_img = test_img.to(config.device)
마찬가지로, 매 epoch마다 학습 결과를 저장할 path를 만들어주고, wandb에 시각화할 이미지와 마스크 이미지를 정의해주었다.
def run():
best_val_loss = 1e10
for epoch in range(config.max_epoch):
print('-----------------------------')
print(f' Epoch : {epoch}')
print('-----------------------------')
for _, data in tqdm(enumerate(dl_train), total=len(dl_train)):
train_loss= engine.model_train(model, data, optimizer, loss_fn)
for _, data in tqdm(enumerate(dl_valid), total=len(dl_valid)):
valid_loss= engine.model_valid(model, data, loss_fn)
torch.save(model.state_dict(), os.path.join(log_path, f'epoch_{epoch}.pth'))
if valid_loss < best_val_loss:
torch.save(model.state_dict(), os.path.join(best_model_path, 'best_model.pth'))
best_val_loss = valid_loss
with torch.no_grad():
model.eval()
view_preds = model(test_img)
view_preds = torch.argmax(view_preds, dim=1)
view_pred = view_preds[0].detach().cpu().numpy()
view_pred = view_pred.astype(np.int16)
wandb.log(
{
'epoch' : epoch,
'train_loss' : train_loss,
'valid_loss' : valid_loss,
'Color Image' : [wandb.Image(view_img, caption='Color image', mode="RGB")],
'Ground Truth' : [wandb.Image(view_lbl, caption='Ground truth', mode="L")],
'Model prediction' : [wandb.Image(view_pred, caption='Model Prediction', mode="L")]
}
)
print()
run()
이제 이 코드를 실행해서 모델을 학습시켜주면 된다.
아쉽게도 wandb의 문제가 있는 것인지 wandb에 이미지가 업로드되지 않아서 epoch마다 모델의 성능 변화를 관찰할 순 없었다...
4. Metric
학습이 완료되면, best model을 불러와서 mIoU를 분석해보자. 코드는 다음과 같다.
import torch
log_path = f'Model/{config.model_name}/log'
best_model_path = log_path + '/best_model'
# model = O_FCN8s(n_class=22).to(config.device)
model = UNet().to(config.device)
best_model_weight = torch.load(os.path.join(best_model_path, 'best_model.pth'))
model.load_state_dict(best_model_weight)
def fast_hist(pred, label, n_class):
mask = (label >= 0) & (label < n_class)
hist = torch.bincount(
n_class * label[mask].view(-1).long() + pred[mask].view(-1),
minlength=n_class ** 2).reshape(n_class, n_class)
return hist
def compute_IoU_per_class(hist):
intersection = torch.diag(hist)
union = hist.sum(1) + hist.sum(0) - intersection
IoU = intersection / union
return IoU
model.eval()
n_class = 22
total_hist = torch.zeros((n_class, n_class), device=config.device)
with torch.no_grad():
for data in dl_valid:
img, lbl = data['image'], data['label']
img = img.to(config.device)
lbl = lbl.to(config.device).squeeze(1)
preds = model(img)
preds = torch.argmax(preds, dim=1)
for pred, label in zip(preds, lbl):
total_hist += fast_hist(pred, label, n_class).to(config.device)
IoUs = compute_IoU_per_class(total_hist)
for i, IoU in enumerate(IoUs):
print(f"Class {i} IoU: {IoU.item(): .4f}")
이렇게 학습 코드 구현에 대한 포스팅을 마무리하겠다. 다음 포스팅은 학습 결과를 분석해보겠다.
포스팅에 대한 질문이나 잘못된 사항있다면 댓글 남겨주시면 감사하겠습니다.