In [ ]:
import os
import json
import sys
import random
sys.path.insert(0, '.')
import tensorflow as tf
import numpy as np
from models import grid_nets, im_nets, model_dlsm
from ops import conv_rnns
from mvnet import MVNet
from utils import Bunch, get_session_config
In [ ]:
# Get pretrained models
!sh get_models.sh
In [ ]:
# Get sample data
!sh download_sample.sh
SAMPLE_DIR = os.path.join('data', 'shapenet_sample')
In [ ]:
im_dir = os.path.join(SAMPLE_DIR, 'renders')
log_dir = os.path.join('models_lsm_v1/dlsm-release/train')
with open(os.path.join(log_dir, 'args.json'), 'r') as f:
args = json.load(f)
args = Bunch(args)
In [ ]:
# Setup TF graph and initialize VLSM model
tf.reset_default_graph()
# Change the ims_per_model to run on different number of views
bs, ims_per_model = 1, 8
ckpt = 'mvnet-200000'
net = MVNet(vmin=-0.5, vmax=0.5, vox_bs=bs,
im_bs=ims_per_model, grid_size=args.nvox,
im_h=args.im_h, im_w=args.im_w,
norm=args.norm, mode="TEST")
net = model_dlsm(
net,
im_nets[args.im_net],
grid_nets[args.grid_net],
conv_rnns[args.rnn],
im_skip=args.im_skip,
ray_samples=args.ray_samples,
sepup=args.sepup,
proj_x=args.proj_x,
proj_last=True)
vars_restore = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='MVNet')
sess = tf.InteractiveSession(config=get_session_config())
saver = tf.train.Saver(var_list=vars_restore)
saver.restore(sess, os.path.join(log_dir, ckpt))
In [ ]:
from shapenet import ShapeNet
# Read data
dset = ShapeNet(im_dir=im_dir, split_file=os.path.join(SAMPLE_DIR, 'splits_sample.json'), rng_seed=1)
test_mids = dset.get_smids('test')
In [ ]:
# Run the last three cells to run on different inputs
rand_sid, rand_mid = random.choice(test_mids) # Select model to test
rand_views = np.random.choice(dset.num_renders, size=(net.im_batch, ), replace=False) # Select views of model to test
# Load images and cameras
ims = dset.load_func['im'](rand_sid, rand_mid, rand_views)
ims = np.expand_dims(ims, 0)
R = dset.load_func['R'](rand_sid, rand_mid, rand_views)
R = np.expand_dims(R, 0)
K = dset.load_func['K'](rand_sid, rand_mid, rand_views)
K = np.expand_dims(K, 0)
In [ ]:
# Run DLSM
feed_dict = {net.K: K, net.Rcam: R, net.ims: ims}
pred_depth = sess.run(net.depth_out, feed_dict=feed_dict)[:, 0, ...]
In [ ]:
from vis_utils import image_grid
# Visualize views
import matplotlib.pyplot as plt
%matplotlib inline
plt.figure(figsize=(16, 10))
plt.subplot(2, 1, 1)
plt.imshow(image_grid(ims))
plt.title('Views: {:s} / {:s}'.format(dset.splits[rand_sid]['name'], rand_mid))
plt.axis('off')
plt.subplot(2, 1, 2)
mask = np.logical_and(pred_depth > 2 - 0.5 * np.sqrt(3), pred_depth < 2 + 0.5 * np.sqrt(3))
plt.imshow(image_grid(pred_depth, mask=mask)[..., 0])
plt.title('Predicted Depths')
plt.axis('off')
plt.show()
In [ ]:
from IPython.display import display
from IPython.core.display import HTML
from vis_utils import unproject_depth, plot_points
# Visualize unprojected point cloud. Feel free to play around with the model!
pts, clr = [], []
dmin, dmax = 2.0 - 0.5*np.sqrt(3), 2.0 + 0.5*np.sqrt(3)
for ix in range(pred_depth.shape[1]):
dpts, dclr = unproject_depth(pred_depth[0, ix, ..., 0], K[0, ix], R[0, ix], im=ims[0, ix], dmin=dmin, dmax=dmax)
pts.append(dpts)
clr.append(dclr)
pts = np.concatenate(pts, axis=0)
clr = np.concatenate(clr, axis=0)
display(plot_points(pts, clr, size=0.005, title='Predicted Point Cloud'))
# Center outputs
HTML("""
<style>
.output {
align-items: center;
}
</style>
""")