In [1]:
import os, glob, platform, datetime, random
from collections import OrderedDict

import torch
import torch.nn as nn
import torch.utils.data as data_utils
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
from torch.autograd import Variable
from torch import functional as F
# import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms

import cv2
from PIL import Image
from tensorboardX import SummaryWriter

import numpy as np
from numpy.linalg import inv as denseinv
from scipy import sparse
from scipy.sparse import lil_matrix, csr_matrix
from scipy.sparse.linalg import spsolve
from scipy.sparse.linalg import inv as spinv
import scipy.misc

from myimagefolder import MyImageFolder
from mymodel import GradientNet
from myargs import Args

Configurations


In [2]:
args = Args()
args.test_scene = 'alley_1'
args.arch = "densenet121"
args.epoches = 500
args.epoches_unary_threshold = 0
args.image_h = 256
args.image_w = 256
args.img_extentions = ["png"]
args.training_thresholds = [250,200,150,50,0,300]
args.base_lr = 1
args.lr = args.base_lr
args.snapshot_interval = 5000
args.debug = True
args.gpu_num = 1
# growth_rate = (4*(2**(args.gpu_num)))
transition_scale=(2*(2**(args.gpu_num+1)))
growth_rate = 32
args.display_interval = 50
args.display_curindex = 0

system_ = platform.system()
system_dist, system_version, _ = platform.dist()
if system_ == "Darwin": 
    args.train_dir = '/Volumes/Transcend/dataset/sintel2'
    args.pretrained = False
elif platform.dist() ==  ('debian', 'jessie/sid', ''):
    args.train_dir = '/home/lwp/workspace/sintel2'
    args.pretrained = True
elif platform.dist() == ('debian', 'stretch/sid', ''):
    args.train_dir = '/home/cad/lwp/workspace/dataset/sintel2'
    args.pretrained = True

if platform.system() == 'Linux': use_gpu = True
else: use_gpu = False
if use_gpu:
    torch.cuda.set_device(args.gpu_num)
    

print(platform.dist())


('debian', 'jessie/sid', '')

My DataLoader


In [3]:
train_dataset = MyImageFolder(args.train_dir, 'train',
                       transforms.Compose(
        [transforms.ToTensor()]
    ), random_crop=True, 
    img_extentions=args.img_extentions, test_scene=args.test_scene, image_h=args.image_h, image_w=args.image_w)
test_dataset = MyImageFolder(args.train_dir, 'test', 
                       transforms.Compose(
        [transforms.CenterCrop((args.image_h, args.image_w)),
         transforms.ToTensor()]
    ), random_crop=False,
    img_extentions=args.img_extentions, test_scene=args.test_scene, image_h=args.image_h, image_w=args.image_w)

train_loader = data_utils.DataLoader(train_dataset,1,True,num_workers=1)
test_loader = data_utils.DataLoader(test_dataset,1,True,num_workers=1)

Load Pretrained Model

Defination

  • DenseNet-121: num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16)
    • First Convolution: 32M -> 16M -> 8M
    • every transition: 8M -> 4M -> 2M (downsample 1/2, except the last block)

In [4]:
densenet = models.__dict__[args.arch](pretrained=args.pretrained)

for param in densenet.parameters():
    param.requires_grad = False

if use_gpu: densenet.cuda()

In [5]:
ss = 6

args.display_curindex = 0
args.base_lr = 0.05
args.display_interval = 20
args.momentum = 0.9
args.epoches = 120
args.training_thresholds = [ss*4,ss*3,ss*2,ss*1,ss*0,ss*5]
args.power = 0.5



# pretrained = PreTrainedModel(densenet)
# if use_gpu: 
#     pretrained.cuda()


net = GradientNet(densenet=densenet, growth_rate=growth_rate, transition_scale=transition_scale)
if use_gpu:
    net.cuda()

if use_gpu: 
    mse_losses = [nn.MSELoss().cuda()] * 6
    test_losses = [nn.MSELoss().cuda()] * 6
else:
    mse_losses = [nn.MSELoss()] * 6
    test_losses = [nn.MSELoss()] * 6

In [6]:
# training loop
writer = SummaryWriter(comment='-transition_scale_{}'.format(transition_scale))

parameters = filter(lambda p: p.requires_grad, net.parameters())
optimizer = optim.SGD(parameters, lr=args.base_lr, momentum=args.momentum)

def adjust_learning_rate(optimizer, epoch, beg, end, reset_lr=None):
    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
    for param_group in optimizer.param_groups:
        if reset_lr != None:
            param_group['lr'] = reset_lr
            continue
        if epoch != 0: 
            # linear
#             param_group['lr'] *= (end-epoch) / (end-beg)
#             poly base_lr (1 - iter/max_iter) ^ (power)
            param_group['lr'] = args.base_lr * (float(end-epoch)/(end-beg)) ** (args.power)
            if param_group['lr'] < 1.0e-8: param_group['lr'] = 1.0e-8
        print('lr', param_group['lr'])
        
# def findLargerInd(target, arr):
#     res = list(filter(lambda x: x>target, arr))
#     print('res',res)
#     if len(res) == 0: return -1
#     return res[0]

for epoch in range(args.epoches):
    net.train()
    print('epoch: {} [{}]'.format(epoch, datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")))
    
    if epoch < args.training_thresholds[-1]: adjust_learning_rate(optimizer, epoch%ss, beg=0, end=ss-1)
    else: adjust_learning_rate(optimizer, epoch, beg=args.training_thresholds[-1], end=args.epoches-1)
    
    run_losses = [0] * len(args.training_thresholds)
    run_cnts   = [0.00001] * len(args.training_thresholds)
    if (epoch in args.training_thresholds) == True: 
        adjust_learning_rate(optimizer, epoch, reset_lr=args.base_lr, beg=-1, end=-1)
    writer.add_scalar('learning rate', optimizer.param_groups[0]['lr'], global_step=epoch)
    for ind, data in enumerate(train_loader, 0):
#         if  ind == 1 : break
        input_img, gt_albedo, gt_shading, test_scene, img_path = data
        im = input_img[0,:,:,:].numpy(); im = im.transpose(1,2,0); im = im[:,:,::-1]*255
        
        if test_scene[0] == 'alley_1':
            print('alley_1 yes')
        input_img = Variable(input_img)
        gt_albedo = Variable(gt_albedo)
        gt_shading = Variable(gt_shading)
        if use_gpu:
            input_img = input_img.cuda()
            gt_albedo = gt_albedo.cuda()
            gt_shading = gt_shading.cuda()
#         run_losses = [0] * len(mse_losses)
#         run_cnts = [0.00001] * len(mse_losses)
        if args.display_curindex % args.display_interval == 0:
            cv2.imwrite('snapshot{}/input.png'.format(args.gpu_num), im)

        optimizer.zero_grad()
#         pretrained.train(); ft_pretreained = pretrained(input_img)
        ft_predict = net(input_img)
        for i, threshold in enumerate(args.training_thresholds):
#             threshold = args.training_thresholds[i]
            if epoch >= threshold:
#             if epoch >= 0:
                if i == 5: s = 1
                else: s = (2**(i+1))
                gt = gt_albedo.cpu().data.numpy()
                n,c,h,w = gt.shape
                gt = gt[0,:,:,:]
                gt = gt.transpose((1,2,0))
                gt = cv2.resize(gt, (h//s, w//s))
#                 gt = cv2.resize(gt, (h,w))
                if args.display_curindex % args.display_interval == 0:
                    cv2.imwrite('snapshot{}/gt-{}-{}.png'.format(args.gpu_num, epoch, i), gt[:,:,::-1]*255)
                gt = gt.transpose((2,0,1))
                gt = gt[np.newaxis, :]
                gt = Variable(torch.from_numpy(gt))
                if use_gpu: gt = gt.cuda()
                loss = mse_losses[i](ft_predict[i], gt)
                loss_data = loss.data.cpu().numpy()
                writer.add_scalar('{}th train iters loss'.format(i), loss_data, global_step=args.display_curindex)
                ma_ = ft_predict[i].max().cpu().data.numpy()
                mi_ = ft_predict[i].min().cpu().data.numpy()
                #print('mi', mi_, 'ma', ma_)
#                 writer.add_scalars('{}th train predict'.format(i), {'max': ma_, 'min': mi_}, global_step=args.display_curindex)
#                 run_cnts[i] += 1
                run_losses[i] += loss.data.cpu().numpy()[0]
                loss.backward(retain_graph=True)
                run_cnts[i] += 1
#                 print('i = ', i, '; weig\n', net.upsample01.weight[0,0,0:4,0:4].data.cpu().numpy())
#                 print('i = ', i, '; grad\n', net.upsample01.weight.grad[0,0,0:4,0:4].data.cpu().numpy())
                if args.display_curindex % args.display_interval == 0:
                    im = ft_predict[i].cpu().data.numpy()[0].transpose((1,2,0)) * 255
                    cv2.imwrite('snapshot{}/train-{}-{}.png'.format(args.gpu_num, epoch, i), im[:,:,::-1])
        optimizer.step()
        args.display_curindex += 1

    """ every epoch """
#     loss_output = 'ind: ' + str(args.display_curindex)
    loss_output = ''
    
    
    
    for i,v in enumerate(run_losses):
        if i == len(run_losses)-1: 
            loss_output += ' merged: %6f' % (run_losses[i] / run_cnts[i])
            continue
        loss_output += ' %2dM: %6f' % ((2**(4-i)), (run_losses[i] / run_cnts[i]))
    print(loss_output)
    # save at every epoch
    if (epoch+1) % 10 == 0:
        torch.save({
            'epoch': epoch,
            'args' : args,
            'state_dict': net.state_dict(),
            'optimizer': optimizer.state_dict()
        }, 'snapshot{}/snapshot-{}.pth.tar'.format(args.gpu_num, epoch))
    
    # test 
    test_losses_trainphase = [0] * len(args.training_thresholds)
    test_cnts_trainphase   = [0.00001] * len(args.training_thresholds)   
    for ind, data in enumerate(test_loader, 0):
        input_img, gt_albedo, gt_shading, test_scene, img_path = data
        input_img = Variable(input_img)
        gt_albedo = Variable(gt_albedo)
        gt_shading = Variable(gt_shading)
        if use_gpu:
            input_img = input_img.cuda(args.gpu_num)
        
#         pretrained.train(); ft_pretreained = pretrained(input_img)
        ft_test = net(input_img)
            
        for i,v in enumerate(ft_test):
            if epoch < args.training_thresholds[i]: continue
            if i == 5: s = 1
            else: s = (2**(i+1))
            gt = gt_albedo.data.numpy()
            n,c,h,w = gt.shape
            gt = gt[0,:,:,:]
            gt = gt.transpose((1,2,0))
            gt = cv2.resize(gt, (h//s, w//s))
#             gt = cv2.resize(gt, (h,w))
            
            gt = gt.transpose((2,0,1))
            gt = gt[np.newaxis, :]
            gt = Variable(torch.from_numpy(gt))
            if use_gpu: gt = gt.cuda()

            loss = mse_losses[i](ft_test[i], gt)
            
            test_losses_trainphase[i] += loss.data.cpu().numpy()[0]
            test_cnts_trainphase[i] += 1
            v = v[0].cpu().data.numpy()
            v = v.transpose(1,2,0)
            if ind == 0: cv2.imwrite('snapshot{}/test-phase_train-{}-{}.png'.format(args.gpu_num, epoch, i), v[:,:,::-1]*255)

    
    net.eval()
    test_losses = [0] * len(args.training_thresholds)
    test_cnts   = [0.00001] * len(args.training_thresholds)   
    for ind, data in enumerate(test_loader, 0):
#         if ind == 1: break
        input_img, gt_albedo, gt_shading, test_scene, img_path = data
        input_img = Variable(input_img)
        gt_albedo = Variable(gt_albedo)
        gt_shading = Variable(gt_shading)
        if use_gpu:
            input_img = input_img.cuda(args.gpu_num)
            
#         pretrained.eval(); ft_pretreained = pretrained(input_img)
        ft_test = net(input_img)
            
        for i,v in enumerate(ft_test):
            if epoch < args.training_thresholds[i]: continue
            if i == 5: s = 1
            else: s = (2**(i+1))
            gt = gt_albedo.data.numpy()
            n,c,h,w = gt.shape
            gt = gt[0,:,:,:]
            gt = gt.transpose((1,2,0))
            gt = cv2.resize(gt, (h//s, w//s))
#             gt = cv2.resize(gt, (h,w))
            
            gt = gt.transpose((2,0,1))
            gt = gt[np.newaxis, :]
            gt = Variable(torch.from_numpy(gt))
            if use_gpu: gt = gt.cuda()

            loss = mse_losses[i](ft_test[i], gt)
            
            test_losses[i] += loss.data.cpu().numpy()[0]
            test_cnts[i] += 1
            v = v[0].cpu().data.numpy()
            v = v.transpose(1,2,0)
            if ind == 0: cv2.imwrite('snapshot{}/test-phase_test-{}-{}.png'.format(args.gpu_num, epoch, i), v[:,:,::-1]*255)
    
    writer.add_scalars('16M loss', {
        'train 16M ': np.array([run_losses[0]/ run_cnts[0]]),
        'test_trainphase 16M ': np.array([test_losses_trainphase[0]/ test_cnts_trainphase[0]]),
        'test 16M ': np.array([test_losses[0]/ test_cnts[0]])
    }, global_step=epoch)  
    writer.add_scalars('8M loss', {
        'train 8M ': np.array([run_losses[1]/ run_cnts[1]]),
        'test_trainphase 8M ': np.array([test_losses_trainphase[1]/ test_cnts_trainphase[1]]),
        'test 8M ': np.array([test_losses[1]/ test_cnts[1]])
    }, global_step=epoch) 
    writer.add_scalars('4M loss', {
        'train 4M ': np.array([run_losses[2]/ run_cnts[2]]),
        'test_trainphase 4M ': np.array([test_losses_trainphase[2]/ test_cnts_trainphase[2]]),
        'test 4M ': np.array([test_losses[2]/ test_cnts[2]])
    }, global_step=epoch) 
    writer.add_scalars('2M loss', {
        'train 2M ': np.array([run_losses[3]/ run_cnts[3]]),
        'test_trainphase 2M ': np.array([test_losses_trainphase[3]/ test_cnts_trainphase[3]]),
        'test 2M ': np.array([test_losses[3]/ test_cnts[3]])
    }, global_step=epoch) 
    writer.add_scalars('1M loss', {
        'train 1M ': np.array([run_losses[4]/ run_cnts[4]]),
        'test_trainphase 1M ': np.array([test_losses_trainphase[4]/ test_cnts_trainphase[4]]),
        'test 1M ': np.array([test_losses[4]/ test_cnts[4]])
    }, global_step=epoch) 
    writer.add_scalars('merged loss', {
        'train merged ': np.array([run_losses[5]/ run_cnts[5]]),
        'test_trainphase merged ': np.array([test_losses_trainphase[5]/ test_cnts_trainphase[5]]),
        'test merged ': np.array([test_losses[5]/ test_cnts[5]])
    }, global_step=epoch)


epoch: 0 [2017-11-23 15:43:45]
lr 0.05
 16M: 0.000000  8M: 0.000000  4M: 0.000000  2M: 0.000000  1M: 0.056058 merged: 0.000000
epoch: 1 [2017-11-23 15:45:47]
lr 0.044721359549995794
 16M: 0.000000  8M: 0.000000  4M: 0.000000  2M: 0.000000  1M: 0.051880 merged: 0.000000
epoch: 2 [2017-11-23 15:47:50]
lr 0.038729833462074176
 16M: 0.000000  8M: 0.000000  4M: 0.000000  2M: 0.000000  1M: 0.046272 merged: 0.000000
epoch: 3 [2017-11-23 15:49:52]
lr 0.0316227766016838
 16M: 0.000000  8M: 0.000000  4M: 0.000000  2M: 0.000000  1M: 0.046587 merged: 0.000000
epoch: 4 [2017-11-23 15:51:54]
lr 0.022360679774997897
 16M: 0.000000  8M: 0.000000  4M: 0.000000  2M: 0.000000  1M: 0.042295 merged: 0.000000
epoch: 5 [2017-11-23 15:53:56]
lr 1e-08
 16M: 0.000000  8M: 0.000000  4M: 0.000000  2M: 0.000000  1M: 0.042449 merged: 0.000000
epoch: 6 [2017-11-23 15:55:58]
lr 1e-08
 16M: 0.000000  8M: 0.000000  4M: 0.000000  2M: 0.053034  1M: 0.045342 merged: 0.000000
epoch: 7 [2017-11-23 15:58:20]
lr 0.044721359549995794
 16M: 0.000000  8M: 0.000000  4M: 0.000000  2M: 0.044935  1M: 0.040145 merged: 0.000000
epoch: 8 [2017-11-23 16:00:48]
lr 0.038729833462074176
 16M: 0.000000  8M: 0.000000  4M: 0.000000  2M: 0.040160  1M: 0.035718 merged: 0.000000
epoch: 9 [2017-11-23 16:03:19]
lr 0.0316227766016838
 16M: 0.000000  8M: 0.000000  4M: 0.000000  2M: 0.032963  1M: 0.032910 merged: 0.000000
epoch: 10 [2017-11-23 16:05:46]
lr 0.022360679774997897
 16M: 0.000000  8M: 0.000000  4M: 0.000000  2M: 0.024377  1M: 0.027546 merged: 0.000000
epoch: 11 [2017-11-23 16:08:14]
lr 1e-08
 16M: 0.000000  8M: 0.000000  4M: 0.000000  2M: 0.022725  1M: 0.026383 merged: 0.000000
epoch: 12 [2017-11-23 16:10:39]
lr 1e-08
 16M: 0.000000  8M: 0.000000  4M: 0.048872  2M: 0.030331  1M: 0.033938 merged: 0.000000
epoch: 13 [2017-11-23 16:13:23]
lr 0.044721359549995794
 16M: 0.000000  8M: 0.000000  4M: 0.038449  2M: 0.025062  1M: 0.031121 merged: 0.000000
epoch: 14 [2017-11-23 16:16:04]
lr 0.038729833462074176
 16M: 0.000000  8M: 0.000000  4M: 0.032484  2M: 0.022485  1M: 0.028771 merged: 0.000000
epoch: 15 [2017-11-23 16:18:43]
lr 0.0316227766016838
 16M: 0.000000  8M: 0.000000  4M: 0.024987  2M: 0.019907  1M: 0.026140 merged: 0.000000
epoch: 16 [2017-11-23 16:21:25]
lr 0.022360679774997897
 16M: 0.000000  8M: 0.000000  4M: 0.019580  2M: 0.018084  1M: 0.024596 merged: 0.000000
epoch: 17 [2017-11-23 16:24:08]
lr 1e-08
 16M: 0.000000  8M: 0.000000  4M: 0.017865  2M: 0.017634  1M: 0.024014 merged: 0.000000
epoch: 18 [2017-11-23 16:26:53]
lr 1e-08
 16M: 0.000000  8M: 0.046292  4M: 0.021528  2M: 0.019886  1M: 0.027491 merged: 0.000000
epoch: 19 [2017-11-23 16:29:53]
lr 0.044721359549995794
 16M: 0.000000  8M: 0.034180  4M: 0.019712  2M: 0.018680  1M: 0.025950 merged: 0.000000
epoch: 20 [2017-11-23 16:32:47]
lr 0.038729833462074176
 16M: 0.000000  8M: 0.027425  4M: 0.017702  2M: 0.017545  1M: 0.024946 merged: 0.000000
epoch: 21 [2017-11-23 16:35:40]
lr 0.0316227766016838
 16M: 0.000000  8M: 0.021148  4M: 0.014970  2M: 0.016295  1M: 0.023539 merged: 0.000000
epoch: 22 [2017-11-23 16:38:34]
lr 0.022360679774997897
 16M: 0.000000  8M: 0.017393  4M: 0.013624  2M: 0.015425  1M: 0.021782 merged: 0.000000
epoch: 23 [2017-11-23 16:41:30]
lr 1e-08
 16M: 0.000000  8M: 0.016563  4M: 0.013344  2M: 0.015648  1M: 0.022171 merged: 0.000000
epoch: 24 [2017-11-23 16:44:22]
lr 1e-08
 16M: 0.045881  8M: 0.021760  4M: 0.015769  2M: 0.017259  1M: 0.024196 merged: 0.000000
epoch: 25 [2017-11-23 16:47:27]
lr 0.044721359549995794
 16M: 0.034483  8M: 0.018814  4M: 0.015001  2M: 0.018045  1M: 0.023449 merged: 0.000000
epoch: 26 [2017-11-23 16:50:31]
lr 0.038729833462074176
 16M: 0.026231  8M: 0.015341  4M: 0.013157  2M: 0.015237  1M: 0.022502 merged: 0.000000
epoch: 27 [2017-11-23 16:53:34]
lr 0.0316227766016838
 16M: 0.021491  8M: 0.014331  4M: 0.012021  2M: 0.014612  1M: 0.021633 merged: 0.000000
epoch: 28 [2017-11-23 16:56:37]
lr 0.022360679774997897
 16M: 0.018616  8M: 0.012797  4M: 0.011214  2M: 0.014085  1M: 0.020879 merged: 0.000000
epoch: 29 [2017-11-23 16:59:42]
lr 1e-08
 16M: 0.017217  8M: 0.012291  4M: 0.011436  2M: 0.013808  1M: 0.021218 merged: 0.000000
epoch: 30 [2017-11-23 17:02:48]
lr 0.05
 16M: 0.022097  8M: 0.018941  4M: 0.014382  2M: 0.015738  1M: 0.023189 merged: 0.032934
epoch: 31 [2017-11-23 17:07:23]
lr 0.04971830761761256
 16M: 0.022289  8M: 0.020632  4M: 0.018121  2M: 0.016201  1M: 0.022858 merged: 0.024452
epoch: 32 [2017-11-23 17:11:59]
lr 0.04943501011144937
 16M: 0.019645  8M: 0.017884  4M: 0.014344  2M: 0.015799  1M: 0.022267 merged: 0.019653
epoch: 33 [2017-11-23 17:16:33]
lr 0.04915007972606608
 16M: 0.018215  8M: 0.017073  4M: 0.013958  2M: 0.015396  1M: 0.021301 merged: 0.018116
epoch: 34 [2017-11-23 17:21:05]
lr 0.04886348789677424
 16M: 0.017232  8M: 0.015670  4M: 0.013122  2M: 0.014528  1M: 0.020686 merged: 0.016650
epoch: 35 [2017-11-23 17:25:40]
lr 0.04857520521621862
 16M: 0.015870  8M: 0.014668  4M: 0.012268  2M: 0.014308  1M: 0.020752 merged: 0.014634
epoch: 36 [2017-11-23 17:30:13]
lr 0.04828520139915856
 16M: 0.015364  8M: 0.013588  4M: 0.011864  2M: 0.013911  1M: 0.020976 merged: 0.014143
epoch: 37 [2017-11-23 17:34:47]
lr 0.047993445245333805
 16M: 0.014991  8M: 0.013339  4M: 0.011456  2M: 0.013406  1M: 0.020021 merged: 0.013529
epoch: 38 [2017-11-23 17:39:22]
lr 0.0476999046002862
 16M: 0.013803  8M: 0.012546  4M: 0.010923  2M: 0.013094  1M: 0.020204 merged: 0.012384
epoch: 39 [2017-11-23 17:43:55]
lr 0.04740454631399772
 16M: 0.013543  8M: 0.012350  4M: 0.010598  2M: 0.012525  1M: 0.019222 merged: 0.012050
epoch: 40 [2017-11-23 17:48:30]
lr 0.04710733619719444
 16M: 0.012430  8M: 0.011563  4M: 0.010619  2M: 0.012437  1M: 0.019094 merged: 0.011689
epoch: 41 [2017-11-23 17:53:03]
lr 0.04680823897515326
 16M: 0.012874  8M: 0.011752  4M: 0.010253  2M: 0.012691  1M: 0.019309 merged: 0.011487
epoch: 42 [2017-11-23 17:57:38]
lr 0.04650721823883479
 16M: 0.012451  8M: 0.011617  4M: 0.010098  2M: 0.012232  1M: 0.018393 merged: 0.011269
epoch: 43 [2017-11-23 18:02:13]
lr 0.046204236393150765
 16M: 0.011622  8M: 0.010918  4M: 0.009663  2M: 0.011823  1M: 0.018678 merged: 0.010595
epoch: 44 [2017-11-23 18:06:47]
lr 0.045899254602157845
 16M: 0.010951  8M: 0.010943  4M: 0.009333  2M: 0.011714  1M: 0.018126 merged: 0.010387
epoch: 45 [2017-11-23 18:11:20]
lr 0.04559223273095164
 16M: 0.010823  8M: 0.010502  4M: 0.009166  2M: 0.011789  1M: 0.018141 merged: 0.009930
epoch: 46 [2017-11-23 18:15:54]
lr 0.045283129284014914
 16M: 0.011014  8M: 0.010371  4M: 0.009545  2M: 0.011898  1M: 0.018270 merged: 0.009982
epoch: 47 [2017-11-23 18:20:27]
lr 0.04497190133975169
 16M: 0.010043  8M: 0.009864  4M: 0.008781  2M: 0.011564  1M: 0.017495 merged: 0.009286
epoch: 48 [2017-11-23 18:25:02]
lr 0.04465850448091506
 16M: 0.010538  8M: 0.009809  4M: 0.008735  2M: 0.011423  1M: 0.017854 merged: 0.009408
epoch: 49 [2017-11-23 18:29:36]
lr 0.044342892720609255
 16M: 0.010068  8M: 0.009730  4M: 0.008560  2M: 0.011121  1M: 0.017976 merged: 0.009110
epoch: 50 [2017-11-23 18:34:11]
lr 0.044025018423517
 16M: 0.010119  8M: 0.009812  4M: 0.008488  2M: 0.010892  1M: 0.016973 merged: 0.009070
epoch: 51 [2017-11-23 18:38:47]
lr 0.04370483222197017
 16M: 0.009827  8M: 0.009495  4M: 0.008716  2M: 0.010951  1M: 0.017957 merged: 0.009002
epoch: 52 [2017-11-23 18:43:21]
lr 0.043382282926444894
 16M: 0.009668  8M: 0.009283  4M: 0.008417  2M: 0.011161  1M: 0.017944 merged: 0.008610
epoch: 53 [2017-11-23 18:47:56]
lr 0.04305731743002185
 16M: 0.009544  8M: 0.009255  4M: 0.008324  2M: 0.010754  1M: 0.017366 merged: 0.008403
epoch: 54 [2017-11-23 18:52:32]
lr 0.04272988060630656
 16M: 0.009358  8M: 0.009132  4M: 0.008380  2M: 0.010600  1M: 0.016970 merged: 0.008224
epoch: 55 [2017-11-23 18:57:08]
lr 0.04239991520025441
 16M: 0.009780  8M: 0.011067  4M: 0.008294  2M: 0.010311  1M: 0.016412 merged: 0.008909
epoch: 56 [2017-11-23 19:01:44]
lr 0.0420673617112877
 16M: 0.008849  8M: 0.009170  4M: 0.008148  2M: 0.010545  1M: 0.016315 merged: 0.008069
epoch: 57 [2017-11-23 19:06:19]
lr 0.041732158268029534
 16M: 0.008856  8M: 0.009199  4M: 0.007760  2M: 0.009906  1M: 0.016076 merged: 0.008094
epoch: 58 [2017-11-23 19:11:00]
lr 0.041394240493907074
 16M: 0.008529  8M: 0.008850  4M: 0.007889  2M: 0.010139  1M: 0.016184 merged: 0.007746
epoch: 59 [2017-11-23 19:15:57]
lr 0.041053541362798006
 16M: 0.008719  8M: 0.008468  4M: 0.007833  2M: 0.010246  1M: 0.016454 merged: 0.007697
epoch: 60 [2017-11-23 19:20:53]
lr 0.04070999104380296
 16M: 0.008880  8M: 0.008808  4M: 0.007895  2M: 0.010165  1M: 0.016242 merged: 0.007795
epoch: 61 [2017-11-23 19:25:48]
lr 0.04036351673412598
 16M: 0.008262  8M: 0.008258  4M: 0.007371  2M: 0.009833  1M: 0.016173 merged: 0.007369
epoch: 62 [2017-11-23 19:30:43]
lr 0.04001404247893005
 16M: 0.008681  8M: 0.008419  4M: 0.007615  2M: 0.010120  1M: 0.016428 merged: 0.007602
epoch: 63 [2017-11-23 19:35:39]
lr 0.03966148897690515
 16M: 0.008221  8M: 0.008158  4M: 0.007489  2M: 0.009934  1M: 0.015799 merged: 0.007220
epoch: 64 [2017-11-23 19:40:35]
lr 0.03930577337013889
 16M: 0.008182  8M: 0.008792  4M: 0.007424  2M: 0.009824  1M: 0.015303 merged: 0.007441
epoch: 65 [2017-11-23 19:45:32]
lr 0.038946809016712394
 16M: 0.008063  8M: 0.008258  4M: 0.007287  2M: 0.009602  1M: 0.015479 merged: 0.007170
epoch: 66 [2017-11-23 19:50:28]
lr 0.03858450524425343
 16M: 0.008022  8M: 0.007912  4M: 0.007119  2M: 0.009510  1M: 0.015310 merged: 0.006878
epoch: 67 [2017-11-23 19:55:24]
lr 0.03821876708246056
 16M: 0.007770  8M: 0.007996  4M: 0.007095  2M: 0.009557  1M: 0.015357 merged: 0.007024
epoch: 68 [2017-11-23 20:00:19]
lr 0.03784949497236286
 16M: 0.007826  8M: 0.007876  4M: 0.007076  2M: 0.009326  1M: 0.014975 merged: 0.006975
epoch: 69 [2017-11-23 20:05:16]
lr 0.03747658444979307
 16M: 0.007566  8M: 0.007526  4M: 0.007020  2M: 0.009401  1M: 0.015297 merged: 0.006595
epoch: 70 [2017-11-23 20:10:12]
lr 0.0370999258002226
 16M: 0.007621  8M: 0.007536  4M: 0.007027  2M: 0.009430  1M: 0.015218 merged: 0.006706
epoch: 71 [2017-11-23 20:15:08]
lr 0.03671940368172628
 16M: 0.007776  8M: 0.007595  4M: 0.006903  2M: 0.009268  1M: 0.015030 merged: 0.006729
epoch: 72 [2017-11-23 20:20:03]
lr 0.03633489671240478
 16M: 0.007749  8M: 0.007426  4M: 0.006859  2M: 0.009370  1M: 0.015102 merged: 0.006604
epoch: 73 [2017-11-23 20:24:59]
lr 0.03594627701808178
 16M: 0.007699  8M: 0.007535  4M: 0.006974  2M: 0.009440  1M: 0.015315 merged: 0.006599
epoch: 74 [2017-11-23 20:29:54]
lr 0.035553409735498295
 16M: 0.007477  8M: 0.007529  4M: 0.006864  2M: 0.009251  1M: 0.014894 merged: 0.006628
epoch: 75 [2017-11-23 20:34:48]
lr 0.03515615246553262
 16M: 0.007347  8M: 0.007411  4M: 0.006834  2M: 0.009583  1M: 0.015112 merged: 0.006569
epoch: 76 [2017-11-23 20:39:44]
lr 0.03475435467016077
 16M: 0.007415  8M: 0.007307  4M: 0.006775  2M: 0.009308  1M: 0.015063 merged: 0.006323
epoch: 77 [2017-11-23 20:44:38]
lr 0.034347857005916346
 16M: 0.007269  8M: 0.007333  4M: 0.006848  2M: 0.009251  1M: 0.015026 merged: 0.006603
epoch: 78 [2017-11-23 20:49:34]
lr 0.0339364905854808
 16M: 0.007258  8M: 0.007220  4M: 0.006666  2M: 0.008956  1M: 0.014404 merged: 0.006366
epoch: 79 [2017-11-23 20:54:30]
lr 0.03352007615769955
 16M: 0.007157  8M: 0.007034  4M: 0.006645  2M: 0.009114  1M: 0.015016 merged: 0.006249
epoch: 80 [2017-11-23 20:59:25]
lr 0.03309842319473132
 16M: 0.006996  8M: 0.006987  4M: 0.006657  2M: 0.008944  1M: 0.014693 merged: 0.006211
epoch: 81 [2017-11-23 21:04:21]
lr 0.03267132887314317
 16M: 0.006944  8M: 0.006910  4M: 0.006512  2M: 0.009066  1M: 0.014588 merged: 0.006041
epoch: 82 [2017-11-23 21:09:17]
lr 0.03223857693349118
 16M: 0.007132  8M: 0.007678  4M: 0.006554  2M: 0.008824  1M: 0.014452 merged: 0.006443
epoch: 83 [2017-11-23 21:14:12]
lr 0.0317999364001908
 16M: 0.007021  8M: 0.006917  4M: 0.006357  2M: 0.008724  1M: 0.014207 merged: 0.006078
epoch: 84 [2017-11-23 21:19:08]
lr 0.031355160140170396
 16M: 0.007615  8M: 0.009117  4M: 0.006732  2M: 0.008819  1M: 0.014348 merged: 0.007271
epoch: 85 [2017-11-23 21:24:06]
lr 0.03090398323477543
 16M: 0.006850  8M: 0.008041  4M: 0.006478  2M: 0.008665  1M: 0.014156 merged: 0.006324
epoch: 86 [2017-11-23 21:29:02]
lr 0.030446121134470178
 16M: 0.006553  8M: 0.007107  4M: 0.006280  2M: 0.008504  1M: 0.013893 merged: 0.005896
epoch: 87 [2017-11-23 21:33:58]
lr 0.02998126755983446
 16M: 0.006618  8M: 0.007023  4M: 0.006275  2M: 0.008690  1M: 0.014162 merged: 0.005854
epoch: 88 [2017-11-23 21:38:54]
lr 0.029509092104873926
 16M: 0.006465  8M: 0.006821  4M: 0.006190  2M: 0.008495  1M: 0.013992 merged: 0.005843
epoch: 89 [2017-11-23 21:43:50]
lr 0.029029237489356888
 16M: 0.006480  8M: 0.006698  4M: 0.006218  2M: 0.008656  1M: 0.014200 merged: 0.005727
epoch: 90 [2017-11-23 21:48:46]
lr 0.028541316395237167
 16M: 0.006511  8M: 0.006705  4M: 0.006334  2M: 0.008621  1M: 0.013974 merged: 0.005821
epoch: 91 [2017-11-23 21:53:42]
lr 0.028044907807525134
 16M: 0.006338  8M: 0.006506  4M: 0.006075  2M: 0.008419  1M: 0.014020 merged: 0.005649
epoch: 92 [2017-11-23 21:58:38]
lr 0.027539552761294706
 16M: 0.006290  8M: 0.006515  4M: 0.006009  2M: 0.008214  1M: 0.013530 merged: 0.005602
epoch: 93 [2017-11-23 22:03:34]
lr 0.027024749372597065
 16M: 0.006430  8M: 0.006689  4M: 0.006257  2M: 0.008718  1M: 0.014264 merged: 0.005734
epoch: 94 [2017-11-23 22:08:32]
lr 0.026499947000159004
 16M: 0.007290  8M: 0.006412  4M: 0.006084  2M: 0.008425  1M: 0.013873 merged: 0.005649
epoch: 95 [2017-11-23 22:13:28]
lr 0.02596453934447493
 16M: 0.006615  8M: 0.006373  4M: 0.005954  2M: 0.008327  1M: 0.013809 merged: 0.005577
epoch: 96 [2017-11-23 22:18:29]
lr 0.025417856237895775
 16M: 0.006593  8M: 0.006396  4M: 0.006064  2M: 0.008390  1M: 0.013943 merged: 0.005517
epoch: 97 [2017-11-23 22:23:28]
lr 0.02485915380880628
 16M: 0.006191  8M: 0.006201  4M: 0.005834  2M: 0.008179  1M: 0.013636 merged: 0.005394
epoch: 98 [2017-11-23 22:28:23]
lr 0.02428760260810931
 16M: 0.006156  8M: 0.006197  4M: 0.005894  2M: 0.008188  1M: 0.013539 merged: 0.005435
epoch: 99 [2017-11-23 22:33:18]
lr 0.02370227315699886
 16M: 0.006187  8M: 0.006256  4M: 0.005862  2M: 0.008141  1M: 0.013511 merged: 0.005425
epoch: 100 [2017-11-23 22:38:16]
lr 0.023102118196575382
 16M: 0.006319  8M: 0.006354  4M: 0.005982  2M: 0.008369  1M: 0.013979 merged: 0.005475
epoch: 101 [2017-11-23 22:43:12]
lr 0.022485950669875843
 16M: 0.006019  8M: 0.006228  4M: 0.006043  2M: 0.008405  1M: 0.013908 merged: 0.005424
epoch: 102 [2017-11-23 22:48:09]
lr 0.021852416110985085
 16M: 0.006035  8M: 0.006153  4M: 0.005832  2M: 0.008057  1M: 0.013431 merged: 0.005345
epoch: 103 [2017-11-23 22:53:05]
lr 0.021199957600127203
 16M: 0.005944  8M: 0.006157  4M: 0.005877  2M: 0.008068  1M: 0.013269 merged: 0.005300
epoch: 104 [2017-11-23 22:58:02]
lr 0.020526770681399003
 16M: 0.005930  8M: 0.006145  4M: 0.005850  2M: 0.008271  1M: 0.013689 merged: 0.005274
epoch: 105 [2017-11-23 23:02:59]
lr 0.019830744488452574
 16M: 0.005954  8M: 0.006074  4M: 0.005836  2M: 0.008162  1M: 0.013469 merged: 0.005273
epoch: 106 [2017-11-23 23:07:56]
lr 0.01910938354123028
 16M: 0.005944  8M: 0.006123  4M: 0.005794  2M: 0.007980  1M: 0.013311 merged: 0.005316
epoch: 107 [2017-11-23 23:12:53]
lr 0.01835970184086314
 16M: 0.005751  8M: 0.005941  4M: 0.005659  2M: 0.007866  1M: 0.013044 merged: 0.005059
epoch: 108 [2017-11-23 23:17:48]
lr 0.01757807623276631
 16M: 0.005843  8M: 0.006016  4M: 0.005743  2M: 0.007966  1M: 0.012806 merged: 0.005192
epoch: 109 [2017-11-23 23:22:44]
lr 0.016760038078849775
 16M: 0.005730  8M: 0.005968  4M: 0.005735  2M: 0.007918  1M: 0.013251 merged: 0.005191
epoch: 110 [2017-11-23 23:27:40]
lr 0.0158999682000954
 16M: 0.005729  8M: 0.005969  4M: 0.005744  2M: 0.008161  1M: 0.013464 merged: 0.005133
epoch: 111 [2017-11-23 23:32:37]
lr 0.01499063377991723
 16M: 0.005682  8M: 0.005859  4M: 0.005627  2M: 0.007856  1M: 0.013040 merged: 0.005063
epoch: 112 [2017-11-23 23:37:31]
lr 0.014022453903762567
 16M: 0.005602  8M: 0.005801  4M: 0.005664  2M: 0.007948  1M: 0.013200 merged: 0.004981
epoch: 113 [2017-11-23 23:42:27]
lr 0.012982269672237465
 16M: 0.005556  8M: 0.005808  4M: 0.005578  2M: 0.007843  1M: 0.012976 merged: 0.004978
epoch: 114 [2017-11-23 23:47:24]
lr 0.01185113657849943
 16M: 0.005578  8M: 0.005833  4M: 0.005669  2M: 0.008013  1M: 0.013291 merged: 0.004990
epoch: 115 [2017-11-23 23:52:21]
lr 0.010599978800063602
 16M: 0.005560  8M: 0.005740  4M: 0.005586  2M: 0.007908  1M: 0.013183 merged: 0.004920
epoch: 116 [2017-11-23 23:57:20]
lr 0.00917985092043157
 16M: 0.005369  8M: 0.005619  4M: 0.005451  2M: 0.007679  1M: 0.012650 merged: 0.004842
epoch: 117 [2017-11-24 00:02:18]
lr 0.007495316889958615
 16M: 0.005507  8M: 0.005826  4M: 0.005677  2M: 0.007835  1M: 0.013130 merged: 0.004962
epoch: 118 [2017-11-24 00:07:15]
lr 0.005299989400031801
 16M: 0.005497  8M: 0.005714  4M: 0.005579  2M: 0.007766  1M: 0.012775 merged: 0.004856
epoch: 119 [2017-11-24 00:12:13]
lr 1e-08
 16M: 0.005383  8M: 0.005634  4M: 0.005551  2M: 0.007808  1M: 0.012815 merged: 0.004814

Visualize Graph


In [7]:
from graphviz import Digraph
import torch
from torch.autograd import Variable


def make_dot(var, params=None):
    """ Produces Graphviz representation of PyTorch autograd graph
    Blue nodes are the Variables that require grad, orange are Tensors
    saved for backward in torch.autograd.Function
    Args:
        var: output Variable
        params: dict of (name, Variable) to add names to node that
            require grad (TODO: make optional)
    """
    if params is not None:
        assert isinstance(params.values()[0], Variable)
        param_map = {id(v): k for k, v in params.items()}

    node_attr = dict(style='filled',
                     shape='box',
                     align='left',
                     fontsize='12',
                     ranksep='0.1',
                     height='0.2')
    dot = Digraph(node_attr=node_attr, graph_attr=dict(size="10240,10240"), format='svg')
    seen = set()

    def size_to_str(size):
        return '('+(', ').join(['%d' % v for v in size])+')'

    def add_nodes(var):
        if var not in seen:
            if torch.is_tensor(var):
                dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')
            elif hasattr(var, 'variable'):
                u = var.variable
                name = param_map[id(u)] if params is not None else ''
                node_name = '%s\n %s' % (name, size_to_str(u.size()))
                dot.node(str(id(var)), node_name, fillcolor='lightblue')
            else:
                dot.node(str(id(var)), str(type(var).__name__))
            seen.add(var)
            if hasattr(var, 'next_functions'):
                for u in var.next_functions:
                    if u[0] is not None:
                        dot.edge(str(id(u[0])), str(id(var)))
                        add_nodes(u[0])
            if hasattr(var, 'saved_tensors'):
                for t in var.saved_tensors:
                    dot.edge(str(id(t)), str(id(var)))
                    add_nodes(t)
    add_nodes(var.grad_fn)
    return dot

In [8]:
x = Variable(torch.zeros(1,3,256,256))
y = net(x.cuda())
g = make_dot(y[-1])

In [9]:
g.render('net-transition_scale_{}'.format(transition_scale))


---------------------------------------------------------------------------
FileNotFoundError                         Traceback (most recent call last)
~/anaconda3/lib/python3.6/site-packages/graphviz/backend.py in render(engine, format, filepath, quiet)
    123         try:
--> 124             subprocess.check_call(args, startupinfo=STARTUPINFO, stderr=stderr)
    125         except OSError as e:

~/anaconda3/lib/python3.6/subprocess.py in check_call(*popenargs, **kwargs)
    285     """
--> 286     retcode = call(*popenargs, **kwargs)
    287     if retcode:

~/anaconda3/lib/python3.6/subprocess.py in call(timeout, *popenargs, **kwargs)
    266     """
--> 267     with Popen(*popenargs, **kwargs) as p:
    268         try:

~/anaconda3/lib/python3.6/subprocess.py in __init__(self, args, bufsize, executable, stdin, stdout, stderr, preexec_fn, close_fds, shell, cwd, env, universal_newlines, startupinfo, creationflags, restore_signals, start_new_session, pass_fds, encoding, errors)
    708                                 errread, errwrite,
--> 709                                 restore_signals, start_new_session)
    710         except:

~/anaconda3/lib/python3.6/subprocess.py in _execute_child(self, args, executable, preexec_fn, close_fds, pass_fds, cwd, env, startupinfo, creationflags, shell, p2cread, p2cwrite, c2pread, c2pwrite, errread, errwrite, restore_signals, start_new_session)
   1343                             err_msg += ': ' + repr(err_filename)
-> 1344                     raise child_exception_type(errno_num, err_msg, err_filename)
   1345                 raise child_exception_type(err_msg)

FileNotFoundError: [Errno 2] No such file or directory: 'dot': 'dot'

During handling of the above exception, another exception occurred:

ExecutableNotFound                        Traceback (most recent call last)
<ipython-input-9-81727d6bc75d> in <module>()
----> 1 g.render('net-transition_scale_{}'.format(transition_scale))

~/anaconda3/lib/python3.6/site-packages/graphviz/files.py in render(self, filename, directory, view, cleanup)
    173         filepath = self.save(filename, directory)
    174 
--> 175         rendered = backend.render(self._engine, self._format, filepath)
    176 
    177         if cleanup:

~/anaconda3/lib/python3.6/site-packages/graphviz/backend.py in render(engine, format, filepath, quiet)
    125         except OSError as e:
    126             if e.errno == errno.ENOENT:
--> 127                 raise ExecutableNotFound(args)
    128             else:  # pragma: no cover
    129                 raise

ExecutableNotFound: failed to execute ['dot', '-Tsvg', '-O', 'net-transition_scale_8'], make sure the Graphviz executables are on your systems' PATH

In [ ]: