In [1]:
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import time

In [2]:
data_filenames = ['./Data/data_19x19_shuffled.csv']
board_side_length = 19
D = board_side_length ** 2
empty_board_string = "0 " * (D - 1) + "0"
n_batches = 3000
batch_size = 10000
debug = False

In [3]:
def read_my_csv(filename_queue):
    # Set up the reader
    reader = tf.TextLineReader()
    # Grab the values from the file(s)
    key, value = reader.read(filename_queue)
    # Perform the decoding
    default_values = [["0"],[empty_board_string],[empty_board_string]]
    col1, col2, col3 = tf.decode_csv(value, record_defaults=default_values)
    # Perform preporcessing here
    split_col2 = tf.string_split(tf.expand_dims(col2, axis=0), delimiter=" ")
    features = tf.reshape(tf.string_to_number(split_col2.values, out_type=tf.float32),[D])
    split_col3 = tf.string_split(tf.expand_dims(col3, axis=0), delimiter=" ")
    labels = tf.reshape(tf.string_to_number(split_col3.values, out_type=tf.float32),[D])
    return features, labels

In [4]:
def input_pipeline(filenames, batch_size):
    filename_queue = tf.train.string_input_producer(filenames, shuffle=True)
    example, label = read_my_csv(filename_queue)
    min_after_dequeue = 100
    capacity = min_after_dequeue + 3 * batch_size
    # Create the batches using shuffle_batch which performs random shuffling
    example_batch, label_batch = tf.train.shuffle_batch([example, label], 
                                                        batch_size=batch_size, 
                                                        capacity=capacity, 
                                                        min_after_dequeue=min_after_dequeue)
    return example_batch, label_batch

In [5]:
example_batch, label_batch = input_pipeline(data_filenames, batch_size)

In [6]:
X = tf.placeholder(tf.float32, [None, D])
W = tf.Variable(tf.zeros([D, D]))
b = tf.Variable(tf.zeros([D]))

init = tf.global_variables_initializer()

In [7]:
Y = tf.nn.softmax(tf.matmul(tf.reshape(X,[-1,D]), W) + b) # Vector representation of board state

Y_ = tf.placeholder(tf.float32, [None, D])

In [8]:
cross_entropy = -tf.reduce_sum(Y_ * tf.log(tf.clip_by_value(Y,1e-10,1.0)))
is_correct = tf.equal(tf.argmax(Y, 1), tf.argmax(Y_, 1))
accuracy = tf.reduce_mean(tf.cast(is_correct, tf.float32))

labels = tf.argmax(Y_, 1)
top5 = tf.nn.in_top_k(Y, labels, 5)
top5_acc = tf.reduce_mean(tf.cast(top5, tf.float32))

optimizer = tf.train.GradientDescentOptimizer(0.003)
train_step = optimizer.minimize(cross_entropy)

In [9]:
train_ces = []
train_accs = []
test_accs = []
test_ces = []
test_top5_accs = []

with tf.Session() as sess:
    start_time = time.time()
    sess.run(tf.global_variables_initializer())
    
    coordinator = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coordinator)
    
    for i in xrange(n_batches):
        train_x_batch, train_y_batch = sess.run([example_batch, label_batch])
        train_data = {X: train_x_batch, Y_: train_y_batch}
        sess.run(train_step, feed_dict=train_data)
        train_accuracy, train_cross_entropy = sess.run([accuracy, cross_entropy], feed_dict = train_data)

        #if i%10 == 0:
        #    print '%.4f     %.0f' (train_accuracy, train_cross_entropy)
        test_x_batch, test_y_batch = sess.run([example_batch, label_batch])
        test_data = {X: test_x_batch, Y_: test_y_batch}
        test_accuracy, test_top5_acc, test_cross_entropy = sess.run([accuracy, top5_acc, cross_entropy], feed_dict = test_data)
        
        if i==0:                
            print 'Test      Test            Test           Train     Train'
            print 'Accuracy  Top-5 Accuracy  Cross_entropy  Accuracy  Cross_entropy  Batch'
            print '--------  --------------  -------------  --------  -------------  -----'
        if i%10==0 or i==n_batches-1:
            print '%.4f      %.4f           %.0f        %.3f       %.0f          %d' % \
            (test_accuracy, test_top5_acc, test_cross_entropy, train_accuracy, train_cross_entropy, i)
            
        test_accs.append(test_accuracy)
        test_ces.append(test_cross_entropy)
        test_top5_accs.append(test_top5_acc)
        train_accs.append(train_accuracy)
        train_ces.append(train_cross_entropy)

    coordinator.request_stop()
    coordinator.join(threads)
    sess.close()


Test      Test            Test           Train     Train
Accuracy  Top-5 Accuracy  Cross_entropy  Accuracy  Cross_entropy  Batch
--------  --------------  -------------  --------  -------------  -----
0.0043      0.0223           58852        0.421       55646          0
0.0104      0.0380           58711        0.049       55477          10
0.0103      0.0355           58700        0.036       55567          20
0.0106      0.0397           58746        0.033       55539          30
0.0138      0.0474           58703        0.036       55466          40
0.0101      0.0399           58621        0.038       55501          50
0.0137      0.0452           58738        0.034       55447          60
0.0137      0.0476           58705        0.036       55550          70
0.0119      0.0422           58720        0.036       55500          80
0.0137      0.0489           58647        0.035       55414          90
0.0141      0.0449           58699        0.036       55372          100
0.0157      0.0474           58700        0.035       55536          110
0.0156      0.0510           58710        0.035       55407          120
0.0134      0.0429           58704        0.039       55514          130
0.0163      0.0488           58637        0.036       55404          140
0.0156      0.0499           58661        0.040       55446          150
0.0165      0.0479           58550        0.041       55383          160
0.0150      0.0478           58713        0.038       55406          170
0.0172      0.0539           58634        0.037       55376          180
0.0163      0.0473           58672        0.039       55387          190
0.0155      0.0513           58642        0.035       55385          200
0.0162      0.0506           58664        0.038       55438          210
0.0177      0.0526           58616        0.039       55400          220
0.0163      0.0505           58687        0.039       55526          230
0.0174      0.0564           58525        0.039       55299          240
0.0169      0.0502           58596        0.040       55314          250
0.0138      0.0493           58678        0.039       55270          260
0.0167      0.0529           58566        0.040       55381          270
0.0170      0.0541           58504        0.037       55414          280
0.0184      0.0543           58598        0.036       55335          290
0.0162      0.0497           58677        0.039       55317          300
0.0138      0.0502           58599        0.040       55381          310
0.0154      0.0478           58524        0.038       55292          320
0.0139      0.0498           58689        0.033       55405          330
0.0157      0.0484           58537        0.034       55378          340
0.0169      0.0510           58662        0.036       55359          350
0.0160      0.0505           58596        0.039       55344          360
0.0137      0.0503           58647        0.035       55461          370
0.0167      0.0501           58472        0.037       55356          380
0.0180      0.0527           58597        0.035       55432          390
0.0162      0.0483           58626        0.038       55317          400
0.0171      0.0533           58557        0.038       55541          410
0.0171      0.0506           58645        0.035       55394          420
0.0153      0.0508           58593        0.035       55434          430
0.0176      0.0528           58563        0.033       55382          440
0.0160      0.0496           58592        0.035       55268          450
0.0155      0.0511           58510        0.038       55380          460
0.0192      0.0539           58501        0.039       55280          470
0.0177      0.0519           58532        0.040       55381          480
0.0154      0.0498           58637        0.037       55330          490
0.0189      0.0543           58563        0.039       55308          500
0.0174      0.0492           58654        0.035       55473          510
0.0174      0.0552           58550        0.039       55362          520
0.0161      0.0506           58702        0.041       55340          530
0.0158      0.0511           58718        0.039       55374          540
0.0145      0.0494           58564        0.036       55348          550
0.0171      0.0542           58603        0.035       55413          560
0.0180      0.0564           58547        0.039       55423          570
0.0180      0.0517           58567        0.038       55324          580
0.0178      0.0523           58540        0.037       55602          590
0.0153      0.0494           58546        0.038       55435          600
0.0179      0.0547           58512        0.038       55392          610
0.0181      0.0521           58384        0.038       55325          620
0.0165      0.0537           58642        0.038       55514          630
0.0164      0.0500           58627        0.039       55364          640
0.0176      0.0523           58622        0.037       55386          650
0.0165      0.0509           58661        0.036       55382          660
0.0156      0.0504           58612        0.038       55387          670
0.0172      0.0521           58517        0.035       55460          680
0.0157      0.0497           58636        0.036       55333          690
0.0175      0.0527           58517        0.038       55297          700
0.0173      0.0515           58647        0.041       55221          710
0.0181      0.0544           58524        0.036       55298          720
0.0162      0.0524           58673        0.041       55351          730
0.0169      0.0538           58486        0.035       55456          740
0.0171      0.0497           58597        0.039       55218          750
0.0162      0.0534           58585        0.035       55354          760
0.0186      0.0528           58492        0.037       55278          770
0.0179      0.0518           58551        0.036       55445          780
0.0167      0.0491           58692        0.038       55302          790
0.0173      0.0473           58564        0.036       55286          800
0.0161      0.0499           58572        0.035       55343          810
0.0162      0.0531           58612        0.040       55332          820
0.0175      0.0510           58497        0.037       55276          830
0.0168      0.0534           58592        0.035       55338          840
0.0143      0.0495           58536        0.038       55413          850
0.0156      0.0492           58597        0.039       55318          860
0.0149      0.0498           58491        0.039       55385          870
0.0195      0.0541           58598        0.038       55348          880
0.0183      0.0563           58594        0.037       55466          890
0.0172      0.0519           58640        0.037       55330          900
0.0204      0.0541           58558        0.039       55366          910
0.0165      0.0489           58569        0.036       55240          920
0.0190      0.0560           58453        0.035       55361          930
0.0173      0.0552           58466        0.036       55488          940
0.0188      0.0525           58557        0.038       55314          950
0.0186      0.0566           58460        0.039       55368          960
0.0179      0.0533           58519        0.040       55351          970
0.0192      0.0537           58529        0.041       55381          980
0.0150      0.0534           58515        0.042       55259          990
0.0155      0.0502           58555        0.038       55320          1000
0.0183      0.0534           58558        0.035       55343          1010
0.0186      0.0534           58517        0.037       55401          1020
0.0172      0.0528           58670        0.037       55507          1030
0.0171      0.0536           58611        0.041       55215          1040
0.0187      0.0521           58540        0.035       55404          1050
0.0189      0.0507           58617        0.036       55364          1060
0.0174      0.0517           58639        0.038       55308          1070
0.0171      0.0527           58579        0.038       55278          1080
0.0164      0.0515           58550        0.040       55262          1090
0.0179      0.0531           58516        0.041       55329          1100
0.0170      0.0539           58501        0.039       55302          1110
0.0176      0.0479           58590        0.041       55334          1120
0.0176      0.0533           58600        0.040       55253          1130
0.0154      0.0513           58512        0.035       55392          1140
0.0195      0.0575           58510        0.036       55367          1150
0.0183      0.0582           58473        0.037       55366          1160
0.0189      0.0532           58522        0.040       55372          1170
0.0168      0.0509           58580        0.038       55399          1180
0.0152      0.0479           58563        0.042       55193          1190
0.0148      0.0528           58589        0.043       55285          1200
0.0168      0.0504           58650        0.038       55388          1210
0.0186      0.0547           58523        0.035       55335          1220
0.0148      0.0469           58641        0.036       55321          1230
0.0180      0.0532           58528        0.043       55294          1240
0.0167      0.0517           58547        0.042       55343          1250
0.0153      0.0545           58623        0.038       55427          1260
0.0181      0.0555           58661        0.037       55367          1270
0.0181      0.0541           58520        0.037       55334          1280
0.0184      0.0526           58539        0.033       55376          1290
0.0160      0.0504           58661        0.036       55398          1300
0.0197      0.0560           58576        0.036       55350          1310
0.0199      0.0544           58564        0.038       55282          1320
0.0167      0.0521           58629        0.040       55320          1330
0.0182      0.0541           58442        0.036       55306          1340
0.0163      0.0497           58752        0.037       55318          1350
0.0157      0.0542           58646        0.039       55447          1360
0.0173      0.0507           58583        0.040       55349          1370
0.0201      0.0559           58574        0.038       55405          1380
0.0167      0.0520           58542        0.038       55200          1390
0.0186      0.0521           58534        0.040       55319          1400
0.0170      0.0488           58631        0.039       55400          1410
0.0188      0.0519           58579        0.040       55300          1420
0.0178      0.0553           58499        0.039       55302          1430
0.0167      0.0498           58539        0.041       55396          1440
0.0161      0.0506           58542        0.040       55461          1450
0.0155      0.0487           58582        0.037       55344          1460
0.0154      0.0524           58616        0.038       55361          1470
0.0180      0.0538           58563        0.038       55300          1480
0.0174      0.0563           58627        0.039       55420          1490
0.0177      0.0513           58558        0.036       55291          1500
0.0170      0.0533           58458        0.036       55264          1510
0.0137      0.0492           58611        0.036       55375          1520
0.0183      0.0518           58705        0.040       55419          1530
0.0161      0.0498           58586        0.036       55523          1540
0.0159      0.0504           58489        0.039       55392          1550
0.0171      0.0544           58555        0.036       55390          1560
0.0170      0.0562           58548        0.037       55429          1570
0.0205      0.0557           58550        0.039       55350          1580
0.0165      0.0510           58690        0.043       55258          1590
0.0176      0.0534           58545        0.039       55259          1600
0.0167      0.0525           58557        0.039       55350          1610
0.0176      0.0556           58527        0.037       55412          1620
0.0197      0.0606           58440        0.039       55315          1630
0.0158      0.0525           58611        0.038       55460          1640
0.0181      0.0526           58607        0.039       55333          1650
0.0159      0.0492           58627        0.036       55315          1660
0.0167      0.0483           58654        0.038       55369          1670
0.0177      0.0524           58458        0.039       55431          1680
0.0167      0.0506           58641        0.035       55349          1690
0.0180      0.0554           58514        0.038       55376          1700
0.0171      0.0489           58530        0.038       55306          1710
0.0148      0.0555           58467        0.037       55494          1720
0.0148      0.0548           58538        0.038       55391          1730
0.0169      0.0517           58647        0.040       55381          1740
0.0191      0.0560           58501        0.043       55287          1750
0.0162      0.0497           58506        0.036       55491          1760
0.0148      0.0500           58576        0.034       55341          1770
0.0182      0.0553           58428        0.038       55375          1780
0.0179      0.0527           58547        0.036       55302          1790
0.0189      0.0550           58582        0.037       55460          1800
0.0188      0.0543           58652        0.037       55356          1810
0.0179      0.0522           58433        0.038       55413          1820
0.0168      0.0540           58646        0.037       55296          1830
0.0157      0.0545           58473        0.038       55391          1840
0.0162      0.0549           58699        0.036       55363          1850
0.0144      0.0482           58584        0.040       55213          1860
0.0179      0.0521           58559        0.042       55270          1870
0.0183      0.0526           58656        0.038       55292          1880
0.0165      0.0489           58663        0.034       55396          1890
0.0159      0.0490           58571        0.038       55314          1900
0.0204      0.0542           58581        0.037       55331          1910
0.0161      0.0484           58620        0.038       55361          1920
0.0188      0.0524           58583        0.040       55304          1930
0.0164      0.0521           58497        0.039       55404          1940
0.0162      0.0512           58540        0.039       55279          1950
0.0175      0.0512           58616        0.039       55348          1960
0.0182      0.0531           58566        0.036       55309          1970
0.0190      0.0542           58621        0.041       55308          1980
0.0180      0.0538           58665        0.040       55182          1990
0.0202      0.0564           58609        0.036       55368          2000
0.0196      0.0549           58472        0.040       55329          2010
0.0161      0.0532           58604        0.037       55480          2020
0.0166      0.0540           58506        0.038       55434          2030
0.0165      0.0506           58587        0.038       55312          2040
0.0163      0.0517           58556        0.039       55303          2050
0.0157      0.0497           58492        0.037       55427          2060
0.0179      0.0536           58458        0.039       55254          2070
0.0180      0.0554           58629        0.038       55211          2080
0.0148      0.0513           58460        0.040       55293          2090
0.0158      0.0512           58669        0.036       55264          2100
0.0152      0.0516           58577        0.034       55338          2110
0.0187      0.0534           58570        0.038       55475          2120
0.0177      0.0522           58522        0.037       55430          2130
0.0174      0.0512           58550        0.038       55350          2140
0.0181      0.0527           58521        0.042       55366          2150
0.0185      0.0527           58624        0.039       55327          2160
0.0167      0.0522           58649        0.037       55428          2170
0.0160      0.0499           58675        0.038       55277          2180
0.0162      0.0478           58632        0.036       55263          2190
0.0186      0.0553           58535        0.038       55273          2200
0.0138      0.0488           58678        0.040       55311          2210
0.0162      0.0515           58514        0.037       55314          2220
0.0160      0.0546           58530        0.040       55360          2230
0.0186      0.0565           58558        0.039       55291          2240
0.0156      0.0533           58549        0.037       55367          2250
0.0143      0.0527           58561        0.042       55250          2260
0.0195      0.0561           58586        0.038       55283          2270
0.0177      0.0524           58631        0.036       55326          2280
0.0201      0.0590           58495        0.038       55346          2290
0.0162      0.0508           58469        0.040       55437          2300
0.0197      0.0536           58483        0.036       55377          2310
0.0147      0.0514           58596        0.039       55374          2320
0.0176      0.0521           58576        0.039       55305          2330
0.0190      0.0551           58446        0.040       55299          2340
0.0185      0.0534           58495        0.038       55293          2350
0.0169      0.0531           58448        0.040       55268          2360
0.0163      0.0516           58589        0.039       55319          2370
0.0177      0.0535           58596        0.037       55373          2380
0.0182      0.0555           58536        0.039       55354          2390
0.0146      0.0473           58678        0.037       55355          2400
0.0169      0.0502           58483        0.038       55332          2410
0.0152      0.0489           58701        0.041       55344          2420
0.0174      0.0537           58621        0.033       55439          2430
0.0171      0.0515           58612        0.040       55349          2440
0.0159      0.0514           58684        0.040       55318          2450
0.0175      0.0548           58564        0.041       55519          2460
0.0168      0.0526           58509        0.041       55313          2470
0.0167      0.0566           58421        0.041       55297          2480
0.0177      0.0506           58600        0.043       55311          2490
0.0149      0.0495           58515        0.036       55347          2500
0.0190      0.0551           58499        0.037       55408          2510
0.0168      0.0532           58626        0.039       55429          2520
0.0171      0.0501           58638        0.038       55407          2530
0.0166      0.0542           58428        0.035       55381          2540
0.0171      0.0535           58608        0.041       55259          2550
0.0185      0.0551           58552        0.040       55255          2560
0.0194      0.0525           58650        0.037       55447          2570
0.0184      0.0575           58586        0.039       55287          2580
0.0167      0.0517           58625        0.037       55374          2590
0.0158      0.0526           58594        0.038       55318          2600
0.0154      0.0499           58704        0.040       55281          2610
0.0187      0.0539           58529        0.038       55481          2620
0.0186      0.0532           58617        0.036       55420          2630
0.0174      0.0522           58575        0.037       55257          2640
0.0164      0.0505           58601        0.037       55348          2650
0.0157      0.0505           58643        0.039       55357          2660
0.0182      0.0520           58624        0.038       55307          2670
0.0171      0.0535           58655        0.040       55324          2680
0.0187      0.0517           58636        0.038       55309          2690
0.0150      0.0503           58474        0.039       55296          2700
0.0161      0.0533           58560        0.038       55250          2710
0.0184      0.0559           58588        0.042       55315          2720
0.0175      0.0530           58493        0.037       55376          2730
0.0164      0.0570           58606        0.038       55349          2740
0.0194      0.0540           58573        0.038       55337          2750
0.0190      0.0553           58606        0.037       55430          2760
0.0157      0.0522           58543        0.040       55334          2770
0.0170      0.0574           58517        0.041       55313          2780
0.0171      0.0535           58484        0.036       55502          2790
0.0183      0.0520           58665        0.036       55430          2800
0.0183      0.0564           58529        0.038       55372          2810
0.0185      0.0501           58461        0.039       55426          2820
0.0179      0.0525           58568        0.038       55325          2830
0.0194      0.0561           58573        0.036       55428          2840
0.0159      0.0479           58617        0.036       55442          2850
0.0163      0.0528           58532        0.036       55313          2860
0.0170      0.0501           58611        0.039       55273          2870
0.0175      0.0538           58604        0.038       55445          2880
0.0188      0.0552           58674        0.036       55347          2890
0.0156      0.0478           58624        0.039       55501          2900
0.0193      0.0521           58595        0.038       55333          2910
0.0183      0.0539           58624        0.039       55381          2920
0.0171      0.0505           58525        0.040       55346          2930
0.0163      0.0515           58678        0.039       55382          2940
0.0177      0.0544           58597        0.037       55330          2950
0.0162      0.0495           58609        0.037       55360          2960
0.0176      0.0533           58586        0.036       55374          2970
0.0190      0.0551           58584        0.038       55392          2980
0.0195      0.0549           58586        0.039       55369          2990
0.0183      0.0538           58611        0.035       55521          2999

In [ ]:
batches = [i for i in xrange(len(test_accs))]
plt.title('Softmax Model')
plt.plot(batches, test_accs, label='Test Accuracy')
plt.plot(batches, test_top5_accs, label='Top-5 Accuracy')
plt.legend(loc='best')
plt.xlabel('Batch')
plt.ylabel('Accuracy')
plt.show()
plt.title('Softmax Cross-Entropy')
plt.plot(batches, test_ces, label='Test Cross-Entropy')
plt.plot(batches, train_ces, label='Training Cross-Entropy')
plt.legend(loc='best')
plt.ylabel('Cross-Entropy')
plt.xlabel('Batch')
plt.show()

In [ ]: