CVAD
[ISBI 2012 segmentation] U-Net 모델로 구현해보기-2 본문
이전 포스팅 내용이 궁금하시다면 아래의 링크 참고하시면 감사하겠습니다!
- [ISBI 2012 segmentation] U-Net 모델 구현해보기-1
지난 번 U-Net모델 구현하기-1 포스팅에 이어 2번째 포스팅이다. 앞서 얘기한대로 오늘은 Dataset에 대한 정의와 Data augmentation, Model까지 작성하겠다.
2. Dataset & Data augmentation
아래는 데이터셋을 정의하는 코드다.
import re
import os
import cv2
import torch
from torch.utils.data import Dataset
def sorter(text):
num = re.findall(r'\d+', text)
return int(num[0]) if num else 0
class ISBI(Dataset):
def __init__(self, mode, transform=None):
'''
ISBI 2012 dataset
- mode : 'train' or 'valid'
- transform : Data augmentation (default : None)
'''
self.mode = mode
self.transform = transform
self.img_path = f'../../Data/ISBI/{self.mode}/image/aug'
self.lbl_path = f'../../Data/ISBI/{self.mode}/label/aug'
self.img_list = sorted(os.listdir(self.img_path), key=sorter)
self.lbl_list = sorted(os.listdir(self.lbl_path), key=sorter)
def __len__(self):
return len(self.img_list)
def __getitem__(self, idx):
img = cv2.imread(os.path.join(self.img_path, self.img_list[idx]), cv2.IMREAD_COLOR)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
lbl = cv2.imread(os.path.join(self.lbl_path, self.lbl_list[idx]), cv2.IMREAD_GRAYSCALE)
# Make numpy to tensor
# H W C => C H W
img = img.transpose(2, 0, 1)
img = torch.from_numpy(img)
lbl = torch.from_numpy(lbl)
lbl = lbl.unsqueeze(0)
sample = {'image': img, 'label': lbl}
if self.transform:
sample = self.transform(sample)
return sample
간단하게 살펴보면, opencv(이하 cv2)를 사용하여 이미지와 마스크 이미지를 불러오고 이를 pytorch를 사용하여 tensor형태로 변환해준다. 여기서 중요한건 cv2에서 다루는 이미지 데이터의 shape과 pytorch에서 다루는 데이터의 shape이 서로 다르다는 점인데 주석에 나와있듯이, cv2는 numpy기반으로 이미지를 읽을 때 Height, Width, Channel 순으로 인식하지만, pytorch는 Channel, Height, Width 순으로 인식한다. 항상 이 점을 염두해야한다.
sorter함수는 이미지와 마스크를 내가 정의한 이름대로 잘 불러오기 위해 key값에 입력할 함수다.
나의 경우 image_{index}.jpg or .png 형태로 이름을 저장하는 습관이 있는데, 이를 index값에 따라 잘 배열하려면 저 key 함수가 꼭 필요하다.
이제 dataset을 정의했으니, data augmentation을 적용해보자. 코드는 아래와 같다.
import torchvision.transforms as T
import random
import numpy as np
import torch
from torchvision.transforms import InterpolationMode
def apply_mirroring(image, mirroring_size=94):
img = img.permute(1, 2, 0)
mirrored_image = np.pad(image, pad_width=((mirroring_size, mirroring_size),
(mirroring_size, mirroring_size),
(0, 0)), mode='reflect')
mirrored_image = mirrored_image.transpose(2, 0, 1)
return mirrored_image
class Random_processing(object):
def __init__(self):
self.h_flip = T.RandomHorizontalFlip()
self.shift = T.RandomAffine(translate=(.001, .001), degrees=0)
self.gray_value = T.ColorJitter(brightness=0.2 + 0.1*random.random())
self.elastic = T.ElasticTransform(interpolation=InterpolationMode.NEAREST)
def __call__(self, samples):
img, lbl = samples['image'], samples['label']
if random.random() > 0.5:
img, lbl = self.shift(img), self.shift(lbl)
if random.random() > 0.5:
img, lbl = self.h_flip(img), self.shift(lbl)
if random.random() > 0.5:
img = self.gray_value(img)
if random.random() > 0.5:
img, lbl = self.elastic(img), self.elastic(lbl)
img = apply_mirroring(img)
img = torch.from_numpy(img)
img = img/255.0
lbl = lbl/255.0
sample = {'image': img, 'label': lbl}
return sample
코드를 살펴보면, 논문에서 언급한 augmentation 방법 4가지가 각각 h_flip, shift, gray_value, elasitc으로 구현된 것을 확인할 수 있다. mirroring_size는 이전 포스팅에서 분할한 patch_size 256과 U-Net의 최종 출력 segmentation map을 고려하여 결정하였다.
위 코드에서 gray_value는 마스크 이미지(lbl)에 적용하지 않는데, 이는 마스크에서 정의된 class 값을 유지하기 위해서다.
근데, 아마 코드 하단부를 보면 lbl를 255로 나누는 것에 대해 의문을 가질 수도 있을 것이다. 이는 모델의 출력을 2차원으로 만들어 줄 것이기 때문에, 마스크 class를 [0, 255]에서 [0, 1]로 mapping한 것이다. img를 255로 나눈것은 당연히 이미지의 데이터 값을 0~1사이로 sacling 하기 위함이다.
이제 Data에 augmentation을 하여 불러오는 전체 과정은 다 끝났다. 남은 것은 모델을 정의하고, 이를 학습시키는 과정만 남았다.
3. U-Net Model
U-Net 모델의 구조는 앞서 논문 리뷰에서 살펴보았으니, 설명은 생략하겠다. 코드는 아래와 같다.
import torch
import torch.nn as nn
class VanilaUNet(nn.Module):
def __init__(self, in_channels=3, out_channels=2):
super(VanilaUNet, self).__init__()
'''
Vanila U-Net model
- in_channels = Chanel of Image
- out_channels = Number of class
'''
self.in_channels = in_channels
self.out_channels = out_channels
self.enc1 = self.conv_block(self.in_channels, 64)
self.enc2 = self.conv_block(64, 128)
self.enc3 = self.conv_block(128, 256)
self.enc4 = self.conv_block(256, 512)
self.pool = nn.MaxPool2d(2, 2)
self.bottel_neck = self.conv_block(512, 1024)
self.upconv1 = self.upconv_block(128, 64)
self.upconv2 = self.upconv_block(256, 128)
self.upconv3 = self.upconv_block(512, 256)
self.upconv4 = self.upconv_block(1024, 512)
self.dec4 = self.conv_block(1024,512)
self.dec3 = self.conv_block(512,256)
self.dec2 = self.conv_block(256, 128)
self.dec1 = self.conv_block(128, 64)
self.head = nn.Conv2d(64, self.out_channels, kernel_size=1)
def conv_block(self, in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=0),
nn.ReLU(),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=0),
nn.ReLU()
)
def upconv_block(self, in_channels, out_channels):
return nn.Sequential(
nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
nn.ReLU()
)
def copy_and_crop(self, input_feature, target_feature):
height, width = input_feature.shape[2:]
target_height, target_width = target_feature.shape[2:]
start_h = int((height - target_height) / 2)
start_w = int((width - target_width) / 2)
cropped_map = input_feature[:, :, start_h:start_h+target_height, start_w : start_w+target_width]
return cropped_map
def forward(self, x):
# 1st down-sampling
conv1 = self.enc1(x)
down1 = self.pool(conv1)
# 2nd down-sampling
conv2 = self.enc2(down1)
down2 = self.pool(conv2)
# 3rd down-sampling
conv3 = self.enc3(down2)
down3 = self.pool(conv3)
# 4th down-sampling
conv4 = self.enc4(down3)
down4 = self.pool(conv4)
# Bottle-neck
bottle = self.bottel_neck(down4)
# 4th up-sampling
up4 = self.upconv4(bottle)
concat4 = torch.concat([self.copy_and_crop(conv4, up4), up4], 1)
deconv4 = self.dec4(concat4)
# 3rd up-sampling
up3 = self.upconv3(deconv4)
concat3 = torch.concat([self.copy_and_crop(conv3, up3), up3], 1)
deconv3 = self.dec3(concat3)
# 2nd up-sampling
up2 = self.upconv2(deconv3)
concat2 = torch.concat([self.copy_and_crop(conv2, up2), up2], 1)
deconv2 = self.dec2(concat2)
# 1st up-sampling
up1 = self.upconv1(deconv2)
concat1 = torch.concat([self.copy_and_crop(conv1, up1), up1], 1)
deconv1 = self.dec1(concat1)
# out
seg_map = self.head(deconv1)
out = self.copy_and_crop(seg_map, torch.randn(1, 2, 256, 256))
return out
사실, 이 모델의 구조는 출력 부분에 억지스러운 부분이 있다. 바로 segmentation_map의 크기를 256X256으로 crop
하는 부분인데 segmentation_map의 크기가 260X260이고 input tile에 위치한 patch의 크기가 256X256이였기 때문에 이를 맞춰주기 위한 약간의 보정이었다.
원래 논문에서처럼 572X572 크기의 input tile을 388X388 image patch를 사용하여 생성한다. 즉, 내 코드에서 patch의 크기를 388X388로 자르도록 수정하고, mirroring_size의 크기를 92로 수정하여 데이터를 사용한다면 seg_map을 바로 출력으로 사용할 수 있다.
논문 리뷰에서도 말했지만, 이런 경우 1 이미지에 대해 1개의 patch만을 사용하는 것인데 이러한 방법은 비효율적인 것
같았다. 따라서, 한 이미지에서 여러 patch를 잘라내다보니 크기에서 발생하는 문제였고, 이를 보정하기 위해 위와 같은 방식을 사용하였다.
이제 모델의 정의까지 끝냈으니, 다음 포스팅부터는 학습 방법에 대해 알아보고, 결과를 확인해보겠다.
'Segmentation' 카테고리의 다른 글
[VOC PASCAL 2012] Semantic segmentation 하기 - 3 (0) | 2024.03.03 |
---|---|
[VOC PASCAL 2012] Semantic segmentation 하기 - 2 (0) | 2024.03.02 |
[VOC PASCAL 2012] Semantic segmentation 하기 - 1 (0) | 2024.02.27 |
[ISBI 2012 segmentation] U-Net 모델 구현해보기-3 (0) | 2024.01.07 |
[ISBI 2012 segmentation] U-Net 모델로 구현해보기-1 (0) | 2023.12.22 |