In [4]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
import numpy as np
np.set_printoptions(threshold=np.nan)
import tensorflow as tf
import time
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import progressbar
import iqa_tools
In [5]:
sess = tf.Session()
new_saver = tf.train.import_meta_graph('ssimNET0.meta')
new_saver.restore(sess, tf.train.latest_checkpoint('./'))
In [6]:
graph = tf.get_default_graph()
op_restore = graph.get_tensor_by_name('BiasAdd_4:0')
x = graph.get_tensor_by_name('Placeholder:0')
In [7]:
data_path = '/home/dirty_mike/Dropbox/github/image_quality_analysis/data/sample_data/'
train_features, train_target, test_features, test_target = iqa_tools.load_data(local=True, path=data_path)
In [8]:
train_features.shape
Out[8]:
In [12]:
pred = sess.run(op_restore, feed_dict={x: test_features})
In [13]:
pred.shape
Out[13]:
In [20]:
plt.figure(figsize = (16,12))
gs1 = gridspec.GridSpec(3, 4)
gs1.update(wspace=0, hspace=0.03)
for j in range(3):
index = np.random.randint(pred.shape[0])
ax1, ax2, ax3, ax4 = plt.subplot(gs1[4*j]), plt.subplot(gs1[4*j+1]), plt.subplot(gs1[4*j+2]), plt.subplot(gs1[4*j+3])
for ax in [ax1, ax2, ax3, ax4]:
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
if j == 0:
ax1.set_title('original', size=20)
ax2.set_title('reconstructed', size=20)
ax3.set_title('ssim', size=20)
ax4.set_title('ssim net prediction', size=20)
ax1.imshow(test_features[index,:,:,0], cmap='gray')
ax2.imshow(test_features[index,:,:,1], cmap='gray')
ax3.imshow(test_target[index,:,:,0], cmap='plasma')
ax4.imshow(pred[index,:,:,0], cmap='plasma')
plt.savefig('prediction.png')
plt.show()
dooooopppppeee
In [26]:
weights1 = sess.run(graph.get_tensor_by_name('weights1:0'))
weights1.shape
Out[26]:
In [44]:
plt.figure(figsize = (16,12))
gs1 = gridspec.GridSpec(3, 4)
gs1.update(wspace=0, hspace=0.03)
for j in range(3):
index = np.random.randint(100)
ax1, ax2, ax3, ax4 = plt.subplot(gs1[4*j]), plt.subplot(gs1[4*j+1]), plt.subplot(gs1[4*j+2]), plt.subplot(gs1[4*j+3])
for ax in [ax1, ax2, ax3, ax4]:
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
if j == 0:
ax1.set_title('original', size=20)
ax2.set_title('recon', size=20)
ax3.set_title('original sq', size=20)
ax4.set_title('recon sq', size=20)
ax1.imshow(weights1[:,:,0,index], cmap='gray')
ax2.imshow(weights1[:,:,1,index], cmap='gray')
ax3.imshow(weights1[:,:,2,index], cmap='gray')
ax4.imshow(weights1[:,:,3,index], cmap='gray')
plt.savefig('trained_filters.png')
plt.show()