This iPython notebook simply explores how to load the handwriting dataset in preparation for training the model on real handwriting data.
First time use: download dataset from IAM Handwriting Database. There are two folders in this dataset that matter: 'ascii' and 'lineStrokes'. Put these in a './data' directory relative to this notebook. When an instance of this model is created for the first time it will parse all of the xml data in these files and save a processed version to a pickle file. This takes about 10 minutes but you only need to do it once
I took some of the code for this class from hardmaru's write-rnn-tensorflow project and modified it to return the ascii labels in addition to the pen stroke data.
In [1]:
import os
import pickle
import numpy as np
import xml.etree.ElementTree as ET
import random
from utils import *
import matplotlib.pyplot as plt
%matplotlib inline
This class is almost identical to the one I used to train the handwriting model
In [2]:
class DataLoader():
def __init__(self, batch_size=50, tsteps=300, scale_factor = 10, U_items=10, limit = 500, alphabet="default"):
self.data_dir = "./data"
self.alphabet = alphabet
self.batch_size = batch_size
self.tsteps = tsteps
self.scale_factor = scale_factor # divide data by this factor
self.limit = limit # removes large noisy gaps in the data
self.U_items = U_items
data_file = os.path.join(self.data_dir, "strokes_training_data.cpkl")
stroke_dir = self.data_dir+"/lineStrokes"
ascii_dir = self.data_dir+"/ascii"
if not (os.path.exists(data_file)) :
print("creating training data cpkl file from raw source")
self.preprocess(stroke_dir, ascii_dir, data_file)
self.load_preprocessed(data_file)
self.reset_batch_pointer()
def preprocess(self, stroke_dir, ascii_dir, data_file):
# create data file from raw xml files from iam handwriting source.
print("Parsing dataset...")
# build the list of xml files
filelist = []
# Set the directory you want to start from
rootDir = stroke_dir
for dirName, subdirList, fileList in os.walk(rootDir):
# print('Found directory: %s' % dirName)
for fname in fileList:
# print('\t%s' % fname)
filelist.append(dirName+"/"+fname)
# function to read each individual xml file
def getStrokes(filename):
tree = ET.parse(filename)
root = tree.getroot()
result = []
x_offset = 1e20
y_offset = 1e20
y_height = 0
for i in range(1, 4):
x_offset = min(x_offset, float(root[0][i].attrib['x']))
y_offset = min(y_offset, float(root[0][i].attrib['y']))
y_height = max(y_height, float(root[0][i].attrib['y']))
y_height -= y_offset
x_offset -= 100
y_offset -= 100
for stroke in root[1].findall('Stroke'):
points = []
for point in stroke.findall('Point'):
points.append([float(point.attrib['x'])-x_offset,float(point.attrib['y'])-y_offset])
result.append(points)
return result
# function to read each individual xml file
def getAscii(filename, line_number):
with open(filename, "r") as f:
s = f.read()
s = s[s.find("CSR"):]
if len(s.split("\n")) > line_number+2:
s = s.split("\n")[line_number+2]
return s
else:
return ""
# converts a list of arrays into a 2d numpy int16 array
def convert_stroke_to_array(stroke):
n_point = 0
for i in range(len(stroke)):
n_point += len(stroke[i])
stroke_data = np.zeros((n_point, 3), dtype=np.int16)
prev_x = 0
prev_y = 0
counter = 0
for j in range(len(stroke)):
for k in range(len(stroke[j])):
stroke_data[counter, 0] = int(stroke[j][k][0]) - prev_x
stroke_data[counter, 1] = int(stroke[j][k][1]) - prev_y
prev_x = int(stroke[j][k][0])
prev_y = int(stroke[j][k][1])
stroke_data[counter, 2] = 0
if (k == (len(stroke[j])-1)): # end of stroke
stroke_data[counter, 2] = 1
counter += 1
return stroke_data
# build stroke database of every xml file inside iam database
strokes = []
asciis = []
for i in range(len(filelist)):
if (filelist[i][-3:] == 'xml'):
stroke_file = filelist[i]
# print 'processing '+stroke_file
stroke = convert_stroke_to_array(getStrokes(stroke_file))
ascii_file = stroke_file.replace("lineStrokes","ascii")[:-7] + ".txt"
line_number = stroke_file[-6:-4]
line_number = int(line_number) - 1
ascii = getAscii(ascii_file, line_number)
if len(ascii) > 10:
strokes.append(stroke)
asciis.append(ascii)
else:
print("======>>>> Line length was too short. Line was: " + str(ascii))
assert(len(strokes)==len(asciis)), "There should be a 1:1 correspondence between stroke data and ascii labels."
f = open(data_file,"wb")
pickle.dump([strokes,asciis], f, protocol=2)
f.close()
print("Finished parsing dataset. Saved {} lines".format(len(strokes)))
def load_preprocessed(self, data_file):
f = open(data_file,"rb")
[self.raw_stroke_data, self.raw_ascii_data] = pickle.load(f, encoding='latin1')
f.close()
# goes thru the list, and only keeps the text entries that have more than tsteps points
self.stroke_data = []
self.ascii_data = []
counter = 0
for i in range(len(self.raw_stroke_data)):
data = self.raw_stroke_data[i]
if len(data) > (self.tsteps+2):
# removes large gaps from the data
data = np.minimum(data, self.limit)
data = np.maximum(data, -self.limit)
data = np.array(data,dtype=np.float32)
data[:,0:2] /= self.scale_factor
self.stroke_data.append(data)
self.ascii_data.append(self.raw_ascii_data[i])
# minus 1, since we want the ydata to be a shifted version of x data
self.num_batches = int(len(self.stroke_data) / self.batch_size)
print("Loaded dataset:")
print(" -> {} individual data points".format(len(self.stroke_data)))
print(" -> {} batches".format(self.num_batches))
def next_batch(self):
# returns a randomised, tsteps sized portion of the training data
x_batch = []
y_batch = []
ascii_list = []
for i in range(self.batch_size):
data = self.stroke_data[self.idx_perm[self.pointer]]
x_batch.append(np.copy(data[:self.tsteps]))
y_batch.append(np.copy(data[1:self.tsteps+1]))
ascii_list.append(self.ascii_data[self.idx_perm[self.pointer]])
self.tick_batch_pointer()
one_hots = [self.one_hot(s) for s in ascii_list]
return x_batch, y_batch, ascii_list, one_hots
def one_hot(self, s):
#index position 0 means "unknown"
if self.alphabet is "default":
alphabet = " abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890"
seq = [alphabet.find(char) + 1 for char in s]
if len(seq) >= self.U_items:
seq = seq[:self.U_items]
else:
seq = seq + [0]*(self.U_items - len(seq))
one_hot = np.zeros((self.U_items,len(alphabet)+1))
one_hot[np.arange(self.U_items),seq] = 1
return one_hot
def tick_batch_pointer(self):
self.pointer += 1
if (self.pointer >= len(self.stroke_data)):
self.reset_batch_pointer()
def reset_batch_pointer(self):
self.idx_perm = np.random.permutation(len(self.stroke_data))
self.pointer = 0
print("pointer reset")
In [3]:
batch_size = 5
tsteps = 700
data_scale = 50
U_items = int(tsteps/20)
data_loader = DataLoader(batch_size=batch_size, tsteps=tsteps, \
scale_factor=data_scale, U_items=U_items, alphabet="default")
In [7]:
def line_plot(strokes, title):
plt.figure(figsize=(20,2))
eos_preds = np.where(strokes[:,-1] == 1)
eos_preds = [0] + list(eos_preds[0]) + [-1] #add start and end indices
for i in range(len(eos_preds)-1):
start = eos_preds[i]+1
stop = eos_preds[i+1]
plt.plot(strokes[start:stop,0], strokes[start:stop,1],'b-', linewidth=2.0)
plt.title(title)
plt.gca().invert_yaxis()
plt.show()
In [8]:
x, y, s, c = data_loader.next_batch()
print(data_loader.pointer)
for i in range(batch_size):
r = x[i]
strokes = r.copy()
strokes[:,:-1] = np.cumsum(r[:,:-1], axis=0)
line_plot(strokes, s[i][:U_items])
These are the five handwriting styles from the blog post. In the future, I will use them to "prime" my model so that it will synthesize handwriting in a particular style
In [6]:
with open(os.path.join('data', 'styles.p'),'rb') as f:
style_strokes, style_strings = pickle.load(f, encoding='latin1')
for i in range(len(style_strokes)):
strokes = style_strokes[i]
strokes[:,:-1] = np.cumsum(strokes[:,:-1], axis=0)
line_plot(strokes, "Style #{}: {}".format(i+1, style_strings[i]))
In [ ]: