In [1]:
import os, glob, platform, datetime, random
from collections import OrderedDict

import torch
import torch.nn as nn
import torch.utils.data as data_utils
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
from torch.autograd import Variable
from torch import functional as F
# import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms

import cv2
from PIL import Image
from tensorboardX import SummaryWriter

import numpy as np
from numpy.linalg import inv as denseinv
from scipy import sparse
from scipy.sparse import lil_matrix, csr_matrix
from scipy.sparse.linalg import spsolve
from scipy.sparse.linalg import inv as spinv
import scipy.misc

from myimagefolder import MyImageFolder
from mymodel import PreTrainedModel, GradientNet
from myargs import Args

Configurations


In [2]:
args = Args()
args.test_scene = 'alley_1'
args.arch = "densenet121"
args.epoches = 500
args.epoches_unary_threshold = 0
args.image_h = 256
args.image_w = 256
args.img_extentions = ["png"]
args.training_thresholds = [250,200,150,50,0,300]
args.base_lr = 1
args.lr = args.base_lr
args.snapshot_interval = 5000
args.debug = True
args.gpu_num = 0
growth_rate = (4*(2**(args.gpu_num)))
args.display_interval = 50
args.display_curindex = 0

system_ = platform.system()
system_dist, system_version, _ = platform.dist()
if system_ == "Darwin": 
    args.train_dir = '/Volumes/Transcend/dataset/sintel2'
    args.pretrained = False
elif platform.dist() ==  ('debian', 'jessie/sid', ''):
    args.train_dir = '/home/lwp/workspace/sintel2'
    args.pretrained = True
elif platform.dist() == ('debian', 'stretch/sid', ''):
    args.train_dir = '/home/cad/lwp/workspace/dataset/sintel2'
    args.pretrained = True

if platform.system() == 'Linux': use_gpu = True
else: use_gpu = False
if use_gpu:
    torch.cuda.set_device(args.gpu_num)
    

print(platform.dist())


('debian', 'jessie/sid', '')

My DataLoader


In [3]:
train_dataset = MyImageFolder(args.train_dir, 'train',
                       transforms.Compose(
        [transforms.ToTensor()]
    ), random_crop=True, 
    img_extentions=args.img_extentions, test_scene=args.test_scene, image_h=args.image_h, image_w=args.image_w)
test_dataset = MyImageFolder(args.train_dir, 'test', 
                       transforms.Compose(
        [transforms.CenterCrop((args.image_h, args.image_w)),
         transforms.ToTensor()]
    ), random_crop=False,
    img_extentions=args.img_extentions, test_scene=args.test_scene, image_h=args.image_h, image_w=args.image_w)

train_loader = data_utils.DataLoader(train_dataset,1,True,num_workers=1)
test_loader = data_utils.DataLoader(test_dataset,1,True,num_workers=1)

Load Pretrained Model

Defination

  • DenseNet-121: num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16)
    • First Convolution: 32M -> 16M -> 8M
    • every transition: 8M -> 4M -> 2M (downsample 1/2, except the last block)

In [4]:
densenet = models.__dict__[args.arch](pretrained=args.pretrained)

for param in densenet.parameters():
    param.requires_grad = False

if use_gpu: densenet.cuda()

In [5]:
ss = 6

args.display_curindex = 0
args.base_lr = 0.05
args.display_interval = 20
args.momentum = 0.9
args.epoches = 120
args.training_thresholds = [ss*4,ss*3,ss*2,ss*1,ss*0,ss*5]
args.power = 0.5



pretrained = PreTrainedModel(densenet)
if use_gpu: 
    pretrained.cuda()


net = GradientNet(growth_rate=growth_rate)
if use_gpu:
    net.cuda()

if use_gpu: 
    mse_losses = [nn.MSELoss().cuda()] * 6
    test_losses = [nn.MSELoss().cuda()] * 6
else:
    mse_losses = [nn.MSELoss()] * 6
    test_losses = [nn.MSELoss()] * 6

In [6]:
# training loop
writer = SummaryWriter(comment='-growth_rate_{}'.format(growth_rate))

parameters = filter(lambda p: p.requires_grad, net.parameters())
optimizer = optim.SGD(parameters, lr=args.base_lr, momentum=args.momentum)

def adjust_learning_rate(optimizer, epoch, beg, end, reset_lr=None):
    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
    for param_group in optimizer.param_groups:
        if reset_lr != None:
            param_group['lr'] = reset_lr
            continue
        if epoch != 0: 
            # linear
#             param_group['lr'] *= (end-epoch) / (end-beg)
#             poly base_lr (1 - iter/max_iter) ^ (power)
            param_group['lr'] = args.base_lr * (float(end-epoch)/(end-beg)) ** (args.power)
            if param_group['lr'] < 1.0e-8: param_group['lr'] = 1.0e-8
        print('lr', param_group['lr'])
        
# def findLargerInd(target, arr):
#     res = list(filter(lambda x: x>target, arr))
#     print('res',res)
#     if len(res) == 0: return -1
#     return res[0]

for epoch in range(args.epoches):
    net.train()
    print('epoch: {} [{}]'.format(epoch, datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")))
    
    if epoch < args.training_thresholds[-1]: adjust_learning_rate(optimizer, epoch%ss, beg=0, end=ss-1)
    else: adjust_learning_rate(optimizer, epoch, beg=args.training_thresholds[-1], end=args.epoches-1)
    
    run_losses = [0] * len(args.training_thresholds)
    run_cnts   = [0.00001] * len(args.training_thresholds)
    if (epoch in args.training_thresholds) == True: 
        adjust_learning_rate(optimizer, epoch, reset_lr=args.base_lr, beg=-1, end=-1)
    writer.add_scalar('learning rate', optimizer.param_groups[0]['lr'], global_step=epoch)
    for ind, data in enumerate(train_loader, 0):
#         if  ind == 1 : break
        input_img, gt_albedo, gt_shading, test_scene, img_path = data
        im = input_img[0,:,:,:].numpy(); im = im.transpose(1,2,0); im = im[:,:,::-1]*255
        
        if test_scene[0] == 'alley_1':
            print('alley_1 yes')
        input_img = Variable(input_img)
        gt_albedo = Variable(gt_albedo)
        gt_shading = Variable(gt_shading)
        if use_gpu:
            input_img = input_img.cuda()
            gt_albedo = gt_albedo.cuda()
            gt_shading = gt_shading.cuda()
#         run_losses = [0] * len(mse_losses)
#         run_cnts = [0.00001] * len(mse_losses)
        if args.display_curindex % args.display_interval == 0:
            cv2.imwrite('snapshot{}/input.png'.format(args.gpu_num), im)

        optimizer.zero_grad()
        pretrained.train(); ft_pretreained = pretrained(input_img)
        ft_predict = net(ft_pretreained)
        for i, threshold in enumerate(args.training_thresholds):
#             threshold = args.training_thresholds[i]
            if epoch >= threshold:
#             if epoch >= 0:
                if i == 5: s = 1
                else: s = (2**(i+1))
                gt = gt_albedo.cpu().data.numpy()
                n,c,h,w = gt.shape
                gt = gt[0,:,:,:]
                gt = gt.transpose((1,2,0))
                gt = cv2.resize(gt, (h//s, w//s))
#                 gt = cv2.resize(gt, (h,w))
                if args.display_curindex % args.display_interval == 0:
                    cv2.imwrite('snapshot{}/gt-{}-{}.png'.format(args.gpu_num, epoch, i), gt[:,:,::-1]*255)
                gt = gt.transpose((2,0,1))
                gt = gt[np.newaxis, :]
                gt = Variable(torch.from_numpy(gt))
                if use_gpu: gt = gt.cuda()
                loss = mse_losses[i](ft_predict[i], gt)
                loss_data = loss.data.cpu().numpy()
                writer.add_scalar('{}th train iters loss'.format(i), loss_data, global_step=args.display_curindex)
                ma_ = ft_predict[i].max().cpu().data.numpy()
                mi_ = ft_predict[i].min().cpu().data.numpy()
                #print('mi', mi_, 'ma', ma_)
#                 writer.add_scalars('{}th train predict'.format(i), {'max': ma_, 'min': mi_}, global_step=args.display_curindex)
#                 run_cnts[i] += 1
                run_losses[i] += loss.data.cpu().numpy()[0]
                loss.backward(retain_graph=True)
                run_cnts[i] += 1
#                 print('i = ', i, '; weig\n', net.upsample01.weight[0,0,0:4,0:4].data.cpu().numpy())
#                 print('i = ', i, '; grad\n', net.upsample01.weight.grad[0,0,0:4,0:4].data.cpu().numpy())
                if args.display_curindex % args.display_interval == 0:
                    im = ft_predict[i].cpu().data.numpy()[0].transpose((1,2,0)) * 255
                    cv2.imwrite('snapshot{}/train-{}-{}.png'.format(args.gpu_num, epoch, i), im[:,:,::-1])
        optimizer.step()
        args.display_curindex += 1

    """ every epoch """
#     loss_output = 'ind: ' + str(args.display_curindex)
    loss_output = ''
    
    
    
    for i,v in enumerate(run_losses):
        if i == len(run_losses)-1: 
            loss_output += ' merged: %6f' % (run_losses[i] / run_cnts[i])
            continue
        loss_output += ' %2dM: %6f' % ((2**(4-i)), (run_losses[i] / run_cnts[i]))
    print(loss_output)
    # save at every epoch
    if (epoch+1) % 10 == 0:
        torch.save({
            'epoch': epoch,
            'args' : args,
            'state_dict': net.state_dict(),
            'optimizer': optimizer.state_dict()
        }, 'snapshot{}/snapshot-{}.pth.tar'.format(args.gpu_num, epoch))
    
    # test 
    test_losses_trainphase = [0] * len(args.training_thresholds)
    test_cnts_trainphase   = [0.00001] * len(args.training_thresholds)   
    for ind, data in enumerate(test_loader, 0):
        input_img, gt_albedo, gt_shading, test_scene, img_path = data
        input_img = Variable(input_img)
        gt_albedo = Variable(gt_albedo)
        gt_shading = Variable(gt_shading)
        if use_gpu:
            input_img = input_img.cuda(args.gpu_num)
        
        pretrained.train(); ft_pretreained = pretrained(input_img)
        ft_test = net(ft_pretreained)
            
        for i,v in enumerate(ft_test):
            if epoch < args.training_thresholds[i]: continue
            if i == 5: s = 1
            else: s = (2**(i+1))
            gt = gt_albedo.data.numpy()
            n,c,h,w = gt.shape
            gt = gt[0,:,:,:]
            gt = gt.transpose((1,2,0))
            gt = cv2.resize(gt, (h//s, w//s))
#             gt = cv2.resize(gt, (h,w))
            
            gt = gt.transpose((2,0,1))
            gt = gt[np.newaxis, :]
            gt = Variable(torch.from_numpy(gt))
            if use_gpu: gt = gt.cuda()

            loss = mse_losses[i](ft_test[i], gt)
            
            test_losses_trainphase[i] += loss.data.cpu().numpy()[0]
            test_cnts_trainphase[i] += 1
            v = v[0].cpu().data.numpy()
            v = v.transpose(1,2,0)
            if ind == 0: cv2.imwrite('snapshot{}/test-phase_train-{}-{}.png'.format(args.gpu_num, epoch, i), v[:,:,::-1]*255)

    
    net.eval()
    test_losses = [0] * len(args.training_thresholds)
    test_cnts   = [0.00001] * len(args.training_thresholds)   
    for ind, data in enumerate(test_loader, 0):
#         if ind == 1: break
        input_img, gt_albedo, gt_shading, test_scene, img_path = data
        input_img = Variable(input_img)
        gt_albedo = Variable(gt_albedo)
        gt_shading = Variable(gt_shading)
        if use_gpu:
            input_img = input_img.cuda(args.gpu_num)
            
        pretrained.eval(); ft_pretreained = pretrained(input_img)
        ft_test = net(ft_pretreained)
            
        for i,v in enumerate(ft_test):
            if epoch < args.training_thresholds[i]: continue
            if i == 5: s = 1
            else: s = (2**(i+1))
            gt = gt_albedo.data.numpy()
            n,c,h,w = gt.shape
            gt = gt[0,:,:,:]
            gt = gt.transpose((1,2,0))
            gt = cv2.resize(gt, (h//s, w//s))
#             gt = cv2.resize(gt, (h,w))
            
            gt = gt.transpose((2,0,1))
            gt = gt[np.newaxis, :]
            gt = Variable(torch.from_numpy(gt))
            if use_gpu: gt = gt.cuda()

            loss = mse_losses[i](ft_test[i], gt)
            
            test_losses[i] += loss.data.cpu().numpy()[0]
            test_cnts[i] += 1
            v = v[0].cpu().data.numpy()
            v = v.transpose(1,2,0)
            if ind == 0: cv2.imwrite('snapshot{}/test-phase_test-{}-{}.png'.format(args.gpu_num, epoch, i), v[:,:,::-1]*255)
    
    writer.add_scalars('16M loss', {
        'train 16M ': np.array([run_losses[0]/ run_cnts[0]]),
        'test_trainphase 16M ': np.array([test_losses_trainphase[0]/ test_cnts_trainphase[0]]),
        'test 16M ': np.array([test_losses[0]/ test_cnts[0]])
    }, global_step=epoch)  
    writer.add_scalars('8M loss', {
        'train 8M ': np.array([run_losses[1]/ run_cnts[1]]),
        'test_trainphase 8M ': np.array([test_losses_trainphase[1]/ test_cnts_trainphase[1]]),
        'test 8M ': np.array([test_losses[1]/ test_cnts[1]])
    }, global_step=epoch) 
    writer.add_scalars('4M loss', {
        'train 4M ': np.array([run_losses[2]/ run_cnts[2]]),
        'test_trainphase 4M ': np.array([test_losses_trainphase[2]/ test_cnts_trainphase[2]]),
        'test 4M ': np.array([test_losses[2]/ test_cnts[2]])
    }, global_step=epoch) 
    writer.add_scalars('2M loss', {
        'train 2M ': np.array([run_losses[3]/ run_cnts[3]]),
        'test_trainphase 2M ': np.array([test_losses_trainphase[3]/ test_cnts_trainphase[3]]),
        'test 2M ': np.array([test_losses[3]/ test_cnts[3]])
    }, global_step=epoch) 
    writer.add_scalars('1M loss', {
        'train 1M ': np.array([run_losses[4]/ run_cnts[4]]),
        'test_trainphase 1M ': np.array([test_losses_trainphase[4]/ test_cnts_trainphase[4]]),
        'test 1M ': np.array([test_losses[4]/ test_cnts[4]])
    }, global_step=epoch) 
    writer.add_scalars('merged loss', {
        'train merged ': np.array([run_losses[5]/ run_cnts[5]]),
        'test_trainphase merged ': np.array([test_losses_trainphase[5]/ test_cnts_trainphase[5]]),
        'test merged ': np.array([test_losses[5]/ test_cnts[5]])
    }, global_step=epoch)


epoch: 0 [2017-11-21 17:39:56]
lr 0.05
 16M: 0.000000  8M: 0.000000  4M: 0.000000  2M: 0.000000  1M: 0.054454 merged: 0.000000
epoch: 1 [2017-11-21 17:41:45]
lr 0.044721359549995794
 16M: 0.000000  8M: 0.000000  4M: 0.000000  2M: 0.000000  1M: 0.048894 merged: 0.000000
epoch: 2 [2017-11-21 17:43:34]
lr 0.038729833462074176
 16M: 0.000000  8M: 0.000000  4M: 0.000000  2M: 0.000000  1M: 0.048793 merged: 0.000000
epoch: 3 [2017-11-21 17:45:21]
lr 0.0316227766016838
 16M: 0.000000  8M: 0.000000  4M: 0.000000  2M: 0.000000  1M: 0.045243 merged: 0.000000
epoch: 4 [2017-11-21 17:47:09]
lr 0.022360679774997897
 16M: 0.000000  8M: 0.000000  4M: 0.000000  2M: 0.000000  1M: 0.043500 merged: 0.000000
epoch: 5 [2017-11-21 17:48:56]
lr 1e-08
 16M: 0.000000  8M: 0.000000  4M: 0.000000  2M: 0.000000  1M: 0.043367 merged: 0.000000
epoch: 6 [2017-11-21 17:50:44]
lr 1e-08
 16M: 0.000000  8M: 0.000000  4M: 0.000000  2M: 0.049237  1M: 0.045619 merged: 0.000000
epoch: 7 [2017-11-21 17:52:53]
lr 0.044721359549995794
 16M: 0.000000  8M: 0.000000  4M: 0.000000  2M: 0.041367  1M: 0.041764 merged: 0.000000
epoch: 8 [2017-11-21 17:55:00]
lr 0.038729833462074176
 16M: 0.000000  8M: 0.000000  4M: 0.000000  2M: 0.041623  1M: 0.040791 merged: 0.000000
epoch: 9 [2017-11-21 17:57:08]
lr 0.0316227766016838
 16M: 0.000000  8M: 0.000000  4M: 0.000000  2M: 0.037810  1M: 0.036591 merged: 0.000000
epoch: 10 [2017-11-21 17:59:18]
lr 0.022360679774997897
 16M: 0.000000  8M: 0.000000  4M: 0.000000  2M: 0.031751  1M: 0.032871 merged: 0.000000
epoch: 11 [2017-11-21 18:01:26]
lr 1e-08
 16M: 0.000000  8M: 0.000000  4M: 0.000000  2M: 0.027010  1M: 0.029272 merged: 0.000000
epoch: 12 [2017-11-21 18:03:38]
lr 1e-08
 16M: 0.000000  8M: 0.000000  4M: 0.046819  2M: 0.030780  1M: 0.035406 merged: 0.000000
epoch: 13 [2017-11-21 18:06:03]
lr 0.044721359549995794
 16M: 0.000000  8M: 0.000000  4M: 0.034769  2M: 0.025618  1M: 0.032694 merged: 0.000000
epoch: 14 [2017-11-21 18:08:30]
lr 0.038729833462074176
 16M: 0.000000  8M: 0.000000  4M: 0.024534  2M: 0.022047  1M: 0.029620 merged: 0.000000
epoch: 15 [2017-11-21 18:10:54]
lr 0.0316227766016838
 16M: 0.000000  8M: 0.000000  4M: 0.021939  2M: 0.020324  1M: 0.027073 merged: 0.000000
epoch: 16 [2017-11-21 18:13:17]
lr 0.022360679774997897
 16M: 0.000000  8M: 0.000000  4M: 0.019103  2M: 0.019104  1M: 0.025697 merged: 0.000000
epoch: 17 [2017-11-21 18:15:41]
lr 1e-08
 16M: 0.000000  8M: 0.000000  4M: 0.018113  2M: 0.018522  1M: 0.024086 merged: 0.000000
epoch: 18 [2017-11-21 18:18:07]
lr 1e-08
 16M: 0.000000  8M: 0.045386  4M: 0.020586  2M: 0.020072  1M: 0.027422 merged: 0.000000
epoch: 19 [2017-11-21 18:20:40]
lr 0.044721359549995794
 16M: 0.000000  8M: 0.033413  4M: 0.019252  2M: 0.019442  1M: 0.026363 merged: 0.000000
epoch: 20 [2017-11-21 18:23:13]
lr 0.038729833462074176
 16M: 0.000000  8M: 0.024713  4M: 0.017110  2M: 0.018045  1M: 0.025249 merged: 0.000000
epoch: 21 [2017-11-21 18:25:47]
lr 0.0316227766016838
 16M: 0.000000  8M: 0.020762  4M: 0.016142  2M: 0.017200  1M: 0.023749 merged: 0.000000
epoch: 22 [2017-11-21 18:28:22]
lr 0.022360679774997897
 16M: 0.000000  8M: 0.018734  4M: 0.015070  2M: 0.016654  1M: 0.022885 merged: 0.000000
epoch: 23 [2017-11-21 18:31:02]
lr 1e-08
 16M: 0.000000  8M: 0.017484  4M: 0.014402  2M: 0.016054  1M: 0.021745 merged: 0.000000
epoch: 24 [2017-11-21 18:33:37]
lr 1e-08
 16M: 0.045360  8M: 0.021110  4M: 0.015954  2M: 0.017320  1M: 0.024896 merged: 0.000000
epoch: 25 [2017-11-21 18:36:38]
lr 0.044721359549995794
 16M: 0.032872  8M: 0.018717  4M: 0.015619  2M: 0.017037  1M: 0.024430 merged: 0.000000
epoch: 26 [2017-11-21 18:39:35]
lr 0.038729833462074176
 16M: 0.025604  8M: 0.017148  4M: 0.014860  2M: 0.016310  1M: 0.023249 merged: 0.000000
epoch: 27 [2017-11-21 18:42:32]
lr 0.0316227766016838
 16M: 0.020868  8M: 0.016170  4M: 0.014002  2M: 0.015723  1M: 0.022338 merged: 0.000000
epoch: 28 [2017-11-21 18:45:28]
lr 0.022360679774997897
 16M: 0.018587  8M: 0.015008  4M: 0.013719  2M: 0.015476  1M: 0.021973 merged: 0.000000
epoch: 29 [2017-11-21 18:48:26]
lr 1e-08
 16M: 0.016492  8M: 0.014369  4M: 0.012966  2M: 0.015083  1M: 0.021118 merged: 0.000000
epoch: 30 [2017-11-21 18:51:21]
lr 0.05
 16M: 0.022124  8M: 0.018213  4M: 0.015385  2M: 0.017013  1M: 0.023356 merged: 0.038542
epoch: 31 [2017-11-21 18:55:27]
lr 0.04971830761761256
 16M: 0.021922  8M: 0.018095  4M: 0.015973  2M: 0.017245  1M: 0.023638 merged: 0.027641
epoch: 32 [2017-11-21 18:59:35]
lr 0.04943501011144937
 16M: 0.019416  8M: 0.016438  4M: 0.014864  2M: 0.016753  1M: 0.023488 merged: 0.020799
epoch: 33 [2017-11-21 19:03:39]
lr 0.04915007972606608
 16M: 0.019216  8M: 0.016630  4M: 0.014155  2M: 0.016864  1M: 0.022710 merged: 0.019291
epoch: 34 [2017-11-21 19:07:51]
lr 0.04886348789677424
 16M: 0.018492  8M: 0.016301  4M: 0.014606  2M: 0.017065  1M: 0.023636 merged: 0.018161
epoch: 35 [2017-11-21 19:12:01]
lr 0.04857520521621862
 16M: 0.017290  8M: 0.016718  4M: 0.014497  2M: 0.016622  1M: 0.022670 merged: 0.017816
epoch: 36 [2017-11-21 19:16:09]
lr 0.04828520139915856
 16M: 0.016038  8M: 0.015055  4M: 0.013223  2M: 0.015692  1M: 0.021332 merged: 0.016487
epoch: 37 [2017-11-21 19:20:20]
lr 0.047993445245333805
 16M: 0.015453  8M: 0.014725  4M: 0.013470  2M: 0.015870  1M: 0.021881 merged: 0.015857
epoch: 38 [2017-11-21 19:24:23]
lr 0.0476999046002862
 16M: 0.014672  8M: 0.014155  4M: 0.012799  2M: 0.015388  1M: 0.021366 merged: 0.015122
epoch: 39 [2017-11-21 19:28:29]
lr 0.04740454631399772
 16M: 0.014860  8M: 0.014153  4M: 0.012944  2M: 0.015251  1M: 0.021855 merged: 0.015261
epoch: 40 [2017-11-21 19:32:36]
lr 0.04710733619719444
 16M: 0.013886  8M: 0.013687  4M: 0.012527  2M: 0.014773  1M: 0.020819 merged: 0.014336
epoch: 41 [2017-11-21 19:36:43]
lr 0.04680823897515326
 16M: 0.013763  8M: 0.013289  4M: 0.012293  2M: 0.014534  1M: 0.020972 merged: 0.013735
epoch: 42 [2017-11-21 19:40:49]
lr 0.04650721823883479
 16M: 0.013005  8M: 0.013008  4M: 0.012193  2M: 0.014513  1M: 0.020582 merged: 0.013305
epoch: 43 [2017-11-21 19:45:00]
lr 0.046204236393150765
 16M: 0.012825  8M: 0.012773  4M: 0.012005  2M: 0.014245  1M: 0.020763 merged: 0.013200
epoch: 44 [2017-11-21 19:49:09]
lr 0.045899254602157845
 16M: 0.012724  8M: 0.012916  4M: 0.011939  2M: 0.014419  1M: 0.020391 merged: 0.013172
epoch: 45 [2017-11-21 19:53:24]
lr 0.04559223273095164
 16M: 0.012146  8M: 0.012518  4M: 0.011897  2M: 0.014000  1M: 0.020469 merged: 0.012681
epoch: 46 [2017-11-21 19:57:37]
lr 0.045283129284014914
 16M: 0.012282  8M: 0.012238  4M: 0.011748  2M: 0.014089  1M: 0.020646 merged: 0.012719
epoch: 47 [2017-11-21 20:01:44]
lr 0.04497190133975169
 16M: 0.011412  8M: 0.011768  4M: 0.011041  2M: 0.013204  1M: 0.019238 merged: 0.011930
epoch: 48 [2017-11-21 20:05:53]
lr 0.04465850448091506
 16M: 0.011412  8M: 0.011924  4M: 0.011181  2M: 0.013568  1M: 0.019814 merged: 0.012095
epoch: 49 [2017-11-21 20:10:03]
lr 0.044342892720609255
 16M: 0.011306  8M: 0.011555  4M: 0.010983  2M: 0.013075  1M: 0.019263 merged: 0.011791
epoch: 50 [2017-11-21 20:14:18]
lr 0.044025018423517
 16M: 0.011448  8M: 0.011774  4M: 0.011016  2M: 0.013374  1M: 0.019346 merged: 0.011891
epoch: 51 [2017-11-21 20:18:26]
lr 0.04370483222197017
 16M: 0.010843  8M: 0.011562  4M: 0.011008  2M: 0.013061  1M: 0.019172 merged: 0.011636
epoch: 52 [2017-11-21 20:22:37]
lr 0.043382282926444894
 16M: 0.010411  8M: 0.011265  4M: 0.010755  2M: 0.012859  1M: 0.018840 merged: 0.011217
epoch: 53 [2017-11-21 20:26:50]
lr 0.04305731743002185
 16M: 0.010283  8M: 0.010835  4M: 0.010492  2M: 0.012711  1M: 0.018540 merged: 0.010985
epoch: 54 [2017-11-21 20:31:01]
lr 0.04272988060630656
 16M: 0.010657  8M: 0.011286  4M: 0.010590  2M: 0.012953  1M: 0.018833 merged: 0.011315
epoch: 55 [2017-11-21 20:35:12]
lr 0.04239991520025441
 16M: 0.010793  8M: 0.011334  4M: 0.010820  2M: 0.012701  1M: 0.018087 merged: 0.011319
epoch: 56 [2017-11-21 20:39:23]
lr 0.0420673617112877
 16M: 0.010005  8M: 0.010686  4M: 0.010380  2M: 0.012595  1M: 0.018433 merged: 0.010548
epoch: 57 [2017-11-21 20:43:36]
lr 0.041732158268029534
 16M: 0.009874  8M: 0.010603  4M: 0.010121  2M: 0.012324  1M: 0.018000 merged: 0.010492
epoch: 58 [2017-11-21 20:47:45]
lr 0.041394240493907074
 16M: 0.009932  8M: 0.010648  4M: 0.010375  2M: 0.012547  1M: 0.018573 merged: 0.010424
epoch: 59 [2017-11-21 20:51:57]
lr 0.041053541362798006
 16M: 0.009747  8M: 0.010781  4M: 0.010537  2M: 0.012542  1M: 0.018476 merged: 0.010672
epoch: 60 [2017-11-21 20:56:09]
lr 0.04070999104380296
 16M: 0.009928  8M: 0.010386  4M: 0.010189  2M: 0.012221  1M: 0.017839 merged: 0.010272
epoch: 61 [2017-11-21 21:00:15]
lr 0.04036351673412598
 16M: 0.009476  8M: 0.010711  4M: 0.010084  2M: 0.012480  1M: 0.018517 merged: 0.010437
epoch: 62 [2017-11-21 21:04:24]
lr 0.04001404247893005
 16M: 0.009693  8M: 0.010509  4M: 0.010043  2M: 0.012105  1M: 0.017715 merged: 0.010278
epoch: 63 [2017-11-21 21:08:34]
lr 0.03966148897690515
 16M: 0.009763  8M: 0.010764  4M: 0.010162  2M: 0.012311  1M: 0.017481 merged: 0.010747
epoch: 64 [2017-11-21 21:12:44]
lr 0.03930577337013889
 16M: 0.009434  8M: 0.010562  4M: 0.010012  2M: 0.012309  1M: 0.017693 merged: 0.010198
epoch: 65 [2017-11-21 21:16:57]
lr 0.038946809016712394
 16M: 0.009232  8M: 0.010286  4M: 0.009878  2M: 0.012294  1M: 0.017615 merged: 0.009975
epoch: 66 [2017-11-21 21:21:08]
lr 0.03858450524425343
 16M: 0.009056  8M: 0.010161  4M: 0.009853  2M: 0.011986  1M: 0.017395 merged: 0.009788
epoch: 67 [2017-11-21 21:25:21]
lr 0.03821876708246056
 16M: 0.009225  8M: 0.010112  4M: 0.009872  2M: 0.012016  1M: 0.017246 merged: 0.009890
epoch: 68 [2017-11-21 21:29:33]
lr 0.03784949497236286
 16M: 0.008895  8M: 0.010096  4M: 0.009709  2M: 0.012036  1M: 0.017819 merged: 0.009856
epoch: 69 [2017-11-21 21:33:44]
lr 0.03747658444979307
 16M: 0.008942  8M: 0.010018  4M: 0.009573  2M: 0.011817  1M: 0.017296 merged: 0.009809
epoch: 70 [2017-11-21 21:37:54]
lr 0.0370999258002226
 16M: 0.008769  8M: 0.009861  4M: 0.009766  2M: 0.011811  1M: 0.017191 merged: 0.009529
epoch: 71 [2017-11-21 21:42:04]
lr 0.03671940368172628
 16M: 0.008703  8M: 0.009811  4M: 0.009457  2M: 0.011644  1M: 0.017166 merged: 0.009435
epoch: 72 [2017-11-21 21:46:15]
lr 0.03633489671240478
 16M: 0.008753  8M: 0.009786  4M: 0.009539  2M: 0.011710  1M: 0.017429 merged: 0.009353
epoch: 73 [2017-11-21 21:50:25]
lr 0.03594627701808178
 16M: 0.008480  8M: 0.009684  4M: 0.009510  2M: 0.011665  1M: 0.017339 merged: 0.009260
epoch: 74 [2017-11-21 21:54:36]
lr 0.035553409735498295
 16M: 0.008514  8M: 0.009708  4M: 0.009408  2M: 0.011617  1M: 0.017133 merged: 0.009273
epoch: 75 [2017-11-21 21:58:50]
lr 0.03515615246553262
 16M: 0.008355  8M: 0.009308  4M: 0.009088  2M: 0.011312  1M: 0.016694 merged: 0.009009
epoch: 76 [2017-11-21 22:03:03]
lr 0.03475435467016077
 16M: 0.008242  8M: 0.009282  4M: 0.009121  2M: 0.011377  1M: 0.016863 merged: 0.009023
epoch: 77 [2017-11-21 22:07:19]
lr 0.034347857005916346
 16M: 0.008333  8M: 0.009456  4M: 0.009206  2M: 0.011625  1M: 0.017102 merged: 0.009150
epoch: 78 [2017-11-21 22:11:30]
lr 0.0339364905854808
 16M: 0.008502  8M: 0.009617  4M: 0.009393  2M: 0.011415  1M: 0.016939 merged: 0.009198
epoch: 79 [2017-11-21 22:15:47]
lr 0.03352007615769955
 16M: 0.008494  8M: 0.009544  4M: 0.009475  2M: 0.011619  1M: 0.017156 merged: 0.009121
epoch: 80 [2017-11-21 22:20:00]
lr 0.03309842319473132
 16M: 0.008335  8M: 0.009482  4M: 0.009239  2M: 0.011446  1M: 0.017225 merged: 0.009011
epoch: 81 [2017-11-21 22:24:10]
lr 0.03267132887314317
 16M: 0.008037  8M: 0.009344  4M: 0.009153  2M: 0.011348  1M: 0.016760 merged: 0.008769
epoch: 82 [2017-11-21 22:28:26]
lr 0.03223857693349118
 16M: 0.007989  8M: 0.009211  4M: 0.008925  2M: 0.011155  1M: 0.016687 merged: 0.008720
epoch: 83 [2017-11-21 22:32:39]
lr 0.0317999364001908
 16M: 0.008008  8M: 0.009418  4M: 0.009031  2M: 0.011291  1M: 0.016423 merged: 0.008920
epoch: 84 [2017-11-21 22:36:51]
lr 0.031355160140170396
 16M: 0.007725  8M: 0.008968  4M: 0.008778  2M: 0.010801  1M: 0.016089 merged: 0.008481
epoch: 85 [2017-11-21 22:41:06]
lr 0.03090398323477543
 16M: 0.007984  8M: 0.009002  4M: 0.008879  2M: 0.011153  1M: 0.016734 merged: 0.008551
epoch: 86 [2017-11-21 22:45:18]
lr 0.030446121134470178
 16M: 0.008011  8M: 0.008771  4M: 0.008635  2M: 0.010599  1M: 0.015854 merged: 0.008216
epoch: 87 [2017-11-21 22:49:28]
lr 0.02998126755983446
 16M: 0.008170  8M: 0.009216  4M: 0.008915  2M: 0.011068  1M: 0.016227 merged: 0.008790
epoch: 88 [2017-11-21 22:53:41]
lr 0.029509092104873926
 16M: 0.007762  8M: 0.008953  4M: 0.008797  2M: 0.011028  1M: 0.016331 merged: 0.008339
epoch: 89 [2017-11-21 22:57:57]
lr 0.029029237489356888
 16M: 0.007813  8M: 0.009186  4M: 0.008925  2M: 0.011188  1M: 0.016662 merged: 0.008522
epoch: 90 [2017-11-21 23:02:08]
lr 0.028541316395237167
 16M: 0.007715  8M: 0.008936  4M: 0.008888  2M: 0.010957  1M: 0.016022 merged: 0.008468
epoch: 91 [2017-11-21 23:06:21]
lr 0.028044907807525134
 16M: 0.007793  8M: 0.009079  4M: 0.008950  2M: 0.010997  1M: 0.016134 merged: 0.008383
epoch: 92 [2017-11-21 23:10:32]
lr 0.027539552761294706
 16M: 0.007734  8M: 0.008939  4M: 0.008749  2M: 0.010815  1M: 0.016025 merged: 0.008517
epoch: 93 [2017-11-21 23:14:45]
lr 0.027024749372597065
 16M: 0.007460  8M: 0.008890  4M: 0.008734  2M: 0.010819  1M: 0.016300 merged: 0.008165
epoch: 94 [2017-11-21 23:18:56]
lr 0.026499947000159004
 16M: 0.007500  8M: 0.008894  4M: 0.008692  2M: 0.010884  1M: 0.016010 merged: 0.008234
epoch: 95 [2017-11-21 23:23:07]
lr 0.02596453934447493
 16M: 0.007420  8M: 0.008701  4M: 0.008521  2M: 0.010699  1M: 0.015669 merged: 0.007946
epoch: 96 [2017-11-21 23:27:18]
lr 0.025417856237895775
 16M: 0.007512  8M: 0.008792  4M: 0.008609  2M: 0.010945  1M: 0.016531 merged: 0.008253
epoch: 97 [2017-11-21 23:31:26]
lr 0.02485915380880628
 16M: 0.007326  8M: 0.008657  4M: 0.008503  2M: 0.010574  1M: 0.015878 merged: 0.007975
epoch: 98 [2017-11-21 23:35:41]
lr 0.02428760260810931
 16M: 0.007221  8M: 0.008499  4M: 0.008522  2M: 0.010832  1M: 0.016202 merged: 0.007786
epoch: 99 [2017-11-21 23:39:55]
lr 0.02370227315699886
 16M: 0.007390  8M: 0.008626  4M: 0.008515  2M: 0.010824  1M: 0.016181 merged: 0.007892
epoch: 100 [2017-11-21 23:44:08]
lr 0.023102118196575382
 16M: 0.006962  8M: 0.008353  4M: 0.008424  2M: 0.010583  1M: 0.015693 merged: 0.007624
epoch: 101 [2017-11-21 23:48:18]
lr 0.022485950669875843
 16M: 0.007051  8M: 0.008357  4M: 0.008263  2M: 0.010582  1M: 0.015868 merged: 0.007641
epoch: 102 [2017-11-21 23:52:31]
lr 0.021852416110985085
 16M: 0.007210  8M: 0.008282  4M: 0.008306  2M: 0.010398  1M: 0.015483 merged: 0.007600
epoch: 103 [2017-11-21 23:56:46]
lr 0.021199957600127203
 16M: 0.007179  8M: 0.008351  4M: 0.008358  2M: 0.010595  1M: 0.015657 merged: 0.007810
epoch: 104 [2017-11-22 00:00:59]
lr 0.020526770681399003
 16M: 0.007121  8M: 0.008295  4M: 0.008283  2M: 0.010439  1M: 0.015633 merged: 0.007604
epoch: 105 [2017-11-22 00:05:11]
lr 0.019830744488452574
 16M: 0.006943  8M: 0.008263  4M: 0.008209  2M: 0.010498  1M: 0.015532 merged: 0.007467
epoch: 106 [2017-11-22 00:09:24]
lr 0.01910938354123028
 16M: 0.006987  8M: 0.008187  4M: 0.008122  2M: 0.010184  1M: 0.015366 merged: 0.007487
epoch: 107 [2017-11-22 00:13:36]
lr 0.01835970184086314
 16M: 0.007025  8M: 0.008279  4M: 0.008280  2M: 0.010439  1M: 0.015539 merged: 0.007589
epoch: 108 [2017-11-22 00:17:46]
lr 0.01757807623276631
 16M: 0.006764  8M: 0.008133  4M: 0.008215  2M: 0.010525  1M: 0.015705 merged: 0.007387
epoch: 109 [2017-11-22 00:22:00]
lr 0.016760038078849775
 16M: 0.007118  8M: 0.008378  4M: 0.008286  2M: 0.010343  1M: 0.015723 merged: 0.007476
epoch: 110 [2017-11-22 00:26:12]
lr 0.0158999682000954
 16M: 0.007036  8M: 0.008359  4M: 0.008363  2M: 0.010724  1M: 0.015657 merged: 0.007645
epoch: 111 [2017-11-22 00:30:24]
lr 0.01499063377991723
 16M: 0.006935  8M: 0.008100  4M: 0.007990  2M: 0.010061  1M: 0.014913 merged: 0.007371
epoch: 112 [2017-11-22 00:34:34]
lr 0.014022453903762567
 16M: 0.006779  8M: 0.008175  4M: 0.008221  2M: 0.010435  1M: 0.015662 merged: 0.007369
epoch: 113 [2017-11-22 00:38:47]
lr 0.012982269672237465
 16M: 0.006684  8M: 0.008037  4M: 0.007953  2M: 0.009953  1M: 0.014647 merged: 0.007171
epoch: 114 [2017-11-22 00:43:00]
lr 0.01185113657849943
 16M: 0.006644  8M: 0.007908  4M: 0.007853  2M: 0.009909  1M: 0.014916 merged: 0.007126
epoch: 115 [2017-11-22 00:47:13]
lr 0.010599978800063602
 16M: 0.006485  8M: 0.007908  4M: 0.007938  2M: 0.010096  1M: 0.015258 merged: 0.007044
epoch: 116 [2017-11-22 00:51:26]
lr 0.00917985092043157
 16M: 0.006728  8M: 0.008152  4M: 0.008177  2M: 0.010354  1M: 0.015436 merged: 0.007253
epoch: 117 [2017-11-22 00:55:43]
lr 0.007495316889958615
 16M: 0.006650  8M: 0.007991  4M: 0.007988  2M: 0.010128  1M: 0.014857 merged: 0.007189
epoch: 118 [2017-11-22 00:59:57]
lr 0.005299989400031801
 16M: 0.006587  8M: 0.007870  4M: 0.007945  2M: 0.010126  1M: 0.015081 merged: 0.007081
epoch: 119 [2017-11-22 01:04:10]
lr 1e-08
 16M: 0.006572  8M: 0.007956  4M: 0.008137  2M: 0.010420  1M: 0.015518 merged: 0.007040

Visualize Graph


In [ ]:
from graphviz import Digraph
import torch
from torch.autograd import Variable


def make_dot(var, params=None):
    """ Produces Graphviz representation of PyTorch autograd graph
    Blue nodes are the Variables that require grad, orange are Tensors
    saved for backward in torch.autograd.Function
    Args:
        var: output Variable
        params: dict of (name, Variable) to add names to node that
            require grad (TODO: make optional)
    """
    if params is not None:
        assert isinstance(params.values()[0], Variable)
        param_map = {id(v): k for k, v in params.items()}

    node_attr = dict(style='filled',
                     shape='box',
                     align='left',
                     fontsize='12',
                     ranksep='0.1',
                     height='0.2')
    dot = Digraph(node_attr=node_attr, graph_attr=dict(size="10240,10240"), format='svg')
    seen = set()

    def size_to_str(size):
        return '('+(', ').join(['%d' % v for v in size])+')'

    def add_nodes(var):
        if var not in seen:
            if torch.is_tensor(var):
                dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')
            elif hasattr(var, 'variable'):
                u = var.variable
                name = param_map[id(u)] if params is not None else ''
                node_name = '%s\n %s' % (name, size_to_str(u.size()))
                dot.node(str(id(var)), node_name, fillcolor='lightblue')
            else:
                dot.node(str(id(var)), str(type(var).__name__))
            seen.add(var)
            if hasattr(var, 'next_functions'):
                for u in var.next_functions:
                    if u[0] is not None:
                        dot.edge(str(id(u[0])), str(id(var)))
                        add_nodes(u[0])
            if hasattr(var, 'saved_tensors'):
                for t in var.saved_tensors:
                    dot.edge(str(id(t)), str(id(var)))
                    add_nodes(t)
    add_nodes(var.grad_fn)
    return dot

In [ ]:
x = Variable(torch.zeros(1,3,256,256))
y = net(x.cuda())
g = make_dot(y[-1])

In [ ]:
g.render('net')

In [ ]: