In [1]:
import numpy as np
import tensorflow as tf
from scipy import misc
from sklearn.model_selection import train_test_split
import time
import os
import matplotlib.pyplot as plt
import seaborn as sns
sns.set(style = "whitegrid", palette = "muted")
import matplotlib.gridspec as gridspec
import csv
import pandas as pd

In [2]:
def gettype(location):
    
    with open(location + '/type', 'r') as fopen:
        type_pokemon = fopen.read().split('\n')
        type_pokemon = [i.split('\t')[4:] for i in type_pokemon]
        
        for i in xrange(len(type_pokemon)):
            if len(type_pokemon[i]) == 1:
                type_pokemon[i].append('none')
        
        type_pokemon = np.array(type_pokemon)
        
        type_list = np.array(np.unique(type_pokemon[:, 0]).tolist() + np.unique(type_pokemon[:, 1]).tolist())
        
        return type_pokemon, np.unique(type_list).tolist()
        
def getpictures(location):
    
    list_folder = os.listdir(location)
    list_folder = [int(i.replace('.png', '')) for i in list_folder]
    list_folder.sort()
    list_folder = [str(i) + '.png' for i in list_folder]
    return list_folder

def generategraph(x, accuracy, lost):
    
    fig = plt.figure(figsize = (10, 5))
    
    plt.subplot(1, 2, 1)
    
    plt.plot(x, lost)
    plt.xlabel('Epoch')
    plt.ylabel('lost')
    plt.title('LOST')
    
    plt.subplot(1, 2, 2)
    
    plt.plot(x, accuracy)
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title('ACCURACY')
    
    fig.tight_layout()
    plt.show()

In [3]:
class Model:
    
    def __init__(self, dimension_picture, learning_rate, dimension_output):
        
        self.X = tf.placeholder(tf.float32, (None, dimension_picture, dimension_picture, 4))
        self.Y_1 = tf.placeholder(tf.float32, (None, dimension_output))
        self.Y_2 = tf.placeholder(tf.float32, (None, dimension_output))
        
        def convolutionize(x, w):
            return tf.nn.conv2d(input = x, filter = w, strides = [1, 1, 1, 1], padding = 'SAME')
        
        def pooling(wx):
            return tf.nn.max_pool(wx, ksize = [1, 2, 2, 1], strides = [1, 2, 2, 1], padding = 'SAME')
        
        first_W_conv = tf.Variable(tf.random_normal([5, 5, 4, 64], stddev = 0.5))
        first_b_conv = tf.Variable(tf.random_normal([64], stddev = 0.1))
        first_hidden_conv = tf.nn.relu(convolutionize(self.X, first_W_conv) + first_b_conv)
        first_hidden_pool = pooling(first_hidden_conv)
        
        second_W_conv = tf.Variable(tf.random_normal([5, 5, 64, 32], stddev = 0.5))
        second_b_conv = tf.Variable(tf.random_normal([32], stddev = 0.1))
        second_hidden_conv = tf.nn.relu(convolutionize(first_hidden_pool, second_W_conv) + second_b_conv)
        second_hidden_pool = pooling(second_hidden_conv)
        
        third_W_conv = tf.Variable(tf.random_normal([5, 5, 32, 16], stddev = 0.5))
        third_b_conv = tf.Variable(tf.random_normal([16], stddev = 0.1))
        third_hidden_conv = tf.nn.relu(convolutionize(second_hidden_pool, third_W_conv) + third_b_conv)
        third_hidden_pool = pooling(third_hidden_conv)
        
        fourth_W_conv = tf.Variable(tf.random_normal([5, 5, 16, 8], stddev = 0.5))
        fourth_b_conv = tf.Variable(tf.random_normal([8], stddev = 0.1))
        fourth_hidden_conv = tf.nn.relu(convolutionize(third_hidden_pool, fourth_W_conv) + fourth_b_conv)
        fourth_hidden_pool = pooling(fourth_hidden_conv)

        first_linear_W = tf.Variable(tf.random_normal([4 * 4 * 8, 128], stddev = 0.5))
        first_linear_b = tf.Variable(tf.random_normal([128], stddev = 0.1))
        fifth_hidden_flatted = tf.reshape(fourth_hidden_pool, [-1, 4 * 4 * 8])
        linear_layer = tf.nn.relu(tf.matmul(fifth_hidden_flatted, first_linear_W) + first_linear_b)
        
        W_1 = tf.Variable(tf.random_normal([128, dimension_output], stddev = 0.5))
        b_1 = tf.Variable(tf.random_normal([dimension_output], stddev = 0.1))
        
        W_2 = tf.Variable(tf.random_normal([128, dimension_output], stddev = 0.5))
        b_2 = tf.Variable(tf.random_normal([dimension_output], stddev = 0.1))
        
        self.y_hat_1 = tf.matmul(linear_layer, W_1) + b_1
        self.y_hat_2 = tf.matmul(linear_layer, W_2) + b_2
        
        self.cost_1 = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits = self.y_hat_1, labels = self.Y_1))
        self.cost_2 = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits = self.y_hat_2, labels = self.Y_2))
        
        self.cost = self.cost_1 + self.cost_2
        
        self.optimizer = tf.train.AdagradOptimizer(learning_rate).minimize(self.cost)
        
        correct_prediction_1 = tf.equal(tf.argmax(self.y_hat_1, 1), tf.argmax(self.Y_1, 1))
        self.accuracy_1 = tf.reduce_mean(tf.cast(correct_prediction_1, "float"))
        
        correct_prediction_2 = tf.equal(tf.argmax(self.y_hat_2, 1), tf.argmax(self.Y_2, 1))
        self.accuracy_2 = tf.reduce_mean(tf.cast(correct_prediction_2, "float"))

In [4]:
current_location = os.getcwd()
learning_rate = 0.001
epoch = 2500
batch_size = 5
split_percentage = 0.2
test_number = 10
type_pokemon, unique_type = gettype(current_location)
pokemon_pictures = getpictures(current_location + '/pokemon')
output_dimension = len(unique_type)
picture_dimension = 64

pokemon_pictures_train, pokemon_pictures_test, pokemon_types_train, pokemon_types_test = train_test_split(pokemon_pictures, 
                                                                                                          type_pokemon, 
                                                                                                          test_size = split_percentage)

In [5]:
sess = tf.InteractiveSession()
model = Model(picture_dimension, learning_rate, output_dimension)
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver(tf.global_variables())

try:
    saver.restore(sess, current_location + "/model.ckpt")
    print "load model.."
except:
    print "start from fresh variables"


INFO:tensorflow:Restoring parameters from /home/husein/space/pokemon-type/model.ckpt
start from fresh variables

In [6]:
ACCURACY, EPOCH, LOST = [], [], []
    
for i in xrange(epoch):
    total_cost = 0
    total_accuracy = 0
    last_time = time.time()
    EPOCH.append(i)
        
    for k in xrange(0, (len(pokemon_pictures_train) // batch_size) * batch_size, batch_size):
            
        emb_data = np.zeros((batch_size, picture_dimension, picture_dimension, 4), dtype = np.float32)
        emb_data_label_1 = np.zeros((batch_size, output_dimension), dtype = np.float32)
        emb_data_label_2 = np.zeros((batch_size, output_dimension), dtype = np.float32)
            
        for x in xrange(batch_size):
                
            image = misc.imread(current_location + '/pokemon/' + pokemon_pictures_train[k + x])
            image = misc.imresize(image, (picture_dimension, picture_dimension))
            emb_data_label_1[x, unique_type.index(pokemon_types_train[k + x, 0])] = 1.0
            emb_data_label_2[x, unique_type.index(pokemon_types_train[k + x, 1])] = 1.0
                
            emb_data[x, :, :, :] = image
            
        _, loss = sess.run([model.optimizer, model.cost], feed_dict = {model.X : emb_data, model.Y_1 : emb_data_label_1, model.Y_2 : emb_data_label_2})
        accuracy_1, accuracy_2 = sess.run([model.accuracy_1, model.accuracy_2], feed_dict = {model.X : emb_data, model.Y_1 : emb_data_label_1, model.Y_2 : emb_data_label_2})
        total_cost += loss
        total_accuracy += ((accuracy_1 + accuracy_2) / 2.0) 
        
    total_accuracy /= (len(pokemon_pictures_train) // batch_size)
    total_cost /= (len(pokemon_pictures_train) // batch_size)
    ACCURACY.append(total_accuracy)
    LOST.append(total_cost)
    
    if (i + 1) % 50 == 0:
        print "epoch: " + str(i + 1) + ", loss: " + str(total_cost) + ", accuracy: " + str(total_accuracy) + ", s / batch: " + str((time.time() - last_time) / (len(pokemon_pictures_train) // batch_size))
        saver.save(sess, current_location + "/model.ckpt")


epoch: 50, loss: 2445043.2377, accuracy: 0.281967219393, s / batch: 0.00851168788847
epoch: 100, loss: 1615285.45389, accuracy: 0.326229516234, s / batch: 0.0084548035606
epoch: 150, loss: 1254219.81865, accuracy: 0.357377058048, s / batch: 0.00857311389485
epoch: 200, loss: 1034156.81378, accuracy: 0.391803289168, s / batch: 0.00849288408874
epoch: 250, loss: 878505.336578, accuracy: 0.426229519556, s / batch: 0.00853291495902
epoch: 300, loss: 758454.909324, accuracy: 0.460655747745, s / batch: 0.00846195220947
epoch: 350, loss: 663250.191214, accuracy: 0.485245914855, s / batch: 0.00849534253605
epoch: 400, loss: 583138.13864, accuracy: 0.503278703475, s / batch: 0.00849055462196
epoch: 450, loss: 517479.384766, accuracy: 0.509836076835, s / batch: 0.0084807208327
epoch: 500, loss: 460625.810403, accuracy: 0.549180341793, s / batch: 0.00846767034687
epoch: 550, loss: 411705.417162, accuracy: 0.572131162784, s / batch: 0.00850280386503
epoch: 600, loss: 368906.780289, accuracy: 0.60655739366, s / batch: 0.00849873902368
epoch: 650, loss: 332390.697234, accuracy: 0.593442639122, s / batch: 0.00849502985595
epoch: 700, loss: 300170.120232, accuracy: 0.62295083726, s / batch: 0.00850255762944
epoch: 750, loss: 272683.775544, accuracy: 0.642622961861, s / batch: 0.00847809822833
epoch: 800, loss: 247918.055232, accuracy: 0.670491818033, s / batch: 0.00848911629349
epoch: 850, loss: 225767.634231, accuracy: 0.721311488601, s / batch: 0.00847518248636
epoch: 900, loss: 206018.16019, accuracy: 0.718032801738, s / batch: 0.00849123079269
epoch: 950, loss: 188350.509512, accuracy: 0.726229518163, s / batch: 0.00850821323082
epoch: 1000, loss: 172920.316253, accuracy: 0.737704927315, s / batch: 0.00853024545263
epoch: 1050, loss: 158048.739298, accuracy: 0.755737715569, s / batch: 0.00852191643637
epoch: 1100, loss: 145316.934232, accuracy: 0.767213122767, s / batch: 0.0084828196979
epoch: 1150, loss: 133588.913075, accuracy: 0.786885255673, s / batch: 0.0084517861976
epoch: 1200, loss: 122670.98167, accuracy: 0.79508197503, s / batch: 0.00846818236054
epoch: 1250, loss: 112701.878448, accuracy: 0.798360665314, s / batch: 0.00849770717934
epoch: 1300, loss: 102890.85578, accuracy: 0.806557387602, s / batch: 0.00849644082492
epoch: 1350, loss: 93988.5902422, accuracy: 0.826229515623, s / batch: 0.00846160435286
epoch: 1400, loss: 85646.165101, accuracy: 0.837704924775, s / batch: 0.00845116474589
epoch: 1450, loss: 77916.5365258, accuracy: 0.842622957269, s / batch: 0.0084653408801
epoch: 1500, loss: 70945.1953826, accuracy: 0.834426232537, s / batch: 0.0084547879266
epoch: 1550, loss: 64950.0038569, accuracy: 0.854098362024, s / batch: 0.00851769134647
epoch: 1600, loss: 59128.5014638, accuracy: 0.873770491999, s / batch: 0.00852559042759
epoch: 1650, loss: 53841.9543208, accuracy: 0.86721311534, s / batch: 0.00846470379439
epoch: 1700, loss: 49026.3736593, accuracy: 0.891803278298, s / batch: 0.00849149266227
epoch: 1750, loss: 44955.7656867, accuracy: 0.881967211844, s / batch: 0.00845365836972
epoch: 1800, loss: 40819.5069032, accuracy: 0.888524588014, s / batch: 0.0085196065121
epoch: 1850, loss: 37098.7844782, accuracy: 0.893442622951, s / batch: 0.00852191643637
epoch: 1900, loss: 33605.1337534, accuracy: 0.909836061665, s / batch: 0.00849244242809
epoch: 1950, loss: 30689.0631416, accuracy: 0.904918032591, s / batch: 0.00844675204793
epoch: 2000, loss: 28042.5949062, accuracy: 0.903278685984, s / batch: 0.00854498441102
epoch: 2050, loss: 25424.5343565, accuracy: 0.909836062642, s / batch: 0.0084936696975
epoch: 2100, loss: 23083.6408957, accuracy: 0.921311476192, s / batch: 0.00850349176125
epoch: 2150, loss: 20993.099367, accuracy: 0.931147538248, s / batch: 0.00860282241321
epoch: 2200, loss: 19122.2475821, accuracy: 0.940983601281, s / batch: 0.00856262347737
epoch: 2250, loss: 17495.1458798, accuracy: 0.937704917837, s / batch: 0.00855098396051
epoch: 2300, loss: 15971.103852, accuracy: 0.94918032357, s / batch: 0.00847742596611
epoch: 2350, loss: 14502.0796643, accuracy: 0.957377044881, s / batch: 0.00851955961009
epoch: 2400, loss: 13182.6744004, accuracy: 0.950819665291, s / batch: 0.00851493585305
epoch: 2450, loss: 11883.868016, accuracy: 0.952459013853, s / batch: 0.00852031004233
epoch: 2500, loss: 10679.247534, accuracy: 0.967213110846, s / batch: 0.00859698311227

In [7]:
generategraph(EPOCH, ACCURACY, LOST)



In [8]:
num_print = int(np.sqrt(len(pokemon_pictures_test)))
fig = plt.figure(figsize = (1.5 * num_print, 1.5 * num_print))
    
for k in xrange(0, num_print * num_print):
        
    plt.subplot(num_print, num_print, k + 1)
        
    emb_data = np.zeros((1, picture_dimension, picture_dimension, 4), dtype = np.float32)
            
    image = misc.imread(current_location + '/pokemon/' + pokemon_pictures_test[k])
    image = misc.imresize(image, (picture_dimension, picture_dimension))
                
    emb_data[0, :, :, :] = image
           
    y_hat_1, y_hat_2 = sess.run([model.y_hat_1, model.y_hat_2], feed_dict = {model.X : emb_data})
        
    label_1 = unique_type[np.argmax(y_hat_1[0])]
    label_2 = unique_type[np.argmax(y_hat_2[0])]
        
    plt.imshow(image)
    plt.title(label_1 + " + " + label_2)
    
fig.tight_layout()
plt.show()



In [9]:
list_folder = os.listdir(current_location + '/diamond-pearl')
    
num_print = int(np.sqrt(len(list_folder)))
fig = plt.figure(figsize = (1.5 * num_print, 1.5 * num_print))
    
for k in xrange(0, num_print * num_print):
        
    plt.subplot(num_print, num_print, k + 1)
        
    emb_data = np.zeros((1, picture_dimension, picture_dimension, 4), dtype = np.float32)
            
    image = misc.imread(current_location + '/diamond-pearl/' + list_folder[k])
    image = misc.imresize(image, (picture_dimension, picture_dimension))
                
    emb_data[0, :, :, :] = image
           
    y_hat_1, y_hat_2 = sess.run([model.y_hat_1, model.y_hat_2], feed_dict = {model.X : emb_data})
        
    label_1 = unique_type[np.argmax(y_hat_1[0])]
    label_2 = unique_type[np.argmax(y_hat_2[0])]
        
    plt.imshow(image)
    plt.title(label_1 + " + " + label_2)
    
fig.tight_layout()
plt.show()



In [ ]: