In [1]:
# 논문 Learning a Deep Convolutional Network for Image Super-Resolution
# 의 알고리즘대로 구현
# 2017. 07. 06

import tensorflow as tf
from PIL import Image
import numpy as np

#이미지, 상수들
learning_rate=1e-4
W1=192
H1=144
W15=360
H15=480
W2=960
H2=720
path="../06/"
pref1="360p/"
pref2="720p/"
suff1="_360.jpg"
suff2="_720.jpg"
train_num=500#1000
file_num=2#6#30
#batch_num=1000


#가중치 초기화 함수
def weight_variable(shape, name):
  initial = tf.truncated_normal(shape, stddev=0.1)
  return tf.Variable(initial, name=name)
#절편 초기화 함수
def bias_variable(shape, name):
  initial = tf.constant(0.1, shape=shape)
  return tf.Variable(initial, name=name)
#2D 컨벌루션 실행
def conv2d(x, W, B):
  return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')+B

def getimage(idx):
    img_1=Image.open(path+pref1+str(idx)+suff1)

    img_1_720 = img_1.resize((H2, W2), Image.BICUBIC)
    array_1_720=np.array(img_1_720)[:, :]
    array_1_720=array_1_720.astype(np.float32)

    img_2=Image.open(path+pref2+str(idx)+suff2)
    array_2=np.array(img_2)[:, :, 0:3]
    array_2=array_2.astype(np.float32)
    return array_1_720, array_2

def l_relu(x, alpha=0.):
    return tf.nn.relu(x)-alpha*tf.nn.relu(-x)

def asImage(tensor):
    result = tensor[0].astype(np.uint8)
    return Image.fromarray(result, 'RGB')

def showres(index, steps):
    test360, test720 = getimage(index)
    A=sess.run(y_result, feed_dict={x_image:[test360], y_image:[test720]})
    result = A.astype(np.uint8)
    #Image.fromarray(array360, 'RGB').save('results/img360.jpg')
    #Image.fromarray(array720, 'RGB').save('results/img720.jpg')
    asImage(result).save('results/result_06_'+str(steps)+'.jpg')


x_image = tf.placeholder(np.float32, shape=[None, H2, W2, 3])
y_image = tf.placeholder(np.float32, shape=[None, H2, W2, 3])

W1 = weight_variable([11, 11, 3, 15], name = 'W1')
B1 = bias_variable([15], name = 'B1')
W2 = weight_variable([1, 1, 15, 12], name = 'W2')
B2 = bias_variable([12], name='B2')
W3 = weight_variable([5, 5, 12, 3], name = 'W3')
B3 = bias_variable([3], name = 'B3')

W4 = weight_variable([11, 11, 3, 15], name = 'W4')
B4 = bias_variable([15], name = 'B4')
W5 = weight_variable([1, 1, 15, 12], name = 'W5')
B5 = bias_variable([12], name='B5')
W6 = weight_variable([5, 5, 12, 3], name = 'W6')
B6 = bias_variable([3], name = 'B6')

F1 = l_relu(conv2d(x_image, W1, B1), alpha = 0.3)
F2 = l_relu(conv2d(F1, W2, B2), alpha = 0.3)
y_result = tf.nn.relu(x_image+l_relu(conv2d(F2, W3, B3), alpha = 0.3))

cost = tf.reduce_mean(tf.square(y_image-y_result))
train_step = tf.train.AdamOptimizer(learning_rate).minimize(cost)
saver = tf.train.Saver()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
#saver.restore(sess, "01/models.ckpt")

for steps in range(train_num):
    for index in range(1, file_num):
        ##array360은 360p인 이미지를 확대해서 720p으로 만들어놓은것.
        array360, array720 = getimage(index)
        #asImage([array720]).show()
        sess.run(train_step, feed_dict={x_image:[array360], y_image:[array720]})
    print (str(steps).zfill(3), sess.run(cost, feed_dict={x_image:[array360], y_image:[array720]}))
    if(steps%5==0):
        showres(index, steps)
print ("끝났다")


000 219.04
001 196.244
002 176.736
003 160.233
004 146.442
005 134.989
006 125.411
007 117.458
008 110.671
009 104.965
010 100.142
011 96.057
012 92.592
013 89.6539
014 87.1493
015 85.0066
016 83.1712
017 81.5998
018 80.2485
019 79.082
020 78.0695
021 77.1809
022 76.3982
023 75.6993
024 75.0699
025 74.494
026 73.9641
027 73.4696
028 73.0045
029 72.5618
030 72.1354
031 71.7239
032 71.3218
033 70.9304
034 70.5487
035 70.1774
036 69.8162
037 69.4625
038 69.1166
039 68.777
040 68.4437
041 68.1184
042 67.8002
043 67.4883
044 67.1813
045 66.88
046 66.5843
047 66.2932
048 66.007
049 65.7241
050 65.4453
051 65.1709
052 64.9001
053 64.6321
054 64.3688
055 64.1081
056 63.8515
057 63.5991
058 63.3495
059 63.1031
060 62.8595
061 62.6186
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-1-eee30d751168> in <module>()
     90         #asImage([array720]).show()
     91         sess.run(train_step, feed_dict={x_image:[array360], y_image:[array720]})
---> 92     print (str(steps).zfill(3), sess.run(cost, feed_dict={x_image:[array360], y_image:[array720]}))
     93     if(steps%5==0):
     94         showres(index, steps)

/home/alpha/.local/lib/python3.5/site-packages/tensorflow/python/client/session.py in run(self, fetches, feed_dict, options, run_metadata)
    776     try:
    777       result = self._run(None, fetches, feed_dict, options_ptr,
--> 778                          run_metadata_ptr)
    779       if run_metadata:
    780         proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

/home/alpha/.local/lib/python3.5/site-packages/tensorflow/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
    980     if final_fetches or final_targets:
    981       results = self._do_run(handle, final_targets, final_fetches,
--> 982                              feed_dict_string, options, run_metadata)
    983     else:
    984       results = []

/home/alpha/.local/lib/python3.5/site-packages/tensorflow/python/client/session.py in _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)
   1030     if handle is None:
   1031       return self._do_call(_run_fn, self._session, feed_dict, fetch_list,
-> 1032                            target_list, options, run_metadata)
   1033     else:
   1034       return self._do_call(_prun_fn, self._session, handle, feed_dict,

/home/alpha/.local/lib/python3.5/site-packages/tensorflow/python/client/session.py in _do_call(self, fn, *args)
   1037   def _do_call(self, fn, *args):
   1038     try:
-> 1039       return fn(*args)
   1040     except errors.OpError as e:
   1041       message = compat.as_text(e.message)

/home/alpha/.local/lib/python3.5/site-packages/tensorflow/python/client/session.py in _run_fn(session, feed_dict, fetch_list, target_list, options, run_metadata)
   1019         return tf_session.TF_Run(session, options,
   1020                                  feed_dict, fetch_list, target_list,
-> 1021                                  status, run_metadata)
   1022 
   1023     def _prun_fn(session, handle, feed_dict, fetch_list):

KeyboardInterrupt: 

In [ ]: