In [5]:
import os, glob, platform, datetime, random
from collections import OrderedDict

import torch
import torch.nn as nn
import torch.utils.data as data_utils
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
from torch.autograd import Variable
from torch import functional as F
# import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms

import cv2
from PIL import Image
from tensorboardX import SummaryWriter

import numpy as np
from numpy.linalg import inv as denseinv
from scipy import sparse
from scipy.sparse import lil_matrix, csr_matrix
from scipy.sparse.linalg import spsolve
from scipy.sparse.linalg import inv as spinv
import scipy.misc

from myimagefolder import MyImageFolder
from mymodel_gpu2 import PreTrainedModel, GradientNet
from myargs import Args

Configurations


In [6]:
args = Args()
args.test_scene = 'alley_1'
args.arch = "densenet121"
args.epoches = 500
args.epoches_unary_threshold = 0
args.image_h = 256
args.image_w = 256
args.img_extentions = ["png"]
args.training_thresholds = [250,200,150,50,0,300]
args.base_lr = 1
args.lr = args.base_lr
args.snapshot_interval = 5000
args.debug = True
args.gpu_num = 2
args.display_interval = 50
args.display_curindex = 0

system_ = platform.system()
system_dist, system_version, _ = platform.dist()
if system_ == "Darwin": 
    args.train_dir = '/Volumes/Transcend/dataset/sintel2'
    args.pretrained = False
elif platform.dist() ==  ('debian', 'jessie/sid', ''):
    args.train_dir = '/home/lwp/workspace/sintel2'
    args.pretrained = True
elif platform.dist() == ('debian', 'stretch/sid', ''):
    args.train_dir = '/home/cad/lwp/workspace/dataset/sintel2'
    args.pretrained = True

if platform.system() == 'Linux': use_gpu = True
else: use_gpu = False
if use_gpu:
    torch.cuda.set_device(args.gpu_num)
    

print(platform.dist())


('debian', 'jessie/sid', '')

My DataLoader


In [7]:
train_dataset = MyImageFolder(args.train_dir, 'train',
                       transforms.Compose(
        [transforms.ToTensor()]
    ), random_crop=True, 
    img_extentions=args.img_extentions, test_scene=args.test_scene, image_h=args.image_h, image_w=args.image_w)
test_dataset = MyImageFolder(args.train_dir, 'test', 
                       transforms.Compose(
        [transforms.CenterCrop((args.image_h, args.image_w)),
         transforms.ToTensor()]
    ), random_crop=False,
    img_extentions=args.img_extentions, test_scene=args.test_scene, image_h=args.image_h, image_w=args.image_w)

train_loader = data_utils.DataLoader(train_dataset,1,True,num_workers=1)
test_loader = data_utils.DataLoader(test_dataset,1,True,num_workers=1)

Load Pretrained Model

Defination

  • DenseNet-121: num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16)
    • First Convolution: 32M -> 16M -> 8M
    • every transition: 8M -> 4M -> 2M (downsample 1/2, except the last block)

In [8]:
densenet = models.__dict__[args.arch](pretrained=args.pretrained)

for param in densenet.parameters():
    param.requires_grad = False

if use_gpu: densenet.cuda()

In [9]:
ss = 6

args.display_curindex = 0
args.base_lr = 0.05
args.display_interval = 20
args.momentum = 0.9
args.epoches = 120
args.training_thresholds = [ss*4,ss*3,ss*2,ss*1,ss*0,ss*5]
args.power = 0.5



pretrained = PreTrainedModel(densenet)
if use_gpu: 
    pretrained.cuda()
    
net = GradientNet(pretrained)
if use_gpu: 
    net.cuda()

if use_gpu: 
    mse_losses = [nn.MSELoss().cuda()] * 6
    test_losses = [nn.MSELoss().cuda()] * 6
else:
    mse_losses = [nn.MSELoss()] * 6
    test_losses = [nn.MSELoss()] * 6

In [10]:
# training loop
writer = SummaryWriter()
writer.add_text('training', 'reduce training time')

parameters = filter(lambda p: p.requires_grad, net.parameters())
optimizer = optim.SGD(parameters, lr=args.base_lr, momentum=args.momentum)

def adjust_learning_rate(optimizer, epoch, beg, end, reset_lr=None):
    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
    for param_group in optimizer.param_groups:
        if reset_lr != None:
            param_group['lr'] = reset_lr
            continue
        if epoch != 0: 
            # linear
#             param_group['lr'] *= (end-epoch) / (end-beg)
#             poly base_lr (1 - iter/max_iter) ^ (power)
            param_group['lr'] = args.base_lr * (float(end-epoch)/(end-beg)) ** (args.power)
            if param_group['lr'] < 1.0e-8: param_group['lr'] = 1.0e-8
        print('lr', param_group['lr'])
        
# def findLargerInd(target, arr):
#     res = list(filter(lambda x: x>target, arr))
#     print('res',res)
#     if len(res) == 0: return -1
#     return res[0]

for epoch in range(args.epoches):
    net.train()
    print('epoch: {} [{}]'.format(epoch, datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")))
    
    if epoch < args.training_thresholds[-1]: adjust_learning_rate(optimizer, epoch%ss, beg=0, end=ss-1)
    else: adjust_learning_rate(optimizer, epoch, beg=args.training_thresholds[-1], end=args.epoches-1)
    
    run_losses = [0] * len(args.training_thresholds)
    run_cnts   = [0.00001] * len(args.training_thresholds)
    if (epoch in args.training_thresholds) == True: 
        adjust_learning_rate(optimizer, epoch, reset_lr=args.base_lr, beg=-1, end=-1)
    writer.add_scalar('learning rate', optimizer.param_groups[0]['lr'], global_step=epoch)
    for ind, data in enumerate(train_loader, 0):
#         if  ind == 1 : break
        input_img, gt_albedo, gt_shading, test_scene, img_path = data
        im = input_img[0,:,:,:].numpy(); im = im.transpose(1,2,0); im = im[:,:,::-1]*255
        
        if test_scene[0] == 'alley_1':
            print('alley_1 yes')
        input_img = Variable(input_img)
        gt_albedo = Variable(gt_albedo)
        gt_shading = Variable(gt_shading)
        if use_gpu:
            input_img = input_img.cuda()
            gt_albedo = gt_albedo.cuda()
            gt_shading = gt_shading.cuda()
#         run_losses = [0] * len(mse_losses)
#         run_cnts = [0.00001] * len(mse_losses)
        if args.display_curindex % args.display_interval == 0:
            cv2.imwrite('snapshot2/input.png', im)

        optimizer.zero_grad()
        ft_predict = net(input_img)
        for i, threshold in enumerate(args.training_thresholds):
#             threshold = args.training_thresholds[i]
            if epoch >= threshold:
#             if epoch >= 0:
                if i == 5: s = 1
                else: s = (2**(i+1))
                gt = gt_albedo.cpu().data.numpy()
                n,c,h,w = gt.shape
                gt = gt[0,:,:,:]
                gt = gt.transpose((1,2,0))
                gt = cv2.resize(gt, (h//s, w//s))
#                 gt = cv2.resize(gt, (h,w))
                if args.display_curindex % args.display_interval == 0:
                    cv2.imwrite('snapshot2/gt-{}-{}.png'.format(epoch, i), gt[:,:,::-1]*255)
                gt = gt.transpose((2,0,1))
                gt = gt[np.newaxis, :]
                gt = Variable(torch.from_numpy(gt))
                if use_gpu: gt = gt.cuda()
                loss = mse_losses[i](ft_predict[i], gt)
                loss_data = loss.data.cpu().numpy()
                writer.add_scalar('{}th train iters loss'.format(i), loss_data, global_step=args.display_curindex)
                ma_ = ft_predict[i].max().cpu().data.numpy()
                mi_ = ft_predict[i].min().cpu().data.numpy()
                #print('mi', mi_, 'ma', ma_)
#                 writer.add_scalars('{}th train predict'.format(i), {'max': ma_, 'min': mi_}, global_step=args.display_curindex)
#                 run_cnts[i] += 1
                run_losses[i] += loss.data.cpu().numpy()[0]
                loss.backward(retain_graph=True)
                run_cnts[i] += 1
#                 print('i = ', i, '; weig\n', net.upsample01.weight[0,0,0:4,0:4].data.cpu().numpy())
#                 print('i = ', i, '; grad\n', net.upsample01.weight.grad[0,0,0:4,0:4].data.cpu().numpy())
                if args.display_curindex % args.display_interval == 0:
                    im = ft_predict[i].cpu().data.numpy()[0].transpose((1,2,0)) * 255
                    cv2.imwrite('snapshot2/train-{}-{}.png'.format(epoch, i), im[:,:,::-1])
        optimizer.step()
        args.display_curindex += 1

    """ every epoch """
#     loss_output = 'ind: ' + str(args.display_curindex)
    loss_output = ''
    
    
    
    for i,v in enumerate(run_losses):
        if i == len(run_losses)-1: 
            loss_output += ' merged: %6f' % (run_losses[i] / run_cnts[i])
            continue
        loss_output += ' %2dM: %6f' % ((2**(4-i)), (run_losses[i] / run_cnts[i]))
    print(loss_output)
    # save at every epoch
    if (epoch+1) % 10 == 0:
        torch.save({
            'epoch': epoch,
            'args' : args,
            'state_dict': net.state_dict(),
            'optimizer': optimizer.state_dict()
        }, 'snapshot2/snapshot-{}.pth.tar'.format(epoch))
    
    # test 
    test_losses_trainphase = [0] * len(args.training_thresholds)
    test_cnts_trainphase   = [0.00001] * len(args.training_thresholds)   
    for ind, data in enumerate(test_loader, 0):
        input_img, gt_albedo, gt_shading, test_scene, img_path = data
        input_img = Variable(input_img)
        gt_albedo = Variable(gt_albedo)
        gt_shading = Variable(gt_shading)
        if use_gpu:
            input_img = input_img.cuda(args.gpu_num)
        ft_test = net(input_img)
            
        for i,v in enumerate(ft_test):
            if epoch < args.training_thresholds[i]: continue
            if i == 5: s = 1
            else: s = (2**(i+1))
            gt = gt_albedo.data.numpy()
            n,c,h,w = gt.shape
            gt = gt[0,:,:,:]
            gt = gt.transpose((1,2,0))
            gt = cv2.resize(gt, (h//s, w//s))
#             gt = cv2.resize(gt, (h,w))
            
            gt = gt.transpose((2,0,1))
            gt = gt[np.newaxis, :]
            gt = Variable(torch.from_numpy(gt))
            if use_gpu: gt = gt.cuda()

            loss = mse_losses[i](ft_test[i], gt)
            
            test_losses_trainphase[i] += loss.data.cpu().numpy()[0]
            test_cnts_trainphase[i] += 1
            v = v[0].cpu().data.numpy()
            v = v.transpose(1,2,0)
            if ind == 0: cv2.imwrite('snapshot2/test-phase_train-{}-{}.png'.format(epoch, i), v[:,:,::-1]*255)

    
    net.eval()
    test_losses = [0] * len(args.training_thresholds)
    test_cnts   = [0.00001] * len(args.training_thresholds)   
    for ind, data in enumerate(test_loader, 0):
#         if ind == 1: break
        input_img, gt_albedo, gt_shading, test_scene, img_path = data
        input_img = Variable(input_img)
        gt_albedo = Variable(gt_albedo)
        gt_shading = Variable(gt_shading)
        if use_gpu:
            input_img = input_img.cuda(args.gpu_num)
        ft_test = net(input_img)
            
        for i,v in enumerate(ft_test):
            if epoch < args.training_thresholds[i]: continue
            if i == 5: s = 1
            else: s = (2**(i+1))
            gt = gt_albedo.data.numpy()
            n,c,h,w = gt.shape
            gt = gt[0,:,:,:]
            gt = gt.transpose((1,2,0))
            gt = cv2.resize(gt, (h//s, w//s))
#             gt = cv2.resize(gt, (h,w))
            
            gt = gt.transpose((2,0,1))
            gt = gt[np.newaxis, :]
            gt = Variable(torch.from_numpy(gt))
            if use_gpu: gt = gt.cuda()

            loss = mse_losses[i](ft_test[i], gt)
            
            test_losses[i] += loss.data.cpu().numpy()[0]
            test_cnts[i] += 1
            v = v[0].cpu().data.numpy()
            v = v.transpose(1,2,0)
            if ind == 0: cv2.imwrite('snapshot2/test-phase_test-{}-{}.png'.format(epoch, i), v[:,:,::-1]*255)
    
    writer.add_scalars('16M loss', {
        'train 16M ': np.array([run_losses[0]/ run_cnts[0]]),
        'test_trainphase 16M ': np.array([test_losses_trainphase[0]/ test_cnts_trainphase[0]]),
        'test 16M ': np.array([test_losses[0]/ test_cnts[0]])
    }, global_step=epoch)  
    writer.add_scalars('8M loss', {
        'train 8M ': np.array([run_losses[1]/ run_cnts[1]]),
        'test_trainphase 8M ': np.array([test_losses_trainphase[1]/ test_cnts_trainphase[1]]),
        'test 8M ': np.array([test_losses[1]/ test_cnts[1]])
    }, global_step=epoch) 
    writer.add_scalars('4M loss', {
        'train 4M ': np.array([run_losses[2]/ run_cnts[2]]),
        'test_trainphase 4M ': np.array([test_losses_trainphase[2]/ test_cnts_trainphase[2]]),
        'test 4M ': np.array([test_losses[2]/ test_cnts[2]])
    }, global_step=epoch) 
    writer.add_scalars('2M loss', {
        'train 2M ': np.array([run_losses[3]/ run_cnts[3]]),
        'test_trainphase 2M ': np.array([test_losses_trainphase[3]/ test_cnts_trainphase[3]]),
        'test 2M ': np.array([test_losses[3]/ test_cnts[3]])
    }, global_step=epoch) 
    writer.add_scalars('1M loss', {
        'train 1M ': np.array([run_losses[4]/ run_cnts[4]]),
        'test_trainphase 1M ': np.array([test_losses_trainphase[4]/ test_cnts_trainphase[4]]),
        'test 1M ': np.array([test_losses[4]/ test_cnts[4]])
    }, global_step=epoch) 
    writer.add_scalars('merged loss', {
        'train merged ': np.array([run_losses[5]/ run_cnts[5]]),
        'test_trainphase merged ': np.array([test_losses_trainphase[5]/ test_cnts_trainphase[5]]),
        'test merged ': np.array([test_losses[5]/ test_cnts[5]])
    }, global_step=epoch)


epoch: 0 [2017-11-19 14:02:25]
lr 0.05
 16M: 0.000000  8M: 0.000000  4M: 0.000000  2M: 0.000000  1M: 0.053698 merged: 0.000000
epoch: 1 [2017-11-19 14:03:48]
lr 0.044721359549995794
 16M: 0.000000  8M: 0.000000  4M: 0.000000  2M: 0.000000  1M: 0.049098 merged: 0.000000
epoch: 2 [2017-11-19 14:05:10]
lr 0.038729833462074176
 16M: 0.000000  8M: 0.000000  4M: 0.000000  2M: 0.000000  1M: 0.047099 merged: 0.000000
epoch: 3 [2017-11-19 14:06:31]
lr 0.0316227766016838
 16M: 0.000000  8M: 0.000000  4M: 0.000000  2M: 0.000000  1M: 0.045852 merged: 0.000000
epoch: 4 [2017-11-19 14:07:53]
lr 0.022360679774997897
 16M: 0.000000  8M: 0.000000  4M: 0.000000  2M: 0.000000  1M: 0.043947 merged: 0.000000
epoch: 5 [2017-11-19 14:09:15]
lr 1e-08
 16M: 0.000000  8M: 0.000000  4M: 0.000000  2M: 0.000000  1M: 0.040625 merged: 0.000000
epoch: 6 [2017-11-19 14:10:36]
lr 1e-08
 16M: 0.000000  8M: 0.000000  4M: 0.000000  2M: 0.048388  1M: 0.045034 merged: 0.000000
epoch: 7 [2017-11-19 14:12:07]
lr 0.044721359549995794
 16M: 0.000000  8M: 0.000000  4M: 0.000000  2M: 0.044077  1M: 0.043985 merged: 0.000000
epoch: 8 [2017-11-19 14:13:39]
lr 0.038729833462074176
 16M: 0.000000  8M: 0.000000  4M: 0.000000  2M: 0.040370  1M: 0.039522 merged: 0.000000
epoch: 9 [2017-11-19 14:15:13]
lr 0.0316227766016838
 16M: 0.000000  8M: 0.000000  4M: 0.000000  2M: 0.034961  1M: 0.035874 merged: 0.000000
epoch: 10 [2017-11-19 14:16:45]
lr 0.022360679774997897
 16M: 0.000000  8M: 0.000000  4M: 0.000000  2M: 0.027672  1M: 0.030514 merged: 0.000000
epoch: 11 [2017-11-19 14:18:18]
lr 1e-08
 16M: 0.000000  8M: 0.000000  4M: 0.000000  2M: 0.024153  1M: 0.028282 merged: 0.000000
epoch: 12 [2017-11-19 14:19:50]
lr 1e-08
 16M: 0.000000  8M: 0.000000  4M: 0.048549  2M: 0.030359  1M: 0.033572 merged: 0.000000
epoch: 13 [2017-11-19 14:21:33]
lr 0.044721359549995794
 16M: 0.000000  8M: 0.000000  4M: 0.040085  2M: 0.026624  1M: 0.032146 merged: 0.000000
epoch: 14 [2017-11-19 14:23:16]
lr 0.038729833462074176
 16M: 0.000000  8M: 0.000000  4M: 0.028554  2M: 0.023017  1M: 0.028174 merged: 0.000000
epoch: 15 [2017-11-19 14:24:59]
lr 0.0316227766016838
 16M: 0.000000  8M: 0.000000  4M: 0.022348  2M: 0.020591  1M: 0.027038 merged: 0.000000
epoch: 16 [2017-11-19 14:26:38]
lr 0.022360679774997897
 16M: 0.000000  8M: 0.000000  4M: 0.019428  2M: 0.018925  1M: 0.025718 merged: 0.000000
epoch: 17 [2017-11-19 14:28:21]
lr 1e-08
 16M: 0.000000  8M: 0.000000  4M: 0.018818  2M: 0.019551  1M: 0.024946 merged: 0.000000
epoch: 18 [2017-11-19 14:30:03]
lr 1e-08
 16M: 0.000000  8M: 0.044768  4M: 0.022065  2M: 0.020874  1M: 0.027948 merged: 0.000000
epoch: 19 [2017-11-19 14:31:51]
lr 0.044721359549995794
 16M: 0.000000  8M: 0.034654  4M: 0.019827  2M: 0.020563  1M: 0.026426 merged: 0.000000
epoch: 20 [2017-11-19 14:33:42]
lr 0.038729833462074176
 16M: 0.000000  8M: 0.026419  4M: 0.018928  2M: 0.019381  1M: 0.026495 merged: 0.000000
epoch: 21 [2017-11-19 14:35:36]
lr 0.0316227766016838
 16M: 0.000000  8M: 0.021048  4M: 0.016542  2M: 0.017964  1M: 0.024340 merged: 0.000000
epoch: 22 [2017-11-19 14:37:28]
lr 0.022360679774997897
 16M: 0.000000  8M: 0.017708  4M: 0.015085  2M: 0.017072  1M: 0.023438 merged: 0.000000
epoch: 23 [2017-11-19 14:39:22]
lr 1e-08
 16M: 0.000000  8M: 0.018498  4M: 0.015039  2M: 0.016562  1M: 0.022912 merged: 0.000000
epoch: 24 [2017-11-19 14:41:14]
lr 1e-08
 16M: 0.043711  8M: 0.019333  4M: 0.016371  2M: 0.018143  1M: 0.024678 merged: 0.000000
epoch: 25 [2017-11-19 14:43:23]
lr 0.044721359549995794
 16M: 0.032719  8M: 0.018697  4M: 0.015621  2M: 0.017714  1M: 0.024652 merged: 0.000000
epoch: 26 [2017-11-19 14:45:35]
lr 0.038729833462074176
 16M: 0.023544  8M: 0.015938  4M: 0.014629  2M: 0.016777  1M: 0.023396 merged: 0.000000
epoch: 27 [2017-11-19 14:47:45]
lr 0.0316227766016838
 16M: 0.019812  8M: 0.014990  4M: 0.013943  2M: 0.016474  1M: 0.023183 merged: 0.000000
epoch: 28 [2017-11-19 14:49:56]
lr 0.022360679774997897
 16M: 0.017507  8M: 0.013819  4M: 0.013274  2M: 0.015694  1M: 0.021792 merged: 0.000000
epoch: 29 [2017-11-19 14:52:09]
lr 1e-08
 16M: 0.016289  8M: 0.013107  4M: 0.013536  2M: 0.016221  1M: 0.022879 merged: 0.000000
epoch: 30 [2017-11-19 14:54:17]
lr 0.05
 16M: 0.022719  8M: 0.016274  4M: 0.014745  2M: 0.017533  1M: 0.023519 merged: 0.036252
epoch: 31 [2017-11-19 14:57:42]
lr 0.04971830761761256
 16M: 0.021270  8M: 0.018308  4M: 0.016725  2M: 0.017940  1M: 0.023363 merged: 0.025193
epoch: 32 [2017-11-19 15:01:01]
lr 0.04943501011144937
 16M: 0.019715  8M: 0.017205  4M: 0.015934  2M: 0.018395  1M: 0.022814 merged: 0.020712
epoch: 33 [2017-11-19 15:04:23]
lr 0.04915007972606608
 16M: 0.018539  8M: 0.017524  4M: 0.014964  2M: 0.018506  1M: 0.023932 merged: 0.019226
epoch: 34 [2017-11-19 15:07:48]
lr 0.04886348789677424
 16M: 0.017740  8M: 0.016407  4M: 0.014609  2M: 0.018155  1M: 0.023205 merged: 0.017715
epoch: 35 [2017-11-19 15:11:08]
lr 0.04857520521621862
 16M: 0.015536  8M: 0.015027  4M: 0.014143  2M: 0.017352  1M: 0.022689 merged: 0.015889
epoch: 36 [2017-11-19 15:14:22]
lr 0.04828520139915856
 16M: 0.014843  8M: 0.013759  4M: 0.013520  2M: 0.016368  1M: 0.022217 merged: 0.015148
epoch: 37 [2017-11-19 15:17:47]
lr 0.047993445245333805
 16M: 0.014486  8M: 0.013551  4M: 0.013405  2M: 0.016112  1M: 0.022671 merged: 0.014400
epoch: 38 [2017-11-19 15:21:09]
lr 0.0476999046002862
 16M: 0.014854  8M: 0.014854  4M: 0.013813  2M: 0.016176  1M: 0.021649 merged: 0.014732
epoch: 39 [2017-11-19 15:24:35]
lr 0.04740454631399772
 16M: 0.013269  8M: 0.013321  4M: 0.013184  2M: 0.016212  1M: 0.021504 merged: 0.013580
epoch: 40 [2017-11-19 15:28:11]
lr 0.04710733619719444
 16M: 0.012947  8M: 0.012827  4M: 0.012850  2M: 0.015847  1M: 0.022193 merged: 0.012965
epoch: 41 [2017-11-19 15:31:35]
lr 0.04680823897515326
 16M: 0.012462  8M: 0.012205  4M: 0.012345  2M: 0.015217  1M: 0.020995 merged: 0.012532
epoch: 42 [2017-11-19 15:35:09]
lr 0.04650721823883479
 16M: 0.012697  8M: 0.012210  4M: 0.012299  2M: 0.015247  1M: 0.021297 merged: 0.012465
epoch: 43 [2017-11-19 15:38:37]
lr 0.046204236393150765
 16M: 0.011966  8M: 0.011831  4M: 0.011867  2M: 0.014899  1M: 0.020445 merged: 0.011556
epoch: 44 [2017-11-19 15:42:01]
lr 0.045899254602157845
 16M: 0.011948  8M: 0.012290  4M: 0.012508  2M: 0.014647  1M: 0.020463 merged: 0.012135
epoch: 45 [2017-11-19 15:45:28]
lr 0.04559223273095164
 16M: 0.011076  8M: 0.011239  4M: 0.011536  2M: 0.014089  1M: 0.020718 merged: 0.011034
epoch: 46 [2017-11-19 15:49:02]
lr 0.045283129284014914
 16M: 0.011146  8M: 0.011562  4M: 0.011826  2M: 0.014383  1M: 0.020694 merged: 0.011201
epoch: 47 [2017-11-19 15:52:29]
lr 0.04497190133975169
 16M: 0.010871  8M: 0.011062  4M: 0.011495  2M: 0.013977  1M: 0.019510 merged: 0.010787
epoch: 48 [2017-11-19 15:55:53]
lr 0.04465850448091506
 16M: 0.010564  8M: 0.011199  4M: 0.011474  2M: 0.013800  1M: 0.019565 merged: 0.010816
epoch: 49 [2017-11-19 15:59:19]
lr 0.044342892720609255
 16M: 0.010017  8M: 0.010419  4M: 0.011153  2M: 0.013595  1M: 0.019576 merged: 0.010157
epoch: 50 [2017-11-19 16:02:36]
lr 0.044025018423517
 16M: 0.010545  8M: 0.010669  4M: 0.011050  2M: 0.013873  1M: 0.019319 merged: 0.010411
epoch: 51 [2017-11-19 16:06:01]
lr 0.04370483222197017
 16M: 0.009676  8M: 0.010225  4M: 0.010763  2M: 0.013342  1M: 0.019092 merged: 0.009784
epoch: 52 [2017-11-19 16:09:25]
lr 0.043382282926444894
 16M: 0.009651  8M: 0.010280  4M: 0.010657  2M: 0.013450  1M: 0.019637 merged: 0.009799
epoch: 53 [2017-11-19 16:12:46]
lr 0.04305731743002185
 16M: 0.009928  8M: 0.010322  4M: 0.010931  2M: 0.013374  1M: 0.019327 merged: 0.009906
epoch: 54 [2017-11-19 16:16:17]
lr 0.04272988060630656
Process Process-163:
Traceback (most recent call last):
  File "/home/lwp/anaconda3/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/home/lwp/anaconda3/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/home/lwp/anaconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 34, in _worker_loop
    r = index_queue.get()
  File "/home/lwp/anaconda3/lib/python3.6/multiprocessing/queues.py", line 335, in get
    res = self._reader.recv_bytes()
  File "/home/lwp/anaconda3/lib/python3.6/multiprocessing/connection.py", line 216, in recv_bytes
    buf = self._recv_bytes(maxlength)
  File "/home/lwp/anaconda3/lib/python3.6/multiprocessing/connection.py", line 407, in _recv_bytes
    buf = self._recv(4)
  File "/home/lwp/anaconda3/lib/python3.6/multiprocessing/connection.py", line 379, in _recv
    chunk = read(handle, remaining)
KeyboardInterrupt
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-10-af2f6ef86bb2> in <module>()
     86 #                 run_cnts[i] += 1
     87                 run_losses[i] += loss.data.cpu().numpy()[0]
---> 88                 loss.backward(retain_graph=True)
     89                 run_cnts[i] += 1
     90 #                 print('i = ', i, '; weig\n', net.upsample01.weight[0,0,0:4,0:4].data.cpu().numpy())

~/anaconda3/lib/python3.6/site-packages/torch/autograd/variable.py in backward(self, gradient, retain_graph, create_graph, retain_variables)
    154                 Variable.
    155         """
--> 156         torch.autograd.backward(self, gradient, retain_graph, create_graph, retain_variables)
    157 
    158     def register_hook(self, hook):

~/anaconda3/lib/python3.6/site-packages/torch/autograd/__init__.py in backward(variables, grad_variables, retain_graph, create_graph, retain_variables)
     96 
     97     Variable._execution_engine.run_backward(
---> 98         variables, grad_variables, retain_graph)
     99 
    100 

KeyboardInterrupt: 

Visualize Graph


In [9]:
from graphviz import Digraph
import torch
from torch.autograd import Variable


def make_dot(var, params=None):
    """ Produces Graphviz representation of PyTorch autograd graph
    Blue nodes are the Variables that require grad, orange are Tensors
    saved for backward in torch.autograd.Function
    Args:
        var: output Variable
        params: dict of (name, Variable) to add names to node that
            require grad (TODO: make optional)
    """
    if params is not None:
        assert isinstance(params.values()[0], Variable)
        param_map = {id(v): k for k, v in params.items()}

    node_attr = dict(style='filled',
                     shape='box',
                     align='left',
                     fontsize='12',
                     ranksep='0.1',
                     height='0.2')
    dot = Digraph(node_attr=node_attr, graph_attr=dict(size="10240,10240"), format='svg')
    seen = set()

    def size_to_str(size):
        return '('+(', ').join(['%d' % v for v in size])+')'

    def add_nodes(var):
        if var not in seen:
            if torch.is_tensor(var):
                dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')
            elif hasattr(var, 'variable'):
                u = var.variable
                name = param_map[id(u)] if params is not None else ''
                node_name = '%s\n %s' % (name, size_to_str(u.size()))
                dot.node(str(id(var)), node_name, fillcolor='lightblue')
            else:
                dot.node(str(id(var)), str(type(var).__name__))
            seen.add(var)
            if hasattr(var, 'next_functions'):
                for u in var.next_functions:
                    if u[0] is not None:
                        dot.edge(str(id(u[0])), str(id(var)))
                        add_nodes(u[0])
            if hasattr(var, 'saved_tensors'):
                for t in var.saved_tensors:
                    dot.edge(str(id(t)), str(id(var)))
                    add_nodes(t)
    add_nodes(var.grad_fn)
    return dot

In [10]:
x = Variable(torch.zeros(1,3,256,256))
y = net(x.cuda())
g = make_dot(y[-1])

In [11]:
g.render('net')


Out[11]:
'net.svg'

In [ ]: