In [1]:
import logging
import os

import numpy as np
import pandas as pd
import torch
from sklearn.model_selection import KFold
from tensorboardX import SummaryWriter
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, RandomResizedCrop, RandomRotation, RandomHorizontalFlip, ToTensor, \
    Resize, RandomAffine, ColorJitter

from dataset import MaskDataset, get_img_files
from loss import dice_loss
from nets.MobileNetV2_unet import MobileNetV2_unet
from trainer import Trainer

np.random.seed(1)
torch.backends.cudnn.deterministic = True
torch.manual_seed(1)


Out[1]:
<torch._C.Generator at 0x7f997f613750>

In [2]:
N_CV = 5
BATCH_SIZE = 8
LR = 1e-4

N_EPOCHS = 100
IMG_SIZE = 224
RANDOM_STATE = 1

EXPERIMENT = 'train_unet'
OUT_DIR = 'outputs/{}'.format(EXPERIMENT)

In [3]:
def get_data_loaders(train_files, val_files):
    train_transform = Compose([
        ColorJitter(0.3, 0.3, 0.3, 0.3),
        RandomResizedCrop(IMG_SIZE, scale=(0.8, 1.2)),
        RandomAffine(10.),
        RandomRotation(13.),
        RandomHorizontalFlip(),
        ToTensor(),
    ])
    val_transform = Compose([
        Resize((IMG_SIZE, IMG_SIZE)),
        ToTensor(),
    ])

    train_loader = DataLoader(MaskDataset(train_files, train_transform),
                              batch_size=BATCH_SIZE,
                              shuffle=True,
                              pin_memory=True,
                              num_workers=4)
    val_loader = DataLoader(MaskDataset(val_files, val_transform),
                            batch_size=BATCH_SIZE,
                            shuffle=False,
                            pin_memory=True,
                            num_workers=4)

    return train_loader, val_loader


def save_best_model(cv, model, df_hist):
    if df_hist['val_loss'].tail(1).iloc[0] <= df_hist['val_loss'].min():
        torch.save(model.state_dict(), '{}/{}-best.pth'.format(OUT_DIR, cv))


def write_on_board(writer, df_hist):
    row = df_hist.tail(1).iloc[0]

    writer.add_scalars('{}/loss'.format(EXPERIMENT), {
        'train': row.train_loss,
        'val': row.val_loss,
    }, row.epoch)


def log_hist(df_hist):
    last = df_hist.tail(1)
    best = df_hist.sort_values('val_loss').head(1)
    summary = pd.concat((last, best)).reset_index(drop=True)
    summary['name'] = ['Last', 'Best']
    logger.debug(summary[['name', 'epoch', 'train_loss', 'val_loss']])
    logger.debug('')


def run_cv():
    image_files = get_img_files()
    kf = KFold(n_splits=N_CV, random_state=RANDOM_STATE, shuffle=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    for n, (train_idx, val_idx) in enumerate(kf.split(image_files)):
        train_files = image_files[train_idx]
        val_files = image_files[val_idx]

        writer = SummaryWriter()

        def on_after_epoch(m, df_hist):
            save_best_model(n, m, df_hist)
            write_on_board(writer, df_hist)
            log_hist(df_hist)

        criterion = dice_loss(scale=2)
        data_loaders = get_data_loaders(train_files, val_files)
        trainer = Trainer(data_loaders, criterion, device, on_after_epoch)

        model = MobileNetV2_unet()
        model.to(device)
        optimizer = Adam(model.parameters(), lr=LR)

        hist = trainer.train(model, optimizer, num_epochs=N_EPOCHS)
        hist.to_csv('{}/{}-hist.csv'.format(OUT_DIR, n), index=False)

        writer.close()

        break

In [4]:
if not os.path.exists(OUT_DIR):
    os.makedirs(OUT_DIR)

logger = logging.getLogger("logger")
logger.setLevel(logging.DEBUG)
if not logger.hasHandlers():
    logger.addHandler(logging.FileHandler(filename="outputs/{}.log".format(EXPERIMENT)))

run_cv()

In [ ]: