Я решил воспользоваться предоставленным корпусом:
In [1]:
!ls -l corpus
Подключим необходимые библиотеки. Стоит выделить nltk - она используется в основном для демонстрации чего можно ожидать от ngram модели.
NTLK потребуется версии 2.0.5 (чтобы сразу получить NgramModel, в 3.0+ ее придется доставлять отдельно).
In [2]:
import os
import numpy as np
import sys
import nltk
import unicodedata
from collections import Counter, namedtuple
import pickle
import numpy as np
from copy import deepcopy
%matplotlib inline
In [3]:
def find_text_files(basedir):
filepaths = []
for root, dirs, files in os.walk(basedir):
for file in files:
if file.endswith(".txt"):
filepaths.append(os.path.join(root, file))
return filepaths
Воспользуемся семинарским кодом для удаления пунктуации. Но "." нам еще пригодится. И еще несколько модификаторов исходного текста, перед его анализом.
В итоге все соберем в последовательность фильтров.
In [31]:
PUNCTUATION_TRANSLATE_TABLE = {i: None \
for i in range(sys.maxunicode) \
if unicodedata.category(unichr(i)).startswith('P') and unichr(i) not in ['.', '\'']}
def fix_case(document):
words = document.split()
capitalize_counter = Counter()
lower_counter = Counter()
for idx, word in enumerate(words):
lower_word = word.lower()
if word == word.capitalize():
if idx > 0 and words[idx - 1] not in ['.', '?', '!']:
capitalize_counter[lower_word] += 1
else:
lower_counter[lower_word] += 1
for idx, word in enumerate(words):
lower_word = word.lower()
if lower_counter[lower_word] == 0 \
or float(capitalize_counter[lower_word]) / lower_counter[lower_word] > 0.75:
words[idx] = lower_word.capitalize()
else:
words[idx] = lower_word
return ' '.join(words)
def remove_punkt(document):
return document.translate(PUNCTUATION_TRANSLATE_TABLE).replace('.', ' . ')
def preprocessing(document):
document = fix_case(document)
document = remove_punkt(document)
# a long filter chain could be placed here
return document
Сделаем постпроцессинг текста сразу.
In [5]:
def title_sentence(sentence):
words = sentence.split()
words[0] = words[0][0].upper() + words[0][1:]
return ' '.join(words)
def uppercase_start(document):
sentences = map(lambda sentence: sentence.strip(), document.split('.'))
sentences = [sentence for sentence in sentences if sentence != '']
return '. '.join(map(title_sentence, sentences)) + '.'
def glue_single_quote(document):
return document.replace(' \'', '\'')
def postprocessing(document):
document = uppercase_start(document)
document = glue_single_quote(document)
return document
Запустим nltk генератор на основе марковских цепей на нашем корпусе и посмотрим, что от него можно ожидать. При обучении на триграммах. При этом замерим время работы отдельных частей процесса.
In [35]:
import warnings
warnings.filterwarnings('ignore')
ngram_length = 3
text_length = 200
def read_data(path):
corpus = ''
for docpath in find_text_files(path):
with open(docpath) as doc:
doc = doc.read().decode('utf-8')
corpus += preprocessing(doc)
return corpus
def learn(corpus, ngram_length):
tokens = nltk.word_tokenize(corpus)
content_model = nltk.model.ngram.NgramModel(ngram_length, tokens)
return content_model
def generate(content_model):
# text generation without seed to get the seed
starting_words = content_model.generate(100)[-(ngram_length - 1):]
# generate text starting with random words
content = content_model.generate(text_length, starting_words)
return content
corpus = read_data('corpus')
content_model = learn(corpus, ngram_length)
content = generate(content_model)
print postprocessing(' '.join(content).encode('utf-8'))
warnings.filterwarnings('always')
Текст на удивление получился довольно связным. Использование ngram модели сделала его похожим на тексты Джима Моррисона - два слова рядом стоят красиво, но общий смысл где-то за гранью человеческого понимания. Что, в принципе, и ожидалось.
Есть некоторые огрехи постпроцессинга.
Еще один важный вывод: основное время тратится на обучения модели. И сейчас есть достаточно точная оценка к чему стоит стремиться.
Далее идет хардкорная реализация примерно того же. С примерно теми же результатами. С Unicode работает. Как и весь код. Проверял куске Войны и Мира (в конце).
При этом, делаем все быстро, без выделения сущностей. Codestyle старался придерживаться, но ipython notebook не позволяет нормально контролировать длину строки. Поэтому, местами они превышают 80 символов.
Если делать правильно - можно просто копировать NGramModel.
Что было вырезано из финальной версии:
In [6]:
from itertools import izip
def build_ngrams(text, n):
input_list = text.split()
return izip(*[input_list[i:] for i in range(n)])
list(build_ngrams('hello sad cruel cold world', 2))
Out[6]:
По заданию, нужно хранить каскад ngram. По 1 слову, затем по 2 слова. То есть, нужно уметь из ngram получать (n-1)-gram. Еще нам требуется знать распределение продолжений ngram. Выделим этот функционал (получение производных ngram) в класс.
In [38]:
class NGramDistribution(object):
def __init__(self, ngrams):
self.distribution = {}
for long_gram in ngrams:
short_gram = long_gram[0:-1]
last_word = long_gram[-1]
if short_gram not in self.distribution:
self.distribution[short_gram] = {'total': 0, 'counter': Counter()}
self.distribution[short_gram]['total'] += ngrams[long_gram]
self.distribution[short_gram]['counter'].update({last_word: ngrams[long_gram]})
@property
def counter(self):
counter_pairs = [(key, self.distribution[key]['total']) \
for key in self.distribution]
return Counter(dict(counter_pairs))
Не будем обращать внимание на непопулярные ngram'ы.
In [8]:
from itertools import dropwhile
def remove_rare_ngrams(counter):
lower_bound = 1
for key, count in dropwhile(lambda key_count: \
key_count[1] > lower_bound, counter.most_common()):
del counter[key]
return counter
def remove_splited_sentences(counter):
for key in counter.keys():
if key[-1] == '.':
del counter[key]
return counter
def simple_stats_filter(counter):
counter = remove_rare_ngrams(counter)
counter = remove_splited_sentences(counter)
# some others filters
# ...
return counter
In [9]:
from datetime import datetime
class Index(object):
def __init__(self, depth):
self.depth = depth
self.ngram = Counter()
self.normalize_document = lambda doc: doc
self.stats_filter = lambda ngram: ngram
def __reset(self):
self.__dist = None
def add_document(self, document):
normalized_document = self.normalize_document(document)
doc_counter = build_ngrams(normalized_document, self.depth + 1)
self.ngram.update(doc_counter)
self.__reset()
@property
def dist(self):
if self.__dist is not None:
return self.__dist
self.__dist = {}
current_counter = self.stats_filter(self.ngram)
for depth in reversed(range(1, self.depth + 1)):
ngram_dist = NGramDistribution(current_counter)
self.__dist[depth] = ngram_dist.distribution
current_counter = ngram_dist.counter
return self.__dist
np.random.choice ломается, когда сумма по вектору вероятностей отлична от 1. Если мы будем выбирать лидирующую биграмму для старта преложения, то вариантов получится огромное количество (проверял с пустым stats_filter - т.е. на всех биграммах). И из-за неточности floating point арифметики сумма по всем вероятностям незначительно, но отличается от 1, что ведет к поломке функции.
In [39]:
import bisect
class MarkovChain(object):
def __init__(self, dist):
self.dist = dist
cumsum = np.cumsum([ngram['total'] for ngram in dist.values()])
self.__segments = dict(zip(cumsum, dist.keys()))
self.__sorted_keys = sorted(self.__segments.keys())
self.state = self.__start_sentence()
def __start_sentence(self):
rnd = np.random.randint(0, self.__sorted_keys[-1])
position = bisect.bisect_right(self.__sorted_keys, rnd)
return self.__segments[self.__sorted_keys[position]]
@property
def word(self):
if self.state[-1] == '.':
return ' '.join(self.state)
self.state = self.__start_sentence()
drop_word = self.state[0]
next_word = '.'
try:
next_word = np.random.choice(\
self.dist[self.state]['counter'].keys(),
p = map(lambda cnt: \
float(cnt) / self.dist[self.state]['total'],
self.dist[self.state]['counter'].values()))
except KeyError:
pass
self.state = (self.state[1], next_word)
return drop_word
def generate(self, length):
for num in xrange(length):
yield self.word
In [36]:
index = Index(2)
index.normalize_document = preprocessing
index.stats_filter = simple_stats_filter
for docpath in find_text_files('corpus'):
with open(docpath) as doc:
index.add_document(doc.read().decode('utf-8'))
dist = index.dist[2]
Сохраним распределения в файл, как того требует пункт 2.
In [12]:
print len(index.dist[2])
with open('distribution.dat', 'w') as fh:
pickle.dump(index.dist, fh)
In [17]:
!ls -lh distribution.dat
In [14]:
restored_dist = None
with open('distribution.dat') as fh:
restored_dist = pickle.load(fh)
len(restored_dist[2])
Out[14]:
In [33]:
generator = MarkovChain(dist)
content = generator.generate(11000)
print postprocessing(' '.join(content))
PoC. Цель показать работу с unicode, а не какую-то качественную генерацию.
In [37]:
index = Index(2)
index.normalize_document = preprocessing
index.stats_filter = simple_stats_filter
for docpath in find_text_files('russian'):
with open(docpath) as doc:
index.add_document(doc.read().decode('utf-8'))
generator = MarkovChain(index.dist[2])
content = generator.generate(250)
print postprocessing(' '.join(content))