In [1]:
#This script takes raw dicom patient images, applies various transformations on 3d scan figure and saves it to one h5 matrix per shard. Step4 will merge it all.
#Saved values are in HU scale
IMAGE_DIMS = (312, 212, 312, 1)
NR_SHARDS = 700
RANDOM_SEED = 0.1
SAVE_IMAGES = True
#Patient DICOM images folder
INPUT_FOLDER = '../../../input/stage1_images/'
LABELS_FILE = '../../../input/stage1_labels.csv'
BASE_OUTPUT_FOLDER = '../../../output/kaggle-bowl/step3/'
In [2]:
import csv
import sys
import h5py
import pandas as pd
import os
import numpy as np # linear algebra
from numpy import ndarray
from random import shuffle
import csv
import os
from modules.logging import logger
import modules.logging
import modules.lungprepare as lungprepare
import modules.utils as utils
from modules.utils import Timer
import modules.logging
from multiprocessing import Pool
import multiprocessing as mp
In [3]:
def get_patient_ids(shard_id, input_dir):
shard_patients = []
file = csv.DictReader(open(LABELS_FILE))
for row in file:
p = row['id']
if(int(p,16)%NR_SHARDS == (shard_id-1)):
shard_patients.append(p)
logger.info('found {} patients for shard {}'.format(len(shard_patients), shard_id))
shuffle(shard_patients, lambda: RANDOM_SEED)
return shard_patients
# force ids
# return ['0c37613214faddf8701ca41e6d43f56e', '0a0c32c9e08cc2ea76a71649de56be6d', '0a38e7597ca26f9374f8ea2770ba870d']
# return ['16377fe7caf072d882f234dbbff9ef6c']
In [4]:
def patient_label(input_dir, patient_id):
labels = pd.read_csv(LABELS_FILE, index_col=0)
value = labels.get_value(patient_id, 'cancer')
#one-hot encoding
label = np.array([0,1])
if(value == 0): label = np.array([1,0])
return label
In [5]:
def start_processing(input_dir, shard_id, max_patients, image_dims, base_output_dir):
t = Timer('shard', debug=False)
output_dir = base_output_dir + str(shard_id) + '/'
logger.info('Processing patients. shard_id=' + str(shard_id) + ' max_patients='+ str(max_patients) + ' input_dir=' + input_dir + ' output_dir=' + output_dir)
#check if this shard was already processed
file_done = output_dir + 'done'
if(os.path.isfile(file_done)):
logger.warning('Shard ' + str(shard_id) + ' already processed. Skipping it.')
return 'shard ' + str(shard_id) + ': SKIPPED'
logger.info('Gathering patient ids for this shard')
patient_ids = get_patient_ids(shard_id, input_dir)
total_patients = len(patient_ids)
dataset_name = 'data-centered-rotated'
logger.info('Preparing output dir')
utils.mkdirs(output_dir, dirs=['images'], recreate=True)
modules.logging.setup_file_logger(base_output_dir + 'out.log')
logger.info('Creating datasets')
dataset_file = utils.dataset_path(output_dir, dataset_name, image_dims)
with h5py.File(dataset_file, 'w') as h5f:
x_ds = h5f.create_dataset('X', (total_patients, image_dims[0], image_dims[1], image_dims[2], image_dims[3]), chunks=(1, image_dims[0], image_dims[1], image_dims[2], image_dims[3]), dtype='f')
y_ds = h5f.create_dataset('Y', (total_patients, 2), dtype='f')
logger.info('Starting to process each patient (count={})'.format(len(patient_ids)))
count = 0
record_row = 0
for patient_id in patient_ids:
if(count>(max_patients-1)):
break
t = Timer('>>> PATIENT PROCESSING ' + patient_id + ' (count=' + str(count) + '; output_dir=' + output_dir + ')')
patient_pixels = lungprepare.process_patient_images(input_dir + patient_id, image_dims, output_dir, patient_id)
if(patient_pixels != None):
if(not np.any(patient_pixels)):
logger.error('Patient pixels returned with zero values patient_id=' + patient_id)
logger.info('Recording patient pixels to output dataset count=' + str(count))
x_ds[record_row] = patient_pixels
label = patient_label(input_dir, patient_id)
y_ds[record_row] = label
record_row = record_row + 1
else:
logger.warning('Patient lung not found. Skipping.')
t.stop()
count = count + 1
if(not utils.validate_dataset(output_dir, dataset_name, image_dims, save_dir=output_dir + 'images/')):
logger.error('Validation ERROR on shard ' + str(shard_id))
return 'shard ' + str(shard_id) + ': ERROR ' + str(t.elapsed()*1000) + 'ms'
logger.info('Marking shard as processed')
f = open(file_done, 'w')
f.write('OK')
f.close()
return 'shard ' + str(shard_id) + ': OK ' + str(t.elapsed()*1000) + 'ms'
In [6]:
logger.info('==== PROCESSING SHARDS IN PARALLEL ====')
from random import randint
from time import sleep
def process_shard(shard_id):
try:
sleep(randint(0,20))
return start_processing(INPUT_FOLDER, shard_id, 999, IMAGE_DIMS, BASE_OUTPUT_FOLDER)
except BaseException as e:
logger.warning('Exception while processing shard ' + str(shard_id) + ': ' + str(e))
return 'shard ' + str(shard_id) + ' exception: ' + str(e)
#mp.set_start_method('spawn')
n_processes = mp.cpu_count()
#n_processes = 1
logger.info('Using ' + str(n_processes) + ' parallel tasks')
with Pool(n_processes) as p:
shards = list(range(1,NR_SHARDS+1))
shuffle(shards)
# shards = [23]
#http://stackoverflow.com/questions/26520781/multiprocessing-pool-whats-the-difference-between-map-async-and-imap
for i in p.imap_unordered(process_shard, shards):
print(i)
logger.info('==== ALL DONE ====')
In [ ]:
In [ ]:
In [ ]:
In [ ]:
In [ ]: