In [5]:
import os
import glob
import torch
import torch.nn as nn
import torch.utils.data as data_utils
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as opti
from torch.autograd import Variable
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms
import cv2
from PIL import Image
import numpy as np
from numpy.linalg import inv
from scipy import sparse
from scipy.sparse import lil_matrix, csr_matrix
from scipy.sparse.linalg import spsolve
from scipy.sparse.linalg import inv as spinv
In [2]:
class Args(object):
pass
args = Args()
args.epoches = 20
args.epoches_unary_threshold = 0
args.base_lr = 1e-5
args.train_dir = '/home/albertxavier/dataset/sintel/images/'
args.arch = "resnet18"
args.img_extentions = ["png",'jpg']
args.image_w = 256
args.image_h = 256
In [3]:
def default_loader(path):
return Image.open(path).convert('RGB')
def make_dataset(dir):
images_paths = glob.glob(os.path.join(dir, 'clean', '*', '*.png'))
albedo_paths = images_paths[:]
shading_paths = images_paths[:]
pathes = []
for img_path in images_paths:
sp = img_path.split('/'); sp[-3] = 'albedo'; sp = ['/'] + sp; albedo_path = os.path.join(*sp)
sp = img_path.split('/'); sp[-3] = 'albedo'; sp = ['/'] + sp; shading_path = os.path.join(*sp)
pathes.append((img_path, albedo_path, shading_path))
return pathes
class MyImageFolder(data_utils.Dataset):
def __init__(self, root, transform=None, target_transform=None,
loader=default_loader):
imgs = make_dataset(root)
if len(imgs) == 0:
raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"
"Supported image extensions are: " + ",".join(args.img_extentions)))
self.root = root
self.imgs = imgs
self.transform = transform
self.target_transform = target_transform
self.loader = loader
def __getitem__(self, index):
img_path, albedo_path, shading_path = self.imgs[index]
img = self.loader(img_path)
albedo = self.loader(albedo_path)
shading = self.loader(shading_path)
if self.transform is not None: img = self.transform(img)
if self.transform is not None: albedo = self.transform(albedo)
if self.transform is not None: shading = self.transform(shading)
return img, albedo, shading
def __len__(self):
return len(self.imgs)
dataset= MyImageFolder(args.train_dir,
transforms.Compose(
[transforms.RandomCrop((args.image_h, args.image_w)),
transforms.ToTensor()]
))
train_loader =data_utils.DataLoader(dataset,1,True,num_workers=1)
In [ ]:
In [ ]:
class CrfLossFunction(torch.autograd.Function):
def getBestPredictY(self, unary_, pairwise_):
n,c,h,w = pairwise_.numpy().shape
imgsize = c*h*w
R_data_np, ind_np = generateR(pairwise_)
diagonals = np.sum(pairwise_, axis=1).reshape((1,-1)).tile()
R_np = csr_matrix((R_data_np, (ind_np[0,:], ind_np[1,:])), shape=(imgsize,imgsize))
I_np = sparse.eye(imgsize, format="csr")
D_np = sparse.diags(diagonals, shape=(imgsize, imgsize))
A_np = I_np + D_np - R_np
z_np = unary_.numpy().reshape((-1,1))
y_np = spsolve(A_np, a_np)
y_np_ = y_np.reshape(unary_.shape)
return torch.from_numpy(y_np_)
def generateR(self, pairwise_):
n,_,h,w = pairwise_.numpy().shape
c = 3
nele = n*c*h*w
planesize = h*w
imgsize = c*h*w
p0 = np.linspace(0, planesize-1, planesize, dtype=np.uint32)
q0_up = (p0.copy() - w).clip(0,planesize-1)
q0_right = (p0.copy() + 1).clip(0,planesize-1)
q0_down = (p0.copy() + w).clip(0,planesize-1)
q0_left = (p0.copy() -1).clip(0,planesize-1)
p0 = p0.reshape((1,-1))
q0_up = q0_up.reshape((1,-1))
q0_right = q0_right.reshape((1,-1))
q0_down = q0_down.reshape((1,-1))
q0_left = q0_left.reshape((1,-1))
p = None
q = None
Rdata = None
for c in range(0, 3):
for d in range(0,4):
p = cat(p, p0 + planesize*c)
Rdata = cat(Rdata, pairwise[p0 + planesize*d])
q = cat(q, q0_up + planesize*c)
q = cat(q, q0_right + planesize*c)
q = cat(q, q0_down + planesize*c)
q = cat(q, q0_left + planesize*c)
ind = np.zeros((2, p.size))
ind[0,:] = p
ind[1,:] = q
return Rdata, ind
def generateA(self, pairwise_):
def forward(self, unary_, pairwise_):
y_best = getBestPredictY(unary_, pairwise_)
A = generateA(pairwise_)
In [ ]:
class CommonModel(nn.Module):
def __init__(self, original_model, arch):
super(CommonModel, self).__init__()
if arch.startswith('resnet') :
self.unary_2M = nn.Sequential(*list(original_model.children())[0:7])
self.unary_1M = nn.Sequential(*list(original_model.children())[7:8])
def forward(self, x):
_2M = self.unary_2M(x)
_1M = self.unary_1M(_2M)
return _2M, _1M
class PotentialModel(nn.Module):
def __init__(self, arch, output_channels):
super(PotentialModel, self).__init__()
# Scale 1
self.unary_2M_to_8M = nn.ConvTranspose2d(kernel_size=8, stride=4, padding=2, in_channels=256, out_channels=256, groups=256, bias=False)
self.unary_1M_to_8M = nn.ConvTranspose2d(kernel_size=16, stride=8, padding=4, in_channels=512, out_channels=512, groups=512, bias=False)
tmp_model = models.__dict__[args.arch](pretrained=True)
# Scale 2
tmp_model.inplanes = 3
self.unary_layer_raw = tmp_model._make_layer(models.resnet.BasicBlock, 128, 2, stride=4)
# Merged
tmp_model.inplanes = 896
self.unary_layer1 = tmp_model._make_layer(models.resnet.BasicBlock, 256, 2, stride=1)
tmp_model.inplanes = 256
self.unary_layer2 = tmp_model._make_layer(models.resnet.BasicBlock, 128, 2, stride=1)
self.unary_deconv = nn.ConvTranspose2d(kernel_size=8, stride=4, padding=2, in_channels=128, out_channels=output_channels, bias=True)
def forward(self, x, _2M, _1M):
raw = self.unary_layer_raw(x)
col1 = self.unary_2M_to_8M(_2M)
col2 = self.unary_1M_to_8M(_1M)
cat = torch.cat([raw, col1, col2], 1)
logcat = torch.log1p(cat)
layer1 = self.unary_layer1(logcat)
layer2 = self.unary_layer2(layer1)
output = self.unary_deconv(layer2)
# output = deconv.view(deconv.nelement(), -1)
return output
In [11]:
"""
FineTuneModel: https://gist.github.com/panovr/2977d9f26866b05583b0c40d88a315bf
ResNet:
https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
Layer Size:
0: (0:4) 1/4
1: (4:5) 1/4
2: (5:6) 1/8
3: (6:7) 1/16
4: (7:8) 1/32
5: (8:9) 1/32
"""
class Net(nn.Module):
def __init__(self, original_model, CommonModel, PotentialModel, CrfLossModel=None):
super(Net, self).__init__()
self.common_net = CommonModel(original_model, args.arch)
self.unary_net = PotentialModel(args.arch, output_channels=3)
self.pairwise_net = PotentialModel(args.arch, output_channels=4)
self.softplus = nn.Softplus()
# self.crfloss_net = CrfLossModel()
self.sigmoid = nn.
def forward(self,x, gt):
_2M, _1M = self.common_net(x)
unary_ = self.unary_net(x, _2M, _1M)
pairwise_ = self.pairwise_net(x, _2M, _1M)
pairwise_ = self.softplus(pairwise_)
R_data_np, ind_np = generateR(pairwise_)
def cat(a, b, axis=0):
if a is None:
a = b
else:
a = np.concatenate((a,b), axis=axis)
return a
In [12]:
original_model = models.__dict__[args.arch](pretrained=True)
# print original_model.state_dict()
net = Net(original_model=original_model, CommonModel=CommonModel, PotentialModel=PotentialModel)
mse_loss = nn.MSELoss()
mse_loss_best = nn.MSELoss()
optimizer = optim.SGD(net.parameters(), lr=args.base_lr)
In [13]:
n,c,h,w = 1,3,args.image_h,args.image_w
# input_img = Variable(torch.rand((1,3,args.image_h,args.image_w)))
# gt_albedo = Variable(torch.rand((1,3,args.image_h,args.image_w)))
# predict_unary, predict_pairwise = net((input_img), (gt_albedo))
# loss_unary = mse_loss(predict_unary, (gt_albedo))
best_predict_model = BestPredictModule(n,c,h,w)
predict_unary = Variable(torch.rand((1,3,args.image_h,args.image_w)))
predict_pairwise = Variable(torch.rand((1,3,args.image_h,args.image_w)))
gt_albedo = torch.rand((1,3,args.image_h,args.image_w))
best_predict = best_predict_model((predict_unary), (predict_pairwise), Variable(gt_albedo))
loss_best = mse_loss_best(best_predict, Variable(gt_albedo))
loss_best.backward()
In [6]:
for epoch in range(args.epoches):
run_loss = 0.
for i, data in enumerate(train_loader, 0):
input_img, gt_albedo, gt_shading = data
input_img = Variable(input_img)
gt_albedo = Variable(gt_albedo)
gt_shading = Variable(gt_shading)
# input_img = input_img.cuda()
# gt_albedo = gt_albedo.cuda()
# gt_shading = gt_shading.cuda()
optimizer.zero_grad()
predict_unary, predict_pairwise = net((input_img), (gt_albedo))
if epoch < args.epoches_unary_threshold:
loss_unary = mse_loss(predict_unary, (gt_albedo))
print "loss unary = ", loss_unary
loss_unary.backward()
optimizer.step()
else:
n,c,h,w = 1,3,args.image_h,args.image_w
best_predict_model = BestPredictModule(n,c,h,w)
best_predict = best_predict_model((predict_unary), (predict_pairwise), (gt_albedo))
loss_best = mse_loss_best((best_predict), (gt_albedo))
print "loss best = ", loss_best
loss_best.backward()
optimizer.step()
In [210]:
cre = nn.MSELoss()
a = Variable(torch.Tensor([1,2]))
b = Variable(torch.Tensor([0,0]))
out = cre(a,b)
out
aa = torch.Tensor([[1,2],[3,4]])
torch.sum(aa)
Out[210]:
In [86]:
def main():
original_model = models.__dict__[args.arch](pretrained=True)
common_model = CommonModel(original_model, args.arch)
model = PotentialModel(args.arch)
print model
if __name__ == '__main__':
main()
In [98]:
from graphviz import Digraph
import torch
from torch.autograd import Variable
def make_dot(var, params=None):
""" Produces Graphviz representation of PyTorch autograd graph
Blue nodes are the Variables that require grad, orange are Tensors
saved for backward in torch.autograd.Function
Args:
var: output Variable
params: dict of (name, Variable) to add names to node that
require grad (TODO: make optional)
"""
if params is not None:
assert isinstance(params.values()[0], Variable)
param_map = {id(v): k for k, v in params.items()}
node_attr = dict(style='filled',
shape='box',
align='left',
fontsize='12',
ranksep='0.1',
height='0.2')
dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
seen = set()
def size_to_str(size):
return '('+(', ').join(['%d' % v for v in size])+')'
def add_nodes(var):
if var not in seen:
if torch.is_tensor(var):
dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')
elif hasattr(var, 'variable'):
u = var.variable
name = param_map[id(u)] if params is not None else ''
node_name = '%s\n %s' % (name, size_to_str(u.size()))
dot.node(str(id(var)), node_name, fillcolor='lightblue')
else:
dot.node(str(id(var)), str(type(var).__name__))
seen.add(var)
if hasattr(var, 'next_functions'):
for u in var.next_functions:
if u[0] is not None:
dot.edge(str(id(u[0])), str(id(var)))
add_nodes(u[0])
if hasattr(var, 'saved_tensors'):
for t in var.saved_tensors:
dot.edge(str(id(t)), str(id(var)))
add_nodes(t)
add_nodes(var.creator)
return dot
In [110]:
def make_dot2(var):
node_attr = dict(style='filled',
shape='box',
align='left',
fontsize='12',
ranksep='0.1',
height='0.2')
dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
seen = set()
def add_nodes(var):
if var not in seen:
if isinstance(var, Variable):
value = '('+(', ').join(['%d'% v for v in var.size()])+')'
dot.node(str(id(var)), str(value), fillcolor='lightblue')
else:
dot.node(str(id(var)), str(type(var).__name__))
seen.add(var)
if hasattr(var, 'previous_functions'):
for u in var.previous_functions:
dot.edge(str(id(u[0])), str(id(var)))
add_nodes(u[0])
add_nodes(var.creator)
return dot
In [129]:
test_original_model = models.__dict__[args.arch](pretrained=True)
test_common_model = CommonModel(test_original_model, args.arch)
test_unary_model = PotentialModel(args.arch, output_channels=3, output_height=args.image_h, output_width=args.image_w)
test_pairwise_model = PotentialModel(args.arch, output_channels=4, output_height=args.image_h, output_width=args.image_w)
test_img = torch.rand(1,3,args.image_h,args.image_w)
test_img = Variable(test_img)
test_2M, test_1M = test_common_model(test_img)
out_unary = test_unary_model(test_img, test_2M, test_1M)
out_pairwise = test_pairwise_model(test_img, test_2M, test_1M)
make_dot2(out_unary)
Out[129]: