In [ ]:
import os, glob, platform, datetime, random
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.utils.data as data_utils
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
from torch.autograd import Variable
# from torch import functional as F
import torch.nn.functional as F
# import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms
import cv2
from PIL import Image
from tensorboardX import SummaryWriter
import numpy as np
from numpy.linalg import inv as denseinv
from scipy import sparse
from scipy.sparse import lil_matrix, csr_matrix
from scipy.sparse.linalg import spsolve
from scipy.sparse.linalg import inv as spinv
import scipy.misc
from myimagefoldereccv import MyImageFolder
from mymodel import GradientNet
from myargs import Args
from myutils import MyUtils
In [ ]:
myutils = MyUtils()
args = Args()
args.arch = "densenet121"
args.epoches = 500
args.epoches_unary_threshold = 0
args.image_h = 256
args.image_w = 256
args.img_extentions = ["png"]
args.training_thresholds = [250,200,150,50,0,300]
args.base_lr = 1
args.lr = args.base_lr
args.snapshot_interval = 5000
args.debug = True
# growth_rate = (4*(2**(args.gpu_num)))
transition_scale=2
pretrained_scale=4
growth_rate = 32
#######
args.test_scene = ['alley_1', 'bamboo_1', 'bandage_1', 'cave_2', 'market_2', 'market_6', 'shaman_2', 'sleeping_1', 'temple_2']
gradient=False
args.gpu_num = 0
#######
writer_comment = 'eccv_albedo'
offset = 0.
if gradient == True: offset = 0.5
args.display_interval = 50
args.display_curindex = 0
system_ = platform.system()
system_dist, system_version, _ = platform.dist()
if system_ == "Darwin":
args.train_dir = '/Volumes/Transcend/dataset/sintel2'
args.pretrained = False
args.image_w, args.image_h = 32, 32
elif platform.dist() == ('debian', 'jessie/sid', ''):
args.train_dir = '/home/albertxavier/dataset/sintel2'
args.pretrained = True
elif platform.dist() == ('debian', 'stretch/sid', ''):
args.train_dir = '/home/cad/lwp/workspace/dataset/sintel2'
args.pretrained = True
if platform.system() == 'Linux': use_gpu = True
else: use_gpu = False
if use_gpu:
torch.cuda.set_device(args.gpu_num)
print(platform.dist())
In [ ]:
train_dataset = MyImageFolder(args.train_dir, 'train',
transforms.Compose(
[transforms.ToTensor()]
), random_crop=True,
img_extentions=args.img_extentions, test_scene=args.test_scene, image_h=args.image_h, image_w=args.image_w)
test_dataset = MyImageFolder(args.train_dir, 'test',
transforms.Compose(
[transforms.CenterCrop((args.image_h, args.image_w)),
transforms.ToTensor()]
), random_crop=False,
img_extentions=args.img_extentions, test_scene=args.test_scene, image_h=args.image_h, image_w=args.image_w)
train_loader = data_utils.DataLoader(train_dataset,1,True,num_workers=1)
test_loader = data_utils.DataLoader(test_dataset,1,True,num_workers=1)
In [ ]:
densenet = models.__dict__[args.arch](pretrained=args.pretrained)
for param in densenet.parameters():
param.requires_grad = False
if use_gpu: densenet.cuda()
In [ ]:
args.display_curindex = 0
args.base_lr = 0.01
args.display_interval = 20
args.momentum = 0.9
args.epoches = int(60*4)
#args.training_thresholds = 240//4
args.power = 0.5
net = GradientNet(densenet=densenet, growth_rate=growth_rate,
transition_scale=transition_scale, pretrained_scale=pretrained_scale,
gradient=gradient)
if use_gpu:
net.cuda()
mse_loss = nn.MSELoss().cuda() if use_gpu==True else nn.MSELoss()
mse_crf_loss = nn.MSELoss().cuda() if use_gpu==True else nn.MSELoss()
parameters = filter(lambda p: p.requires_grad, net.parameters())
optimizer = optim.SGD(parameters, lr=args.base_lr, momentum=args.momentum)
In [ ]:
def generate_y(predict_unary, predict_dx, predict_dy, gt, predict_alpha, predict_beta, max_iter=100, eps=1.e-4, use_gpu=True, volatile=False):
def generate_y_(last_y, predict_unary, predict_dx, predict_dy, gt, predict_alpha, predict_beta, use_gpu=True):
def prepare_fileters(direction='up'):
filters = torch.Tensor(torch.zeros(3,3,3,3))
if direction == 'up':
for i in range(3): filters[i,i,0,1] = 1.
elif direction == 'down':
for i in range(3): filters[i,i,2,1] = 1.
elif direction == 'left':
for i in range(3): filters[i,i,1,0] = 1.
else:
for i in range(3): filters[i,i,1,2] = 1.
filters = Variable(filters)
if use_gpu == True: filters = filters.cuda()
return filters
f_up = prepare_fileters(direction='up')
f_down = prepare_fileters(direction='down')
f_left = prepare_fileters(direction='left')
f_right = prepare_fileters(direction='right')
last_y_up = F.conv2d(last_y, f_up, padding=1)
last_y_down = F.conv2d(last_y, f_down, padding=1)
last_y_left = F.conv2d(last_y, f_left, padding=1)
last_y_right = F.conv2d(last_y, f_right, padding=1)
t_up = F.conv2d(predict_dy, f_up, padding=1)
t_down = -predict_dy
t_left = F.conv2d(predict_dx, f_left, padding=1)
t_right = -predict_dx
beta_up = predict_beta[:,0:1,:,:]
beta_down = predict_beta[:,1:2,:,:]
beta_left = predict_beta[:,2:3,:,:]
beta_right = predict_beta[:,3:4,:,:]
sum_beta = beta_up + beta_down + beta_left + beta_right
constant = predict_alpha + sum_beta
#print('constant', constant)
# y = (predict_alpha * predict_unary + \
# beta_up * (last_y_up + t_up) + \
# beta_down * (last_y_down + t_down) + \
# beta_left * (last_y_left + t_left) + beta_right * (last_y_right + t_right))/constant
y = predict_alpha * predict_unary
y = y + last_y_up + beta_up * t_up
y = y + last_y_down + beta_down * t_down
y = y + last_y_left + beta_left * t_left
y = y + last_y_right + beta_right * t_right
y = y / 5.
return y
predict_unary = predict_unary.clone()
predict_dx = predict_dx.clone()
predict_dy = predict_dy.clone()
#y = Variable(predict_unary.data.cpu().clone()+torch.rand(predict_unary.size())/10., volatile=True).cuda()
y = Variable(predict_unary.data.clone())
if use_gpu == True: y = y.cuda()
iters = 0
while 1:
last_y = y.clone()
y = generate_y_(y, predict_unary, predict_dx, predict_dy, gt, predict_alpha, predict_beta, use_gpu=use_gpu)
cur_loss = myutils.mse_loss_scalar(y, last_y)
if cur_loss <= eps:
#print('cur loss', cur_loss)
#print('cur iter', iters)
#print('y min', y.min(), 'max', y.max())
break
if iters >= max_iter:
#print('@break at max iter', cur_loss)
break
iters += 1
#print('y min', y.min(), 'max', y.max())
break
return y
In [ ]:
def crf_loss(y, predict_unary, predict_dx, predict_dy, gt_albedo, predict_alpha, predict_beta, volatile=False):
# return torch.cat([y],1)
def filter_gen(direction='x'):
filters = torch.Tensor(torch.zeros(3,3,3,3))
if use_gpu == True: filters = filters.cuda()
for i in range(3):
filters[i,i,1,1] = -1.
if direction == 'x':
for i in range(3):
filters[i,i,1,2] = 1.
else:
for i in range(3):
filters[i,i,2,1] = 1.
filters = Variable(filters)
return filters
def prepare_fileters(direction='up'):
filters = torch.Tensor(torch.zeros(3,3,3,3))
if direction == 'up':
for i in range(3): filters[i,i,0,1] = 1.
elif direction == 'down':
for i in range(3): filters[i,i,2,1] = 1.
elif direction == 'left':
for i in range(3): filters[i,i,1,0] = 1.
else:
for i in range(3): filters[i,i,1,2] = 1.
filters = Variable(filters)
if use_gpu == True: filters = filters.cuda()
return filters
predict_unary = predict_unary.clone()
predict_dx = predict_dx.clone()
predict_dy = predict_dy.clone()
if volatile ==True:
predict_alpha = predict_alpha.clone()
predict_beta = predict_beta.clone()
f_up = prepare_fileters(direction='up')
f_down = prepare_fileters(direction='down')
f_left = prepare_fileters(direction='left')
f_right = prepare_fileters(direction='right')
beta_up = predict_beta[:,0,:,:]
beta_down = predict_beta[:,1,:,:]
beta_left = predict_beta[:,2,:,:]
beta_right = predict_beta[:,3,:,:]
f_dx = filter_gen(direction='x')
f_dy = filter_gen(direction='y')
J1 = (y - predict_alpha * predict_unary)**2
J2 = (((y - F.conv2d(y, f_up, padding=1) - beta_up * F.conv2d(predict_dy, f_up, padding=1))**2))
J3 = (((y - F.conv2d(y, f_down, padding=1) + beta_down * predict_dy)**2))
J4 = (((y - F.conv2d(y, f_left, padding=1) - beta_left * F.conv2d(predict_dx, f_left, padding=1))**2))
J5 = (((y - F.conv2d(y, f_right, padding=1) + beta_right * predict_dx)**2))
J = torch.cat([J1,J2,J3,J4,J5],1)
return J
In [ ]:
def train_eval_model_per_epoch(epoch, net, args, train_loader, test_loader, phase='train'):
if phase == 'train':
volatile = False
net.train()
else:
volatile = True
# net.eval()
net.train()
print('epoch: {} [{}]'.format(epoch, datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")))
"""adjust learning rate"""
myutils.adjust_learning_rate(optimizer, args, epoch, beg=0, end=args.epoches)
#if epoch < args.training_thresholds:
# myutils.adjust_learning_rate(optimizer, args, epoch, beg=0, end=args.training_thresholds-1)
#else:
# myutils.adjust_learning_rate(optimizer, args, epoch, beg=args.training_thresholds, end=args.epoches)
writer.add_scalar('learning rate', optimizer.param_groups[0]['lr'], global_step=epoch)
"""init statics"""
run_loss_unary = 0.
run_loss_dx = 0.
run_loss_dy = 0.
run_loss_y = 0.
run_loss_crf = 0.
run_cnt = 0.00001
"""for all training/test data"""
loader = train_loader if phase == 'train' else test_loader
for ind, data in enumerate(loader, 0):
"""prepare data"""
input_img, gt_albedo, gt_shading, cur_scene, img_path = data
(cur_scene,) = cur_scene
(img_path,) = img_path
cur_frame = img_path.split('/')[-1]
input_img = Variable(input_img, volatile=volatile)
gt_albedo = Variable(gt_albedo, requires_grad=False)
gt_shading = Variable(gt_shading)
if use_gpu:
input_img, gt_albedo, gt_shading = input_img.cuda(), gt_albedo.cuda(), gt_shading.cuda()
"""prepare gradient"""
gt_dx = myutils.makeGradientTorch(gt_albedo, direction='x', use_gpu=use_gpu)
gt_dy = myutils.makeGradientTorch(gt_albedo, direction='y', use_gpu=use_gpu)
if phase == 'train':
optimizer.zero_grad()
predict_all = net(input_img)
predict_unary = predict_all[:,0:3,:,:]
predict_dx = predict_all[:,3:6,:,:]
predict_dy = predict_all[:,6:9,:,:]
predict_alpha = predict_all[:,9:10,:,:]
predict_beta = predict_all[:,9:13,:,:]
#print('alpha', predict_alpha.min(), predict_beta.max())
#print('beta ', predict_beta.min(), predict_beta.max())
y = None
crf_loss_y = None
crf_loss_gt = None
"""prepare crf y"""
y = generate_y(predict_unary, predict_dx, predict_dy, gt_albedo, predict_alpha, predict_beta, use_gpu=use_gpu, volatile=volatile)
#y = Variable(y.data.clone(), requires_grad=False).cuda()
#print('y', y.min(), y.max())
"""prepare crf loss"""
#crf_loss_y = crf_loss(y, predict_unary, predict_dx, predict_dy, gt_albedo, predict_alpha, predict_beta)
# crf_loss_y = crf_loss(predict_dx, predict_unary, predict_dx, predict_dy, gt_albedo, predict_alpha, predict_beta)
#crf_loss_gt = crf_loss(gt_albedo, predict_unary, predict_dx, predict_dy, gt_albedo, predict_alpha, predict_beta, volatile=True)
# crf_loss_gt = crf_loss(predict_dy, predict_unary, predict_dx, predict_dy, gt_albedo, predict_alpha, predict_beta)
#crf_loss_gt = Variable(crf_loss_gt.data.cpu(), requires_grad=False).cuda()
"""prepare final gt"""
predict_final = None
gt_final = None
predict_final = torch.cat([predict_all[:,0:3+6,:,:], y], 1)
gt_final = torch.cat([gt_albedo, gt_dx, gt_dy, gt_albedo], 1)
"""compute loss"""
loss = mse_loss(predict_final, gt_final)
# c_loss = mse_crf_loss(predict_dx, predict_dy)
run_loss_unary += myutils.mse_loss_scalar(predict_unary, gt_albedo)
run_loss_dx += myutils.mse_loss_scalar(predict_dx, gt_dx)
run_loss_dy += myutils.mse_loss_scalar(predict_dy, gt_dy)
run_loss_y += myutils.mse_loss_scalar(y, gt_albedo)
#run_loss_crf += myutils.mse_loss_scalar(crf_loss_y, 0)
run_cnt += 1
"""backward"""
if phase == 'train':
loss.backward()
# c_loss.backward()
optimizer.step()
"""generate display img"""
display_im = myutils.tensor2Numpy(input_img)[:,:,::-1]*255
display_gt_albedo = myutils.tensor2Numpy(gt_albedo)[:,:,::-1]*255
display_gt_dx = (myutils.tensor2Numpy(gt_dx)[:,:,::-1]+0.5)*255
display_gt_dy = (myutils.tensor2Numpy(gt_dy)[:,:,::-1]+0.5)*255
display_unary = myutils.tensor2Numpy(predict_unary)[:,:,::-1]*255
display_dx = (myutils.tensor2Numpy(predict_dx)[:,:,::-1]+0.5)*255
display_dy = (myutils.tensor2Numpy(predict_dy)[:,:,::-1]+0.5)*255
display_y = (myutils.tensor2Numpy(y)[:,:,::-1])*255
"""display"""
if (phase == 'train' and args.display_curindex % args.display_interval == 0) or \
(phase == 'test' and cur_scene == 'alley_1' and cur_frame == 'frame_0001.png'):
# print('display ', phase, img_path, display_im.shape)
cv2.imwrite('snapshot{}/input.png'.format(args.gpu_num), display_im)
cv2.imwrite('snapshot{}/{}-gt-{}-unary.png'.format(args.gpu_num, phase, epoch), display_gt_albedo)
cv2.imwrite('snapshot{}/{}-gt-{}-dx.png'.format(args.gpu_num, phase, epoch), display_gt_dx)
cv2.imwrite('snapshot{}/{}-gt-{}-dy.png'.format(args.gpu_num, phase, epoch), display_gt_dy)
cv2.imwrite('snapshot{}/{}-rs-{}-unary.png'.format(args.gpu_num, phase, epoch), display_unary)
cv2.imwrite('snapshot{}/{}-rs-{}-dx.png'.format(args.gpu_num, phase, epoch), display_dx)
cv2.imwrite('snapshot{}/{}-rs-{}-dy.png'.format(args.gpu_num, phase, epoch), display_dy)
cv2.imwrite('snapshot{}/{}-rs-{}-y.png'.format(args.gpu_num, phase, epoch), display_y)
args.display_curindex += 1
"""output loss"""
loss_output = ''
loss_output += '{} loss: '.format(phase)
loss_output += 'unary: %6f ' % (run_loss_unary/run_cnt)
loss_output += 'pairwise: %6f ' % ((run_loss_dx+run_loss_dy)/run_cnt)
#loss_output += 'crf: %6f ' % (run_loss_crf/run_cnt)
loss_output += 'y: %6f ' % (run_loss_y/run_cnt)
print(loss_output)
"""write to tensorboard"""
writer.add_scalars('loss', {
'%s unary loss'% (phase): np.array([run_loss_unary/run_cnt]),
'%s dx loss'% (phase): np.array([run_loss_dx/run_cnt]),
'%s dy loss'% (phase): np.array([run_loss_dy/run_cnt]),
'%s pairwise loss'% (phase): np.array([(run_loss_dx+run_loss_dy)/run_cnt])
#'%s y loss'% (phase): np.array([run_loss_y/run_cnt]),
}, global_step=epoch)
"""save snapshot"""
if phase == 'train':
myutils.save_snapshot(epoch, args, net, optimizer)
In [ ]:
"""training loop"""
writer = SummaryWriter(comment='-{}'.format(writer_comment))
for epoch in range(args.epoches):
phase = 'test' if (epoch+1) % 5 == 0 else 'train'
train_eval_model_per_epoch(epoch, net, args, train_loader, test_loader, phase=phase)
In [ ]:
# x = Variable(torch.zeros(1,3,256,256))
# y = net(x.cuda())
# g = make_dot(y[-1])
In [ ]:
# g.render('net-transition_scale_{}'.format(transition_scale))
In [ ]: