Train 3d nodule detector with LUNA16 dataset


In [1]:
INPUT_DIR = '../../input/nodules/'
OUTPUT_DIR = '../../output/lung-cancer/03/'
IMAGE_DIMS = (50,50,50,1)

In [3]:
%matplotlib inline
import numpy as np
import pandas as pd
import h5py
import matplotlib.pyplot as plt
import sklearn
import os
import glob

from modules.logging import logger
import modules.utils as utils
from modules.utils import Timer
import modules.logging
import modules.cnn as cnn
import modules.ctscan as ctscan

Training

Prepare output dir


In [4]:
utils.mkdirs(OUTPUT_DIR, recreate=True)
modules.logging.setup_file_logger(OUTPUT_DIR + 'out.log')
logger.info('Dir ' + OUTPUT_DIR + ' created')


2017-03-25 14:13:46,454 INFO Dir ../../output/lung-cancer/03/ created

Prepare CNN model


In [ ]:
logger.info('Prepare CNN for training')
network = cnn.net_nodule3d_swethasubramanian(IMAGE_DIMS)
model = cnn.prepare_cnn_model(network, OUTPUT_DIR, model_file=None)


2017-03-25 14:13:47,855 INFO Prepare CNN for training
2017-03-25 14:13:48,023 INFO Prepare CNN
2017-03-25 14:13:48,025 INFO Preparing output dir
2017-03-25 14:13:48,026 INFO Initializing network...

Train model


In [ ]:
dataset_path = INPUT_DIR + 'nodules-train.h5'

with h5py.File(dataset_path, 'r') as train_hdf5:
    X = train_hdf5['X']
    Y = train_hdf5['Y']
    logger.info('X shape ' + str(X.shape))
    logger.info('Y shape ' + str(Y.shape))

    dataset_path = INPUT_DIR + 'nodules-validate.h5'
    with h5py.File(dataset_path, 'r') as validate_hdf5:
        X_validate = validate_hdf5['X']
        Y_validate = validate_hdf5['Y']
        logger.info('X_validate shape ' + str(X_validate.shape))
        logger.info('Y_validate shape ' + str(Y_validate.shape))

        logger.info('Starting CNN training...')
        model.fit(X, Y, 
            validation_set=(X_validate, Y_validate), 
            shuffle=True, 
            batch_size=96, 
            n_epoch=100,
            show_metric=True,
            snapshot_epoch=True,
            run_id='nodule_classifier')

model.save(OUTPUT_DIR + "nodule-classifier.tfl")
logger.info("Network trained and saved as nodule-classifier.tfl!")


2017-03-25 14:13:51,483 INFO X shape (6616, 50, 50, 50, 1)
2017-03-25 14:13:51,485 INFO Y shape (6616, 2)
2017-03-25 14:13:51,488 INFO X_validate shape (1248, 50, 50, 50, 1)
2017-03-25 14:13:51,490 INFO Y_validate shape (1248, 2)
2017-03-25 14:13:51,491 INFO Starting CNN training...
---------------------------------
Run id: nodule_classifier
Log directory: ../../output/lung-cancer/03/tf-logs/
INFO:tensorflow:Summary name Accuracy/ (raw) is illegal; using Accuracy/__raw_ instead.
2017-03-25 14:13:51,898 INFO Summary name Accuracy/ (raw) is illegal; using Accuracy/__raw_ instead.
---------------------------------
Training samples: 6616
Validation samples: 1248
--

Evaluate results


In [ ]:
logger.info('Evaluate dataset')
evaluate_dataset(OUTPUT_DIR + 'nodules-test.h5', model, batch_size=12, confusion_matrix=True)