Install Jupyter to open this file.
You can use this to run inference on pretrained models saved to Google Cloud, or modify it to do inference on your own models.
In [ ]:
import io
import os
import numpy as np
import scipy
import scipy.misc
import sys
import tensorflow as tf
import skimage
import math
import matplotlib.pyplot as plt
from io import StringIO
from PIL import Image
from tensorflow.python.platform import gfile
import prediction_input
import prediction_model
def save_png(image_array, path):
"""Saves an image to disk.
Args:
image_array: numpy array of shape [image_size, image_size, 3].
path: str, output file.
"""
buf = io.BytesIO()
scipy.misc.imsave(buf, image_array, format='png')
buf.seek(0)
f = tf.gfile.GFile(path, 'w')
f.write(buf.getvalue())
f.close()
In [ ]:
COLOR_CHAN = 3
IMG_WIDTH = 64
IMG_HEIGHT = 64
IMAGE_FEATURE_NAME = "images"
JOINT_POSE_FEATURE_NAME = "joint_poses"
ACTION_FEATURE_NAME = "actions"
def get_input_fn_queue_determ(pattern, batch_size, flags):
def input_fn(params=None):
"""Input function using queues for GPU, always returns examples in the same order."""
del params
filenames = gfile.Glob(os.path.join(flags.data_dir, pattern))
if not filenames:
raise RuntimeError('No data files found.')
filename_queue = tf.train.string_input_producer(filenames, shuffle=False)
reader = tf.TFRecordReader()
_, val = reader.read(filename_queue)
serialized_input = tf.reshape(val, shape=[1])
image_seq = None
for i in range(0, flags.sequence_length, flags.skip_num):
image_name = 'image_' + str(i)
if flags.dataset_type == 'robot':
pose_name = 'state_' + str(i)
action_name = 'action_' + str(i)
joint_pos_name = 'joint_positions_' + str(i)
features = {
pose_name:
tf.FixedLenFeature([flags.pose_dim], tf.float32),
image_name:
tf.FixedLenFeature([1], tf.string),
action_name:
tf.FixedLenFeature([flags.pose_dim], tf.float32),
joint_pos_name:
tf.FixedLenFeature([flags.joint_pos_dim], tf.float32)
}
else:
features = {
image_name: tf.FixedLenFeature([1], tf.string),
}
parsed_input = tf.parse_example(serialized_input, features)
# Process image
image_buffer = tf.reshape(parsed_input[image_name], shape=[])
image = tf.image.decode_jpeg(image_buffer, channels=COLOR_CHAN)
image = tf.image.resize_images(
image, (IMG_HEIGHT, IMG_WIDTH), method=tf.image.ResizeMethod.BICUBIC)
image = tf.cast(tf.expand_dims(image, 0), tf.float32) / 255.0
if flags.dataset_type == 'robot':
pose = tf.reshape(parsed_input[pose_name], shape=[flags.pose_dim])
pose = tf.expand_dims(pose, 0)
action = tf.reshape(parsed_input[action_name], shape=[flags.pose_dim])
action = tf.expand_dims(action, 0)
joint_pos = tf.reshape(
parsed_input[joint_pos_name], shape=[flags.joint_pos_dim])
joint_pos = tf.expand_dims(joint_pos, 0)
else:
pose = tf.zeros([1, flags.pose_dim])
action = tf.zeros([1, flags.pose_dim])
joint_pos = tf.zeros([1, flags.joint_pos_dim])
if i == 0:
image_seq = image
action_seq, pose_seq, joint_pos_seq = action, pose, joint_pos
else:
image_seq = tf.concat([image_seq, image], 0)
action_seq = tf.concat([action_seq, action], 0)
pose_seq = tf.concat([pose_seq, pose], 0)
joint_pos_seq = tf.concat([joint_pos_seq, joint_pos], 0)
[images, actions, poses, joint_pos] = tf.train.batch(
[image_seq, action_seq, pose_seq, joint_pos_seq],
batch_size,
enqueue_many=False,
capacity=100 * batch_size)
print(flags.sequence_length)
joint_poses = tf.concat([joint_pos, poses], 2)
output_features = {
IMAGE_FEATURE_NAME: images,
JOINT_POSE_FEATURE_NAME: joint_poses,
ACTION_FEATURE_NAME: actions
}
return output_features, None
return input_fn
In [ ]:
def get_flags(dataset_type):
import prediction_train
FLAGS = prediction_train.FLAGS
try:
tf.app.flags.DEFINE_string('f', '', 'kernel')
except:
pass
FLAGS.is_training = False
FLAGS.use_tpu = False
FLAGS.use_image_summary = False
FLAGS.dataset_type = dataset_type
FLAGS.use_legacy_vars = True
if dataset_type == "robot":
FLAGS.data_dir="<Your download path>"
FLAGS.sequence_length = 20
FLAGS.skip_num = 1
FLAGS.context_frames = 2
FLAGS.use_image_summary = False
else:
FLAGS.data_dir="gs://unsupervised-hierarch-video/data/"
FLAGS.sequence_length = 256
FLAGS.skip_num = 2
FLAGS.context_frames = 5
FLAGS.use_image_summary = False
FLAGS.use_legacy_vars = False
return FLAGS
In [ ]:
dataset_type = "human"
def get_images(model_dir, flags, num_to_eval=100, pattern="humans-test"):
run_config = tf.contrib.learn.RunConfig(
model_dir=set_model_dir,
)
estimator = tf.estimator.Estimator(
model_fn=prediction_model.make_model_fn(flags), config=run_config)
predictions = estimator.predict(
input_fn=get_input_fn_queue_determ(pattern, 8, flags))
num_evals = 0
van_out_psnr_all = []
van_on_enc_psnr_all = []
print(predictions)
all_runs = []
for prediction in predictions:
all_rows = {}
gt_images = prediction["gt_images"] #[1:]
#van_on_enc = prediction["van_on_enc_all"]
mask_out = prediction["mask_out_all"]
van_out = prediction["van_out_all"]
gt_images_row = []
van_out_row = []
mask_out_row = []
for frame_i in range(len(van_out)):
van_out_row.append(van_out[frame_i])
mask_rgb = np.tile(mask_out[frame_i], [1, 1, 3])
mask_out_row.append(mask_rgb)
for frame_i in range(len(gt_images)):
gt_images_row.append(gt_images[frame_i])
all_rows["gt_images"] = gt_images_row
all_rows["van_out"]= van_out_row
all_rows["mask_out"] = mask_out_row
#all_rows["van_on_enc"]= van_on_enc
all_runs.append(all_rows)
num_evals += 1
if num_evals >= num_to_eval:
break
del predictions
return all_runs
In [ ]:
# Change this to your path to save the images.
base_dir = "/mnt/brain6/scratch/rubville/projects/unsupervised-hierarch-video-prediction/gen_frames/"
def save_imgs(images, folder, key="van_out"):
for run_num in range(len(images)):
sys.stdout.flush()
frame_nums = range(len(images[run_num][key]))
sys.stdout.flush()
dir_path = os.path.join(folder, str(run_num))
if not os.path.exists(dir_path):
os.makedirs(dir_path)
for frame_i in frame_nums:
frame = images[run_num][key][frame_i]
#frame = scipy.misc.imresize(frame, 4.0)
save_name = frame_i
if key == "gt_images":
# Make the number of the ground truth frames line up with the predicted frames.
save_name = frame_i - 1
save_png(frame, os.path.join(dir_path, "frame"+str(save_name)+'.png'))
In [ ]:
# Run to save results from EPVA Gan
# This code will take a while to run since it hast to construct a large graph.
# Decrease flags.sequence_length for a faster runtime.
flags = get_flags(dataset_type)
flags.enc_pred_use_l2norm = True
flags.enc_size = 64
flags.pred_noise_std = 1.0
set_model_dir = "gs://unsupervised-hierarch-video/pretrained_models/epva_wgan_human/"
# flags.sequence_length = 64 # Comment out to repo the results in the paper.
all_runs_epva_wgan = get_images(set_model_dir, flags, num_to_eval=1000)
save_imgs(all_runs_epva_wgan, os.path.join(base_dir, "human_epva_wgan_frames"), key="van_out")
save_imgs(all_runs_epva_wgan, os.path.join(base_dir, "human_epva_wgan_masks"), key="mask_out")
# Also saves the ground truth images.
save_imgs(all_runs_epva_wgan, os.path.join(base_dir, "human_gt"), key="gt_images")
all_runs_epva_wgan = None
del all_runs_epva_wgan
In [ ]:
# Run to save results from EPVA
flags = get_flags(dataset_type)
flags.enc_pred_use_l2norm = False
flags.enc_size = 64
flags.pred_noise_std = 0
flags.sequence_length = 64 # Comment out to repo the results in the paper.
set_model_dir = "gs://unsupervised-hierarch-video/pretrained_models/epva_human/"
all_runs_epva = get_images(set_model_dir, flags, num_to_eval=1000)
save_imgs(all_runs_epva, os.path.join(base_dir, "human_epva"), key="van_out")
all_runs_epva = None
del all_runs_epva
In [ ]:
# Run to save results from E2E
flags = get_flags(dataset_type)
flags.enc_pred_use_l2norm = False
flags.enc_size = 32
flags.use_legacy_vars = True
flags.sequence_length = 64 # Comment out to repo the results in the paper.
set_model_dir = "gs://unsupervised-hierarch-video/pretrained_models/e2e_human/"
all_runs_e2e = get_images(set_model_dir, flags, num_to_eval=1000)
save_imgs(all_runs_e2e, os.path.join(base_dir, "human_e2e"), key="van_out")
all_runs_e2e = None
del all_runs_e2e
In [ ]: