In [1]:
import tensorflow as tf
from nets import nets_factory
from preprocessing import preprocessing_factory
import reader
import model
import time
import os
import numpy as np

height = 0
width = 0
with open('./camera.jpg', 'rb') as img:
    with tf.Session().as_default() as sess:
        #if FLAGS.image_file.lower().endswith('png'):
        #    image = sess.run(tf.image.decode_png(img.read()))
        #else:
        image = sess.run(tf.image.decode_jpeg(img.read()))
        height = image.shape[0]
        width = image.shape[1]
tf.logging.info('Image size: %dx%d' % (width, height))


INFO:tensorflow:Image size: 252x252

In [3]:
with tf.Graph().as_default():
    with tf.Session().as_default() as sess:
        image_preprocessing_fn, _ = preprocessing_factory.get_preprocessing(
            'vgg_19',
            is_training=False)
        image = reader.get_image('./camera.jpg', height, width, image_preprocessing_fn)
        image = tf.expand_dims(image, 0)
        generated = model.net(image, training=False)
        generated = tf.squeeze(generated, [0])
        saver = tf.train.Saver(tf.global_variables())
        sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()])
        #FLAGS.model_file = os.path.abspath(FLAGS.model_file)
        saver.restore(sess, './denoised_starry.ckpt-done')

        start_time = time.time()
        generated = sess.run(generated)
        generated = tf.cast(generated, tf.uint8)
        end_time = time.time()
        tf.logging.info('Elapsed time: %fs' % (end_time - start_time))
        generated_file = 'generated/camerafangao.jpg'
        if os.path.exists('generated') is False:
            os.makedirs('generated')
        with open(generated_file, 'wb') as img:
            img.write(sess.run(tf.image.encode_jpeg(generated)))
            tf.logging.info('Done. Please check %s.' % generated_file)


INFO:tensorflow:Elapsed time: 1.455744s
INFO:tensorflow:Done. Please check generated/camerafangao.jpg.

In [3]:
%matplotlib inline
import cv2
from matplotlib import pyplot as plt

img = cv2.imread('generated/fangao.jpg')
plt.imshow(img)


Out[3]:
<matplotlib.image.AxesImage at 0x1196fee10>

In [ ]: