In [1]:
import os
import subprocess
# NDSB_DIR = '/afs/ee.cooper.edu/user/t/a/tam8/documents/ndsb2015'
NDSB_DIR = '/media/raid_arr/data/ndsb/config'
# TRAIN_SCRIPT = os.path.join(NDSB_DIR, 'train_pl.sh')
# RESUME_SCRIPT = os.path.join(NDSB_DIR, 'resume_training_pl.sh')
SOLVER = os.path.join(NDSB_DIR, 'solver_pl.prototxt')
NET = os.path.join(NDSB_DIR, 'train_val_pl.prototxt')
CAFFE = '/afs/ee.cooper.edu/user/t/a/tam8/documents/caffe/build/tools/caffe'
MODELS_DIR = '/media/raid_arr/data/ndsb/models'
snapshot_prefix = 'pl_iter_'
MAX_ITER = 100000    # global max (not per step)
STEP = 1000

In [2]:
def iter_to_epoch(ii):
    """
    Convert the iteration # to an approximate epoch
    ii: iteration
    """
    k = 30336.0/384    # Iterations per epoch (# training / batch size)
    return ii/k

def w_schedule(t):
    """
    Weight of the PL as a fn of iter
    t: epoch
    """
    wf = 3.0
    t1 = 100
    t2 = 600
    if t < t1:
        w = 0.0
    elif t > t2:
        w = wf
    else:
        w = wf*(float(t)-t1)/(t2-t1)
    return w

def write_weight_to_net(w, f_path=NET):
    """
    Could regex or import net,
    brute force much easier
    """
    with open(f_path, 'r') as f:
    # read a list of lines into data
        f_data = f.readlines()

    # Change the line with the PL loss weight
    line_n = 549    # The line number we're going to replace
    new_line = '  loss_weight: ' + str(w) + '\n'
    f_data[line_n] = new_line

    # and write everything back
    with open(f_path, 'w') as f:
        f.writelines(f_data)
    return f_path

def write_max_iter_to_solver(max_iter, f_path=SOLVER):
    with open(f_path, 'r') as f:
    # read a list of lines into data
        f_data = f.readlines()

    # Change the line with the PL loss weight
    line_n = 8    # The line number we're going to replace
    new_line = 'max_iter: ' + str(max_iter) + '\n'
    f_data[line_n] = new_line

    # and write everything back
    with open(f_path, 'w') as f:
        f.writelines(f_data)
    return f_path
    
    
    

snap_name = lambda n_iter: os.path.join(MODELS_DIR,
            snapshot_prefix + str(n_iter) + '.solverstate')

call_start = lambda sol=SOLVER: subprocess.call(
            [CAFFE, 'train', '--solver=' + sol])
call_resume = lambda snap, sol=SOLVER: subprocess.call(
            [CAFFE, 'train', '--solver=' + sol, '--snapshot=' + snap])

In [4]:
last_saved_iter = max({int(os.path.splitext(f)[0].rsplit('_', 1)[1]) for f in next(os.walk(MODELS_DIR))[2]}.union{0})

for ii in range(last_saved_iter, MAX_ITER+1, STEP):
    e = iter_to_epoch(ii)
    w = w_schedule(e)
#     write_weight_to_net(w)
#     write_max_iter_to_solver(ii + STEP)
    print 'ITER:\t', ii, '\nWEIGHT:\t', w
#     if ii == 0:
#         call_start()
#     else:
#         call_resume(snap_name(ii))


ITER:	0 
WEIGHT:	0.0
ITER:	1000 
WEIGHT:	0.0
ITER:	2000 
WEIGHT:	0.0
ITER:	3000 
WEIGHT:	0.0
ITER:	4000 
WEIGHT:	0.0
ITER:	5000 
WEIGHT:	0.0
ITER:	6000 
WEIGHT:	0.0
ITER:	7000 
WEIGHT:	0.0
ITER:	8000 
WEIGHT:	0.00759493670886
ITER:	9000 
WEIGHT:	0.0835443037975
ITER:	10000 
WEIGHT:	0.159493670886
ITER:	11000 
WEIGHT:	0.235443037975
ITER:	12000 
WEIGHT:	0.311392405063
ITER:	13000 
WEIGHT:	0.387341772152
ITER:	14000 
WEIGHT:	0.463291139241
ITER:	15000 
WEIGHT:	0.539240506329
ITER:	16000 
WEIGHT:	0.615189873418
ITER:	17000 
WEIGHT:	0.691139240506
ITER:	18000 
WEIGHT:	0.767088607595
ITER:	19000 
WEIGHT:	0.843037974684
ITER:	20000 
WEIGHT:	0.918987341772
ITER:	21000 
WEIGHT:	0.994936708861
ITER:	22000 
WEIGHT:	1.07088607595
ITER:	23000 
WEIGHT:	1.14683544304
ITER:	24000 
WEIGHT:	1.22278481013
ITER:	25000 
WEIGHT:	1.29873417722
ITER:	26000 
WEIGHT:	1.3746835443
ITER:	27000 
WEIGHT:	1.45063291139
ITER:	28000 
WEIGHT:	1.52658227848
ITER:	29000 
WEIGHT:	1.60253164557
ITER:	30000 
WEIGHT:	1.67848101266
ITER:	31000 
WEIGHT:	1.75443037975
ITER:	32000 
WEIGHT:	1.83037974684
ITER:	33000 
WEIGHT:	1.90632911392
ITER:	34000 
WEIGHT:	1.98227848101
ITER:	35000 
WEIGHT:	2.0582278481
ITER:	36000 
WEIGHT:	2.13417721519
ITER:	37000 
WEIGHT:	2.21012658228
ITER:	38000 
WEIGHT:	2.28607594937
ITER:	39000 
WEIGHT:	2.36202531646
ITER:	40000 
WEIGHT:	2.43797468354
ITER:	41000 
WEIGHT:	2.51392405063
ITER:	42000 
WEIGHT:	2.58987341772
ITER:	43000 
WEIGHT:	2.66582278481
ITER:	44000 
WEIGHT:	2.7417721519
ITER:	45000 
WEIGHT:	2.81772151899
ITER:	46000 
WEIGHT:	2.89367088608
ITER:	47000 
WEIGHT:	2.96962025316
ITER:	48000 
WEIGHT:	3.0
ITER:	49000 
WEIGHT:	3.0
ITER:	50000 
WEIGHT:	3.0
ITER:	51000 
WEIGHT:	3.0
ITER:	52000 
WEIGHT:	3.0
ITER:	53000 
WEIGHT:	3.0
ITER:	54000 
WEIGHT:	3.0
ITER:	55000 
WEIGHT:	3.0
ITER:	56000 
WEIGHT:	3.0
ITER:	57000 
WEIGHT:	3.0
ITER:	58000 
WEIGHT:	3.0
ITER:	59000 
WEIGHT:	3.0
ITER:	60000 
WEIGHT:	3.0
ITER:	61000 
WEIGHT:	3.0
ITER:	62000 
WEIGHT:	3.0
ITER:	63000 
WEIGHT:	3.0
ITER:	64000 
WEIGHT:	3.0
ITER:	65000 
WEIGHT:	3.0
ITER:	66000 
WEIGHT:	3.0
ITER:	67000 
WEIGHT:	3.0
ITER:	68000 
WEIGHT:	3.0
ITER:	69000 
WEIGHT:	3.0
ITER:	70000 
WEIGHT:	3.0
ITER:	71000 
WEIGHT:	3.0
ITER:	72000 
WEIGHT:	3.0
ITER:	73000 
WEIGHT:	3.0
ITER:	74000 
WEIGHT:	3.0
ITER:	75000 
WEIGHT:	3.0
ITER:	76000 
WEIGHT:	3.0
ITER:	77000 
WEIGHT:	3.0
ITER:	78000 
WEIGHT:	3.0
ITER:	79000 
WEIGHT:	3.0
ITER:	80000 
WEIGHT:	3.0
ITER:	81000 
WEIGHT:	3.0
ITER:	82000 
WEIGHT:	3.0
ITER:	83000 
WEIGHT:	3.0
ITER:	84000 
WEIGHT:	3.0
ITER:	85000 
WEIGHT:	3.0
ITER:	86000 
WEIGHT:	3.0
ITER:	87000 
WEIGHT:	3.0
ITER:	88000 
WEIGHT:	3.0
ITER:	89000 
WEIGHT:	3.0
ITER:	90000 
WEIGHT:	3.0
ITER:	91000 
WEIGHT:	3.0
ITER:	92000 
WEIGHT:	3.0
ITER:	93000 
WEIGHT:	3.0
ITER:	94000 
WEIGHT:	3.0
ITER:	95000 
WEIGHT:	3.0
ITER:	96000 
WEIGHT:	3.0
ITER:	97000 
WEIGHT:	3.0
ITER:	98000 
WEIGHT:	3.0
ITER:	99000 
WEIGHT:	3.0
ITER:	100000 
WEIGHT:	3.0

In [19]:
qq = max({int(os.path.splitext(f)[0].rsplit('_', 1)[1]) for f in next(os.walk(MODELS_DIR))[2]})
range(qq, MAX_ITER+1, STEP)


Out[19]:
[1000,
 2000,
 3000,
 4000,
 5000,
 6000,
 7000,
 8000,
 9000,
 10000,
 11000,
 12000,
 13000,
 14000,
 15000,
 16000,
 17000,
 18000,
 19000,
 20000,
 21000,
 22000,
 23000,
 24000,
 25000,
 26000,
 27000,
 28000,
 29000,
 30000,
 31000,
 32000,
 33000,
 34000,
 35000,
 36000,
 37000,
 38000,
 39000,
 40000,
 41000,
 42000,
 43000,
 44000,
 45000,
 46000,
 47000,
 48000,
 49000,
 50000,
 51000,
 52000,
 53000,
 54000,
 55000,
 56000,
 57000,
 58000,
 59000,
 60000,
 61000,
 62000,
 63000,
 64000,
 65000,
 66000,
 67000,
 68000,
 69000,
 70000,
 71000,
 72000,
 73000,
 74000,
 75000,
 76000,
 77000,
 78000,
 79000,
 80000,
 81000,
 82000,
 83000,
 84000,
 85000,
 86000,
 87000,
 88000,
 89000,
 90000,
 91000,
 92000,
 93000,
 94000,
 95000,
 96000,
 97000,
 98000,
 99000,
 100000]