In [ ]:
%matplotlib inline
import tensorflow as tf
import numpy as np
import scipy as sp
from scipy import io
from scipy import interpolate
from scipy import ndimage
import os
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

Parameters: see parameters_to_use.txt


In [ ]:
#camera parameters
cf = 3.19 #sensor crop factor
ml_pitch = 0.02 #distance between microlenses: 20 um
num_sa = 14.0 #number of subaperture images
num_sa_d = 8.0 #number of desired subaperture images (cropped to only consider those that lie within camera aperture)
lfsize = [8, 8, 200, 200] #size of 4D light field
f35 = 30.0 #focal length (35 mm equivalent)
f_num = 2.0 #f-number
f = f35/cf
aperture_size = f/f_num
aperture_res = (aperture_size/num_sa)/ml_pitch

#forward model parameters
path_order = 3 #implemented for either 2 (quadratic) or 3 (cubic) Bezier curves
num_exp_pts = 7 #number of exposure points to average along motion curve (try {5, 7, 10})

#blind deblurring parameters
lam_init = 8e-3 #initial regularization weight
num_iters = 1000 #optimization iterations
eta_pts = 1.0 #step size for control points
eta_lf = 0.4 #step size for sharp light field
eps_init = 0.12 #initial parameter for evolving L0 approximation (try 0.1 or 0.15 for real data, 0.12 for synthetic data)
lam_decay = 0.997 #regularization weight decay
eps_decay = 0.997 #evolving L0 approximation decay
lam_min = 0.1*lam_init #minimum regularization weight (try 0.25*lam_init for real data, 0.1*lam_init for synthetic data)
eps_min = 0.04 #minimum L0 approximation parameter 

#final light field restoration parameters
lam_lf = 0.004 #TV regularization weight (try {0.002, 0.04})
iters_lf = 100 #TV regularization iterations
eta_lf = 0.04 #TV regularization step size

Functions


In [ ]:
def normalize_lf(lf):
    #normalize to between 0 and 1
    return ((lf-np.amin(lf))/(np.amax(lf)-np.amin(lf)))

In [ ]:
def meshgrid_4D(v_vals, u_vals, y_vals, x_vals):
    #4D meshgrid
    with tf.name_scope('meshgrid_4D'):
        
        size = [tf.size(v_vals), tf.size(u_vals), tf.size(y_vals), tf.size(x_vals)]
        v_vals_e = tf.expand_dims(tf.expand_dims(tf.expand_dims(v_vals, 1), 2), 3)
        u_vals_e = tf.expand_dims(tf.expand_dims(tf.expand_dims(u_vals, 1), 2), 3)
        y_vals_e = tf.expand_dims(tf.expand_dims(tf.expand_dims(y_vals, 1), 2), 3)
        x_vals_e = tf.expand_dims(tf.expand_dims(tf.expand_dims(x_vals, 1), 2), 3)
        v = tf.tile(tf.reshape(v_vals_e, [-1, 1, 1, 1]), [1, size[1], size[2], size[3]])
        u = tf.tile(tf.reshape(u_vals_e, [1, -1, 1, 1]), [size[0], 1, size[2], size[3]])
        y = tf.tile(tf.reshape(y_vals_e, [1, 1, -1, 1]), [size[0], size[1], 1, size[3]])
        x = tf.tile(tf.reshape(x_vals_e, [1, 1, 1, -1]), [size[0], size[1], size[2], 1])
        
        return v, u, y, x

In [ ]:
def gather_4D(params, indices):
    #gather values from tensor params at indices
    with tf.name_scope('gather_4D'):
        
        flat_params = tf.reshape(params, [-1])
        multipliers = [lfsize[3]*lfsize[2]*lfsize[1], lfsize[3]*lfsize[2], lfsize[3], 1]
        ind_0 = tf.multiply(tf.slice(indices, [0,0], [-1,1]), multipliers[0])
        ind_1 = tf.multiply(tf.slice(indices, [0,1], [-1,1]), multipliers[1])
        ind_2 = tf.multiply(tf.slice(indices, [0,2], [-1,1]), multipliers[2])
        ind_3 = tf.multiply(tf.slice(indices, [0,3], [-1,1]), multipliers[3])
        flat_indices = tf.add_n([ind_0, ind_1, ind_2, ind_3])
                
        return tf.gather(flat_params, flat_indices)

In [ ]:
def lf_reparam_tf(lf, coord):
    #tensorflow forward model to reparameterize light field by moving the two parameterization planes by coord
    with tf.name_scope('lf_reparam_tf'):
        
        z_multiple = 20.0 #scale z coordinates to similar magnitude as x/y coordinates for optimization
                
        #create and reparameterize light field grid
        v_vals = tf.multiply(aperture_res, tf.subtract(tf.cast(tf.range(lfsize[0]), tf.float32), 
                                             tf.divide(tf.subtract(tf.cast(lfsize[0], tf.float32), 1.0), 2.0)))
        u_vals = tf.multiply(aperture_res, tf.subtract(tf.cast(tf.range(lfsize[1]), tf.float32), 
                                             tf.divide(tf.subtract(tf.cast(lfsize[1], tf.float32), 1.0), 2.0)))
        y_vals = tf.subtract(tf.cast(tf.range(lfsize[2]), tf.float32), 
                                tf.divide(tf.subtract(tf.cast(lfsize[2], tf.float32), 1.0), 2.0))
        x_vals = tf.subtract(tf.cast(tf.range(lfsize[3]), tf.float32), 
                                tf.divide(tf.subtract(tf.cast(lfsize[3], tf.float32), 1.0), 2.0))
    
        v, u, y, x = meshgrid_4D(v_vals, u_vals, y_vals, x_vals)
        
        v_r = tf.add(tf.subtract(v, tf.multiply(coord[2]/z_multiple, tf.subtract(v, y))), coord[1])
        u_r = tf.add(tf.subtract(u, tf.multiply(coord[2]/z_multiple, tf.subtract(u, x))), coord[0])
        y_r = tf.add(tf.subtract(y, tf.multiply(coord[2]/z_multiple, tf.subtract(v, y))), coord[1])
        x_r = tf.add(tf.subtract(x, tf.multiply(coord[2]/z_multiple, tf.subtract(u, x))), coord[0])
        
        v_r = tf.add(tf.div(v_r, aperture_res), tf.div(tf.subtract(tf.to_float(lfsize[0]), 1.0), 2.0))
        u_r = tf.add(tf.div(u_r, aperture_res), tf.div(tf.subtract(tf.to_float(lfsize[1]), 1.0), 2.0))
        y_r = tf.add(y_r, tf.divide(tf.subtract(tf.to_float(lfsize[2]), 1.0), 2.0))
        x_r = tf.add(x_r, tf.divide(tf.subtract(tf.to_float(lfsize[3]), 1.0), 2.0))
        
        v_r = tf.reshape(v_r, [-1,1])
        u_r = tf.reshape(u_r, [-1,1])
        y_r = tf.reshape(y_r, [-1,1])
        x_r = tf.reshape(x_r, [-1,1])
    
        v_r_1 = tf.cast(tf.floor(v_r), tf.int32)
        v_r_2 = v_r_1 + 1
        u_r_1 = tf.cast(tf.floor(u_r), tf.int32)
        u_r_2 = u_r_1 + 1
        y_r_1 = tf.cast(tf.floor(y_r), tf.int32)
        y_r_2 = y_r_1 + 1
        x_r_1 = tf.cast(tf.floor(x_r), tf.int32)
        x_r_2 = x_r_1 + 1
        
        v_r_1 = tf.clip_by_value(v_r_1, 0, lfsize[0]-1)
        v_r_2 = tf.clip_by_value(v_r_2, 0, lfsize[0]-1)
        u_r_1 = tf.clip_by_value(u_r_1, 0, lfsize[1]-1)
        u_r_2 = tf.clip_by_value(u_r_2, 0, lfsize[1]-1)
        y_r_1 = tf.clip_by_value(y_r_1, 0, lfsize[2]-1)
        y_r_2 = tf.clip_by_value(y_r_2, 0, lfsize[2]-1)
        x_r_1 = tf.clip_by_value(x_r_1, 0, lfsize[3]-1)
        x_r_2 = tf.clip_by_value(x_r_2, 0, lfsize[3]-1)
        
        #interpolate reparameterized points (quadrilinear)
        interp_pts_1 = tf.concat([v_r_1, u_r_1, y_r_1, x_r_1], 1)
        interp_pts_2 = tf.concat([v_r_2, u_r_1, y_r_1, x_r_1], 1)
        interp_pts_3 = tf.concat([v_r_1, u_r_2, y_r_1, x_r_1], 1)
        interp_pts_4 = tf.concat([v_r_1, u_r_1, y_r_2, x_r_1], 1)
        interp_pts_5 = tf.concat([v_r_1, u_r_1, y_r_1, x_r_2], 1)
        interp_pts_6 = tf.concat([v_r_2, u_r_2, y_r_1, x_r_1], 1)
        interp_pts_7 = tf.concat([v_r_2, u_r_1, y_r_2, x_r_1], 1)
        interp_pts_8 = tf.concat([v_r_2, u_r_1, y_r_1, x_r_2], 1)
        interp_pts_9 = tf.concat([v_r_1, u_r_2, y_r_2, x_r_1], 1)
        interp_pts_10 = tf.concat([v_r_1, u_r_2, y_r_1, x_r_2], 1)
        interp_pts_11 = tf.concat([v_r_1, u_r_1, y_r_2, x_r_2], 1)
        interp_pts_12 = tf.concat([v_r_2, u_r_2, y_r_2, x_r_1], 1)
        interp_pts_13 = tf.concat([v_r_2, u_r_2, y_r_1, x_r_2], 1)
        interp_pts_14 = tf.concat([v_r_2, u_r_1, y_r_2, x_r_2], 1)
        interp_pts_15 = tf.concat([v_r_1, u_r_2, y_r_2, x_r_2], 1)
        interp_pts_16 = tf.concat([v_r_2, u_r_2, y_r_2, x_r_2], 1)
        
        lf_r_1 = gather_4D(tf.squeeze(lf), interp_pts_1)
        lf_r_2 = gather_4D(tf.squeeze(lf), interp_pts_2)
        lf_r_3 = gather_4D(tf.squeeze(lf), interp_pts_3)
        lf_r_4 = gather_4D(tf.squeeze(lf), interp_pts_4)
        lf_r_5 = gather_4D(tf.squeeze(lf), interp_pts_5)
        lf_r_6 = gather_4D(tf.squeeze(lf), interp_pts_6)
        lf_r_7 = gather_4D(tf.squeeze(lf), interp_pts_7)
        lf_r_8 = gather_4D(tf.squeeze(lf), interp_pts_8)
        lf_r_9 = gather_4D(tf.squeeze(lf), interp_pts_9)
        lf_r_10 = gather_4D(tf.squeeze(lf), interp_pts_10)
        lf_r_11 = gather_4D(tf.squeeze(lf), interp_pts_11)
        lf_r_12 = gather_4D(tf.squeeze(lf), interp_pts_12)
        lf_r_13 = gather_4D(tf.squeeze(lf), interp_pts_13)
        lf_r_14 = gather_4D(tf.squeeze(lf), interp_pts_14)
        lf_r_15 = gather_4D(tf.squeeze(lf), interp_pts_15)
        lf_r_16 = gather_4D(tf.squeeze(lf), interp_pts_16)
        
        v_r_1_f = tf.cast(v_r_1, tf.float32)
        v_r_2_f = tf.cast(v_r_2, tf.float32)
        u_r_1_f = tf.cast(u_r_1, tf.float32)
        u_r_2_f = tf.cast(u_r_2, tf.float32)
        y_r_1_f = tf.cast(y_r_1, tf.float32)
        y_r_2_f = tf.cast(y_r_2, tf.float32)
        x_r_1_f = tf.cast(x_r_1, tf.float32)
        x_r_2_f = tf.cast(x_r_2, tf.float32)
        
        d_v_1 = 1.0 - (v_r - v_r_1_f)
        d_v_2 = 1.0 - d_v_1
        d_u_1 = 1.0 - (u_r - u_r_1_f)
        d_u_2 = 1.0 - d_u_1
        d_y_1 = 1.0 - (y_r - y_r_1_f)
        d_y_2 = 1.0 - d_y_1
        d_x_1 = 1.0 - (x_r - x_r_1_f)
        d_x_2 = 1.0 - d_x_1
        
        w1 = tf.multiply(tf.multiply(tf.multiply(d_v_1, d_u_1), d_y_1), d_x_1)
        w2 = tf.multiply(tf.multiply(tf.multiply(d_v_2, d_u_1), d_y_1), d_x_1)
        w3 = tf.multiply(tf.multiply(tf.multiply(d_v_1, d_u_2), d_y_1), d_x_1)
        w4 = tf.multiply(tf.multiply(tf.multiply(d_v_1, d_u_1), d_y_2), d_x_1)
        w5 = tf.multiply(tf.multiply(tf.multiply(d_v_1, d_u_1), d_y_1), d_x_2)
        w6 = tf.multiply(tf.multiply(tf.multiply(d_v_2, d_u_2), d_y_1), d_x_1)
        w7 = tf.multiply(tf.multiply(tf.multiply(d_v_2, d_u_1), d_y_2), d_x_1)
        w8 = tf.multiply(tf.multiply(tf.multiply(d_v_2, d_u_1), d_y_1), d_x_2)
        w9 = tf.multiply(tf.multiply(tf.multiply(d_v_1, d_u_2), d_y_2), d_x_1)
        w10 = tf.multiply(tf.multiply(tf.multiply(d_v_1, d_u_2), d_y_1), d_x_2)
        w11 = tf.multiply(tf.multiply(tf.multiply(d_v_1, d_u_1), d_y_2), d_x_2)
        w12 = tf.multiply(tf.multiply(tf.multiply(d_v_2, d_u_2), d_y_2), d_x_1)
        w13 = tf.multiply(tf.multiply(tf.multiply(d_v_2, d_u_2), d_y_1), d_x_2)
        w14 = tf.multiply(tf.multiply(tf.multiply(d_v_2, d_u_1), d_y_2), d_x_2)
        w15 = tf.multiply(tf.multiply(tf.multiply(d_v_1, d_u_2), d_y_2), d_x_2)
        w16 = tf.multiply(tf.multiply(tf.multiply(d_v_2, d_u_2), d_y_2), d_x_2)
        
        
        lf_r = tf.add_n([w1*lf_r_1, w2*lf_r_2, w3*lf_r_3, w4*lf_r_4, w5*lf_r_5, w6*lf_r_6, w7*lf_r_7, w8*lf_r_8, 
                         w9*lf_r_9, w10*lf_r_10, w11*lf_r_11, w12*lf_r_12, w13*lf_r_13, w14*lf_r_14, w15*lf_r_15, w16*lf_r_16])
        
        lf_r = tf.reshape(lf_r, lfsize)
    
        return lf_r

In [ ]:
def lf_blur_forward_tf(lf, pts_1, pts_2, pts_3, order):
    #tensorflow forward model to blur light field along Bezier curve motion path, with given control points
    #assumes first control point is origin
    #currently only implemented for orders 1 and 2
    b = tf.zeros(lfsize)
    for i in range(num_exp_pts):
        t = np.true_divide(i, num_exp_pts).astype(np.float32)
        coord_2 = lambda: tf.multiply((2.0*t*(1.0-t)).astype(np.float32), pts_1) + tf.multiply(tf.square(t), pts_2) #quadratic
        coord_3 = lambda: tf.multiply((3.0*t*(1.0-t)*(1.0-t)).astype(np.float32), pts_1) +tf.multiply((3.0*t*t*(1.0-t)).astype(np.float32), pts_2) + tf.multiply(tf.pow(t, 3), pts_3) #cubic
        coord = tf.cond(tf.equal(order, tf.constant(2)), coord_2, coord_3)
        b = tf.add(b, lf_reparam_tf(lf, tf.squeeze(coord)))
    b = tf.divide(b, tf.cast(num_exp_pts, tf.float32))
    
    return b

In [ ]:
def data_loss(lf_observed, lf_sharp, pts_1, pts_2, pts_3, order):
    #data term (l2 norm of difference between observed blurred light field and forward model predicted light field)
    return tf.reduce_mean(tf.squared_difference(lf_blur_forward_tf(lf_sharp, pts_1, pts_2, pts_3, order), lf_observed))

In [ ]:
def tv_loss_s(x):
    #spatial total variation loss (l1 norm of spatial derivatives)
    temp = x[:,:,0:lfsize[2]-1,0:lfsize[3]-1]
    dy = (x[:,:,1:lfsize[2],0:lfsize[3]-1] - temp)
    dx = (x[:,:,0:lfsize[2]-1,1:lfsize[3]] - temp)
    l_1 = tf.reduce_mean(tf.abs(dy)+tf.abs(dx))
    return l_1

In [ ]:
def tv_loss_a(x):
    #angular total variation loss (l1 norm of angular derivatives)
    temp = x[0:lfsize[0]-1,0:lfsize[1]-1,:,:]
    dv = (x[1:lfsize[0],0:lfsize[1]-1,:,:] - temp)
    du = (x[0:lfsize[0]-1,1:lfsize[1],:,:] - temp)
    l_1 = tf.reduce_mean(tf.abs(dv)+tf.abs(du))
    return l_1

In [ ]:
def sp_loss_s(x, eps):
    #gradual non-convex approximation of l0 norm of spatial derivatives
    temp = x[:,:,0:lfsize[2]-1,0:lfsize[3]-1]
    dy = tf.abs(x[:,:,1:lfsize[2],0:lfsize[3]-1] - temp)
    dx = tf.abs(x[:,:,0:lfsize[2]-1,1:lfsize[3]] - temp)
    dy_c = tf.clip_by_value((1/(tf.square(eps)))*tf.square(dy), 0.0, 1.0)
    dx_c = tf.clip_by_value((1/(tf.square(eps)))*tf.square(dx), 0.0, 1.0)
    return tf.reduce_mean(dy_c+dx_c)

In [ ]:
def sp_loss_a(x, eps):
    #gradual non-convex approximation of l0 norm of angular derivatives
    temp = x[0:lfsize[0]-1,0:lfsize[1]-1,:,:]
    dv = tf.abs(x[1:lfsize[0],0:lfsize[1]-1,:,:] - temp)
    du = tf.abs(x[0:lfsize[0]-1,1:lfsize[1],:,:] - temp)
    dv_c = tf.clip_by_value((1/(tf.square(eps)))*tf.square(dv), 0.0, 1.0)
    du_c = tf.clip_by_value((1/(tf.square(eps)))*tf.square(du), 0.0, 1.0)
    return tf.reduce_mean(dv_c+du_c)

In [ ]:
def lf_subproblem(pts_1, pts_2, pts_3, lf_blur, num_iters, lam_lf, eta_lf, order, record):
    #solve for latent sharp light field, given observed blurred light field and control points
    lam = tf.constant(lam_lf, dtype=tf.float32)
    lf_blur_placeholder = tf.placeholder(tf.float32, shape=[lfsize[0], lfsize[1], lfsize[2], lfsize[3]])
    pts_1_placeholder = tf.placeholder(tf.float32, shape=[1,3])
    pts_2_placeholder = tf.placeholder(tf.float32, shape=[1,3])
    pts_3_placeholder = tf.placeholder(tf.float32, shape=[1,3])
    lf_var = tf.Variable(tf.constant(lf_blur, dtype=tf.float32), name='lf_var')
    data_term = data_loss(lf_blur_placeholder, lf_var, pts_1_placeholder, pts_2_placeholder, pts_3_placeholder, order)
    prior_loss = (lam*tv_loss_s(lf_var)) + (lam*tv_loss_a(lf_var))
    full_loss = data_term + prior_loss
    train_step_lf = tf.train.AdamOptimizer(learning_rate=eta_lf).minimize(full_loss, var_list=[lf_var])
    
    iter_full_loss = np.zeros((num_iters))

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        for i in range(num_iters):   
            print ('lf subproblem iteration %i'%(i))
            if (record):
                curr_full_loss = full_loss.eval(feed_dict={lf_blur_placeholder:lf_blur, pts_1_placeholder:pts_1, pts_2_placeholder:pts_2, pts_3_placeholder:pts_3})
                iter_full_loss[i] = curr_full_loss        
            sess.run([train_step_lf], feed_dict={lf_blur_placeholder:lf_blur, pts_1_placeholder:pts_1, pts_2_placeholder:pts_2, pts_3_placeholder:pts_3})
            lf_var = tf.clip_by_value(lf_var, 0.0, 1.0)
        return lf_var.eval(), iter_full_loss

Simulate Motion Blur


In [ ]:
#load light field
temp = sp.io.loadmat('Data/Synthetic/lf_12.mat')
lf = normalize_lf(np.sum(np.array(temp['lf']), axis=4)).astype(np.float32)
lf_r = normalize_lf(np.array(temp['lf'])[:,:,:,:,0]).astype(np.float32)
lf_g = normalize_lf(np.array(temp['lf'])[:,:,:,:,1]).astype(np.float32)
lf_b = normalize_lf(np.array(temp['lf'])[:,:,:,:,2]).astype(np.float32)
#simulate motion blurred light field
pts_1_blur = np.array([-2.0, 2.0, 1.0]).astype(np.float32)
pts_2_blur = np.array([-4.0, -4.0, 2.0]).astype(np.float32)
pts_3_blur = np.array([-6.0, 6.0, 3.0]).astype(np.float32)
with tf.Session() as sess:
    lf_blur = lf_blur_forward_tf(lf, pts_1_blur, pts_2_blur, pts_3_blur, tf.constant(path_order)).eval()
    lf_blur_r = lf_blur_forward_tf(lf_r, pts_1_blur, pts_2_blur, pts_3_blur, tf.constant(path_order)).eval()
    lf_blur_g = lf_blur_forward_tf(lf_g, pts_1_blur, pts_2_blur, pts_3_blur, tf.constant(path_order)).eval()
    lf_blur_b = lf_blur_forward_tf(lf_b, pts_1_blur, pts_2_blur, pts_3_blur, tf.constant(path_order)).eval()
tf.reset_default_graph()

Load Real Motion Blur (Illum)


In [ ]:
#load light field
temp = sp.io.loadmat('Data/Real/illum_73.mat')
lf = normalize_lf(np.sum(np.array(temp['lf']), axis=4)).astype(np.float32)
lf_r = normalize_lf(np.array(temp['lf'])[:,:,:,:,0]).astype(np.float32)
lf_g = normalize_lf(np.array(temp['lf'])[:,:,:,:,1]).astype(np.float32)
lf_b = normalize_lf(np.array(temp['lf'])[:,:,:,:,2]).astype(np.float32)

lf_blur = np.copy(lf)
lf_blur_r = np.copy(lf_r)
lf_blur_g = np.copy(lf_g)
lf_blur_b = np.copy(lf_b)

#placeholder blur control points (to avoid error when saving results)
pts_1_blur = np.array([0, 0, 0]).astype(np.float32) 
pts_2_blur = np.array([0, 0, 0]).astype(np.float32) 
pts_3_blur = np.array([0, 0, 0]).astype(np.float32)

Blind Motion Deblurring


In [ ]:
#set up tensorflow graph
lam = tf.placeholder(tf.float32)
eps = tf.placeholder(tf.float32)
order = tf.placeholder(tf.int32)
lf_blur_placeholder = tf.placeholder(tf.float32, shape=[lfsize[0], lfsize[1], lfsize[2], lfsize[3]])
lf_var = tf.Variable(tf.constant(lf_blur), name='lf_var')
pts_1_var = tf.Variable(tf.zeros([1,3]), name='pts_1_var') #first non-origin control point
pts_2_var = tf.Variable(tf.zeros([1,3]), name='pts_2_var') #second non-origin control point
pts_3_var = tf.Variable(tf.zeros([1,3]), name='pts_3_var') #third non-origin control point
with tf.variable_scope('loss'):
    data_term = data_loss(lf_blur_placeholder, lf_var, pts_1_var, pts_2_var, pts_3_var, order)
    prior_loss = (lam*sp_loss_s(lf_var, eps)) + (lam*sp_loss_a(lf_var, eps))
    full_loss = data_term + prior_loss
with tf.variable_scope('train'):
    train_step_lf = tf.train.AdamOptimizer(learning_rate=eta_lf).minimize(full_loss, var_list=[lf_var])
    train_step_pts = tf.train.AdamOptimizer(learning_rate=eta_pts).minimize(full_loss, var_list=[pts_1_var, pts_2_var, pts_3_var])
    train_step = tf.group(train_step_lf, train_step_pts)
    
#losses to record
iter_data_loss = np.zeros((num_iters))
iter_prior_loss = np.zeros((num_iters))
iter_full_loss = np.zeros((num_iters))
lf_mse = np.zeros((num_iters)) #only useful for synthetic examples

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    lam_curr = lam_init
    eps_curr = eps_init
    for i in range(num_iters):
        #print interation information
        print ('iteration %i'%(i))
        print ('current points')
        print pts_1_var.eval()
        print pts_2_var.eval()
        print pts_3_var.eval()
        #calculate losses
        curr_data_loss = data_term.eval(feed_dict={lf_blur_placeholder:lf_blur, lam:lam_curr, eps:eps_curr, order:path_order})
        curr_prior_loss = prior_loss.eval(feed_dict={lf_blur_placeholder:lf_blur, lam:lam_curr, eps:eps_curr, order:path_order})
        curr_full_loss = full_loss.eval(feed_dict={lf_blur_placeholder:lf_blur, lam:lam_curr, eps:eps_curr, order:path_order})
        iter_data_loss[i] = curr_data_loss
        iter_prior_loss[i] = curr_prior_loss
        iter_full_loss[i] = curr_full_loss
        lf_mse[i] = np.mean(np.square(lf - lf_var.eval()))
        #run tensorflow session (optimization step)
        sess.run([train_step], feed_dict={lf_blur_placeholder:lf_blur, lam:lam_curr, eps:eps_curr, order:path_order})
        #project latent sharp light field to non-negative values    
        lf_var = tf.clip_by_value(lf_var, 0.0, 1.0)
        #decay regularization weight and increase regularization non-convexity
        lam_curr = np.clip(lam_curr*lam_decay, lam_min, lam_init)
        eps_curr = np.clip(eps_curr*eps_decay, eps_min, eps_init)
    #evaluate final latent sharp light field and control points
    lf_final = lf_var.eval()
    pts_1_final = pts_1_var.eval()
    pts_2_final = pts_2_var.eval()
    pts_3_final = pts_3_var.eval()

In [ ]:
#plot losses
plt.figure()
plt.plot(iter_full_loss[0,:])
plt.figure()
plt.plot(iter_full_loss[1,:])
plt.figure()
plt.plot(iter_data_loss[0,:])
plt.figure()
plt.plot(iter_data_loss[1,:])
plt.figure()
plt.plot(iter_prior_loss[0,:])
plt.figure()
plt.plot(iter_prior_loss[1,:])
plt.figure()
plt.plot(lf_mse[0,:])
plt.figure()
plt.plot(lf_mse[1,:])

In [ ]:
#solve LF subproblem for each color channel for final light field estimate
tf.reset_default_graph()
lf_deblur_r, loss_deblur_r = lf_subproblem(pts_1_final, pts_2_final, pts_3_final, lf_blur_r, iters_lf, lam_lf, eta_lf, path_order, False)
tf.reset_default_graph()
lf_deblur_g, loss_deblur_g = lf_subproblem(pts_1_final, pts_2_final, pts_3_final, lf_blur_g, iters_lf, lam_lf, eta_lf, path_order, False)
tf.reset_default_graph()
lf_deblur_b, loss_deblur_b = lf_subproblem(pts_1_final, pts_2_final, pts_3_final, lf_blur_b, iters_lf, lam_lf, eta_lf, path_order, False)

In [ ]:
#save inputs and results
deblur_dict = {'lf_in_r':lf_blur_r, 'lf_in_g':lf_blur_g, 'lf_in_b':lf_blur_b, 
             'lf_out_r':lf_deblur_r, 'lf_out_g':lf_deblur_g, 'lf_out_b':lf_deblur_b,
             'pts_1_in':pts_1_blur, 'pts_2_in':pts_2_blur, 'pts_3_in':pts_3_blur,
             'pts_1_out':pts_1_final, 'pts_2_out':pts_2_final, 'pts_3_out':pts_3_final}
sp.io.savemat('deblur_results.mat', deblur_dict)