In [3]:
% matplotlib inline
import re
import sys
import matplotlib.pyplot as plt
LINE_SUBS = [ 'iter', 'loss' ]
def process_log( file_name ):
fin = open( file_name, 'rt' )
iterations, loss = [], []
for line in fin:
if all( x in line for x in LINE_SUBS ):
temp = re.split( ',| ', line )
if len( temp ) == 5:
iterations.append( float( temp[1] ) )
loss.append( float( temp[4] ) )
fin.close()
return iterations, loss
relu_i, relu_l = process_log( 'tanh.log' )
tanh_i, tanh_l = process_log( 'relu.log' )
fig, ax = plt.subplots( figsize=(10,6) )
ax.plot( relu_i, relu_l, 'b-', label = 'tanh' )
ax.plot( tanh_i, tanh_l, 'g-', label = 'relu' )
legend = plt.legend(loc='upper center', shadow=True, fontsize='x-large')
ax.set_ylabel( 'loss' )
ax.set_xlabel( 'iterations' )
plt.show()
In [ ]: