load package

import os
import glob

import torch
import torch.nn as nn
import torch.utils.data as data_utils
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as opti
from torch.autograd import Variable

import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms

import cv2
from PIL import Image

import numpy as np
from numpy.linalg import inv
from scipy import sparse
from scipy.sparse import lil_matrix, csr_matrix
from scipy.sparse.linalg import spsolve
from scipy.sparse.linalg import inv as spinv


class Args(object):
args = Args()
args.epoches = 20
args.epoches_unary_threshold = 0
args.base_lr = 1e-5
args.train_dir = '/home/albertxavier/dataset/sintel/images/'
args.arch = "resnet18"
args.img_extentions = ["png",'jpg']
args.image_w = 256
args.image_h = 256

Custom DataLoader

def default_loader(path):
    return Image.open(path).convert('RGB')

def make_dataset(dir):
    images_paths = glob.glob(os.path.join(dir, 'clean', '*', '*.png'))
    albedo_paths = images_paths[:]
    shading_paths = images_paths[:]
    pathes = []
    for img_path in images_paths:
        sp = img_path.split('/'); sp[-3] = 'albedo'; sp = ['/'] + sp; albedo_path = os.path.join(*sp)
        sp = img_path.split('/'); sp[-3] = 'albedo'; sp = ['/'] + sp; shading_path = os.path.join(*sp)
        pathes.append((img_path, albedo_path, shading_path))
    return pathes

class MyImageFolder(data_utils.Dataset):
    def __init__(self, root, transform=None, target_transform=None,
        imgs = make_dataset(root)
        if len(imgs) == 0:
            raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"
                               "Supported image extensions are: " + ",".join(args.img_extentions)))

        self.root = root
        self.imgs = imgs
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader
    def __getitem__(self, index):
        img_path, albedo_path, shading_path = self.imgs[index]
        img = self.loader(img_path)
        albedo = self.loader(albedo_path)
        shading = self.loader(shading_path)
        if self.transform is not None: img = self.transform(img)
        if self.transform is not None: albedo = self.transform(albedo)
        if self.transform is not None: shading = self.transform(shading)

        return img, albedo, shading
    def __len__(self):
        return len(self.imgs)
dataset= MyImageFolder(args.train_dir, 
        [transforms.RandomCrop((args.image_h, args.image_w)),

train_loader =data_utils.DataLoader(dataset,1,True,num_workers=1)


My nn Functions

class CrfLossFunction(torch.autograd.Function):

    def getBestPredictY(self, unary_, pairwise_):
        n,c,h,w = pairwise_.numpy().shape
        imgsize = c*h*w
        R_data_np, ind_np = generateR(pairwise_)
        diagonals = np.sum(pairwise_, axis=1).reshape((1,-1)).tile()
        R_np = csr_matrix((R_data_np, (ind_np[0,:], ind_np[1,:])), shape=(imgsize,imgsize))
        I_np = sparse.eye(imgsize, format="csr")
        D_np = sparse.diags(diagonals, shape=(imgsize, imgsize))
        A_np = I_np + D_np - R_np
        z_np = unary_.numpy().reshape((-1,1))
        y_np = spsolve(A_np, a_np)
        y_np_ = y_np.reshape(unary_.shape)
        return torch.from_numpy(y_np_)

    def generateR(self, pairwise_):

        n,_,h,w = pairwise_.numpy().shape
        c = 3

        nele = n*c*h*w 
        planesize = h*w
        imgsize = c*h*w
        p0 = np.linspace(0, planesize-1, planesize, dtype=np.uint32)
        q0_up = (p0.copy() - w).clip(0,planesize-1)
        q0_right = (p0.copy() + 1).clip(0,planesize-1)
        q0_down = (p0.copy() + w).clip(0,planesize-1)
        q0_left = (p0.copy() -1).clip(0,planesize-1)

        p0 = p0.reshape((1,-1))
        q0_up = q0_up.reshape((1,-1))
        q0_right = q0_right.reshape((1,-1))
        q0_down = q0_down.reshape((1,-1))
        q0_left = q0_left.reshape((1,-1))

        p = None
        q = None
        Rdata = None

        for c in range(0, 3):
            for d in range(0,4):
                p = cat(p, p0 + planesize*c)
                Rdata = cat(Rdata, pairwise[p0 + planesize*d])
            q = cat(q, q0_up + planesize*c)
            q = cat(q, q0_right + planesize*c)
            q = cat(q, q0_down + planesize*c)
            q = cat(q, q0_left + planesize*c)

        ind = np.zeros((2, p.size))
        ind[0,:] = p
        ind[1,:] = q
        return Rdata, ind
    def generateA(self, pairwise_):
    def forward(self, unary_, pairwise_):
        y_best = getBestPredictY(unary_, pairwise_)
        A = generateA(pairwise_)

My nn Modules

class CommonModel(nn.Module):
    def __init__(self, original_model, arch):
        super(CommonModel, self).__init__()
        if arch.startswith('resnet') :
            self.unary_2M = nn.Sequential(*list(original_model.children())[0:7])
            self.unary_1M = nn.Sequential(*list(original_model.children())[7:8])
    def forward(self, x):
        _2M = self.unary_2M(x) 
        _1M = self.unary_1M(_2M)
        return _2M, _1M

class PotentialModel(nn.Module):
    def __init__(self, arch, output_channels):
        super(PotentialModel, self).__init__()
        # Scale 1
        self.unary_2M_to_8M = nn.ConvTranspose2d(kernel_size=8, stride=4, padding=2, in_channels=256, out_channels=256, groups=256, bias=False)
        self.unary_1M_to_8M = nn.ConvTranspose2d(kernel_size=16, stride=8, padding=4, in_channels=512, out_channels=512, groups=512, bias=False)

        tmp_model = models.__dict__[args.arch](pretrained=True)
        # Scale 2
        tmp_model.inplanes = 3
        self.unary_layer_raw = tmp_model._make_layer(models.resnet.BasicBlock, 128, 2, stride=4)
        # Merged
        tmp_model.inplanes = 896
        self.unary_layer1 = tmp_model._make_layer(models.resnet.BasicBlock, 256, 2, stride=1)
        tmp_model.inplanes = 256
        self.unary_layer2 = tmp_model._make_layer(models.resnet.BasicBlock, 128, 2, stride=1)

        self.unary_deconv = nn.ConvTranspose2d(kernel_size=8, stride=4, padding=2, in_channels=128, out_channels=output_channels, bias=True)
    def forward(self, x, _2M, _1M):
        raw = self.unary_layer_raw(x)
        col1 = self.unary_2M_to_8M(_2M)
        col2 = self.unary_1M_to_8M(_1M)
        cat = torch.cat([raw, col1, col2], 1)
        logcat = torch.log1p(cat)
        layer1 = self.unary_layer1(logcat)
        layer2 = self.unary_layer2(layer1)
        output = self.unary_deconv(layer2)
#         output = deconv.view(deconv.nelement(), -1)
        return output

My Network

    FineTuneModel: https://gist.github.com/panovr/2977d9f26866b05583b0c40d88a315bf
        Layer Size:
            0: (0:4) 1/4
            1: (4:5) 1/4
            2: (5:6) 1/8
            3: (6:7) 1/16
            4: (7:8) 1/32
            5: (8:9) 1/32


class Net(nn.Module):
    def __init__(self, original_model, CommonModel, PotentialModel, CrfLossModel=None):
        super(Net, self).__init__()
        self.common_net = CommonModel(original_model, args.arch)
        self.unary_net = PotentialModel(args.arch, output_channels=3)
        self.pairwise_net = PotentialModel(args.arch, output_channels=4)
        self.softplus = nn.Softplus()
#         self.crfloss_net = CrfLossModel() 
        self.sigmoid = nn.
    def forward(self,x, gt):
        _2M, _1M = self.common_net(x)
        unary_ = self.unary_net(x, _2M, _1M)
        pairwise_ = self.pairwise_net(x, _2M, _1M)
        pairwise_ = self.softplus(pairwise_)
        R_data_np, ind_np = generateR(pairwise_)
def cat(a, b, axis=0):
    if a is None:
        a = b
        a = np.concatenate((a,b), axis=axis)
    return a

Define Optimizer and Loss Function

original_model = models.__dict__[args.arch](pretrained=True)
# print original_model.state_dict()
net = Net(original_model=original_model, CommonModel=CommonModel, PotentialModel=PotentialModel)

mse_loss = nn.MSELoss()
mse_loss_best = nn.MSELoss()

optimizer = optim.SGD(net.parameters(), lr=args.base_lr)

Training Loop

n,c,h,w = 1,3,args.image_h,args.image_w
# input_img = Variable(torch.rand((1,3,args.image_h,args.image_w))) 
# gt_albedo = Variable(torch.rand((1,3,args.image_h,args.image_w))) 

# predict_unary, predict_pairwise = net((input_img), (gt_albedo))
# loss_unary = mse_loss(predict_unary, (gt_albedo))

best_predict_model = BestPredictModule(n,c,h,w)
predict_unary = Variable(torch.rand((1,3,args.image_h,args.image_w))) 
predict_pairwise = Variable(torch.rand((1,3,args.image_h,args.image_w))) 
gt_albedo = torch.rand((1,3,args.image_h,args.image_w))
best_predict = best_predict_model((predict_unary), (predict_pairwise), Variable(gt_albedo))
loss_best = mse_loss_best(best_predict, Variable(gt_albedo))

self.imgsize =  196608
row size =  (1, 65536)
col_up size =  (1, 65536)
tmp1size =  (1, 65536)
tmp2size =  (1, 65536)
pairwise shape =  (196608, 1)

row size =  (1, 65536)
col_up size =  (1, 65536)
tmp1size =  (1, 65536)
tmp2size =  (1, 65536)
pairwise shape =  (196608, 1)

row size =  (1, 65536)
col_up size =  (1, 65536)
tmp1size =  (1, 65536)
tmp2size =  (1, 65536)
pairwise shape =  (196608, 1)

row size =  (1, 65536)
col_up size =  (1, 65536)
tmp1size =  (1, 65536)
tmp2size =  (1, 65536)
pairwise shape =  (196608, 1)

RuntimeError                              Traceback (most recent call last)
<ipython-input-13-e46dd822f31a> in <module>()
     12 best_predict = best_predict_model((predict_unary), (predict_pairwise), Variable(gt_albedo))
     13 loss_best = mse_loss_best(best_predict, Variable(gt_albedo))
---> 14 loss_best.backward()

/home/albertxavier/anaconda2/lib/python2.7/site-packages/torch/autograd/variable.pyc in backward(self, gradient, retain_variables)
    144                     'or with gradient w.r.t. the variable')
    145             gradient = self.data.new().resize_as_(self.data).fill_(1)
--> 146         self._execution_engine.run_backward((self,), (gradient,), retain_variables)
    148     def register_hook(self, hook):

RuntimeError: there are no graph nodes that require computing gradients

for epoch in range(args.epoches): 
    run_loss = 0.
    for i, data in enumerate(train_loader, 0):
        input_img, gt_albedo, gt_shading = data
        input_img = Variable(input_img)
        gt_albedo = Variable(gt_albedo)
        gt_shading = Variable(gt_shading)
#         input_img = input_img.cuda()
#         gt_albedo = gt_albedo.cuda()
#         gt_shading = gt_shading.cuda()
        predict_unary, predict_pairwise = net((input_img), (gt_albedo))
        if epoch < args.epoches_unary_threshold:
            loss_unary = mse_loss(predict_unary, (gt_albedo))
            print "loss unary = ", loss_unary
            n,c,h,w = 1,3,args.image_h,args.image_w
            best_predict_model = BestPredictModule(n,c,h,w)
            best_predict = best_predict_model((predict_unary), (predict_pairwise), (gt_albedo))
            loss_best = mse_loss_best((best_predict), (gt_albedo))
            print "loss best = ", loss_best

self.imgsize =  196608
row size =  (1, 65536)
col_up size =  (1, 65536)
tmp1size =  (1, 65536)
tmp2size =  (1, 65536)
pairwise shape =  (262144, 1)

row size =  (1, 65536)
col_up size =  (1, 65536)
tmp1size =  (1, 65536)
tmp2size =  (1, 65536)
pairwise shape =  (262144, 1)

row size =  (1, 65536)
col_up size =  (1, 65536)
tmp1size =  (1, 65536)
tmp2size =  (1, 65536)
pairwise shape =  (262144, 1)

row size =  (1, 65536)
col_up size =  (1, 65536)
tmp1size =  (1, 65536)
tmp2size =  (1, 65536)
pairwise shape =  (262144, 1)

RuntimeError                              Traceback (most recent call last)
<ipython-input-6-af9bb7a93a28> in <module>()
     20             n,c,h,w = 1,3,args.image_h,args.image_w
     21             best_predict_model = BestPredictModule(n,c,h,w)
---> 22             best_predict = best_predict_model((predict_unary), (predict_pairwise), (gt_albedo))
     23             loss_best = mse_loss_best((best_predict).type(torch.FloatTensor), (gt_albedo))
     24             print "loss best = ", loss_best

/home/albertxavier/anaconda2/lib/python2.7/site-packages/torch/nn/modules/module.pyc in __call__(self, *input, **kwargs)
    201     def __call__(self, *input, **kwargs):
--> 202         result = self.forward(*input, **kwargs)
    203         for hook in self._forward_hooks.values():
    204             hook_result = hook(self, input, result)

<ipython-input-4-6ab1b1cad7bf> in forward(self, unary_, pairwise_, gt_)
    188     def forward(self, unary_, pairwise_, gt_):
--> 189         return BestPredictFunc()(unary_, pairwise_, gt_)

RuntimeError: save_for_backward can only save tensors, but argument 0 is of type torch.Size

cre = nn.MSELoss()
a = Variable(torch.Tensor([1,2]))
b = Variable(torch.Tensor([0,0]))
out = cre(a,b)
aa = torch.Tensor([[1,2],[3,4]])


def main():
    original_model = models.__dict__[args.arch](pretrained=True)
    common_model = CommonModel(original_model, args.arch)
    model = PotentialModel(args.arch)
    print model
if __name__ == '__main__':

PotentialModel (
  (unary_2M_to_8M): ConvTranspose2d(256, 256, kernel_size=(8, 8), stride=(4, 4), padding=(2, 2), groups=256, bias=False)
  (unary_1M_to_8M): ConvTranspose2d(512, 512, kernel_size=(16, 16), stride=(8, 8), padding=(4, 4), groups=512, bias=False)
  (unary_layer_raw): Sequential (
    (0): BasicBlock (
      (conv1): Conv2d(3, 128, kernel_size=(3, 3), stride=(4, 4), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)
      (relu): ReLU (inplace)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)
      (downsample): Sequential (
        (0): Conv2d(3, 128, kernel_size=(1, 1), stride=(4, 4), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)
    (1): BasicBlock (
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)
      (relu): ReLU (inplace)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)
  (unary_layer1): Sequential (
    (0): BasicBlock (
      (conv1): Conv2d(768, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)
      (relu): ReLU (inplace)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)
      (downsample): Sequential (
        (0): Conv2d(768, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)
    (1): BasicBlock (
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)
      (relu): ReLU (inplace)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)
  (unary_layer2): Sequential (
    (0): BasicBlock (
      (conv1): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)
      (relu): ReLU (inplace)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)
      (downsample): Sequential (
        (0): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)
    (1): BasicBlock (
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)
      (relu): ReLU (inplace)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)
  (unary_output): ConvTranspose2d(128, 3, kernel_size=(8, 8), stride=(4, 4), padding=(2, 2))

Visualize Graph

from graphviz import Digraph
import torch
from torch.autograd import Variable

def make_dot(var, params=None):
    """ Produces Graphviz representation of PyTorch autograd graph

    Blue nodes are the Variables that require grad, orange are Tensors
    saved for backward in torch.autograd.Function

        var: output Variable
        params: dict of (name, Variable) to add names to node that
            require grad (TODO: make optional)
    if params is not None:
        assert isinstance(params.values()[0], Variable)
        param_map = {id(v): k for k, v in params.items()}

    node_attr = dict(style='filled',
    dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
    seen = set()

    def size_to_str(size):
        return '('+(', ').join(['%d' % v for v in size])+')'

    def add_nodes(var):
        if var not in seen:
            if torch.is_tensor(var):
                dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')
            elif hasattr(var, 'variable'):
                u = var.variable
                name = param_map[id(u)] if params is not None else ''
                node_name = '%s\n %s' % (name, size_to_str(u.size()))
                dot.node(str(id(var)), node_name, fillcolor='lightblue')
                dot.node(str(id(var)), str(type(var).__name__))
            if hasattr(var, 'next_functions'):
                for u in var.next_functions:
                    if u[0] is not None:
                        dot.edge(str(id(u[0])), str(id(var)))
            if hasattr(var, 'saved_tensors'):
                for t in var.saved_tensors:
                    dot.edge(str(id(t)), str(id(var)))
    return dot

def make_dot2(var):
    node_attr = dict(style='filled',
    dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
    seen = set()

    def add_nodes(var):
        if var not in seen:
            if isinstance(var, Variable):
                value = '('+(', ').join(['%d'% v for v in var.size()])+')'
                dot.node(str(id(var)), str(value), fillcolor='lightblue')
                dot.node(str(id(var)), str(type(var).__name__))
            if hasattr(var, 'previous_functions'):
                for u in var.previous_functions:
                    dot.edge(str(id(u[0])), str(id(var)))
    return dot

test_original_model = models.__dict__[args.arch](pretrained=True)
test_common_model = CommonModel(test_original_model, args.arch)
test_unary_model = PotentialModel(args.arch, output_channels=3, output_height=args.image_h, output_width=args.image_w)
test_pairwise_model = PotentialModel(args.arch, output_channels=4, output_height=args.image_h, output_width=args.image_w)

test_img = torch.rand(1,3,args.image_h,args.image_w)
test_img = Variable(test_img)
test_2M, test_1M = test_common_model(test_img)

out_unary = test_unary_model(test_img, test_2M, test_1M)
out_pairwise = test_pairwise_model(test_img, test_2M, test_1M)


