In [1]:
report_file = 'reports/encdec_200_512_2.json'
log_file = 'logs/encdec_200_512_logs.json'
import json
import matplotlib.pyplot as plt
with open(report_file) as f:
report = json.loads(f.read())
with open(log_file) as f:
logs = json.loads(f.read())
print'Architecture: \n\n', report['architecture'],
In [ ]:
print('Train Perplexity: ', report['train_perplexity'])
print('Valid Perplexity: ', report['valid_perplexity'])
print('Test Perplexity: ', report['test_perplexity'])
In [ ]:
%matplotlib inline
for k in logs.keys():
plt.plot(logs[k][0], logs[k][1], label=str(k) + ' (train)')
plt.plot(logs[k][0], logs[k][2], label=str(k) + ' (valid)')
plt.title('Loss v. Epoch')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()
In [ ]:
%matplotlib inline
for k in logs.keys():
plt.plot(logs[k][0], logs[k][3], label=str(k) + ' (train)')
plt.plot(logs[k][0], logs[k][4], label=str(k) + ' (valid)')
plt.title('Perplexity v. Epoch')
plt.xlabel('Epoch')
plt.ylabel('Perplexity')
plt.legend()
plt.show()
In [ ]:
def print_sample(sample):
enc_input = ' '.join([w for w in sample['encoder_input'].split(' ') if w != '<pad>'])
gold = ' '.join([w for w in sample['gold'].split(' ') if w != '<mask>'])
print('Input: '+ enc_input + '\n')
print('Gend: ' + sample['generated'] + '\n')
print('True: ' + gold + '\n')
print('\n')
In [ ]:
for sample in report['train_samples']:
print_sample(sample)
In [ ]:
for sample in report['valid_samples']:
print_sample(sample)
In [ ]:
for sample in report['test_samples']:
print_sample(sample)