Evaluating OCR Models


In [ ]:
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf
import cv2
import time
import math
from collections import Counter
import unidecode
from abc import ABC, abstractmethod

# Import Widgets
from ipywidgets import Button, Text, HBox, VBox
from IPython.display import display, clear_output

sys.path.append('../src')
from ocr import characters
from ocr.normalization import word_normalization, letter_normalization
# Helpers
from ocr.helpers import implt, resize, img_extend
from ocr.datahelpers import load_words_data, idx2char
from ocr.tfhelpers import Model
from ocr.viz import print_progress_bar

Global Variables


In [5]:
# ONLY 'en' is supported right now
LANG = 'en'

Load Trained Model


In [ ]:
charClass_1 = Model('../models/char-clas/' + LANG + '/CharClassifier')

wordClass = Model('../models/word-clas/' + LANG + '/WordClassifier2', 'prediction_infer')
wordClass2 = Model('../models/word-clas/' + LANG + '/SeqRNN/Classifier', 'word_prediction') # None
wordClass3 = Model('../models/word-clas/' + LANG + '/CTC/Classifier2', 'word_prediction')

Load image


In [4]:
images, labels = load_words_data('../data/sets/test.csv', is_csv=True)


for i in range(len(images)):
    print_progress_bar(i, len(images))
    images[i] = word_normalization(
        cv2.cvtColor(images[i], cv2.COLOR_GRAY2RGB),
        60,
        border=False,
        tilt=True,
        hystNorm=True)

if LANG == 'en':
    for i in range(len(labels)):
        labels[i] = unidecode.unidecode(labels[i])
print()        
print('Number of chars:', sum(len(l) for l in labels))


Loading words...
-> Number of words: 267
 |████████████████████████████████████████| 100.0% 

Number of chars: 1356

Testing


In [5]:
# Load Words
WORDS = {}
with open('../data/dictionaries' + LANG + '_50k.txt') as f:
    for line in f:
        if LANG == 'en':
            WORDS[unidecode.unidecode(line.split(" ")[0])] = int(line.split(" ")[1])
        else:
            WORDS[line.split(" ")[0]] = int(line.split(" ")[1])
WORDS = Counter(WORDS)

def P(word, N=sum(WORDS.values())): 
    "Probability of word."
    return WORDS[word] / N

def correction(word): 
    "Most probable spelling correction for word."
    if word in WORDS:
        return word
    return max(candidates(word), key=P)

def candidates(word): 
    "Generate possible spelling corrections for word."
    return (known([word]) or known(edits1(word)) or known(edits2(word)) or [word])

def known(words): 
    "The subset of words that appear in the dictionary of WORDS."
    return set(w for w in words if w in WORDS)

def edits1(word):
    "All edits that are one edit away from `word`."
    
    if LANG == 'cz':
        letters = 'aábcčdďeéěfghiíjklmnňoópqrřsštťuúůvwxyýzž'
    else:
        letters = 'abcdefghijklmnopqrstuvwxyz'
    splits     = [(word[:i], word[i:])    for i in range(len(word) + 1)]
    deletes    = [L + R[1:]               for L, R in splits if R]
    transposes = [L + R[1] + R[0] + R[2:] for L, R in splits if len(R)>1]
    replaces   = [L + c + R[1:]           for L, R in splits if R for c in letters]
    inserts    = [L + c + R               for L, R in splits for c in letters]
    return set(deletes + transposes + replaces + inserts)

def edits2(word): 
    "All edits that are two edits away from `word`."
    return (e2 for e1 in edits1(word) for e2 in edits1(e1))

In [6]:
def cer(r, h):
    """
    From two strings calculate character error rate (insert, delete or substitution).
    """
    r = list(r)
    h = list(h)
    d = np.zeros((len(r) + 1) * (len(h) + 1), dtype=np.uint16)
    d = d.reshape((len(r) + 1, len(h) + 1))
    for i in range(len(r) + 1):
        for j in range(len(h) + 1):
            if i == 0:
                d[0][j] = j
            elif j == 0:
                d[i][0] = i

    for i in range(1, len(r) + 1):
        for j in range(1, len(h) + 1):
            if r[i - 1] == h[j - 1]:
                d[i][j] = d[i - 1][j - 1]
            else:
                substitution = d[i - 1][j - 1] + 1
                insertion = d[i][j - 1] + 1
                deletion = d[i - 1][j] + 1
                d[i][j] = min(substitution, insertion, deletion)
#     result = float(d[len(r)][len(h)]) / len(r) * 100
#     print('CER %.4f %%' % result)
#     print(d[len(r)][len(h)])
    return(d[len(r)][len(h)])

Cycler


In [7]:
class Cycler(ABC):
    """ Abstract cycler class """ 
    def __init__(self,
                 images,
                 labels,
                 charClass,
                 stats="No Stats Provided",
                 slider=(60, 15),
                 ctc=False,
                 seq2seq=False,
                 charRNN=False):
        self.images = images
        self.labels = labels
        self.charClass = charClass
        self.slider = slider
        self.totalChars = sum([len(l) for l in labels])
        self.ctc = ctc
        self.seq2seq = seq2seq
        self.charRNN = charRNN
        self.stats = stats
        
        self.evaluate()
        
    @abstractmethod
    def recogniseWord(self, img):
        pass
    
    def countCorrect(self, pred, label, lower=False):
        correct = 0
        for i in range(min(len(pred), len(label))):
            if ((not lower and pred[i] == label[i])
                 or (lower and pred[i] == label.lower()[i])):
                correct += 1
                
        return correct        

    
    def evaluate(self):
        """ Evaluate accuracy of the word classification """
        print()
        print("STATS:", self.stats)
        print(self.labels[1], ':', self.recogniseWord(self.images[1]))
        start_time = time.time()
        for i in range(len(self.images)):
            word = self.recogniseWord(self.images[i])
#             a = correction(word.lower()
        print("--- %s seconds ---" % round(time.time() - start_time, 2))
        ccer = 0
        correctLetters = 0
        correctWords = 0
        correctWordsCorrection = 0
        correctLettersCorrection = 0
        for i in range(len(self.images)):
            word = self.recogniseWord(self.images[i])
            correctLetters += self.countCorrect(word,
                                         self.labels[i])
            # Correction works only for lower letters
            correctLettersCorrection += self.countCorrect(correction(word.lower()),
                                                       self.labels[i],
                                                       lower=True)
            ccer += cer(word, self.labels[i])
            # Words accuracy
            if word == self.labels[i]:
                correctWords += 1
            if correction(word.lower()) == self.labels[i].lower():
                correctWordsCorrection += 1

        print("Correct/Total: %s / %s" % (correctLetters, self.totalChars))
        print("CERacc: %s %%" % round(100 - ccer/self.totalChars * 100, 4))
        print("Letter Accuracy: %s %%" % round(correctLetters/self.totalChars * 100, 4))
        print("Letter Accuracy with Correction: %s %%" % round(correctLettersCorrection/self.totalChars * 100, 4))
        print("Word Accuracy: %s %%" % round(correctWords/len(self.images) * 100, 4))
        print("Word Accuracy with Correction: %s %%" % round(correctWordsCorrection/len(self.images) * 100, 4))
#         print("--- %s seconds ---" % round(time.time() - start_time, 2))

In [8]:
class WordCycler(Cycler):
    """ Cycle through the words and recognise them """ 
    def recogniseWord(self, img):
        slider = self.slider
        
        if self.ctc:
            step = 10    # 10 for (60, 60) slider
            img = cv2.copyMakeBorder(
                img,
                0, 0, self.slider[1]//2, self.slider[1]//2,
                cv2.BORDER_CONSTANT,
                value=[0, 0, 0])
            img = img_extend(
                img,
                (img.shape[0], max(-(-img.shape[1] // step) * step, self.slider[1] + step)))
            length = (img.shape[1]-slider[1]) // step
            input_seq = np.zeros((1, length, slider[0] * slider[1]), dtype=np.float32)
            input_seq[0][:] = [img[:, loc*step: loc*step + slider[1]].flatten()
                             for loc in range(length)]
            input_seq = input_seq.swapaxes(0, 1)
            
            pred = self.charClass.eval_feed({'inputs:0': input_seq,
                                             'inputs_length:0': [length],
                                             'keep_prob:0': 1})[0]
            
            word = ''
            for i in pred:
                if word == 0 and i != 0:
                    break
                else:
                    word += idx2char(i)
            
        else:       
            length = img.shape[1]//slider[1]

            input_seq = np.zeros((1, length, slider[0] * slider[1]), dtype=np.float32)
            input_seq[0][:] = [img[:, loc * slider[1]: (loc+1) * slider[1]].flatten()
                               for loc in range(length)]                                
            input_seq = input_seq.swapaxes(0, 1)


            if self.seq2seq:
                targets = np.zeros((1, 1), dtype=np.int32)  
                pred = self.charClass.eval_feed({'encoder_inputs:0': input_seq,
                                                 'encoder_inputs_length:0': [length],
                                                 'decoder_targets:0': targets,
                                                 'keep_prob:0': 1})[0]
            else:
                targets = np.zeros((1, 1, 4096), dtype=np.int32)  
                pred = self.charClass.eval_feed({'encoder_inputs:0': input_seq,
                                                 'encoder_inputs_length:0': [length],
                                                 'letter_targets:0': targets,
                                                 'is_training:0': False,
                                                 'keep_prob:0': 1})[0]
            word = ''
            for i in pred:
                if word == 1:
                    break
                else:
                    word += idx2char(i, True)

        return word

In [9]:
class CharCycler(Cycler):
    """ Cycle through the words and recognise them """ 
    def recogniseWord(self, img):
        img = cv2.copyMakeBorder(img,
                                 0, 0, 30, 30,
                                 cv2.BORDER_CONSTANT,
                                 value=[0, 0, 0])
        gaps = characters.segment(img, RNN=True)
        
        chars = []
        for i in range(len(gaps)-1):
            char = img[:, gaps[i]:gaps[i+1]]
            # TODO None type error after treshold
            char, dim = letter_normalization(char, is_thresh=True, dim=True)
            # TODO Test different values
            if dim[0] > 4 and dim[1] > 4:
                chars.append(char.flatten())
                
        chars = np.array(chars)
        word = ''
        if len(chars) != 0:
            if self.charRNN:
                pred = self.charClass.eval_feed({'inputs:0': [chars],
                                                 'length:0': [len(chars)],
                                                 'keep_prob:0': 1})[0]
            else:
                pred = self.charClass.run(chars)
                
            for c in pred:
                # word += CHARS[charIdx]
                word += idx2char(c)        
        return word

In [10]:
# Class cycling through words

WordCycler(images,
           labels,
           wordClass,
           stats='Seq2Seq',
           slider=(60, 2),
           seq2seq=True)

WordCycler(images,
           labels,
           wordClass2,
           stats='Seq2SeqX',
           slider=(60, 2))

WordCycler(images,
           labels,
           wordClass3,
           stats='CTC',
           slider=(60, 60),
           ctc=True)

CharCycler(images,
           labels,
           charClass_1,
           stats='Bi-RNN and CNN',
           charRNN=False)

# Cycler(images,
#        labels,
#        charClass_2,
#        charRNN=True)

# Cycler(images,
#        labels,
#        charClass_3,
#        charRNN=True)


STATS: Seq2Seq
spreads : zpaside
--- 68.82 seconds ---
Correct/Total: 641 / 1356
CERacc: 54.351 %
Letter Accuracy: 47.2714 %
Letter Accuracy with Correction: 46.2389 %
Word Accuracy: 23.5955 %
Word Accuracy with Correction: 31.8352 %

STATS: Seq2SeqX
spreads : spreads
--- 547.08 seconds ---
Correct/Total: 979 / 1356
CERacc: 76.3274 %
Letter Accuracy: 72.1976 %
Letter Accuracy with Correction: 76.2537 %
Word Accuracy: 40.4494 %
Word Accuracy with Correction: 64.0449 %

STATS: CTC
spreads : spreads
--- 761.05 seconds ---
Correct/Total: 985 / 1356
CERacc: 82.9646 %
Letter Accuracy: 72.6401 %
Letter Accuracy with Correction: 77.8761 %
Word Accuracy: 49.8127 %
Word Accuracy with Correction: 68.9139 %

STATS: Bi-RNN and CNN
spreads : spreads
--- 581.9 seconds ---
Correct/Total: 1057 / 1356
CERacc: 84.4395 %
Letter Accuracy: 77.9499 %
Letter Accuracy with Correction: 82.2271 %
Word Accuracy: 64.0449 %
Word Accuracy with Correction: 74.5318 %
Out[10]:
<__main__.CharCycler at 0x7f5d1ddf8d30>