In [ ]:
import os, glob, platform, datetime, random
from collections import OrderedDict

import torch
import torch.nn as nn
import torch.utils.data as data_utils
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
from torch.autograd import Variable
# from torch import functional as F
import torch.nn.functional as F
# import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms

import cv2
from PIL import Image
from tensorboardX import SummaryWriter

import numpy as np
from numpy.linalg import inv as denseinv
from scipy import sparse
from scipy.sparse import lil_matrix, csr_matrix
from scipy.sparse.linalg import spsolve
from scipy.sparse.linalg import inv as spinv
import scipy.misc

from myimagefoldereccv import MyImageFolder
from mymodel import GradientNet
from myargs import Args
from myutils import MyUtils

Configurations


In [ ]:
myutils = MyUtils()

args = Args()
args.arch = "densenet121"
args.epoches = 500
args.epoches_unary_threshold = 0
args.image_h = 256
args.image_w = 256
args.img_extentions = ["png"]
args.training_thresholds = [250,200,150,50,0,300]
args.base_lr = 1
args.lr = args.base_lr
args.snapshot_interval = 5000
args.debug = True


# growth_rate = (4*(2**(args.gpu_num)))
transition_scale=2
pretrained_scale=4
growth_rate = 32

#######
args.test_scene = ['alley_1', 'bamboo_1', 'bandage_1', 'cave_2', 'market_2', 'market_6', 'shaman_2', 'sleeping_1', 'temple_2']
gradient=False
args.gpu_num = 0
#######

writer_comment = 'eccv_albedo'


offset = 0.
if gradient == True: offset = 0.5

args.display_interval = 50
args.display_curindex = 0

system_ = platform.system()
system_dist, system_version, _ = platform.dist()
if system_ == "Darwin": 
    args.train_dir = '/Volumes/Transcend/dataset/sintel2'
    args.pretrained = False
    args.image_w, args.image_h = 32, 32
elif platform.dist() ==  ('debian', 'jessie/sid', ''):
    args.train_dir = '/home/albertxavier/dataset/sintel2'
    args.pretrained = True
elif platform.dist() == ('debian', 'stretch/sid', ''):
    args.train_dir = '/home/cad/lwp/workspace/dataset/sintel2'
    args.pretrained = True

if platform.system() == 'Linux': use_gpu = True
else: use_gpu = False

if use_gpu:
    torch.cuda.set_device(args.gpu_num)
    

print(platform.dist())

My DataLoader


In [ ]:
train_dataset = MyImageFolder(args.train_dir, 'train',
                       transforms.Compose(
        [transforms.ToTensor()]
    ), random_crop=True, 
    img_extentions=args.img_extentions, test_scene=args.test_scene, image_h=args.image_h, image_w=args.image_w)
test_dataset = MyImageFolder(args.train_dir, 'test', 
                       transforms.Compose(
        [transforms.CenterCrop((args.image_h, args.image_w)),
         transforms.ToTensor()]
    ), random_crop=False,
    img_extentions=args.img_extentions, test_scene=args.test_scene, image_h=args.image_h, image_w=args.image_w)

train_loader = data_utils.DataLoader(train_dataset,1,True,num_workers=1)
test_loader = data_utils.DataLoader(test_dataset,1,True,num_workers=1)

Load Pretrained Model

Defination

  • DenseNet-121: num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16)
    • First Convolution: 32M -> 16M -> 8M
    • every transition: 8M -> 4M -> 2M (downsample 1/2, except the last block)

In [ ]:
densenet = models.__dict__[args.arch](pretrained=args.pretrained)

for param in densenet.parameters():
    param.requires_grad = False

if use_gpu: densenet.cuda()

In [ ]:
args.display_curindex = 0
args.base_lr = 0.01
args.display_interval = 20
args.momentum = 0.9
args.epoches = int(60*4)
#args.training_thresholds = 240//4
args.power = 0.5



net = GradientNet(densenet=densenet, growth_rate=growth_rate, 
                  transition_scale=transition_scale, pretrained_scale=pretrained_scale,
                 gradient=gradient)
if use_gpu:
    net.cuda()

mse_loss = nn.MSELoss().cuda() if use_gpu==True else nn.MSELoss()
mse_crf_loss = nn.MSELoss().cuda() if use_gpu==True else nn.MSELoss()

parameters = filter(lambda p: p.requires_grad, net.parameters())
optimizer = optim.SGD(parameters, lr=args.base_lr, momentum=args.momentum)

In [ ]:
def generate_y(predict_unary, predict_dx, predict_dy, gt, predict_alpha, predict_beta, max_iter=100, eps=1.e-4, use_gpu=True, volatile=False):
    def generate_y_(last_y, predict_unary, predict_dx, predict_dy, gt, predict_alpha, predict_beta, use_gpu=True):
        def prepare_fileters(direction='up'):
            filters = torch.Tensor(torch.zeros(3,3,3,3))
            if direction == 'up': 
                for i in range(3): filters[i,i,0,1] = 1.
            elif direction == 'down': 
                for i in range(3): filters[i,i,2,1] = 1.
            elif direction == 'left': 
                for i in range(3): filters[i,i,1,0] = 1.
            else: 
                for i in range(3): filters[i,i,1,2] = 1.
            filters = Variable(filters)
            if use_gpu == True: filters = filters.cuda()
            return filters

        f_up = prepare_fileters(direction='up')
        f_down = prepare_fileters(direction='down')
        f_left = prepare_fileters(direction='left')
        f_right = prepare_fileters(direction='right')

        last_y_up = F.conv2d(last_y, f_up, padding=1)
        last_y_down = F.conv2d(last_y, f_down, padding=1)
        last_y_left = F.conv2d(last_y, f_left, padding=1)
        last_y_right = F.conv2d(last_y, f_right, padding=1)
        
        t_up = F.conv2d(predict_dy, f_up, padding=1)
        t_down = -predict_dy
        t_left = F.conv2d(predict_dx, f_left, padding=1)
        t_right = -predict_dx
        
        beta_up = predict_beta[:,0:1,:,:]
        beta_down = predict_beta[:,1:2,:,:]
        beta_left = predict_beta[:,2:3,:,:]
        beta_right = predict_beta[:,3:4,:,:]
        
        sum_beta = beta_up + beta_down + beta_left + beta_right
        constant = predict_alpha + sum_beta
        #print('constant', constant)
        
        # y = (predict_alpha * predict_unary + \
        #     beta_up * (last_y_up + t_up) + \
        #     beta_down * (last_y_down + t_down) + \
        #      beta_left * (last_y_left + t_left) + beta_right * (last_y_right + t_right))/constant

        y = predict_alpha * predict_unary
        y = y + last_y_up    + beta_up    * t_up
        y = y + last_y_down  + beta_down  * t_down
        y = y + last_y_left  + beta_left  * t_left
        y = y + last_y_right + beta_right * t_right
        y = y / 5.
        return y
    
    predict_unary = predict_unary.clone()
    predict_dx = predict_dx.clone()
    predict_dy = predict_dy.clone()
    
    #y = Variable(predict_unary.data.cpu().clone()+torch.rand(predict_unary.size())/10., volatile=True).cuda()
    y = Variable(predict_unary.data.clone())
        
    if use_gpu == True: y = y.cuda()
    iters = 0
    while 1:
        last_y = y.clone()
        y = generate_y_(y, predict_unary, predict_dx, predict_dy, gt, predict_alpha, predict_beta, use_gpu=use_gpu)
        cur_loss = myutils.mse_loss_scalar(y, last_y)
        if cur_loss <= eps: 
            #print('cur loss', cur_loss)
            #print('cur iter', iters)
            #print('y min', y.min(), 'max', y.max())
            break
        if iters >= max_iter: 
            #print('@break at max iter', cur_loss)
            break
        iters += 1
        #print('y min', y.min(), 'max', y.max())
        break
    return y

In [ ]:
def crf_loss(y, predict_unary, predict_dx, predict_dy, gt_albedo, predict_alpha, predict_beta, volatile=False):
#     return torch.cat([y],1)
    def filter_gen(direction='x'):
        filters = torch.Tensor(torch.zeros(3,3,3,3))
        if use_gpu == True: filters = filters.cuda()
        for i in range(3):
            filters[i,i,1,1] = -1.
        if direction == 'x':
            for i in range(3):
                filters[i,i,1,2] = 1.
        else:
            for i in range(3):
                filters[i,i,2,1] = 1.
        filters = Variable(filters)
        return filters
    
    def prepare_fileters(direction='up'):
            filters = torch.Tensor(torch.zeros(3,3,3,3))
            if direction == 'up': 
                for i in range(3): filters[i,i,0,1] = 1.
            elif direction == 'down': 
                for i in range(3): filters[i,i,2,1] = 1.
            elif direction == 'left': 
                for i in range(3): filters[i,i,1,0] = 1.
            else: 
                for i in range(3): filters[i,i,1,2] = 1.
            filters = Variable(filters)
            if use_gpu == True: filters = filters.cuda()
            return filters

    predict_unary = predict_unary.clone()
    predict_dx = predict_dx.clone()
    predict_dy = predict_dy.clone()
    
    if volatile ==True:
        predict_alpha = predict_alpha.clone()
        predict_beta = predict_beta.clone()


    f_up = prepare_fileters(direction='up')
    f_down = prepare_fileters(direction='down')
    f_left = prepare_fileters(direction='left')
    f_right = prepare_fileters(direction='right')

    beta_up = predict_beta[:,0,:,:]
    beta_down = predict_beta[:,1,:,:]
    beta_left = predict_beta[:,2,:,:]
    beta_right = predict_beta[:,3,:,:]
    
    f_dx = filter_gen(direction='x')
    f_dy = filter_gen(direction='y')
    
    J1 = (y - predict_alpha * predict_unary)**2
    J2 = (((y - F.conv2d(y, f_up, padding=1) - beta_up * F.conv2d(predict_dy, f_up, padding=1))**2))
    J3 = (((y - F.conv2d(y, f_down, padding=1) + beta_down * predict_dy)**2))
    J4 = (((y - F.conv2d(y, f_left, padding=1) - beta_left * F.conv2d(predict_dx, f_left, padding=1))**2))
    J5 = (((y - F.conv2d(y, f_right, padding=1) + beta_right * predict_dx)**2))
    J = torch.cat([J1,J2,J3,J4,J5],1) 
    
    return J

In [ ]:
def train_eval_model_per_epoch(epoch, net, args, train_loader, test_loader, phase='train'):
    if phase == 'train':
        volatile = False
        net.train()
    else:
        volatile = True
#         net.eval()
        net.train()
    
    print('epoch: {} [{}]'.format(epoch, datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")))

    """adjust learning rate"""
    myutils.adjust_learning_rate(optimizer, args, epoch, beg=0, end=args.epoches)
    #if epoch < args.training_thresholds: 
    #    myutils.adjust_learning_rate(optimizer, args, epoch, beg=0, end=args.training_thresholds-1)
    #else:
    #    myutils.adjust_learning_rate(optimizer, args, epoch, beg=args.training_thresholds, end=args.epoches)
    writer.add_scalar('learning rate', optimizer.param_groups[0]['lr'], global_step=epoch)

    """init statics"""
    run_loss_unary = 0.
    run_loss_dx = 0.
    run_loss_dy = 0.
    run_loss_y = 0.
    run_loss_crf = 0.
    run_cnt   = 0.00001

    """for all training/test data"""
    loader = train_loader if phase == 'train' else test_loader
    
    for ind, data in enumerate(loader, 0):
        """prepare data"""
        input_img, gt_albedo, gt_shading, cur_scene, img_path = data
        (cur_scene,) = cur_scene
        (img_path,) = img_path
        cur_frame = img_path.split('/')[-1]
        input_img = Variable(input_img, volatile=volatile)
        gt_albedo = Variable(gt_albedo, requires_grad=False)
        gt_shading = Variable(gt_shading)
        if use_gpu: 
            input_img, gt_albedo, gt_shading = input_img.cuda(), gt_albedo.cuda(), gt_shading.cuda()
        
        """prepare gradient"""
        gt_dx = myutils.makeGradientTorch(gt_albedo, direction='x', use_gpu=use_gpu)
        gt_dy = myutils.makeGradientTorch(gt_albedo, direction='y', use_gpu=use_gpu)
        
        if phase == 'train':
            optimizer.zero_grad()
        
        predict_all = net(input_img)
        predict_unary = predict_all[:,0:3,:,:]
        predict_dx = predict_all[:,3:6,:,:]
        predict_dy = predict_all[:,6:9,:,:]
        predict_alpha = predict_all[:,9:10,:,:]
        predict_beta = predict_all[:,9:13,:,:]
        
        #print('alpha', predict_alpha.min(), predict_beta.max())
        #print('beta ', predict_beta.min(), predict_beta.max())
        
        y = None
        crf_loss_y = None
        crf_loss_gt = None
        
        """prepare crf y"""
        y = generate_y(predict_unary, predict_dx, predict_dy, gt_albedo, predict_alpha, predict_beta, use_gpu=use_gpu, volatile=volatile)
        #y = Variable(y.data.clone(), requires_grad=False).cuda()

        #print('y', y.min(), y.max())


        """prepare crf loss"""
        #crf_loss_y = crf_loss(y, predict_unary, predict_dx, predict_dy, gt_albedo, predict_alpha, predict_beta)
        # crf_loss_y = crf_loss(predict_dx, predict_unary, predict_dx, predict_dy, gt_albedo, predict_alpha, predict_beta)
        #crf_loss_gt = crf_loss(gt_albedo, predict_unary, predict_dx, predict_dy, gt_albedo, predict_alpha, predict_beta, volatile=True)
        # crf_loss_gt = crf_loss(predict_dy, predict_unary, predict_dx, predict_dy, gt_albedo, predict_alpha, predict_beta)
        #crf_loss_gt = Variable(crf_loss_gt.data.cpu(), requires_grad=False).cuda()
        
        """prepare final gt"""
        predict_final = None
        gt_final = None
        predict_final = torch.cat([predict_all[:,0:3+6,:,:], y], 1)
        gt_final = torch.cat([gt_albedo, gt_dx, gt_dy, gt_albedo], 1)
            
        
        """compute loss"""
        loss = mse_loss(predict_final, gt_final)
        # c_loss = mse_crf_loss(predict_dx, predict_dy)
        
        run_loss_unary += myutils.mse_loss_scalar(predict_unary, gt_albedo)
        run_loss_dx += myutils.mse_loss_scalar(predict_dx, gt_dx)
        run_loss_dy += myutils.mse_loss_scalar(predict_dy, gt_dy)
        run_loss_y += myutils.mse_loss_scalar(y, gt_albedo)
        #run_loss_crf += myutils.mse_loss_scalar(crf_loss_y, 0)
        run_cnt += 1

        """backward"""
        if phase == 'train':
            loss.backward()
            # c_loss.backward()
            optimizer.step()
        
        """generate display img"""
        display_im = myutils.tensor2Numpy(input_img)[:,:,::-1]*255
        display_gt_albedo = myutils.tensor2Numpy(gt_albedo)[:,:,::-1]*255
        display_gt_dx = (myutils.tensor2Numpy(gt_dx)[:,:,::-1]+0.5)*255
        display_gt_dy = (myutils.tensor2Numpy(gt_dy)[:,:,::-1]+0.5)*255
        display_unary = myutils.tensor2Numpy(predict_unary)[:,:,::-1]*255
        display_dx = (myutils.tensor2Numpy(predict_dx)[:,:,::-1]+0.5)*255
        display_dy = (myutils.tensor2Numpy(predict_dy)[:,:,::-1]+0.5)*255
        display_y = (myutils.tensor2Numpy(y)[:,:,::-1])*255

        """display"""
        if (phase == 'train' and args.display_curindex % args.display_interval == 0) or \
        (phase == 'test' and cur_scene == 'alley_1' and cur_frame == 'frame_0001.png'):
            # print('display ', phase, img_path, display_im.shape)
            cv2.imwrite('snapshot{}/input.png'.format(args.gpu_num), display_im)
            cv2.imwrite('snapshot{}/{}-gt-{}-unary.png'.format(args.gpu_num, phase, epoch), display_gt_albedo) 
            cv2.imwrite('snapshot{}/{}-gt-{}-dx.png'.format(args.gpu_num, phase, epoch), display_gt_dx) 
            cv2.imwrite('snapshot{}/{}-gt-{}-dy.png'.format(args.gpu_num, phase, epoch), display_gt_dy) 
            cv2.imwrite('snapshot{}/{}-rs-{}-unary.png'.format(args.gpu_num, phase, epoch), display_unary)
            cv2.imwrite('snapshot{}/{}-rs-{}-dx.png'.format(args.gpu_num, phase, epoch), display_dx)
            cv2.imwrite('snapshot{}/{}-rs-{}-dy.png'.format(args.gpu_num, phase, epoch), display_dy)
            cv2.imwrite('snapshot{}/{}-rs-{}-y.png'.format(args.gpu_num, phase, epoch), display_y)
        
        args.display_curindex += 1
    
    """output loss"""
    loss_output = ''
    loss_output += '{} loss: '.format(phase)
    loss_output += 'unary: %6f ' % (run_loss_unary/run_cnt)
    loss_output += 'pairwise: %6f ' % ((run_loss_dx+run_loss_dy)/run_cnt)
    #loss_output += 'crf: %6f ' % (run_loss_crf/run_cnt)
    loss_output += 'y: %6f ' % (run_loss_y/run_cnt)
    
    print(loss_output)
    
    """write to tensorboard"""
    writer.add_scalars('loss', {
        '%s unary loss'% (phase): np.array([run_loss_unary/run_cnt]),
        '%s dx loss'% (phase): np.array([run_loss_dx/run_cnt]),
        '%s dy loss'% (phase): np.array([run_loss_dy/run_cnt]),
        '%s pairwise loss'% (phase): np.array([(run_loss_dx+run_loss_dy)/run_cnt])
        #'%s y loss'% (phase): np.array([run_loss_y/run_cnt]),
    }, global_step=epoch)
    
    """save snapshot"""
    if phase == 'train':
        myutils.save_snapshot(epoch, args, net, optimizer)

In [ ]:
"""training loop"""
writer = SummaryWriter(comment='-{}'.format(writer_comment))

for epoch in range(args.epoches):
    phase = 'test' if (epoch+1) % 5 == 0 else 'train'
    train_eval_model_per_epoch(epoch, net, args, train_loader, test_loader, phase=phase)

Visualize Graph


In [ ]:
# x = Variable(torch.zeros(1,3,256,256))
# y = net(x.cuda())
# g = make_dot(y[-1])

In [ ]:
# g.render('net-transition_scale_{}'.format(transition_scale))

In [ ]: