In [1]:
import tensorflow as tf
import numpy as np

sess = tf.InteractiveSession()

In [3]:
class batch_norm(object):
    def __init__(self, epsilon=1e-5, momentum = 0.9, name="batch_norm"):
        with tf.variable_scope(name):
            self.epsilon  = epsilon
            self.momentum = momentum
            self.name = name

    def __call__(self, x, train=True):
        return tf.contrib.layers.batch_norm(x,
                                            decay=self.momentum, 
                                            updates_collections=None,
                                            epsilon=self.epsilon,
                                            scale=True,
                                            scope=self.name)

def conv2d(input_, output_dim, 
           k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02,
           name="conv2d"):
    with tf.variable_scope(name):
        w = tf.get_variable('w', [k_h, k_w, input_.get_shape()[-1], output_dim],
                            initializer=tf.truncated_normal_initializer(stddev=stddev))
        conv = tf.nn.conv2d(input_, w, strides=[1, d_h, d_w, 1], padding='SAME')

        biases = tf.get_variable('biases', [output_dim], initializer=tf.constant_initializer(0.0))
        conv = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape())

        return conv

def deconv2d(input_, output_shape,
             k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02,
             name="deconv2d", with_w=False):
    with tf.variable_scope(name):
        # filter : [height, width, output_channels, in_channels]
        w = tf.get_variable('w', [k_h, k_w, output_shape[-1], input_.get_shape()[-1]],
                            initializer=tf.random_normal_initializer(stddev=stddev))
        
        try:
            deconv = tf.nn.conv2d_transpose(input_, w, output_shape=output_shape,
                                strides=[1, d_h, d_w, 1])

        # Support for verisons of TensorFlow before 0.7.0
        except AttributeError:
            deconv = tf.nn.deconv2d(input_, w, output_shape=output_shape,
                                strides=[1, d_h, d_w, 1])

        biases = tf.get_variable('biases', [output_shape[-1]], initializer=tf.constant_initializer(0.0))
        deconv = tf.reshape(tf.nn.bias_add(deconv, biases), deconv.get_shape())

        if with_w:
            return deconv, w, biases
        else:
            return deconv

def lrelu(x, leak=0.2, name="lrelu"):
    return tf.maximum(x, leak*x)

In [4]:
dim = 16
batch_size = 32
X = tf.placeholder(tf.float32, [batch_size, 256, 256, 3]) # RGB 通道图
X_ = tf.placeholder(tf.float32, [batch_size, 1024, 1024, 3]) # RGB 高清通道图

# encoder
bn1 = batch_norm(name='bn1')
bn2 = batch_norm(name='bn2')
bn3 = batch_norm(name='bn3')

h0 = lrelu(conv2d(X, dim, name='h0_conv'))
h1 = lrelu(bn1(conv2d(h0, dim*2, name='h1_conv')))
h2 = lrelu(bn2(conv2d(h1, dim*4, name='h2_conv')))
h3 = lrelu(bn3(conv2d(h2, dim*8, name='h3_conv')))

# decoder area
dbn1 = batch_norm(name='dbn1')
dbn2 = batch_norm(name='dbn2')
dbn3 = batch_norm(name='dbn3')

dh1, h1_w, h1_b = deconv2d(h3, 
    [batch_size, 32, 32, dim*4], name='dh1', with_w=True)
h1_ = tf.nn.relu(dbn1(dh1))

dh2, h2_w, h2_b = deconv2d(h1_, 
    [batch_size, 128, 128, dim*2], name='dh2', with_w=True, d_h=4, d_w=4)
h2_ = tf.nn.relu(dbn2(dh2))

dh3, h3_w, h3_b = deconv2d(h2_, 
    [batch_size, 512, 512, dim], name='dh3', with_w=True, d_h=4, d_w=4)
h3_ = tf.nn.relu(dbn3(dh3))

dh, h_w, h_b = deconv2d(h3_, 
    [batch_size, 1024, 1024, 3], name='dh', with_w=True)
h_ = tf.nn.sigmoid(dh)

# loss
loss = tf.nn.sigmoid_cross_entropy_with_logits(h_, X_)

In [5]:
# test the net
print h1_, h2_, h3_, h_
sess.run(tf.global_variables_initializer())
h_.eval(feed_dict={X: np.random.random((32, 256, 256, 3))}).shape


Tensor("Relu:0", shape=(32, 32, 32, 64), dtype=float32) Tensor("Relu_1:0", shape=(32, 128, 128, 32), dtype=float32) Tensor("Relu_2:0", shape=(32, 512, 512, 16), dtype=float32) Tensor("Sigmoid:0", shape=(32, 1024, 1024, 3), dtype=float32)

In [7]:


In [ ]: