In [1]:
import cv2
from super_resolution import *
%matplotlib inline
In [2]:
tf.reset_default_graph()
In [ ]:
# variables
height = 41
width = 41
channel = 3
stride = (21,21)
resize_scale=2
img_list = []
img_list = load_img_list_and_extract_patch_list("./data/urban_hr", size=(width, height),stride=stride)
img_list.extend(img_list)
X_all = np.array(blur_img_list(img_list, resize_scale))
Y_all = np.array(img_list)
mean_img = np.mean(Y_all, axis=0)
std_img = np.std(Y_all, axis=0)
train_ratio = 0.8
valid_ratio = 0.1
test_ratio = 0.1
data_num = len(X_all)
train_data_num = round(data_num * train_ratio)
valid_data_num = round(data_num * valid_ratio)
test_data_num = round(data_num * test_ratio)
X_train = X_all[:train_data_num]
Y_train = Y_all[:train_data_num]
X_valid = X_all[train_data_num:train_data_num + valid_data_num]
Y_valid = Y_all[train_data_num:train_data_num + valid_data_num]
X_test = X_all[train_data_num + valid_data_num:train_data_num + valid_data_num + test_data_num]
Y_test = Y_all[train_data_num + valid_data_num:train_data_num + valid_data_num + test_data_num]
# data shuffle
rand_idx = np.random.permutation(range(len(Y_valid)))
X_valid = X_valid[rand_idx]
Y_valid = Y_valid[rand_idx]
rand_idx = np.random.permutation(range(len(Y_test)))
X_test = X_test[rand_idx]
Y_test = Y_test[rand_idx]
In [ ]:
img_idx = 2
fig, axs = plt.subplots(1, 2, figsize=(20, 6))
axs[0].imshow(X_train[img_idx])
axs[0].set_title("X_train")
axs[1].imshow(Y_train[img_idx])
axs[1].set_title("Y_train")
In [ ]:
print("num data")
print("="*30)
print("length X_train : {} ". format(len(X_train)))
In [9]:
sess = tf.Session()
usrcnn = USRCNN(sess)
In [ ]:
usrcnn.load(sess,'./model/weights','./model/meta_data.pickle')
In [ ]:
usrcnn.train(X_train, Y_train, X_valid=X_valid[:200], Y_valid=Y_valid[:200],
batch_size=64, n_epoch=100, start_learning_rate=0.005, save_dir_path='./model')
In [ ]:
import os
def load_imgs_2(dir_path):
name_list = os.listdir(dir_path)
img_list = []
for name in name_list:
img_path = "{}/{}".format(dir_path, name)
img = cv2.imread(img_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img_list.append(img)
return img_list
In [3]:
from super_resolution import load_img_list
In [4]:
test_img_list = load_img_list("./data/urban_hr")
In [18]:
test_idx = 52
In [19]:
plt.imshow(test_img_list[test_idx])
Out[19]:
In [26]:
test_img = test_img_list[test_idx]
#test_img = test_img[200:350,250:400]
test_img_resized = cv2.resize(test_img, (round(test_img.shape[1]/2), round(test_img.shape[0]/2)))
test_img_resized = cv2.resize(test_img_resized, (test_img.shape[1],test_img.shape[0]), interpolation = cv2.INTER_CUBIC)
# test_img = Y_test[test_idx]
# test_img_resized = X_test[test_idx]
In [27]:
usrcnn.load(sess,'./model/weights','./model/meta_data.pickle')
In [28]:
result = usrcnn.enhance_resolution(test_img_resized)
In [29]:
fig, axs = plt.subplots(1, 3, figsize=(20, 6))
axs[0].imshow(test_img)
axs[0].set_title("origin")
axs[1].imshow(test_img_resized)
axs[1].set_title("query")
axs[2].imshow(result)
axs[2].set_title("reconstructed image")
Out[29]:
In [30]:
plt.imshow(result[300:400,300:400])
Out[30]:
In [31]:
plt.imshow(test_img_resized[300:400,300:400])
Out[31]:
In [ ]:
plt.imshow(usrcnn.mean_img.astype(np.uint8))
In [ ]:
plt.imshow(usrcnn.std_img.astype(np.uint8))
In [ ]:
usrcnn.std_img
In [ ]:
img = X_valid[1]
In [ ]:
plt.imshow(img)
In [ ]:
In [ ]: