In [1]:
%load_ext autoreload
%autoreload 2
import os
import json
import tabulate
from collections import Counter
from IPython.display import HTML, display
In [3]:
models_dir = '~/ParlAI/data/controllable_dialogue/wordstat_files' # Enter the path to your wordstat_files directory here
wordstat_files = [fname for fname in os.listdir(models_dir) if 'wordstats.json' in fname]
mf2data = {} # master dict mapping model file name to its data dict
print('Loading %i files...' % len(wordstat_files), end='')
for idx, json_file in enumerate(sorted(wordstat_files)):
mf = json_file[:json_file.index('.wordstats.json')]
print('%i, ' % idx, end='')
with open(os.path.join(models_dir, json_file), "r") as f:
data = json.load(f)
mf2data[mf] = data
print('\nFinished loading files')
In [4]:
# This cell makes Table 6 from the paper
columns = [
'extrep_2gram',
'extrep_nonstopword',
'intrep_2gram',
'intrep_nonstopword',
'partnerrep_2gram',
'avg_nidf',
'lastuttsim',
'question',
]
header_row = ['model name'] + columns
rows = [
# gold data and baselines
'goldresponse',
'convai2_finetuned_baseline.valid.usemodelreply.beam1',
'convai2_finetuned_baseline.valid.usemodelreply.beam20.beamminnbest10',
# repetition control (WD)
'convai2_finetuned_baseline.valid.usemodelreply.beam20.beamminnbest10.WDfeatures:extrep_2gram-0.5',
'convai2_finetuned_baseline.valid.usemodelreply.beam20.beamminnbest10.WDfeatures:extrep_2gram-1.25',
'convai2_finetuned_baseline.valid.usemodelreply.beam20.beamminnbest10.WDfeatures:extrep_2gram-3.5',
'convai2_finetuned_baseline.valid.usemodelreply.beam20.beamminnbest10.WDfeatures:extrep_2gram-1e+20',
'convai2_finetuned_baseline.valid.usemodelreply.beam20.beamminnbest10.WDfeatures:extrep_2gram-3.5_extrep_nonstopword-1e+20_intrep_nonstopword-1e+20',
# question control (CT)
'control_questionb11e10.valid.usemodelreply.beam20.beamminnbest10.setcontrols:question0.WDfeatures:extrep_2gram-3.5_extrep_nonstopword-1e+20_intrep_nonstopword-1e+20',
'control_questionb11e10.valid.usemodelreply.beam20.beamminnbest10.setcontrols:question1.WDfeatures:extrep_2gram-3.5_extrep_nonstopword-1e+20_intrep_nonstopword-1e+20',
'control_questionb11e10.valid.usemodelreply.beam20.beamminnbest10.setcontrols:question4.WDfeatures:extrep_2gram-3.5_extrep_nonstopword-1e+20_intrep_nonstopword-1e+20',
'control_questionb11e10.valid.usemodelreply.beam20.beamminnbest10.setcontrols:question7.WDfeatures:extrep_2gram-3.5_extrep_nonstopword-1e+20_intrep_nonstopword-1e+20',
'control_questionb11e10.valid.usemodelreply.beam20.beamminnbest10.setcontrols:question10.WDfeatures:extrep_2gram-3.5_extrep_nonstopword-1e+20_intrep_nonstopword-1e+20',
'control_questionb11e10.valid.usemodelreply.beam20.beamminnbest10.setcontrols:question10.beamreorder_best_extrep2gram_qn.WDfeatures:extrep_nonstopword-1e+20_intrep_nonstopword-1e+20',
# specificity control (CT)
'control_avgnidf10b10e.valid.usemodelreply.beam20.beamminnbest10.setcontrols:avg_nidf0.WDfeatures:extrep_2gram-3.5_extrep_nonstopword-1e+20_intrep_nonstopword-1e+20',
'control_avgnidf10b10e.valid.usemodelreply.beam20.beamminnbest10.setcontrols:avg_nidf2.WDfeatures:extrep_2gram-3.5_extrep_nonstopword-1e+20_intrep_nonstopword-1e+20',
'control_avgnidf10b10e.valid.usemodelreply.beam20.beamminnbest10.setcontrols:avg_nidf4.WDfeatures:extrep_2gram-3.5_extrep_nonstopword-1e+20_intrep_nonstopword-1e+20',
'control_avgnidf10b10e.valid.usemodelreply.beam20.beamminnbest10.setcontrols:avg_nidf7.WDfeatures:extrep_2gram-3.5_extrep_nonstopword-1e+20_intrep_nonstopword-1e+20',
'control_avgnidf10b10e.valid.usemodelreply.beam20.beamminnbest10.setcontrols:avg_nidf9.WDfeatures:extrep_2gram-3.5_extrep_nonstopword-1e+20_intrep_nonstopword-1e+20',
# specificity control (WD)
'convai2_finetuned_baseline.valid.usemodelreply.beam20.beamminnbest10.WDfeatures:extrep_2gram-3.5_extrep_nonstopword-1e+20_intrep_nonstopword-1e+20_nidf-10.0',
'convai2_finetuned_baseline.valid.usemodelreply.beam20.beamminnbest10.WDfeatures:extrep_2gram-3.5_extrep_nonstopword-1e+20_intrep_nonstopword-1e+20_nidf-4.0',
'convai2_finetuned_baseline.valid.usemodelreply.beam20.beamminnbest10.WDfeatures:extrep_2gram-3.5_extrep_nonstopword-1e+20_intrep_nonstopword-1e+20_nidf4.0',
'convai2_finetuned_baseline.valid.usemodelreply.beam20.beamminnbest10.WDfeatures:extrep_2gram-3.5_extrep_nonstopword-1e+20_intrep_nonstopword-1e+20_nidf6.0',
'convai2_finetuned_baseline.valid.usemodelreply.beam20.beamminnbest10.WDfeatures:extrep_2gram-3.5_extrep_nonstopword-1e+20_intrep_nonstopword-1e+20_nidf8.0',
# response-related control (WD)
'convai2_finetuned_baseline.valid.usemodelreply.beam20.beamminnbest10.WDfeatures:extrep_2gram-3.5_extrep_nonstopword-1e+20_intrep_2gram-1e+20_intrep_nonstopword-1e+20_lastuttsim-10.0_partnerrep_2gram-1e+20',
'convai2_finetuned_baseline.valid.usemodelreply.beam20.beamminnbest10.WDfeatures:extrep_2gram-3.5_extrep_nonstopword-1e+20_intrep_2gram-1e+20_intrep_nonstopword-1e+20_lastuttsim0.0_partnerrep_2gram-1e+20',
'convai2_finetuned_baseline.valid.usemodelreply.beam20.beamminnbest10.WDfeatures:extrep_2gram-3.5_extrep_nonstopword-1e+20_intrep_2gram-1e+20_intrep_nonstopword-1e+20_lastuttsim5.0_partnerrep_2gram-1e+20',
'convai2_finetuned_baseline.valid.usemodelreply.beam20.beamminnbest10.WDfeatures:extrep_2gram-3.5_extrep_nonstopword-1e+20_intrep_2gram-1e+20_intrep_nonstopword-1e+20_lastuttsim10.0_partnerrep_2gram-1e+20',
'convai2_finetuned_baseline.valid.usemodelreply.beam20.beamminnbest10.WDfeatures:extrep_2gram-3.5_extrep_nonstopword-1e+20_intrep_2gram-1e+20_intrep_nonstopword-1e+20_lastuttsim13.0_partnerrep_2gram-1e+20',
]
def mean(l):
return sum(l)/len(l)
def model2row(mf, data):
"""Given the data from a json file, make a row of data for the table"""
row = [mf]
for attr in columns:
sent_attrs = data['sent_attrs']
if attr in sent_attrs:
attr_mean = mean(sent_attrs[attr])
if attr in ['avg_nidf', 'lastuttsim']:
row.append("%.4f" % (attr_mean))
else:
row.append("%.2f%%" % (attr_mean*100))
else:
row.append('')
return row
# Build table
table = [header_row]
for mf in rows:
data = mf2data[mf]
table.append(model2row(mf, data))
html = HTML(tabulate.tabulate(table, tablefmt='html', stralign='center'))
html.data = html.data.replace("text-align: center;", "text-align: left;") # fix left-alignment
display(html)
In [7]:
mf = 'convai2_finetuned_baseline.valid.usemodelreply.beam20.beamminnbest10' # beam search baseline
num_show = 100 # Show the top 100 most common utterances
def show_preds(mf, num_show=None):
counter = Counter()
preds = mf2data[mf]['word_statistics']['pred_list'] # this is the normalized version; use pure_pred_list for unnormalized
counter.update(preds)
num_unique = len([p for p,count in counter.items() if count==1])
print("%% of utterances that are unique: %.2f%% (%i/%i)\n" % (num_unique*100/sum(counter.values()), num_unique, sum(counter.values())))
print("COUNT UTTERANCE")
for p, count in counter.most_common(num_show):
print("%5i %s" % (count, p))
show_preds(mf, num_show)