CVAD
[ISBI 2012 segmentation] U-Net 모델 구현해보기-3 본문
이전 포스팅 내용이 궁금하시다면 아래의 링크를 참고하시면 감사하겠습니다.
- [ISBI 2012 segmentation] U-Net 모델 구현해보기-1
- [ISBI 2012 segmentation] U-Net 모델 구현해보기-2
오늘은 U-Net 모델 구현 포스팅의 마지막이다. 오늘은 WandB를 사용한 모델 학습 코드를 작성해보려고 한다.
시작하기 앞서서, WandB는 머신러닝 툴로서, 모델의 학습상태를 보여주기 위한 Dashboard를 제공한다.
그냥 loss, accuracy 등의 파라미터부터 이미지, 동영상 등의 시각화까지 도와주는 모듈이라고 생각하면 편하다.
WandB에 대해서는 차후 따로 포스팅하겠다.
그리고 지금까지 포스팅된 U-Net 모델 예제 코드는 아래의 깃허브 주소에서 받을 수 있다.
https://github.com/KDB0814/ISBI-semantic-segmentation
4. Model training
import matplotlib.pyplot as plt
import random
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import os
import wandb
from tqdm.auto import tqdm
from torch.utils.data import DataLoader
from Dataset import ISBI
from utils import Random_processing
from Model.Vanila_UNet import VanilaUNet
from torchsummary import summary
random.seed(0)
torch.manual_seed(0)
np.random.seed(0)
ds_train = ISBI(mode='train', transform=Random_processing())
ds_valid = ISBI(mode='valid', transform=Random_processing())
dl_train = DataLoader(ds_train, batch_size=1, shuffle=True)
dl_valid = DataLoader(ds_valid, batch_size=1 ,shuffle=False)
위 코드에서 seed 값을 설정해주는 이유는 Random processing에서 정의된 augmentation 들을 image와 mask에
동일하게 처리하기 위해서 설정한 것이다.
class config:
Device = 'cuda' if torch.cuda.is_available() else 'cpu'
Max_epoch = 20
train_batch = 1
valid_batch = 1
model_name = 'Vanila_UNet'
lr = 1e-3
학습에 사용할 파라미터들을 간단히 정리해준다. 나의 경우 이런 방식으로 train에 사용하는 모델과 학습 파라미터를 간편하게 확인할 수 있어서 항상 이런 식으로 정의한다.
wandb.login()
WANDB_CONFIG = {'_wandb_kernel' : 'neuracort'}
run = wandb.init(
project='ISBI 2012 Semantic segmentation',
config=WANDB_CONFIG
)
wandb.run.name = 'Vanila-UNet'
wandb에 연결해주고, project 및 모델 이름을 입력해준다.
model = VanilaUNet(in_channels=3, out_channels=2).to(config.Device)
# summary(model, input_size=(3, 440, 440))
이제 모델을 불러와준다. 우리는 컬러 이미지(3채널)을 통해 2가지 클래스를 예측하므로, in과 out channel 값을 위와 같이 설정해준다.
ce_loss = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr = config.lr, momentum=0.99)
def loss_function(preds, targets):
targets = targets.squeeze(0).long()
loss = ce_loss(preds, targets)
acc = (torch.max(preds, 1)[1] == targets).float().mean()
return loss, acc
class engine():
def model_train(model, data, optimizer, loss_func):
imgs, lbls = data['image'], data['label']
model.train()
imgs = imgs.to(config.Device)
lbls = lbls.to(config.Device)
preds = model(imgs)
loss, acc = loss_func(preds, lbls)
optimizer.zero_grad()
loss.backward()
optimizer.step()
return loss.item(), acc.item()
@torch.no_grad()
def model_valid(model, data, loss_func):
imgs, lbls = data['image'], data['label']
model.eval()
imgs = imgs.to(config.Device)
lbls = lbls.to(config.Device)
preds = model(imgs)
loss, acc = loss_func(preds, lbls)
return loss.item(), acc.item()
loss는 Cross entropy를 사용해준다. 엄밀히 말하자면, 논문에서처럼 가중치 함수를 곱해줘야하지만 해당 부분은 구현에 어려움이 있어서 생략했다. 그리고 정확도를 수치화하기 위해 pixel-wise accuracy를 사용했다.
log_path = f'Model/{config.model_name}/log'
best_model_path = log_path + '/best_model'
# Check log & best model path
if os.path.exists(best_model_path):
pass
else:
os.makedirs(best_model_path)
# Set test image
test_sample = next(iter(dl_valid))
test_img, test_lbl = test_sample['image'], test_sample['label']
view_img = test_img.squeeze().permute(1,2,0).numpy()*255.0
test_img = test_img.to(config.Device)
view_lbl = test_lbl.squeeze(0)
view_lbl = view_lbl.permute(1,2,0).numpy()
매 Epoch마다 학습된 모델을 저장할 path를 만들어준다. 그리고, WandB에 시각화할 이미지와 ground truth 마스크를 지정해준다.
이때, WandB의 Image 데이터는 numpy에서 인식하는 방식과 동일하므로 Height, Width, Channel 순서의 dimension이
되도록 각 데이터를 변환해준다.
def run():
best_val_acc = 0.0
for epoch in range(config.Max_epoch):
print('-----------------------------')
print(f' Epoch : {epoch}')
print('-----------------------------')
for _, data in tqdm(enumerate(dl_train), total=len(dl_train)):
train_loss, train_acc = engine.model_train(model, data, optimizer, loss_function)
for _, data in tqdm(enumerate(dl_valid), total=len(dl_valid)):
valid_loss, valid_acc = engine.model_valid(model, data, loss_function)
torch.save(model.state_dict(), os.path.join(log_path, f'epoch_{epoch}.pth'))
if valid_acc >= best_val_acc:
torch.save(model.state_dict(), os.path.join(best_model_path, 'best_model.pth'))
with torch.no_grad():
model.eval()
view_preds = model(test_img)
view_preds = torch.argmax(view_preds, dim=1)
view_pred = view_preds.permute(1,2,0).detach().cpu().numpy()
wandb.log(
{
'epoch' : epoch,
'train_loss' : train_loss,
'train_acc' : train_acc,
'valid_loss' : valid_loss,
'valid_acc' : valid_acc,
'Color Image' : [wandb.Image(view_img, caption='Color image')],
'Ground Truth' : [wandb.Image(view_lbl, caption='Ground truth')],
'Model prediction' : [wandb.Image(view_pred, caption='Model Prediction')]
}
)
print()
run()
-----------------------------
Epoch : 0
-----------------------------
100%|██████████| 96/96 [00:08<00:00, 10.70it/s]
100%|██████████| 24/24 [00:00<00:00, 26.63it/s]
-----------------------------
Epoch : 1
-----------------------------
100%|██████████| 96/96 [00:07<00:00, 12.70it/s]
100%|██████████| 24/24 [00:01<00:00, 22.43it/s]
-----------------------------
Epoch : 2
-----------------------------
100%|██████████| 96/96 [00:09<00:00, 10.20it/s]
100%|██████████| 24/24 [00:01<00:00, 18.38it/s]
-----------------------------
Epoch : 3
-----------------------------
100%|██████████| 96/96 [00:08<00:00, 11.36it/s]
100%|██████████| 24/24 [00:00<00:00, 27.89it/s]
-----------------------------
Epoch : 4
-----------------------------
100%|██████████| 96/96 [00:07<00:00, 13.06it/s]
100%|██████████| 24/24 [00:00<00:00, 29.28it/s]
-----------------------------
Epoch : 5
-----------------------------
100%|██████████| 96/96 [00:07<00:00, 12.87it/s]
100%|██████████| 24/24 [00:00<00:00, 28.58it/s]
-----------------------------
위의 코드를 실행하면 아래와 같이 진행바와 함께 학습이 진행되는 것을 확인할 수 있다.
5. Results
학습 결과는 WandB에서 확인할 수 있다.

간단하게 20 epoch만 돌려서 성능을 확인해보았다. train과 valid loss가 하강하고 있고, accuracy 모양이 이상하지만 그래도 우상향이다. 결과적으로 학습 과정은 잘 진행된 것을 알 수 있다.
아래는 best model로 이미지를 예측한 결과다.

Input tile의 빨간색 부분이 실제 예측 범위이고, 나머지는 Overlap-tile을 통해 반사된 영역이다. 아쉽게도 Groudn truth처럼 완벽한 경계인식이 이루어지지는 않았다. 또한, 원본 이미지의 미토콘드리아(? 맞는지는 모르겠다.)까지 세포벽으로 인식하여 경계로서 예측하였다.
7. Conclusion
학습 결과에 대한 고찰을 해보자.
먼저, 생각보다 높지않은 정확도의 가장 큰 원인은 경계선에 대한 가중치 함수가 없기 때문일 것이다.
가중치 함수가 일종의 가이드 라인으로서 작용하여 경계선에 대한 인식이 잘 이루어질 수록 좋은 loss를 반환하도록 보정 했어야하지만, 이 부분이 없기 때문에 정확도가 높지 않았던 것으로 추측된다.
몇 가지 더 고려해보자면, 데이터셋의 부족과 짧은 epoch를 생각해볼 수 있다.
학습 데이터셋은 총 96장의 이미지, 마스크 쌍이다. 이는 상당히 작은 데이터셋으로 많은 augmentation이 적용되야할
필요가 있다. 하지만, 내가 작성한 코드의 경우 극적인 augmentation이 일어난 만큼의 변화를 주지 않았고, 이 상황에서
epoch 마저 크지 않으니, 매 epoch당 비슷한 이미지를 가지고 학습했을 확률이 높다. 아마 이 상태에서 무작정 epoch만
높이면 over fitting이 날 것 같기도 하다.
그러므로, 학습을 개선할 수 있는 방법을 정리하자면 아래와 같다.
- Weight function 구현
- Dataset 추가 혹은 증강 및 epoch 수 조절
만약 시간이 난다면, weight function도 구현해보겠다.
마지막으로, 내가 구현한 U-Net 모델을 보면, 일반적으로 다른 사람들이 구현한 모델과 다르다는 것을 알 수 있다.
결론만 말하자면, 다른 사람이 구현한 모델과 내가 구현한 모델을 다른 것이 맞다.
엄밀히 말하면, 대부분의 인터넷에 있는 모델들은 원본 논문에서 제안한 모델을 조금 더 사용하기 쉽게 변형한 모델을 사용하는 것이다.
내가 구현한 모델은 U-Net 논문에서 나와있듯, 저자들은 1) 패딩을 사용하지 않았고, 2) U-Net모델은 입력 이미지와 출력 결과 이미지의 크기가 서로 다르다. ( U-Net 모델 리뷰 참고 : https://cvad.tistory.com/10 )
하지만, 다른 사람들이 구현한 모델들은 대부분 1)패딩이 설정되어있고, 2)입·출력 이미의 크기가 동일할 것이다.
사실, 이러한 방식이 기존의 CV 모델들에서 사용하는 방법이고 사용자에게 있어서는 더 편한 방법이다.
일반적으로 우리는 입력된 이미지 전체가 segmentation 되길 바라지, 일부분만 적용하는 걸 바라진 않을 것이다.
그리고, 테두리에 대한 정보를 살리고싶다면 padding을 추가하는 편한 방법이 있다. (완벽하게 보존되진 않을테지만)
논문에서 제안한 방식도 결국, 대칭성을 갖는 dataset에서는 유의미하지만, CityScape 같은 dataset에는 적용하기 어려운 augmentation이다.
때문에, 이러한 사항들을 고려하여 padding을 추가하고, U-Net의 architecture를 사용하기 위해 수정한 것이다.
아마, augmentation 중 mirroing 빼고 patch에 수정된 U-Net 모델로 학습하면 더 잘 될수도 있을 것이다.
(옛날에 해봤을 때는, 더 잘 되었던 것으로 기억한다.)
정리하자면, 내가 구현한 모델과 다른 사람들이 구현한 모델 모두 틀리지 않았다.
이렇게 U-Net 모델을 통해 ISBI 2012 구현을 마무리했다.
궁금하거나 질문은 댓글로 달아주시면 감사하겠습니다.
'Segmentation' 카테고리의 다른 글
| [VOC PASCAL 2012] Semantic segmentation 하기 - 3 (0) | 2024.03.03 |
|---|---|
| [VOC PASCAL 2012] Semantic segmentation 하기 - 2 (0) | 2024.03.02 |
| [VOC PASCAL 2012] Semantic segmentation 하기 - 1 (0) | 2024.02.27 |
| [ISBI 2012 segmentation] U-Net 모델로 구현해보기-2 (0) | 2024.01.04 |
| [ISBI 2012 segmentation] U-Net 모델로 구현해보기-1 (0) | 2023.12.22 |