In [1]:
import tensorflow as tf
import numpy as np
sess = tf.InteractiveSession()
In [3]:
class batch_norm(object):
def __init__(self, epsilon=1e-5, momentum = 0.9, name="batch_norm"):
with tf.variable_scope(name):
self.epsilon = epsilon
self.momentum = momentum
self.name = name
def __call__(self, x, train=True):
return tf.contrib.layers.batch_norm(x,
decay=self.momentum,
updates_collections=None,
epsilon=self.epsilon,
scale=True,
scope=self.name)
def conv2d(input_, output_dim,
k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02,
name="conv2d"):
with tf.variable_scope(name):
w = tf.get_variable('w', [k_h, k_w, input_.get_shape()[-1], output_dim],
initializer=tf.truncated_normal_initializer(stddev=stddev))
conv = tf.nn.conv2d(input_, w, strides=[1, d_h, d_w, 1], padding='SAME')
biases = tf.get_variable('biases', [output_dim], initializer=tf.constant_initializer(0.0))
conv = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape())
return conv
def deconv2d(input_, output_shape,
k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02,
name="deconv2d", with_w=False):
with tf.variable_scope(name):
# filter : [height, width, output_channels, in_channels]
w = tf.get_variable('w', [k_h, k_w, output_shape[-1], input_.get_shape()[-1]],
initializer=tf.random_normal_initializer(stddev=stddev))
try:
deconv = tf.nn.conv2d_transpose(input_, w, output_shape=output_shape,
strides=[1, d_h, d_w, 1])
# Support for verisons of TensorFlow before 0.7.0
except AttributeError:
deconv = tf.nn.deconv2d(input_, w, output_shape=output_shape,
strides=[1, d_h, d_w, 1])
biases = tf.get_variable('biases', [output_shape[-1]], initializer=tf.constant_initializer(0.0))
deconv = tf.reshape(tf.nn.bias_add(deconv, biases), deconv.get_shape())
if with_w:
return deconv, w, biases
else:
return deconv
def lrelu(x, leak=0.2, name="lrelu"):
return tf.maximum(x, leak*x)
In [4]:
dim = 16
batch_size = 32
X = tf.placeholder(tf.float32, [batch_size, 256, 256, 3]) # RGB 通道图
X_ = tf.placeholder(tf.float32, [batch_size, 1024, 1024, 3]) # RGB 高清通道图
# encoder
bn1 = batch_norm(name='bn1')
bn2 = batch_norm(name='bn2')
bn3 = batch_norm(name='bn3')
h0 = lrelu(conv2d(X, dim, name='h0_conv'))
h1 = lrelu(bn1(conv2d(h0, dim*2, name='h1_conv')))
h2 = lrelu(bn2(conv2d(h1, dim*4, name='h2_conv')))
h3 = lrelu(bn3(conv2d(h2, dim*8, name='h3_conv')))
# decoder area
dbn1 = batch_norm(name='dbn1')
dbn2 = batch_norm(name='dbn2')
dbn3 = batch_norm(name='dbn3')
dh1, h1_w, h1_b = deconv2d(h3,
[batch_size, 32, 32, dim*4], name='dh1', with_w=True)
h1_ = tf.nn.relu(dbn1(dh1))
dh2, h2_w, h2_b = deconv2d(h1_,
[batch_size, 128, 128, dim*2], name='dh2', with_w=True, d_h=4, d_w=4)
h2_ = tf.nn.relu(dbn2(dh2))
dh3, h3_w, h3_b = deconv2d(h2_,
[batch_size, 512, 512, dim], name='dh3', with_w=True, d_h=4, d_w=4)
h3_ = tf.nn.relu(dbn3(dh3))
dh, h_w, h_b = deconv2d(h3_,
[batch_size, 1024, 1024, 3], name='dh', with_w=True)
h_ = tf.nn.sigmoid(dh)
# loss
loss = tf.nn.sigmoid_cross_entropy_with_logits(h_, X_)
In [5]:
# test the net
print h1_, h2_, h3_, h_
sess.run(tf.global_variables_initializer())
h_.eval(feed_dict={X: np.random.random((32, 256, 256, 3))}).shape
In [7]:
In [ ]: