In [ ]:
import os, glob, platform, datetime, random
from collections import OrderedDict

import torch
import torch.nn as nn
import torch.utils.data as data_utils
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
from torch.autograd import Variable
from torch import functional as F
# import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms
import scipy.io as sio
import cv2
from PIL import Image
from tensorboardX import SummaryWriter

import numpy as np
from numpy.linalg import inv as denseinv
from scipy import sparse
from scipy.sparse import lil_matrix, csr_matrix
from scipy.sparse.linalg import spsolve
from scipy.sparse.linalg import inv as spinv
import scipy.misc

from myimagefoldereccv import MyImageFolder
from mymodel import GradientNet
from myargs import Args

In [ ]:
def loadimg(path):
    im = Image.open(path).convert('RGB')
    print(im.size)
    im = transforms.ToTensor()(im)
    x = torch.zeros(1,3,416,1024)
    x[0,:,:,:] = im[:,0:416,0:1024]
    #x = torch.zeros(1,3,32,32)
    #x[0,:,:,:] = im[:,0:32,0:32]
    x = Variable(x, volatile=True)
    return x

In [ ]:
def save_csv(path, para):
    text = ''
    n,c,h,w = para.size()
    text += ','.join([str(n), str(c), str(h), str(w)]) + ','
    for nn in range(n):
        for cc in range(c):
            for hh in range(h):
                for ww in range(w):
                    text += str(para[nn,cc,hh,ww].data.cpu().numpy()) + ','
    with open(path, 'w') as f:
        f.write(text)

In [ ]:
gpu_num = 0
gradient = False
type2 = 'rgb' if gradient == False else 'gd'
image_slpit = True

# parameters = filter(lambda p: p.requires_grad, net.parameters())
# optimizer = optim.SGD(parameters, lr=args.base_lr, momentum=args.momentum)

res_root = './results/images/'
scenes = glob.glob('/home/albertxavier/dataset/sintel2/clean/*')
cnt_albedo = 0
cnt_shading = 0
for scene in scenes:
    scene = scene.split('/')[-1]
    res_dir = os.path.join(res_root, 'image_split', scene)
    if not os.path.exists(res_dir):
        os.makedirs(res_dir)
    for type_ in ['albedo', 'shading']:
        
#         if scene!='image_split': continue
        
        #root = '/media/lwp/xavier/graduation_results/showcase_model/image_split/{}/{}/'.format(type_, type2)
#         root = '/media/lwp/xavier/graduation_results/showcase_model/{}/{}/{}/'.format(scene, type_, type2)
        root = '/media/albertxavier/data/eccv/graduation-project/pytorch/snapshot0/'
        print (root+'snapshot-238.pth.tar')
        if not os.path.exists(root+'snapshot-238.pth.tar'): continue
        snapshot = torch.load(root+'snapshot-238.pth.tar')
        state_dict = snapshot['state_dict']
        args = snapshot['args']
        densenet = models.__dict__[args.arch](pretrained=True).cuda(gpu_num)
        
#         net.load_state_dict(state_dict)
#         net.train()
        net = None
        num = 40 if scene=='market_6' else 50
        for ind in range(1, 11):
            if net is not None: del net
#             torch.cuda.empty_cache()
            net = GradientNet(densenet=densenet, growth_rate=32, 
                          transition_scale=2, pretrained_scale=4,
                    debug=False).cuda(gpu_num)
            
            net.load_state_dict(state_dict)
            net.train()
            frame = 'frame_%04d.png'%(ind)
            print('/home/albertxavier/dataset/sintel2/clean/{}/{}'.format(scene, frame))
            im = loadimg('/home/albertxavier/dataset/sintel2/clean/{}/{}'.format(scene, frame)).cuda(gpu_num)
            print(im.size())
            merged = net(im.cuda(gpu_num))
            alpha = merged[0,9:10,:,:]
            beta = merged[0,10:13,:,:]
            alpha = alpha.cpu().data.numpy()
            beta = beta.cpu().data.numpy()
            alpha = alpha.transpose((1,2,0))
            beta = beta.transpose((1,2,0))
            print('alpha', alpha.min(), alpha.max())
            print('beta', beta.min(), beta.max())
            #print(merged)
            ######
            #break
            ######
            # merged = mergeRGB
            merged = merged[0]
            # merged = merged[0:3,:,:]
            merged = merged.cpu().data.numpy()
            print (merged.shape)
            merged = merged.transpose(1,2,0)
            print (merged.shape)
            B = merged[:,:,0:3]
            dx = merged[:,:,3:6]
            dy = merged[:,:,6:9]
            res_frame = 'albedo_%04d.png'%(ind) if type_ == 'albedo' else 'shading_%04d.png'%(ind)
            res_dx_frame = 'albedo_dx_%04d.png'%(ind) if type_ == 'albedo' else 'shading_%04d.png'%(ind)
            res_dy_frame = 'albedo_dy_%04d.png'%(ind) if type_ == 'albedo' else 'shading_%04d.png'%(ind)
            res_alpha = 'alpha_%04d.mat'%(ind)
            res_beta =  'beta_%04d.mat'%(ind)
            print('res path', os.path.join(res_dir,res_frame))
            cv2.imwrite(os.path.join(res_dir,res_frame), B[:,:,::-1]*255)   
            cv2.imwrite(os.path.join(res_dir,res_dx_frame), (dx[:,:,::-1]+0.5)*255)   
            cv2.imwrite(os.path.join(res_dir,res_dy_frame), (dy[:,:,::-1]+0.5)*255)   
            # save_csv(os.path.join(res_dir,res_alpha), alpha)
            # save_csv(os.path.join(res_dir,res_beta), beta)
            sio.savemat(os.path.join(res_dir,res_alpha), {'alpha': alpha})
            sio.savemat(os.path.join(res_dir,res_beta), {'beta': beta})
            break

In [ ]:
# if gradient == False:
#     merged = mergeRGB[5]
#     merged = merged[0]
#     merged = merged.cpu().data.numpy()
#     print (merged.shape)
#     merged = merged.transpose(1,2,0)
#     print (merged.shape)
#     dx = merged[:,:,0:3]
#     cv2.imwrite('out_merge.png', dx[:,:,::-1]*255)

In [ ]:
# if gradient == True:
#     merged = mergeRGB[5]
#     merged = merged[0]
#     merged = merged.cpu().data.numpy()
#     print (merged.shape)
#     merged = merged.transpose(1,2,0)
#     print (merged.shape)
#     dy = merged[:,:,0:3]+0.5
#     dx = merged[:,:,3:6]+0.5
#     cv2.imwrite('out_merge_dx.png', dx[:,:,::-1]*255)
#     cv2.imwrite('out_merge_dy.png', dy[:,:,::-1]*255)