In [18]:
"""
A small Jupyter demo of the fast image stylization. To use, install jupyter and from this
directory run 'jupyter notebook'
"""
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import matplotlib.pyplot as plt
%matplotlib inline
import ast
import os
import sys
import random
import numpy as np
import tensorflow as tf
from six.moves.urllib.request import urlopen
from magenta.models.image_stylization import image_utils
from magenta.models.image_stylization import model
def DownloadCheckpointFiles(checkpoint_dir='checkpoints'):
"""Download checkpoint files if necessary."""
url_prefix = 'http://download.magenta.tensorflow.org/models/'
checkpoints = ['multistyle-pastiche-generator-monet.ckpt', 'multistyle-pastiche-generator-varied.ckpt']
for checkpoint in checkpoints:
full_checkpoint = os.path.join(checkpoint_dir, checkpoint)
if not os.path.exists(full_checkpoint):
print('Downloading {}'.format(full_checkpoint))
response = urlopen(url_prefix + checkpoint)
data = response.read()
with open(full_checkpoint, 'wb') as fh:
fh.write(data)
# Select an image (any jpg or png).
example_path = os.path.dirname(sys.modules['magenta.models.image_stylization'].__file__)
input_image = os.path.join(example_path, 'evaluation_images/guerrillero_heroico.jpg')
# Select a demo ('varied' or 'monet')
demo = 'varied'
# create 'checkpoints' directory if it doesn't exist
if not os.path.isdir('checkpoints'):
os.makedirs('checkpoints')
DownloadCheckpointFiles()
image = np.expand_dims(image_utils.load_np_image(
os.path.expanduser(input_image)), 0)
if demo == 'monet':
checkpoint = 'checkpoints/multistyle-pastiche-generator-monet.ckpt'
num_styles = 10 # Number of images in checkpoint file. Do not change.
elif demo == 'varied':
checkpoint = 'checkpoints/multistyle-pastiche-generator-varied.ckpt'
num_styles = 32 # Number of images in checkpoint file. Do not change.
# Styles from checkpoint file to render. They are done in batch, so the more
# rendered, the longer it will take and the more memory will be used.
# These can be modified as you like. Here we randomly select six styles.
styles = list(range(num_styles))
random.shuffle(styles)
which_styles = styles[0:6]
num_rendered = len(which_styles)
with tf.Graph().as_default(), tf.Session() as sess:
stylized_images = model.transform(
tf.concat([image for _ in range(num_rendered)], 0),
normalizer_params={
'labels': tf.constant(which_styles),
'num_categories': num_styles,
'center': True,
'scale': True})
model_saver = tf.train.Saver(tf.global_variables())
model_saver.restore(sess, checkpoint)
stylized_images = stylized_images.eval()
# Plot the images.
counter = 0
num_cols = 3
f, _ = plt.subplots(num_rendered // num_cols, num_cols, figsize=(25, 25))
for counter, axis in enumerate(f.axes):
axis.imshow(stylized_images[counter])
axis.set_xlabel('Style %i' % which_styles[counter])
In [ ]: