Inference on pretrained models.

Install Jupyter to open this file.

You can use this to run inference on pretrained models saved to Google Cloud, or modify it to do inference on your own models.


In [ ]:
import io
import os
import numpy as np
import scipy
import scipy.misc
import sys
import tensorflow as tf
import skimage
import math
import matplotlib.pyplot as plt

from io import StringIO
from PIL import Image

from tensorflow.python.platform import gfile

import prediction_input
import prediction_model

def save_png(image_array, path):
  """Saves an image to disk.

  Args:
    image_array: numpy array of shape [image_size, image_size, 3].
    path: str, output file.
  """
  buf = io.BytesIO()
  scipy.misc.imsave(buf, image_array, format='png')
  buf.seek(0)
  f = tf.gfile.GFile(path, 'w')
  f.write(buf.getvalue())
  f.close()

In [ ]:
COLOR_CHAN = 3
IMG_WIDTH = 64
IMG_HEIGHT = 64

IMAGE_FEATURE_NAME = "images"
JOINT_POSE_FEATURE_NAME = "joint_poses"
ACTION_FEATURE_NAME = "actions"

def get_input_fn_queue_determ(pattern, batch_size, flags):
  def input_fn(params=None):
    """Input function using queues for GPU, always returns examples in the same order."""
    del params
    filenames = gfile.Glob(os.path.join(flags.data_dir, pattern))
    if not filenames:
      raise RuntimeError('No data files found.')
    filename_queue = tf.train.string_input_producer(filenames, shuffle=False)
    reader = tf.TFRecordReader()

    _, val = reader.read(filename_queue)
    serialized_input = tf.reshape(val, shape=[1])

    image_seq = None

    for i in range(0, flags.sequence_length, flags.skip_num):
      image_name = 'image_' + str(i)

      if flags.dataset_type == 'robot':
        pose_name = 'state_' + str(i)
        action_name = 'action_' + str(i)
        joint_pos_name = 'joint_positions_' + str(i)
        features = {
            pose_name:
                tf.FixedLenFeature([flags.pose_dim], tf.float32),
            image_name:
                tf.FixedLenFeature([1], tf.string),
            action_name:
                tf.FixedLenFeature([flags.pose_dim], tf.float32),
            joint_pos_name:
                tf.FixedLenFeature([flags.joint_pos_dim], tf.float32)
        }
      else:
        features = {
            image_name: tf.FixedLenFeature([1], tf.string),
        }

      parsed_input = tf.parse_example(serialized_input, features)

      # Process image
      image_buffer = tf.reshape(parsed_input[image_name], shape=[])
      image = tf.image.decode_jpeg(image_buffer, channels=COLOR_CHAN)
      image = tf.image.resize_images(
          image, (IMG_HEIGHT, IMG_WIDTH), method=tf.image.ResizeMethod.BICUBIC)
      image = tf.cast(tf.expand_dims(image, 0), tf.float32) / 255.0

      if flags.dataset_type == 'robot':
        pose = tf.reshape(parsed_input[pose_name], shape=[flags.pose_dim])
        pose = tf.expand_dims(pose, 0)
        action = tf.reshape(parsed_input[action_name], shape=[flags.pose_dim])
        action = tf.expand_dims(action, 0)
        joint_pos = tf.reshape(
            parsed_input[joint_pos_name], shape=[flags.joint_pos_dim])
        joint_pos = tf.expand_dims(joint_pos, 0)
      else:
        pose = tf.zeros([1, flags.pose_dim])
        action = tf.zeros([1, flags.pose_dim])
        joint_pos = tf.zeros([1, flags.joint_pos_dim])

      if i == 0:
        image_seq = image
        action_seq, pose_seq, joint_pos_seq = action, pose, joint_pos
      else:
        image_seq = tf.concat([image_seq, image], 0)
        action_seq = tf.concat([action_seq, action], 0)
        pose_seq = tf.concat([pose_seq, pose], 0)
        joint_pos_seq = tf.concat([joint_pos_seq, joint_pos], 0)

    [images, actions, poses, joint_pos] = tf.train.batch(
        [image_seq, action_seq, pose_seq, joint_pos_seq],
        batch_size,
        enqueue_many=False,
        capacity=100 * batch_size)

    print(flags.sequence_length)
    joint_poses = tf.concat([joint_pos, poses], 2)

    output_features = {
        IMAGE_FEATURE_NAME: images,
        JOINT_POSE_FEATURE_NAME: joint_poses,
        ACTION_FEATURE_NAME: actions
    }

    return output_features, None
  return input_fn

In [ ]:
def get_flags(dataset_type):
  import prediction_train
  
  FLAGS = prediction_train.FLAGS
  try:
      tf.app.flags.DEFINE_string('f', '', 'kernel')
  except:
    pass
  
  FLAGS.is_training = False
  FLAGS.use_tpu = False
  FLAGS.use_image_summary = False
  FLAGS.dataset_type = dataset_type
  FLAGS.use_legacy_vars = True
  
  if dataset_type == "robot":
    FLAGS.data_dir="<Your download path>"
    FLAGS.sequence_length = 20
    FLAGS.skip_num = 1
    FLAGS.context_frames = 2
    FLAGS.use_image_summary = False
  else:
    FLAGS.data_dir="gs://unsupervised-hierarch-video/data/"
    FLAGS.sequence_length = 256
    FLAGS.skip_num = 2
    FLAGS.context_frames = 5
    FLAGS.use_image_summary = False
    FLAGS.use_legacy_vars = False
    
  return FLAGS

In [ ]:
dataset_type = "human"

def get_images(model_dir, flags, num_to_eval=100, pattern="humans-test"):

  run_config = tf.contrib.learn.RunConfig(
    model_dir=set_model_dir,
  )

  estimator = tf.estimator.Estimator(
    model_fn=prediction_model.make_model_fn(flags), config=run_config)

  predictions = estimator.predict(
          input_fn=get_input_fn_queue_determ(pattern, 8, flags))
    
  num_evals = 0
  van_out_psnr_all = []
  van_on_enc_psnr_all = []
  print(predictions)

  all_runs = []

  for prediction in predictions:
    all_rows = {}
    gt_images = prediction["gt_images"] #[1:]
    #van_on_enc = prediction["van_on_enc_all"]
    mask_out = prediction["mask_out_all"]
    van_out = prediction["van_out_all"]
    
    gt_images_row = []
    van_out_row = []
    mask_out_row = []
    for frame_i in range(len(van_out)):
      van_out_row.append(van_out[frame_i])
      mask_rgb = np.tile(mask_out[frame_i], [1, 1, 3])
      mask_out_row.append(mask_rgb)
    for frame_i in range(len(gt_images)):
      gt_images_row.append(gt_images[frame_i])
    all_rows["gt_images"] = gt_images_row
    all_rows["van_out"]= van_out_row
    all_rows["mask_out"] = mask_out_row
    #all_rows["van_on_enc"]= van_on_enc

    all_runs.append(all_rows)

    num_evals += 1
    if num_evals >= num_to_eval:
      break
      
  del predictions
      
  return all_runs

In [ ]:
# Change this to your path to save the images.
base_dir = "/mnt/brain6/scratch/rubville/projects/unsupervised-hierarch-video-prediction/gen_frames/"

def save_imgs(images, folder, key="van_out"):
  for run_num in range(len(images)):
    sys.stdout.flush()
    frame_nums = range(len(images[run_num][key]))
    sys.stdout.flush()
    
    dir_path = os.path.join(folder, str(run_num))
    if not os.path.exists(dir_path):
      os.makedirs(dir_path)

    for frame_i in frame_nums:     
      frame = images[run_num][key][frame_i]
      #frame = scipy.misc.imresize(frame, 4.0)
      save_name = frame_i
      if key == "gt_images":
        # Make the number of the ground truth frames line up with the predicted frames.
        save_name = frame_i - 1
      save_png(frame,  os.path.join(dir_path, "frame"+str(save_name)+'.png'))

In [ ]:
# Run to save results from EPVA Gan
# This code will take a while to run since it hast to construct a large graph.
# Decrease flags.sequence_length for a faster runtime.

flags = get_flags(dataset_type)

flags.enc_pred_use_l2norm = True
flags.enc_size = 64
flags.pred_noise_std = 1.0
set_model_dir = "gs://unsupervised-hierarch-video/pretrained_models/epva_wgan_human/"
# flags.sequence_length = 64 # Comment out to repo the results in the paper.
all_runs_epva_wgan = get_images(set_model_dir, flags, num_to_eval=1000)
save_imgs(all_runs_epva_wgan, os.path.join(base_dir, "human_epva_wgan_frames"), key="van_out")
save_imgs(all_runs_epva_wgan, os.path.join(base_dir, "human_epva_wgan_masks"), key="mask_out")
# Also saves the ground truth images.
save_imgs(all_runs_epva_wgan, os.path.join(base_dir, "human_gt"), key="gt_images")
all_runs_epva_wgan = None
del all_runs_epva_wgan

In [ ]:
# Run to save results from EPVA

flags = get_flags(dataset_type)

flags.enc_pred_use_l2norm = False
flags.enc_size = 64
flags.pred_noise_std = 0
flags.sequence_length = 64 # Comment out to repo the results in the paper.
set_model_dir = "gs://unsupervised-hierarch-video/pretrained_models/epva_human/"
all_runs_epva = get_images(set_model_dir, flags, num_to_eval=1000)
save_imgs(all_runs_epva, os.path.join(base_dir, "human_epva"), key="van_out")
all_runs_epva = None
del all_runs_epva

In [ ]:
# Run to save results from E2E

flags = get_flags(dataset_type)

flags.enc_pred_use_l2norm = False
flags.enc_size = 32
flags.use_legacy_vars = True
flags.sequence_length = 64 # Comment out to repo the results in the paper.
set_model_dir = "gs://unsupervised-hierarch-video/pretrained_models/e2e_human/"
all_runs_e2e = get_images(set_model_dir, flags, num_to_eval=1000)
save_imgs(all_runs_e2e, os.path.join(base_dir, "human_e2e"), key="van_out")
all_runs_e2e = None
del all_runs_e2e

In [ ]: