In [12]:
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import time

In [13]:
data_filenames = ['./Data/data_19x19_filter.csv']
board_side_length = 19
D = board_side_length ** 2
empty_board_string = "0 " * (D - 1) + "0"
n_batches = 100000
batch_size = 10000
debug = False

In [14]:
def read_my_csv(filename_queue):
    # Set up the reader
    reader = tf.TextLineReader()
    # Grab the values from the file(s)
    key, value = reader.read(filename_queue)
    # Perform the decoding
    default_values = [["0"],[empty_board_string],[empty_board_string]]
    col1, col2, col3 = tf.decode_csv(value, record_defaults=default_values)
    # Perform preporcessing here
    split_col2 = tf.string_split(tf.expand_dims(col2, axis=0), delimiter=" ")
    features = tf.reshape(tf.string_to_number(split_col2.values, out_type=tf.float32),[D])
    split_col1 = tf.string_split(tf.expand_dims(col1, axis=0), delimiter=" ")
    labels = tf.reshape(tf.string_to_number(split_col1.values, out_type=tf.float32),[1])
    return features, labels

In [15]:
def input_pipeline(filenames, batch_size):
    filename_queue = tf.train.string_input_producer(filenames, shuffle=True)
    example, label = read_my_csv(filename_queue)
    min_after_dequeue = 100
    capacity = min_after_dequeue + 3 * batch_size
    # Create the batches using shuffle_batch which performs random shuffling
    example_batch, label_batch = tf.train.shuffle_batch([example, label], 
                                                        batch_size=batch_size, 
                                                        capacity=capacity, 
                                                        min_after_dequeue=min_after_dequeue)
    return example_batch, label_batch

In [16]:
example_batch, label_batch = input_pipeline(data_filenames, batch_size)
test_batch, test_label_batch = input_pipeline(data_filenames, 10000)

In [17]:
X = tf.placeholder(tf.float32, [None, D])
W = tf.Variable(tf.zeros([D, 2]))
b = tf.Variable(tf.zeros([2]))

init = tf.global_variables_initializer()

In [18]:
Y = tf.nn.softmax(tf.matmul(tf.reshape(X,[-1,D]), W) + b) # Vector representation of board state

Y_ = tf.placeholder(tf.float32, [None, 2])

In [19]:
cross_entropy = -tf.reduce_sum(Y_ * tf.log(tf.clip_by_value(Y,1e-10,1.0)))
# cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=Y, labels=Y_)
is_correct = tf.equal(tf.argmax(Y, 1), tf.argmax(Y_, 1))
accuracy = tf.reduce_mean(tf.cast(is_correct, tf.float32))

optimizer = tf.train.GradientDescentOptimizer(0.003)
train_step = optimizer.minimize(cross_entropy)

In [21]:
train_ces = []
train_accs = []
test_accs = []
test_ces = []

with tf.Session() as sess:
    start_time = time.time()
    sess.run(tf.global_variables_initializer())
    
    coordinator = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coordinator)
    
    test_x_batch, test_y_batch = sess.run([test_batch, test_label_batch])
    test_y_batch = [np.concatenate((np.where(yy > 0, yy, 0), np.where(yy < 0, abs(yy), 0))) for yy in test_y_batch]
    test_data = {X: test_x_batch, Y_: test_y_batch}
    
    for i in xrange(n_batches):
        train_x_batch, train_y_batch = sess.run([example_batch, label_batch])
        train_y_batch = [np.concatenate((np.where(yy > 0, yy, 0), np.where(yy < 0, abs(yy), 0))) for yy in train_y_batch]
        train_data = {X: train_x_batch, Y_: train_y_batch}
        sess.run(train_step, feed_dict=train_data)
        
        if i==0:                
            print 'Test       Test           Train     Train'
            print 'Accuracy  Cross_entropy  Accuracy  Cross_entropy  Batch'
            print '--------  -------------  --------  -------------  -----'
        if i%10==0 or i==n_batches-1:
            train_accuracy, train_cross_entropy = sess.run([accuracy, cross_entropy], feed_dict = train_data)            
            test_accuracy, test_cross_entropy = sess.run([accuracy, cross_entropy], feed_dict = test_data)
        
            print '%.4f        %.0f        %.3f       %.0f          %d' % \
            (test_accuracy, test_cross_entropy, train_accuracy, train_cross_entropy, i)
            
            test_accs.append(test_accuracy)
            test_ces.append(test_cross_entropy)
            train_accs.append(train_accuracy)
            train_ces.append(train_cross_entropy)

    coordinator.request_stop()
    coordinator.join(threads)
    sess.close()


Test       Test           Train     Train
Accuracy  Cross_entropy  Accuracy  Cross_entropy  Batch
--------  -------------  --------  -------------  -----
0.5740        10193        0.575       10169          0
0.5022        33143        0.510       32431          10
0.4938        58081        0.499       57329          20
0.5215        62446        0.532       62249          30
0.5125        26912        0.519       26633          40
0.4826        74222        0.482       76330          50
0.5166        28274        0.527       27528          60
0.4878        82601        0.495       82362          70
0.4892        80104        0.485       81222          80
0.4892        42939        0.495       43025          90
0.5161        64521        0.516       63723          100
0.5229        68857        0.531       67394          110
0.5275        49561        0.534       49061          120
0.5051        34218        0.523       33476          130
0.5130        21869        0.521       21465          140
0.4942        57044        0.499       56357          150
0.4979        41993        0.515       40675          160
0.4910        81389        0.487       81151          170
0.5068        63154        0.523       61191          180
0.4960        33992        0.508       33907          190
0.4923        71748        0.494       71810          200
0.5212        50016        0.534       48010          210
0.5089        55561        0.518       55300          220
0.5214        55084        0.528       54412          230
0.4871        81032        0.493       81006          240
0.4856        83650        0.483       84289          250
0.5271        39909        0.538       38537          260
0.5172        57491        0.525       57148          270
0.5302        20546        0.526       21101          280
0.4985        61330        0.503       60473          290
0.5004        56679        0.511       55694          300
0.5073        59544        0.512       58014          310
0.5075        56092        0.514       53899          320
0.5089        62547        0.513       60627          330
0.5091        61248        0.512       59364          340
0.4891        79616        0.493       79404          350
0.4912        77679        0.491       78994          360
0.5036        27889        0.514       27266          370
0.5127        62432        0.514       62835          380
0.5100        60717        0.510       61455          390
0.5117        58701        0.518       58111          400
0.5160        58538        0.526       59130          410
0.5121        61885        0.523       61486          420
0.5069        67896        0.506       68213          430
0.5131        56532        0.513       56658          440
0.5242        55209        0.523       55490          450
0.5120        62952        0.512       63265          460
0.5108        55157        0.520       54047          470
0.5109        60797        0.512       61280          480
0.5151        48429        0.514       48311          490
0.5164        54970        0.520       54391          500
0.5145        58405        0.516       57977          510
0.5068        69436        0.508       68956          520
0.5153        52990        0.513       53238          530
0.5182        49945        0.525       49857          540
0.5196        54698        0.517       56473          550
0.5130        59910        0.519       59041          560
0.5051        64854        0.510       64479          570
0.5161        58788        0.516       58395          580
0.5120        62022        0.521       60614          590
0.5092        58929        0.511       58659          600
0.5228        43842        0.518       43302          610
0.5116        62604        0.516       62632          620
0.5094        65806        0.507       66009          630
0.5117        57261        0.515       57325          640
0.5228        55707        0.525       55832          650
0.5065        64565        0.507       64980          660
0.5175        49139        0.524       48991          670
0.5123        57094        0.516       57754          680
0.5086        65312        0.512       64775          690
0.5196        54545        0.523       54823          700
0.5145        57629        0.515       58156          710
0.5109        60000        0.517       59979          720
0.5102        61775        0.520       61340          730
0.5086        60722        0.514       60668          740
0.5111        57404        0.518       57210          750
0.5140        60956        0.522       59825          760
0.5059        68356        0.513       67618          770
0.5128        55574        0.516       54210          780
0.5235        56020        0.528       55565          790
0.5113        61842        0.512       61841          800
0.5198        43098        0.533       41775          810
0.5185        55150        0.520       55012          820
0.5146        54238        0.513       55023          830
0.5174        51643        0.518       51786          840
0.5151        58256        0.520       58508          850
0.5089        62806        0.502       64269          860
0.5156        59275        0.528       58354          870
0.5112        60461        0.507       61753          880
0.5187        42998        0.524       42335          890
0.5122        57525        0.522       56474          900
0.5149        58687        0.517       58766          910
0.5073        62637        0.506       62966          920
0.5236        56808        0.530       56546          930
0.5162        55474        0.512       56248          940
0.5169        54979        0.529       54161          950
0.5140        46394        0.525       45309          960
0.5141        58793        0.510       60008          970
0.5131        53232        0.514       53461          980
0.5237        43623        0.525       44349          990
0.5164        58463        0.521       59074          1000
0.5176        54877        0.522       54565          1010
0.5147        58111        0.520       57890          1020
0.5084        57746        0.511       57666          1030
0.5159        57483        0.513       58777          1040
0.5183        45933        0.517       45859          1050
0.5131        57692        0.519       56767          1060
0.5117        61167        0.506       61749          1070
0.5113        56813        0.521       56789          1080
0.5137        57361        0.513       58172          1090
0.5182        55330        0.525       55045          1100
0.5124        60632        0.522       59461          1110
0.5113        61863        0.519       61008          1120
0.5132        62367        0.517       62360          1130
0.5106        61900        0.511       61343          1140
0.5220        53373        0.519       53997          1150
0.5135        59515        0.514       59242          1160
0.5065        68312        0.518       66742          1170
0.5130        63921        0.517       62862          1180
0.5064        67352        0.505       67680          1190
0.5114        52313        0.516       52499          1200
0.5140        58219        0.518       58321          1210
0.5081        58811        0.507       59006          1220
0.5186        56784        0.512       57681          1230
0.5105        59716        0.507       60478          1240
0.5150        53895        0.511       55585          1250
0.5102        63402        0.509       64371          1260
0.5114        57167        0.514       56905          1270
0.5145        58796        0.524       58015          1280
0.5246        54275        0.542       52220          1290
0.5246        67980        0.525       67422          1300
0.5218        41845        0.527       41920          1310
0.5153        22348        0.524       21807          1320
0.4922        80231        0.494       79762          1330
0.4842        83907        0.485       83733          1340
0.4920        58281        0.501       59012          1350
0.5240        44390        0.527       43087          1360
0.5168        51969        0.532       50646          1370
0.4966        43715        0.515       42299          1380
0.5181        41679        0.532       40618          1390
0.5277        28736        0.543       27275          1400
0.4914        37388        0.495       36564          1410
0.5242        48229        0.535       47290          1420
0.5075        46608        0.510       45593          1430
0.4944        55138        0.491       55835          1440
0.5221        25744        0.541       24595          1450
0.4943        77565        0.493       78768          1460
0.4944        61699        0.498       62975          1470
0.5213        68938        0.537       67036          1480
0.4999        47505        0.502       46700          1490
0.5117        38508        0.519       37505          1500
0.5173        64096        0.524       64229          1510
0.5045        39919        0.521       38916          1520
0.4911        82814        0.495       82485          1530
0.5219        64155        0.530       62074          1540
0.4957        50679        0.498       50556          1550
0.5204        67812        0.534       67064          1560
0.4917        78879        0.486       79421          1570
0.5184        66387        0.528       64612          1580
0.5285        37211        0.531       36790          1590
0.4896        83062        0.499       81946          1600
0.5112        47904        0.522       46702          1610
0.5279        31285        0.542       29909          1620
0.5115        45250        0.513       45465          1630
0.5184        78531        0.524       77488          1640
0.4905        74330        0.492       75366          1650
0.5104        39395        0.519       38380          1660
0.4874        87761        0.495       86829          1670
0.5200        46259        0.530       45437          1680
0.4882        49550        0.491       49685          1690
0.5168        84531        0.528       82736          1700
0.4890        48224        0.498       48659          1710
0.5183        78666        0.518       79428          1720
0.5199        75525        0.521       75234          1730
0.4869        70517        0.487       71350          1740
0.5350        37381        0.545       36282          1750
0.4892        81063        0.489       81677          1760
0.5227        69678        0.525       68762          1770
0.4864        81441        0.488       80586          1780
0.4896        72155        0.492       73631          1790
0.5198        76989        0.517       77812          1800
0.5251        58759        0.535       57609          1810
0.5248        24423        0.538       23423          1820
0.4971        48613        0.507       49064          1830
0.5187        48933        0.526       48286          1840
0.5067        52848        0.511       52935          1850
0.5249        69272        0.525       69825          1860
0.5442        22788        0.554       21553          1870
0.5235        69176        0.524       68638          1880
0.5072        46139        0.516       45526          1890
0.4938        33673        0.517       32805          1900
0.5113        58850        0.515       58550          1910
0.5212        74872        0.527       73956          1920
0.4910        83354        0.497       82907          1930
0.4884        77622        0.495       78105          1940
0.5168        69050        0.526       68544          1950
0.5206        70125        0.520       69528          1960
0.5080        40426        0.518       39708          1970
0.5190        76012        0.536       74137          1980
0.5140        32327        0.528       30852          1990
0.5003        59420        0.502       58978          2000
0.5169        66173        0.526       64361          2010
INFO:tensorflow:Error reported to Coordinator: <type 'exceptions.RuntimeError'>, Attempted to use a closed Session.
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-21-2db2a4801c25> in <module>()
     16 
     17     for i in xrange(n_batches):
---> 18         train_x_batch, train_y_batch = sess.run([example_batch, label_batch])
     19         train_y_batch = [np.concatenate((np.where(yy > 0, yy, 0), np.where(yy < 0, abs(yy), 0))) for yy in train_y_batch]
     20         train_data = {X: train_x_batch, Y_: train_y_batch}

/usr/local/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in run(self, fetches, feed_dict, options, run_metadata)
    765     try:
    766       result = self._run(None, fetches, feed_dict, options_ptr,
--> 767                          run_metadata_ptr)
    768       if run_metadata:
    769         proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

/usr/local/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in _run(self, handle, fetches, feed_dict, options, run_metadata)
    963     if final_fetches or final_targets:
    964       results = self._do_run(handle, final_targets, final_fetches,
--> 965                              feed_dict_string, options, run_metadata)
    966     else:
    967       results = []

/usr/local/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)
   1013     if handle is None:
   1014       return self._do_call(_run_fn, self._session, feed_dict, fetch_list,
-> 1015                            target_list, options, run_metadata)
   1016     else:
   1017       return self._do_call(_prun_fn, self._session, handle, feed_dict,

/usr/local/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in _do_call(self, fn, *args)
   1020   def _do_call(self, fn, *args):
   1021     try:
-> 1022       return fn(*args)
   1023     except errors.OpError as e:
   1024       message = compat.as_text(e.message)

/usr/local/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in _run_fn(session, feed_dict, fetch_list, target_list, options, run_metadata)
   1002         return tf_session.TF_Run(session, options,
   1003                                  feed_dict, fetch_list, target_list,
-> 1004                                  status, run_metadata)
   1005 
   1006     def _prun_fn(session, handle, feed_dict, fetch_list):

KeyboardInterrupt: 

In [27]:
batches = [i*10 for i in xrange(len(test_accs))]
print sum(test_accs[-10::])/10
plt.title('Softmax Value Network')
plt.plot(batches, test_accs, label='Test Accuracy')
plt.plot(batches, train_accs, label='Train Accuracy')
plt.legend(loc='best')
# plt.xlabel('Batch')
plt.ylabel('Accuracy')
plt.xlim([0,2000])
plt.show()
plt.title('Softmax Value Network Cross-Entropy')
plt.plot(batches, test_ces, label='Test Cross-Entropy')
plt.plot(batches, train_ces, label='Training Cross-Entropy')
plt.legend(loc='best')
plt.ylabel('Cross-Entropy')
# plt.xlabel('Batch')
plt.xlim([0,2000])
plt.show()


0.509620001912

In [ ]: