In [1]:
import os
import glob
import torch
import torch.nn as nn
import torch.utils.data as data_utils
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as opti
from torch.autograd import Variable
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms
import cv2
from PIL import Image
import numpy as np
from numpy.linalg import inv
from scipy.sparse import lil_matrix, csr_matrix
from scipy.sparse.linalg import spsolve
In [2]:
class Args(object):
pass
args = Args()
args.epoches = 20
args.epoches_unary_threshold = 0
args.base_lr = 1e-5
args.train_dir = '/home/albertxavier/dataset/sintel/images/'
args.arch = "resnet18"
args.img_extentions = ["png",'jpg']
args.image_w = 256
args.image_h = 256
In [3]:
def default_loader(path):
return Image.open(path).convert('RGB')
def make_dataset(dir):
images_paths = glob.glob(os.path.join(dir, 'clean', '*', '*.png'))
albedo_paths = images_paths[:]
shading_paths = images_paths[:]
pathes = []
for img_path in images_paths:
sp = img_path.split('/'); sp[-3] = 'albedo'; sp = ['/'] + sp; albedo_path = os.path.join(*sp)
sp = img_path.split('/'); sp[-3] = 'albedo'; sp = ['/'] + sp; shading_path = os.path.join(*sp)
pathes.append((img_path, albedo_path, shading_path))
return pathes
class MyImageFolder(data_utils.Dataset):
def __init__(self, root, transform=None, target_transform=None,
loader=default_loader):
imgs = make_dataset(root)
if len(imgs) == 0:
raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"
"Supported image extensions are: " + ",".join(args.img_extentions)))
self.root = root
self.imgs = imgs
self.transform = transform
self.target_transform = target_transform
self.loader = loader
def __getitem__(self, index):
img_path, albedo_path, shading_path = self.imgs[index]
img = self.loader(img_path)
albedo = self.loader(albedo_path)
shading = self.loader(shading_path)
if self.transform is not None: img = self.transform(img)
if self.transform is not None: albedo = self.transform(albedo)
if self.transform is not None: shading = self.transform(shading)
return img, albedo, shading
def __len__(self):
return len(self.imgs)
dataset= MyImageFolder(args.train_dir,
transforms.Compose(
[transforms.RandomCrop((args.image_h, args.image_w)),
transforms.ToTensor()]
))
train_loader =data_utils.DataLoader(dataset,1,True,num_workers=1)
In [11]:
"""
FineTuneModel: https://gist.github.com/panovr/2977d9f26866b05583b0c40d88a315bf
ResNet:
https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
Layer Size:
0: (0:4) 1/4
1: (4:5) 1/4
2: (5:6) 1/8
3: (6:7) 1/16
4: (7:8) 1/32
5: (8:9) 1/32
"""
class CommonModel(nn.Module):
def __init__(self, original_model, arch):
super(CommonModel, self).__init__()
if arch.startswith('resnet') :
self.unary_2M = nn.Sequential(*list(original_model.children())[0:7])
self.unary_1M = nn.Sequential(*list(original_model.children())[7:8])
def forward(self, x):
_2M = self.unary_2M(x)
_1M = self.unary_1M(_2M)
return _2M, _1M
class PotentialModel(nn.Module):
def __init__(self, arch, output_channels):
super(PotentialModel, self).__init__()
# Scale 1
self.unary_2M_to_8M = nn.ConvTranspose2d(kernel_size=8, stride=4, padding=2, in_channels=256, out_channels=256, groups=256, bias=False)
self.unary_1M_to_8M = nn.ConvTranspose2d(kernel_size=16, stride=8, padding=4, in_channels=512, out_channels=512, groups=512, bias=False)
tmp_model = models.__dict__[args.arch](pretrained=True)
# Scale 2
tmp_model.inplanes = 3
self.unary_layer_raw = tmp_model._make_layer(models.resnet.BasicBlock, 128, 2, stride=4)
# Merged
tmp_model.inplanes = 896
self.unary_layer1 = tmp_model._make_layer(models.resnet.BasicBlock, 256, 2, stride=1)
tmp_model.inplanes = 256
self.unary_layer2 = tmp_model._make_layer(models.resnet.BasicBlock, 128, 2, stride=1)
self.unary_deconv = nn.ConvTranspose2d(kernel_size=8, stride=4, padding=2, in_channels=128, out_channels=output_channels, bias=True)
def forward(self, x, _2M, _1M):
raw = self.unary_layer_raw(x)
col1 = self.unary_2M_to_8M(_2M)
col2 = self.unary_1M_to_8M(_1M)
cat = torch.cat([raw, col1, col2], 1)
logcat = torch.log1p(cat)
layer1 = self.unary_layer1(logcat)
layer2 = self.unary_layer2(layer1)
output = self.unary_deconv(layer2)
# output = deconv.view(deconv.nelement(), -1)
return output
class CrfLossModel(nn.Module):
def __init__(self):
pass
def forward(self, unary, pairwise, gt):
E_unary = (unary-gt)**2
E_pairwise_up = pairwise[:,0,:,:].repeat(gt.size()) * ((generateShiftGd(gt, 0) - gt) ** 2)
E_pairwise_right = pairwise[:,1,:,:].repeat(gt.size()) * ((generateShiftGd(gt, 1) - gt) ** 2)
E_pairwise_down = pairwise[:,2,:,:].repeat(gt.size()) * ((generateShiftGd(gt, 2) - gt) ** 2)
E_pairwise_left = pairwise[:,3,:,:].repeat(gt.size()) * ((generateShiftGd(gt, 3) - gt) ** 2)
E = E_unary + E_pairwise_up + E_pairwise_right + E_pairwise_down + E_pairwise_left
Z = torch.sum(E)
return E/Z
def generateShiftGd(gt, dir_):
if dir_ == 0:
new = torch.zeros(gt.size())
new[:,:,0,:] = gt[:,:,0,:]
new[:,:,1:,:] = gt[:,:,:-1,:]
return new
elif dir_ == 1:
new = torch.zeros(gt.size())
new[:,:,:,-1] = gt[:,:,:,-1]
new[:,:,:,:-1] = gt[:,:,:,1:]
return new
elif dir_ == 2:
new = torch.zeros(gt.size())
new[:,:,-1,:] = gt[:,:,-1,:]
new[:,:,:-1,:] = gt[:,:,1:,:]
return new
elif dir_ == 3:
new = torch.zeros(gt.size())
new[:,:,:,0] = gt[:,:,:,0]
new[:,:,:,1:] = gt[:,:,:,:-1]
return new
class Net(nn.Module):
def __init__(self, original_model, CommonModel, PotentialModel, CrfLossModel=None):
super(Net, self).__init__()
self.common_net = CommonModel(original_model, args.arch)
self.unary_net = PotentialModel(args.arch, output_channels=3)
self.pairwise_net = PotentialModel(args.arch, output_channels=4)
# self.crfloss_net = CrfLossModel()
def forward(self,x, gt):
_2M, _1M = self.common_net(x)
unary = self.unary_net(x, _2M, _1M)
pairwise = self.pairwise_net(x, _2M, _1M)
# crfloss = self.crfloss_net(unary, pairwise, gt)
# return crfloss, unary, pairwise
return unary, pairwise
def cat(a, b, axis=0):
if a is None:
a = b
else:
a = np.concatenate((a,b), axis=axis)
return a
def generateR(pairwise_):
n,_,h,w = pairwise_.numpy().shape
c = 3
nele = n*c*h*w
planesize = h*w
imgsize = c*h*w
row = np.linspace(0, planesize-1, planesize, dtype=np.uint32)
col_up = (row.copy() - w).clip(0,planesize-1)
col_right = (row.copy() + 1).clip(0,planesize-1)
col_down = (row.copy() + w).clip(0,planesize-1)
col_left = (row.copy() -1).clip(0,planesize-1)
row = row.reshape((1,-1))
col_up = col_up.reshape((1,-1))
col_right = col_right.reshape((1,-1))
col_down = col_down.reshape((1,-1))
col_left = col_left.reshape((1,-1))
tmp_row = None
tmp_col = None
tmp_col_up = None
tmp_col_right = None
tmp_col_down = None
tmp_col_left = None
for i in range(0, 4):
for j in range(0,4):
tmp_row = cat(tmp_row, row + planesize*i)
tmp_col_up = cat(tmp_col_up, col_up + planesize*i)
tmp_col_right = cat(tmp_col_right, col_right + planesize*i)
tmp_col_down = cat(tmp_col_down, col_down + planesize*i)
tmp_col_left = cat(tmp_col_left, col_left + planesize*i)
tmp_col = cat(tmp_col, tmp_col_up)
tmp_col = cat(tmp_col, tmp_col_right)
tmp_col = cat(tmp_col, tmp_col_down)
tmp_col = cat(tmp_col, tmp_col_left)
Rdata = pairwise[tmp_col,0]
ind = np.zeros((2, tmp_row.size))
ind[0,:] = tmp_row
ind[1,:] = tmp_col
return Rdata, ind
def getBestPredictY(unary_, pairwise_):
Rdata_np, ind_np = generateR(pairwise_)
R = torch.sparse(torch.from_numpy(), )
class BestPredictFunc(torch.autograd.Function):
def forward(self, unary_, pairwise_, gt_):
n,c,h,w = gt_.numpy().shape
self.nele = n*c*h*w
self.planesize = h*w
self.imgsize = c*h*w
self.row = np.linspace(0, self.planesize-1, self.planesize, dtype=np.uint32)
tmp = np.linspace(0, self.planesize-1, self.planesize, dtype=np.uint32)
self.col_up = (tmp.copy() - w).clip(0,self.planesize-1)
self.col_right = (tmp.copy() + 1).clip(0,self.planesize-1)
self.col_down = (tmp.copy() + w).clip(0,self.planesize-1)
self.col_left = (tmp.copy() -1).clip(0,self.planesize-1)
self.row = self.row.reshape((1,-1))
self.col_up = self.col_up.reshape((1,-1))
self.col_right = self.col_right.reshape((1,-1))
self.col_down = self.col_down.reshape((1,-1))
self.col_left = self.col_left.reshape((1,-1))
self.unary = unary_.clone().view(-1,1).numpy().astype(np.float32)
self.pairwise = pairwise_.clone().view(-1,1).numpy().astype(np.float32)
self.gt = gt_.clone().view(-1,1).numpy().astype(np.float32)
self.R = lil_matrix((self.imgsize, self.imgsize))
for i in range(0, 4):
tmp_row = np.concatenate(tmp_row,torch.from_numpy(self.row + self.planesize).type(torch.LongTensor)
tmp_col_up = torch.from_numpy(self.col_up + self.planesize).type(torch.LongTensor)
tmp_col_right = torch.from_numpy(self.col_right + self.planesize).type(torch.LongTensor)
tmp_col_down = torch.from_numpy(self.col_down + self.planesize).type(torch.LongTensor)
tmp_col_left = torch.from_numpy(self.col_left + self.planesize).type(torch.LongTensor)
tmp_pw_up = torch.from_numpy(self.pairwise[self.col_up + self.planesize,0]).type(torch.FloatTensor)
tmp_pw_right = torch.from_numpy(self.pairwise[self.col_right + self.planesize,0]).type(torch.FloatTensor)
tmp_pw_down = torch.from_numpy(self.pairwise[self.col_down + self.planesize,0]).type(torch.FloatTensor)
tmp_pw_left = torch.from_numpy(self.pairwise[self.col_left + self.planesize,0]).type(torch.FloatTensor)
self.I = lil_matrix((self.imgsize,self.imgsize))
self.I.setdiag(1)
self.D = lil_matrix((self.imgsize,self.imgsize))
self.D.setdiag(self.R.sum(axis=1))
self.A = self.I + self.D - self.R
predict_albedo = spsolve(self.A.tocsr(), self.unary)
tw = torch.LongTensor(w)
th = torch.LongTensor(h)
self.save_for_backward(th, tw)
return torch.from_numpy(predict_albedo).type(torch.FloatTensor)
def backward(self, grad_output):
th, tw = self.saved_tensors
return torch.zeros(1,3,th[0],tw[0]), torch.zeros(1,4,th[0],tw[0]), torch.zeros(1,3,th[0],tw[0])
class BestPredictModule(nn.Module):
def __init__(self, n,c,h,w):
super(BestPredictModule, self).__init__()
self.nele = n*c*h*w
self.planesize = h*w
self.imgsize = c*h*w
print "self.imgsize = ", self.imgsize
self.row = np.linspace(0, self.planesize-1, self.planesize, dtype=np.uint32)
tmp = np.linspace(0, self.planesize-1, self.planesize, dtype=np.uint32)
self.col_up = (tmp.copy() - w).clip(0,self.planesize-1)
self.col_right = (tmp.copy() + 1).clip(0,self.planesize-1)
self.col_down = (tmp.copy() + w).clip(0,self.planesize-1)
self.col_left = (tmp.copy() -1).clip(0,self.planesize-1)
self.row = self.row.reshape((1,-1))
self.col_up = self.col_up.reshape((1,-1))
self.col_right = self.col_right.reshape((1,-1))
self.col_down = self.col_down.reshape((1,-1))
self.col_left = self.col_left.reshape((1,-1))
def forward(self, unary_, pairwise_, gt_):
return BestPredictFunc()(unary_, pairwise_, gt_)
In [12]:
original_model = models.__dict__[args.arch](pretrained=True)
# print original_model.state_dict()
net = Net(original_model=original_model, CommonModel=CommonModel, PotentialModel=PotentialModel)
mse_loss = nn.MSELoss()
mse_loss_best = nn.MSELoss()
optimizer = optim.SGD(net.parameters(), lr=args.base_lr)
In [13]:
n,c,h,w = 1,3,args.image_h,args.image_w
# input_img = Variable(torch.rand((1,3,args.image_h,args.image_w)))
# gt_albedo = Variable(torch.rand((1,3,args.image_h,args.image_w)))
# predict_unary, predict_pairwise = net((input_img), (gt_albedo))
# loss_unary = mse_loss(predict_unary, (gt_albedo))
best_predict_model = BestPredictModule(n,c,h,w)
predict_unary = Variable(torch.rand((1,3,args.image_h,args.image_w)))
predict_pairwise = Variable(torch.rand((1,3,args.image_h,args.image_w)))
gt_albedo = torch.rand((1,3,args.image_h,args.image_w))
best_predict = best_predict_model((predict_unary), (predict_pairwise), Variable(gt_albedo))
loss_best = mse_loss_best(best_predict, Variable(gt_albedo))
loss_best.backward()
In [6]:
for epoch in range(args.epoches):
run_loss = 0.
for i, data in enumerate(train_loader, 0):
input_img, gt_albedo, gt_shading = data
input_img = Variable(input_img)
gt_albedo = Variable(gt_albedo)
gt_shading = Variable(gt_shading)
# input_img = input_img.cuda()
# gt_albedo = gt_albedo.cuda()
# gt_shading = gt_shading.cuda()
optimizer.zero_grad()
predict_unary, predict_pairwise = net((input_img), (gt_albedo))
if epoch < args.epoches_unary_threshold:
loss_unary = mse_loss(predict_unary, (gt_albedo))
print "loss unary = ", loss_unary
loss_unary.backward()
optimizer.step()
else:
n,c,h,w = 1,3,args.image_h,args.image_w
best_predict_model = BestPredictModule(n,c,h,w)
best_predict = best_predict_model((predict_unary), (predict_pairwise), (gt_albedo))
loss_best = mse_loss_best((best_predict), (gt_albedo))
print "loss best = ", loss_best
loss_best.backward()
optimizer.step()
In [210]:
cre = nn.MSELoss()
a = Variable(torch.Tensor([1,2]))
b = Variable(torch.Tensor([0,0]))
out = cre(a,b)
out
aa = torch.Tensor([[1,2],[3,4]])
torch.sum(aa)
Out[210]:
In [86]:
def main():
original_model = models.__dict__[args.arch](pretrained=True)
common_model = CommonModel(original_model, args.arch)
model = PotentialModel(args.arch)
print model
if __name__ == '__main__':
main()
In [98]:
from graphviz import Digraph
import torch
from torch.autograd import Variable
def make_dot(var, params=None):
""" Produces Graphviz representation of PyTorch autograd graph
Blue nodes are the Variables that require grad, orange are Tensors
saved for backward in torch.autograd.Function
Args:
var: output Variable
params: dict of (name, Variable) to add names to node that
require grad (TODO: make optional)
"""
if params is not None:
assert isinstance(params.values()[0], Variable)
param_map = {id(v): k for k, v in params.items()}
node_attr = dict(style='filled',
shape='box',
align='left',
fontsize='12',
ranksep='0.1',
height='0.2')
dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
seen = set()
def size_to_str(size):
return '('+(', ').join(['%d' % v for v in size])+')'
def add_nodes(var):
if var not in seen:
if torch.is_tensor(var):
dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')
elif hasattr(var, 'variable'):
u = var.variable
name = param_map[id(u)] if params is not None else ''
node_name = '%s\n %s' % (name, size_to_str(u.size()))
dot.node(str(id(var)), node_name, fillcolor='lightblue')
else:
dot.node(str(id(var)), str(type(var).__name__))
seen.add(var)
if hasattr(var, 'next_functions'):
for u in var.next_functions:
if u[0] is not None:
dot.edge(str(id(u[0])), str(id(var)))
add_nodes(u[0])
if hasattr(var, 'saved_tensors'):
for t in var.saved_tensors:
dot.edge(str(id(t)), str(id(var)))
add_nodes(t)
add_nodes(var.creator)
return dot
In [110]:
def make_dot2(var):
node_attr = dict(style='filled',
shape='box',
align='left',
fontsize='12',
ranksep='0.1',
height='0.2')
dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
seen = set()
def add_nodes(var):
if var not in seen:
if isinstance(var, Variable):
value = '('+(', ').join(['%d'% v for v in var.size()])+')'
dot.node(str(id(var)), str(value), fillcolor='lightblue')
else:
dot.node(str(id(var)), str(type(var).__name__))
seen.add(var)
if hasattr(var, 'previous_functions'):
for u in var.previous_functions:
dot.edge(str(id(u[0])), str(id(var)))
add_nodes(u[0])
add_nodes(var.creator)
return dot
In [129]:
test_original_model = models.__dict__[args.arch](pretrained=True)
test_common_model = CommonModel(test_original_model, args.arch)
test_unary_model = PotentialModel(args.arch, output_channels=3, output_height=args.image_h, output_width=args.image_w)
test_pairwise_model = PotentialModel(args.arch, output_channels=4, output_height=args.image_h, output_width=args.image_w)
test_img = torch.rand(1,3,args.image_h,args.image_w)
test_img = Variable(test_img)
test_2M, test_1M = test_common_model(test_img)
out_unary = test_unary_model(test_img, test_2M, test_1M)
out_pairwise = test_pairwise_model(test_img, test_2M, test_1M)
make_dot2(out_unary)
Out[129]: