In [1]:
import os, glob, platform, datetime, random
from collections import OrderedDict

import torch
import torch.nn as nn
import torch.utils.data as data_utils
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
from torch.autograd import Variable
from torch import functional as F
# import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms

import cv2
from PIL import Image
from tensorboardX import SummaryWriter

import numpy as np
from numpy.linalg import inv as denseinv
from scipy import sparse
from scipy.sparse import lil_matrix, csr_matrix
from scipy.sparse.linalg import spsolve
from scipy.sparse.linalg import inv as spinv
import scipy.misc

from myimagefolder import MyImageFolder
from mymodel import GradientNet
from myargs import Args

Configurations


In [2]:
args = Args()
args.test_scene = 'alley_1'
args.arch = "densenet121"
args.epoches = 500
args.epoches_unary_threshold = 0
args.image_h = 256
args.image_w = 256
args.img_extentions = ["png"]
args.training_thresholds = [250,200,150,50,0,300]
args.base_lr = 1
args.lr = args.base_lr
args.snapshot_interval = 5000
args.debug = True
args.gpu_num = 0
# growth_rate = (4*(2**(args.gpu_num)))
transition_scale=(2*(2**(args.gpu_num+1)))
growth_rate = 32
args.display_interval = 50
args.display_curindex = 0

system_ = platform.system()
system_dist, system_version, _ = platform.dist()
if system_ == "Darwin": 
    args.train_dir = '/Volumes/Transcend/dataset/sintel2'
    args.pretrained = False
elif platform.dist() ==  ('debian', 'jessie/sid', ''):
    args.train_dir = '/home/lwp/workspace/sintel2'
    args.pretrained = True
elif platform.dist() == ('debian', 'stretch/sid', ''):
    args.train_dir = '/home/cad/lwp/workspace/dataset/sintel2'
    args.pretrained = True

if platform.system() == 'Linux': use_gpu = True
else: use_gpu = False
if use_gpu:
    torch.cuda.set_device(args.gpu_num)
    

print(platform.dist())


('debian', 'jessie/sid', '')

My DataLoader


In [3]:
train_dataset = MyImageFolder(args.train_dir, 'train',
                       transforms.Compose(
        [transforms.ToTensor()]
    ), random_crop=True, 
    img_extentions=args.img_extentions, test_scene=args.test_scene, image_h=args.image_h, image_w=args.image_w)
test_dataset = MyImageFolder(args.train_dir, 'test', 
                       transforms.Compose(
        [transforms.CenterCrop((args.image_h, args.image_w)),
         transforms.ToTensor()]
    ), random_crop=False,
    img_extentions=args.img_extentions, test_scene=args.test_scene, image_h=args.image_h, image_w=args.image_w)

train_loader = data_utils.DataLoader(train_dataset,1,True,num_workers=1)
test_loader = data_utils.DataLoader(test_dataset,1,True,num_workers=1)

Load Pretrained Model

Defination

  • DenseNet-121: num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16)
    • First Convolution: 32M -> 16M -> 8M
    • every transition: 8M -> 4M -> 2M (downsample 1/2, except the last block)

In [4]:
densenet = models.__dict__[args.arch](pretrained=args.pretrained)

for param in densenet.parameters():
    param.requires_grad = False

if use_gpu: densenet.cuda()

In [5]:
ss = 6

args.display_curindex = 0
args.base_lr = 0.05
args.display_interval = 20
args.momentum = 0.9
args.epoches = 120
args.training_thresholds = [ss*4,ss*3,ss*2,ss*1,ss*0,ss*5]
args.power = 0.5



# pretrained = PreTrainedModel(densenet)
# if use_gpu: 
#     pretrained.cuda()


net = GradientNet(densenet=densenet, growth_rate=growth_rate, transition_scale=transition_scale)
if use_gpu:
    net.cuda()

if use_gpu: 
    mse_losses = [nn.MSELoss().cuda()] * 6
    test_losses = [nn.MSELoss().cuda()] * 6
else:
    mse_losses = [nn.MSELoss()] * 6
    test_losses = [nn.MSELoss()] * 6

In [6]:
# training loop
writer = SummaryWriter(comment='-transition_scale_{}'.format(transition_scale))

parameters = filter(lambda p: p.requires_grad, net.parameters())
optimizer = optim.SGD(parameters, lr=args.base_lr, momentum=args.momentum)

def adjust_learning_rate(optimizer, epoch, beg, end, reset_lr=None):
    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
    for param_group in optimizer.param_groups:
        if reset_lr != None:
            param_group['lr'] = reset_lr
            continue
        if epoch != 0: 
            # linear
#             param_group['lr'] *= (end-epoch) / (end-beg)
#             poly base_lr (1 - iter/max_iter) ^ (power)
            param_group['lr'] = args.base_lr * (float(end-epoch)/(end-beg)) ** (args.power)
            if param_group['lr'] < 1.0e-8: param_group['lr'] = 1.0e-8
        print('lr', param_group['lr'])
        
# def findLargerInd(target, arr):
#     res = list(filter(lambda x: x>target, arr))
#     print('res',res)
#     if len(res) == 0: return -1
#     return res[0]

for epoch in range(args.epoches):
    net.train()
    print('epoch: {} [{}]'.format(epoch, datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")))
    
    if epoch < args.training_thresholds[-1]: adjust_learning_rate(optimizer, epoch%ss, beg=0, end=ss-1)
    else: adjust_learning_rate(optimizer, epoch, beg=args.training_thresholds[-1], end=args.epoches-1)
    
    run_losses = [0] * len(args.training_thresholds)
    run_cnts   = [0.00001] * len(args.training_thresholds)
    if (epoch in args.training_thresholds) == True: 
        adjust_learning_rate(optimizer, epoch, reset_lr=args.base_lr, beg=-1, end=-1)
    writer.add_scalar('learning rate', optimizer.param_groups[0]['lr'], global_step=epoch)
    for ind, data in enumerate(train_loader, 0):
#         if  ind == 1 : break
        input_img, gt_albedo, gt_shading, test_scene, img_path = data
        im = input_img[0,:,:,:].numpy(); im = im.transpose(1,2,0); im = im[:,:,::-1]*255
        
        if test_scene[0] == 'alley_1':
            print('alley_1 yes')
        input_img = Variable(input_img)
        gt_albedo = Variable(gt_albedo)
        gt_shading = Variable(gt_shading)
        if use_gpu:
            input_img = input_img.cuda()
            gt_albedo = gt_albedo.cuda()
            gt_shading = gt_shading.cuda()
#         run_losses = [0] * len(mse_losses)
#         run_cnts = [0.00001] * len(mse_losses)
        if args.display_curindex % args.display_interval == 0:
            cv2.imwrite('snapshot{}/input.png'.format(args.gpu_num), im)

        optimizer.zero_grad()
#         pretrained.train(); ft_pretreained = pretrained(input_img)
        ft_predict = net(input_img)
        for i, threshold in enumerate(args.training_thresholds):
#             threshold = args.training_thresholds[i]
            if epoch >= threshold:
#             if epoch >= 0:
                if i == 5: s = 1
                else: s = (2**(i+1))
                gt = gt_albedo.cpu().data.numpy()
                n,c,h,w = gt.shape
                gt = gt[0,:,:,:]
                gt = gt.transpose((1,2,0))
                gt = cv2.resize(gt, (h//s, w//s))
#                 gt = cv2.resize(gt, (h,w))
                if args.display_curindex % args.display_interval == 0:
                    cv2.imwrite('snapshot{}/gt-{}-{}.png'.format(args.gpu_num, epoch, i), gt[:,:,::-1]*255)
                gt = gt.transpose((2,0,1))
                gt = gt[np.newaxis, :]
                gt = Variable(torch.from_numpy(gt))
                if use_gpu: gt = gt.cuda()
                loss = mse_losses[i](ft_predict[i], gt)
                loss_data = loss.data.cpu().numpy()
                writer.add_scalar('{}th train iters loss'.format(i), loss_data, global_step=args.display_curindex)
                ma_ = ft_predict[i].max().cpu().data.numpy()
                mi_ = ft_predict[i].min().cpu().data.numpy()
                #print('mi', mi_, 'ma', ma_)
#                 writer.add_scalars('{}th train predict'.format(i), {'max': ma_, 'min': mi_}, global_step=args.display_curindex)
#                 run_cnts[i] += 1
                run_losses[i] += loss.data.cpu().numpy()[0]
                loss.backward(retain_graph=True)
                run_cnts[i] += 1
#                 print('i = ', i, '; weig\n', net.upsample01.weight[0,0,0:4,0:4].data.cpu().numpy())
#                 print('i = ', i, '; grad\n', net.upsample01.weight.grad[0,0,0:4,0:4].data.cpu().numpy())
                if args.display_curindex % args.display_interval == 0:
                    im = ft_predict[i].cpu().data.numpy()[0].transpose((1,2,0)) * 255
                    cv2.imwrite('snapshot{}/train-{}-{}.png'.format(args.gpu_num, epoch, i), im[:,:,::-1])
        optimizer.step()
        args.display_curindex += 1

    """ every epoch """
#     loss_output = 'ind: ' + str(args.display_curindex)
    loss_output = ''
    
    
    
    for i,v in enumerate(run_losses):
        if i == len(run_losses)-1: 
            loss_output += ' merged: %6f' % (run_losses[i] / run_cnts[i])
            continue
        loss_output += ' %2dM: %6f' % ((2**(4-i)), (run_losses[i] / run_cnts[i]))
    print(loss_output)
    # save at every epoch
    if (epoch+1) % 10 == 0:
        torch.save({
            'epoch': epoch,
            'args' : args,
            'state_dict': net.state_dict(),
            'optimizer': optimizer.state_dict()
        }, 'snapshot{}/snapshot-{}.pth.tar'.format(args.gpu_num, epoch))
    
    # test 
    test_losses_trainphase = [0] * len(args.training_thresholds)
    test_cnts_trainphase   = [0.00001] * len(args.training_thresholds)   
    for ind, data in enumerate(test_loader, 0):
        input_img, gt_albedo, gt_shading, test_scene, img_path = data
        input_img = Variable(input_img)
        gt_albedo = Variable(gt_albedo)
        gt_shading = Variable(gt_shading)
        if use_gpu:
            input_img = input_img.cuda(args.gpu_num)
        
#         pretrained.train(); ft_pretreained = pretrained(input_img)
        ft_test = net(input_img)
            
        for i,v in enumerate(ft_test):
            if epoch < args.training_thresholds[i]: continue
            if i == 5: s = 1
            else: s = (2**(i+1))
            gt = gt_albedo.data.numpy()
            n,c,h,w = gt.shape
            gt = gt[0,:,:,:]
            gt = gt.transpose((1,2,0))
            gt = cv2.resize(gt, (h//s, w//s))
#             gt = cv2.resize(gt, (h,w))
            
            gt = gt.transpose((2,0,1))
            gt = gt[np.newaxis, :]
            gt = Variable(torch.from_numpy(gt))
            if use_gpu: gt = gt.cuda()

            loss = mse_losses[i](ft_test[i], gt)
            
            test_losses_trainphase[i] += loss.data.cpu().numpy()[0]
            test_cnts_trainphase[i] += 1
            v = v[0].cpu().data.numpy()
            v = v.transpose(1,2,0)
            if ind == 0: cv2.imwrite('snapshot{}/test-phase_train-{}-{}.png'.format(args.gpu_num, epoch, i), v[:,:,::-1]*255)

    
    net.eval()
    test_losses = [0] * len(args.training_thresholds)
    test_cnts   = [0.00001] * len(args.training_thresholds)   
    for ind, data in enumerate(test_loader, 0):
#         if ind == 1: break
        input_img, gt_albedo, gt_shading, test_scene, img_path = data
        input_img = Variable(input_img)
        gt_albedo = Variable(gt_albedo)
        gt_shading = Variable(gt_shading)
        if use_gpu:
            input_img = input_img.cuda(args.gpu_num)
            
#         pretrained.eval(); ft_pretreained = pretrained(input_img)
        ft_test = net(input_img)
            
        for i,v in enumerate(ft_test):
            if epoch < args.training_thresholds[i]: continue
            if i == 5: s = 1
            else: s = (2**(i+1))
            gt = gt_albedo.data.numpy()
            n,c,h,w = gt.shape
            gt = gt[0,:,:,:]
            gt = gt.transpose((1,2,0))
            gt = cv2.resize(gt, (h//s, w//s))
#             gt = cv2.resize(gt, (h,w))
            
            gt = gt.transpose((2,0,1))
            gt = gt[np.newaxis, :]
            gt = Variable(torch.from_numpy(gt))
            if use_gpu: gt = gt.cuda()

            loss = mse_losses[i](ft_test[i], gt)
            
            test_losses[i] += loss.data.cpu().numpy()[0]
            test_cnts[i] += 1
            v = v[0].cpu().data.numpy()
            v = v.transpose(1,2,0)
            if ind == 0: cv2.imwrite('snapshot{}/test-phase_test-{}-{}.png'.format(args.gpu_num, epoch, i), v[:,:,::-1]*255)
    
    writer.add_scalars('16M loss', {
        'train 16M ': np.array([run_losses[0]/ run_cnts[0]]),
        'test_trainphase 16M ': np.array([test_losses_trainphase[0]/ test_cnts_trainphase[0]]),
        'test 16M ': np.array([test_losses[0]/ test_cnts[0]])
    }, global_step=epoch)  
    writer.add_scalars('8M loss', {
        'train 8M ': np.array([run_losses[1]/ run_cnts[1]]),
        'test_trainphase 8M ': np.array([test_losses_trainphase[1]/ test_cnts_trainphase[1]]),
        'test 8M ': np.array([test_losses[1]/ test_cnts[1]])
    }, global_step=epoch) 
    writer.add_scalars('4M loss', {
        'train 4M ': np.array([run_losses[2]/ run_cnts[2]]),
        'test_trainphase 4M ': np.array([test_losses_trainphase[2]/ test_cnts_trainphase[2]]),
        'test 4M ': np.array([test_losses[2]/ test_cnts[2]])
    }, global_step=epoch) 
    writer.add_scalars('2M loss', {
        'train 2M ': np.array([run_losses[3]/ run_cnts[3]]),
        'test_trainphase 2M ': np.array([test_losses_trainphase[3]/ test_cnts_trainphase[3]]),
        'test 2M ': np.array([test_losses[3]/ test_cnts[3]])
    }, global_step=epoch) 
    writer.add_scalars('1M loss', {
        'train 1M ': np.array([run_losses[4]/ run_cnts[4]]),
        'test_trainphase 1M ': np.array([test_losses_trainphase[4]/ test_cnts_trainphase[4]]),
        'test 1M ': np.array([test_losses[4]/ test_cnts[4]])
    }, global_step=epoch) 
    writer.add_scalars('merged loss', {
        'train merged ': np.array([run_losses[5]/ run_cnts[5]]),
        'test_trainphase merged ': np.array([test_losses_trainphase[5]/ test_cnts_trainphase[5]]),
        'test merged ': np.array([test_losses[5]/ test_cnts[5]])
    }, global_step=epoch)


epoch: 0 [2017-11-23 15:43:27]
lr 0.05
 16M: 0.000000  8M: 0.000000  4M: 0.000000  2M: 0.000000  1M: 0.056346 merged: 0.000000
epoch: 1 [2017-11-23 15:45:35]
lr 0.044721359549995794
 16M: 0.000000  8M: 0.000000  4M: 0.000000  2M: 0.000000  1M: 0.051293 merged: 0.000000
epoch: 2 [2017-11-23 15:47:41]
lr 0.038729833462074176
 16M: 0.000000  8M: 0.000000  4M: 0.000000  2M: 0.000000  1M: 0.050426 merged: 0.000000
epoch: 3 [2017-11-23 15:49:48]
lr 0.0316227766016838
 16M: 0.000000  8M: 0.000000  4M: 0.000000  2M: 0.000000  1M: 0.045310 merged: 0.000000
epoch: 4 [2017-11-23 15:51:54]
lr 0.022360679774997897
 16M: 0.000000  8M: 0.000000  4M: 0.000000  2M: 0.000000  1M: 0.041864 merged: 0.000000
epoch: 5 [2017-11-23 15:54:01]
lr 1e-08
 16M: 0.000000  8M: 0.000000  4M: 0.000000  2M: 0.000000  1M: 0.038081 merged: 0.000000
epoch: 6 [2017-11-23 15:56:07]
lr 1e-08
 16M: 0.000000  8M: 0.000000  4M: 0.000000  2M: 0.050632  1M: 0.044893 merged: 0.000000
epoch: 7 [2017-11-23 15:58:36]
lr 0.044721359549995794
 16M: 0.000000  8M: 0.000000  4M: 0.000000  2M: 0.041830  1M: 0.037297 merged: 0.000000
epoch: 8 [2017-11-23 16:01:10]
lr 0.038729833462074176
 16M: 0.000000  8M: 0.000000  4M: 0.000000  2M: 0.038328  1M: 0.036307 merged: 0.000000
epoch: 9 [2017-11-23 16:03:48]
lr 0.0316227766016838
 16M: 0.000000  8M: 0.000000  4M: 0.000000  2M: 0.030455  1M: 0.031241 merged: 0.000000
epoch: 10 [2017-11-23 16:06:18]
lr 0.022360679774997897
 16M: 0.000000  8M: 0.000000  4M: 0.000000  2M: 0.023524  1M: 0.028136 merged: 0.000000
epoch: 11 [2017-11-23 16:08:50]
lr 1e-08
 16M: 0.000000  8M: 0.000000  4M: 0.000000  2M: 0.020230  1M: 0.028566 merged: 0.000000
epoch: 12 [2017-11-23 16:11:21]
lr 1e-08
 16M: 0.000000  8M: 0.000000  4M: 0.048583  2M: 0.027805  1M: 0.031723 merged: 0.000000
epoch: 13 [2017-11-23 16:14:11]
lr 0.044721359549995794
 16M: 0.000000  8M: 0.000000  4M: 0.039270  2M: 0.023627  1M: 0.029941 merged: 0.000000
epoch: 14 [2017-11-23 16:16:55]
lr 0.038729833462074176
 16M: 0.000000  8M: 0.000000  4M: 0.028857  2M: 0.020459  1M: 0.027377 merged: 0.000000
epoch: 15 [2017-11-23 16:19:40]
lr 0.0316227766016838
 16M: 0.000000  8M: 0.000000  4M: 0.023017  2M: 0.018355  1M: 0.025514 merged: 0.000000
epoch: 16 [2017-11-23 16:22:27]
lr 0.022360679774997897
 16M: 0.000000  8M: 0.000000  4M: 0.019011  2M: 0.016497  1M: 0.023245 merged: 0.000000
epoch: 17 [2017-11-23 16:25:16]
lr 1e-08
 16M: 0.000000  8M: 0.000000  4M: 0.018064  2M: 0.016422  1M: 0.022323 merged: 0.000000
epoch: 18 [2017-11-23 16:28:10]
lr 1e-08
 16M: 0.000000  8M: 0.047219  4M: 0.025861  2M: 0.018504  1M: 0.026694 merged: 0.000000
epoch: 19 [2017-11-23 16:31:10]
lr 0.044721359549995794
 16M: 0.000000  8M: 0.037033  4M: 0.021127  2M: 0.017448  1M: 0.024489 merged: 0.000000
epoch: 20 [2017-11-23 16:34:14]
lr 0.038729833462074176
 16M: 0.000000  8M: 0.026921  4M: 0.017209  2M: 0.016117  1M: 0.024114 merged: 0.000000
epoch: 21 [2017-11-23 16:37:11]
lr 0.0316227766016838
 16M: 0.000000  8M: 0.022049  4M: 0.015750  2M: 0.014885  1M: 0.022948 merged: 0.000000
epoch: 22 [2017-11-23 16:40:12]
lr 0.022360679774997897
 16M: 0.000000  8M: 0.018257  4M: 0.014509  2M: 0.014577  1M: 0.022100 merged: 0.000000
epoch: 23 [2017-11-23 16:43:15]
lr 1e-08
 16M: 0.000000  8M: 0.018828  4M: 0.014502  2M: 0.014097  1M: 0.023216 merged: 0.000000
epoch: 24 [2017-11-23 16:46:13]
lr 1e-08
 16M: 0.043308  8M: 0.021372  4M: 0.016752  2M: 0.015713  1M: 0.024243 merged: 0.000000
epoch: 25 [2017-11-23 16:49:25]
lr 0.044721359549995794
 16M: 0.032101  8M: 0.018288  4M: 0.016559  2M: 0.014735  1M: 0.023415 merged: 0.000000
epoch: 26 [2017-11-23 16:52:37]
lr 0.038729833462074176
 16M: 0.023015  8M: 0.015692  4M: 0.014008  2M: 0.014034  1M: 0.022839 merged: 0.000000
epoch: 27 [2017-11-23 16:55:50]
lr 0.0316227766016838
 16M: 0.020160  8M: 0.014811  4M: 0.012931  2M: 0.013113  1M: 0.021341 merged: 0.000000
epoch: 28 [2017-11-23 16:58:59]
lr 0.022360679774997897
 16M: 0.017321  8M: 0.012983  4M: 0.012009  2M: 0.012536  1M: 0.021214 merged: 0.000000
epoch: 29 [2017-11-23 17:02:10]
lr 1e-08
 16M: 0.016378  8M: 0.012169  4M: 0.011484  2M: 0.011850  1M: 0.019770 merged: 0.000000
epoch: 30 [2017-11-23 17:05:22]
lr 0.05
 16M: 0.020712  8M: 0.017511  4M: 0.014794  2M: 0.014449  1M: 0.022852 merged: 0.034070
epoch: 31 [2017-11-23 17:10:11]
lr 0.04971830761761256
 16M: 0.018833  8M: 0.021438  4M: 0.014525  2M: 0.014964  1M: 0.022585 merged: 0.023343
epoch: 32 [2017-11-23 17:14:56]
lr 0.04943501011144937
 16M: 0.018299  8M: 0.019179  4M: 0.014672  2M: 0.015636  1M: 0.022657 merged: 0.019402
epoch: 33 [2017-11-23 17:19:44]
lr 0.04915007972606608
 16M: 0.018964  8M: 0.018934  4M: 0.014052  2M: 0.014680  1M: 0.022131 merged: 0.018770
epoch: 34 [2017-11-23 17:24:30]
lr 0.04886348789677424
 16M: 0.015564  8M: 0.015990  4M: 0.012950  2M: 0.015276  1M: 0.021739 merged: 0.015841
epoch: 35 [2017-11-23 17:29:18]
lr 0.04857520521621862
 16M: 0.014355  8M: 0.014931  4M: 0.012063  2M: 0.013249  1M: 0.021396 merged: 0.014666
epoch: 36 [2017-11-23 17:34:02]
lr 0.04828520139915856
 16M: 0.013881  8M: 0.014458  4M: 0.012122  2M: 0.013282  1M: 0.021773 merged: 0.013567
epoch: 37 [2017-11-23 17:38:49]
lr 0.047993445245333805
 16M: 0.013910  8M: 0.013776  4M: 0.011785  2M: 0.012927  1M: 0.021087 merged: 0.013424
epoch: 38 [2017-11-23 17:43:35]
lr 0.0476999046002862
 16M: 0.012601  8M: 0.012800  4M: 0.011083  2M: 0.012741  1M: 0.020237 merged: 0.012008
epoch: 39 [2017-11-23 17:48:22]
lr 0.04740454631399772
 16M: 0.012403  8M: 0.012313  4M: 0.010798  2M: 0.012458  1M: 0.020177 merged: 0.011509
epoch: 40 [2017-11-23 17:53:06]
lr 0.04710733619719444
 16M: 0.012579  8M: 0.012985  4M: 0.011277  2M: 0.012844  1M: 0.020181 merged: 0.011930
epoch: 41 [2017-11-23 17:57:54]
lr 0.04680823897515326
 16M: 0.012168  8M: 0.012263  4M: 0.010703  2M: 0.012642  1M: 0.019654 merged: 0.011600
epoch: 42 [2017-11-23 18:02:40]
lr 0.04650721823883479
 16M: 0.011331  8M: 0.011296  4M: 0.010785  2M: 0.012355  1M: 0.019567 merged: 0.010806
epoch: 43 [2017-11-23 18:07:28]
lr 0.046204236393150765
 16M: 0.011144  8M: 0.011178  4M: 0.010477  2M: 0.011938  1M: 0.019129 merged: 0.010475
epoch: 44 [2017-11-23 18:12:14]
lr 0.045899254602157845
 16M: 0.010733  8M: 0.011444  4M: 0.010141  2M: 0.011775  1M: 0.018698 merged: 0.010337
epoch: 45 [2017-11-23 18:17:00]
lr 0.04559223273095164
 16M: 0.010339  8M: 0.010823  4M: 0.010085  2M: 0.011720  1M: 0.019023 merged: 0.009960
epoch: 46 [2017-11-23 18:21:47]
lr 0.045283129284014914
 16M: 0.010032  8M: 0.010160  4M: 0.009612  2M: 0.011109  1M: 0.017546 merged: 0.009373
epoch: 47 [2017-11-23 18:26:32]
lr 0.04497190133975169
 16M: 0.009949  8M: 0.010506  4M: 0.009255  2M: 0.011230  1M: 0.017993 merged: 0.009497
epoch: 48 [2017-11-23 18:31:18]
lr 0.04465850448091506
 16M: 0.009700  8M: 0.009943  4M: 0.009163  2M: 0.011064  1M: 0.018080 merged: 0.008995
epoch: 49 [2017-11-23 18:36:05]
lr 0.044342892720609255
 16M: 0.009673  8M: 0.009777  4M: 0.009617  2M: 0.011063  1M: 0.017719 merged: 0.009065
epoch: 50 [2017-11-23 18:40:52]
lr 0.044025018423517
 16M: 0.009225  8M: 0.009487  4M: 0.009153  2M: 0.010903  1M: 0.017787 merged: 0.008607
epoch: 51 [2017-11-23 18:45:39]
lr 0.04370483222197017
 16M: 0.009039  8M: 0.009426  4M: 0.009004  2M: 0.010862  1M: 0.017729 merged: 0.008547
epoch: 52 [2017-11-23 18:50:26]
lr 0.043382282926444894
 16M: 0.009082  8M: 0.009899  4M: 0.009054  2M: 0.010722  1M: 0.017208 merged: 0.008906
epoch: 53 [2017-11-23 18:55:15]
lr 0.04305731743002185
 16M: 0.009030  8M: 0.009506  4M: 0.008725  2M: 0.010574  1M: 0.017572 merged: 0.008462
epoch: 54 [2017-11-23 19:00:02]
lr 0.04272988060630656
 16M: 0.008931  8M: 0.009345  4M: 0.008965  2M: 0.010361  1M: 0.016909 merged: 0.008256
epoch: 55 [2017-11-23 19:04:46]
lr 0.04239991520025441
 16M: 0.008547  8M: 0.008831  4M: 0.008768  2M: 0.010429  1M: 0.016975 merged: 0.008028
epoch: 56 [2017-11-23 19:09:33]
lr 0.0420673617112877
 16M: 0.008343  8M: 0.008646  4M: 0.008209  2M: 0.010143  1M: 0.016541 merged: 0.007812
epoch: 57 [2017-11-23 19:14:41]
lr 0.041732158268029534
 16M: 0.008478  8M: 0.008490  4M: 0.008173  2M: 0.010049  1M: 0.016816 merged: 0.007553
epoch: 58 [2017-11-23 19:19:49]
lr 0.041394240493907074
 16M: 0.008058  8M: 0.008470  4M: 0.008058  2M: 0.009947  1M: 0.016273 merged: 0.007711
epoch: 59 [2017-11-23 19:24:54]
lr 0.041053541362798006
 16M: 0.008373  8M: 0.008548  4M: 0.008487  2M: 0.009992  1M: 0.016794 merged: 0.007629
epoch: 60 [2017-11-23 19:30:03]
lr 0.04070999104380296
 16M: 0.008072  8M: 0.008511  4M: 0.007858  2M: 0.009792  1M: 0.016283 merged: 0.007470
epoch: 61 [2017-11-23 19:35:09]
lr 0.04036351673412598
 16M: 0.007866  8M: 0.008163  4M: 0.007915  2M: 0.009663  1M: 0.015977 merged: 0.007333
epoch: 62 [2017-11-23 19:40:16]
lr 0.04001404247893005
 16M: 0.007819  8M: 0.008225  4M: 0.007804  2M: 0.009861  1M: 0.016271 merged: 0.007266
epoch: 63 [2017-11-23 19:45:22]
lr 0.03966148897690515
 16M: 0.007896  8M: 0.008366  4M: 0.007824  2M: 0.009503  1M: 0.015478 merged: 0.007410
epoch: 64 [2017-11-23 19:50:27]
lr 0.03930577337013889
 16M: 0.007712  8M: 0.008371  4M: 0.007770  2M: 0.009594  1M: 0.015745 merged: 0.007434
epoch: 65 [2017-11-23 19:55:34]
lr 0.038946809016712394
 16M: 0.007428  8M: 0.008067  4M: 0.007659  2M: 0.009492  1M: 0.015648 merged: 0.007070
epoch: 66 [2017-11-23 20:00:40]
lr 0.03858450524425343
 16M: 0.007367  8M: 0.008032  4M: 0.007657  2M: 0.009327  1M: 0.015771 merged: 0.006872
epoch: 67 [2017-11-23 20:05:46]
lr 0.03821876708246056
 16M: 0.007583  8M: 0.007865  4M: 0.007479  2M: 0.009455  1M: 0.015713 merged: 0.006866
epoch: 68 [2017-11-23 20:10:52]
lr 0.03784949497236286
 16M: 0.007466  8M: 0.007822  4M: 0.007341  2M: 0.009389  1M: 0.015337 merged: 0.006894
epoch: 69 [2017-11-23 20:15:58]
lr 0.03747658444979307
 16M: 0.007214  8M: 0.007623  4M: 0.007273  2M: 0.009217  1M: 0.015262 merged: 0.006580
epoch: 70 [2017-11-23 20:21:05]
lr 0.0370999258002226
 16M: 0.007037  8M: 0.007456  4M: 0.007144  2M: 0.009135  1M: 0.015221 merged: 0.006504
epoch: 71 [2017-11-23 20:26:12]
lr 0.03671940368172628
 16M: 0.006854  8M: 0.007292  4M: 0.007170  2M: 0.009012  1M: 0.015303 merged: 0.006437
epoch: 72 [2017-11-23 20:31:18]
lr 0.03633489671240478
 16M: 0.007258  8M: 0.007527  4M: 0.007229  2M: 0.009095  1M: 0.015068 merged: 0.006618
epoch: 73 [2017-11-23 20:36:24]
lr 0.03594627701808178
 16M: 0.008107  8M: 0.008554  4M: 0.007305  2M: 0.009270  1M: 0.015527 merged: 0.007379
epoch: 74 [2017-11-23 20:41:30]
lr 0.035553409735498295
 16M: 0.007033  8M: 0.007549  4M: 0.007137  2M: 0.008912  1M: 0.015021 merged: 0.006474
epoch: 75 [2017-11-23 20:46:37]
lr 0.03515615246553262
 16M: 0.006999  8M: 0.007591  4M: 0.007315  2M: 0.009197  1M: 0.015321 merged: 0.006532
epoch: 76 [2017-11-23 20:51:43]
lr 0.03475435467016077
 16M: 0.006839  8M: 0.007251  4M: 0.007050  2M: 0.008883  1M: 0.014751 merged: 0.006258
epoch: 77 [2017-11-23 20:56:48]
lr 0.034347857005916346
 16M: 0.006744  8M: 0.007096  4M: 0.006830  2M: 0.008780  1M: 0.014740 merged: 0.006218
epoch: 78 [2017-11-23 21:01:55]
lr 0.0339364905854808
 16M: 0.006885  8M: 0.007222  4M: 0.007020  2M: 0.009008  1M: 0.014697 merged: 0.006322
epoch: 79 [2017-11-23 21:07:01]
lr 0.03352007615769955
 16M: 0.006489  8M: 0.007115  4M: 0.006807  2M: 0.008750  1M: 0.014395 merged: 0.006058
epoch: 80 [2017-11-23 21:12:09]
lr 0.03309842319473132
 16M: 0.006518  8M: 0.006836  4M: 0.006683  2M: 0.008762  1M: 0.014484 merged: 0.005994
epoch: 81 [2017-11-23 21:17:16]
lr 0.03267132887314317
 16M: 0.006507  8M: 0.006882  4M: 0.006600  2M: 0.008623  1M: 0.014283 merged: 0.006047
epoch: 82 [2017-11-23 21:22:22]
lr 0.03223857693349118
 16M: 0.006306  8M: 0.006723  4M: 0.006653  2M: 0.008636  1M: 0.014633 merged: 0.005859
epoch: 83 [2017-11-23 21:27:30]
lr 0.0317999364001908
 16M: 0.006392  8M: 0.006860  4M: 0.006721  2M: 0.008683  1M: 0.014600 merged: 0.005998
epoch: 84 [2017-11-23 21:32:36]
lr 0.031355160140170396
 16M: 0.006405  8M: 0.006696  4M: 0.006655  2M: 0.008488  1M: 0.014195 merged: 0.005918
epoch: 85 [2017-11-23 21:37:43]
lr 0.03090398323477543
 16M: 0.006368  8M: 0.006697  4M: 0.006681  2M: 0.008445  1M: 0.013947 merged: 0.005796
epoch: 86 [2017-11-23 21:42:49]
lr 0.030446121134470178
 16M: 0.006295  8M: 0.006605  4M: 0.006404  2M: 0.008354  1M: 0.014162 merged: 0.005783
epoch: 87 [2017-11-23 21:47:53]
lr 0.02998126755983446
 16M: 0.006207  8M: 0.006506  4M: 0.006345  2M: 0.008221  1M: 0.013727 merged: 0.005626
epoch: 88 [2017-11-23 21:52:58]
lr 0.029509092104873926
 16M: 0.006099  8M: 0.006441  4M: 0.006334  2M: 0.008339  1M: 0.014071 merged: 0.005646
epoch: 89 [2017-11-23 21:58:04]
lr 0.029029237489356888
 16M: 0.005960  8M: 0.006362  4M: 0.006332  2M: 0.008183  1M: 0.014013 merged: 0.005490
epoch: 90 [2017-11-23 22:03:10]
lr 0.028541316395237167
 16M: 0.005999  8M: 0.006451  4M: 0.006389  2M: 0.008351  1M: 0.014029 merged: 0.005612
epoch: 91 [2017-11-23 22:08:19]
lr 0.028044907807525134
 16M: 0.005936  8M: 0.006438  4M: 0.006431  2M: 0.008344  1M: 0.014033 merged: 0.005556
epoch: 92 [2017-11-23 22:13:28]
lr 0.027539552761294706
 16M: 0.005903  8M: 0.006445  4M: 0.006401  2M: 0.008335  1M: 0.014263 merged: 0.005548
epoch: 93 [2017-11-23 22:18:40]
lr 0.027024749372597065
 16M: 0.005874  8M: 0.006271  4M: 0.006121  2M: 0.007937  1M: 0.013390 merged: 0.005401
epoch: 94 [2017-11-23 22:23:48]
lr 0.026499947000159004
 16M: 0.005765  8M: 0.006220  4M: 0.006133  2M: 0.008010  1M: 0.013389 merged: 0.005352
epoch: 95 [2017-11-23 22:28:56]
lr 0.02596453934447493
 16M: 0.006001  8M: 0.006404  4M: 0.006227  2M: 0.008148  1M: 0.013933 merged: 0.005541
epoch: 96 [2017-11-23 22:34:02]
lr 0.025417856237895775
 16M: 0.005724  8M: 0.006247  4M: 0.006229  2M: 0.008153  1M: 0.013784 merged: 0.005386
epoch: 97 [2017-11-23 22:39:10]
lr 0.02485915380880628
 16M: 0.005780  8M: 0.006110  4M: 0.006004  2M: 0.007789  1M: 0.013503 merged: 0.005329
epoch: 98 [2017-11-23 22:44:16]
lr 0.02428760260810931
 16M: 0.005802  8M: 0.006179  4M: 0.006196  2M: 0.008103  1M: 0.014025 merged: 0.005329
epoch: 99 [2017-11-23 22:49:23]
lr 0.02370227315699886
 16M: 0.005709  8M: 0.006183  4M: 0.006323  2M: 0.008038  1M: 0.013852 merged: 0.005339
epoch: 100 [2017-11-23 22:54:31]
lr 0.023102118196575382
 16M: 0.005715  8M: 0.006163  4M: 0.006140  2M: 0.007930  1M: 0.013620 merged: 0.005362
epoch: 101 [2017-11-23 22:59:37]
lr 0.022485950669875843
 16M: 0.005607  8M: 0.006093  4M: 0.006034  2M: 0.007832  1M: 0.013403 merged: 0.005258
epoch: 102 [2017-11-23 23:04:45]
lr 0.021852416110985085
 16M: 0.005567  8M: 0.006053  4M: 0.005931  2M: 0.007823  1M: 0.013184 merged: 0.005200
epoch: 103 [2017-11-23 23:09:51]
lr 0.021199957600127203
 16M: 0.005631  8M: 0.006185  4M: 0.006130  2M: 0.007918  1M: 0.013372 merged: 0.005345
epoch: 104 [2017-11-23 23:14:56]
lr 0.020526770681399003
 16M: 0.005535  8M: 0.006017  4M: 0.005927  2M: 0.007670  1M: 0.013209 merged: 0.005089
epoch: 105 [2017-11-23 23:20:02]
lr 0.019830744488452574
 16M: 0.005500  8M: 0.005961  4M: 0.006035  2M: 0.007870  1M: 0.013455 merged: 0.005092
epoch: 106 [2017-11-23 23:25:09]
lr 0.01910938354123028
 16M: 0.005426  8M: 0.005825  4M: 0.005942  2M: 0.007640  1M: 0.012962 merged: 0.004988
epoch: 107 [2017-11-23 23:30:16]
lr 0.01835970184086314
 16M: 0.005393  8M: 0.005897  4M: 0.005958  2M: 0.007774  1M: 0.013309 merged: 0.005016
epoch: 108 [2017-11-23 23:35:22]
lr 0.01757807623276631
 16M: 0.005299  8M: 0.005843  4M: 0.005938  2M: 0.007797  1M: 0.013363 merged: 0.004982
epoch: 109 [2017-11-23 23:40:29]
lr 0.016760038078849775
 16M: 0.005373  8M: 0.005822  4M: 0.005889  2M: 0.007723  1M: 0.013032 merged: 0.004993
epoch: 110 [2017-11-23 23:45:36]
lr 0.0158999682000954
 16M: 0.005318  8M: 0.005841  4M: 0.005854  2M: 0.007715  1M: 0.013302 merged: 0.005018
epoch: 111 [2017-11-23 23:50:43]
lr 0.01499063377991723
 16M: 0.005307  8M: 0.005818  4M: 0.005876  2M: 0.007740  1M: 0.013145 merged: 0.004936
epoch: 112 [2017-11-23 23:55:52]
lr 0.014022453903762567
 16M: 0.005386  8M: 0.005804  4M: 0.005816  2M: 0.007696  1M: 0.013155 merged: 0.004952
epoch: 113 [2017-11-24 00:00:59]
lr 0.012982269672237465
 16M: 0.005233  8M: 0.005731  4M: 0.005816  2M: 0.007640  1M: 0.013208 merged: 0.004849
epoch: 114 [2017-11-24 00:06:06]
lr 0.01185113657849943
 16M: 0.005159  8M: 0.005689  4M: 0.005753  2M: 0.007622  1M: 0.013367 merged: 0.004836
epoch: 115 [2017-11-24 00:11:14]
lr 0.010599978800063602
 16M: 0.005193  8M: 0.005688  4M: 0.005773  2M: 0.007629  1M: 0.013239 merged: 0.004779
epoch: 116 [2017-11-24 00:16:21]
lr 0.00917985092043157
 16M: 0.005108  8M: 0.005619  4M: 0.005707  2M: 0.007514  1M: 0.012882 merged: 0.004766
epoch: 117 [2017-11-24 00:21:18]
lr 0.007495316889958615
 16M: 0.005047  8M: 0.005515  4M: 0.005605  2M: 0.007311  1M: 0.012730 merged: 0.004687
epoch: 118 [2017-11-24 00:26:12]
lr 0.005299989400031801
 16M: 0.005102  8M: 0.005702  4M: 0.005807  2M: 0.007743  1M: 0.013352 merged: 0.004809
epoch: 119 [2017-11-24 00:31:07]
lr 1e-08
 16M: 0.005044  8M: 0.005587  4M: 0.005622  2M: 0.007457  1M: 0.012889 merged: 0.004721

Visualize Graph


In [7]:
from graphviz import Digraph
import torch
from torch.autograd import Variable


def make_dot(var, params=None):
    """ Produces Graphviz representation of PyTorch autograd graph
    Blue nodes are the Variables that require grad, orange are Tensors
    saved for backward in torch.autograd.Function
    Args:
        var: output Variable
        params: dict of (name, Variable) to add names to node that
            require grad (TODO: make optional)
    """
    if params is not None:
        assert isinstance(params.values()[0], Variable)
        param_map = {id(v): k for k, v in params.items()}

    node_attr = dict(style='filled',
                     shape='box',
                     align='left',
                     fontsize='12',
                     ranksep='0.1',
                     height='0.2')
    dot = Digraph(node_attr=node_attr, graph_attr=dict(size="10240,10240"), format='svg')
    seen = set()

    def size_to_str(size):
        return '('+(', ').join(['%d' % v for v in size])+')'

    def add_nodes(var):
        if var not in seen:
            if torch.is_tensor(var):
                dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')
            elif hasattr(var, 'variable'):
                u = var.variable
                name = param_map[id(u)] if params is not None else ''
                node_name = '%s\n %s' % (name, size_to_str(u.size()))
                dot.node(str(id(var)), node_name, fillcolor='lightblue')
            else:
                dot.node(str(id(var)), str(type(var).__name__))
            seen.add(var)
            if hasattr(var, 'next_functions'):
                for u in var.next_functions:
                    if u[0] is not None:
                        dot.edge(str(id(u[0])), str(id(var)))
                        add_nodes(u[0])
            if hasattr(var, 'saved_tensors'):
                for t in var.saved_tensors:
                    dot.edge(str(id(t)), str(id(var)))
                    add_nodes(t)
    add_nodes(var.grad_fn)
    return dot

In [8]:
x = Variable(torch.zeros(1,3,256,256))
y = net(x.cuda())
g = make_dot(y[-1])

In [9]:
g.render('net-transition_scale_{}'.format(transition_scale))


---------------------------------------------------------------------------
FileNotFoundError                         Traceback (most recent call last)
~/anaconda3/lib/python3.6/site-packages/graphviz/backend.py in render(engine, format, filepath, quiet)
    123         try:
--> 124             subprocess.check_call(args, startupinfo=STARTUPINFO, stderr=stderr)
    125         except OSError as e:

~/anaconda3/lib/python3.6/subprocess.py in check_call(*popenargs, **kwargs)
    285     """
--> 286     retcode = call(*popenargs, **kwargs)
    287     if retcode:

~/anaconda3/lib/python3.6/subprocess.py in call(timeout, *popenargs, **kwargs)
    266     """
--> 267     with Popen(*popenargs, **kwargs) as p:
    268         try:

~/anaconda3/lib/python3.6/subprocess.py in __init__(self, args, bufsize, executable, stdin, stdout, stderr, preexec_fn, close_fds, shell, cwd, env, universal_newlines, startupinfo, creationflags, restore_signals, start_new_session, pass_fds, encoding, errors)
    708                                 errread, errwrite,
--> 709                                 restore_signals, start_new_session)
    710         except:

~/anaconda3/lib/python3.6/subprocess.py in _execute_child(self, args, executable, preexec_fn, close_fds, pass_fds, cwd, env, startupinfo, creationflags, shell, p2cread, p2cwrite, c2pread, c2pwrite, errread, errwrite, restore_signals, start_new_session)
   1343                             err_msg += ': ' + repr(err_filename)
-> 1344                     raise child_exception_type(errno_num, err_msg, err_filename)
   1345                 raise child_exception_type(err_msg)

FileNotFoundError: [Errno 2] No such file or directory: 'dot': 'dot'

During handling of the above exception, another exception occurred:

ExecutableNotFound                        Traceback (most recent call last)
<ipython-input-9-81727d6bc75d> in <module>()
----> 1 g.render('net-transition_scale_{}'.format(transition_scale))

~/anaconda3/lib/python3.6/site-packages/graphviz/files.py in render(self, filename, directory, view, cleanup)
    173         filepath = self.save(filename, directory)
    174 
--> 175         rendered = backend.render(self._engine, self._format, filepath)
    176 
    177         if cleanup:

~/anaconda3/lib/python3.6/site-packages/graphviz/backend.py in render(engine, format, filepath, quiet)
    125         except OSError as e:
    126             if e.errno == errno.ENOENT:
--> 127                 raise ExecutableNotFound(args)
    128             else:  # pragma: no cover
    129                 raise

ExecutableNotFound: failed to execute ['dot', '-Tsvg', '-O', 'net-transition_scale_4'], make sure the Graphviz executables are on your systems' PATH

In [ ]: