In [ ]:
import numpy as np
import pycrfsuite

In [ ]:
data = np.load('../features.pkl')
print(len(data['X']))
print(len(data['y']))

In [ ]:
random = np.random.RandomState(0)
indices = random.permutation(len(data['X']))
train_indices = indices[:len(indices)//2]
test_indices = indices[len(indices)//2:]

X_train = [data['X'][i] for i in train_indices]
y_train = [data['y'][i] for i in train_indices]

X_test = [data['X'][i] for i in test_indices]
y_test = [data['y'][i] for i in test_indices]

In [ ]:
[[e.split('=')[1] for e in wordf if e.split('=')[0] in ('word', 'known_journal')] for wordf in X_train[2]]
#[type(e) for e in X_train[0]]
#[dict(list((f.split('=') for f in e))) for e in X_train[0]]

In [ ]:
trainer = pycrfsuite.Trainer()

for x, y in zip(X_train, y_train):
    trainer.append(x, y)

trainer.set_params({
    'c1': 0.1,   # coefficient for L1 penalty
    'c2': 1,  # coefficient for L2 penalty
    'max_iterations': 100,  # stop earlier
})
trainer.train('scratch.crfsuite')

In [ ]:
tagger = pycrfsuite.Tagger()
tagger.open('scratch.crfsuite')

In [ ]:
def score(X, y, tagger):
    total_entities = 0
    total_correct = 0 
    for i, (xx, yy) in enumerate(zip(X, y)):
        predicted = tagger.tag(xx)
        total_entities += len(xx)
        total_correct += sum(pred == true for pred, true in zip(predicted, yy))

    print(total_entities, total_correct, total_correct/total_entities)
    
score(X_train, y_train, tagger)
score(X_test, y_test, tagger)

In [ ]:
from pprint import pprint

def errors(X, y, tagger):
    n_errored = 0
    for i, (xx, yy) in enumerate(zip(X, y)):
        predicted = tagger.tag(xx)
        if sum((a!=b) for a, b in zip(predicted, yy)) > 0:
            phrase = [next(f.split('=')[1] for f in word if f.split('=')[0]=='word') for word in xx]
            pprint(list(zip(yy, predicted, ((a==b) for a,b in zip(predicted, yy)), phrase)))
            n_errored += 1
            print()
        if n_errored > 5:
            break

errors(X_train, y_train, tagger)

In [ ]:
print(list(zip(y_test[0], tagger.tag(X_test[0]))))

In [ ]:
# print('|'.join([e[0].split('=')[1] for e in X_test[4]]))
# tagger.tag(X_test[4])

In [ ]:
info = tagger.info()

In [ ]:
from collections import Counter
Counter(info.state_features).most_common(10)

In [ ]:
trainer = pycrfsuite.Trainer()
for x, y in zip(data['X'], data['y']):
    trainer.append(x, y)

trainer.set_params({
    'c1': 0.1,   # coefficient for L1 penalty
    'c2': 1,  # coefficient for L2 penalty
    'max_iterations': 500,  # stop earlier
})
trainer.train('model.crfsuite')

In [ ]: