In [1]:
import os, glob, platform, datetime, random
from collections import OrderedDict

import torch
import torch.nn as nn
import torch.utils.data as data_utils
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
from torch.autograd import Variable
from torch import functional as F
# import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms

import cv2
from PIL import Image
from tensorboardX import SummaryWriter

import numpy as np
from numpy.linalg import inv as denseinv
from scipy import sparse
from scipy.sparse import lil_matrix, csr_matrix
from scipy.sparse.linalg import spsolve
from scipy.sparse.linalg import inv as spinv
import scipy.misc

from myimagefolder import MyImageFolder
from mymodel import PreTrainedModel, GradientNet
from myargs import Args

Configurations


In [2]:
args = Args()
args.test_scene = 'alley_1'
args.arch = "densenet121"
args.epoches = 500
args.epoches_unary_threshold = 0
args.image_h = 256
args.image_w = 256
args.img_extentions = ["png"]
args.training_thresholds = [250,200,150,50,0,300]
args.base_lr = 1
args.lr = args.base_lr
args.snapshot_interval = 5000
args.debug = True
args.gpu_num = 1
growth_rate = (4*(2**(args.gpu_num)))
args.display_interval = 50
args.display_curindex = 0

system_ = platform.system()
system_dist, system_version, _ = platform.dist()
if system_ == "Darwin": 
    args.train_dir = '/Volumes/Transcend/dataset/sintel2'
    args.pretrained = False
elif platform.dist() ==  ('debian', 'jessie/sid', ''):
    args.train_dir = '/home/lwp/workspace/sintel2'
    args.pretrained = True
elif platform.dist() == ('debian', 'stretch/sid', ''):
    args.train_dir = '/home/cad/lwp/workspace/dataset/sintel2'
    args.pretrained = True

if platform.system() == 'Linux': use_gpu = True
else: use_gpu = False
if use_gpu:
    torch.cuda.set_device(args.gpu_num)
    

print(platform.dist())


('debian', 'jessie/sid', '')

My DataLoader


In [3]:
train_dataset = MyImageFolder(args.train_dir, 'train',
                       transforms.Compose(
        [transforms.ToTensor()]
    ), random_crop=True, 
    img_extentions=args.img_extentions, test_scene=args.test_scene, image_h=args.image_h, image_w=args.image_w)
test_dataset = MyImageFolder(args.train_dir, 'test', 
                       transforms.Compose(
        [transforms.CenterCrop((args.image_h, args.image_w)),
         transforms.ToTensor()]
    ), random_crop=False,
    img_extentions=args.img_extentions, test_scene=args.test_scene, image_h=args.image_h, image_w=args.image_w)

train_loader = data_utils.DataLoader(train_dataset,1,True,num_workers=1)
test_loader = data_utils.DataLoader(test_dataset,1,True,num_workers=1)

Load Pretrained Model

Defination

  • DenseNet-121: num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16)
    • First Convolution: 32M -> 16M -> 8M
    • every transition: 8M -> 4M -> 2M (downsample 1/2, except the last block)

In [4]:
densenet = models.__dict__[args.arch](pretrained=args.pretrained)

for param in densenet.parameters():
    param.requires_grad = False

if use_gpu: densenet.cuda()

In [5]:
ss = 6

args.display_curindex = 0
args.base_lr = 0.05
args.display_interval = 20
args.momentum = 0.9
args.epoches = 120
args.training_thresholds = [ss*4,ss*3,ss*2,ss*1,ss*0,ss*5]
args.power = 0.5



pretrained = PreTrainedModel(densenet)
if use_gpu: 
    pretrained.cuda()


net = GradientNet(growth_rate=growth_rate)
if use_gpu:
    net.cuda()

if use_gpu: 
    mse_losses = [nn.MSELoss().cuda()] * 6
    test_losses = [nn.MSELoss().cuda()] * 6
else:
    mse_losses = [nn.MSELoss()] * 6
    test_losses = [nn.MSELoss()] * 6

In [6]:
# training loop
writer = SummaryWriter(comment='-growth_rate_{}'.format(growth_rate))

parameters = filter(lambda p: p.requires_grad, net.parameters())
optimizer = optim.SGD(parameters, lr=args.base_lr, momentum=args.momentum)

def adjust_learning_rate(optimizer, epoch, beg, end, reset_lr=None):
    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
    for param_group in optimizer.param_groups:
        if reset_lr != None:
            param_group['lr'] = reset_lr
            continue
        if epoch != 0: 
            # linear
#             param_group['lr'] *= (end-epoch) / (end-beg)
#             poly base_lr (1 - iter/max_iter) ^ (power)
            param_group['lr'] = args.base_lr * (float(end-epoch)/(end-beg)) ** (args.power)
            if param_group['lr'] < 1.0e-8: param_group['lr'] = 1.0e-8
        print('lr', param_group['lr'])
        
# def findLargerInd(target, arr):
#     res = list(filter(lambda x: x>target, arr))
#     print('res',res)
#     if len(res) == 0: return -1
#     return res[0]

for epoch in range(args.epoches):
    net.train()
    print('epoch: {} [{}]'.format(epoch, datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")))
    
    if epoch < args.training_thresholds[-1]: adjust_learning_rate(optimizer, epoch%ss, beg=0, end=ss-1)
    else: adjust_learning_rate(optimizer, epoch, beg=args.training_thresholds[-1], end=args.epoches-1)
    
    run_losses = [0] * len(args.training_thresholds)
    run_cnts   = [0.00001] * len(args.training_thresholds)
    if (epoch in args.training_thresholds) == True: 
        adjust_learning_rate(optimizer, epoch, reset_lr=args.base_lr, beg=-1, end=-1)
    writer.add_scalar('learning rate', optimizer.param_groups[0]['lr'], global_step=epoch)
    for ind, data in enumerate(train_loader, 0):
#         if  ind == 1 : break
        input_img, gt_albedo, gt_shading, test_scene, img_path = data
        im = input_img[0,:,:,:].numpy(); im = im.transpose(1,2,0); im = im[:,:,::-1]*255
        
        if test_scene[0] == 'alley_1':
            print('alley_1 yes')
        input_img = Variable(input_img)
        gt_albedo = Variable(gt_albedo)
        gt_shading = Variable(gt_shading)
        if use_gpu:
            input_img = input_img.cuda()
            gt_albedo = gt_albedo.cuda()
            gt_shading = gt_shading.cuda()
#         run_losses = [0] * len(mse_losses)
#         run_cnts = [0.00001] * len(mse_losses)
        if args.display_curindex % args.display_interval == 0:
            cv2.imwrite('snapshot{}/input.png'.format(args.gpu_num), im)

        optimizer.zero_grad()
        pretrained.train(); ft_pretreained = pretrained(input_img)
        ft_predict = net(ft_pretreained)
        for i, threshold in enumerate(args.training_thresholds):
#             threshold = args.training_thresholds[i]
            if epoch >= threshold:
#             if epoch >= 0:
                if i == 5: s = 1
                else: s = (2**(i+1))
                gt = gt_albedo.cpu().data.numpy()
                n,c,h,w = gt.shape
                gt = gt[0,:,:,:]
                gt = gt.transpose((1,2,0))
                gt = cv2.resize(gt, (h//s, w//s))
#                 gt = cv2.resize(gt, (h,w))
                if args.display_curindex % args.display_interval == 0:
                    cv2.imwrite('snapshot{}/gt-{}-{}.png'.format(args.gpu_num, epoch, i), gt[:,:,::-1]*255)
                gt = gt.transpose((2,0,1))
                gt = gt[np.newaxis, :]
                gt = Variable(torch.from_numpy(gt))
                if use_gpu: gt = gt.cuda()
                loss = mse_losses[i](ft_predict[i], gt)
                loss_data = loss.data.cpu().numpy()
                writer.add_scalar('{}th train iters loss'.format(i), loss_data, global_step=args.display_curindex)
                ma_ = ft_predict[i].max().cpu().data.numpy()
                mi_ = ft_predict[i].min().cpu().data.numpy()
                #print('mi', mi_, 'ma', ma_)
#                 writer.add_scalars('{}th train predict'.format(i), {'max': ma_, 'min': mi_}, global_step=args.display_curindex)
#                 run_cnts[i] += 1
                run_losses[i] += loss.data.cpu().numpy()[0]
                loss.backward(retain_graph=True)
                run_cnts[i] += 1
#                 print('i = ', i, '; weig\n', net.upsample01.weight[0,0,0:4,0:4].data.cpu().numpy())
#                 print('i = ', i, '; grad\n', net.upsample01.weight.grad[0,0,0:4,0:4].data.cpu().numpy())
                if args.display_curindex % args.display_interval == 0:
                    im = ft_predict[i].cpu().data.numpy()[0].transpose((1,2,0)) * 255
                    cv2.imwrite('snapshot{}/train-{}-{}.png'.format(args.gpu_num, epoch, i), im[:,:,::-1])
        optimizer.step()
        args.display_curindex += 1

    """ every epoch """
#     loss_output = 'ind: ' + str(args.display_curindex)
    loss_output = ''
    
    
    
    for i,v in enumerate(run_losses):
        if i == len(run_losses)-1: 
            loss_output += ' merged: %6f' % (run_losses[i] / run_cnts[i])
            continue
        loss_output += ' %2dM: %6f' % ((2**(4-i)), (run_losses[i] / run_cnts[i]))
    print(loss_output)
    # save at every epoch
    if (epoch+1) % 10 == 0:
        torch.save({
            'epoch': epoch,
            'args' : args,
            'state_dict': net.state_dict(),
            'optimizer': optimizer.state_dict()
        }, 'snapshot{}/snapshot-{}.pth.tar'.format(args.gpu_num, epoch))
    
    # test 
    test_losses_trainphase = [0] * len(args.training_thresholds)
    test_cnts_trainphase   = [0.00001] * len(args.training_thresholds)   
    for ind, data in enumerate(test_loader, 0):
        input_img, gt_albedo, gt_shading, test_scene, img_path = data
        input_img = Variable(input_img)
        gt_albedo = Variable(gt_albedo)
        gt_shading = Variable(gt_shading)
        if use_gpu:
            input_img = input_img.cuda(args.gpu_num)
        
        pretrained.train(); ft_pretreained = pretrained(input_img)
        ft_test = net(ft_pretreained)
            
        for i,v in enumerate(ft_test):
            if epoch < args.training_thresholds[i]: continue
            if i == 5: s = 1
            else: s = (2**(i+1))
            gt = gt_albedo.data.numpy()
            n,c,h,w = gt.shape
            gt = gt[0,:,:,:]
            gt = gt.transpose((1,2,0))
            gt = cv2.resize(gt, (h//s, w//s))
#             gt = cv2.resize(gt, (h,w))
            
            gt = gt.transpose((2,0,1))
            gt = gt[np.newaxis, :]
            gt = Variable(torch.from_numpy(gt))
            if use_gpu: gt = gt.cuda()

            loss = mse_losses[i](ft_test[i], gt)
            
            test_losses_trainphase[i] += loss.data.cpu().numpy()[0]
            test_cnts_trainphase[i] += 1
            v = v[0].cpu().data.numpy()
            v = v.transpose(1,2,0)
            if ind == 0: cv2.imwrite('snapshot{}/test-phase_train-{}-{}.png'.format(args.gpu_num, epoch, i), v[:,:,::-1]*255)

    
    net.eval()
    test_losses = [0] * len(args.training_thresholds)
    test_cnts   = [0.00001] * len(args.training_thresholds)   
    for ind, data in enumerate(test_loader, 0):
#         if ind == 1: break
        input_img, gt_albedo, gt_shading, test_scene, img_path = data
        input_img = Variable(input_img)
        gt_albedo = Variable(gt_albedo)
        gt_shading = Variable(gt_shading)
        if use_gpu:
            input_img = input_img.cuda(args.gpu_num)
            
        pretrained.eval(); ft_pretreained = pretrained(input_img)
        ft_test = net(ft_pretreained)
            
        for i,v in enumerate(ft_test):
            if epoch < args.training_thresholds[i]: continue
            if i == 5: s = 1
            else: s = (2**(i+1))
            gt = gt_albedo.data.numpy()
            n,c,h,w = gt.shape
            gt = gt[0,:,:,:]
            gt = gt.transpose((1,2,0))
            gt = cv2.resize(gt, (h//s, w//s))
#             gt = cv2.resize(gt, (h,w))
            
            gt = gt.transpose((2,0,1))
            gt = gt[np.newaxis, :]
            gt = Variable(torch.from_numpy(gt))
            if use_gpu: gt = gt.cuda()

            loss = mse_losses[i](ft_test[i], gt)
            
            test_losses[i] += loss.data.cpu().numpy()[0]
            test_cnts[i] += 1
            v = v[0].cpu().data.numpy()
            v = v.transpose(1,2,0)
            if ind == 0: cv2.imwrite('snapshot{}/test-phase_test-{}-{}.png'.format(args.gpu_num, epoch, i), v[:,:,::-1]*255)
    
    writer.add_scalars('16M loss', {
        'train 16M ': np.array([run_losses[0]/ run_cnts[0]]),
        'test_trainphase 16M ': np.array([test_losses_trainphase[0]/ test_cnts_trainphase[0]]),
        'test 16M ': np.array([test_losses[0]/ test_cnts[0]])
    }, global_step=epoch)  
    writer.add_scalars('8M loss', {
        'train 8M ': np.array([run_losses[1]/ run_cnts[1]]),
        'test_trainphase 8M ': np.array([test_losses_trainphase[1]/ test_cnts_trainphase[1]]),
        'test 8M ': np.array([test_losses[1]/ test_cnts[1]])
    }, global_step=epoch) 
    writer.add_scalars('4M loss', {
        'train 4M ': np.array([run_losses[2]/ run_cnts[2]]),
        'test_trainphase 4M ': np.array([test_losses_trainphase[2]/ test_cnts_trainphase[2]]),
        'test 4M ': np.array([test_losses[2]/ test_cnts[2]])
    }, global_step=epoch) 
    writer.add_scalars('2M loss', {
        'train 2M ': np.array([run_losses[3]/ run_cnts[3]]),
        'test_trainphase 2M ': np.array([test_losses_trainphase[3]/ test_cnts_trainphase[3]]),
        'test 2M ': np.array([test_losses[3]/ test_cnts[3]])
    }, global_step=epoch) 
    writer.add_scalars('1M loss', {
        'train 1M ': np.array([run_losses[4]/ run_cnts[4]]),
        'test_trainphase 1M ': np.array([test_losses_trainphase[4]/ test_cnts_trainphase[4]]),
        'test 1M ': np.array([test_losses[4]/ test_cnts[4]])
    }, global_step=epoch) 
    writer.add_scalars('merged loss', {
        'train merged ': np.array([run_losses[5]/ run_cnts[5]]),
        'test_trainphase merged ': np.array([test_losses_trainphase[5]/ test_cnts_trainphase[5]]),
        'test merged ': np.array([test_losses[5]/ test_cnts[5]])
    }, global_step=epoch)


epoch: 0 [2017-11-21 17:40:04]
lr 0.05
 16M: 0.000000  8M: 0.000000  4M: 0.000000  2M: 0.000000  1M: 0.055995 merged: 0.000000
epoch: 1 [2017-11-21 17:41:57]
lr 0.044721359549995794
 16M: 0.000000  8M: 0.000000  4M: 0.000000  2M: 0.000000  1M: 0.049208 merged: 0.000000
epoch: 2 [2017-11-21 17:43:50]
lr 0.038729833462074176
 16M: 0.000000  8M: 0.000000  4M: 0.000000  2M: 0.000000  1M: 0.047502 merged: 0.000000
epoch: 3 [2017-11-21 17:45:41]
lr 0.0316227766016838
 16M: 0.000000  8M: 0.000000  4M: 0.000000  2M: 0.000000  1M: 0.045861 merged: 0.000000
epoch: 4 [2017-11-21 17:47:33]
lr 0.022360679774997897
 16M: 0.000000  8M: 0.000000  4M: 0.000000  2M: 0.000000  1M: 0.042942 merged: 0.000000
epoch: 5 [2017-11-21 17:49:26]
lr 1e-08
 16M: 0.000000  8M: 0.000000  4M: 0.000000  2M: 0.000000  1M: 0.041904 merged: 0.000000
epoch: 6 [2017-11-21 17:51:19]
lr 1e-08
 16M: 0.000000  8M: 0.000000  4M: 0.000000  2M: 0.048487  1M: 0.045408 merged: 0.000000
epoch: 7 [2017-11-21 17:53:31]
lr 0.044721359549995794
 16M: 0.000000  8M: 0.000000  4M: 0.000000  2M: 0.044846  1M: 0.044687 merged: 0.000000
epoch: 8 [2017-11-21 17:55:43]
lr 0.038729833462074176
 16M: 0.000000  8M: 0.000000  4M: 0.000000  2M: 0.039257  1M: 0.039106 merged: 0.000000
epoch: 9 [2017-11-21 17:57:56]
lr 0.0316227766016838
 16M: 0.000000  8M: 0.000000  4M: 0.000000  2M: 0.030131  1M: 0.034028 merged: 0.000000
epoch: 10 [2017-11-21 18:00:10]
lr 0.022360679774997897
 16M: 0.000000  8M: 0.000000  4M: 0.000000  2M: 0.024124  1M: 0.030004 merged: 0.000000
epoch: 11 [2017-11-21 18:02:22]
lr 1e-08
 16M: 0.000000  8M: 0.000000  4M: 0.000000  2M: 0.022335  1M: 0.028380 merged: 0.000000
epoch: 12 [2017-11-21 18:04:34]
lr 1e-08
 16M: 0.000000  8M: 0.000000  4M: 0.048452  2M: 0.029604  1M: 0.034685 merged: 0.000000
epoch: 13 [2017-11-21 18:07:05]
lr 0.044721359549995794
 16M: 0.000000  8M: 0.000000  4M: 0.038831  2M: 0.024342  1M: 0.030010 merged: 0.000000
epoch: 14 [2017-11-21 18:09:35]
lr 0.038729833462074176
 16M: 0.000000  8M: 0.000000  4M: 0.029331  2M: 0.021477  1M: 0.028652 merged: 0.000000
epoch: 15 [2017-11-21 18:12:05]
lr 0.0316227766016838
 16M: 0.000000  8M: 0.000000  4M: 0.022728  2M: 0.020328  1M: 0.026492 merged: 0.000000
epoch: 16 [2017-11-21 18:14:32]
lr 0.022360679774997897
 16M: 0.000000  8M: 0.000000  4M: 0.019621  2M: 0.018477  1M: 0.024788 merged: 0.000000
epoch: 17 [2017-11-21 18:17:00]
lr 1e-08
 16M: 0.000000  8M: 0.000000  4M: 0.019611  2M: 0.018279  1M: 0.023365 merged: 0.000000
epoch: 18 [2017-11-21 18:19:28]
lr 1e-08
 16M: 0.000000  8M: 0.047451  4M: 0.020998  2M: 0.020092  1M: 0.026975 merged: 0.000000
epoch: 19 [2017-11-21 18:22:06]
lr 0.044721359549995794
 16M: 0.000000  8M: 0.033968  4M: 0.019049  2M: 0.018897  1M: 0.026792 merged: 0.000000
epoch: 20 [2017-11-21 18:24:44]
lr 0.038729833462074176
 16M: 0.000000  8M: 0.024432  4M: 0.016376  2M: 0.017460  1M: 0.024029 merged: 0.000000
epoch: 21 [2017-11-21 18:27:22]
lr 0.0316227766016838
 16M: 0.000000  8M: 0.020982  4M: 0.015712  2M: 0.017136  1M: 0.024195 merged: 0.000000
epoch: 22 [2017-11-21 18:30:01]
lr 0.022360679774997897
 16M: 0.000000  8M: 0.018209  4M: 0.014467  2M: 0.016216  1M: 0.022361 merged: 0.000000
epoch: 23 [2017-11-21 18:32:42]
lr 1e-08
 16M: 0.000000  8M: 0.017454  4M: 0.014652  2M: 0.016416  1M: 0.022751 merged: 0.000000
epoch: 24 [2017-11-21 18:35:19]
lr 1e-08
 16M: 0.043795  8M: 0.019385  4M: 0.015555  2M: 0.016618  1M: 0.023509 merged: 0.000000
epoch: 25 [2017-11-21 18:38:30]
lr 0.044721359549995794
 16M: 0.031538  8M: 0.017654  4M: 0.014909  2M: 0.016924  1M: 0.023567 merged: 0.000000
epoch: 26 [2017-11-21 18:41:44]
lr 0.038729833462074176
 16M: 0.024005  8M: 0.015513  4M: 0.013299  2M: 0.015760  1M: 0.022426 merged: 0.000000
epoch: 27 [2017-11-21 18:44:53]
lr 0.0316227766016838
 16M: 0.020152  8M: 0.014826  4M: 0.012924  2M: 0.015241  1M: 0.021839 merged: 0.000000
epoch: 28 [2017-11-21 18:48:04]
lr 0.022360679774997897
 16M: 0.016833  8M: 0.013471  4M: 0.012111  2M: 0.014342  1M: 0.020711 merged: 0.000000
epoch: 29 [2017-11-21 18:51:17]
lr 1e-08
 16M: 0.014879  8M: 0.012797  4M: 0.012123  2M: 0.014540  1M: 0.021152 merged: 0.000000
epoch: 30 [2017-11-21 18:54:28]
lr 0.05
 16M: 0.020367  8M: 0.016169  4M: 0.014902  2M: 0.016826  1M: 0.023762 merged: 0.035464
epoch: 31 [2017-11-21 18:59:00]
lr 0.04971830761761256
 16M: 0.020922  8M: 0.016355  4M: 0.016300  2M: 0.016628  1M: 0.024189 merged: 0.023990
epoch: 32 [2017-11-21 19:03:28]
lr 0.04943501011144937
 16M: 0.018079  8M: 0.016710  4M: 0.015471  2M: 0.015907  1M: 0.023090 merged: 0.020878
epoch: 33 [2017-11-21 19:08:06]
lr 0.04915007972606608
 16M: 0.019073  8M: 0.015616  4M: 0.015997  2M: 0.016010  1M: 0.023063 merged: 0.020340
epoch: 34 [2017-11-21 19:12:43]
lr 0.04886348789677424
 16M: 0.015314  8M: 0.014485  4M: 0.014335  2M: 0.015202  1M: 0.022275 merged: 0.016799
epoch: 35 [2017-11-21 19:17:15]
lr 0.04857520521621862
 16M: 0.014447  8M: 0.013997  4M: 0.014213  2M: 0.015324  1M: 0.022012 merged: 0.016109
epoch: 36 [2017-11-21 19:21:47]
lr 0.04828520139915856
 16M: 0.013456  8M: 0.013151  4M: 0.013592  2M: 0.015116  1M: 0.022733 merged: 0.014945
epoch: 37 [2017-11-21 19:26:20]
lr 0.047993445245333805
 16M: 0.012919  8M: 0.012985  4M: 0.013323  2M: 0.014697  1M: 0.022284 merged: 0.014327
epoch: 38 [2017-11-21 19:30:56]
lr 0.0476999046002862
 16M: 0.012300  8M: 0.012422  4M: 0.012745  2M: 0.014676  1M: 0.021050 merged: 0.013770
epoch: 39 [2017-11-21 19:35:36]
lr 0.04740454631399772
 16M: 0.012560  8M: 0.012160  4M: 0.012594  2M: 0.014134  1M: 0.021295 merged: 0.013435
epoch: 40 [2017-11-21 19:40:02]
lr 0.04710733619719444
 16M: 0.012160  8M: 0.012578  4M: 0.012764  2M: 0.014719  1M: 0.022176 merged: 0.013355
epoch: 41 [2017-11-21 19:44:41]
lr 0.04680823897515326
 16M: 0.011780  8M: 0.012191  4M: 0.012009  2M: 0.014142  1M: 0.021355 merged: 0.013186
epoch: 42 [2017-11-21 19:49:21]
lr 0.04650721823883479
 16M: 0.011209  8M: 0.011837  4M: 0.011685  2M: 0.013770  1M: 0.020569 merged: 0.012530
epoch: 43 [2017-11-21 19:53:58]
lr 0.046204236393150765
 16M: 0.011181  8M: 0.011671  4M: 0.011691  2M: 0.013755  1M: 0.021060 merged: 0.012229
epoch: 44 [2017-11-21 19:58:38]
lr 0.045899254602157845
 16M: 0.010676  8M: 0.011008  4M: 0.011100  2M: 0.013100  1M: 0.020258 merged: 0.011819
epoch: 45 [2017-11-21 20:03:16]
lr 0.04559223273095164
 16M: 0.010824  8M: 0.011069  4M: 0.011148  2M: 0.013515  1M: 0.019915 merged: 0.011758
epoch: 46 [2017-11-21 20:07:55]
lr 0.045283129284014914
 16M: 0.010232  8M: 0.010956  4M: 0.011264  2M: 0.013122  1M: 0.020118 merged: 0.011426
epoch: 47 [2017-11-21 20:12:35]
lr 0.04497190133975169
 16M: 0.009937  8M: 0.010563  4M: 0.010629  2M: 0.012737  1M: 0.019574 merged: 0.010968
epoch: 48 [2017-11-21 20:17:15]
lr 0.04465850448091506
 16M: 0.009511  8M: 0.010174  4M: 0.010408  2M: 0.012554  1M: 0.019134 merged: 0.010544
epoch: 49 [2017-11-21 20:21:55]
lr 0.044342892720609255
 16M: 0.009822  8M: 0.010423  4M: 0.010898  2M: 0.012749  1M: 0.019048 merged: 0.010929
epoch: 50 [2017-11-21 20:26:36]
lr 0.044025018423517
 16M: 0.009423  8M: 0.010081  4M: 0.010239  2M: 0.012439  1M: 0.018874 merged: 0.010301
epoch: 51 [2017-11-21 20:31:14]
lr 0.04370483222197017
 16M: 0.009370  8M: 0.010180  4M: 0.010687  2M: 0.012647  1M: 0.019020 merged: 0.010385
epoch: 52 [2017-11-21 20:35:55]
lr 0.043382282926444894
 16M: 0.009378  8M: 0.010039  4M: 0.010509  2M: 0.012487  1M: 0.018315 merged: 0.010281
epoch: 53 [2017-11-21 20:40:33]
lr 0.04305731743002185
 16M: 0.009116  8M: 0.009893  4M: 0.010284  2M: 0.012288  1M: 0.018706 merged: 0.010153
epoch: 54 [2017-11-21 20:45:12]
lr 0.04272988060630656
 16M: 0.008778  8M: 0.009522  4M: 0.009843  2M: 0.011868  1M: 0.018379 merged: 0.009746
epoch: 55 [2017-11-21 20:49:51]
lr 0.04239991520025441
 16M: 0.008698  8M: 0.009464  4M: 0.009695  2M: 0.011860  1M: 0.017860 merged: 0.009616
epoch: 56 [2017-11-21 20:54:27]
lr 0.0420673617112877
 16M: 0.008595  8M: 0.009620  4M: 0.009876  2M: 0.011942  1M: 0.018114 merged: 0.009515
epoch: 57 [2017-11-21 20:59:08]
lr 0.041732158268029534
 16M: 0.008562  8M: 0.009618  4M: 0.009906  2M: 0.011831  1M: 0.017615 merged: 0.009451
epoch: 58 [2017-11-21 21:03:46]
lr 0.041394240493907074
 16M: 0.008132  8M: 0.009243  4M: 0.009540  2M: 0.011491  1M: 0.017138 merged: 0.009114
epoch: 59 [2017-11-21 21:08:24]
lr 0.041053541362798006
 16M: 0.008658  8M: 0.009830  4M: 0.010032  2M: 0.011929  1M: 0.017485 merged: 0.009530
epoch: 60 [2017-11-21 21:13:05]
lr 0.04070999104380296
 16M: 0.008116  8M: 0.009349  4M: 0.009548  2M: 0.011531  1M: 0.017049 merged: 0.008917
epoch: 61 [2017-11-21 21:17:43]
lr 0.04036351673412598
 16M: 0.008053  8M: 0.009119  4M: 0.009494  2M: 0.011436  1M: 0.017350 merged: 0.009002
epoch: 62 [2017-11-21 21:22:27]
lr 0.04001404247893005
 16M: 0.008018  8M: 0.009149  4M: 0.009547  2M: 0.011417  1M: 0.017619 merged: 0.008744
epoch: 63 [2017-11-21 21:27:06]
lr 0.03966148897690515
 16M: 0.007811  8M: 0.008922  4M: 0.009133  2M: 0.011309  1M: 0.017150 merged: 0.008391
epoch: 64 [2017-11-21 21:31:46]
lr 0.03930577337013889
 16M: 0.007552  8M: 0.008802  4M: 0.009011  2M: 0.010948  1M: 0.016830 merged: 0.008419
epoch: 65 [2017-11-21 21:36:23]
lr 0.038946809016712394
 16M: 0.007681  8M: 0.009039  4M: 0.009109  2M: 0.011232  1M: 0.016828 merged: 0.008506
epoch: 66 [2017-11-21 21:41:06]
lr 0.03858450524425343
 16M: 0.007883  8M: 0.009017  4M: 0.009297  2M: 0.011578  1M: 0.017765 merged: 0.008560
epoch: 67 [2017-11-21 21:45:45]
lr 0.03821876708246056
 16M: 0.007678  8M: 0.008862  4M: 0.009133  2M: 0.011373  1M: 0.017198 merged: 0.008393
epoch: 68 [2017-11-21 21:50:25]
lr 0.03784949497236286
 16M: 0.007425  8M: 0.008540  4M: 0.008909  2M: 0.011024  1M: 0.016697 merged: 0.008099
epoch: 69 [2017-11-21 21:55:03]
lr 0.03747658444979307
 16M: 0.007353  8M: 0.008662  4M: 0.009101  2M: 0.011229  1M: 0.016583 merged: 0.008099
epoch: 70 [2017-11-21 21:59:43]
lr 0.0370999258002226
 16M: 0.007404  8M: 0.008656  4M: 0.009127  2M: 0.011348  1M: 0.017087 merged: 0.008082
epoch: 71 [2017-11-21 22:04:23]
lr 0.03671940368172628
 16M: 0.007379  8M: 0.008850  4M: 0.009163  2M: 0.011231  1M: 0.016854 merged: 0.008012
epoch: 72 [2017-11-21 22:09:02]
lr 0.03633489671240478
 16M: 0.007133  8M: 0.008404  4M: 0.009158  2M: 0.011203  1M: 0.016930 merged: 0.007868
epoch: 73 [2017-11-21 22:13:41]
lr 0.03594627701808178
 16M: 0.006851  8M: 0.008095  4M: 0.008435  2M: 0.010474  1M: 0.015930 merged: 0.007480
epoch: 74 [2017-11-21 22:18:21]
lr 0.035553409735498295
 16M: 0.007025  8M: 0.008375  4M: 0.008695  2M: 0.010765  1M: 0.016427 merged: 0.007671
epoch: 75 [2017-11-21 22:22:59]
lr 0.03515615246553262
 16M: 0.007068  8M: 0.008327  4M: 0.008917  2M: 0.010802  1M: 0.016393 merged: 0.007686
epoch: 76 [2017-11-21 22:27:38]
lr 0.03475435467016077
 16M: 0.006979  8M: 0.008222  4M: 0.008562  2M: 0.010369  1M: 0.016270 merged: 0.007562
epoch: 77 [2017-11-21 22:32:14]
lr 0.034347857005916346
 16M: 0.006771  8M: 0.008192  4M: 0.008595  2M: 0.010678  1M: 0.016294 merged: 0.007535
epoch: 78 [2017-11-21 22:36:54]
lr 0.0339364905854808
 16M: 0.006910  8M: 0.008207  4M: 0.008588  2M: 0.010660  1M: 0.016237 merged: 0.007476
epoch: 79 [2017-11-21 22:41:30]
lr 0.03352007615769955
 16M: 0.006855  8M: 0.008021  4M: 0.008550  2M: 0.010652  1M: 0.016013 merged: 0.007367
epoch: 80 [2017-11-21 22:46:10]
lr 0.03309842319473132
 16M: 0.006872  8M: 0.008018  4M: 0.008469  2M: 0.010544  1M: 0.016150 merged: 0.007444
epoch: 81 [2017-11-21 22:50:47]
lr 0.03267132887314317
 16M: 0.006795  8M: 0.007987  4M: 0.008419  2M: 0.010524  1M: 0.015931 merged: 0.007229
epoch: 82 [2017-11-21 22:55:27]
lr 0.03223857693349118
 16M: 0.006724  8M: 0.007937  4M: 0.008335  2M: 0.010357  1M: 0.016114 merged: 0.007186
epoch: 83 [2017-11-21 23:00:05]
lr 0.0317999364001908
 16M: 0.006579  8M: 0.007766  4M: 0.008259  2M: 0.010181  1M: 0.015681 merged: 0.007140
epoch: 84 [2017-11-21 23:04:46]
lr 0.031355160140170396
 16M: 0.006695  8M: 0.007963  4M: 0.008512  2M: 0.010491  1M: 0.015962 merged: 0.007212
epoch: 85 [2017-11-21 23:09:25]
lr 0.03090398323477543
 16M: 0.006528  8M: 0.007860  4M: 0.008247  2M: 0.010128  1M: 0.015707 merged: 0.006984
epoch: 86 [2017-11-21 23:14:06]
lr 0.030446121134470178
 16M: 0.006522  8M: 0.007888  4M: 0.008336  2M: 0.010463  1M: 0.015935 merged: 0.007030
epoch: 87 [2017-11-21 23:18:43]
lr 0.02998126755983446
 16M: 0.006458  8M: 0.007861  4M: 0.008274  2M: 0.010236  1M: 0.015679 merged: 0.007031
epoch: 88 [2017-11-21 23:23:21]
lr 0.029509092104873926
 16M: 0.006396  8M: 0.007775  4M: 0.008074  2M: 0.010104  1M: 0.015398 merged: 0.006918
epoch: 89 [2017-11-21 23:28:00]
lr 0.029029237489356888
 16M: 0.006454  8M: 0.007690  4M: 0.008157  2M: 0.010113  1M: 0.015606 merged: 0.006845
epoch: 90 [2017-11-21 23:32:38]
lr 0.028541316395237167
 16M: 0.006321  8M: 0.007520  4M: 0.008120  2M: 0.009958  1M: 0.015134 merged: 0.006812
epoch: 91 [2017-11-21 23:37:20]
lr 0.028044907807525134
 16M: 0.006357  8M: 0.007666  4M: 0.008267  2M: 0.010184  1M: 0.015639 merged: 0.006907
epoch: 92 [2017-11-21 23:42:00]
lr 0.027539552761294706
 16M: 0.006152  8M: 0.007438  4M: 0.007977  2M: 0.009926  1M: 0.015432 merged: 0.006640
epoch: 93 [2017-11-21 23:46:41]
lr 0.027024749372597065
 16M: 0.006301  8M: 0.007520  4M: 0.008076  2M: 0.010019  1M: 0.015418 merged: 0.006712
epoch: 94 [2017-11-21 23:51:18]
lr 0.026499947000159004
 16M: 0.006146  8M: 0.007474  4M: 0.007977  2M: 0.009927  1M: 0.015171 merged: 0.006580
epoch: 95 [2017-11-21 23:55:55]
lr 0.02596453934447493
 16M: 0.006099  8M: 0.007440  4M: 0.008044  2M: 0.010029  1M: 0.015382 merged: 0.006594
epoch: 96 [2017-11-22 00:00:35]
lr 0.025417856237895775
 16M: 0.006109  8M: 0.007476  4M: 0.008056  2M: 0.009982  1M: 0.015486 merged: 0.006579
epoch: 97 [2017-11-22 00:05:13]
lr 0.02485915380880628
 16M: 0.006020  8M: 0.007349  4M: 0.007912  2M: 0.009797  1M: 0.015213 merged: 0.006568
epoch: 98 [2017-11-22 00:09:53]
lr 0.02428760260810931
 16M: 0.005929  8M: 0.007248  4M: 0.007668  2M: 0.009375  1M: 0.014535 merged: 0.006404
epoch: 99 [2017-11-22 00:14:30]
lr 0.02370227315699886
 16M: 0.006016  8M: 0.007357  4M: 0.007963  2M: 0.010013  1M: 0.015286 merged: 0.006467
epoch: 100 [2017-11-22 00:19:09]
lr 0.023102118196575382
 16M: 0.005904  8M: 0.007227  4M: 0.007808  2M: 0.009704  1M: 0.014684 merged: 0.006322
epoch: 101 [2017-11-22 00:23:45]
lr 0.022485950669875843
 16M: 0.005937  8M: 0.007199  4M: 0.007791  2M: 0.009800  1M: 0.014820 merged: 0.006439
epoch: 102 [2017-11-22 00:28:23]
lr 0.021852416110985085
 16M: 0.005990  8M: 0.007184  4M: 0.007706  2M: 0.009673  1M: 0.015073 merged: 0.006259
epoch: 103 [2017-11-22 00:33:01]
lr 0.021199957600127203
 16M: 0.005933  8M: 0.007350  4M: 0.007885  2M: 0.009868  1M: 0.014922 merged: 0.006460
epoch: 104 [2017-11-22 00:37:38]
lr 0.020526770681399003
 16M: 0.005778  8M: 0.007074  4M: 0.007645  2M: 0.009695  1M: 0.014899 merged: 0.006190
epoch: 105 [2017-11-22 00:42:18]
lr 0.019830744488452574
 16M: 0.005837  8M: 0.007091  4M: 0.007598  2M: 0.009707  1M: 0.014686 merged: 0.006235
epoch: 106 [2017-11-22 00:46:56]
lr 0.01910938354123028
 16M: 0.005763  8M: 0.007094  4M: 0.007551  2M: 0.009564  1M: 0.014670 merged: 0.006199
epoch: 107 [2017-11-22 00:51:33]
lr 0.01835970184086314
 16M: 0.005781  8M: 0.007085  4M: 0.007592  2M: 0.009541  1M: 0.014583 merged: 0.006194
epoch: 108 [2017-11-22 00:56:11]
lr 0.01757807623276631
 16M: 0.005616  8M: 0.007006  4M: 0.007477  2M: 0.009344  1M: 0.014410 merged: 0.006046
epoch: 109 [2017-11-22 01:00:51]
lr 0.016760038078849775
 16M: 0.005742  8M: 0.007224  4M: 0.007863  2M: 0.009841  1M: 0.014801 merged: 0.006254
epoch: 110 [2017-11-22 01:05:31]
lr 0.0158999682000954
 16M: 0.005707  8M: 0.007007  4M: 0.007674  2M: 0.009679  1M: 0.014936 merged: 0.006065
epoch: 111 [2017-11-22 01:10:15]
lr 0.01499063377991723
 16M: 0.005694  8M: 0.007067  4M: 0.007600  2M: 0.009624  1M: 0.014604 merged: 0.006068
epoch: 112 [2017-11-22 01:15:00]
lr 0.014022453903762567
 16M: 0.005659  8M: 0.007024  4M: 0.007680  2M: 0.009784  1M: 0.014936 merged: 0.006125
epoch: 113 [2017-11-22 01:19:47]
lr 0.012982269672237465
 16M: 0.005747  8M: 0.007480  4M: 0.007547  2M: 0.009502  1M: 0.014569 merged: 0.006181
epoch: 114 [2017-11-22 01:24:33]
lr 0.01185113657849943
 16M: 0.005554  8M: 0.007107  4M: 0.007574  2M: 0.009640  1M: 0.014577 merged: 0.006061
epoch: 115 [2017-11-22 01:29:26]
lr 0.010599978800063602
 16M: 0.005524  8M: 0.006950  4M: 0.007436  2M: 0.009481  1M: 0.014535 merged: 0.005935
epoch: 116 [2017-11-22 01:34:12]
lr 0.00917985092043157
 16M: 0.005487  8M: 0.006940  4M: 0.007450  2M: 0.009431  1M: 0.014525 merged: 0.005896
epoch: 117 [2017-11-22 01:38:50]
lr 0.007495316889958615
 16M: 0.005439  8M: 0.006842  4M: 0.007492  2M: 0.009540  1M: 0.014767 merged: 0.005811
epoch: 118 [2017-11-22 01:43:29]
lr 0.005299989400031801
 16M: 0.005485  8M: 0.006830  4M: 0.007374  2M: 0.009378  1M: 0.014383 merged: 0.005881
epoch: 119 [2017-11-22 01:48:18]
lr 1e-08
 16M: 0.005347  8M: 0.006710  4M: 0.007235  2M: 0.009123  1M: 0.014188 merged: 0.005748

Visualize Graph


In [ ]:
from graphviz import Digraph
import torch
from torch.autograd import Variable


def make_dot(var, params=None):
    """ Produces Graphviz representation of PyTorch autograd graph
    Blue nodes are the Variables that require grad, orange are Tensors
    saved for backward in torch.autograd.Function
    Args:
        var: output Variable
        params: dict of (name, Variable) to add names to node that
            require grad (TODO: make optional)
    """
    if params is not None:
        assert isinstance(params.values()[0], Variable)
        param_map = {id(v): k for k, v in params.items()}

    node_attr = dict(style='filled',
                     shape='box',
                     align='left',
                     fontsize='12',
                     ranksep='0.1',
                     height='0.2')
    dot = Digraph(node_attr=node_attr, graph_attr=dict(size="10240,10240"), format='svg')
    seen = set()

    def size_to_str(size):
        return '('+(', ').join(['%d' % v for v in size])+')'

    def add_nodes(var):
        if var not in seen:
            if torch.is_tensor(var):
                dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')
            elif hasattr(var, 'variable'):
                u = var.variable
                name = param_map[id(u)] if params is not None else ''
                node_name = '%s\n %s' % (name, size_to_str(u.size()))
                dot.node(str(id(var)), node_name, fillcolor='lightblue')
            else:
                dot.node(str(id(var)), str(type(var).__name__))
            seen.add(var)
            if hasattr(var, 'next_functions'):
                for u in var.next_functions:
                    if u[0] is not None:
                        dot.edge(str(id(u[0])), str(id(var)))
                        add_nodes(u[0])
            if hasattr(var, 'saved_tensors'):
                for t in var.saved_tensors:
                    dot.edge(str(id(t)), str(id(var)))
                    add_nodes(t)
    add_nodes(var.grad_fn)
    return dot

In [ ]:
x = Variable(torch.zeros(1,3,256,256))
y = net(x.cuda())
g = make_dot(y[-1])

In [ ]:
g.render('net')

In [ ]: