In [1]:
import os, glob, platform, datetime, random
from collections import OrderedDict

import torch
import torch.nn as nn
import torch.utils.data as data_utils
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
from torch.autograd import Variable
from torch import functional as F
# import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms

import cv2
from PIL import Image
from tensorboardX import SummaryWriter

import numpy as np
from numpy.linalg import inv as denseinv
from scipy import sparse
from scipy.sparse import lil_matrix, csr_matrix
from scipy.sparse.linalg import spsolve
from scipy.sparse.linalg import inv as spinv
import scipy.misc

from myimagefolder import MyImageFolder
from mymodel import PreTrainedModel, GradientNet
from myargs import Args

Configurations


In [2]:
args = Args()
args.test_scene = 'alley_1'
args.arch = "densenet121"
args.epoches = 500
args.epoches_unary_threshold = 0
args.image_h = 256
args.image_w = 256
args.img_extentions = ["png"]
args.training_thresholds = [250,200,150,50,0,300]
args.base_lr = 1
args.lr = args.base_lr
args.snapshot_interval = 5000
args.debug = True
args.gpu_num = 2
growth_rate = (4*(2**(args.gpu_num)))
args.display_interval = 50
args.display_curindex = 0

system_ = platform.system()
system_dist, system_version, _ = platform.dist()
if system_ == "Darwin": 
    args.train_dir = '/Volumes/Transcend/dataset/sintel2'
    args.pretrained = False
elif platform.dist() ==  ('debian', 'jessie/sid', ''):
    args.train_dir = '/home/lwp/workspace/sintel2'
    args.pretrained = True
elif platform.dist() == ('debian', 'stretch/sid', ''):
    args.train_dir = '/home/cad/lwp/workspace/dataset/sintel2'
    args.pretrained = True

if platform.system() == 'Linux': use_gpu = True
else: use_gpu = False
if use_gpu:
    torch.cuda.set_device(args.gpu_num)
    

print(platform.dist())


('debian', 'jessie/sid', '')

My DataLoader


In [3]:
train_dataset = MyImageFolder(args.train_dir, 'train',
                       transforms.Compose(
        [transforms.ToTensor()]
    ), random_crop=True, 
    img_extentions=args.img_extentions, test_scene=args.test_scene, image_h=args.image_h, image_w=args.image_w)
test_dataset = MyImageFolder(args.train_dir, 'test', 
                       transforms.Compose(
        [transforms.CenterCrop((args.image_h, args.image_w)),
         transforms.ToTensor()]
    ), random_crop=False,
    img_extentions=args.img_extentions, test_scene=args.test_scene, image_h=args.image_h, image_w=args.image_w)

train_loader = data_utils.DataLoader(train_dataset,1,True,num_workers=1)
test_loader = data_utils.DataLoader(test_dataset,1,True,num_workers=1)

Load Pretrained Model

Defination

  • DenseNet-121: num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16)
    • First Convolution: 32M -> 16M -> 8M
    • every transition: 8M -> 4M -> 2M (downsample 1/2, except the last block)

In [4]:
densenet = models.__dict__[args.arch](pretrained=args.pretrained)

for param in densenet.parameters():
    param.requires_grad = False

if use_gpu: densenet.cuda()

In [5]:
ss = 6

args.display_curindex = 0
args.base_lr = 0.05
args.display_interval = 20
args.momentum = 0.9
args.epoches = 120
args.training_thresholds = [ss*4,ss*3,ss*2,ss*1,ss*0,ss*5]
args.power = 0.5



pretrained = PreTrainedModel(densenet)
if use_gpu: 
    pretrained.cuda()


net = GradientNet(growth_rate=growth_rate)
if use_gpu:
    net.cuda()

if use_gpu: 
    mse_losses = [nn.MSELoss().cuda()] * 6
    test_losses = [nn.MSELoss().cuda()] * 6
else:
    mse_losses = [nn.MSELoss()] * 6
    test_losses = [nn.MSELoss()] * 6

In [6]:
# training loop
writer = SummaryWriter(comment='-growth_rate_{}'.format(growth_rate))

parameters = filter(lambda p: p.requires_grad, net.parameters())
optimizer = optim.SGD(parameters, lr=args.base_lr, momentum=args.momentum)

def adjust_learning_rate(optimizer, epoch, beg, end, reset_lr=None):
    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
    for param_group in optimizer.param_groups:
        if reset_lr != None:
            param_group['lr'] = reset_lr
            continue
        if epoch != 0: 
            # linear
#             param_group['lr'] *= (end-epoch) / (end-beg)
#             poly base_lr (1 - iter/max_iter) ^ (power)
            param_group['lr'] = args.base_lr * (float(end-epoch)/(end-beg)) ** (args.power)
            if param_group['lr'] < 1.0e-8: param_group['lr'] = 1.0e-8
        print('lr', param_group['lr'])
        
# def findLargerInd(target, arr):
#     res = list(filter(lambda x: x>target, arr))
#     print('res',res)
#     if len(res) == 0: return -1
#     return res[0]

for epoch in range(args.epoches):
    net.train()
    print('epoch: {} [{}]'.format(epoch, datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")))
    
    if epoch < args.training_thresholds[-1]: adjust_learning_rate(optimizer, epoch%ss, beg=0, end=ss-1)
    else: adjust_learning_rate(optimizer, epoch, beg=args.training_thresholds[-1], end=args.epoches-1)
    
    run_losses = [0] * len(args.training_thresholds)
    run_cnts   = [0.00001] * len(args.training_thresholds)
    if (epoch in args.training_thresholds) == True: 
        adjust_learning_rate(optimizer, epoch, reset_lr=args.base_lr, beg=-1, end=-1)
    writer.add_scalar('learning rate', optimizer.param_groups[0]['lr'], global_step=epoch)
    for ind, data in enumerate(train_loader, 0):
#         if  ind == 1 : break
        input_img, gt_albedo, gt_shading, test_scene, img_path = data
        im = input_img[0,:,:,:].numpy(); im = im.transpose(1,2,0); im = im[:,:,::-1]*255
        
        if test_scene[0] == 'alley_1':
            print('alley_1 yes')
        input_img = Variable(input_img)
        gt_albedo = Variable(gt_albedo)
        gt_shading = Variable(gt_shading)
        if use_gpu:
            input_img = input_img.cuda()
            gt_albedo = gt_albedo.cuda()
            gt_shading = gt_shading.cuda()
#         run_losses = [0] * len(mse_losses)
#         run_cnts = [0.00001] * len(mse_losses)
        if args.display_curindex % args.display_interval == 0:
            cv2.imwrite('snapshot{}/input.png'.format(args.gpu_num), im)

        optimizer.zero_grad()
        pretrained.train(); ft_pretreained = pretrained(input_img)
        ft_predict = net(ft_pretreained)
        for i, threshold in enumerate(args.training_thresholds):
#             threshold = args.training_thresholds[i]
            if epoch >= threshold:
#             if epoch >= 0:
                if i == 5: s = 1
                else: s = (2**(i+1))
                gt = gt_albedo.cpu().data.numpy()
                n,c,h,w = gt.shape
                gt = gt[0,:,:,:]
                gt = gt.transpose((1,2,0))
                gt = cv2.resize(gt, (h//s, w//s))
#                 gt = cv2.resize(gt, (h,w))
                if args.display_curindex % args.display_interval == 0:
                    cv2.imwrite('snapshot{}/gt-{}-{}.png'.format(args.gpu_num, epoch, i), gt[:,:,::-1]*255)
                gt = gt.transpose((2,0,1))
                gt = gt[np.newaxis, :]
                gt = Variable(torch.from_numpy(gt))
                if use_gpu: gt = gt.cuda()
                loss = mse_losses[i](ft_predict[i], gt)
                loss_data = loss.data.cpu().numpy()
                writer.add_scalar('{}th train iters loss'.format(i), loss_data, global_step=args.display_curindex)
                ma_ = ft_predict[i].max().cpu().data.numpy()
                mi_ = ft_predict[i].min().cpu().data.numpy()
                #print('mi', mi_, 'ma', ma_)
#                 writer.add_scalars('{}th train predict'.format(i), {'max': ma_, 'min': mi_}, global_step=args.display_curindex)
#                 run_cnts[i] += 1
                run_losses[i] += loss.data.cpu().numpy()[0]
                loss.backward(retain_graph=True)
                run_cnts[i] += 1
#                 print('i = ', i, '; weig\n', net.upsample01.weight[0,0,0:4,0:4].data.cpu().numpy())
#                 print('i = ', i, '; grad\n', net.upsample01.weight.grad[0,0,0:4,0:4].data.cpu().numpy())
                if args.display_curindex % args.display_interval == 0:
                    im = ft_predict[i].cpu().data.numpy()[0].transpose((1,2,0)) * 255
                    cv2.imwrite('snapshot{}/train-{}-{}.png'.format(args.gpu_num, epoch, i), im[:,:,::-1])
        optimizer.step()
        args.display_curindex += 1

    """ every epoch """
#     loss_output = 'ind: ' + str(args.display_curindex)
    loss_output = ''
    
    
    
    for i,v in enumerate(run_losses):
        if i == len(run_losses)-1: 
            loss_output += ' merged: %6f' % (run_losses[i] / run_cnts[i])
            continue
        loss_output += ' %2dM: %6f' % ((2**(4-i)), (run_losses[i] / run_cnts[i]))
    print(loss_output)
    # save at every epoch
    if (epoch+1) % 10 == 0:
        torch.save({
            'epoch': epoch,
            'args' : args,
            'state_dict': net.state_dict(),
            'optimizer': optimizer.state_dict()
        }, 'snapshot{}/snapshot-{}.pth.tar'.format(args.gpu_num, epoch))
    
    # test 
    test_losses_trainphase = [0] * len(args.training_thresholds)
    test_cnts_trainphase   = [0.00001] * len(args.training_thresholds)   
    for ind, data in enumerate(test_loader, 0):
        input_img, gt_albedo, gt_shading, test_scene, img_path = data
        input_img = Variable(input_img)
        gt_albedo = Variable(gt_albedo)
        gt_shading = Variable(gt_shading)
        if use_gpu:
            input_img = input_img.cuda(args.gpu_num)
        
        pretrained.train(); ft_pretreained = pretrained(input_img)
        ft_test = net(ft_pretreained)
            
        for i,v in enumerate(ft_test):
            if epoch < args.training_thresholds[i]: continue
            if i == 5: s = 1
            else: s = (2**(i+1))
            gt = gt_albedo.data.numpy()
            n,c,h,w = gt.shape
            gt = gt[0,:,:,:]
            gt = gt.transpose((1,2,0))
            gt = cv2.resize(gt, (h//s, w//s))
#             gt = cv2.resize(gt, (h,w))
            
            gt = gt.transpose((2,0,1))
            gt = gt[np.newaxis, :]
            gt = Variable(torch.from_numpy(gt))
            if use_gpu: gt = gt.cuda()

            loss = mse_losses[i](ft_test[i], gt)
            
            test_losses_trainphase[i] += loss.data.cpu().numpy()[0]
            test_cnts_trainphase[i] += 1
            v = v[0].cpu().data.numpy()
            v = v.transpose(1,2,0)
            if ind == 0: cv2.imwrite('snapshot{}/test-phase_train-{}-{}.png'.format(args.gpu_num, epoch, i), v[:,:,::-1]*255)

    
    net.eval()
    test_losses = [0] * len(args.training_thresholds)
    test_cnts   = [0.00001] * len(args.training_thresholds)   
    for ind, data in enumerate(test_loader, 0):
#         if ind == 1: break
        input_img, gt_albedo, gt_shading, test_scene, img_path = data
        input_img = Variable(input_img)
        gt_albedo = Variable(gt_albedo)
        gt_shading = Variable(gt_shading)
        if use_gpu:
            input_img = input_img.cuda(args.gpu_num)
            
        pretrained.eval(); ft_pretreained = pretrained(input_img)
        ft_test = net(ft_pretreained)
            
        for i,v in enumerate(ft_test):
            if epoch < args.training_thresholds[i]: continue
            if i == 5: s = 1
            else: s = (2**(i+1))
            gt = gt_albedo.data.numpy()
            n,c,h,w = gt.shape
            gt = gt[0,:,:,:]
            gt = gt.transpose((1,2,0))
            gt = cv2.resize(gt, (h//s, w//s))
#             gt = cv2.resize(gt, (h,w))
            
            gt = gt.transpose((2,0,1))
            gt = gt[np.newaxis, :]
            gt = Variable(torch.from_numpy(gt))
            if use_gpu: gt = gt.cuda()

            loss = mse_losses[i](ft_test[i], gt)
            
            test_losses[i] += loss.data.cpu().numpy()[0]
            test_cnts[i] += 1
            v = v[0].cpu().data.numpy()
            v = v.transpose(1,2,0)
            if ind == 0: cv2.imwrite('snapshot{}/test-phase_test-{}-{}.png'.format(args.gpu_num, epoch, i), v[:,:,::-1]*255)
    
    writer.add_scalars('16M loss', {
        'train 16M ': np.array([run_losses[0]/ run_cnts[0]]),
        'test_trainphase 16M ': np.array([test_losses_trainphase[0]/ test_cnts_trainphase[0]]),
        'test 16M ': np.array([test_losses[0]/ test_cnts[0]])
    }, global_step=epoch)  
    writer.add_scalars('8M loss', {
        'train 8M ': np.array([run_losses[1]/ run_cnts[1]]),
        'test_trainphase 8M ': np.array([test_losses_trainphase[1]/ test_cnts_trainphase[1]]),
        'test 8M ': np.array([test_losses[1]/ test_cnts[1]])
    }, global_step=epoch) 
    writer.add_scalars('4M loss', {
        'train 4M ': np.array([run_losses[2]/ run_cnts[2]]),
        'test_trainphase 4M ': np.array([test_losses_trainphase[2]/ test_cnts_trainphase[2]]),
        'test 4M ': np.array([test_losses[2]/ test_cnts[2]])
    }, global_step=epoch) 
    writer.add_scalars('2M loss', {
        'train 2M ': np.array([run_losses[3]/ run_cnts[3]]),
        'test_trainphase 2M ': np.array([test_losses_trainphase[3]/ test_cnts_trainphase[3]]),
        'test 2M ': np.array([test_losses[3]/ test_cnts[3]])
    }, global_step=epoch) 
    writer.add_scalars('1M loss', {
        'train 1M ': np.array([run_losses[4]/ run_cnts[4]]),
        'test_trainphase 1M ': np.array([test_losses_trainphase[4]/ test_cnts_trainphase[4]]),
        'test 1M ': np.array([test_losses[4]/ test_cnts[4]])
    }, global_step=epoch) 
    writer.add_scalars('merged loss', {
        'train merged ': np.array([run_losses[5]/ run_cnts[5]]),
        'test_trainphase merged ': np.array([test_losses_trainphase[5]/ test_cnts_trainphase[5]]),
        'test merged ': np.array([test_losses[5]/ test_cnts[5]])
    }, global_step=epoch)


epoch: 0 [2017-11-21 17:40:12]
lr 0.05
 16M: 0.000000  8M: 0.000000  4M: 0.000000  2M: 0.000000  1M: 0.055675 merged: 0.000000
epoch: 1 [2017-11-21 17:42:09]
lr 0.044721359549995794
 16M: 0.000000  8M: 0.000000  4M: 0.000000  2M: 0.000000  1M: 0.051149 merged: 0.000000
epoch: 2 [2017-11-21 17:44:05]
lr 0.038729833462074176
 16M: 0.000000  8M: 0.000000  4M: 0.000000  2M: 0.000000  1M: 0.047816 merged: 0.000000
epoch: 3 [2017-11-21 17:45:58]
lr 0.0316227766016838
 16M: 0.000000  8M: 0.000000  4M: 0.000000  2M: 0.000000  1M: 0.045797 merged: 0.000000
epoch: 4 [2017-11-21 17:47:52]
lr 0.022360679774997897
 16M: 0.000000  8M: 0.000000  4M: 0.000000  2M: 0.000000  1M: 0.043717 merged: 0.000000
epoch: 5 [2017-11-21 17:49:46]
lr 1e-08
 16M: 0.000000  8M: 0.000000  4M: 0.000000  2M: 0.000000  1M: 0.042538 merged: 0.000000
epoch: 6 [2017-11-21 17:51:40]
lr 1e-08
 16M: 0.000000  8M: 0.000000  4M: 0.000000  2M: 0.051502  1M: 0.046703 merged: 0.000000
epoch: 7 [2017-11-21 17:53:51]
lr 0.044721359549995794
 16M: 0.000000  8M: 0.000000  4M: 0.000000  2M: 0.043417  1M: 0.041166 merged: 0.000000
epoch: 8 [2017-11-21 17:56:06]
lr 0.038729833462074176
 16M: 0.000000  8M: 0.000000  4M: 0.000000  2M: 0.040211  1M: 0.036353 merged: 0.000000
epoch: 9 [2017-11-21 17:58:19]
lr 0.0316227766016838
 16M: 0.000000  8M: 0.000000  4M: 0.000000  2M: 0.034678  1M: 0.033623 merged: 0.000000
epoch: 10 [2017-11-21 18:00:34]
lr 0.022360679774997897
 16M: 0.000000  8M: 0.000000  4M: 0.000000  2M: 0.026206  1M: 0.028478 merged: 0.000000
epoch: 11 [2017-11-21 18:02:49]
lr 1e-08
 16M: 0.000000  8M: 0.000000  4M: 0.000000  2M: 0.024831  1M: 0.028140 merged: 0.000000
epoch: 12 [2017-11-21 18:05:03]
lr 1e-08
 16M: 0.000000  8M: 0.000000  4M: 0.047170  2M: 0.029928  1M: 0.032065 merged: 0.000000
epoch: 13 [2017-11-21 18:07:36]
lr 0.044721359549995794
 16M: 0.000000  8M: 0.000000  4M: 0.038411  2M: 0.026114  1M: 0.030081 merged: 0.000000
epoch: 14 [2017-11-21 18:10:06]
lr 0.038729833462074176
 16M: 0.000000  8M: 0.000000  4M: 0.026960  2M: 0.021676  1M: 0.026853 merged: 0.000000
epoch: 15 [2017-11-21 18:12:39]
lr 0.0316227766016838
 16M: 0.000000  8M: 0.000000  4M: 0.022625  2M: 0.019653  1M: 0.025259 merged: 0.000000
epoch: 16 [2017-11-21 18:15:08]
lr 0.022360679774997897
 16M: 0.000000  8M: 0.000000  4M: 0.018582  2M: 0.018368  1M: 0.024535 merged: 0.000000
epoch: 17 [2017-11-21 18:17:39]
lr 1e-08
 16M: 0.000000  8M: 0.000000  4M: 0.016654  2M: 0.018087  1M: 0.022356 merged: 0.000000
epoch: 18 [2017-11-21 18:20:09]
lr 1e-08
 16M: 0.000000  8M: 0.046222  4M: 0.021843  2M: 0.022111  1M: 0.027005 merged: 0.000000
epoch: 19 [2017-11-21 18:22:53]
lr 0.044721359549995794
 16M: 0.000000  8M: 0.035752  4M: 0.017887  2M: 0.019086  1M: 0.026210 merged: 0.000000
epoch: 20 [2017-11-21 18:25:39]
lr 0.038729833462074176
 16M: 0.000000  8M: 0.026925  4M: 0.017360  2M: 0.018221  1M: 0.025583 merged: 0.000000
epoch: 21 [2017-11-21 18:28:26]
lr 0.0316227766016838
 16M: 0.000000  8M: 0.020912  4M: 0.014752  2M: 0.016614  1M: 0.023896 merged: 0.000000
epoch: 22 [2017-11-21 18:31:09]
lr 0.022360679774997897
 16M: 0.000000  8M: 0.018442  4M: 0.013032  2M: 0.015421  1M: 0.021267 merged: 0.000000
epoch: 23 [2017-11-21 18:33:56]
lr 1e-08
 16M: 0.000000  8M: 0.016457  4M: 0.012792  2M: 0.015016  1M: 0.022399 merged: 0.000000
epoch: 24 [2017-11-21 18:36:43]
lr 1e-08
 16M: 0.045104  8M: 0.021023  4M: 0.014820  2M: 0.017318  1M: 0.024788 merged: 0.000000
epoch: 25 [2017-11-21 18:40:23]
lr 0.044721359549995794
 16M: 0.032688  8M: 0.017705  4M: 0.013232  2M: 0.016442  1M: 0.023403 merged: 0.000000
epoch: 26 [2017-11-21 18:44:04]
lr 0.038729833462074176
 16M: 0.023367  8M: 0.015331  4M: 0.012391  2M: 0.015305  1M: 0.022479 merged: 0.000000
epoch: 27 [2017-11-21 18:47:44]
lr 0.0316227766016838
 16M: 0.019320  8M: 0.014150  4M: 0.011535  2M: 0.014656  1M: 0.021483 merged: 0.000000
epoch: 28 [2017-11-21 18:51:25]
lr 0.022360679774997897
 16M: 0.016134  8M: 0.012808  4M: 0.010884  2M: 0.013975  1M: 0.020820 merged: 0.000000
epoch: 29 [2017-11-21 18:55:08]
lr 1e-08
 16M: 0.014639  8M: 0.012304  4M: 0.010830  2M: 0.013701  1M: 0.020554 merged: 0.000000
epoch: 30 [2017-11-21 18:58:48]
lr 0.05
 16M: 0.019807  8M: 0.015587  4M: 0.013311  2M: 0.015262  1M: 0.022798 merged: 0.033898
epoch: 31 [2017-11-21 19:04:30]
lr 0.04971830761761256
 16M: 0.021134  8M: 0.018020  4M: 0.017512  2M: 0.016947  1M: 0.023818 merged: 0.025660
epoch: 32 [2017-11-21 19:10:12]
lr 0.04943501011144937
 16M: 0.018237  8M: 0.015457  4M: 0.015591  2M: 0.016087  1M: 0.023724 merged: 0.019927
epoch: 33 [2017-11-21 19:15:51]
lr 0.04915007972606608
 16M: 0.015477  8M: 0.013243  4M: 0.014120  2M: 0.015379  1M: 0.022727 merged: 0.017099
epoch: 34 [2017-11-21 19:21:29]
lr 0.04886348789677424
 16M: 0.014848  8M: 0.012905  4M: 0.013251  2M: 0.014775  1M: 0.021883 merged: 0.015844
epoch: 35 [2017-11-21 19:27:08]
lr 0.04857520521621862
 16M: 0.013462  8M: 0.012393  4M: 0.013204  2M: 0.014309  1M: 0.022112 merged: 0.014391
epoch: 36 [2017-11-21 19:32:50]
lr 0.04828520139915856
 16M: 0.013200  8M: 0.012010  4M: 0.012737  2M: 0.014412  1M: 0.021432 merged: 0.013891
epoch: 37 [2017-11-21 19:38:28]
lr 0.047993445245333805
 16M: 0.012608  8M: 0.011597  4M: 0.011942  2M: 0.014167  1M: 0.021241 merged: 0.013088
epoch: 38 [2017-11-21 19:44:13]
lr 0.0476999046002862
 16M: 0.011706  8M: 0.011229  4M: 0.011677  2M: 0.013875  1M: 0.020891 merged: 0.012449
epoch: 39 [2017-11-21 19:50:00]
lr 0.04740454631399772
 16M: 0.011965  8M: 0.011265  4M: 0.011734  2M: 0.013606  1M: 0.020734 merged: 0.012459
epoch: 40 [2017-11-21 19:55:46]
lr 0.04710733619719444
 16M: 0.011725  8M: 0.011536  4M: 0.011360  2M: 0.013133  1M: 0.020671 merged: 0.012590
epoch: 41 [2017-11-21 20:01:29]
lr 0.04680823897515326
 16M: 0.011022  8M: 0.010751  4M: 0.011096  2M: 0.013028  1M: 0.019886 merged: 0.011521
epoch: 42 [2017-11-21 20:07:13]
lr 0.04650721823883479
 16M: 0.010033  8M: 0.009896  4M: 0.010530  2M: 0.012490  1M: 0.019323 merged: 0.010513
epoch: 43 [2017-11-21 20:12:58]
lr 0.046204236393150765
 16M: 0.009784  8M: 0.009687  4M: 0.009982  2M: 0.012326  1M: 0.018133 merged: 0.010449
epoch: 44 [2017-11-21 20:18:42]
lr 0.045899254602157845
 16M: 0.010013  8M: 0.009571  4M: 0.010441  2M: 0.012549  1M: 0.019095 merged: 0.010364
epoch: 45 [2017-11-21 20:24:25]
lr 0.04559223273095164
 16M: 0.010054  8M: 0.009621  4M: 0.010124  2M: 0.012044  1M: 0.018718 merged: 0.010388
epoch: 46 [2017-11-21 20:30:09]
lr 0.045283129284014914
 16M: 0.009151  8M: 0.009214  4M: 0.009559  2M: 0.011636  1M: 0.018684 merged: 0.009585
epoch: 47 [2017-11-21 20:35:53]
lr 0.04497190133975169
 16M: 0.009445  8M: 0.009208  4M: 0.009806  2M: 0.012031  1M: 0.018623 merged: 0.009707
epoch: 48 [2017-11-21 20:41:40]
lr 0.04465850448091506
 16M: 0.008768  8M: 0.009005  4M: 0.009610  2M: 0.011840  1M: 0.018543 merged: 0.009370
epoch: 49 [2017-11-21 20:47:22]
lr 0.044342892720609255
 16M: 0.008494  8M: 0.008759  4M: 0.009211  2M: 0.011316  1M: 0.017586 merged: 0.008974
epoch: 50 [2017-11-21 20:53:04]
lr 0.044025018423517
 16M: 0.008457  8M: 0.008857  4M: 0.009342  2M: 0.011512  1M: 0.018494 merged: 0.008887
epoch: 51 [2017-11-21 20:58:47]
lr 0.04370483222197017
 16M: 0.008515  8M: 0.008677  4M: 0.009156  2M: 0.011278  1M: 0.017398 merged: 0.008982
epoch: 52 [2017-11-21 21:04:32]
lr 0.043382282926444894
 16M: 0.008377  8M: 0.008620  4M: 0.009288  2M: 0.011308  1M: 0.017929 merged: 0.008843
epoch: 53 [2017-11-21 21:10:18]
lr 0.04305731743002185
 16M: 0.007963  8M: 0.008280  4M: 0.008792  2M: 0.010947  1M: 0.017299 merged: 0.008353
epoch: 54 [2017-11-21 21:16:02]
lr 0.04272988060630656
 16M: 0.007722  8M: 0.008210  4M: 0.008816  2M: 0.011144  1M: 0.017131 merged: 0.008084
epoch: 55 [2017-11-21 21:21:46]
lr 0.04239991520025441
 16M: 0.007985  8M: 0.008471  4M: 0.008935  2M: 0.010864  1M: 0.016959 merged: 0.008336
epoch: 56 [2017-11-21 21:27:30]
lr 0.0420673617112877
 16M: 0.008147  8M: 0.008269  4M: 0.008704  2M: 0.010730  1M: 0.017090 merged: 0.008324
epoch: 57 [2017-11-21 21:33:14]
lr 0.041732158268029534
 16M: 0.007632  8M: 0.008085  4M: 0.008720  2M: 0.010716  1M: 0.016842 merged: 0.007925
epoch: 58 [2017-11-21 21:39:03]
lr 0.041394240493907074
 16M: 0.007454  8M: 0.007843  4M: 0.008522  2M: 0.010578  1M: 0.016308 merged: 0.007664
epoch: 59 [2017-11-21 21:44:46]
lr 0.041053541362798006
 16M: 0.007308  8M: 0.007783  4M: 0.008319  2M: 0.010497  1M: 0.016683 merged: 0.007510
epoch: 60 [2017-11-21 21:50:32]
lr 0.04070999104380296
 16M: 0.007049  8M: 0.007623  4M: 0.008295  2M: 0.010479  1M: 0.016176 merged: 0.007418
epoch: 61 [2017-11-21 21:56:16]
lr 0.04036351673412598
 16M: 0.007351  8M: 0.007668  4M: 0.008283  2M: 0.010396  1M: 0.016145 merged: 0.007445
epoch: 62 [2017-11-21 22:02:02]
lr 0.04001404247893005
 16M: 0.007179  8M: 0.007600  4M: 0.008291  2M: 0.010378  1M: 0.016245 merged: 0.007516
epoch: 63 [2017-11-21 22:07:46]
lr 0.03966148897690515
 16M: 0.006966  8M: 0.007583  4M: 0.008215  2M: 0.010700  1M: 0.016246 merged: 0.007262
epoch: 64 [2017-11-21 22:13:30]
lr 0.03930577337013889
 16M: 0.006774  8M: 0.007584  4M: 0.008147  2M: 0.010445  1M: 0.016849 merged: 0.006990
epoch: 65 [2017-11-21 22:19:12]
lr 0.038946809016712394
 16M: 0.007006  8M: 0.007454  4M: 0.008043  2M: 0.010467  1M: 0.016380 merged: 0.007242
epoch: 66 [2017-11-21 22:24:56]
lr 0.03858450524425343
 16M: 0.007057  8M: 0.007431  4M: 0.008170  2M: 0.009982  1M: 0.015781 merged: 0.007221
epoch: 67 [2017-11-21 22:30:45]
lr 0.03821876708246056
 16M: 0.006661  8M: 0.007241  4M: 0.007813  2M: 0.010130  1M: 0.015858 merged: 0.006905
epoch: 68 [2017-11-21 22:36:28]
lr 0.03784949497236286
 16M: 0.006431  8M: 0.007237  4M: 0.007769  2M: 0.009974  1M: 0.015775 merged: 0.006685
epoch: 69 [2017-11-21 22:42:12]
lr 0.03747658444979307
 16M: 0.006657  8M: 0.007392  4M: 0.007897  2M: 0.010174  1M: 0.016333 merged: 0.006816
epoch: 70 [2017-11-21 22:47:58]
lr 0.0370999258002226
 16M: 0.006597  8M: 0.007161  4M: 0.007815  2M: 0.010078  1M: 0.015936 merged: 0.006898
epoch: 71 [2017-11-21 22:53:46]
lr 0.03671940368172628
 16M: 0.006318  8M: 0.007090  4M: 0.007576  2M: 0.009832  1M: 0.015603 merged: 0.006552
epoch: 72 [2017-11-21 22:59:32]
lr 0.03633489671240478
 16M: 0.006316  8M: 0.006962  4M: 0.007516  2M: 0.009769  1M: 0.015557 merged: 0.006486
epoch: 73 [2017-11-21 23:05:15]
lr 0.03594627701808178
 16M: 0.006533  8M: 0.007319  4M: 0.007838  2M: 0.009896  1M: 0.015926 merged: 0.006699
epoch: 74 [2017-11-21 23:10:59]
lr 0.035553409735498295
 16M: 0.006393  8M: 0.007198  4M: 0.007772  2M: 0.009962  1M: 0.015648 merged: 0.006607
epoch: 75 [2017-11-21 23:16:43]
lr 0.03515615246553262
 16M: 0.006381  8M: 0.007145  4M: 0.007612  2M: 0.009718  1M: 0.015365 merged: 0.006494
epoch: 76 [2017-11-21 23:22:28]
lr 0.03475435467016077
 16M: 0.006009  8M: 0.006765  4M: 0.007495  2M: 0.009635  1M: 0.015615 merged: 0.006269
epoch: 77 [2017-11-21 23:28:11]
lr 0.034347857005916346
 16M: 0.006128  8M: 0.006925  4M: 0.007472  2M: 0.009682  1M: 0.015352 merged: 0.006359
epoch: 78 [2017-11-21 23:33:58]
lr 0.0339364905854808
 16M: 0.005944  8M: 0.006694  4M: 0.007350  2M: 0.009448  1M: 0.015038 merged: 0.006161
epoch: 79 [2017-11-21 23:39:45]
lr 0.03352007615769955
 16M: 0.005885  8M: 0.006789  4M: 0.007466  2M: 0.009664  1M: 0.015218 merged: 0.006187
epoch: 80 [2017-11-21 23:45:32]
lr 0.03309842319473132
 16M: 0.006021  8M: 0.006783  4M: 0.007357  2M: 0.009482  1M: 0.014993 merged: 0.006183
epoch: 81 [2017-11-21 23:51:16]
lr 0.03267132887314317
 16M: 0.005921  8M: 0.006649  4M: 0.007229  2M: 0.009389  1M: 0.014879 merged: 0.006204
epoch: 82 [2017-11-21 23:57:02]
lr 0.03223857693349118
 16M: 0.005743  8M: 0.006582  4M: 0.007194  2M: 0.009255  1M: 0.014664 merged: 0.006055
epoch: 83 [2017-11-22 00:02:47]
lr 0.0317999364001908
 16M: 0.005821  8M: 0.006689  4M: 0.007204  2M: 0.009409  1M: 0.014965 merged: 0.006059
epoch: 84 [2017-11-22 00:08:34]
lr 0.031355160140170396
 16M: 0.005750  8M: 0.006684  4M: 0.007259  2M: 0.009521  1M: 0.015370 merged: 0.006010
epoch: 85 [2017-11-22 00:14:18]
lr 0.03090398323477543
 16M: 0.005569  8M: 0.006551  4M: 0.007056  2M: 0.009180  1M: 0.014743 merged: 0.005916
epoch: 86 [2017-11-22 00:20:05]
lr 0.030446121134470178
 16M: 0.005582  8M: 0.006473  4M: 0.007113  2M: 0.009176  1M: 0.014678 merged: 0.005861
epoch: 87 [2017-11-22 00:25:48]
lr 0.02998126755983446
 16M: 0.005762  8M: 0.006469  4M: 0.006984  2M: 0.008990  1M: 0.014518 merged: 0.005888
epoch: 88 [2017-11-22 00:31:36]
lr 0.029509092104873926
 16M: 0.005489  8M: 0.006294  4M: 0.006910  2M: 0.009044  1M: 0.014433 merged: 0.005716
epoch: 89 [2017-11-22 00:37:21]
lr 0.029029237489356888
 16M: 0.005282  8M: 0.006197  4M: 0.006795  2M: 0.008963  1M: 0.014579 merged: 0.005575
epoch: 90 [2017-11-22 00:43:07]
lr 0.028541316395237167
 16M: 0.005464  8M: 0.006474  4M: 0.007055  2M: 0.009253  1M: 0.014843 merged: 0.005855
epoch: 91 [2017-11-22 00:48:54]
lr 0.028044907807525134
 16M: 0.005385  8M: 0.006297  4M: 0.006805  2M: 0.009037  1M: 0.014503 merged: 0.005600
epoch: 92 [2017-11-22 00:54:38]
lr 0.027539552761294706
 16M: 0.005684  8M: 0.006399  4M: 0.006901  2M: 0.009110  1M: 0.014584 merged: 0.005837
epoch: 93 [2017-11-22 01:00:22]
lr 0.027024749372597065
 16M: 0.005467  8M: 0.006305  4M: 0.007033  2M: 0.009145  1M: 0.014756 merged: 0.005657
epoch: 94 [2017-11-22 01:06:06]
lr 0.026499947000159004
 16M: 0.005517  8M: 0.006287  4M: 0.006867  2M: 0.008947  1M: 0.014254 merged: 0.005683
epoch: 95 [2017-11-22 01:11:57]
lr 0.02596453934447493
 16M: 0.005177  8M: 0.006076  4M: 0.006589  2M: 0.008753  1M: 0.014494 merged: 0.005417
epoch: 96 [2017-11-22 01:17:52]
lr 0.025417856237895775
 16M: 0.005282  8M: 0.006247  4M: 0.006707  2M: 0.008844  1M: 0.014395 merged: 0.005531
epoch: 97 [2017-11-22 01:23:43]
lr 0.02485915380880628
 16M: 0.005147  8M: 0.006218  4M: 0.006845  2M: 0.009052  1M: 0.014635 merged: 0.005523
epoch: 98 [2017-11-22 01:29:36]
lr 0.02428760260810931
 16M: 0.005105  8M: 0.006128  4M: 0.006697  2M: 0.008925  1M: 0.014142 merged: 0.005416
epoch: 99 [2017-11-22 01:35:24]
lr 0.02370227315699886
 16M: 0.005174  8M: 0.006096  4M: 0.006761  2M: 0.008997  1M: 0.014309 merged: 0.005376
epoch: 100 [2017-11-22 01:41:15]
lr 0.023102118196575382
 16M: 0.005067  8M: 0.006044  4M: 0.006547  2M: 0.008632  1M: 0.013752 merged: 0.005407
epoch: 101 [2017-11-22 01:47:04]
lr 0.022485950669875843
 16M: 0.005128  8M: 0.006056  4M: 0.006765  2M: 0.008905  1M: 0.014418 merged: 0.005422
epoch: 102 [2017-11-22 01:52:54]
lr 0.021852416110985085
 16M: 0.005126  8M: 0.006084  4M: 0.006592  2M: 0.008800  1M: 0.014171 merged: 0.005339
epoch: 103 [2017-11-22 01:59:02]
lr 0.021199957600127203
 16M: 0.005144  8M: 0.006116  4M: 0.006709  2M: 0.008842  1M: 0.014105 merged: 0.005425
epoch: 104 [2017-11-22 02:04:51]
lr 0.020526770681399003
 16M: 0.005073  8M: 0.005996  4M: 0.006618  2M: 0.008708  1M: 0.014205 merged: 0.005315
epoch: 105 [2017-11-22 02:10:57]
lr 0.019830744488452574
 16M: 0.004969  8M: 0.005906  4M: 0.006528  2M: 0.008657  1M: 0.014137 merged: 0.005227
epoch: 106 [2017-11-22 02:17:05]
lr 0.01910938354123028
 16M: 0.004951  8M: 0.005829  4M: 0.006476  2M: 0.008515  1M: 0.013639 merged: 0.005239
epoch: 107 [2017-11-22 02:23:07]
lr 0.01835970184086314
 16M: 0.004884  8M: 0.005783  4M: 0.006412  2M: 0.008512  1M: 0.013776 merged: 0.005177
epoch: 108 [2017-11-22 02:29:08]
lr 0.01757807623276631
 16M: 0.004855  8M: 0.005858  4M: 0.006514  2M: 0.008737  1M: 0.014217 merged: 0.005126
epoch: 109 [2017-11-22 02:35:04]
lr 0.016760038078849775
 16M: 0.004887  8M: 0.005871  4M: 0.006476  2M: 0.008715  1M: 0.014171 merged: 0.005161
epoch: 110 [2017-11-22 02:40:59]
lr 0.0158999682000954
 16M: 0.004715  8M: 0.005772  4M: 0.006399  2M: 0.008401  1M: 0.013414 merged: 0.004991
epoch: 111 [2017-11-22 02:46:57]
lr 0.01499063377991723
 16M: 0.004855  8M: 0.005868  4M: 0.006481  2M: 0.008698  1M: 0.014188 merged: 0.005124
epoch: 112 [2017-11-22 02:52:46]
lr 0.014022453903762567
 16M: 0.004684  8M: 0.005681  4M: 0.006341  2M: 0.008415  1M: 0.013682 merged: 0.004967
epoch: 113 [2017-11-22 02:58:43]
lr 0.012982269672237465
 16M: 0.004655  8M: 0.005723  4M: 0.006278  2M: 0.008432  1M: 0.013551 merged: 0.004959
epoch: 114 [2017-11-22 03:04:43]
lr 0.01185113657849943
 16M: 0.004686  8M: 0.005703  4M: 0.006300  2M: 0.008411  1M: 0.013412 merged: 0.004986
epoch: 115 [2017-11-22 03:10:35]
lr 0.010599978800063602
 16M: 0.004697  8M: 0.005776  4M: 0.006379  2M: 0.008561  1M: 0.014098 merged: 0.004990
epoch: 116 [2017-11-22 03:16:27]
lr 0.00917985092043157
 16M: 0.004669  8M: 0.005690  4M: 0.006247  2M: 0.008355  1M: 0.013432 merged: 0.004938
epoch: 117 [2017-11-22 03:22:29]
lr 0.007495316889958615
 16M: 0.004579  8M: 0.005637  4M: 0.006233  2M: 0.008283  1M: 0.013203 merged: 0.004886
epoch: 118 [2017-11-22 03:28:34]
lr 0.005299989400031801
 16M: 0.004632  8M: 0.005648  4M: 0.006276  2M: 0.008437  1M: 0.013367 merged: 0.004893
epoch: 119 [2017-11-22 03:34:22]
lr 1e-08
 16M: 0.004531  8M: 0.005638  4M: 0.006309  2M: 0.008412  1M: 0.013612 merged: 0.004876

Visualize Graph


In [ ]:
from graphviz import Digraph
import torch
from torch.autograd import Variable


def make_dot(var, params=None):
    """ Produces Graphviz representation of PyTorch autograd graph
    Blue nodes are the Variables that require grad, orange are Tensors
    saved for backward in torch.autograd.Function
    Args:
        var: output Variable
        params: dict of (name, Variable) to add names to node that
            require grad (TODO: make optional)
    """
    if params is not None:
        assert isinstance(params.values()[0], Variable)
        param_map = {id(v): k for k, v in params.items()}

    node_attr = dict(style='filled',
                     shape='box',
                     align='left',
                     fontsize='12',
                     ranksep='0.1',
                     height='0.2')
    dot = Digraph(node_attr=node_attr, graph_attr=dict(size="10240,10240"), format='svg')
    seen = set()

    def size_to_str(size):
        return '('+(', ').join(['%d' % v for v in size])+')'

    def add_nodes(var):
        if var not in seen:
            if torch.is_tensor(var):
                dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')
            elif hasattr(var, 'variable'):
                u = var.variable
                name = param_map[id(u)] if params is not None else ''
                node_name = '%s\n %s' % (name, size_to_str(u.size()))
                dot.node(str(id(var)), node_name, fillcolor='lightblue')
            else:
                dot.node(str(id(var)), str(type(var).__name__))
            seen.add(var)
            if hasattr(var, 'next_functions'):
                for u in var.next_functions:
                    if u[0] is not None:
                        dot.edge(str(id(u[0])), str(id(var)))
                        add_nodes(u[0])
            if hasattr(var, 'saved_tensors'):
                for t in var.saved_tensors:
                    dot.edge(str(id(t)), str(id(var)))
                    add_nodes(t)
    add_nodes(var.grad_fn)
    return dot

In [ ]:
x = Variable(torch.zeros(1,3,256,256))
y = net(x.cuda())
g = make_dot(y[-1])

In [ ]:
g.render('net')

In [ ]: