In [1]:
import os, glob, platform, datetime, random
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.utils.data as data_utils
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
from torch.autograd import Variable
from torch import functional as F
# import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms
import cv2
from PIL import Image
from tensorboardX import SummaryWriter
import numpy as np
from numpy.linalg import inv as denseinv
from scipy import sparse
from scipy.sparse import lil_matrix, csr_matrix
from scipy.sparse.linalg import spsolve
from scipy.sparse.linalg import inv as spinv
import scipy.misc
from myimagefolder import MyImageFolder
from mymodel import PreTrainedModel, GradientNet16, GradientNet08, GradientNet04, GradientNet02, GradientNet01, GradientNetMerge
from myargs import Args
In [2]:
args = Args()
args.test_scene = 'alley_1'
args.arch = "densenet121"
args.epoches = 500
args.epoches_unary_threshold = 0
args.image_h = 256
args.image_w = 256
args.img_extentions = ["png"]
args.training_thresholds = [250,200,150,50,0,300]
args.base_lr = 1
args.lr = args.base_lr
args.snapshot_interval = 5000
args.debug = True
args.gpu_num = 1
args.display_interval = 50
args.display_curindex = 0
system_ = platform.system()
system_dist, system_version, _ = platform.dist()
if system_ == "Darwin":
args.train_dir = '/Volumes/Transcend/dataset/sintel2'
args.pretrained = False
elif platform.dist() == ('debian', 'jessie/sid', ''):
args.train_dir = '/home/lwp/workspace/sintel2'
args.pretrained = True
elif platform.dist() == ('debian', 'stretch/sid', ''):
args.train_dir = '/home/cad/lwp/workspace/dataset/sintel2'
args.pretrained = True
if platform.system() == 'Linux': use_gpu = True
else: use_gpu = False
if use_gpu:
torch.cuda.set_device(args.gpu_num)
print(platform.dist())
In [3]:
train_dataset = MyImageFolder(args.train_dir, 'train',
transforms.Compose(
[transforms.ToTensor()]
), random_crop=True,
img_extentions=args.img_extentions, test_scene=args.test_scene, image_h=args.image_h, image_w=args.image_w)
test_dataset = MyImageFolder(args.train_dir, 'test',
transforms.Compose(
[transforms.CenterCrop((args.image_h, args.image_w)),
transforms.ToTensor()]
), random_crop=False,
img_extentions=args.img_extentions, test_scene=args.test_scene, image_h=args.image_h, image_w=args.image_w)
train_loader = data_utils.DataLoader(train_dataset,1,True,num_workers=1)
test_loader = data_utils.DataLoader(test_dataset,1,True,num_workers=1)
In [4]:
densenet = models.__dict__[args.arch](pretrained=args.pretrained)
for param in densenet.parameters():
param.requires_grad = False
if use_gpu: densenet.cuda()
In [5]:
ss = 6
args.display_curindex = 0
args.base_lr = 0.05
args.display_interval = 20
args.momentum = 0.9
args.epoches = 120
args.training_thresholds = [ss*4,ss*3,ss*2,ss*1,ss*0,ss*5]
args.power = 0.5
pretrained = PreTrainedModel(densenet)
if use_gpu:
pretrained.cuda()
net16 = GradientNet16()
net08 = GradientNet08()
net04 = GradientNet04()
net02 = GradientNet02()
net01 = GradientNet01()
netmg = GradientNetMerge()
if use_gpu:
net16.cuda()
net08.cuda()
net04.cuda()
net02.cuda()
net01.cuda()
netmg.cuda()
nets = [net16, net08, net04, net02, net01, netmg]
if use_gpu:
mse_losses = [nn.MSELoss().cuda()] * 6
test_losses = [nn.MSELoss().cuda()] * 6
else:
mse_losses = [nn.MSELoss()] * 6
test_losses = [nn.MSELoss()] * 6
In [23]:
# training loop
writer = SummaryWriter()
writer.add_text('training', 'different kernel size for different scale')
parameters = [0]*len(nets)
optimizers = [0]*len(nets)
for i in range(len(nets)):
parameters[i] = filter(lambda p: p.requires_grad, nets[i].parameters())
optimizers[i] = optim.SGD(parameters[i], lr=args.base_lr, momentum=args.momentum)
def adjust_learning_rate(optimizer, epoch, beg, end, reset_lr=None):
"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
for param_group in optimizer.param_groups:
if reset_lr != None:
param_group['lr'] = reset_lr
continue
if epoch != 0:
param_group['lr'] = args.base_lr * (float(end-epoch)/(end-beg)) ** (args.power)
if param_group['lr'] < 1.0e-8: param_group['lr'] = 1.0e-8
# print('lr', param_group['lr'])
pretrained.train()
ft_pretrained = pretrained(input_img)
pretrained.eval()
ft_pretrained_test_phase = pretrained(input_img)
for epoch in range(args.epoches):
net16.train(); net08.train(); net04.train(); net02.train(); net01.train(); netmg.train();
print('epoch: {} [{}]'.format(epoch, datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")))
if epoch < args.training_thresholds[-1]:
for optimizer in optimizers: adjust_learning_rate(optimizer, epoch%ss, beg=0, end=ss-1)
else:
for optimizer in optimizers: adjust_learning_rate(optimizer, epoch, beg=args.training_thresholds[-1], end=args.epoches-1)
run_losses = [0] * len(args.training_thresholds)
run_cnts = [0.00001] * len(args.training_thresholds)
if (epoch in args.training_thresholds) == True:
for optimizer in optimizers: adjust_learning_rate(optimizer, epoch, reset_lr=args.base_lr, beg=-1, end=-1)
writer.add_scalar('learning rate', optimizers[0].param_groups[0]['lr'], global_step=epoch)
for ind, data in enumerate(train_loader, 0):
# if ind == 1 : break
input_img, gt_albedo, gt_shading, test_scene, img_path = data
im = input_img[0,:,:,:].numpy(); im = im.transpose(1,2,0); im = im[:,:,::-1]*255
if test_scene[0] == 'alley_1':
print('alley_1 yes')
input_img = Variable(input_img)
gt_albedo = Variable(gt_albedo)
gt_shading = Variable(gt_shading)
if use_gpu:
input_img = input_img.cuda()
gt_albedo = gt_albedo.cuda()
gt_shading = gt_shading.cuda()
if args.display_curindex % args.display_interval == 0:
cv2.imwrite('snapshot/input.png', im)
# for optimizer in optimizers: optimizer.zero_grad()
ft_predict = [0] * len(nets)
for i, threshold in enumerate(args.training_thresholds):
# if epoch >= threshold:
if epoch >= 0:
optimizers[i].zero_grad()
if i < 5: ft_predict[i] = nets[i](ft_pretrained)
else:
for net_ind in range(len(nets)-1):
ft_predict[net_ind] =nets[net_ind](ft_pretrained)
ft_predict[i] = nets[i](ft_predict[0:-1])
if i == 5: s = 1
else: s = (2**(i+1))
gt = gt_albedo.cpu().data.numpy()
n,c,h,w = gt.shape
gt = gt[0,:,:,:]
gt = gt.transpose((1,2,0))
gt = cv2.resize(gt, (h//s, w//s))
# gt = cv2.resize(gt, (h,w))
if args.display_curindex % args.display_interval == 0:
cv2.imwrite('snapshot/gt-{}-{}.png'.format(epoch, i), gt[:,:,::-1]*255)
gt = gt.transpose((2,0,1))
gt = gt[np.newaxis, :]
gt = Variable(torch.from_numpy(gt))
if use_gpu: gt = gt.cuda()
loss = mse_losses[i](ft_predict[i], gt)
loss_data = loss.data.cpu().numpy()
writer.add_scalar('{}th train iters loss'.format(i), loss_data, global_step=args.display_curindex)
# ma_ = ft_predict[i].max().cpu().data.numpy()
# mi_ = ft_predict[i].min().cpu().data.numpy()
#print('mi', mi_, 'ma', ma_)
# writer.add_scalars('{}th train predict'.format(i), {'max': ma_, 'min': mi_}, global_step=args.display_curindex)
# run_cnts[i] += 1
run_losses[i] += loss.data.cpu().numpy()[0]
loss.backward()
optimizers[i].step()
# if i < 5: optimizers[i].step()
# else: for opt_ind in range(len(optimizers)): optimizers[i].step()
run_cnts[i] += 1
# print('i = ', i, '; weig\n', net.upsample01.weight[0,0,0:4,0:4].data.cpu().numpy())
# print('i = ', i, '; grad\n', net.upsample01.weight.grad[0,0,0:4,0:4].data.cpu().numpy())
if args.display_curindex % args.display_interval == 0:
im = ft_predict[i].cpu().data.numpy()[0].transpose((1,2,0)) * 255
cv2.imwrite('snapshot/train-{}-{}.png'.format(epoch, i), im[:,:,::-1])
args.display_curindex += 1
""" every epoch """
# loss_output = 'ind: ' + str(args.display_curindex)
loss_output = ''
for i,v in enumerate(run_losses):
if i == len(run_losses)-1:
loss_output += ' merged: %6f' % (run_losses[i] / run_cnts[i])
continue
loss_output += ' %2dM: %6f' % ((2**(4-i)), (run_losses[i] / run_cnts[i]))
print(loss_output)
# save at every epoch
if (epoch+1) % 10 == 0:
print('snapshot')
torch.save({
'epoch': epoch,
'args' : args,
'state_dict_16M': nets[0].state_dict(),
'state_dict_08M': nets[1].state_dict(),
'state_dict_04M': nets[2].state_dict(),
'state_dict_02M': nets[3].state_dict(),
'state_dict_01M': nets[4].state_dict(),
'state_dict_merge': nets[5].state_dict(),
'optimizer_16M': optimizers[0].state_dict(),
'optimizer_08M': optimizers[1].state_dict(),
'optimizer_04M': optimizers[2].state_dict(),
'optimizer_02M': optimizers[3].state_dict(),
'optimizer_01M': optimizers[4].state_dict(),
'optimizer_merge': optimizers[5].state_dict()
}, 'snapshot/snapshot-{}.pth.tar'.format(epoch))
# test
if epoch % 5 != 0 or epoch == 0: continue
print('eval net')
test_losses_trainphase = [0] * len(args.training_thresholds)
test_cnts_trainphase = [0.00001] * len(args.training_thresholds)
for ind, data in enumerate(test_loader, 0):
input_img, gt_albedo, gt_shading, test_scene, img_path = data
input_img = Variable(input_img)
gt_albedo = Variable(gt_albedo)
gt_shading = Variable(gt_shading)
if use_gpu:
input_img = input_img.cuda(args.gpu_num)
# ft_pretrained = pretrained(input_img)
ft_test = [0]*len(nets)
for i,v in enumerate(ft_test):
if epoch < args.training_thresholds[i]: continue
if i < 5: ft_test[i] = nets[i](ft_pretrained)
else: ft_test[i] = nets[i](ft_test[0:-1])
if i == 5: s = 1
else: s = (2**(i+1))
gt = gt_albedo.data.numpy()
n,c,h,w = gt.shape
gt = gt[0,:,:,:]
gt = gt.transpose((1,2,0))
gt = cv2.resize(gt, (h//s, w//s))
# gt = cv2.resize(gt, (h,w))
gt = gt.transpose((2,0,1))
gt = gt[np.newaxis, :]
gt = Variable(torch.from_numpy(gt))
if use_gpu: gt = gt.cuda()
loss = mse_losses[i](ft_test[i], gt)
test_losses_trainphase[i] += loss.data.cpu().numpy()[0]
test_cnts_trainphase[i] += 1
v = ft_test[i]
v = v[0].cpu().data.numpy()
v = v.transpose(1,2,0)
if ind == 0: cv2.imwrite('snapshot/test-phase_train-{}-{}.png'.format(epoch, i), v[:,:,::-1]*255)
# net.eval()
net16.eval(); net08.eval(); net04.eval(); net02.eval(); net01.eval(); netmg.eval();
test_losses = [0] * len(args.training_thresholds)
test_cnts = [0.00001] * len(args.training_thresholds)
for ind, data in enumerate(test_loader, 0):
# if ind == 1: break
input_img, gt_albedo, gt_shading, test_scene, img_path = data
input_img = Variable(input_img)
gt_albedo = Variable(gt_albedo)
gt_shading = Variable(gt_shading)
if use_gpu:
input_img = input_img.cuda(args.gpu_num)
# ft_test = net(input_img)
# ft_pretrained = pretrained(input_img)
ft_test = [0]*len(nets)
for i,v in enumerate(ft_test):
if epoch < args.training_thresholds[i]: continue
if i < 5: ft_test[i] = nets[i](ft_pretrained_test_phase)
else: ft_test[i] = nets[i](ft_test[0:-1])
if i == 5: s = 1
else: s = (2**(i+1))
gt = gt_albedo.data.numpy()
n,c,h,w = gt.shape
gt = gt[0,:,:,:]
gt = gt.transpose((1,2,0))
gt = cv2.resize(gt, (h//s, w//s))
# gt = cv2.resize(gt, (h,w))
gt = gt.transpose((2,0,1))
gt = gt[np.newaxis, :]
gt = Variable(torch.from_numpy(gt))
if use_gpu: gt = gt.cuda()
loss = mse_losses[i](ft_test[i], gt)
test_losses[i] += loss.data.cpu().numpy()[0]
test_cnts[i] += 1
v = ft_test[i]
v = v[0].cpu().data.numpy()
v = v.transpose(1,2,0)
if ind == 0: cv2.imwrite('snapshot/test-phase_test-{}-{}.png'.format(epoch, i), v[:,:,::-1]*255)
writer.add_scalars('16M loss', {
'train 16M ': np.array([run_losses[0]/ run_cnts[0]]),
'test_trainphase 16M ': np.array([test_losses_trainphase[0]/ test_cnts_trainphase[0]]),
'test 16M ': np.array([test_losses[0]/ test_cnts[0]])
}, global_step=epoch)
writer.add_scalars('8M loss', {
'train 8M ': np.array([run_losses[1]/ run_cnts[1]]),
'test_trainphase 8M ': np.array([test_losses_trainphase[1]/ test_cnts_trainphase[1]]),
'test 8M ': np.array([test_losses[1]/ test_cnts[1]])
}, global_step=epoch)
writer.add_scalars('4M loss', {
'train 4M ': np.array([run_losses[2]/ run_cnts[2]]),
'test_trainphase 4M ': np.array([test_losses_trainphase[2]/ test_cnts_trainphase[2]]),
'test 4M ': np.array([test_losses[2]/ test_cnts[2]])
}, global_step=epoch)
writer.add_scalars('2M loss', {
'train 2M ': np.array([run_losses[3]/ run_cnts[3]]),
'test_trainphase 2M ': np.array([test_losses_trainphase[3]/ test_cnts_trainphase[3]]),
'test 2M ': np.array([test_losses[3]/ test_cnts[3]])
}, global_step=epoch)
writer.add_scalars('1M loss', {
'train 1M ': np.array([run_losses[4]/ run_cnts[4]]),
'test_trainphase 1M ': np.array([test_losses_trainphase[4]/ test_cnts_trainphase[4]]),
'test 1M ': np.array([test_losses[4]/ test_cnts[4]])
}, global_step=epoch)
writer.add_scalars('merged loss', {
'train merged ': np.array([run_losses[5]/ run_cnts[5]]),
'test_trainphase merged ': np.array([test_losses_trainphase[5]/ test_cnts_trainphase[5]]),
'test merged ': np.array([test_losses[5]/ test_cnts[5]])
}, global_step=epoch)
In [9]:
from graphviz import Digraph
import torch
from torch.autograd import Variable
def make_dot(var, params=None):
""" Produces Graphviz representation of PyTorch autograd graph
Blue nodes are the Variables that require grad, orange are Tensors
saved for backward in torch.autograd.Function
Args:
var: output Variable
params: dict of (name, Variable) to add names to node that
require grad (TODO: make optional)
"""
if params is not None:
assert isinstance(params.values()[0], Variable)
param_map = {id(v): k for k, v in params.items()}
node_attr = dict(style='filled',
shape='box',
align='left',
fontsize='12',
ranksep='0.1',
height='0.2')
dot = Digraph(node_attr=node_attr, graph_attr=dict(size="10240,10240"), format='svg')
seen = set()
def size_to_str(size):
return '('+(', ').join(['%d' % v for v in size])+')'
def add_nodes(var):
if var not in seen:
if torch.is_tensor(var):
dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')
elif hasattr(var, 'variable'):
u = var.variable
name = param_map[id(u)] if params is not None else ''
node_name = '%s\n %s' % (name, size_to_str(u.size()))
dot.node(str(id(var)), node_name, fillcolor='lightblue')
else:
dot.node(str(id(var)), str(type(var).__name__))
seen.add(var)
if hasattr(var, 'next_functions'):
for u in var.next_functions:
if u[0] is not None:
dot.edge(str(id(u[0])), str(id(var)))
add_nodes(u[0])
if hasattr(var, 'saved_tensors'):
for t in var.saved_tensors:
dot.edge(str(id(t)), str(id(var)))
add_nodes(t)
add_nodes(var.grad_fn)
return dot
In [10]:
x = Variable(torch.zeros(1,3,256,256))
y = net(x.cuda())
g = make_dot(y[-1])
In [11]:
g.render('net')
Out[11]:
In [ ]: