Extract boxes with PyTorch


In [1]:
# !pip install -U toposort
# !pip install pydensecrf
# !pip install pyro-ppl
# !pip install fastcluster

In [2]:
%load_ext autoreload
%autoreload 2

import multiprocessing

if __name__ == '__main__':
    multiprocessing.set_start_method('spawn', force=True)

import os, glob, json, tqdm, pandas, pickle, rtree, gc, toposort, joblib, tabulate, lxml.etree, bisect
import matplotlib.pyplot as plt
%pylab inline

import torch
from torch import nn
from torch.autograd import Variable
from torch.nn import functional as F
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader

# import pydensecrf.densecrf as dcrf
# from pydensecrf.utils import create_pairwise_bilateral

from imgaug import augmenters as iaa
from imgaug import imgaug as ia
from PIL import Image

from sklearn import metrics as sk_metrics
import scipy.cluster
import heapq
# import fastcluster

from IPython.display import display

from intracell_utils.data_source import *

from prepare_images_utils import *
from latex_dataset import *


Populating the interactive namespace from numpy and matplotlib

Prepare data source

Table detection


In [3]:
DET_SRC_DIR = './data/arxiv/inout_pairs/'
all_det_image_ids = [fname[:-9]
                     for fname in glob.glob(os.path.join(DET_SRC_DIR, '*_out.json'))]
random.shuffle(all_det_image_ids)

In [4]:
# all_det_image_ids, strange_det_images = leave_only_valid_samples(all_det_image_ids, prepare_det_batch, fake_imgaug_pipeline)
all_det_image_ids, strange_det_images = all_det_image_ids, []
print('all_det_image_ids', len(all_det_image_ids), 'strange_det_images', len(strange_det_images))
TOTAL_DET_SAMPLES = len(all_det_image_ids)
TRAIN_DET_SAMPLES = int(TOTAL_DET_SAMPLES * 0.8)
VAL_DET_SAMPLES = TOTAL_DET_SAMPLES - TRAIN_DET_SAMPLES
train_det_image_ids = all_det_image_ids[:TRAIN_DET_SAMPLES]
val_det_image_ids = all_det_image_ids[TRAIN_DET_SAMPLES:]
print('train', len(train_det_image_ids), 'val', len(val_det_image_ids))


all_det_image_ids 18208 strange_det_images 0
train 14566 val 3642

Internal table structure segmentation


In [5]:
# INT_SRC_DIR = './data/generated/full/src/'
INT_SRC_DIR = './data/generated_with_char_info/big_simple_lined/src/'
all_int_image_ids = [fname[:-9]
                     for fname in glob.glob(os.path.join(INT_SRC_DIR, '*_out.json'))]
random.shuffle(all_int_image_ids)
len(all_int_image_ids)


Out[5]:
2000

In [6]:
# all_int_image_ids, strange_int_images = leave_only_valid_samples(all_int_image_ids, prepare_int_batch, imgaug_pipeline)
strange_int_images = []
print(len(all_int_image_ids), len(strange_int_images))
TOTAL_INT_SAMPLES = len(all_int_image_ids)
TRAIN_INT_SAMPLES = int(TOTAL_INT_SAMPLES * 0.8)
VAL_INT_SAMPLES = TOTAL_INT_SAMPLES - TRAIN_INT_SAMPLES
train_int_image_ids = all_int_image_ids[:TRAIN_INT_SAMPLES]
val_int_image_ids = all_int_image_ids[TRAIN_INT_SAMPLES:]
print('train', len(train_int_image_ids), 'val', len(val_int_image_ids))


2000 0
train 1600 val 400

In [7]:
vocab = build_vocab(all_int_image_ids)


100%|██████████| 2000/2000 [00:34<00:00, 57.47it/s]

Generators playground


In [8]:
# def get_demo(image_id):
#     dirname = os.path.dirname(os.path.dirname(image_id))
#     filename = os.path.basename(image_id)
#     return load_image_opaque(os.path.join(dirname, 'demo', filename + '_demo.png'))
# 
# get_demo(all_int_image_ids[0])

In [9]:
# qq = prepare_int_batch(all_int_image_ids[0:10], imgaug_pipeline)
# qq = prepare_int_ext_batch(all_int_image_ids[10:11], fake_imgaug_pipeline, vocab)
# qq = prepare_det_batch(all_det_image_ids[:10], imgaug_pipeline)

In [10]:
# arr_to_img(mask_to_img(qq[4][0]))

In [11]:
# arr_to_img(mask_to_img(qq[2][5]))

Compare data generators


In [12]:
# train_gen_mt = DataLoader(SegmDataset(train_image_ids, imgaug_pipeline),
#                           batch_size=8,
#                           shuffle=True,
#                           num_workers=4)
# train_gen_mt_iter = iter(train_gen_mt)
# _ = next(train_gen_mt_iter)

In [13]:
# %%prun
# for _ in range(10):
#     next(train_gen_mt_iter)

In [14]:
# train_gen_st = data_gen(train_image_ids, imgaug_pipeline, batch_size=8)
# train_gen_st_iter = iter(train_gen_st)
# _ = next(train_gen_st_iter)

In [15]:
# %%prun
# for _ in range(10):
#     next(train_gen_st_iter)

Define losses and metrics

Discrete


In [62]:
def unsharp_mask(image, kernel=(9, 9), sigma=10, power=0.7):
    blurred = cv2.GaussianBlur(image, kernel, sigma)
    return cv2.addWeighted(image, 1.0, blurred, -power, 0, blurred)


# def get_all_boxes(mask, min_area=100, min_size=(5, 5), threshold=0.5):
def get_all_boxes(mask, min_area=1, min_size=(1, 1)):
    result = []
#     binarized_mask = cv2.adaptiveThreshold((mask*255).astype('uint8'),
#                                            255,
#                                            cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
#                                            cv2.THRESH_BINARY,
#                                            11,
#                                            50)
    mask_img = (mask*255).astype('uint8')
#     display('src mask')
#     display(arr_to_img((mask_img / 255).astype('float32')))
    
#     mask_img = unsharp_mask(mask_img)
    
#     display('unsharp mask')
#     display(arr_to_img((mask_img / 255).astype('float32')))
    
    _, binarized_mask = cv2.threshold(mask_img,
                                      0,
                                      255,
                                      cv2.THRESH_BINARY+cv2.THRESH_OTSU)
    binarized_mask[:, 0] = binarized_mask[:, -1] = binarized_mask[0, :] = binarized_mask[-1, :] = 0
#     display(arr_to_img((binarized_mask / 255).astype('float32')))
    contours = cv2.findContours(binarized_mask,
                                cv2.RETR_LIST,
                                cv2.CHAIN_APPROX_SIMPLE)[1]
    for cnt in contours:
        if cv2.contourArea(cnt) < min_area:
            continue
        x, y, w, h = cv2.boundingRect(cnt)
        if h < min_size[0] or w < min_size[1]:
            continue
        result.append((y, x, y+h, x+w))
    result.sort()
    return result


def filter_boxes_by_overlap(boxes, min_overlap=0.9):
    boxes = list(boxes)
    boxes.sort(key=box_area)
    while True:
        boxes_to_remove = set()
        for i, cur in enumerate(boxes):
            for j, bigger in enumerate(boxes[i+1:]):
                overlap = box_inter_area(cur, bigger) / box_area(cur)
                if overlap >= min_overlap:
                    boxes_to_remove.add(i)
                    break
        if len(boxes_to_remove) == 0:
            break
        boxes = [b for i, b in enumerate(boxes) if i not in boxes_to_remove]
    return boxes
        

def get_boxes_by_channel(pred, **kwargs):
    pred_boxes = collections.defaultdict(list)
    for channel in range(pred.shape[0]):
        pred_boxes[channel] = filter_boxes_by_overlap(get_all_boxes(pred[channel], **kwargs))
    return pred_boxes


def calc_dice(box, other):
    return 2 * box_inter_area(box, other) / (box_area(box) + box_area(other))


def calc_inter_over_other_area(box, other):
    return box_inter_area(box, other) / box_area(other)


def find_closest_box(box, others, min_overlap=0.8, calc_overlap=calc_dice):
    best_overlap = max(0, min_overlap - 1e-3)
    best_idx = None
    for i, other in enumerate(others):
        overlap = calc_overlap(box, other)
        if overlap > best_overlap:
            best_idx = i
            best_overlap = overlap
    return best_idx


def match_boxes_by_biggest_overlap_1to1(gold_regions, pred_regions):
    overlaps = [(gi, pi, intersection, len(pelems - gelems))
                for gi, gelems in enumerate(gold_regions)
                for pi, pelems in enumerate(pred_regions)
                for intersection in [len(gelems & pelems)]
                if intersection > 0]
    overlaps = numpy.array(overlaps)
    
    g_idx = collections.defaultdict(set)
    p_idx = collections.defaultdict(set)
    for i, (gi, pi, _, _) in enumerate(overlaps):
        g_idx[gi].add(i)
        p_idx[pi].add(i)

    matched_golds = set()
    matched_preds = set()
    matched_pairs = set()

    for gi, pi, intersection, foreign in overlaps:
        if gi in matched_golds or pi in matched_preds:
            continue
        matched_golds.add(gi)
        matched_preds.add(pi)
        matched_pairs.add((gi, pi))
    return [(gold_regions[gi], pred_regions[pi]) for gi, pi in matched_pairs]


def classify_boxes(pred_boxes, gold_boxes, strictness=0.8, calc_overlap=calc_dice):
    true_positive = []
    false_positive = []

    found_gold = set()
    pred_i_to_gold_i = {}
    for i, box in enumerate(pred_boxes):
        closest_gold_i = find_closest_box(box, gold_boxes, min_overlap=strictness, calc_overlap=calc_overlap)
        if not closest_gold_i is None:
            pred_i_to_gold_i[i] = closest_gold_i
            true_positive.append(i)
            found_gold.add(closest_gold_i)
        else:
            false_positive.append(i)

    false_negative = set(range(len(gold_boxes))) - found_gold
    return (true_positive, false_positive, false_negative, pred_i_to_gold_i)


def calc_precision(tp, fp, fn):
    denom = float(tp + fp)
    return (tp / denom) if denom > 1e-4 else 0.0


def calc_recall(tp, fp, fn):
    denom = float(tp + fn)
    return (tp / denom) if denom > 1e-4 else 0.0


def calc_f1(tp, fp, fn):
    p = calc_precision(tp, fp, fn)
    r = calc_recall(tp, fp, fn)
    denom = p + r
    return (2 * p * r / denom) if denom > 1e-4 else 0.0


def calc_metric_over_boxes(pred_boxes, gold_boxes, metric, strictness=0.8, calc_overlap=calc_dice):
    result = []
    for ch in range(len(gold_boxes)):
        tp, fp, fn, _ = classify_boxes(pred_boxes.get(ch, []),
                                       gold_boxes.get(ch, []),
                                       strictness=strictness,
                                       calc_overlap=calc_overlap)
        tp, fp, fn = len(tp), len(fp), len(fn)
        result.append(metric(tp, fp, fn))
    return result


def box_match_single_image(pred, gold_boxes, metric, strictness=0.8, calc_overlap=calc_dice):
    pred_boxes = get_boxes_by_channel(pred)
    return calc_metric_over_boxes(pred_boxes,
                                  gold_boxes,
                                  metric,
                                  strictness=strictness,
                                  calc_overlap=calc_overlap)


def box_match_batch(pred, gold_boxes, metric, strictness=0.8, calc_overlap=calc_dice):
    if not isinstance(pred, numpy.ndarray):
        pred = pred.data.cpu().numpy()
    image_metrics = [box_match_single_image(pred[i],
                                            gold_boxes[i],
                                            metric,
                                            strictness=strictness,
                                            calc_overlap=calc_overlap)
                     for i in range(pred.shape[0])]
    return numpy.array(image_metrics).mean(0)


def box_match_precision(pred, target, gold_boxes, strictness=0.5):
    return box_match_batch(pred, gold_boxes, calc_precision,
                           strictness=strictness,
                           calc_overlap=calc_dice)


def box_match_recall(pred, target, gold_boxes, strictness=0.5):
    return box_match_batch(pred, gold_boxes, calc_recall,
                           strictness=strictness,
                           calc_overlap=calc_inter_over_other_area)


def box_match_f1(pred, target, gold_boxes, strictness=0.5):
    return box_match_batch(pred, gold_boxes, calc_f1,
                           strictness=strictness,
                           calc_overlap=calc_inter_over_other_area)


def fill_boxes_on_mask_single_image(mask, max_iterations=3):
    for _ in range(max_iterations):
        boxes_dict = get_boxes_by_channel(mask)
        boxes_lst = [boxes_dict[i] for i in range(mask.shape[0])]
        mask = make_mask_for_nn_base(mask.shape[1:], mask.shape[0], boxes_lst)
    return mask


def fill_boxes_on_mask_batch(mask_batch, max_iterations=3):
    result = numpy.zeros_like(mask_batch)
    for i in range(mask_batch.shape[0]):
        result[i] = fill_boxes_on_mask_single_image(mask_batch[i], max_iterations=max_iterations)
    return result


# box_match_precision(numpy.tile(numpy.array([[[[1, 1, 1, 1], [1, 1, 1, 1]]]]), (1, 5, 1, 1)),
#                     None,
#                     [{2 : [[0, 0, 2, 4]]}])

Table level


In [17]:
def flatten_cell_dependencies(grid):
    full_deps = { i : set(itertools.chain.from_iterable(neighbors.values()))
                 for i, neighbors in enumerate(grid) }
    for cell_i, neigbors in full_deps.items():
        neighs_to_remove = set()
        for neigh in neigbors:
            if cell_i in full_deps[neigh]:
                neighs_to_remove.add(neigh)
        neigbors -= neighs_to_remove
    return full_deps


def filter_boxes_by_real_pixels(boxes, image, max_mean=0.99, min_mean=0.01):
    result = []
    for box in boxes:
        y1, x1, y2, x2 = box
        mean = image[y1:y2+1, x1:x2+1].mean()
        if mean >= min_mean and mean <= max_mean:
            result.append(box)
    return result


def merge_close_cells(cells, dx=0.5, dy=0.5):
    cells = list(cells)
    idx = rtree.index.Index(interleaved=True)
    idx2cell = {}
    for i, c in enumerate(cells):
        idx.insert(i, c)
        idx2cell[i] = c
    next_id = len(idx2cell)

    merged_anything = True
    while merged_anything:
        merged_anything = False
        for i, (y1, x1, y2, x2) in idx2cell.items():
            neighbors = set()
            neighbors.update(idx.intersection((y1-dy, x1, y2+dy, x2)))
            neighbors.update(idx.intersection((y1, x1-dx, y2, x2+dx)))
            neighbors = {j for j in neighbors if j in idx2cell}
            if len(neighbors) > 1:
                idx2cell[next_id] = just_box_union([idx2cell[j] for j in neighbors])
                next_id += 1
                for j in neighbors:
                    del idx2cell[j]
                merged_anything = True
                break
    return list(idx2cell.values())


NO_GRID = (None, (), (), (), {})
def table_grid_from_intracell_mask(mask, input_image=None, box_content_getter=None, content_filter=None):
    body_candidates = get_all_boxes(mask[0])
#     if len(body_candidates) == 0:
#         return NO_GRID
    body = (0, 0, 0, 0) #get_biggest_box(body_candidates)
#     cells = filter_boxes_by_overlap(get_all_boxes(mask[1]))
    cells = get_all_boxes(mask[1])
    if len(cells) == 0:
        return NO_GRID
    if not input_image is None:
        cells = filter_boxes_by_real_pixels(cells, input_image)
    if not box_content_getter is None:
        cell_contents = [box_content_getter(cell) for cell in cells]
        cells_with_contents = [(cell, content)
                               for cell, content in zip(cells, cell_contents)
                               if content]
        if content_filter is None:
            content_filter = lambda x: True
        cells = [cell for cell, content in cells_with_contents if content_filter(content)]

    cells = merge_close_cells(cells)
    grid = make_grid(cells)
    full_deps = flatten_cell_dependencies(grid)
    cell_idx = list(full_deps.keys()) # toposort.toposort_flatten(full_deps)
    intracell_relations = {}
    for cur_cell_i in cell_idx:
        cur_cell = cells[cur_cell_i]
        for direction, dir_neighbors in grid[cur_cell_i].items():
            for neigh_cell_i in dir_neighbors:
                intracell_relations[(cur_cell_i, neigh_cell_i)] = direction
    return (body, cells, grid, cell_idx, intracell_relations)


def reconstruct_table_from_intracell_mask(mask, interbox_classifier):
    body, cells, grid, cell_idx, intracell_relations = table_grid_from_intracell_mask(mask)
    intracell_space_classes = { k : interbox_classifier([mask], [pair])
                               for pair in intracell_relations }
    return reconstruct_table_from_grid(body, cells, grid, cell_idx, intracell_space_classes)


def boxes_by_channel_from_reconstructed_table(mask):
    table_info = reconstruct_table_from_intracell_mask(mask)
    if table_info is None:
        return {}
    body, cells, rows, cols = table_info
    result = { 1 : [body] }
    result[2] = cells
    result[3] = rows
    result[4] = cols
    return result


def table_level_metric_single_image(pred, gold_boxes, metric, strictness=0.5):
    pred_boxes = boxes_by_channel_from_reconstructed_table(pred)
    return calc_metric_over_boxes(pred_boxes, gold_boxes, calc_precision, strictness=strictness)


def table_level_metric_by_batch(pred, gold_boxes, metric, strictness=0.5):
    if not isinstance(pred, numpy.ndarray):
        pred = pred.data.cpu().numpy()
    image_metrics = [table_level_metric_single_image(pred[i],
                                                     gold_boxes[i],
                                                     metric,
                                                     strictness=strictness)
                     for i in range(pred.shape[0])]
    return numpy.array(image_metrics).mean(0)

def table_level_precision(pred, target, gold_boxes, strictness=0.5):
    return table_level_metric_by_batch(pred, gold_boxes, calc_precision, strictness=strictness)


def table_level_recall(pred, target, gold_boxes, strictness=0.5):
    return table_level_metric_by_batch(pred, gold_boxes, calc_recall, strictness=strictness)

Fuzzy

$Dice(p, t, w) = 1 - \frac{ p \cdot t + 1 }{ p + t + 1 }$

$WDice(p, t, w) = 1 - \frac{ p \cdot t \cdot w^{-1} + 1 }{ p + t + p \cdot (1 - t) \cdot w + 1 }$


In [18]:
DICE_SMOOTH = 1.0
def dice_coef(pred, target, gold_boxes):
    intersection = pred * target
    union = pred + target
    return ((2. * intersection.sum(3).sum(2).sum(0) + DICE_SMOOTH) /
            (union.sum(3).sum(2).sum(0) + DICE_SMOOTH))


def px_precision(pred, target, gold_boxes, threshold=0.5):
    pred = (pred >= threshold).float()
    target = (target >= threshold).float()
    tp = (pred * target).float().sum(3).sum(2).sum(0)
    fp = ((target - pred) < 0).float().sum(3).sum(2).sum(0)
    denum = tp + fp
    return tp / (denum + (denum == 0).float())


def px_recall(pred, target, gold_boxes, threshold=0.5):
    pred = (pred >= threshold).float()
    target = (target >= threshold).float()
    tp = (pred * target).float().sum(3).sum(2).sum(0)
    fn = ((pred - target) < 0).float().sum(3).sum(2).sum(0)
    denum = tp + fn
    return tp / (denum + (denum == 0).float())


def dice_on_boxes(pred, target, gold_boxes):
    return dice_coef(fill_boxes_on_mask_batch(pred.data.cpu().numpy()), target.data.cpu().numpy(), gold_boxes)


def make_single_channel(f, channel):
    def _impl(pred, target):
        return f(pred[:, channel:channel+1], target[:, channel:channel+1])
    return _impl


def make_cpu(f):
    def _impl(pred, target):
        return f(pred.cpu(), target.cpu())
    return _impl


def _make_key_from_metric_title(txt):
    levels = txt.split('\n', 1)
    if len(levels) < 2:
        levels.append('value')
    return tuple(levels)


def compactify_metrics_table(metrics_df):
    data = { _make_key_from_metric_title(c) : '{:.2f}±{:.2f}'.format(metrics_df[c]['mean'], metrics_df[c]['std'])
            for c in metrics_df.columns }
    idx = pandas.MultiIndex.from_tuples(data.keys())
    return pandas.Series(data, index=idx).unstack(level=0)


def format_metrics_table_md(df):
    df = compactify_metrics_table(df.describe())
    print(tabulate.tabulate(df, df.columns, tablefmt='pipe'))


TRAIN_DET_METRICS = {'d' : (dice_coef, DET_MASK_CHANNELS),
                     }
VAL_DET_METRICS = {'d' : (dice_coef, DET_MASK_CHANNELS),
                   'bd' : (dice_on_boxes, DET_MASK_CHANNELS),
                   'px_p' : (px_precision, DET_MASK_CHANNELS),
                   'px_r' : (px_recall, DET_MASK_CHANNELS)
                   }


TRAIN_INT_METRICS = {'d' : (dice_coef, INT_MASK_CHANNELS),
#                      'bp' : (box_match_precision, INT_MASK_CHANNELS),
#                      'br' : (box_match_recall, INT_MASK_CHANNELS),
#                      'bf' : (box_match_f1, INT_MASK_CHANNELS),
#                      'px_p' : (px_precision, INT_MASK_CHANNELS),
#                      'px_r' : (px_recall, INT_MASK_CHANNELS)
                     }
VAL_INT_METRICS = {'d' : (dice_coef, INT_MASK_CHANNELS),
#                'bp' : (box_match_precision, INT_MASK_CHANNELS),
#                'br' : (box_match_recall, INT_MASK_CHANNELS)
#                'tp' : (table_level_precision, INT_MASK_CHANNELS),
#                'tr' : (table_level_recall, INT_MASK_CHANNELS),
                   'px_p' : (px_precision, INT_MASK_CHANNELS),
                   'px_r' : (px_recall, INT_MASK_CHANNELS)
                   }
TEST_INT_METRICS = {'d' : (dice_coef, INT_MASK_CHANNELS),
                    'bp' : (box_match_precision, INT_MASK_CHANNELS),
                    'br' : (box_match_recall, INT_MASK_CHANNELS),
                    'bf' : (box_match_f1, INT_MASK_CHANNELS),
                    'px_p' : (px_precision, INT_MASK_CHANNELS),
                    'px_r' : (px_recall, INT_MASK_CHANNELS)
                    }

IID Classification


In [19]:
def prepare_metric_data_by_columns(pred, target_idx):
    all_rows = list(range(len(target_idx)))

    pred = pred.data.cpu().numpy()
    pred_idx = pred.argmax(axis=1)
    pred[:] = 0
    pred[all_rows, pred_idx] = 1

    target = numpy.zeros_like(pred, dtype='int')
    target[all_rows, target_idx.data.cpu().numpy()] = 1

    tp = numpy.array([(pred[:, col] == target[:, col]).sum()
                      for col in range(pred.shape[1])]).astype('float32')
    fp = numpy.array([((pred[:, col] - target[:, col]) > 0).sum()
                      for col in range(pred.shape[1])]).astype('float32')
    fn = numpy.array([((target[:, col] - pred[:, col]) > 0).sum()
                      for col in range(pred.shape[1])]).astype('float32')
    return tp, fp, fn


def precision_by_col(pred, target):
    tp, fp, fn = prepare_metric_data_by_columns(pred, target)
    return numpy.nan_to_num(tp / (tp + fp))


def recall_by_col(pred, target):
    tp, fp, fn = prepare_metric_data_by_columns(pred, target)
    return numpy.nan_to_num(tp / (tp + fn))


def f1_by_col(pred, target):
    tp, fp, fn = prepare_metric_data_by_columns(pred, target)
    p = tp / (tp + fp)
    r = tp / (tp + fn)
    return numpy.nan_to_num(2 * p * r / (p + r))


def var2flatarray(x):
    return x.data.cpu().numpy().reshape(-1)


def binary_precision(pred, target, *args, **kwargs):
    return [sk_metrics.precision_score(var2flatarray(target) > 0.5,
                                       var2flatarray(pred) > 0.5)]


def binary_recall(pred, target, *args, **kwargs):
    return [sk_metrics.recall_score(var2flatarray(target) > 0.5,
                                    var2flatarray(pred) > 0.5)]

def binary_f1(pred, target, *args, **kwargs):
    return [sk_metrics.f1_score(var2flatarray(target) > 0.5,
                                var2flatarray(pred) > 0.5)]


CELL_CLASSES = ('no_rel', 'same_row', 'same_col', 'same_cell_h', 'same_cell_v')
CELL_TRAIN_METRICS = {'f1' : (f1_by_col, CELL_CLASSES),
                      'p' : (precision_by_col, CELL_CLASSES),
                      'r' : (recall_by_col, CELL_CLASSES) }
CELL_VAL_METRICS = {'f1' : (f1_by_col, CELL_CLASSES),
                    'p' : (precision_by_col, CELL_CLASSES),
                    'r' : (recall_by_col, CELL_CLASSES) }

Superpixel Postprocessing


In [20]:
def mask_from_superpixels(grid_ys, grid_xs, superpixel_values):
    result = numpy.zeros((grid_ys[-1], grid_xs[-1]), dtype='float32')
    for row in range(len(grid_ys) - 1):
        for col in range(len(grid_xs) - 1):
            result[grid_ys[row]:grid_ys[row+1]+1,
                   grid_xs[col]:grid_xs[col+1]+1] = superpixel_values[row, col]
    return result


def mask_from_superpixels_batch(pred, grids):
    return numpy.expand_dims(numpy.array([mask_from_superpixels(grid_ys, grid_xs, cur_pred[0])
                                          for cur_pred, (grid_ys, grid_xs) in zip(pred, grids)]),
                             1)


def sp_pp_box_match_precision(pred, target, gold_boxes, grids=None, strictness=0.99999):
    pred = mask_from_superpixels_batch(pred.data.cpu().numpy(),
                                       grids)
    return box_match_precision(pred, target, gold_boxes)


def sp_pp_box_match_recall(pred, target, gold_boxes, grids=None, strictness=0.99999):
    pred = mask_from_superpixels_batch(pred.data.cpu().numpy(),
                                       grids)
    return box_match_recall(pred, target, gold_boxes)


def sp_pp_box_match_f1(pred, target, gold_boxes, grids=None, strictness=0.99999):
    pred = mask_from_superpixels_batch(pred.data.cpu().numpy(),
                                       grids)
    return box_match_f1(pred, target, gold_boxes)


SP_CLASSES = ('sp',)
SP_PP_METRICS = {'f1' : (binary_f1, SP_CLASSES),
                 'p' : (binary_precision, SP_CLASSES),
                 'r' : (binary_recall, SP_CLASSES),
                 'bp' : (sp_pp_box_match_precision, SP_CLASSES),
                 'br' : (sp_pp_box_match_recall, SP_CLASSES),
                 'bf' : (sp_pp_box_match_f1, SP_CLASSES)
                }

Losses


In [21]:
def dice_score(pred, target):
    intersection = pred * target
    union = pred + target
    return ((2. * intersection.sum() + DICE_SMOOTH) /
                (union.sum() + DICE_SMOOTH))

def dice_loss(pred, target, weights):
    return 1 - dice_score(pred, target)


def weighted_dice_loss(pred, target, weights):
    # the idea is to lower actual intersection in important areas
    inv_weights = 1 / weights
    intersection = pred * inv_weights * target
    
    # the idea is to increase actual predicted values
    # where they have to be zero
    inv_target = 1 - target
    inv_intersection = pred * inv_target * weights
    union = pred + target + inv_intersection

    return 1 - ((2. * intersection.sum() + DICE_SMOOTH) /
                (union.sum() + DICE_SMOOTH))


def dice_bce_loss(pred, target, weights):
#     return dice_loss(pred, target, weights) + F.binary_cross_entropy(pred, target, weights)
    return F.binary_cross_entropy(pred, target, weights) - torch.log(dice_score(pred, target))

Define a network

Basic


In [22]:
def mcuda(x, cuda):
    return x.cuda() if cuda else x


def npten(arr, cuda):
    return mcuda(torch.from_numpy(arr), cuda)


def npvar(arr, cuda):
    if not torch.is_tensor(arr):
        arr = torch.from_numpy(arr)
    return mcuda(Variable(arr), cuda)


def is_module_on_cuda(m):
    return next(m.parameters()).is_cuda


class LocalAttention(nn.Module):
    def __init__(self, in_channels):
        super(LocalAttention, self).__init__()
        self.att = nn.Conv2d(in_channels, 1, (1, 1))

    def forward(self, x):
        unnorm = self.att(x)
        norm = F.softmax(unnorm.view(unnorm.size()[0], -1)).view(*unnorm.size())
        return x * norm


def identity_func(x):
    return x


def masked_average(x, masks):
    return (x * masks).mean(3).mean(2)


class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, dilations=[1], bn=True, out_act=F.relu,
                 conv_layer=nn.Conv2d, norm_layer=nn.BatchNorm2d):
        super(ConvBlock, self).__init__()
        self.bn = norm_layer(int(in_channels))
        assert out_channels % len(dilations) == 0
        channels_per_dilation = int(out_channels // len(dilations))
        self.convs = nn.ModuleList([conv_layer(in_channels,
                                               channels_per_dilation,
                                               kernel_size,
                                               padding=dil,
                                               dilation=dil)
                                    for dil in dilations])
        self.out_act = out_act
        self.out_channels = out_channels

    def forward(self, x):
        x = self.bn(x)
        conv_outs = [conv(x) for conv in self.convs]
        return self.out_act(torch.cat(conv_outs, dim=1))

    @property
    def receptive_field(self):
        dils = numpy.array([c.dilation for c in self.convs])
        kernels = numpy.array([c.kernel_size for c in self.convs])
        kernels_odd = kernels % 2
        return (dils * (kernels - 1) + kernels_odd).max(axis=0)


class ConvBlock1d(ConvBlock):
    def __init__(self, in_channels, out_channels, kernel_size, dilations=[1], bn=True, out_act=F.relu):
        super(ConvBlock1d, self).__init__(in_channels, out_channels, kernel_size,
                                          dilations=dilations,
                                          bn=bn,
                                          out_act=out_act,
                                          conv_layer=nn.Conv1d,
                                          norm_layer=nn.BatchNorm1d)


def ensure_function(x):
    if callable(x):
        return x
    def _f(*args, **kwargs):
        return x
    return _f


class UNet(nn.Module):
    def __init__(self, in_channels=1, out_channels=TOTAL_INT_CLASSES, first_conv_channels=4, depth=2, out_layers=1, conv_kernel=3,
                 enc_dilations=[1], dec_dilations=[1], out_dilations=[1], convblock=ConvBlock):
        super(UNet, self).__init__()
        self.out_channels = out_channels
        enc_channels = [in_channels] + [int(first_conv_channels * (2**step)) for step in range(depth)]
        enc_dilations = ensure_function(enc_dilations)
        dec_dilations = ensure_function(dec_dilations)
        out_dilations = ensure_function(out_dilations)
        self.encoder = nn.ModuleList([convblock(enc_channels[i],
                                                enc_channels[i+1],
                                                kernel_size=conv_kernel,
                                                dilations=enc_dilations(i))
                                      for i in range(depth)])
        bottleneck_channels = enc_channels[-1] * 2
        self.bottleneck = convblock(enc_channels[-1],
                                    bottleneck_channels,
                                    kernel_size=conv_kernel,
                                    dilations=enc_dilations(depth))
        dec_channels = [bottleneck_channels] + enc_channels[:0:-1]

        self.dec_conv = nn.ModuleList([convblock(dec_channels[i],
                                                 dec_channels[i+1],
                                                 kernel_size=conv_kernel,
                                                 dilations=dec_dilations(i))
                                      for i in range(depth)])
        self.dec_deconv = nn.ModuleList([nn.ConvTranspose2d(dec_channels[i],
                                                            dec_channels[i+1],
                                                            (2, 2),
                                                            stride=2)
                                         for i in range(depth)])
        self.out_layers = nn.ModuleList([convblock(dec_channels[-1],
                                                   dec_channels[-1],
                                                   kernel_size=conv_kernel,
                                                   dilations=out_dilations(i))
                                         for i in range(out_layers)])
        self.out_conv = nn.Conv2d(dec_channels[-1],
                                  out_channels,
                                  (1, 1))

    def forward(self, x):
        enc_conv_outs = []
        enc_pool_outs = [x]
        for enc_conv in self.encoder:
            cur_conv_out = enc_conv(enc_pool_outs[-1])
            enc_conv_outs.append(cur_conv_out)
            cur_pool_out = F.max_pool2d(cur_conv_out, (2, 2))
            enc_pool_outs.append(cur_pool_out)

        internal_outs = []
        internal_outs.extend(enc_conv_outs)

        cur_out = self.bottleneck(enc_pool_outs[-1])
        internal_outs.append(cur_out)

        for dec_step, (dec_conv, dec_deconv) in enumerate(zip(self.dec_conv, self.dec_deconv)):
            up = dec_deconv(cur_out)
            cur_out = torch.cat([up, enc_conv_outs[-dec_step-1]], dim=1)
            cur_out = dec_conv(cur_out)
            internal_outs.append(cur_out)

        for out_layer in self.out_layers:
            cur_out = F.relu(out_layer(cur_out))
            internal_outs.append(cur_out)

        final_out = self.out_conv(cur_out)
        internal_outs.append(final_out)

        return F.sigmoid(final_out), internal_outs


def SeparableUNet(*args, **kwargs):
    kwargs['convblock'] = SeparableConvBlock2d
    return UNet(*args, **kwargs)


class SimpleConvFCN(nn.Module):
    def __init__(self, in_channels, out_channels, conv_layers=1, conv_kernel=(3, 3), dilations=[1, 2, 4], conv_block_class=ConvBlock, out_act=F.sigmoid):
        super(SimpleConvFCN, self).__init__()
        conv_channels = [in_channels] + [len(dilations) * (2 ** (i + 1)) for i in range(conv_layers)]
        print('conv channels', conv_channels)
        self.convs = nn.ModuleList([conv_block_class(conv_channels[i],
                                                     conv_channels[i+1],
                                                     conv_kernel,
                                                     dilations=dilations)
                                    for i in range(conv_layers)])
        self.output = nn.Conv2d(conv_channels[-1],
                                out_channels,
                                1)
        self.out_act = out_act

    def forward(self, x):
        for conv in self.convs:
            x = conv(x)
        return self.out_act(self.output(x))

    @property
    def receptive_field(self):
        conv_fields = [c.receptive_field for c in self.convs]
        result = conv_fields[0]
        for i in range(1, len(conv_fields)):
            result += numpy.clip(conv_fields[i] // 2, 1, None)
        return result


class RectRNN(nn.Module):
    def __init__(self, in_channels, out_channels, inner_channels=32, num_layers=1, out_act=F.sigmoid):
        super(RectRNN, self).__init__()
        self.inner_channels = inner_channels
        self.num_layers = num_layers
        self.rows_rnn = nn.LSTM(in_channels,
                                self.inner_channels,
                                num_layers=self.num_layers,
                                bidirectional=True)
        self.cols_rnn = nn.LSTM(in_channels,
                                self.inner_channels,
                                num_layers=self.num_layers,
                                bidirectional=True)
        self.out = nn.Conv2d(inner_channels*4,
                             out_channels,
                             1)
        self.out_act = out_act

    def _make_hidden(self, batch_size, cuda):
        return (mcuda(Variable(torch.randn(self.num_layers * 2, batch_size, self.inner_channels)),
                      cuda),
                mcuda(Variable(torch.randn(self.num_layers * 2, batch_size, self.inner_channels)),
                      cuda))
        
    def forward(self, x):
        """x: (B, C, H, W)"""
        rows_batch_size = x.size(0) * x.size(2)
        rows_in = x.permute(3, 0, 2, 1).contiguous().view(x.size(3), -1, x.size(1)) # (W, B*H, C)
        rows_out, _ = self.rows_rnn(rows_in, self._make_hidden(rows_batch_size, x.data.is_cuda)) # (W, B*H, I)
        rows_rect = rows_out.view(x.size(3), x.size(0), x.size(2), self.inner_channels*2).permute(1, 3, 2, 0).contiguous()

        cols_batch_size = x.size(0) * x.size(3)
        cols_in = x.permute(2, 0, 3, 1).contiguous().view(x.size(2), -1, x.size(1)) # (H, B*W, C)
        cols_out, _ = self.cols_rnn(cols_in, self._make_hidden(cols_batch_size, x.data.is_cuda)) # (H, B*W, I)
        cols_rect = cols_out.view(x.size(2), x.size(0), x.size(3), self.inner_channels*2).permute(1, 3, 0, 2).contiguous()
        
        features = torch.cat([rows_rect, cols_rect], 1)

        return self.out_act(self.out(features))


class DenseNetWrapper(nn.Module):
    def __init__(self, *blocks):
        super(DenseNetWrapper, self).__init__()
        self.blocks = nn.ModuleList(list(blocks))

    def forward(self, x):
        last_out = x
        for block in self.blocks:
            last_out = block(x)
            x = torch.cat([x, last_out], 1)
        return last_out


def DenseRectRNN(in_channels, out_channels, inner_out_channels=32, inner_channels=32, layers=1, inner_act=lambda x: x, out_act=F.sigmoid):
    rnn_out_channels = [in_channels] + [inner_out_channels for _ in range(layers-1)] + [out_channels]
    rnn_in_channels = [int(sum(rnn_out_channels[:i+1])) for i in range(layers+1)]
    rnns = [RectRNN(rnn_in_channels[i],
                    rnn_out_channels[i+1],
                    inner_channels=inner_channels,
                    num_layers=1,
                    out_act=out_act if i == layers-1 else inner_act)
            for i in range(layers)]
    return DenseNetWrapper(*rnns)


class Linear1dMixture(nn.Module):
    def __init__(self, n_sources, n_channels):
        super(Linear1dMixture, self).__init__()
        self.weights = nn.ModuleList([nn.Linear(n_sources, 1, bias=False)
                                      for _ in range(n_channels)])
        self.n_sources = n_sources

    def forward(self, x):
        """
        Expects tensor of shape (n_sources, B, n_channels, *)
        """
        src_size = tuple(x.size())
        # transpose to (n_channels, B, *, n_sources)
        forward_axis_order = (2, 1) + tuple(range(3, len(src_size))) + (0,)
        x = x.permute(*forward_axis_order)
        weighted_sums = [channel_weights(x[i])
                         for i, channel_weights
                         in enumerate(self.weights)]
        print('weighted_sums', weighted_sums[0].size())
        1 / 0
        return torch.cat(weighted_sums) / n_sources
        


class DenseStack(nn.Module):
    def __init__(self, *blocks):
        super(DenseStack, self).__init__()
        self.blocks = nn.ModuleList(list(blocks))

    def forward(self, x):
        outs = []
        for block in self.blocks:
            last_out = block(x)
            outs.append(last_out)
            x = torch.cat([x, last_out], 1)
        return torch.cat([t.unsqueeze(2) for t in outs]).mean(0)


class DenseMixture(nn.Module):
    def __init__(self, out_channels, *blocks):
        super(DenseMixture, self).__init__()
        self.blocks = nn.ModuleList(list(blocks))
        self.mix = Linear1dMixture(len(self.blocks), out_channels)

    def forward(self, x):
        outs = []
        for block in self.blocks:
            last_out = block(x)
            outs.append(last_out)
            x = torch.cat([x, last_out], 1)
        return self.mix(torch.cat([t.unsqueeze(0) for t in outs]))


class StackedUNet1(DenseStack):
    def __init__(self, out_channels=TOTAL_INT_CLASSES):
        super(StackedUNet1, self).__init__(
            UNet(out_channels=out_channels,
                 first_conv_channels=6,
                 depth=4,
                 enc_dilations=[1, 2, 4, 8, 16, 32]),
            SimpleConvFCN(out_channels+1,
                          out_channels)
        )


class DenseNet1(DenseNetWrapper):
    def __init__(self, out_channels=TOTAL_INT_CLASSES):
        super(DenseNet1, self).__init__(
            UNet(out_channels=out_channels,
                 first_conv_channels=6,
                 depth=4,
                 enc_dilations=[1, 2, 4, 8, 16, 32]),
            SimpleConvFCN(out_channels+1,
                          out_channels,
                          dilations=[1, 2, 4, 8, 16, 32]),
            SimpleConvFCN(out_channels*2+1,
                          out_channels,
                          conv_layers=0)
        )


class RowwiseAtt(nn.Module):
    def __init__(self, input_channels, dim):
        super(RowwiseAtt, self).__init__()
        self.dim = dim
        self.weights = nn.Linear(input_channels, 1)

    def forward(self, x):
        """x.size = (N, C, H, W)"""
        att_unnorm = self.weights(x.transpose(1, -1)).transpose(1, -1)
        att_norm = F.softmax(att_unnorm, dim=self.dim)
        weighted_x = x * att_norm.expand_as(x)
        return weighted_x.sum(self.dim)


class SimpleStack(nn.Module):
    def __init__(self, blocks, forward_input=True):
        super(SimpleStack, self).__init__()
        self.blocks = nn.ModuleList(blocks)
        self.out_channels = self.blocks[-1].out_channels
        self.forward_input = forward_input

    def forward(self, x):
        if self.forward_input:
            out = x
        else:
            out = None
        for block in self.blocks:
            last_out = block(x)
            if out is None:
                out = last_out
            else:
                out = torch.cat([out, last_out], dim=1)
            x = out
        return out


class ConvWithMean(ConvBlock):
    def __init__(self, in_channels, *args, **kwargs):
        super(ConvWithMean, self).__init__(in_channels*3, *args, **kwargs)

    def forward(self, x):
        row_mean = x.mean(2).unsqueeze(2).expand(-1, -1, x.size(2), -1)
        col_mean = x.mean(3).unsqueeze(3).expand(-1, -1, -1, x.size(3))
        x = torch.cat([x, row_mean, col_mean], dim=1)
        return super(ConvWithMean, self).forward(x)


class ConvWithMaxPool(ConvBlock):
    def __init__(self, in_channels, *args, **kwargs):
        super(ConvWithMaxPool, self).__init__(in_channels*3, *args, **kwargs)

    def forward(self, x):
        row_max_pool = F.max_pool2d(x, (1, x.size(3))).expand(-1, -1, -1, x.size(3))
        col_max_pool = F.max_pool2d(x, (x.size(2), 1)).expand(-1, -1, x.size(2), -1)
        x = torch.cat([x, row_max_pool, col_max_pool], dim=1)
        return super(ConvWithMaxPool, self).forward(x)


class SimpleDenseNet(nn.Module):
    def __init__(self, out_channels=TOTAL_INT_CLASSES, first_conv_channels=4, depth=2, out_layers=1, conv_kernel=3,
                 enc_dilations=[1], channel_growth_factor=2, convblock=ConvBlock):
        super(SimpleDenseNet, self).__init__()
        channels = [int(1)] + [int(first_conv_channels * (channel_growth_factor**step)) for step in range(depth)]
        self.stack = SimpleStack([convblock(int(sum(channels[:i+1])),
                                            channels[i+1],
                                            kernel_size=conv_kernel,
                                            dilations=enc_dilations)
                                  for i in range(depth)])
        self.out = nn.Conv2d(int(sum(channels)),
                             out_channels,
                             1)

    def forward(self, x):
        x = self.stack(x)
        return F.sigmoid(self.out(x))


class SimpleDenseNetWithMean(SimpleDenseNet):
    def __init__(self, *args, **kwargs):
        kwargs['convblock'] = ConvWithMean
        super(SimpleDenseNetWithMean, self).__init__(*args, **kwargs)

        
class UNetWithMean(UNet):
    def __init__(self, *args, **kwargs):
        kwargs['convblock'] = ConvWithMean
        super(UNetWithMean, self).__init__(*args, **kwargs)


class UNetWithMaxPool(UNet):
    def __init__(self, *args, **kwargs):
        kwargs['convblock'] = ConvWithMaxPool
        super(UNetWithMaxPool, self).__init__(*args, **kwargs)


class RowwiseAttConvStack(SimpleStack):
    def __init__(self, att_dim, in_channels, out_channels, out_channels_growth_factor=1, conv_layers=1, kernel_size=3,
                 dilations=[1], bn=True, out_act=F.relu):
        channels = [in_channels] + [out_channels * (out_channels_growth_factor**(i + 1)) for i in range(conv_layers)]
        super(RowwiseAttConvStack, self).__init__(
            [RowwiseAtt(in_channels,  att_dim)] +
            [ConvBlock1d(int(sum(channels[:i+1])),
                         channels[i+1],
                         kernel_size,
                         dilations=dilations,
                         bn=bn,
                         out_act=out_act)
             for i in range(conv_layers)],
            forward_input=False)
        self.out_channels = int(sum(channels))


class RowColAttConvLayer(nn.Module):
    def __init__(self, in_channels, out_channels, joint_out_act=F.relu, **kwargs):
        super(RowColAttConvLayer, self).__init__()
        self.row_block = RowwiseAttConvStack(3, in_channels, out_channels, **kwargs)
        self.col_block = RowwiseAttConvStack(2, in_channels, out_channels, **kwargs)
        self.out = ConvBlock(in_channels + self.row_block.out_channels + self.col_block.out_channels,
                             out_channels,
                             **kwargs)
        self.joint_out_act = joint_out_act

    def forward(self, x):
        row_att_1d = self.row_block(x)
        row_att_2d = row_att_1d.unsqueeze(3).expand(-1, -1, -1, x.size(3))
        col_att_1d = self.col_block(x)
        col_att_2d = col_att_1d.unsqueeze(2).expand(-1, -1, x.size(2), -1)
        x = torch.cat([x, row_att_2d, col_att_2d], dim=1)
#         out = self.out(x.transpose(1, -1)).transpose(1, -1)
        out = self.out(x)
        return self.joint_out_act(out)


class RowCollAttUNet(UNet):
    def __init__(self, *args, **kwargs):
        kwargs['convblock'] = RowColAttConvLayer
        super(RowCollAttUNet, self).__init__(*args, **kwargs)


class Conv2dHV(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, padding=0, dilation=1):
        assert out_channels % 2 == 0
        assert isinstance(kernel_size, int)
        assert isinstance(padding, int)
        assert isinstance(dilation, int)
        super(Conv2dHV, self).__init__()
        self.h_conv = nn.Conv2d(in_channels,
                                out_channels // 2,
                                (1, kernel_size),
                                padding=(0, padding),
                                dilation=(1, dilation))
        self.v_conv = nn.Conv2d(in_channels,
                                out_channels // 2,
                                (kernel_size, 1),
                                padding=(padding, 0),
                                dilation=(dilation, 1))

    def forward(self, x):
        return torch.cat([self.h_conv(x), self.v_conv(x)], 1)


class ConvBlockHV(ConvBlock):
    def __init__(self, *args, **kwargs):
        kwargs['conv_layer'] = Conv2dHV
        super(ConvBlockHV, self).__init__(*args, **kwargs)


class UNetHV(UNet):
    def __init__(self, *args, **kwargs):
        kwargs['convblock'] = ConvBlockHV
        super(UNetHV, self).__init__(*args, **kwargs)


class Conv2dHVS(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, padding=0, dilation=1):
        assert isinstance(kernel_size, int)
        assert isinstance(padding, int)
        assert isinstance(dilation, int)
        super(Conv2dHVS, self).__init__()
        self.h_conv = nn.Conv2d(in_channels,
                                out_channels,
                                (1, kernel_size),
                                padding=(0, padding),
                                dilation=(1, dilation))
        self.v_conv = nn.Conv2d(out_channels,
                                out_channels,
                                (kernel_size, 1),
                                padding=(padding, 0),
                                dilation=(dilation, 1))

    def forward(self, x):
        return self.v_conv(self.h_conv(x))
    

class ConvBlockHVS(ConvBlock):
    def __init__(self, *args, **kwargs):
        kwargs['conv_layer'] = Conv2dHVS
        super(ConvBlockHVS, self).__init__(*args, **kwargs)


class UNetHVS(UNet):
    def __init__(self, *args, **kwargs):
        kwargs['convblock'] = ConvBlockHVS
        super(UNetHVS, self).__init__(*args, **kwargs)
        

def seq_smooth_reg(x):
    return ((1 - x[:, :-1]) * x[:, 1:]).mean(1)


def contrast_reg(x):
    return ((1 - x) * x).mean(1)


def gauss_reg(x):
    return (x * x).mean(1)


def smooth_contrast_reg(x, contrast_w=1.5):
    return seq_smooth_reg(x) + contrast_w * contrast_reg(x)


def row_col_reg(internal_outs, impl):
    result = 0
    norm = 0.0
    for x in internal_outs:
        # x.size = (B, C, H, W)
        x = F.sigmoid(x)
        size = x.size()
#         rows_pooled = x.mean(3)
        rows_pooled = F.max_pool1d(x.view(size[0], size[1] * size[2], size[3]), size[3]).view(size[0], size[1], size[2])
        rows = rows_pooled.transpose(1, 2) # (B, H, C)
        rows_reg = impl(rows).mean()
        
        cols_pooled = F.max_pool1d(x.transpose(2, 3).contiguous().view(size[0], size[1] * size[3], size[2]), size[2]).view(size[0], size[1], size[3])
        cols = x.max(2)[0].transpose(1, 2) # (B, W, C)
        cols_reg = impl(cols).mean()
        result = rows_reg + cols_reg + result
        norm += 1.0
    return result / norm


def row_col_smooth_contrast_reg(internal_outs):
    return row_col_reg(internal_outs, smooth_contrast_reg)


def run_network(network, generator, num_batches, criterion=dice_bce_loss, optimizer=None, regularizer=None, reg_weight=2e-1, metrics=TRAIN_INT_METRICS, cuda=True):
    metric_values = []
    gen_iter = iter(generator)
    for _ in tqdm.tqdm(range(num_batches)):
        image_ids, images, mask, loss_weights, boxes = next(gen_iter)
        images_var = npvar(images, cuda)
        mask_var = npvar(mask, cuda)
        loss_weights_var = npvar(loss_weights, cuda)
        boxes = [pickle.loads(b) for b in boxes]

        cur_out = network(images_var)
        if isinstance(cur_out, tuple):
            cur_out, internal_outs = cur_out
        else:
            internal_outs = None

        will_regularize = regularizer and internal_outs
        if will_regularize:
            reg_loss = reg_weight * regularizer(internal_outs)
        else:
            reg_loss = npvar(numpy.array([0], dtype='float32'), cuda)

        loss = criterion(cur_out, mask_var, loss_weights_var)
        full_loss = loss + reg_loss

        if optimizer:
            optimizer.zero_grad()
            full_loss.backward()
            nn.utils.clip_grad_norm(network.parameters(), 10)
            optimizer.step()

        cur_metrics = { 'loss' : loss.data[0] }

        if will_regularize:
            cur_metrics['reg'] = reg_loss.data[0]

        for name, (func, elem_names) in metrics.items():
            metric_value = func(cur_out, mask_var, boxes)
            if not isinstance(metric_value, (list, numpy.ndarray)):
                metric_value = metric_value.cpu().data.numpy()
            cur_metrics.update(('\n'.join((name, n)), v) for n, v in zip(elem_names, metric_value) if n)
        metric_values.append(cur_metrics)
    return metric_values


def run_network_ext(network, generator, num_batches, criterion=dice_bce_loss, optimizer=None, regularizer=None, reg_weight=2e-1, metrics=TRAIN_INT_METRICS, cuda=True):
    metric_values = []
    gen_iter = iter(generator)
    for _ in tqdm.tqdm(range(num_batches)):
        image_ids, images, txt_idx, mask, loss_weights, boxes = next(gen_iter)
        images_var = npvar(images, cuda)
        mask_var = npvar(mask, cuda)
        loss_weights_var = npvar(loss_weights, cuda)
        boxes = [pickle.loads(b) for b in boxes]

        cur_out = network(images_var)
        if isinstance(cur_out, tuple):
            cur_out, internal_outs = cur_out
        else:
            internal_outs = None

        will_regularize = regularizer and internal_outs
        if will_regularize:
            reg_loss = reg_weight * regularizer(internal_outs)
        else:
            reg_loss = npvar(numpy.array([0], dtype='float32'), cuda)

        loss = criterion(cur_out, mask_var, loss_weights_var)
        full_loss = loss + reg_loss

        if optimizer:
            optimizer.zero_grad()
            full_loss.backward()
            nn.utils.clip_grad_norm(network.parameters(), 10)
            optimizer.step()

        cur_metrics = { 'loss' : loss.data[0] }

        if will_regularize:
            cur_metrics['reg'] = reg_loss.data[0]

        for name, (func, elem_names) in metrics.items():
            metric_value = func(cur_out, mask_var, boxes)
            if not isinstance(metric_value, (list, numpy.ndarray)):
                metric_value = metric_value.cpu().data.numpy()
            cur_metrics.update(('\n'.join((name, n)), v) for n, v in zip(elem_names, metric_value) if n)
        metric_values.append(cur_metrics)
    return metric_values


def display_3dtensor(tensor, title='', fig_width=100, max_width=1000):
    """tensor.shape = (C, H, W)"""
    images = tensor.shape[0]
    max_images_per_row = max_width // tensor.shape[2]
    rows_n = images // max_images_per_row + (1 if images % max_images_per_row > 0 else 0)
    fig, axes = plt.subplots(rows_n, max_images_per_row)
    fig.set_size_inches(fig_width, fig_width * tensor.shape[1] / (tensor.shape[2]))
    axes = axes.reshape(-1)
    for i in range(images):
        axes[i].imshow(tensor[i], cmap='gray')
        axes[i].set_title('{} {}'.format(title, i))
    fig.suptitle(title)
    fig.set_tight_layout(True)


def display_inner_outs(tensors_list, fig_width=100, max_width=1000):
    for i, tensor in enumerate(tensors_list):
        display_3dtensor(tensor.data.cpu().numpy()[0], title=i, fig_width=fig_width, max_width=max_width)

Structured End-to-End


In [23]:
def bbox_shift(box, y, x):
    y1, x1, y2, x2 = box
    return (y1 + y, x1 + x, y2 + y, x2 + x)


class ConvFCNClassifier(nn.Module):
    def __init__(self, in_channels, out_channels, conv_layers=1, conv_kernel=(3, 3), dilations=[1, 2, 4], conv_block_class=ConvBlock):
        super(ConvFCNClassifier, self).__init__()
        conv_channels = [in_channels] + [len(dilations) * (2 ** (i + 1)) for i in range(conv_layers)]
        self.convs = nn.ModuleList([conv_block_class(conv_channels[i],
                                                     conv_channels[i+1],
                                                     conv_kernel,
                                                     dilations=dilations)
                                    for i in range(conv_layers)])
        self.output = nn.Linear(conv_channels[-1], out_channels)

    def forward(self, x):
        for conv in self.convs:
            x = conv(x)
        x = F.max_pool2d(x, (x.size(2), x.size(3))).squeeze(3).squeeze(2)
        return self.output(x)

    @property
    def receptive_field(self):
        conv_fields = [c.receptive_field for c in self.convs]
        result = conv_fields[0]
        for i in range(1, len(conv_fields)):
            result += numpy.clip(conv_fields[i] // 2, 1, None)
        return result


def inflate_box(box, min_y=0, min_x=0, max_y=1000, max_x=1000, amount=50):
    y1, x1, y2, x2 = box
    return (max(min_y, y1 - amount),
            max(min_x, x1 - amount),
            min(max_y, y2 + amount),
            min(max_x, x2 + amount))


class CellRelConvFCNClassifier(nn.Module):
    def __init__(self, base_classifier):
        super(CellRelConvAttention, self).__init__()
        self.base_classifier = base_classifier

    def forward(self, table_masks, relations_by_image):
        out = self.make_batch(table_masks, relations_by_image)
        if out is None:
            return mcuda(Variable(torch.DoubleTensor()), table_masks.data.is_cuda)
        return self.base_classifier(out)

    @staticmethod
    def make_batch(src_images, table_masks, relations_by_image, pad=50):
        assert src_images.size(0) == table_masks.size(0)
        assert src_images.size(1) == table_masks.size(2)
        assert src_images.size(2) == table_masks.size(3)
        assert table_masks.size(0) == len(relations_by_image)

        samples_num = sum(len(img_rels) for img_rels in relations_by_image)
        if samples_num == 0:
            return None

        rel_boxes = [inflate_box(get_intercell_line_bbox(cell1, cell2, direction),
                                 max_y=src_images.size(1),
                                 max_x=src_images.size(2),
                                 amount=pad)
                     for img_i, img_rels in enumerate(relations_by_image)
                     for (cell1, cell2, direction) in img_rels]
        max_rel_width = max(x2 - x1 for _, x1, _, x2 in rel_boxes) + 1
        max_rel_height = max(y2 - y1 for y1, _, y2, _ in rel_boxes) + 1

        src_images_for_samples = npvar(numpy.zeros((len(rel_boxes),
                                                    1,
                                                    max_rel_height,
                                                    max_rel_width),
                                                   dtype='float32'),
                                       table_masks.data.is_cuda)
        table_masks_for_samples = npvar(numpy.zeros((len(rel_boxes),
                                                     table_masks.size()[1],
                                                     max_rel_height,
                                                     max_rel_width),
                                                    dtype='float32'),
                                        table_masks.data.is_cuda)
        relation_masks_for_samples = numpy.zeros((table_masks_for_samples.size()[0],
                                                  max_rel_height,
                                                  max_rel_width),
                                                 dtype='float32')
        sample_i = 0
        for img_i, img_rels in enumerate(relations_by_image):
            for cell1, cell2, direction in img_rels:
                by1, bx1, by2, bx2 = rel_boxes[sample_i]
                cur_table_mask = table_masks[img_i, :, by1:by2+1, bx1:bx2+1]
                mask_size = cur_table_mask.size()
                table_masks_for_samples[sample_i, :, :mask_size[1], :mask_size[2]] = cur_table_mask

                src_images_for_samples[sample_i, 0, :mask_size[1], :mask_size[2]] = src_images[img_i, by1:by2+1, bx1:bx2+1]

                draw_intercell_mask(relation_masks_for_samples[sample_i],
                                    bbox_shift(cell1, -by1, -bx1),
                                    bbox_shift(cell2, -by1, -bx1),
                                    direction)
                sample_i += 1

        out = torch.cat([table_masks_for_samples,
                         npvar(relation_masks_for_samples, table_masks.data.is_cuda).unsqueeze(1),
                         src_images_for_samples],
                        dim=1)
        return out


class CellRelationsClassifier(nn.Module):
    def __init__(self, base_classifier, return_grids=False):
        super(CellRelationsClassifier, self).__init__()
        self.base_classifier = base_classifier
        self.return_grids = return_grids

    def forward(self, simple_mask):
        simple_mask_np = simple_mask.data.cpu().numpy()
        grids = list(map(table_grid_from_intracell_mask, simple_mask_np))
        intracell_relation_by_image = [[(cells[i1], cells[i2], direction)
                                        for (i1, i2), direction in rels.items()]
                                       for body, cells, grid, cell_idx, rels in grids]
        intracell_classes_flat = self.base_classifier(simple_mask, intracell_relation_by_image)
        if self.return_grids:
            return intracell_classes_flat, grids
        else:
            return intracell_classes_flat


class TableSegmenter(nn.Module):
    def __init__(self):
        super(TableSegmenter, self).__init__()
        self.unet = UNet(first_conv_channels=12,
                         depth=2,
                         enc_dilations=[1, 2, 4, 8])
        self.intra_cls = CellRelationsClassifier(CellRelConvAttention(self.unet.out_channels, 2, conv_layers=1),
                                                 return_grids=True)

    def forward(self, x):
        simple_mask = self.unet(x)
        intracell_classes_flat, grids = self.intra_cls(simple_mask, intracell_relation_by_image)
        return simple_mask, grids, intracell_classes_flat


STRUCT_NO_REL_I = 0
STRUCT_SAME_ROWS_I = 1
STRUCT_SAME_COLS_I = 2
STRUCT_SAME_CELL_H_I = 3
STRUCT_SAME_CELL_V_I = 4
STRUCT_SAME_CELL_H_RELS = {'right', 'left'}
STRUCT_SAME_CELL_V_RELS = {'upper', 'lower'}
def make_structured_gold_single_image(gold_boxes, grid_info, relation_keys_in_order, cell_strictness=0.3):
    body, pred_cells, grid, cell_idx, rels = grid_info
    gold_cells = gold_boxes[1]
    gold_rows = gold_boxes[2]
    gold_cols = gold_boxes[3]
    gold_cell2rows = group_by_centers(gold_rows, gold_cells)
    gold_cell2cols = group_by_centers(gold_cols, gold_cells)
    
    pred_cell_to_gold = group_by_intersection(gold_cells,
                                              pred_cells,
                                              threshold=cell_strictness)
    gold_intracell_classes = numpy.zeros(len(rels),
                                         dtype='int')
    gold_intracell_classes[:] = STRUCT_NO_REL_I
    for row_i, (pred_i1, pred_i2) in enumerate(relation_keys_in_order):
        gold_ids1 = pred_cell_to_gold.get(pred_i1, ())
        gold_ids2 = pred_cell_to_gold.get(pred_i2, ())
        if len(gold_ids1) == 0  or len(gold_ids2) == 0:
            continue
        gold_i1 = next(iter(gold_ids1))
        gold_i2 = next(iter(gold_ids2))
        if gold_i1 == gold_i2:
            rel = rels[(pred_i1, pred_i2)]
            if rel in STRUCT_SAME_CELL_H_RELS:
                gold_intracell_classes[row_i] = STRUCT_SAME_CELL_H_I
            elif rel in STRUCT_SAME_CELL_V_RELS:
                gold_intracell_classes[row_i] = STRUCT_SAME_CELL_V_I
        else:
            rows_rel, cols_rel = get_cells_relation_def(rels[(pred_i1, pred_i2)],
                                                        gold_cell2rows[gold_i1],
                                                        gold_cell2cols[gold_i1],
                                                        gold_cell2rows[gold_i2],
                                                        gold_cell2cols[gold_i2])
            if rows_rel == CellsRel.SAME and cols_rel == CellsRel.SAME:
                continue
            if rows_rel == CellsRel.SAME:
                gold_intracell_classes[row_i] = STRUCT_SAME_ROWS_I
            elif cols_rel == CellsRel.SAME:
                gold_intracell_classes[row_i] = STRUCT_SAME_COLS_I
    return gold_intracell_classes


# def make_structured_gold_batch(gold_boxes_by_image, grid_info_by_image, cell_strictness=0.5):
#     result = [make_structured_gold_single_image(gold_boxes, grid_info)
#               for gold_boxes, grid_info in zip(gold_boxes_by_image, grid_info_by_image)]
#     return numpy.concatenate(result)


# def run_network_structured(network, generator, num_batches,
#                            raw_criterion=dice_loss, structured_criterion=F.binary_cross_entropy,
#                            optimizer=None, metrics=TRAIN_INT_METRICS, cuda=True):
#     metric_values = []
#     gen_iter = iter(generator)
#     for _ in tqdm.tqdm(range(num_batches)):
#         image_ids, images, mask, loss_weights, boxes = next(gen_iter)
#         images_var = npvar(images, cuda)
#         mask_var = npvar(mask, cuda)
#         loss_weights_var = npvar(loss_weights, cuda)
#         boxes = [pickle.loads(b) for b in boxes]

#         cur_out, pred_grids, pred_intracell_classes = network(images_var)

#         raw_loss = raw_criterion(cur_out, mask_var, loss_weights_var)
#         gold_intracell_classes = make_structured_gold_batch(boxes, pred_grids)
#         if gold_intracell_classes.shape[0] > 0:
#             gold_intracell_classes = npvar(gold_intracell_classes,
#                                            pred_intracell_classes.data.is_cuda)
#             structured_loss = structured_criterion(pred_intracell_classes,
#                                                    gold_intracell_classes)
#         else:
#             structured_loss = npvar(numpy.zeros(1).astype('float32'), raw_loss.data.is_cuda)
#         full_loss = raw_loss.contiguous() + structured_loss.contiguous()

#         if optimizer:
#             optimizer.zero_grad()
#             full_loss.backward()
#             nn.utils.clip_grad_norm(network.parameters(), 10)
#             optimizer.step()

#         cur_metrics = {'raw_loss' : raw_loss.data[0],
#                        'structured_loss' : structured_loss.data[0],
#                        'full_loss' : full_loss.data[0] }
#         for name, (func, elem_names) in metrics.items():
#             metric_value = func(cur_out, mask_var, boxes)
#             if not isinstance(metric_value, (list, numpy.ndarray)):
#                 metric_value = metric_value.cpu().data.numpy()
#             cur_metrics.update(('\n'.join((name, n)), v) for n, v in zip(elem_names, metric_value) if n)
#         metric_values.append(cur_metrics)
#     return metric_values

Structured Two-step


In [24]:
def pad_and_concat(tensors):
    max_height = max(t.size()[2] for t in tensors)
    max_width = max(t.size()[3] for t in tensors)
    return torch.cat([F.pad(t, (0, max_width-t.size()[3], 0, max_height-t.size()[2]))
                      for t in tensors])


SIMPLE_CELL_NEIGHBORHOODS = {'left', 'right', 'upper', 'lower'}
SIMPLE_CELL_NEIGHBORHOODS_SUBSAMPLE = 1.0
def structured_two_step_datagen(orig_datagen, model=None, cuda=True, batch_size=32, tamper_channels=[]):
    batch_input = []
    batch_output = []
    out_i = 0
    for image_ids, in_img, mask, loss_weights, boxes_aug in orig_datagen:
        if model:
            mask = model(npvar(in_img, cuda))
            if isinstance(mask, tuple):
                mask = mask[0]
            mask = mask.data
        mask = mask.cpu().numpy()
        in_img_np = in_img.cpu().numpy()
        for table_in_img, table_mask, gold_boxes in zip(in_img_np, mask, boxes_aug):
            gold_boxes = pickle.loads(gold_boxes)
            body, cells, grid, cell_idx, rels = table_grid_from_intracell_mask(table_mask, input_image=table_in_img[0])
            intracell_relation_keys = [k for k, neigh_type in rels.items()
                                       if (not neigh_type in SIMPLE_CELL_NEIGHBORHOODS)
                                           or (random.random() <= SIMPLE_CELL_NEIGHBORHOODS_SUBSAMPLE)]

            in_i = 0
            while in_i < len(intracell_relation_keys):
                step = min(batch_size - out_i, len(intracell_relation_keys) - in_i)

                batch_relation_keys = intracell_relation_keys[in_i:in_i+step]
                batch_relations = [(cells[i1], cells[i2], rels[(i1, i2)])
                                   for (i1, i2) in batch_relation_keys]
                batch_input.append(CellRelConvFCNClassifier.make_batch(npvar(table_in_img, cuda),
                                                                       npvar(table_mask, cuda).unsqueeze(0),
                                                                       [batch_relations]))
                gold_cls = make_structured_gold_single_image(gold_boxes,
                                                             (body, cells, grid, cell_idx,
                                                              { k : rels[k] for k in batch_relation_keys }),
                                                             batch_relation_keys)
                batch_output.append(npvar(gold_cls, cuda))
                in_i += step
                out_i += step

                if out_i >= batch_size:
                    out_input = pad_and_concat(batch_input)
                    if tamper_channels:
                        out_input[:, tamper_channels] = 0
                    yield out_input, torch.cat(batch_output)
                    batch_input = []
                    batch_output = []
                    out_i = 0


def run_cell_network(network, generator, num_batches, criterion=F.cross_entropy, optimizer=None, metrics=CELL_TRAIN_METRICS):
    metric_values = []
    gen_iter = iter(generator)
    for _ in tqdm.tqdm(range(num_batches)):
        inp, gold = next(gen_iter)

        pred = network(inp)

        loss = criterion(pred, gold)

        if optimizer:
            optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm(network.parameters(), 10)
            optimizer.step()

        cur_metrics = { 'loss' : loss.data[0] }
        for name, (func, elem_names) in metrics.items():
            metric_value = func(pred, gold)
            if not isinstance(metric_value, (list, numpy.ndarray)):
                metric_value = metric_value.cpu().data.numpy()
            cur_metrics.update(('\n'.join((name, n)), v) for n, v in zip(elem_names, metric_value) if n)
        metric_values.append(cur_metrics)
    return metric_values

Postprocess using superpixels and RNN


In [25]:
def remove_too_close_points_inner(lst, th=0.8):
    result = []
    last_group = [lst[0]]
    for cur in lst[1:]:
        if cur - last_group[-1] < th:
            last_group.append(cur)
        else:
            
            if last_group:
                result.append(last_group)
            last_group = [cur]
    if last_group:
        result.append(last_group)
#     return [int(round(numpy.mean(group))) for group in result]
    return [elem
            for group in result
            for elem in ([group[0], group[-1]] if len(group) > 1 else group)]


def remove_too_close_points(lst, th=1.1, max_iter=2):
    if not lst:
        return lst
    new_lst = lst
    for _ in range(max_iter):
        new_lst = remove_too_close_points_inner(lst, th=th)
        if len(new_lst) == len(lst):
            break
        lst = new_lst
    return new_lst


def make_superpixels_grid(mask, boxes_th=1.01):
    ys = [0]
    xs = [0]
    for (y1, x1, y2, x2) in get_all_boxes(mask):
        ys.append(y1)
        ys.append(y2)
        xs.append(x1)
        xs.append(x2)
    ys.append(mask.shape[0])
    xs.append(mask.shape[1])
    ys.sort()
    xs.sort()
    ys = remove_too_close_points(ys, th=boxes_th)
    xs = remove_too_close_points(xs, th=boxes_th)
    return ys, xs


# ML_SP_LEVELS = [0.1, 0.3, 0.6, 0.8, 1.0]
ML_SP_LEVELS = [0.1, 0.3, 1.0]
def make_multilevel_superpixels_grid(mask, levels=ML_SP_LEVELS, boxes_th=1.01):
    ys = { 0, mask.shape[0] }
    xs = { 0, mask.shape[1] }
    for i in range(len(levels) - 1):
        level_mask = numpy.where((mask > levels[i]) & (mask <= levels[i+1]), 1.0, 0.0)
        level_boxes = get_all_boxes(level_mask)
        for y1, x1, y2, x2 in level_boxes:
            ys.add(y1)
            ys.add(y2)
            xs.add(x1)
            xs.add(x2)
    ys = sorted(ys)
    xs = sorted(xs)
    ys = remove_too_close_points(ys, th=boxes_th)
    xs = remove_too_close_points(xs, th=boxes_th)
#     display(make_demo_mask(arr_to_img(mask),
#                            [(1, b) for b in make_boxes_from_grid(ys, xs)]))
    return ys, xs


def make_boxes_from_grid(ys, xs):
    return [(ys[i], xs[j], ys[i+1], xs[j+1])
            for i in range(len(ys) - 1)
            for j in range(len(xs) - 1)]


ML_SP_SRC_IMG_LEVELS = [0.2, 1.0]
def make_multilevel_superpixels_from_img_and_mask(src_image, mask, src_img_levels=ML_SP_SRC_IMG_LEVELS, mask_levels=ML_SP_LEVELS, boxes_th=1.01):
    img_ys, img_xs = make_multilevel_superpixels_grid(1.0 - src_image,
                                                      levels=src_img_levels,
                                                      boxes_th=boxes_th)
    mask_ys, mask_xs = make_multilevel_superpixels_grid(mask,
                                                        levels=mask_levels,
                                                        boxes_th=boxes_th)
    ys = sorted(set(img_ys) | set(mask_ys))
    xs = sorted(set(img_xs) | set(mask_xs))
    ys = remove_too_close_points(ys, th=boxes_th)
    xs = remove_too_close_points(xs, th=boxes_th)
    display(make_demo_mask(arr_to_img(test_segm_mask),
                           [(1, b) for b in make_boxes_from_grid(ys, xs)]))
    return ys, xs


def make_superpixels(mask, boxes_th=1.01):
    grid = make_superpixels_grid(mask, boxes_th=boxes_th)
    return make_boxes_from_grid(*grid)


def make_sp_pp_fex_batch(src_images, predicted_masks, boxes_by_image):
    total_samples = sum(len(img_boxes) for img_boxes in boxes_by_image)
    max_height = max(y2 - y1 + 1
                     for img_boxes in boxes_by_image
                     for y1, _, y2, _ in img_boxes)
    max_width = max(x2 - x1 + 1
                    for img_boxes in boxes_by_image
                    for _, x1, _, x2 in img_boxes)
    batch_input = numpy.zeros((total_samples,
                               predicted_masks[0].size(0) + 1,
                               max_height,
                               max_width),
                              dtype='float32')
    sample_i = 0
    for img_i, img_boxes in enumerate(boxes_by_image):
        for y1, x1, y2, x2 in img_boxes:
            h = y2 - y1 + 1
            w = x2 - x1 + 1
            batch_input[sample_i, 0, :h, :w] = src_images[img_i][0, y1:y2+1, x1:x2+1]
            batch_input[sample_i, 1:, :h, :w] = predicted_masks[img_i][:, y1:y2+1, x1:x2+1]
    return batch_input


# class SuperpixelPostprocessor(nn.Module):
#     def __init__(self, feature_extractor, sp_classifier, fex_batch_size=64):
#         super(SuperpixelPostprocessor, self).__init__()
#         self.feature_extractor = feature_extractor
#         self.sp_classifier = sp_classifier
#         self.fex_batch_size = fex_batch_size

#     def forward(self, inp):
#         src_images, predicted_masks, grids = inp
#         sp_input = []
#         for img_i in range(len(src_images)):
#             cur_img = src_images[img_i]
#             cur_mask = predicted_masks[img_i]
#             grid_ys, grid_xs = grids[img_i]
#             superpixels = make_boxes_from_grid(grid_ys, grid_xs)

#             cur_img_features = []
#             for fex_batch_start in range(0, len(superpixels), self.fex_batch_size):
#                 batch_superpixels = superpixels[fex_batch_start:fex_batch_start+self.fex_batch_size]
#                 fex_input = npvar(SuperpixelPostprocessor.make_fex_batch(src_images[img_i:img_i+1],
#                                                                          predicted_masks[img_i:img_i+1],
#                                                                          [batch_superpixels]),
#                                   src_images.data.is_cuda)
#                 cur_img_features.append(self.feature_extractor(fex_input))
#             cur_img_features = torch.cat(cur_img_features, 0).view(1,
#                                                                    len(grid_ys) - 1,
#                                                                    len(grid_xs) - 1,
#                                                                    -1)
#             sp_input.append(cur_img_features.transpose(0, 3, 1, 2))
#         sp_input = torch.cat(sp_input, 0)
#         return self.sp_classifier(s)


def sp_pp_fex_agg_maxpool(t):
    return F.max_pool2d(t, (t.size(2), t.size(3))).squeeze(3).squeeze(2)


def sp_pp_fex_agg_avgpool(t):
    return F.avg_pool2d(t, (t.size(2), t.size(3))).squeeze(3).squeeze(2)


def dummy_sp_pp_fex(t):
    return t


def dummy_sp_pp_cls(t):
    return t[:, 2:3]


def dummy_sp_pp_cls_with_th(t, th=0.5):
    return (t[:, 2:3] >= th).float()
    


class SuperpixelPostprocessor(nn.Module):
    def __init__(self, feature_extractor, sp_classifier, fex_agg=sp_pp_fex_agg_avgpool):
        super(SuperpixelPostprocessor, self).__init__()
        self.feature_extractor = feature_extractor
        self.sp_classifier = sp_classifier
        self.fex_agg = fex_agg

    def forward(self, inp):
        src_images, predicted_masks, grids = inp
        fex_input = torch.cat([src_images, predicted_masks], 1)
        features = self.feature_extractor(fex_input)
#         display(arr_to_img(mask_to_img(features.data.cpu().numpy()[0])))

        sp_input = []
        for img_i in range(len(src_images)):
            cur_mask = predicted_masks[img_i]

            grid_ys, grid_xs = grids[img_i]

            cur_sp_feats = [self.fex_agg(features[img_i:img_i+1,
                                                  :,
                                                  grid_ys[row]:grid_ys[row+1]+1, 
                                                  grid_xs[col]:grid_xs[col+1]+1])
                            for row in range(len(grid_ys) - 1)
                            for col in range(len(grid_xs) - 1)]
            cur_sp_feats = torch.cat(cur_sp_feats, 0).view(len(grid_ys) - 1, len(grid_xs) - 1, -1).permute(2, 0, 1).unsqueeze(0)
            sp_input.append(cur_sp_feats)
        sp_input = pad_and_concat(sp_input)
        return self.sp_classifier(sp_input)


def superpixel_pp_datagen(orig_datagen, model=None, cuda=True, min_mean_for_pos=0.5):
    for image_ids, in_img, gold_mask, loss_weights, boxes_aug in orig_datagen:
        if model:
            pred_mask = model(npvar(in_img, cuda))
            if isinstance(pred_mask, tuple):
                pred_mask = pred_mask[0]
            pred_mask = pred_mask.data
        else:
            pred_mask = gold_mask
            

        gold_cells_np = gold_mask.cpu().numpy()
        pred_mask_np = pred_mask.cpu().numpy()
        in_img_np = in_img.cpu().numpy()

#         display(make_demo_mask(arr_to_img(pred_mask_np[0][1]),
#                                [(1, b) for b in make_superpixels(pred_mask_np[0][1])]))
        grids_by_img = [make_multilevel_superpixels_grid(table_mask[1]) for table_mask in pred_mask_np]

        max_y_superpixels = max(len(grid_ys) - 1 for grid_ys, _ in grids_by_img)
        max_x_superpixels = max(len(grid_xs) - 1 for _, grid_xs in grids_by_img)
        gold_pred = numpy.zeros((len(grids_by_img),
                                 max_y_superpixels,
                                 max_x_superpixels),
                                dtype='float32')
        for i, (grid_ys, grid_xs) in enumerate(grids_by_img):
            for row in range(len(grid_ys) - 1):
                for col in range(len(grid_xs) - 1):
                    mean_gold_label = gold_cells_np[i,
                                                    1,
                                                    grid_ys[row]:grid_ys[row+1]+1,
                                                    grid_xs[col]:grid_xs[col+1]+1].mean()
                    if mean_gold_label >= min_mean_for_pos:
                        gold_pred[i, row, col] = 1            
            
        yield (npvar(in_img, cuda),
               gold_mask,
               boxes_aug,
               mcuda(Variable(pred_mask), cuda),
               grids_by_img,
               npvar(gold_pred, cuda).unsqueeze(1))


def run_sp_pp_network(network, generator, num_batches, criterion=F.binary_cross_entropy, optimizer=None, metrics=SP_PP_METRICS):
    metric_values = []
    gen_iter = iter(generator)
    for _ in tqdm.tqdm(range(num_batches)):
        in_img, gold_mask, boxes_aug, pred_mask, grids, gold = next(gen_iter)
        boxes = [{ 0: pickle.loads(b)[1] } for b in boxes_aug]

        pred = network((in_img, pred_mask, grids))

        loss = criterion(pred, gold)

        if optimizer:
            optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm(network.parameters(), 10)
            optimizer.step()

        cur_metrics = { 'loss' : loss.data[0] }
        for name, (func, elem_names) in metrics.items():
            metric_value = func(pred, gold,
                                gold_boxes=boxes,
                                grids=grids)
            if not isinstance(metric_value, (list, numpy.ndarray)):
                metric_value = metric_value.cpu().data.numpy()
            cur_metrics.update(('\n'.join((name, n)), v) for n, v in zip(elem_names, metric_value) if n)
        metric_values.append(cur_metrics)
    return metric_values

Train


In [26]:
def clean_pytorch_dataloadergen(gen_iter):
    try:
        gen_iter._shutdown_workers()
        for _ in range(100):
            gen_iter.index_queue.put(None)
#         while not gen_iter.data_queue.empty():
#             b = gen_iter.data_queue.get()
        for w in gen_iter.workers:
            w.terminate()
        del gen_iter
    except ex:
        print(ex)


def const_lr(epoch, value=1e-3):
    return value


def step_annealed_lr(epoch, start=1e-3, step_epochs=15, decay_factor=0.5):
    lr_factor = decay_factor ** (epoch // step_epochs)
    return start * lr_factor


def step_annealed_big_lr(epoch):
    return step_annealed_lr(epoch, start=1e-1, step_epochs=15, decay_factor=0.5)


def sin_annealed_lr(epoch, freq=0.6, start=1e-3, sin_amp=5e-4, amp_decay_factor=0.97, mean_decay_factor=0.94, min_lr=1e-5):
    base = (numpy.sin(epoch * freq) + 1) * sin_amp * (amp_decay_factor ** epoch)
    mean = start * (mean_decay_factor ** epoch)
    return numpy.clip(base + mean, min_lr, None)


lr_vis_x = numpy.arange(0, 60, 1)
pandas.DataFrame(data=dict(const=const_lr(lr_vis_x),
                           step=step_annealed_lr(lr_vis_x),
                           sin=sin_annealed_lr(lr_vis_x))).plot()


Out[26]:
<matplotlib.axes._subplots.AxesSubplot at 0x7f5b480e2eb8>

Run

Detect tables


In [27]:
# # %%prun

# train_gen = None
# val_gen = None

# EPOCHS_NUM = 30
# BATCH_SIZE = 1
# PART_PER_EPOCH = 0.2
# BATCHES_PER_EPOCH_TRAIN = int(len(train_det_image_ids) * PART_PER_EPOCH // BATCH_SIZE)
# BATCHES_PER_EPOCH_VAL = int(len(val_det_image_ids) * PART_PER_EPOCH // BATCH_SIZE)


# # det_net = UNet(out_channels=TOTAL_DET_CLASSES,
# #                first_conv_channels=7,
# #                depth=5,
# #                enc_dilations=[1, 2, 4, 8, 16, 32, 64]).cuda()
# det_net = RowCollAttUNet(out_channels=TOTAL_DET_CLASSES,
#                          first_conv_channels=6,
#                          depth=3,
#                          enc_dilations=[1, 2, 4, 8, 16, 32],
#                          dec_dilations=[1, 2, 4, 8, 16, 32]).cuda()
# LOSS = dice_bce_loss
# # LOSS = dice_loss
# # LOSS = weighted_dice_loss

# LR_SCHEDULE = step_annealed_lr
# # LR_SCHEDULE = sin_annealed_lr

# det_augmenter = imgaug_pipeline
# # det_augmenter = fake_imgaug_pipeline

# print('total parameters', sum(numpy.product(p.size()) for p in det_net.parameters()))

# train_metrics = []
# val_metrics = []
# for epoch in range(EPOCHS_NUM):
#     try:
#         train_gen = iter(DataLoader(SegmDataset(train_det_image_ids,
#                                                 det_augmenter,
#                                                 prepare_det_batch),
#                                     batch_size=BATCH_SIZE,
#                                     shuffle=True,
#                                     num_workers=4))
#         val_gen = iter(DataLoader(SegmDataset(val_det_image_ids,
#                                               det_augmenter,
#                                               prepare_det_batch),
#                                   batch_size=BATCH_SIZE,
#                                   shuffle=True,
#                                   num_workers=2))
#         print('epoch', epoch)

#         lr = LR_SCHEDULE(epoch)
#         print('lr', lr)
#         optimizer = Adam(det_net.parameters(), lr=lr)

#         det_net.train()
#         cur_train_metrics = run_network(det_net, train_gen, BATCHES_PER_EPOCH_TRAIN,
#                                         criterion=LOSS,
#                                         metrics=TRAIN_DET_METRICS,
#                                         optimizer=optimizer)
#         train_metrics.extend(cur_train_metrics)
#         display(pandas.DataFrame(cur_train_metrics).describe().loc[['mean', 'std']])

#         det_net.eval()
#         cur_val_metrics = run_network(det_net, val_gen, BATCHES_PER_EPOCH_VAL,
#                                       criterion=LOSS,
#                                       metrics=VAL_DET_METRICS)
#         val_metrics.extend(cur_val_metrics)
#         display(pandas.DataFrame(cur_val_metrics).describe().loc[['mean', 'std']])
#     finally:
#         if train_gen:
#             clean_pytorch_dataloadergen(train_gen)
#         if val_gen:
#             clean_pytorch_dataloadergen(val_gen)

#         gc.collect()

Segment structure


In [28]:
# train_gen = None
# val_gen = None

# EPOCHS_NUM = 40
# TRAIN_BATCH_SIZE = 16
# VAL_BATCH_SIZE = 1
# PART_PER_EPOCH = 0.1
# BATCHES_PER_EPOCH_TRAIN = int(len(train_int_image_ids) * PART_PER_EPOCH // TRAIN_BATCH_SIZE)
# BATCHES_PER_EPOCH_VAL = int(len(val_int_image_ids) * PART_PER_EPOCH // VAL_BATCH_SIZE)


# ENC_DILATIONS = [
#     [1, 2],
#     [1, 2],
#     [1, 2],
#     [1, 2],
#     [1]
# ]

# DEC_DILATIONS = [
#     [1],
#     [1],
#     [1],
#     [1]
# ]

# OUT_DILATIONS = [
#     [1]
# ]

# # net = StackedUNet1().cuda()
# # net = DenseNet1().cuda()

# # best!!!
# net = UNet(first_conv_channels=6,
#            depth=4,
#            enc_dilations=[1, 2, 4, 8, 16, 32],
#            dec_dilations=[1]).cuda() 

# # net = UNetWithMaxPool(first_conv_channels=6,
# #                       depth=4,
# #                       enc_dilations=[1, 2, 4, 8, 16, 32],
# #                       dec_dilations=[1]).cuda()

# # net = UNetHVS(first_conv_channels=6,
# #               depth=4,
# #               enc_dilations=[1, 2, 4, 8, 16, 32],
# #               dec_dilations=[1]).cuda()

# # net = UNet(first_conv_channels=6,
# #            depth=4,
# #            enc_dilations=ENC_DILATIONS.__getitem__,
# #            dec_dilations=DEC_DILATIONS.__getitem__,
# #            out_dilations=OUT_DILATIONS.__getitem__).cuda()

# # net = UNet(first_conv_channels=5,
# #                     depth=4,
# #                     enc_dilations=[1],
# #                     dec_dilations=[1],
# #                     out_dilations=[1]).cuda()
# # net = SimpleDenseNet(first_conv_channels=7,
# #                      depth=3,
# #                      enc_dilations=[1, 2, 4, 8, 16, 32, 64]).cuda()
# # net = SimpleDenseNetWithMean(first_conv_channels=5,
# #                              depth=6,
# #                              enc_dilations=[1, 2, 4, 8, 16]).cuda()
# # net = UNetWithMean(first_conv_channels=4,
# #                    depth=4,
# #                    enc_dilations=[1, 2, 4, 8],
# #                    dec_dilations=[1, 2, 4, 8]).cuda()

# LOSS = dice_bce_loss
# # LOSS = dice_loss
# # LOSS = weighted_dice_loss

# LR_SCHEDULE = step_annealed_lr
# # LR_SCHEDULE = sin_annealed_lr

# int_augmenter = imgaug_pipeline
# # int_augmenter = fake_imgaug_pipeline

# print('total parameters', sum(numpy.product(p.size()) for p in net.parameters()))

# train_metrics = []
# val_metrics = []
# for epoch in range(EPOCHS_NUM):
#     try:
#         train_gen = iter(DataLoader(SegmDataset(train_int_image_ids,
#                                                 int_augmenter,
#                                                 prepare_int_batch),
#                                     batch_size=TRAIN_BATCH_SIZE,
#                                     shuffle=True,
#                                     num_workers=4))
#         val_gen = iter(DataLoader(SegmDataset(val_int_image_ids,
#                                               int_augmenter,
#                                               prepare_int_batch),
#                                   batch_size=VAL_BATCH_SIZE,
#                                   shuffle=True,
#                                   num_workers=2))
#         print('epoch', epoch)

#         lr = LR_SCHEDULE(epoch)
#         print('lr', lr)
#         optimizer = Adam(net.parameters(), lr=lr)

#         net.train()
#         cur_train_metrics = run_network(net, train_gen, BATCHES_PER_EPOCH_TRAIN,
#                                         criterion=LOSS,
#                                         metrics=TRAIN_INT_METRICS,
#                                         optimizer=optimizer,
#                                         regularizer=row_col_smooth_contrast_reg)
#         train_metrics.extend(cur_train_metrics)
#         display(pandas.DataFrame(cur_train_metrics).describe().loc[['mean', 'std']])

#         net.eval()
#         cur_val_metrics = run_network(net, val_gen, BATCHES_PER_EPOCH_VAL,
#                                       criterion=LOSS,
#                                       metrics=VAL_INT_METRICS,
#                                       regularizer=row_col_smooth_contrast_reg)
#         val_metrics.extend(cur_val_metrics)
#         display(pandas.DataFrame(cur_val_metrics).describe().loc[['mean', 'std']])
#     finally:
#         if train_gen:
#             clean_pytorch_dataloadergen(train_gen)
#         if val_gen:
#             clean_pytorch_dataloadergen(val_gen)

#         gc.collect()

Segment structure with extra info


In [29]:
train_gen = None
val_gen = None

EPOCHS_NUM = 40
TRAIN_BATCH_SIZE = 16
VAL_BATCH_SIZE = 1
PART_PER_EPOCH = 0.1
BATCHES_PER_EPOCH_TRAIN = int(len(train_int_image_ids) * PART_PER_EPOCH // TRAIN_BATCH_SIZE)
BATCHES_PER_EPOCH_VAL = int(len(val_int_image_ids) * PART_PER_EPOCH // VAL_BATCH_SIZE)


ENC_DILATIONS = [
    [1, 2],
    [1, 2],
    [1, 2],
    [1, 2],
    [1]
]

DEC_DILATIONS = [
    [1],
    [1],
    [1],
    [1]
]

OUT_DILATIONS = [
    [1]
]

# net = StackedUNet1().cuda()
# net = DenseNet1().cuda()

# best!!!
net = UNet(in_channels=2,
           first_conv_channels=6,
           depth=4,
           enc_dilations=[1, 2, 4, 8, 16, 32],
           dec_dilations=[1]).cuda() 

# net = UNetWithMaxPool(first_conv_channels=6,
#                       depth=4,
#                       enc_dilations=[1, 2, 4, 8, 16, 32],
#                       dec_dilations=[1]).cuda()

# net = UNetHVS(first_conv_channels=6,
#               depth=4,
#               enc_dilations=[1, 2, 4, 8, 16, 32],
#               dec_dilations=[1]).cuda()

# net = UNet(first_conv_channels=6,
#            depth=4,
#            enc_dilations=ENC_DILATIONS.__getitem__,
#            dec_dilations=DEC_DILATIONS.__getitem__,
#            out_dilations=OUT_DILATIONS.__getitem__).cuda()

# net = UNet(first_conv_channels=5,
#                     depth=4,
#                     enc_dilations=[1],
#                     dec_dilations=[1],
#                     out_dilations=[1]).cuda()
# net = SimpleDenseNet(first_conv_channels=7,
#                      depth=3,
#                      enc_dilations=[1, 2, 4, 8, 16, 32, 64]).cuda()
# net = SimpleDenseNetWithMean(first_conv_channels=5,
#                              depth=6,
#                              enc_dilations=[1, 2, 4, 8, 16]).cuda()
# net = UNetWithMean(first_conv_channels=4,
#                    depth=4,
#                    enc_dilations=[1, 2, 4, 8],
#                    dec_dilations=[1, 2, 4, 8]).cuda()

LOSS = dice_bce_loss
# LOSS = dice_loss
# LOSS = weighted_dice_loss

LR_SCHEDULE = step_annealed_lr
# LR_SCHEDULE = sin_annealed_lr

int_augmenter = imgaug_pipeline
# int_augmenter = fake_imgaug_pipeline

print('total parameters', sum(numpy.product(p.size()) for p in net.parameters()))

train_metrics = []
val_metrics = []
for epoch in range(EPOCHS_NUM):
    try:
        train_gen = iter(DataLoader(SegmExtDataset(train_int_image_ids,
                                                   int_augmenter,
                                                   prepare_int_ext_batch,
                                                   vocab),
                                    batch_size=TRAIN_BATCH_SIZE,
                                    shuffle=True,
                                    num_workers=4))
        qq = next(train_gen)
        val_gen = iter(DataLoader(SegmExtDataset(val_int_image_ids,
                                                 int_augmenter,
                                                 prepare_int_ext_batch,
                                                 vocab),
                                  batch_size=VAL_BATCH_SIZE,
                                  shuffle=True,
                                  num_workers=2))
        print('epoch', epoch)

        lr = LR_SCHEDULE(epoch)
        print('lr', lr)
        optimizer = Adam(net.parameters(), lr=lr)

        net.train()
        cur_train_metrics = run_network_ext(net, train_gen, BATCHES_PER_EPOCH_TRAIN,
                                            criterion=LOSS,
                                            metrics=TRAIN_INT_METRICS,
                                            optimizer=optimizer,
#                                             regularizer=row_col_smooth_contrast_reg
                                           )
        train_metrics.extend(cur_train_metrics)
        display(pandas.DataFrame(cur_train_metrics).describe().loc[['mean', 'std']])

        net.eval()
        cur_val_metrics = run_network_ext(net, val_gen, BATCHES_PER_EPOCH_VAL,
                                          criterion=LOSS,
                                          metrics=VAL_INT_METRICS,
#                                           regularizer=row_col_smooth_contrast_reg
                                         )
        val_metrics.extend(cur_val_metrics)
        display(pandas.DataFrame(cur_val_metrics).describe().loc[['mean', 'std']])
    finally:
        if train_gen:
            clean_pytorch_dataloadergen(train_gen)
        if val_gen:
            clean_pytorch_dataloadergen(val_gen)

        gc.collect()


total parameters 136028
  0%|          | 0/10 [00:00<?, ?it/s]
epoch 0
lr 0.001
100%|██████████| 10/10 [00:08<00:00,  1.22it/s]
d body d cell d same_col_other_row d same_row_other_col loss
mean 4.056901e-07 0.189660 4.524311e-07 4.114571e-07 3.905798
std 4.057957e-09 0.024889 1.157371e-09 2.605503e-09 0.127541
100%|██████████| 40/40 [00:03<00:00, 10.82it/s]
d body d cell d same_col_other_row d same_row_other_col loss px_p body px_p cell px_p same_col_other_row px_p same_row_other_col px_r body px_r cell px_r same_col_other_row px_r same_row_other_col
mean 6.776977e-06 0.147116 7.158827e-06 6.290724e-06 4.182329 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
std 5.854985e-09 0.046519 3.924294e-09 7.227305e-10 0.287012 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
  0%|          | 0/10 [00:00<?, ?it/s]
epoch 1
lr 0.001
100%|██████████| 10/10 [00:08<00:00,  1.23it/s]
d body d cell d same_col_other_row d same_row_other_col loss
mean 4.154608e-07 0.252438 4.592797e-07 4.117624e-07 3.592064
std 2.197272e-09 0.019839 2.741174e-09 1.624643e-09 0.085070
100%|██████████| 40/40 [00:03<00:00, 10.28it/s]
d body d cell d same_col_other_row d same_row_other_col loss px_p body px_p cell px_p same_col_other_row px_p same_row_other_col px_r body px_r cell px_r same_col_other_row px_r same_row_other_col
mean 6.797402e-06 0.173674 7.245313e-06 6.390528e-06 4.012378 0.0 0.965076 0.0 0.0 0.0 0.766685 0.0 0.0
std 5.000844e-08 0.055986 2.170645e-08 4.983042e-09 0.306989 0.0 0.017582 0.0 0.0 0.0 0.038594 0.0 0.0
  0%|          | 0/10 [00:00<?, ?it/s]
epoch 2
lr 0.001
100%|██████████| 10/10 [00:08<00:00,  1.13it/s]
d body d cell d same_col_other_row d same_row_other_col loss
mean 4.233518e-07 0.292192 4.651860e-07 4.246418e-07 3.385190
std 1.046931e-09 0.017101 2.994926e-09 4.882025e-09 0.073071
100%|██████████| 40/40 [00:03<00:00, 10.27it/s]
d body d cell d same_col_other_row d same_row_other_col loss px_p body px_p cell px_p same_col_other_row px_p same_row_other_col px_r body px_r cell px_r same_col_other_row px_r same_row_other_col
mean 6.806449e-06 0.223400 7.358388e-06 6.660302e-06 3.724190 0.0 0.820546 0.0 0.0 0.0 0.991150 0.0 0.0
std 6.917138e-08 0.079845 1.466244e-08 2.525795e-08 0.387293 0.0 0.032899 0.0 0.0 0.0 0.004538 0.0 0.0
  0%|          | 0/10 [00:00<?, ?it/s]
epoch 3
lr 0.001
100%|██████████| 10/10 [00:08<00:00,  1.12it/s]
d body d cell d same_col_other_row d same_row_other_col loss
mean 4.246457e-07 0.314825 4.737295e-07 4.441133e-07 3.237692
std 5.434927e-10 0.016364 2.090081e-09 6.275182e-09 0.076260
100%|██████████| 40/40 [00:03<00:00, 10.51it/s]
d body d cell d same_col_other_row d same_row_other_col loss px_p body px_p cell px_p same_col_other_row px_p same_row_other_col px_r body px_r cell px_r same_col_other_row px_r same_row_other_col
mean 6.806978e-06 0.280529 7.550104e-06 7.157312e-06 3.424852 0.0 0.782846 0.0 0.0 0.0 0.994719 0.0 0.0
std 9.839478e-08 0.093569 9.813464e-08 9.840434e-08 0.361496 0.0 0.035278 0.0 0.0 0.0 0.011517 0.0 0.0
  0%|          | 0/10 [00:00<?, ?it/s]
epoch 4
lr 0.001
100%|██████████| 10/10 [00:08<00:00,  1.13it/s]
d body d cell d same_col_other_row d same_row_other_col loss
mean 4.299966e-07 0.321230 4.786856e-07 4.755539e-07 3.144114
std 3.135448e-09 0.022883 1.689156e-09 1.332421e-08 0.080410
100%|██████████| 40/40 [00:04<00:00,  9.96it/s]
d body d cell d same_col_other_row d same_row_other_col loss px_p body px_p cell px_p same_col_other_row px_p same_row_other_col px_r body px_r cell px_r same_col_other_row px_r same_row_other_col
mean 6.931040e-06 0.337385 7.717075e-06 7.993394e-06 3.135847 0.0 0.789889 0.0 0.0 0.0 0.997019 0.0 0.0
std 1.685753e-07 0.088528 1.458626e-07 1.701126e-07 0.296859 0.0 0.028815 0.0 0.0 0.0 0.001604 0.0 0.0
  0%|          | 0/10 [00:00<?, ?it/s]
epoch 5
lr 0.001
100%|██████████| 10/10 [00:08<00:00,  1.20it/s]
d body d cell d same_col_other_row d same_row_other_col loss
mean 4.407959e-07 0.346935 4.831206e-07 5.268508e-07 2.997763
std 3.625473e-09 0.024864 1.218617e-09 1.556222e-08 0.077687
100%|██████████| 40/40 [00:04<00:00,  9.62it/s]
d body d cell d same_col_other_row d same_row_other_col loss px_p body px_p cell px_p same_col_other_row px_p same_row_other_col px_r body px_r cell px_r same_col_other_row px_r same_row_other_col
mean 7.173713e-06 0.348588 7.768068e-06 9.088657e-06 3.005106 0.0 0.512988 0.0 0.0 0.0 0.992994 0.0 0.0
std 3.004655e-07 0.087946 1.542024e-07 2.643650e-07 0.357718 0.0 0.058309 0.0 0.0 0.0 0.003416 0.0 0.0
  0%|          | 0/10 [00:00<?, ?it/s]
epoch 6
lr 0.001
100%|██████████| 10/10 [00:08<00:00,  1.18it/s]
d body d cell d same_col_other_row d same_row_other_col loss
mean 4.575021e-07 0.343521 4.878731e-07 6.096537e-07 2.930743
std 5.289791e-09 0.021118 3.259234e-09 3.472831e-08 0.081768
100%|██████████| 40/40 [00:04<00:00,  9.53it/s]
d body d cell d same_col_other_row d same_row_other_col loss px_p body px_p cell px_p same_col_other_row px_p same_row_other_col px_r body px_r cell px_r same_col_other_row px_r same_row_other_col
mean 7.867448e-06 0.321398 8.055607e-06 1.111476e-05 2.951208 0.0 0.315743 0.0 0.0 0.0 0.987009 0.0 0.0
std 7.966028e-07 0.075210 2.897676e-07 6.235242e-07 0.328685 0.0 0.062186 0.0 0.0 0.0 0.004033 0.0 0.0
  0%|          | 0/10 [00:00<?, ?it/s]
epoch 7
lr 0.001
100%|██████████| 10/10 [00:08<00:00,  1.24it/s]
d body d cell d same_col_other_row d same_row_other_col loss
mean 4.678431e-07 0.354548 5.033243e-07 7.553149e-07 2.820611
std 2.588499e-09 0.016504 5.396402e-09 5.139881e-08 0.040407
100%|██████████| 40/40 [00:04<00:00,  9.42it/s]
d body d cell d same_col_other_row d same_row_other_col loss px_p body px_p cell px_p same_col_other_row px_p same_row_other_col px_r body px_r cell px_r same_col_other_row px_r same_row_other_col
mean 7.986119e-06 0.345471 8.384401e-06 1.371020e-05 2.799487 0.0 0.308395 0.0 0.0 0.0 0.987091 0.0 0.0
std 9.208194e-07 0.081570 3.740185e-07 5.964570e-07 0.365114 0.0 0.058670 0.0 0.0 0.0 0.009280 0.0 0.0
  0%|          | 0/10 [00:00<?, ?it/s]
epoch 8
lr 0.001
100%|██████████| 10/10 [00:08<00:00,  1.16it/s]
d body d cell d same_col_other_row d same_row_other_col loss
mean 4.696878e-07 0.361405 5.302940e-07 9.718822e-07 2.727694
std 2.513685e-09 0.019269 9.045297e-09 6.403969e-08 0.063380
100%|██████████| 40/40 [00:03<00:00, 11.11it/s]
d body d cell d same_col_other_row d same_row_other_col loss px_p body px_p cell px_p same_col_other_row px_p same_row_other_col px_r body px_r cell px_r same_col_other_row px_r same_row_other_col
mean 0.000008 0.308644 8.800283e-06 1.790338e-05 2.919124 0.0 0.292643 0.0 0.0 0.0 0.981026 0.0 0.0
std 0.000001 0.085783 6.330175e-07 8.569461e-07 0.448301 0.0 0.048701 0.0 0.0 0.0 0.007787 0.0 0.0
  0%|          | 0/10 [00:00<?, ?it/s]
epoch 9
lr 0.001
100%|██████████| 10/10 [00:09<00:00,  1.10it/s]
d body d cell d same_col_other_row d same_row_other_col loss
mean 4.723513e-07 0.393554 5.645739e-07 1.216482e-06 2.567290
std 2.670330e-09 0.027081 9.522667e-09 7.823057e-08 0.084079
100%|██████████| 40/40 [00:03<00:00, 10.08it/s]
d body d cell d same_col_other_row d same_row_other_col loss px_p body px_p cell px_p same_col_other_row px_p same_row_other_col px_r body px_r cell px_r same_col_other_row px_r same_row_other_col
mean 7.559051e-06 0.387017 9.209408e-06 0.000022 2.649665 0.0 0.333567 0.0 0.0 0.0 0.993655 0.0 0.0
std 8.585435e-07 0.078417 6.265015e-07 0.000001 0.361124 0.0 0.060072 0.0 0.0 0.0 0.002684 0.0 0.0
  0%|          | 0/10 [00:00<?, ?it/s]
epoch 10
lr 0.001
100%|██████████| 10/10 [00:07<00:00,  1.28it/s]
d body d cell d same_col_other_row d same_row_other_col loss
mean 4.781836e-07 0.394806 5.962996e-07 1.502126e-06 2.530196
std 3.485911e-09 0.023124 9.592594e-09 9.922580e-08 0.075815
100%|██████████| 40/40 [00:03<00:00, 10.20it/s]
d body d cell d same_col_other_row d same_row_other_col loss px_p body px_p cell px_p same_col_other_row px_p same_row_other_col px_r body px_r cell px_r same_col_other_row px_r same_row_other_col
mean 7.779767e-06 0.396890 9.788359e-06 2.701040e-05 2.520801 0.0 0.333075 0.0 0.0 0.0 0.978330 0.0 0.0
std 8.193675e-07 0.068999 7.159282e-07 8.718568e-07 0.309079 0.0 0.067545 0.0 0.0 0.0 0.006148 0.0 0.0
  0%|          | 0/10 [00:00<?, ?it/s]
epoch 11
lr 0.001
100%|██████████| 10/10 [00:08<00:00,  1.17it/s]
d body d cell d same_col_other_row d same_row_other_col loss
mean 4.965779e-07 0.420561 6.329131e-07 1.835616e-06 2.394897
std 7.890650e-09 0.020641 1.135715e-08 1.076554e-07 0.078355
100%|██████████| 40/40 [00:03<00:00, 10.92it/s]
d body d cell d same_col_other_row d same_row_other_col loss px_p body px_p cell px_p same_col_other_row px_p same_row_other_col px_r body px_r cell px_r same_col_other_row px_r same_row_other_col
mean 0.000008 0.396272 0.000010 0.000035 2.528565 0.0 0.324123 0.0 0.0 0.0 0.995379 0.0 0.0
std 0.000001 0.066415 0.000001 0.000001 0.316243 0.0 0.074986 0.0 0.0 0.0 0.002211 0.0 0.0
  0%|          | 0/10 [00:00<?, ?it/s]
epoch 12
lr 0.001
100%|██████████| 10/10 [00:08<00:00,  1.14it/s]
d body d cell d same_col_other_row d same_row_other_col loss
mean 5.272059e-07 0.442754 6.663432e-07 2.204662e-06 2.284052
std 1.120027e-08 0.017337 1.001487e-08 1.294353e-07 0.058011
100%|██████████| 40/40 [00:03<00:00, 11.44it/s]
d body d cell d same_col_other_row d same_row_other_col loss px_p body px_p cell px_p same_col_other_row px_p same_row_other_col px_r body px_r cell px_r same_col_other_row px_r same_row_other_col
mean 0.000009 0.406362 0.000011 0.000042 2.427863 0.0 0.321298 0.0 0.0 0.0 0.974148 0.0 0.0
std 0.000001 0.058815 0.000001 0.000002 0.271971 0.0 0.075108 0.0 0.0 0.0 0.009732 0.0 0.0
  0%|          | 0/10 [00:00<?, ?it/s]
epoch 13
lr 0.001
100%|██████████| 10/10 [00:08<00:00,  1.22it/s]
d body d cell d same_col_other_row d same_row_other_col loss
mean 5.637194e-07 0.445014 6.969446e-07 2.664660e-06 2.228165
std 1.224645e-08 0.017822 1.231571e-08 2.083687e-07 0.058812
100%|██████████| 40/40 [00:03<00:00, 10.98it/s]
d body d cell d same_col_other_row d same_row_other_col loss px_p body px_p cell px_p same_col_other_row px_p same_row_other_col px_r body px_r cell px_r same_col_other_row px_r same_row_other_col
mean 0.000009 0.449537 0.000011 0.000045 2.312730 0.0 0.333375 0.0 0.0 0.0 0.991989 0.0 0.0
std 0.000001 0.063453 0.000001 0.000004 0.350057 0.0 0.069113 0.0 0.0 0.0 0.003696 0.0 0.0
  0%|          | 0/10 [00:00<?, ?it/s]
epoch 14
lr 0.001
100%|██████████| 10/10 [00:08<00:00,  1.21it/s]
d body d cell d same_col_other_row d same_row_other_col loss
mean 6.074658e-07 0.467663 7.326023e-07 3.157316e-06 2.115110
std 1.826640e-08 0.012100 1.620846e-08 1.895829e-07 0.049588
100%|██████████| 40/40 [00:04<00:00,  9.92it/s]
d body d cell d same_col_other_row d same_row_other_col loss px_p body px_p cell px_p same_col_other_row px_p same_row_other_col px_r body px_r cell px_r same_col_other_row px_r same_row_other_col
mean 0.000011 0.440490 0.000012 0.000064 2.217857 0.0 0.324599 0.0 0.0 0.0 0.973709 0.0 0.0
std 0.000002 0.081561 0.000002 0.000004 0.360991 0.0 0.086102 0.0 0.0 0.0 0.022678 0.0 0.0
  0%|          | 0/10 [00:00<?, ?it/s]
epoch 15
lr 0.0005
100%|██████████| 10/10 [00:08<00:00,  1.22it/s]
d body d cell d same_col_other_row d same_row_other_col loss
mean 6.436329e-07 0.469534 7.580789e-07 3.686112e-06 2.075336
std 1.101252e-08 0.024338 1.072071e-08 2.071961e-07 0.078159
100%|██████████| 40/40 [00:03<00:00, 10.13it/s]
d body d cell d same_col_other_row d same_row_other_col loss px_p body px_p cell px_p same_col_other_row px_p same_row_other_col px_r body px_r cell px_r same_col_other_row px_r same_row_other_col
mean 0.000011 0.462069 0.000013 0.000062 2.124047 0.0 0.318995 0.0 0.0 0.0 0.980194 0.0 0.0
std 0.000002 0.071356 0.000002 0.000007 0.372804 0.0 0.072028 0.0 0.0 0.0 0.010592 0.0 0.0
  0%|          | 0/10 [00:00<?, ?it/s]
epoch 16
lr 0.0005
100%|██████████| 10/10 [00:08<00:00,  1.15it/s]
d body d cell d same_col_other_row d same_row_other_col loss
mean 6.704073e-07 0.470069 7.778506e-07 4.044111e-06 2.046208
std 1.439625e-08 0.028854 1.597843e-08 1.551180e-07 0.117343
100%|██████████| 40/40 [00:03<00:00, 10.37it/s]
d body d cell d same_col_other_row d same_row_other_col loss px_p body px_p cell px_p same_col_other_row px_p same_row_other_col px_r body px_r cell px_r same_col_other_row px_r same_row_other_col
mean 0.000011 0.472956 0.000013 0.000066 2.086362 0.0 0.314791 0.0 0.0 0.0 0.989034 0.0 0.0
std 0.000001 0.074405 0.000001 0.000006 0.296834 0.0 0.077663 0.0 0.0 0.0 0.012916 0.0 0.0
  0%|          | 0/10 [00:00<?, ?it/s]
epoch 17
lr 0.0005
100%|██████████| 10/10 [00:07<00:00,  1.26it/s]
d body d cell d same_col_other_row d same_row_other_col loss
mean 7.013876e-07 0.484372 7.991458e-07 4.481882e-06 1.981633
std 1.116219e-08 0.024037 1.085147e-08 2.200686e-07 0.089767
100%|██████████| 40/40 [00:03<00:00, 10.36it/s]
d body d cell d same_col_other_row d same_row_other_col loss px_p body px_p cell px_p same_col_other_row px_p same_row_other_col px_r body px_r cell px_r same_col_other_row px_r same_row_other_col
mean 0.000012 0.486938 0.000013 0.000070 1.987426 0.0 0.310268 0.0 0.0 0.0 0.991668 0.0 0.0
std 0.000001 0.065029 0.000001 0.000009 0.295425 0.0 0.066796 0.0 0.0 0.0 0.003315 0.0 0.0
  0%|          | 0/10 [00:00<?, ?it/s]
epoch 18
lr 0.0005
100%|██████████| 10/10 [00:08<00:00,  1.21it/s]
d body d cell d same_col_other_row d same_row_other_col loss
mean 7.352651e-07 0.512494 8.222158e-07 4.898392e-06 1.881299
std 1.005457e-08 0.015508 1.183388e-08 2.865118e-07 0.062425
100%|██████████| 40/40 [00:03<00:00, 10.54it/s]
d body d cell d same_col_other_row d same_row_other_col loss px_p body px_p cell px_p same_col_other_row px_p same_row_other_col px_r body px_r cell px_r same_col_other_row px_r same_row_other_col
mean 0.000012 0.506746 0.000013 0.000084 1.955938 0.0 0.329113 0.0 0.0 0.0 0.988106 0.0 0.0
std 0.000001 0.057350 0.000001 0.000009 0.279165 0.0 0.059826 0.0 0.0 0.0 0.021439 0.0 0.0
  0%|          | 0/10 [00:00<?, ?it/s]
epoch 19
lr 0.0005
100%|██████████| 10/10 [00:08<00:00,  1.16it/s]
d body d cell d same_col_other_row d same_row_other_col loss
mean 7.675077e-07 0.508328 8.419666e-07 5.264745e-06 1.867416
std 1.827367e-08 0.015662 1.764448e-08 4.638427e-07 0.072337
100%|██████████| 40/40 [00:03<00:00, 10.35it/s]
d body d cell d same_col_other_row d same_row_other_col loss px_p body px_p cell px_p same_col_other_row px_p same_row_other_col px_r body px_r cell px_r same_col_other_row px_r same_row_other_col
mean 0.000013 0.508288 0.000014 0.000103 1.914061 0.0 0.340286 0.0 0.0 0.0 0.987672 0.0 0.0
std 0.000002 0.064047 0.000002 0.000010 0.319403 0.0 0.061229 0.0 0.0 0.0 0.020597 0.0 0.0
  0%|          | 0/10 [00:00<?, ?it/s]
epoch 20
lr 0.0005
100%|██████████| 10/10 [00:07<00:00,  1.26it/s]
d body d cell d same_col_other_row d same_row_other_col loss
mean 8.021081e-07 0.517115 8.579039e-07 5.983006e-06 1.832522
std 2.860042e-08 0.020202 2.327583e-08 5.076560e-07 0.089254
100%|██████████| 40/40 [00:03<00:00, 10.71it/s]
d body d cell d same_col_other_row d same_row_other_col loss px_p body px_p cell px_p same_col_other_row px_p same_row_other_col px_r body px_r cell px_r same_col_other_row px_r same_row_other_col
mean 0.000014 0.490931 0.000014 0.000109 1.929016 0.0 0.305960 0.0 0.0 0.0 0.992342 0.0 0.0
std 0.000002 0.075406 0.000002 0.000012 0.319662 0.0 0.077085 0.0 0.0 0.0 0.003423 0.0 0.0
  0%|          | 0/10 [00:00<?, ?it/s]
epoch 21
lr 0.0005
100%|██████████| 10/10 [00:08<00:00,  1.16it/s]
d body d cell d same_col_other_row d same_row_other_col loss
mean 8.485857e-07 0.535772 8.887929e-07 6.716074e-06 1.751647
std 1.417976e-08 0.013363 1.305302e-08 4.161513e-07 0.052792
100%|██████████| 40/40 [00:03<00:00, 10.78it/s]
d body d cell d same_col_other_row d same_row_other_col loss px_p body px_p cell px_p same_col_other_row px_p same_row_other_col px_r body px_r cell px_r same_col_other_row px_r same_row_other_col
mean 0.000015 0.518763 0.000015 0.000129 1.854332 0.0 0.330046 0.0 0.0 0.0 0.986644 0.0 0.0
std 0.000002 0.066386 0.000002 0.000015 0.315699 0.0 0.067962 0.0 0.0 0.0 0.046382 0.0 0.0
  0%|          | 0/10 [00:00<?, ?it/s]
epoch 22
lr 0.0005
100%|██████████| 10/10 [00:08<00:00,  1.24it/s]
d body d cell d same_col_other_row d same_row_other_col loss
mean 8.988887e-07 0.536484 9.150536e-07 7.762128e-06 1.726291
std 1.249380e-08 0.017175 1.079329e-08 4.067237e-07 0.053727
100%|██████████| 40/40 [00:03<00:00, 10.03it/s]
d body d cell d same_col_other_row d same_row_other_col loss px_p body px_p cell px_p same_col_other_row px_p same_row_other_col px_r body px_r cell px_r same_col_other_row px_r same_row_other_col
mean 0.000015 0.559976 0.000015 0.000141 1.716911 0.0 0.352083 0.0 0.0 0.0 0.979360 0.0 0.0
std 0.000001 0.060852 0.000002 0.000016 0.271989 0.0 0.069479 0.0 0.0 0.0 0.044414 0.0 0.0
  0%|          | 0/10 [00:00<?, ?it/s]
epoch 23
lr 0.0005
100%|██████████| 10/10 [00:08<00:00,  1.22it/s]
d body d cell d same_col_other_row d same_row_other_col loss
mean 9.482100e-07 0.555734 9.429784e-07 8.640225e-06 1.656072
std 3.238485e-08 0.020435 2.685593e-08 3.668440e-07 0.088438
100%|██████████| 40/40 [00:03<00:00, 10.15it/s]
d body d cell d same_col_other_row d same_row_other_col loss px_p body px_p cell px_p same_col_other_row px_p same_row_other_col px_r body px_r cell px_r same_col_other_row px_r same_row_other_col
mean 0.000017 0.496053 0.000017 0.000146 1.738916 0.0 0.302521 0.0 0.0 0.0 0.994840 0.0 0.0
std 0.000002 0.065417 0.000002 0.000008 0.274936 0.0 0.062216 0.0 0.0 0.0 0.002509 0.0 0.0
  0%|          | 0/10 [00:00<?, ?it/s]
epoch 24
lr 0.0005
100%|██████████| 10/10 [00:08<00:00,  1.23it/s]
d body d cell d same_col_other_row d same_row_other_col loss
mean 1.013733e-06 0.572098 9.760464e-07 1.008319e-05 1.592463
std 2.237418e-08 0.020028 1.550510e-08 5.685652e-07 0.065901
100%|██████████| 40/40 [00:03<00:00, 10.98it/s]
d body d cell d same_col_other_row d same_row_other_col loss px_p body px_p cell px_p same_col_other_row px_p same_row_other_col px_r body px_r cell px_r same_col_other_row px_r same_row_other_col
mean 0.000017 0.591408 0.000016 0.000180 1.645308 0.0 0.362868 0.0 0.0 0.0 0.990518 0.0 0.0
std 0.000001 0.074977 0.000002 0.000021 0.295965 0.0 0.086541 0.0 0.0 0.0 0.006390 0.0 0.0
  0%|          | 0/10 [00:00<?, ?it/s]
epoch 25
lr 0.0005
100%|██████████| 10/10 [00:08<00:00,  1.16it/s]
d body d cell d same_col_other_row d same_row_other_col loss
mean 1.100349e-06 0.583318 1.019814e-06 1.164305e-05 1.524867
std 2.403200e-08 0.019824 1.518060e-08 5.581849e-07 0.057432
100%|██████████| 40/40 [00:03<00:00, 10.07it/s]
d body d cell d same_col_other_row d same_row_other_col loss px_p body px_p cell px_p same_col_other_row px_p same_row_other_col px_r body px_r cell px_r same_col_other_row px_r same_row_other_col
mean 0.000019 0.574778 0.000017 0.000212 1.573695 0.0 0.339245 0.0 0.0 0.0 0.992866 0.0 0.0
std 0.000002 0.062442 0.000002 0.000024 0.258720 0.0 0.066204 0.0 0.0 0.0 0.012912 0.0 0.0
  0%|          | 0/10 [00:00<?, ?it/s]
epoch 26
lr 0.0005
100%|██████████| 10/10 [00:08<00:00,  1.17it/s]
d body d cell d same_col_other_row d same_row_other_col loss
mean 1.162765e-06 0.596255 1.045548e-06 1.317354e-05 1.492137
std 3.816768e-08 0.017220 3.048885e-08 6.876149e-07 0.077274
100%|██████████| 40/40 [00:03<00:00, 10.84it/s]
d body d cell d same_col_other_row d same_row_other_col loss px_p body px_p cell px_p same_col_other_row px_p same_row_other_col px_r body px_r cell px_r same_col_other_row px_r same_row_other_col
mean 0.000020 0.596355 0.000017 0.000238 1.534820 0.0 0.436268 0.0 0.0 0.0 0.985670 0.0 0.0
std 0.000001 0.068184 0.000002 0.000021 0.238188 0.0 0.085546 0.0 0.0 0.0 0.009845 0.0 0.0
  0%|          | 0/10 [00:00<?, ?it/s]
epoch 27
lr 0.0005
100%|██████████| 10/10 [00:07<00:00,  1.29it/s]
d body d cell d same_col_other_row d same_row_other_col loss
mean 1.272817e-06 0.606993 1.106150e-06 1.524626e-05 1.418599
std 4.560539e-08 0.020047 3.938241e-08 9.655695e-07 0.092408
100%|██████████| 40/40 [00:03<00:00, 10.37it/s]
d body d cell d same_col_other_row d same_row_other_col loss px_p body px_p cell px_p same_col_other_row px_p same_row_other_col px_r body px_r cell px_r same_col_other_row px_r same_row_other_col
mean 0.000022 0.617299 0.000018 0.000294 1.465348 0.0 0.522851 0.0 0.0 0.0 0.99508 0.0 0.0
std 0.000002 0.064092 0.000002 0.000035 0.241676 0.0 0.083930 0.0 0.0 0.0 0.00403 0.0 0.0
  0%|          | 0/10 [00:00<?, ?it/s]
epoch 28
lr 0.0005
100%|██████████| 10/10 [00:08<00:00,  1.19it/s]
d body d cell d same_col_other_row d same_row_other_col loss
mean 1.369895e-06 0.622623 1.147918e-06 1.812149e-05 1.372976
std 3.627153e-08 0.019790 2.770903e-08 9.551123e-07 0.079096
100%|██████████| 40/40 [00:04<00:00,  9.36it/s]
d body d cell d same_col_other_row d same_row_other_col loss px_p body px_p cell px_p same_col_other_row px_p same_row_other_col px_r body px_r cell px_r same_col_other_row px_r same_row_other_col
mean 0.000023 0.656064 0.000019 0.000315 1.352967 0.0 0.977001 0.0 0.0 0.0 0.989433 0.0 0.0
std 0.000001 0.070508 0.000002 0.000045 0.281273 0.0 0.025069 0.0 0.0 0.0 0.013052 0.0 0.0
  0%|          | 0/10 [00:00<?, ?it/s]
epoch 29
lr 0.0005
100%|██████████| 10/10 [00:08<00:00,  1.20it/s]
d body d cell d same_col_other_row d same_row_other_col loss
mean 1.500074e-06 0.636635 1.230565e-06 0.000022 1.301772
std 4.181190e-08 0.020611 2.695326e-08 0.000002 0.055873
100%|██████████| 40/40 [00:03<00:00, 10.89it/s]
d body d cell d same_col_other_row d same_row_other_col loss px_p body px_p cell px_p same_col_other_row px_p same_row_other_col px_r body px_r cell px_r same_col_other_row px_r same_row_other_col
mean 0.000026 0.635561 0.000021 0.000419 1.352153 0.0 0.919176 0.0 0.0 0.0 0.983354 0.0 0.0
std 0.000002 0.068684 0.000002 0.000059 0.281215 0.0 0.024224 0.0 0.0 0.0 0.012470 0.0 0.0
  0%|          | 0/10 [00:00<?, ?it/s]
epoch 30
lr 0.00025
100%|██████████| 10/10 [00:08<00:00,  1.15it/s]
d body d cell d same_col_other_row d same_row_other_col loss
mean 1.577746e-06 0.649867 1.280698e-06 2.449065e-05 1.261581
std 5.126670e-08 0.017249 4.803407e-08 9.914197e-07 0.068829
100%|██████████| 40/40 [00:03<00:00, 10.19it/s]
d body d cell d same_col_other_row d same_row_other_col loss px_p body px_p cell px_p same_col_other_row px_p same_row_other_col px_r body px_r cell px_r same_col_other_row px_r same_row_other_col
mean 0.000027 0.643658 0.000021 0.000438 1.309131 0.0 0.985446 0.0 0.0 0.0 0.992116 0.0 0.0
std 0.000002 0.076818 0.000003 0.000066 0.281658 0.0 0.023572 0.0 0.0 0.0 0.002771 0.0 0.0
  0%|          | 0/10 [00:00<?, ?it/s]
epoch 31
lr 0.00025
100%|██████████| 10/10 [00:08<00:00,  1.22it/s]
d body d cell d same_col_other_row d same_row_other_col loss
mean 1.659675e-06 0.658592 1.327976e-06 2.746361e-05 1.224913
std 2.853164e-08 0.009596 1.978852e-08 9.353681e-07 0.033757
100%|██████████| 40/40 [00:03<00:00, 11.03it/s]
d body d cell d same_col_other_row d same_row_other_col loss px_p body px_p cell px_p same_col_other_row px_p same_row_other_col px_r body px_r cell px_r same_col_other_row px_r same_row_other_col
mean 0.000027 0.661731 0.000021 0.000498 1.316361 0.0 0.983167 0.0 0.0 0.0 0.989353 0.0 0.0
std 0.000002 0.069066 0.000002 0.000067 0.279384 0.0 0.019594 0.0 0.0 0.0 0.038384 0.0 0.0
  0%|          | 0/10 [00:00<?, ?it/s]
epoch 32
lr 0.00025
100%|██████████| 10/10 [00:08<00:00,  1.22it/s]
d body d cell d same_col_other_row d same_row_other_col loss
mean 1.732516e-06 0.676819 1.369803e-06 2.965649e-05 1.174078
std 3.019353e-08 0.016651 2.782046e-08 7.580107e-07 0.059281
100%|██████████| 40/40 [00:03<00:00, 10.77it/s]
d body d cell d same_col_other_row d same_row_other_col loss px_p body px_p cell px_p same_col_other_row px_p same_row_other_col px_r body px_r cell px_r same_col_other_row px_r same_row_other_col
mean 0.000028 0.669680 0.000022 0.000544 1.266303 0.0 0.983729 0.0 0.0 0.0 0.983225 0.0 0.0
std 0.000002 0.065999 0.000002 0.000080 0.211907 0.0 0.026099 0.0 0.0 0.0 0.045972 0.0 0.0
  0%|          | 0/10 [00:00<?, ?it/s]
epoch 33
lr 0.00025
100%|██████████| 10/10 [00:08<00:00,  1.21it/s]
d body d cell d same_col_other_row d same_row_other_col loss
mean 1.799888e-06 0.681146 1.414841e-06 0.000032 1.155716
std 4.865850e-08 0.019984 4.611424e-08 0.000001 0.072840
100%|██████████| 40/40 [00:03<00:00, 10.31it/s]
d body d cell d same_col_other_row d same_row_other_col loss px_p body px_p cell px_p same_col_other_row px_p same_row_other_col px_r body px_r cell px_r same_col_other_row px_r same_row_other_col
mean 0.000029 0.682802 0.000023 0.000563 1.243143 0.0 0.983709 0.0 0.0 0.0 0.995291 0.0 0.0
std 0.000002 0.060236 0.000003 0.000093 0.284146 0.0 0.022795 0.0 0.0 0.0 0.002656 0.0 0.0
  0%|          | 0/10 [00:00<?, ?it/s]
epoch 34
lr 0.00025
100%|██████████| 10/10 [00:08<00:00,  1.16it/s]
d body d cell d same_col_other_row d same_row_other_col loss
mean 1.862290e-06 0.686934 1.458404e-06 0.000035 1.137857
std 7.735195e-08 0.019218 6.145749e-08 0.000002 0.095763
100%|██████████| 40/40 [00:03<00:00, 10.28it/s]
d body d cell d same_col_other_row d same_row_other_col loss px_p body px_p cell px_p same_col_other_row px_p same_row_other_col px_r body px_r cell px_r same_col_other_row px_r same_row_other_col
mean 0.000031 0.699686 0.000024 0.000610 1.167262 0.0 0.984951 0.0 0.0 0.0 0.991993 0.0 0.0
std 0.000002 0.058780 0.000002 0.000092 0.231188 0.0 0.021716 0.0 0.0 0.0 0.008972 0.0 0.0
  0%|          | 0/10 [00:00<?, ?it/s]
epoch 35
lr 0.00025
100%|██████████| 10/10 [00:08<00:00,  1.16it/s]
d body d cell d same_col_other_row d same_row_other_col loss
mean 1.965175e-06 0.697720 1.517793e-06 0.000038 1.098333
std 6.293186e-08 0.014752 5.631265e-08 0.000001 0.068941
100%|██████████| 40/40 [00:03<00:00, 11.32it/s]
d body d cell d same_col_other_row d same_row_other_col loss px_p body px_p cell px_p same_col_other_row px_p same_row_other_col px_r body px_r cell px_r same_col_other_row px_r same_row_other_col
mean 0.000032 0.705084 0.000024 0.000730 1.204058 0.0 0.987946 0.0 0.0 0.0 0.984649 0.0 0.0
std 0.000002 0.070402 0.000002 0.000125 0.242602 0.0 0.017997 0.0 0.0 0.0 0.034994 0.0 0.0
  0%|          | 0/10 [00:00<?, ?it/s]
epoch 36
lr 0.00025
100%|██████████| 10/10 [00:08<00:00,  1.19it/s]
d body d cell d same_col_other_row d same_row_other_col loss
mean 2.042392e-06 0.704541 1.561376e-06 0.000042 1.082145
std 6.496223e-08 0.013167 4.965266e-08 0.000002 0.062781
100%|██████████| 40/40 [00:03<00:00, 11.56it/s]
d body d cell d same_col_other_row d same_row_other_col loss px_p body px_p cell px_p same_col_other_row px_p same_row_other_col px_r body px_r cell px_r same_col_other_row px_r same_row_other_col
mean 0.000033 0.700990 0.000025 0.000812 1.231784 0.0 0.986605 0.0 0.0 0.0 0.980493 0.0 0.0
std 0.000002 0.074356 0.000003 0.000149 0.291156 0.0 0.018601 0.0 0.0 0.0 0.069881 0.0 0.0
  0%|          | 0/10 [00:00<?, ?it/s]
epoch 37
lr 0.00025
100%|██████████| 10/10 [00:08<00:00,  1.14it/s]
d body d cell d same_col_other_row d same_row_other_col loss
mean 2.150527e-06 0.726345 1.632555e-06 0.000045 1.017095
std 7.344926e-08 0.012793 6.154299e-08 0.000002 0.063205
100%|██████████| 40/40 [00:04<00:00,  9.58it/s]
d body d cell d same_col_other_row d same_row_other_col loss px_p body px_p cell px_p same_col_other_row px_p same_row_other_col px_r body px_r cell px_r same_col_other_row px_r same_row_other_col
mean 0.000036 0.717228 0.000027 0.000856 1.090806 0.0 0.984360 0.0 0.0 0.0 0.982790 0.0 0.0
std 0.000003 0.079988 0.000003 0.000143 0.239776 0.0 0.023867 0.0 0.0 0.0 0.045458 0.0 0.0
  0%|          | 0/10 [00:00<?, ?it/s]
epoch 38
lr 0.00025
100%|██████████| 10/10 [00:08<00:00,  1.22it/s]
d body d cell d same_col_other_row d same_row_other_col loss
mean 2.245151e-06 0.724859 1.683035e-06 0.000050 1.020063
std 5.269124e-08 0.015734 4.235574e-08 0.000002 0.045483
100%|██████████| 40/40 [00:03<00:00, 10.31it/s]
d body d cell d same_col_other_row d same_row_other_col loss px_p body px_p cell px_p same_col_other_row px_p same_row_other_col px_r body px_r cell px_r same_col_other_row px_r same_row_other_col
mean 0.000037 0.714467 0.000028 0.000913 1.085198 0.0 0.993840 0.0 0.0 0.0 0.990413 0.0 0.0
std 0.000003 0.059678 0.000004 0.000170 0.218067 0.0 0.011644 0.0 0.0 0.0 0.003608 0.0 0.0
  0%|          | 0/10 [00:00<?, ?it/s]
epoch 39
lr 0.00025
100%|██████████| 10/10 [00:08<00:00,  1.15it/s]
d body d cell d same_col_other_row d same_row_other_col loss
mean 2.391169e-06 0.744293 1.780677e-06 0.000054 0.951571
std 6.446971e-08 0.012642 5.489111e-08 0.000002 0.054607
100%|██████████| 40/40 [00:03<00:00, 10.67it/s]
d body d cell d same_col_other_row d same_row_other_col loss px_p body px_p cell px_p same_col_other_row px_p same_row_other_col px_r body px_r cell px_r same_col_other_row px_r same_row_other_col
mean 0.000040 0.735458 0.000029 0.001032 1.041447 0.0 0.987958 0.0 0.0 0.0 0.993013 0.0 0.0
std 0.000002 0.063813 0.000003 0.000193 0.216088 0.0 0.018660 0.0 0.0 0.0 0.007584 0.0 0.0

Structured End-to-End


In [30]:
# train_gen = None
# val_gen = None

# EPOCHS_NUM = 30
# BATCH_SIZE = 1
# PART_PER_EPOCH = 0.5
# BATCHES_PER_EPOCH_TRAIN = int(len(train_image_ids) * PART_PER_EPOCH // BATCH_SIZE)
# BATCHES_PER_EPOCH_VAL = int(len(val_image_ids) * PART_PER_EPOCH // BATCH_SIZE)

# net = TableSegmenter().cuda()
# RAW_LOSS = dice_loss
# print('total parameters', sum(numpy.product(p.size()) for p in net.parameters()))

# train_metrics = []
# val_metrics = []
# for epoch in range(EPOCHS_NUM):
#     try:
#         train_gen = iter(DataLoader(SegmDataset(train_image_ids, imgaug_pipeline, prepare_int_batch),
#                                     batch_size=BATCH_SIZE,
#                                     shuffle=True,
#                                     num_workers=4))
#         val_gen = iter(DataLoader(SegmDataset(val_image_ids, imgaug_pipeline, prepare_int_batch),
#                                   batch_size=BATCH_SIZE,
#                                   shuffle=True,
#                                   num_workers=2))
#         print('epoch', epoch)

#         lr_factor = 0.5 ** (epoch // 20)
#         lr = 1e-3 * lr_factor
#         print('lr', lr)
#         optimizer = Adam(net.parameters(), lr=lr)

#         net.train()
#         cur_train_metrics = run_network_structured(net, train_gen,
#                                                    BATCHES_PER_EPOCH_TRAIN,
#                                                    raw_criterion=RAW_LOSS,
#                                                    metrics=TRAIN_METRICS,
#                                                    optimizer=optimizer)
#         train_metrics.extend(cur_train_metrics)
#         display(pandas.DataFrame(cur_train_metrics).describe().loc[['mean', 'std']])

#         net.eval()
#         cur_val_metrics = run_network_structured(net, val_gen,
#                                                  BATCHES_PER_EPOCH_VAL,
#                                                  raw_criterion=RAW_LOSS,
#                                                  metrics=VAL_METRICS)
#         val_metrics.extend(cur_val_metrics)
#         display(pandas.DataFrame(cur_val_metrics).describe().loc[['mean', 'std']])
#     finally:
#         if train_gen:
#             clean_pytorch_dataloadergen(train_gen)
#         if val_gen:
#             clean_pytorch_dataloadergen(val_gen)

#         gc.collect()

Structured Two-step


In [31]:
# base_segm_model = torch.load('./models/segm_unet1_step_aug_full2_rc_sc_reg').cuda()

In [32]:
# base_train_gen = None
# base_val_gen = None

# EPOCHS_NUM = 30
# BASE_BATCH_SIZE = 4
# BATCH_SIZE = 10
# PART_PER_EPOCH_TRAIN = 0.05
# PART_PER_EPOCH_VAL = 1.0
# MEAN_RELS_PER_IMAGE = 2
# BATCHES_PER_EPOCH_TRAIN = int(len(train_int_image_ids) * PART_PER_EPOCH_TRAIN // BASE_BATCH_SIZE) * MEAN_RELS_PER_IMAGE
# BATCHES_PER_EPOCH_VAL = int(len(val_int_image_ids) * PART_PER_EPOCH_VAL // BASE_BATCH_SIZE) * MEAN_RELS_PER_IMAGE

# cell_net = ConvFCNClassifier(6, 5,
#                              conv_layers=3,
#                              dilations=[1, 2, 4, 8, 16, 32]).cuda()
# print('receptive field', cell_net.receptive_field)
# CELL_REL_LOSS = F.cross_entropy
# print('total parameters', sum(numpy.product(p.size()) for p in cell_net.parameters()))

# # cell_tamper_channels = [0, 2, 3]
# cell_tamper_channels = []

# # cell_augmenter = imgaug_pipeline
# cell_augmenter = fake_imgaug_pipeline

# CELL_LR_SCHEDULE = step_annealed_lr

# train_metrics = []
# val_metrics = []
# for epoch in range(EPOCHS_NUM):
#     try:
#         base_train_gen = iter(DataLoader(SegmDataset(train_int_image_ids,
#                                                      cell_augmenter,
#                                                      prepare_int_batch),
#                                          batch_size=BASE_BATCH_SIZE,
#                                          shuffle=True,
#                                          num_workers=4))
#         train_gen = structured_two_step_datagen(base_train_gen,
#                                                 model=base_segm_model,
#                                                 cuda=True,
#                                                 batch_size=BATCH_SIZE,
#                                                 tamper_channels=cell_tamper_channels)
#         qq = next(iter(train_gen))[0].data.cpu().numpy()
#         base_val_gen = iter(DataLoader(SegmDataset(val_int_image_ids,
#                                                    cell_augmenter,
#                                                    prepare_int_batch),
#                                        batch_size=BASE_BATCH_SIZE,
#                                        shuffle=True,
#                                        num_workers=2))
#         val_gen = structured_two_step_datagen(base_val_gen,
#                                               model=base_segm_model,
#                                               cuda=True,
#                                               batch_size=BATCH_SIZE,
#                                               tamper_channels=cell_tamper_channels)

#         print('epoch', epoch)

#         lr = CELL_LR_SCHEDULE(epoch)
#         print('lr', lr)
#         optimizer = Adam(cell_net.parameters(), lr=lr)

#         cell_net.train()
#         cur_train_metrics = run_cell_network(cell_net, train_gen, BATCHES_PER_EPOCH_TRAIN,
#                                              criterion=CELL_REL_LOSS,
#                                              metrics=CELL_TRAIN_METRICS,
#                                              optimizer=optimizer)
#         train_metrics.extend(cur_train_metrics)
#         display(pandas.DataFrame(cur_train_metrics).describe().loc[['mean', 'std']])

#         cell_net.eval()
#         cur_val_metrics = run_cell_network(cell_net, val_gen, BATCHES_PER_EPOCH_VAL,
#                                            criterion=CELL_REL_LOSS,
#                                            metrics=CELL_VAL_METRICS)
#         val_metrics.extend(cur_val_metrics)
#         display(pandas.DataFrame(cur_val_metrics).describe().loc[['mean', 'std']])
#     finally:
#         if base_train_gen:
#             clean_pytorch_dataloadergen(base_train_gen)
#         if base_val_gen:
#             clean_pytorch_dataloadergen(base_val_gen)

#         gc.collect()

Postprocess using superpixels


In [33]:
# base_segm_model = torch.load('./models/segm_unet1_step_aug_full2_rc_sc_reg').cuda()

In [34]:
# base_train_gen = None
# base_val_gen = None

# EPOCHS_NUM = 60
# BATCH_SIZE = 8
# PART_PER_EPOCH_TRAIN = 0.05
# PART_PER_EPOCH_VAL = 1.0
# MEAN_RELS_PER_IMAGE = 2
# BATCHES_PER_EPOCH_TRAIN = 8 # int(len(train_int_image_ids) * PART_PER_EPOCH_TRAIN // BASE_BATCH_SIZE) * MEAN_RELS_PER_IMAGE
# BATCHES_PER_EPOCH_VAL = 8 # int(len(val_int_image_ids) * PART_PER_EPOCH_VAL // BASE_BATCH_SIZE) * MEAN_RELS_PER_IMAGE

# pp_fex_net = dummy_sp_pp_fex
# # SimpleConvFCN(5, 16,
# #                            conv_layers=3,
# #                            dilations=[1, 2, 4, 8],
# #                            out_act=F.relu).cuda()
# # pp_sp_net = SimpleConvFCN(5, 1,
# #                           conv_layers=4,
# #                           dilations=[1, 2, 4, 8, 16, 32],
# #                           out_act=F.sigmoid).cuda()
# # pp_sp_net = RectRNN(5,
# #                     1,
# #                     num_layers=3)
# pp_sp_net = DenseRectRNN(5,
#                          1,
#                          inner_out_channels=32,
#                          inner_channels=32,
#                          layers=2)
# pp_net = SuperpixelPostprocessor(pp_fex_net,
#                                  pp_sp_net,
#                                  fex_agg=sp_pp_fex_agg_avgpool).cuda()
# # print('receptive field', pp_fex_net.receptive_field)
# PP_LOSS = F.binary_cross_entropy
# print('total parameters', sum(numpy.product(p.size()) for p in pp_net.parameters()))

# pp_augmenter = imgaug_pipeline
# # pp_augmenter = fake_imgaug_pipeline

# PP_LR_SCHEDULE = step_annealed_big_lr

# train_metrics = []
# val_metrics = []
# for epoch in range(EPOCHS_NUM):
#     try:
#         base_train_gen = iter(DataLoader(SegmDataset(train_int_image_ids,
#                                                      pp_augmenter,
#                                                      prepare_int_batch),
#                                          batch_size=BATCH_SIZE,
#                                          shuffle=True,
#                                          num_workers=4))
#         train_gen = superpixel_pp_datagen(base_train_gen,
#                                           model=base_segm_model,
#                                           cuda=True,
#                                           min_mean_for_pos=0.5)
# #         qq = next(iter(train_gen))[0].data.cpu().numpy()
#         base_val_gen = iter(DataLoader(SegmDataset(val_int_image_ids,
#                                                    pp_augmenter,
#                                                    prepare_int_batch),
#                                        batch_size=BATCH_SIZE,
#                                        shuffle=True,
#                                        num_workers=2))
#         val_gen = superpixel_pp_datagen(base_val_gen,
#                                         model=base_segm_model,
#                                         cuda=True,
#                                         min_mean_for_pos=0.5)

#         print('epoch', epoch)

#         lr = PP_LR_SCHEDULE(epoch)
#         print('lr', lr)
#         optimizer = Adam(pp_net.parameters(), lr=lr)

#         pp_net.train()
#         cur_train_metrics = run_sp_pp_network(pp_net, train_gen, BATCHES_PER_EPOCH_TRAIN,
#                                               criterion=PP_LOSS,
#                                               metrics=SP_PP_METRICS,
#                                               optimizer=optimizer)
#         train_metrics.extend(cur_train_metrics)
#         display(pandas.DataFrame(cur_train_metrics).describe().loc[['mean', 'std']])

#         pp_net.eval()
#         cur_val_metrics = run_sp_pp_network(pp_net, val_gen, BATCHES_PER_EPOCH_VAL,
#                                             criterion=PP_LOSS,
#                                             metrics=SP_PP_METRICS)
#         val_metrics.extend(cur_val_metrics)
#         display(pandas.DataFrame(cur_val_metrics).describe().loc[['mean', 'std']])
#     finally:
#         if base_train_gen:
#             clean_pytorch_dataloadergen(base_train_gen)
#         if base_val_gen:
#             clean_pytorch_dataloadergen(base_val_gen)

#         gc.collect()

Plot metrics


In [35]:
train_metrics = pandas.DataFrame(train_metrics)
val_metrics = pandas.DataFrame(val_metrics)
train_metrics.rolling(10).mean().plot(figsize=(13, 8), grid=True)
val_metrics.rolling(10).mean().plot(figsize=(13, 8), grid=True)


Out[35]:
<matplotlib.axes._subplots.AxesSubplot at 0x7f5b47e5dda0>

Predict and visualize

Table detection


In [36]:
# torch.save(det_net, 'models/det_unet1_aug_ds2_aug_dil')

In [37]:
# det_test_net = det_net
det_test_net = torch.load('models/det_unet3_aug_ds2').cuda()

In [38]:
# det_test_gen = DataLoader(SegmDataset(val_det_image_ids,
# #                                       imgaug_pipeline,
#                                       fake_imgaug_pipeline,
#                                       prepare_det_batch),
#                           batch_size=1,
#                           shuffle=True,
#                           num_workers=0)

In [39]:
# det_test_metrics = pandas.DataFrame(run_network(det_test_net,
#                                                 det_test_gen,
#                                                 len(val_det_image_ids) // 4,
#                                                 metrics=VAL_DET_METRICS))
# format_metrics_table_md(det_test_metrics)

In [40]:
# det_test_batch = next(iter(det_test_gen))
# det_test_img = det_test_batch[1].cpu().numpy()
# det_test_mask = det_test_batch[2].cpu().numpy()
# det_test_pred = det_test_net(npvar(det_test_batch[1], True))
# det_test_pred_np = det_test_pred.cpu().data.numpy()

In [41]:
# arr_to_img(det_test_img[2][0])

In [42]:
# arr_to_img(mask_to_img(det_test_mask[0]))

In [43]:
# arr_to_img(mask_to_img(det_test_pred_np[2]))

In [44]:
# arr_to_img(mask_to_img(fill_boxes_on_mask_single_image(det_test_pred_np[2])))

Results

Trained Without Augmentation
U-net1+step
bd d loss px_p px_r
body 0.90±0.10 0.90±0.10 0.91±0.12 0.92±0.11
caption 0.76±0.18 0.75±0.18 0.79±0.17 0.79±0.21
loss 0.21±0.24
DenseNet1
Trained With Augmentation
U-net1+step
bd d loss px_p px_r
body 0.90±0.10 0.89±0.10 0.93±0.11 0.88±0.12
caption 0.72±0.20 0.66±0.20 0.81±0.19 0.62±0.24
loss 0.25±0.24
U-net2+step+new dataset
bd d loss px_p px_r
body 0.89±0.10 0.89±0.10 0.92±0.12 0.89±0.11
caption 0.76±0.16 0.74±0.17 0.83±0.16 0.74±0.20
value 0.23±0.22
det_unet3_aug_ds2
bd d loss px_p px_r
body 0.90±0.10 0.89±0.10 0.93±0.12 0.90±0.11
caption 0.78±0.16 0.79±0.16 0.85±0.15 0.80±0.20
value 0.21±0.23
det_attnet1_aug_ds2
bd d loss px_p px_r
body 0.85±0.21 0.84±0.20 0.87±0.22 0.88±0.20
caption 0.68±0.32 0.64±0.30 0.73±0.34 0.67±0.31
value 0.43±0.96

Old Results

torch4_net
d body  d caption   loss
mean    0.889029    0.724086    0.238287
std 0.104200    0.185334    0.262712

Table structure segmentation


In [45]:
# torch.save(net, 'models/segm_unet1_step_aug_full2_rc_sc_reg_var_dil')

In [46]:
test_net = net

In [47]:
# test_net = torch.load('models/segm_unet1_step_aug_full2_rc_sc_reg_var_dil').cuda()

In [48]:
test_gen = DataLoader(SegmExtDataset(val_int_image_ids,
                                     imgaug_pipeline,
#                                      fake_imgaug_pipeline,
                                     prepare_int_ext_batch,
                                     vocab),
                      batch_size=1,
                      shuffle=True,
                      num_workers=0)

In [49]:
test_metrics = pandas.DataFrame(run_network_ext(test_net,
                                                test_gen,
                                                len(val_int_image_ids),
                                                metrics=TEST_INT_METRICS))
format_metrics_table_md(test_metrics)


100%|██████████| 400/400 [03:47<00:00,  1.76it/s]
|                    | bf        | bp        | br        | d         | loss      | px_p      | px_r      |
|:-------------------|:----------|:----------|:----------|:----------|:----------|:----------|:----------|
| body               | 0.03±0.16 | 0.02±0.09 | 0.04±0.21 | 0.00±0.00 |           | 0.00±0.00 | 0.00±0.00 |
| cell               | 0.13±0.10 | 0.98±0.03 | 0.12±0.10 | 0.75±0.06 |           | 0.99±0.02 | 0.99±0.03 |
| same_col_other_row | 0.20±0.05 | 0.01±0.07 | 0.11±0.03 | 0.00±0.00 |           | 0.00±0.00 | 0.00±0.00 |
| same_row_other_col | 0.00±0.01 | 0.00±0.00 | 0.00±0.01 | 0.00±0.00 |           | 0.00±0.00 | 0.00±0.00 |
| value              |           |           |           |           | 1.04±0.24 |           |           |


In [56]:
test_iter = iter(test_gen)

In [57]:
test_batch = next(test_iter)

In [58]:
test_pred, test_inner_outs = test_net(npvar(test_batch[1], True))
test_pred_np = test_pred.cpu().data.numpy()

In [63]:
print('bf', box_match_f1(test_pred, test_batch[2], [pickle.loads(test_batch[-1][0])]))
print('bp', box_match_precision(test_pred, test_batch[2], [pickle.loads(test_batch[-1][0])]))
print('br', box_match_recall(test_pred, test_batch[2], [pickle.loads(test_batch[-1][0])]))


bf [ 0.          0.828125    0.          0.18181818]
bp [ 0.          0.99173554  0.          0.        ]
br [ 0.          0.78518519  0.          0.1       ]

In [64]:
arr_to_img(mask_to_img(test_batch[1][0].cpu().numpy()))


Out[64]:

In [65]:
arr_to_img(mask_to_img(test_pred_np[0]))


Out[65]:

In [66]:
display_inner_outs(test_inner_outs)


/root/.pyenv/versions/3.6.0/lib/python3.6/site-packages/matplotlib/figure.py:2022: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.
  warnings.warn("This figure includes Axes that are not compatible "

In [ ]:
# arr_to_img(mask_to_img(test_batch[2][0].cpu().numpy()))

Results

Trained Without Augmentation
U-net1+step

Eval With Augmentation

bf bp br d loss
body 0.00±0.00 0.00±0.00 0.00±0.00 0.86±0.05
cell 0.43±0.20 0.60±0.25 0.38±0.19 0.47±0.19
loss 1.33±0.45
same_col_other_row 0.01±0.03 0.06±0.10 0.01±0.02 0.37±0.18
same_row_other_col 0.01±0.02 0.03±0.09 0.00±0.02 0.50±0.17

Eval Without Augmentation

bf bp br d loss
body 0.00±0.00 0.00±0.00 0.00±0.00 1.00±0.00
cell 0.97±0.02 1.00±0.01 0.95±0.03 0.96±0.00
loss 0.11±0.02
same_col_other_row 0.01±0.02 0.02±0.07 0.00±0.01 0.87±0.02
same_row_other_col 0.02±0.04 0.10±0.16 0.01±0.02 0.91±0.01
DenseNet1+step

Eval With Augmentation

Eval Without Augmentation

bf bp br d loss
body 0.00±0.00 0.00±0.00 0.00±0.00 0.98±0.01
cell 0.97±0.02 1.00±0.01 0.95±0.03 0.93±0.01
loss 0.32±0.02
same_col_other_row 0.00±0.00 0.00±0.00 0.00±0.00 0.27±0.01
same_row_other_col 0.13±0.08 0.50±0.23 0.08±0.05 0.64±0.02
Trained With Augmentation
U-net1+step

Eval With Augmentation

bf bp br d loss
body 0.00±0.00 0.00±0.00 0.00±0.00 0.99±0.01
cell 0.92±0.06 0.91±0.07 0.95±0.08 0.91±0.03
same_col_other_row 0.07±0.15 0.20±0.40 0.04±0.11 0.81±0.06
same_row_other_col 0.20±0.26 0.48±0.50 0.14±0.21 0.82±0.11
value 0.20±0.11
Combined Dataset (full)
U-Net1
bf bp br d loss
body 0.00±0.00 0.00±0.00 0.00±0.00 1.00±0.00
cell 0.96±0.05 0.96±0.04 0.96±0.07 0.94±0.02
same_col_other_row 0.04±0.12 0.13±0.34 0.03±0.08 0.85±0.05
same_row_other_col 0.09±0.19 0.28±0.45 0.06±0.14 0.88±0.04
value 0.14±0.05
DenseNet2
bf bp br d loss
body 0.00±0.00 0.00±0.00 0.00±0.00 0.99±0.01
cell 0.93±0.05 0.91±0.07 0.96±0.06 0.93±0.02
same_col_other_row 0.30±0.30 0.59±0.49 0.22±0.24 0.80±0.06
same_row_other_col 0.43±0.29 0.83±0.37 0.32±0.26 0.80±0.08
value 0.18±0.07
segm_densenetmean1_step_aug_full
bf bp br d loss
body 0.00±0.00 0.00±0.00 0.00±0.00 0.76±0.22
cell 0.93±0.12 0.91±0.12 0.95±0.12 0.86±0.15
same_col_other_row 0.21±0.26 0.48±0.50 0.15±0.20 0.63±0.20
same_row_other_col 0.29±0.30 0.61±0.49 0.22±0.25 0.57±0.28
value 1.47±1.78
segm_unetmean1_step_aug_full
bf bp br d loss
body 0.00±0.00 0.00±0.00 0.00±0.00 0.72±0.29
cell 0.83±0.24 0.79±0.25 0.89±0.24 0.76±0.24
same_col_other_row 0.19±0.26 0.40±0.48 0.14±0.21 0.60±0.24
same_row_other_col 0.24±0.28 0.55±0.49 0.17±0.23 0.67±0.26
value 1.71±3.18
segm_unetmean2_step_aug_full_shrinked_reweighted_cells_only_dec_dil
bf bp br d loss px_p px_r
body 0.00±0.00 0.00±0.00 0.00±0.00 0.95±0.12 0.00±0.00 0.00±0.00
cell 0.91±0.19 0.93±0.20 0.91±0.16 0.89±0.15 0.89±0.17 0.95±0.03
same_col_other_row 0.00±0.00 0.00±0.00 0.00±0.00 0.97±0.10 0.00±0.00 0.00±0.00
same_row_other_col 0.00±0.00 0.00±0.00 0.00±0.00 0.93±0.16 0.00±0.00 0.00±0.00
value 0.36±0.85
segm_unet3_step_aug_full_shrinked_reweighted_cells_only_dec_dil_out
bf bp br d loss px_p px_r
body 0.00±0.00 0.00±0.00 0.00±0.00 0.88±0.12 0.00±0.00 0.00±0.00
cell 0.95±0.06 0.97±0.05 0.94±0.09 0.93±0.03 0.94±0.03 0.95±0.03
same_col_other_row 0.00±0.00 0.00±0.00 0.00±0.00 0.72±0.18 0.00±0.00 0.00±0.00
same_row_other_col 0.00±0.00 0.00±0.00 0.00±0.00 0.83±0.15 0.00±0.00 0.00±0.00
value 0.17±0.08
Full2 (with big_simple)
segm_unet1_step_aug_full2
bf bp br d loss px_p px_r
body 0.00±0.00 0.00±0.00 0.00±0.00 0.89±0.13 0.00±0.00 0.00±0.00
cell 0.96±0.05 0.99±0.03 0.94±0.08 0.93±0.02 0.95±0.03 0.94±0.03
same_col_other_row 0.00±0.00 0.00±0.00 0.00±0.00 0.95±0.09 0.00±0.00 0.00±0.00
same_row_other_col 0.00±0.00 0.00±0.00 0.00±0.00 0.92±0.08 0.00±0.00 0.00±0.00
value 0.17±0.07
segm_unet1_step_aug_full2_rc_sc_reg (best currently!)
bf bp br d loss px_p px_r
body 0.00±0.00 0.00±0.00 0.00±0.00 1.00±0.00 0.00±0.00 0.00±0.00
cell 0.95±0.05 0.98±0.04 0.93±0.09 0.92±0.02 0.93±0.03 0.94±0.03
same_col_other_row 0.00±0.00 0.00±0.00 0.00±0.00 1.00±0.00 0.00±0.00 0.00±0.00
same_row_other_col 0.00±0.00 0.00±0.00 0.00±0.00 0.99±0.00 0.00±0.00 0.00±0.00
value 0.20±0.07
segm_unet1_step_aug_full2_rc_sc_reg_maxpool
bf bp br d loss px_p px_r
body 0.00±0.00 0.00±0.00 0.00±0.00 0.96±0.10 0.00±0.00 0.00±0.00
cell 0.95±0.05 0.95±0.07 0.96±0.07 0.92±0.02 0.93±0.03 0.94±0.03
same_col_other_row 0.00±0.03 0.01±0.07 0.00±0.02 0.94±0.12 0.00±0.00 0.00±0.00
same_row_other_col 0.00±0.02 0.00±0.05 0.00±0.01 0.97±0.10 0.00±0.00 0.00±0.00
value 0.20±0.07
segm_unet1_step_aug_full2_rc_sc_reg_with_ic
bf bp br d loss px_p px_r
body 0.00±0.00 0.00±0.00 0.00±0.00 1.00±0.00 0.00±0.00 0.00±0.00
cell 0.92±0.07 0.93±0.10 0.93±0.08 0.90±0.03 0.91±0.04 0.95±0.03
same_col_other_row 0.15±0.23 0.38±0.48 0.10±0.17 0.80±0.06 0.82±0.06 0.83±0.09
same_row_other_col 0.18±0.27 0.39±0.49 0.13±0.23 0.88±0.06 0.89±0.06 0.90±0.08
value 0.32±0.11

Intracell classification (second step)


In [ ]:
# torch.save(cell_net, './models/cell_fcn1_no_aug_diag_with_unet1_reg')

In [123]:
cell_test_net = torch.load('models/cell_fcn1_no_aug_diag_with_unet1_reg').cuda()
cell_test_base_segm_model = torch.load('models/segm_unet1_step_aug_full2_rc_sc_reg_var_dil').cuda()

In [ ]:
# cell_test_gen = structured_two_step_datagen(DataLoader(SegmDataset(val_int_image_ids,
#                                                                    imgaug_pipeline,
# #                                                                    fake_imgaug_pipeline,
#                                                                    prepare_int_batch),
#                                                        batch_size=4,
#                                                        shuffle=True,
#                                                        num_workers=0),
#                                             model=cell_test_base_segm_model,
#                                             cuda=True,
#                                             batch_size=10, 
#                                             tamper_channels=[])

In [ ]:
# cell_test_metrics = pandas.DataFrame(run_cell_network(cell_test_net,
#                                                       cell_test_gen,
#                                                       1000,
#                                                       metrics=CELL_TRAIN_METRICS))
# format_metrics_table_md(cell_test_metrics)

In [ ]:
# cell_test_iter = iter(cell_test_gen)

In [ ]:
# while True:
#     cell_test_batch = next(cell_test_iter)

#     cell_test_pred = cell_test_net(cell_test_batch[0])
#     cell_test_pred_np = cell_test_pred.cpu().data.numpy()

#     cell_test_in_np = cell_test_batch[0].data.cpu().numpy()
#     cell_test_gold_np = cell_test_batch[1].data.cpu().numpy()
#     if cell_test_gold_np.sum(0).sum() > 0:
#         break
# print(cell_test_gold_np),
# print(cell_test_pred_np)

In [ ]:
# cell_err_idx = numpy.where(cell_test_gold_np != cell_test_pred_np.argmax(1))[0]
# cell_err_idx

In [ ]:
# arr_to_img(mask_to_img(cell_test_in_np[0]))

Results

Trained Without Augmentation
cell_fcn1_no_aug + segm_unet1_step_aug dirty

eval no aug

f1 loss p r
loss 0.12±0.17
same_col 0.98±0.07 0.97±0.10 0.99±0.03
same_row 0.98±0.05 1.00±0.02 0.97±0.08

eval with aug

f1 loss p r
loss 0.12±0.19
same_col 0.97±0.07 0.96±0.10 0.99±0.03
same_row 0.98±0.05 0.99±0.03 0.97±0.08
cell_fcn1_no_aug + segm_unet1_step_aug clean

eval no aug

f1 loss p r
loss 0.06±0.09
same_col 0.99±0.04 0.99±0.06 1.00±0.02
same_row 0.99±0.03 1.00±0.02 0.99±0.04

eval with aug

f1 loss p r
loss 0.08±0.16
same_col 0.99±0.05 0.98±0.06 1.00±0.02
same_row 0.98±0.04 0.98±0.06 0.99±0.05
Trained With Augmentation

Postprocess using superpixels


In [45]:
# torch.save(pp_net, './models/pp_lstm_stacked2')

In [128]:
# sp_pp_test_net = SuperpixelPostprocessor(dummy_sp_pp_fex, dummy_sp_pp_cls_with_th, fex_agg=sp_pp_fex_agg_avgpool).cuda()
# sp_pp_test_net = pp_net
# sp_pp_test_net = torch.load('./models/pp_lstm_stacked2')
# sp_pp_test_base_segm_model = torch.load('models/segm_unet1_step_aug_full2_rc_sc_reg').cuda()

In [71]:
# test_pp_gen = superpixel_pp_datagen(DataLoader(SegmDataset(val_int_image_ids,
#                                                            pp_augmenter,
#                                                            prepare_int_batch),
#                                                batch_size=1,
#                                                shuffle=True,
#                                                num_workers=0),
#                                     model=sp_pp_test_base_segm_model,
#                                     cuda=True,
#                                     min_mean_for_pos=0.6)

In [70]:
# sp_pp_test_metrics = pandas.DataFrame(run_sp_pp_network(sp_pp_test_net,
#                                                         test_pp_gen,
#                                                         1000,
#                                                         criterion=F.binary_cross_entropy,
#                                                         metrics=SP_PP_METRICS))
# format_metrics_table_md(sp_pp_test_metrics)

In [69]:
# sp_pp_test_iter = iter(test_pp_gen)

In [68]:
# sp_pp_test_in_img, sp_pp_test_gold_mask, sp_pp_test_boxes_aug, sp_pp_test_pred_mask, sp_pp_test_grids, sp_pp_test_gold = next(sp_pp_test_iter)
# sp_pp_test_pred = sp_pp_test_net((sp_pp_test_in_img, sp_pp_test_pred_mask, sp_pp_test_grids))
# sp_pp_test_pred_np = sp_pp_test_pred.data.cpu().numpy()
# sp_pp_test_pred_mask_np = sp_pp_test_pred_mask.data.cpu().numpy()
# gold_out_np = sp_pp_test_gold.data.cpu().numpy()

In [67]:
# arr_to_img(mask_from_superpixels(sp_pp_test_grids[0][0], sp_pp_test_grids[0][1], gold_out_np[0][0]))

In [66]:
# arr_to_img(sp_pp_test_pred_mask_np[0][1])

In [65]:
# arr_to_img(mask_from_superpixels(sp_pp_test_grids[0][0], sp_pp_test_grids[0][1], sp_pp_test_pred_np[0][0]))

In [ ]:
# arr_to_img(sp_pp_test_pred_np[0][0])

In [ ]:
# arr_to_img(gold_out_np[0][0])

Results

Baseline
bf bp br f1 loss p r
sp 0.94±0.07 0.97±0.06 0.92±0.11 0.93±0.04 0.89±0.06 0.96±0.05
value 0.90±0.52
sp_pp1
bf bp br f1 loss p r
sp 0.96±0.05 0.99±0.03 0.93±0.08 0.88±0.05 0.83±0.08 0.94±0.05
value 0.13±0.05
lstm1
bf bp br f1 loss p r
sp 0.95±0.06 0.99±0.03 0.91±0.10 0.88±0.05 0.83±0.08 0.95±0.04
value 0.14±0.06
lstm2
bf bp br f1 loss p r
sp 0.95±0.06 0.99±0.03 0.92±0.10 0.88±0.04 0.83±0.07 0.95±0.04
value 0.14±0.06
lstm3
bf bp br f1 loss p r
sp 0.95±0.06 0.98±0.04 0.92±0.10 0.85±0.05 0.79±0.08 0.94±0.04
value 0.14±0.07

Parse tables using all built models


In [132]:
def reorder_cells(c1, c2, cells, dim):
    b1, b2 = cells[c1], cells[c2]
    coord1, coord2 = b1[dim], b2[dim]
    coord1o, coord2o = b1[dim+2], b2[dim+2]
    if coord1 < coord2:
        return c1, c2, coord2 > coord1o
    else:
        return c2, c1, coord1 > coord2o


class SeqIdGen(object):
    def __init__(self):
        self.next_id = 0

    def __call__(self):
        res = self.next_id
        self.next_id += 1
        return res


def make_equivalence_classes(neighborhood):
    cell2cls = {}
    gen_cls_id = SeqIdGen()

    refs_by_cell = { c : 0
                    for cur_cell, neighs in neighborhood.items()
                    for lst in ((cur_cell,), neighs)
                    for c in lst }
    for cur_cell, neighs in neighborhood.items():
        for oc in neighs:
            refs_by_cell[oc] += 1
    cells_with_no_refs = [c
                          for c, refs_n in refs_by_cell.items()
                          if refs_n == 0]

    for cls_seed in cells_with_no_refs:
        cur_queue = collections.deque([(cls_seed, { gen_cls_id() })])

        while len(cur_queue) > 0:
            cur_cell, cur_cls = cur_queue.popleft()
            cell2cls[cur_cell] = cur_cls

            cur_neighbors = neighborhood.get(cur_cell, [])
            many_neighbors = len(cur_neighbors) > 1
            for oc in cur_neighbors:
                oc_cls = set(cur_cls)
                if many_neighbors:
                    oc_cls.add(gen_cls_id())
                cur_queue.append((oc, oc_cls))

    cls2cell = collections.defaultdict(set)
    for cell, classes in cell2cls.items():
        for cls in classes:
            cls2cell[cls].add(cell)
    classes_to_remove = {cls
                         for cls, cells in cls2cell.items()
                         if len(cells) == 1 and len(cell2cls[next(iter(cells))]) > 1}
    for cls in classes_to_remove:
        for cell in cls2cell[cls]:
            cell2cls[cell].remove(cls)
        del cls2cell[cls]

    return cell2cls


RECONST_SAME_ROW_NEIGH = {'left', 'right'}
RECONST_SAME_COL_NEIGH = {'upper', 'lower'}
def reconstruct_table_from_grid(body, cells, grid, cell_idx, intracell_space_classes, prob_threshold=0.3):
    row_neighbors = { i : set() for i in range(len(cells)) }
    col_neighbors = { i : set() for i in range(len(cells)) }

    for (c1, c2), (neigh_type, probs) in intracell_space_classes.items():
        b1 = cells[c1]
        b2 = cells[c2]
        rel_cls = probs.argmax()
        if (neigh_type in RECONST_SAME_ROW_NEIGH): # or (rel_cls == STRUCT_SAME_ROWS_I):
            rc1, rc2, do_not_overlap = reorder_cells(c1, c2, cells, 1) # ensure rc1 is in the left of rc2
            if do_not_overlap:
                row_neighbors[rc1].add(rc2)
        if (neigh_type in RECONST_SAME_COL_NEIGH): # or (rel_cls == STRUCT_SAME_COLS_I):
            cc1, cc2, do_not_overlap = reorder_cells(c1, c2, cells, 0) # ensure cc1 is above of cc2
            if do_not_overlap:
                col_neighbors[cc1].add(cc2)

    cell2rows = make_equivalence_classes(row_neighbors)
    cell2cols = make_equivalence_classes(col_neighbors)

    rows = collections.defaultdict(set)
    for cell_i, cell_rows_idx in cell2rows.items():
        for row_i in cell_rows_idx:
            rows[row_i].add(cell_i)

    cols = collections.defaultdict(set)
    for cell_i, cell_cols_idx in cell2cols.items():
        for col_i in cell_cols_idx:
            cols[col_i].add(cell_i)

    row_boxes = [just_box_union([cells[i] for i in row_idx])
                 for row_idx in rows.values()]
    row_boxes.sort(key=lambda b: b[0])

    col_boxes = [just_box_union([cells[i] for i in col_idx])
                 for col_idx in cols.values()]
    col_boxes.sort(key=lambda b: b[1])
    return (body,
            cells,
            row_boxes,
            col_boxes)


def get_segm_mask_batch(images, segm_model, cuda=True):
    batch_segm_input = npvar(numpy.expand_dims(numpy.array([numpy.array(img).astype('float32')
                                                            for img in images]),
                                               1) / 255.0,
                             cuda)
    out = segm_model(batch_segm_input)
    if isinstance(out, tuple):
        out = out[0]
    return out.data.cpu().numpy()


# !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
# подключить det_model
# копировать выделенные фрагменты в отдельные картинки и их сегментировать с помощью segm_model


def ceil_mod(x, mod):
    if x % mod == 0:
        return x
    return ((x // mod) + 1) * mod


def choose_image_size_multiple_of(min_size, multiplier=32):
    return (ceil_mod(min_size[0], multiplier),
            ceil_mod(min_size[1], multiplier))


def fit_image_size(src_image, multiplier=128, mode='L', filler=255):
    new_size = choose_image_size_multiple_of(src_image.size,
                                             multiplier=multiplier)
    result = Image.new(mode,
                       new_size,
                       filler)
    result.paste(src_image)
    return result


def box_to_abs_coords(box, parent):
    py1, px1, py2, px2 = parent
    by1, bx1, by2, bx2 = box
    return (by1+py1, bx1+px1, by2+py1, bx2+px1)


def box_lst_to_abs_coords(boxes, parent):
    return [box_to_abs_coords(box, parent) for box in boxes]


def preprocess_page_image(img, kernel=11, ):
    img_arr = numpy.array(img)
    bin_arr = cv2.adaptiveThreshold(img_arr,
                                    255,
                                    cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
                                    cv2.THRESH_BINARY,
                                    5,
                                    20)
    return Image.fromarray(bin_arr)


# CELL_OPEN_H_MASK = numpy.ones((3, 5), dtype='uint8')
# CELL_OPEN_V_MASK = numpy.ones((5, 3), dtype='uint8')
# CELL_DIL_H_MASK = numpy.ones((1, 3), dtype='uint8')
# CELL_DIL_V_MASK = numpy.ones((3, 1), dtype='uint8')
CELL_OPEN_MASK = numpy.ones((5, 5), dtype='uint8')
CELL_ERODE_MASK = numpy.ones((3, 1), dtype='uint8')
def postprocess_segm_mask(mask):
    cell_mask = (mask[1] * 255).astype('uint8')

    for box in filter_boxes_by_overlap(get_all_boxes(mask[1])):
        y1, x1, y2, x2 = box
        if box_area(box) <= 15 or y2-y1 <= 9 or x2-x1 <= 9:
            continue
        box_mask = cell_mask[y1:y2+1, x1:x2+1]
        box_mask = cv2.erode(box_mask, CELL_ERODE_MASK)
        box_mask = cv2.morphologyEx(box_mask, cv2.MORPH_OPEN, CELL_OPEN_MASK)
        cell_mask[y1:y2+1, x1:x2+1] = box_mask

    result = mask.copy()
    result[1] = cell_mask.astype('float32') / 255
    return result


def extend_img_with_means(img):
    img = numpy.array(img)
    rows_mean = numpy.repeat(numpy.expand_dims(img.mean(0), 0),
                             img.shape[0],
                             axis=0)
    cols_mean = numpy.repeat(numpy.expand_dims(img.mean(1), 1),
                             img.shape[1],
                             axis=1)
    return numpy.stack([img, rows_mean, cols_mean])


def postprocess_densecrf(input_img, mask):
    input_img = numpy.array(input_img)
    d = dcrf.DenseCRF2D(mask.shape[0], mask.shape[1], 2)
    
    # account for predicted mask
    unary = numpy.zeros((2, mask.shape[0], mask.shape[1]),
                        dtype='float32')
    unary[0] = mask
    unary[1] = 1 - mask
    unary = -numpy.log(unary)
    d.setUnaryEnergy(unary.reshape((2, -1)))

    # account for adjacency
#     d.addPairwiseGaussian(sxy=(3, 3),
#                           compat=1,
#                           kernel=dcrf.DIAG_KERNEL,
#                           normalization=dcrf.NORMALIZE_SYMMETRIC)

    # account for pixel-level box-relatedness
    boxes = get_all_boxes(mask)
    px_row_boxes = numpy.zeros((mask.shape[0], mask.shape[1], len(boxes)), dtype='float32')
    px_col_boxes = numpy.zeros((mask.shape[0], mask.shape[1], len(boxes)), dtype='float32')
    for box_i, (y1, x1, y2, x2) in enumerate(boxes):
        px_row_boxes[y1:y2+1, :, box_i] = 1
        px_col_boxes[:, x1:x2+1, box_i] = 1

    px_row_energy = create_pairwise_bilateral(sdims=(1.0, 1.0),
                                              schan=numpy.ones(len(boxes), dtype='float32'),
                                              img=px_row_boxes,
                                              chdim=2)
    px_col_energy = create_pairwise_bilateral(sdims=(1.0, 1.0),
                                              schan=numpy.ones(len(boxes), dtype='float32'),
                                              img=px_col_boxes,
                                              chdim=2)
    d.addPairwiseEnergy(px_row_energy, compat=10)
    d.addPairwiseEnergy(px_col_energy, compat=10)

    pred = numpy.array(d.inference(50))
    pred = pred.reshape((2, mask.shape[0], mask.shape[1]))
    return pred[0]


_HIST_FILTER_BINS = numpy.arange(0, 1.0001, 30)
def make_hist_filter(xs, bins=_HIST_FILTER_BINS, min_val=3e-2, min_perc=0.25):
    sorted_xs = sorted(xs)
    min_pos = min_perc * len(sorted_xs)
    xs = [(x >= min_val) and (bisect.bisect_left(sorted_xs, x) >= min_pos)
          for x in xs]
    return numpy.array(xs)


def remove_false_in_line(arr, window=1):
    return numpy.array([arr[i:i+window+1].any() for i in range(len(arr))])


def filter_by_row_col_stat(src_img, mask, max_iter=3, min_diff=1, true_window_to_ignore=1):
    src_row_labels = remove_false_in_line(make_hist_filter(1 - src_img.mean(1)),
                                          window=true_window_to_ignore)
    src_col_labels = remove_false_in_line(make_hist_filter(1 - src_img.mean(0)),
                                          window=true_window_to_ignore)
    for _ in range(max_iter):
        row_labels = remove_false_in_line(make_hist_filter(mask.mean(1)),
                                          window=true_window_to_ignore)
        row_labels |= src_row_labels
        col_labels = remove_false_in_line(make_hist_filter(mask.mean(0)),
                                          window=true_window_to_ignore)
        col_labels |= src_col_labels

        new_mask = mask.copy()
        new_mask[~row_labels, :] = 0
        new_mask[:, ~col_labels] = 0
        if (mask - new_mask).sum() < min_diff:
            return new_mask
        mask = new_mask
    return mask


SPARSE_CLOSE_KERNEL = numpy.ones((1, 1), dtype='uint8')
def smooth_sparse_mask(mask):
    img = (mask * 255).astype('uint8')
    img = cv2.morphologyEx(img, cv2.MORPH_CLOSE, SPARSE_CLOSE_KERNEL)
    return img.astype('float32') / 255


def postprocess_using_superpixels(src_img, cells_mask):
    cells_mask_sparse = filter_by_row_col_stat(src_img, postprocess_segm_mask(cells_mask))
    orig_superpixels = make_superpixels(cells_mask_sparse)
    superpixels = filter_boxes_by_real_pixels(orig_superpixels, src_img)
    removed_boxes = set(orig_superpixels) - set(superpixels)
    for y1, x1, y2, x2 in removed_boxes:
        cells_mask_sparse[y1:y2+1, x1:x2+1] = 0
    return cells_mask_sparse


def postprocess_using_superpixels_rnn(src_img, cells_mask, model, cuda=True):
#     grid = make_multilevel_superpixels_from_img_and_mask(1 - src_img,
#                                                          cells_mask[1])
    grid = make_multilevel_superpixels_grid(cells_mask[1])
    grid_pred = model((npvar(src_img, cuda).unsqueeze(0).unsqueeze(0),
                       npvar(cells_mask, cuda).unsqueeze(0),
                       [grid]))
    return mask_from_superpixels(grid[0], grid[1], grid_pred.data.cpu().numpy()[0][0])


def parse_tables_from_images(images, det_model, segm_model, cell_model, sp_pp_model=None,
                             cuda=True, cell_batch_size=64, min_body_area=300,
                             body_padding=0, demo_prefix=None, display_demo=False,
                             box_content_getter=None, cell_content_filter=None, table_regions_by_image=None):
    tables_by_image = []
    grids_by_image = []

    images = [preprocess_page_image(img) for img in images]
    batch_det_output = get_segm_mask_batch(images, det_model, cuda)
    batch_det_output = fill_boxes_on_mask_batch(batch_det_output)
    for page_no, (src_image, det_mask) in enumerate(zip(images, batch_det_output)):
        cur_img_tables = []
        cur_img_grids = []

        if demo_prefix or display_demo:
            demo = arr_to_img(mask_to_img(det_mask))
            if display_demo:
                display(demo)
            if demo_prefix:
                demo.save(demo_prefix + '-det.png')
        det_boxes_by_channel = get_boxes_by_channel(det_mask)

        cur_image_captions = det_boxes_by_channel[DET_CAPTION_CHANNEL_I]
        if table_regions_by_image:
            cur_image_bodies = table_regions_by_image[page_no]
        else:
            cur_image_bodies = det_boxes_by_channel[DET_BODY_CHANNEL_I]
            cur_image_bodies.sort(key=box_area, reverse=True)
        for body_i, body in enumerate(cur_image_bodies):
            if box_area(body) < min_body_area:
                continue
            ext_body = (max(0, body[0]-body_padding),
                        max(0, body[1]-body_padding),
                        min(src_image.size[1], body[2]+body_padding),
                        min(src_image.size[0], body[3]+body_padding))
            cropped_image = src_image.crop((ext_body[1], ext_body[0], ext_body[3], ext_body[2]))
            src_image_for_segm = fit_image_size(cropped_image)
            src_image_for_segm_arr = numpy.array(src_image_for_segm).astype('float32') / 255.0
            segm_output = get_segm_mask_batch([src_image_for_segm], segm_model, cuda)[0]
#             return src_image_for_segm, segm_output
#             segm_output = postprocess_segm_mask(segm_output)
#             segm_output_cells = postprocess_using_superpixels(src_image_for_segm_arr,
#                                                               segm_output[1])
#             segm_output_cells = postprocess_segm_mask(segm_output_cells)
            if sp_pp_model:
                segm_output_cells = postprocess_using_superpixels_rnn(src_image_for_segm_arr,
                                                                      segm_output,
                                                                      sp_pp_model,
                                                                      cuda=cuda)
                segm_output[1] = segm_output_cells

            if demo_prefix or display_demo:
                demo = arr_to_img(mask_to_img(segm_output))
                if display_demo:
                    display(demo)
                if demo_prefix:
                    demo.save(demo_prefix + '-segm-{}.png'.format(body_i))

            # here should go a loop over all found bodies
            _, cells, grid, cell_idx, intracell_relations = table_grid_from_intracell_mask(segm_output,
                                                                                           input_image=src_image_for_segm_arr,
                                                                                           box_content_getter=(lambda box: box_content_getter(page_no,
                                                                                                                                              box_to_abs_coords(box, ext_body)))
                                                                                               if box_content_getter else None,
                                                                                           content_filter=cell_content_filter)
            output_cell_coords = box_lst_to_abs_coords(cells, ext_body)
            intracell_relation_keys = list(intracell_relations.keys())
            intracell_relation_classes = {}
            for batch_start in range(0, len(intracell_relation_keys), cell_batch_size):
                batch_rel_keys = intracell_relation_keys[batch_start:batch_start+cell_batch_size]
                batch_relations = [(cells[i1], cells[i2], intracell_relations[(i1, i2)])
                                   for (i1, i2) in batch_rel_keys]

                cell_src_img_inp = npvar(src_image_for_segm_arr, cuda).unsqueeze(0)
                cell_img_mask_inp = npvar(segm_output, cuda).unsqueeze(0)
                cell_batch_input = CellRelConvFCNClassifier.make_batch(cell_src_img_inp,
                                                                       cell_img_mask_inp,
                                                                       [batch_relations])
                cell_batch_output = cell_model(cell_batch_input).data.cpu().numpy()
                for rel_key, cur_cell_pair_out in zip(batch_rel_keys, cell_batch_output):
                    intracell_relation_classes[rel_key] = (intracell_relations[rel_key], cur_cell_pair_out)
            cur_img_grids.append((body, output_cell_coords, grid, cell_idx, intracell_relation_classes))
            cur_table = reconstruct_table_from_grid(body, output_cell_coords, grid, cell_idx, intracell_relation_classes)
            if len(cur_table[2]) > 1 and len(cur_table[3]) > 1:
                cur_img_tables.append(cur_table)

        tables_by_image.append(cur_img_tables)
        grids_by_image.append(cur_img_grids)
        
        if demo_prefix or display_demo:
            for table_i, table in enumerate(cur_img_tables):
                demo = make_demo_mask_from_parsed_table(src_image, table)
                if display_demo:
                    display(demo)
                if demo_prefix:
                    demo.save(demo_prefix + '-table-{}.png'.format(table_i))

    return tables_by_image, grids_by_image


def make_boxes_with_channels_from_parsed_table(table_info):
    body, cells, rows, cols = table_info
    boxes = [(1, body)]
    boxes.extend((2, b) for b in cells)
    boxes.extend((3, b) for b in rows)
    boxes.extend((4, b) for b in cols)
    return boxes


def make_demo_mask_from_parsed_table(img, table_info):
    boxes = make_boxes_with_channels_from_parsed_table(table_info)
    return make_demo_mask(img, boxes)


def make_intracell_rels_symmetric(intracell_relation_classes):
    src_keys = list(intracell_relation_classes.keys())
    sums = collections.defaultdict(lambda: numpy.zeros(2, dtype='float32'))
    norms = collections.defaultdict(float)
    for (c1, c2), probs in intracell_relation_classes.items():
        if c1 > c2:
            c1, c2 = c2, c1
        symmetric_key = (c1, c2)
        sums[symmetric_key] += probs
        norms[symmetric_key] += 1
    result = {}
    for key in src_keys:
        c1, c2 = key
        if c1 > c2:
            c1, c2 = c2, c1
        symmetric_key = (c1, c2)
        result[key] = sums[symmetric_key] / norms[symmetric_key]
    return result


def make_demo_mask_from_grid_info(img, grid_info, min_prob=0.5, cells_of_interest=None, print_probs=False):
    body, cells, grid, cell_idx, intracell_relation_classes = grid_info
    rel_mask = numpy.zeros((3, img.size[1], img.size[0]), dtype='uint8')
    for (c1, c2), probs in intracell_relation_classes.items():
        if not cells_of_interest is None:
            if not (c1 in cells_of_interest or c2 in cells_of_interest):
                continue
        b1 = cells[c1]
        b2 = cells[c2]
        for ch_i in range(probs.shape[0]):
            if probs[ch_i] < min_prob:
                continue
            if print_probs:
                print(c1, c2, ch_i, probs[ch_i])
            draw_intercell_mask(rel_mask[ch_i],
                                b1,
                                b2,
                                'same_row',
                                value=int(probs[ch_i]*255))
    rel_mask = mask_to_img(rel_mask.astype('float32') / 255.0)

    rel_mask_img = arr_to_img(rel_mask)
    cells_for_display = cells if cells_of_interest is None else [cells[i] for i in cells_of_interest]
    base_mask = make_demo_mask(img, [(1, body)] + [(2, c)
                                                   for c in cells_for_display])
    return Image.blend(rel_mask_img.convert('RGB'), base_mask, 0.5)


def get_either(m, k):
    r = m.get(k, None)
    if r is None:
        r = m.get((k[1], k[0]), None)
    return r

In [133]:
# test_img_id = val_det_image_ids[10]
# # test_img_filename = test_img_id + IN_IMG_SUFFIX
# test_img_filename = './data/dc/docs_pages_flat/2012-3.pdf_0001.png'
test_img_filename = './data/dc/docs_pages_flat/2003-16.pdf_0004.png'
# test_img_filename = './data/dc/docs_pages_flat/2002-5-mucin-gene-muc1-transfected-dendritic-cells-as-vaccine-results-of-a-phase-i-ii-clinical-trial.pdf_0001.png'
# test_img_filename = './data/dc/docs_pages_flat/2010-1.pdf_0002.png'
test_image = fit_image_size(load_image_opaque(test_img_filename, mode='L'))
# # test_image

In [134]:
# test_img_for_segm, test_segm_mask = parse_tables_from_images([test_image],
#                                                              det_test_net,
#                                                              cell_test_base_segm_model,
#                                                              cell_test_net,
#                                                              sp_pp_test_net,
#                                                              cuda=True,
#                                                              cell_batch_size=2,
#                                                              display_demo=False)
# test_img_for_segm = numpy.array(test_img_for_segm).astype('float32') / 255
# test_segm_mask = test_segm_mask[1]
# display(arr_to_img(test_img_for_segm))
# display(arr_to_img(test_segm_mask))
# # new_mask = postprocess_densecrf(test_img_for_segm, test_segm_mask)
# # arr_to_img(new_mask)
# # sparse_mask = postprocess_using_superpixels(test_img_for_segm, test_segm_mask)
# display(make_demo_mask(arr_to_img(test_segm_mask),
# #                        [(1, b) for b in make_boxes_from_grid(*make_multilevel_superpixels_from_img_and_mask(1 - (numpy.array(test_image).astype('float32') / 255.0),
# #                                                                                                             test_segm_mask))]
#                        [(1, b) for b in make_boxes_from_grid(*make_multilevel_superpixels_grid(test_segm_mask))]
# #                        [(1, b) for b in make_agg_superpixels_grid(test_segm_mask)]
#                       ))
# display(make_demo_mask(arr_to_img(test_segm_mask),
#                        [(1, b) for b in make_boxes_from_grid(*make_multilevel_superpixels_from_img_and_mask(1 - (numpy.array(test_image).astype('float32') / 255.0),
#                                                                                                             test_segm_mask))]
#                       ))

In [135]:
tables_by_image, grids_by_image = parse_tables_from_images([test_image],
                                                           det_test_net,
                                                           cell_test_base_segm_model,
                                                           cell_test_net,
                                                           cuda=True,
                                                           cell_batch_size=2,
                                                           display_demo=True)



In [73]:
# test_grid_ic_rels = grids_by_image[0][0][-1]
# print('test_grid_ic_rels', len(test_grid_ic_rels), numpy.array(list(test_grid_ic_rels.values())).mean(0))
# test_grid_ic_rels_sym = make_intracell_rels_symmetric(test_grid_ic_rels)
# print('test_grid_ic_rels_sym', len(test_grid_ic_rels_sym), numpy.array(list(test_grid_ic_rels_sym.values())).mean(0))

Process PDF files


In [138]:
def image_box_to_pdf_coords(box, cropbox):
    return convert_coords_to_pq(numpy.array(box) * PIXELS_TO_POINTS_FACTOR, cropbox)


def process_pdf(in_file, det_model, segm_model, cell_model, sp_pp_model=None, tmp_dir_prefix='/tmp',
                pages=None, table_regions_by_page=None, cuda=True, cell_batch_size=1,
                min_cols=2, min_rows=2, min_cells_with_text=0.5,
                demo_prefix=None, display_demo=False,
                cell_content_filter=None):
    result = []
    pdf = PdfMinerWrapper(in_file)
    pdf.load()
    pdf_basename = os.path.splitext(os.path.basename(in_file))[0]
    with tempfile.TemporaryDirectory(dir=tmp_dir_prefix) as wd:
        page_filenames = pdf_to_pages(in_file, wd, pages=pages)

        for page_fname in page_filenames:
            page_i = int(os.path.splitext(os.path.basename(page_fname))[0])
            cropbox = pdf.get_page(page_i)[1].cropbox
            if table_regions_by_page:
                table_regions_on_page = [[convert_coords_from_pq(box, cropbox) * POINTS_TO_PIXELS_FACTOR
                                          for box in table_regions_by_page.get(page_i, [])]]
            else:
                table_regions_on_page = None

            page_image = fit_image_size(load_image_opaque(page_fname, mode='L'))
            tables_by_image, grids_by_image = parse_tables_from_images([page_image],
                                                                       det_model,
                                                                       segm_model,
                                                                       cell_model,
                                                                       sp_pp_model,
                                                                       cuda=cuda,
                                                                       cell_batch_size=cell_batch_size,
                                                                       demo_prefix=(demo_prefix + '-{}'.format(page_i)) if demo_prefix else None,
                                                                       display_demo=display_demo,
                                                                       box_content_getter=lambda i, box: pdf.get_text(page_i+i,
                                                                                                                      [image_box_to_pdf_coords(box, cropbox)]),
                                                                       cell_content_filter=cell_content_filter,
                                                                       table_regions_by_image=table_regions_on_page)
            page_result = []
            for body, cells, rows, cols in tables_by_image[0]:
                if len(rows) < min_rows or len(cols) < min_cols:
                    continue
                cell_texts = [pdf.get_text(page_i, [image_box_to_pdf_coords(cell, cropbox)])
                              for cell in cells]
                cells_with_text = sum(1 for ct in cell_texts if len(ct.strip()) > 0)
                if cells_with_text / float(len(cells)) < min_cells_with_text:
                    continue
                page_result.append((image_box_to_pdf_coords(body, cropbox),
                                    [image_box_to_pdf_coords(cell, cropbox) for cell in cells],
                                    [image_box_to_pdf_coords(row, cropbox) for row in rows],
                                    [image_box_to_pdf_coords(col, cropbox) for col in cols],
                                    cell_texts))
#                 page_result.append((body,
#                                     cells,
#                                     rows,
#                                     cols,
#                                     cell_texts))
#                 print(cells[0])
#                 display(make_demo_mask_from_parsed_table(page_image, (body, cells, rows, cols)))
            result.append((page_i, page_result))
    return result

In [139]:
test_pdf_tables = process_pdf('./data/icdar2013-competition-dataset-with-gt/competition-dataset-us/us-019.pdf',
                              det_test_net,
                              cell_test_base_segm_model,
                              cell_test_net,
#                               pages=[2],
                              display_demo=True)


Export markup to ICDAR 2013 Table Competition format


In [144]:
ICDAR_GOOD_TXT_RE = re.compile(r'\w+')
def clean_text_for_icdar(t):
    return ''.join(ICDAR_GOOD_TXT_RE.findall(t)).upper()


def make_icdar_bb(root, bb):
    left, bottom, right, top = bb
    root.append(lxml.etree.Element('bounding-box',
                                   x1=str(int(math.floor(left))),
                                   y1=str(int(math.floor(bottom))),
                                   x2=str(int(math.ceil(right))),
                                   y2=str(int(math.ceil(top)))))


def table2icdar(parent, page_no, table, table_id, for_region=True):
    body, cells, rows, cols, cell_texts = table

    table_root = lxml.etree.Element('table',
                                    id=str(table_id + 1))
    parent.append(table_root)

    region = lxml.etree.Element('region',
                                **{'id': '1',
                                   'page': str(page_no+1),
#                                    'col-increment': '0',
#                                    'row-increment': '0'
                                  })
    table_root.append(region)

    if for_region:
        make_icdar_bb(region, body)
    else:
        cell2rows = group_by_centers(rows, cells)
        cell2cols = group_by_centers(cols, cells)
        for i, (cell, txt) in enumerate(zip(cells, cell_texts)):
            txt = clean_text_for_icdar(txt)
            if len(cell2cols[i]) == 0 or len(cell2rows[i]) == 0 or not txt:
#                 print('unlinked cell!!')
                continue
            cell_node = lxml.etree.Element('cell',
                                           **{'id': str(i),
                                              'start-col': str(min(cell2cols[i])),
                                              'end-col': str(max(cell2cols[i])),
                                              'start-row' : str(min(cell2rows[i])),
                                              'end-row' : str(max(cell2rows[i]))})
            make_icdar_bb(cell_node, cell)
            content = lxml.etree.Element('content')
            content.text = txt
            cell_node.append(content)
            region.append(cell_node)


def pdf_tables2icdar(outfile, tables, for_region=True):
    root = lxml.etree.Element('document',
                              filename=os.path.basename(outfile))
    table_id = 0
    for page_no, page_tables in tables:
        for table in page_tables:
            table2icdar(root, page_no, table, table_id, for_region=for_region)
            table_id += 1
    with open(outfile, 'w') as f:
        f.write(lxml.etree.tostring(root, pretty_print=True).decode('utf8'))


def load_xml(fname):
    with open(fname, 'r') as f:
        return lxml.etree.parse(f)


def icdar_get_bounding_box(node):
    return (float(node.attrib['x1']),
            float(node.attrib['y1']),
            float(node.attrib['x2']),
            float(node.attrib['y2']))


def icdar_get_regions_by_page(root):
    result = collections.defaultdict(list)
    for reg in root.xpath('//table/region'):
        result[int(reg.attrib['page']) - 1].append(icdar_get_bounding_box(reg.xpath('./bounding-box')[0]))
    return result


def icdar_cell_content_filter(txt):
    txt = txt.strip().lower()
    if len(txt) > 5 and len(set(txt)) == 1:
        return False
    return True


def process_pdf_icdar(infile, outdir, det_model, segm_model, cell_model, sp_pp_model):
    basename = os.path.splitext(os.path.basename(infile))[0]

#     reg_doc_tables = process_pdf(infile, det_model, segm_model, cell_model,
#                                  demo_prefix=os.path.join(outdir, basename + '-reg'))
#     pdf_tables2icdar(os.path.join(outdir, basename + '-reg-result.xml'),
#                      doc_tables,
#                      for_region=True)
    gold_table_regions_by_page = icdar_get_regions_by_page(
        load_xml(os.path.join(outdir, basename + '-reg.xml')))
    str_doc_tables = process_pdf(infile, det_model, segm_model, cell_model, sp_pp_model,
                                 demo_prefix=os.path.join(outdir, basename + '-str'),
                                 table_regions_by_page=gold_table_regions_by_page,
                                 cell_content_filter=icdar_cell_content_filter)
    pdf_tables2icdar(os.path.join(outdir, basename + '-str-result.xml'),
                     str_doc_tables,
                     for_region=False)


def process_pdf_files_icdar(infiles, outdir, det_model, segm_model, cell_model, sp_pp_model=None):
    for file in tqdm.tqdm(infiles):
        process_pdf_icdar(file, outdir, det_model, segm_model, cell_model, sp_pp_model=sp_pp_model)
        torch.cuda.empty_cache()

In [141]:
# %%prun -s cumulative

# process_pdf_icdar('./data/icdar2013-competition-dataset-with-gt/competition-dataset-us/us-001.pdf',
#                   './data/icdar2013-pred/',
#                   det_test_net,
#                   cell_test_base_segm_model,
#                   cell_test_net)

In [145]:
icdar_files = glob.glob('./data/icdar2013-competition-dataset-with-gt/*/*.pdf')
# icdar_files = ['./data/icdar2013-pred/us-034.pdf']
process_pdf_files_icdar(icdar_files,
                        './data/icdar2013-pred/',
                        det_test_net,
                        cell_test_base_segm_model,
                        cell_test_net,
#                         sp_pp_test_net
                       )


100%|██████████| 67/67 [29:28<00:00, 26.39s/it]  

Validate xml


In [84]:
class XmlValidator(object):
    def __init__(self, dtd_file):
        with open(dtd_file, 'r') as f:
            tree = lxml.etree.parse(f)
            self.schema = lxml.etree.XMLSchema(tree)

    def __call__(self, xml_file):
        with open(xml_file, 'r') as f:
            tree = lxml.etree.parse(f)
        is_valid = self.schema.validate(tree)
        if is_valid:
            errors = []
        else:
            errors = self.schema.error_log
        return is_valid, errors

reg_dtd = XmlValidator('./data/competition-entry-region-model.xsd')
str_dtd = XmlValidator('./data/competition-entry-structure-model.xsd')

In [85]:
# reg_dtd('./data/icdar2013-pred/us-025-reg-result.xml')

In [86]:
# str_dtd('./data/icdar2013-pred/us-001-str-result.xml')

Calculate metrics


In [146]:
def evaluate_reg_one_region(gold_elems, pred_elems):
    tp = len(gold_elems & pred_elems)
    fp = len(pred_elems - gold_elems)
    fn = len(gold_elems - pred_elems)
    return dict(complete=1 if fn == 0 else 0,
                pure=1 if fp == 0 else 0,
                precision=calc_precision(tp, fp, fn),
                recall=calc_recall(tp, fp, fn),
                f1=calc_f1(tp, fp, fn))


def icdar_get_region_contents(pdf, page_no, region):
    return { id(ch) for ch in pdf.get_boxes(page_no, [region]) }


def icdar_get_region_contents_by_page(pdf, regions_by_page):
    return { page_no : [icdar_get_region_contents(pdf, page_no, region)
                        for region in page_regions]
            for page_no, page_regions in regions_by_page.items() }


def icdar_evaluate_reg_one_doc(pdf, gold_tree, pred_tree):
    gold_regions_by_page = icdar_get_region_contents_by_page(pdf, icdar_get_regions_by_page(gold_tree))
    pred_regions_by_page = icdar_get_region_contents_by_page(pdf, icdar_get_regions_by_page(pred_tree))

    total_gold_regions = sum(len(page_regions) for page_regions in gold_regions_by_page.values())
    total_pred_regions = sum(len(page_regions) for page_regions in pred_regions_by_page.values())

    metrics = []
    for page in sorted(set(gold_regions_by_page) | set(pred_regions_by_page)):
        matched_pairs = match_boxes_by_biggest_overlap_1to1(gold_regions_by_page.get(page, []),
                                                            pred_regions_by_page.get(page, []))
        for gold_reg, pred_reg in matched_pairs:
            metrics.append(evaluate_reg_one_region(gold_reg, pred_reg))
    if len(metrics) > 0:
        metrics = pandas.DataFrame(metrics)
        completeness = metrics.complete.sum() / total_gold_regions
        purity = metrics.pure.sum() / total_pred_regions
        return {'reg completeness' : completeness,
                'reg purity' : purity,
                'reg cpf' : (2 * completeness * purity / (completeness + purity)) if (completeness + purity) > 1e-3 else 0,
                'reg precision' : metrics.precision.mean(),
                'reg recall' : metrics.recall.mean(),
                'reg f1' : metrics.f1.mean() }
    else:
        return {'reg completeness' : 0,
                'reg purity' : 0,
                'reg cpf' : 0,
                'reg precision' : 0,
                'reg recall' : 0,
                'reg f1' : 0 }


def icdar_parse_table(table_root):
    pages = set(int(p) for p in table_root.xpath('./region/@page'))
    assert len(pages) <= 1
    page = (next(iter(pages)) - 1) if pages else None
    cells = {}
    rows = collections.defaultdict(set)
    cols = collections.defaultdict(set)
    for i, cell_node in enumerate(table_root.xpath('./region/cell')):
        cell_attribs = { k : int(v) for k, v in cell_node.attrib.items() }
        cell_attribs['id'] = i
        if not 'end-col' in cell_attribs:
            cell_attribs['end-col'] = cell_attribs['start-col']
        if not 'end-row' in cell_attribs:
            cell_attribs['end-row'] = cell_attribs['start-row']
        cell_attribs['region'] = icdar_get_bounding_box(cell_node.xpath('./bounding-box')[0])
        cell_attribs['content'] = cell_node.xpath('./content/text()')[0]
        cells[cell_attribs['id']] = cell_attribs
        for ri in range(cell_attribs['start-row'], cell_attribs['end-row'] + 1):
            rows[ri].add(cell_attribs['id'])
        for ci in range(cell_attribs['start-col'], cell_attribs['end-col'] + 1):
            cols[ci].add(cell_attribs['id'])
    return page, cells, rows, cols


def icdar_parse_tables_str(root):
    return {int(table_node.attrib['id']) : icdar_parse_table(table_node)
            for table_node in root.xpath('//table')}


def icdar_str_table_to_triplets(table):
    page, cells, rows, cols = table
    result = set()
    cell2content = { c['id'] : clean_text_for_icdar(c['content'])
                    for c in cells.values() }
    for cur_cell in cells.values():
        cur_content = cell2content[cur_cell['id']]
        r0, r1 = cur_cell['start-row'], cur_cell['end-row']
        c0, c1 = cur_cell['start-col'], cur_cell['end-col']
        for ri in range(r0, r1 + 1):
            result.update((cur_content, '>', cell2content[neigh_i])
                          for neigh_i in rows[ri] & cols[c1+1])

        for ci in range(c0, c1 + 1):
            result.update((cur_content, 'v', cell2content[neigh_i])
                          for neigh_i in rows[r1+1] & cols[ci])
    return result


def icdar_compare_tables_str(pdf, gold_table, pred_table):
    gpage, gcells, grows, gcols = gold_table
    ppage, pcells, prows, pcols = pred_table
    assert gpage == ppage and not gpage is None

    # evaluate cells adjacency
    gtriplets = icdar_str_table_to_triplets(gold_table)
    ptriplets = icdar_str_table_to_triplets(pred_table)
    triplets_tp = len(gtriplets & ptriplets)
    triplets_fp = len(ptriplets - gtriplets)
    triplets_fn = len(gtriplets - ptriplets)
#     print('triplets fp', ptriplets - gtriplets)
#     print('triplets fn', gtriplets - ptriplets)

    # evaluate boxes against completeness/purity/prec/recall
    gold_regions = [{ id(ch) for ch in pdf.get_boxes(gpage, [c['region']]) }
                    for c in gcells.values()]
    pred_regions = [{ id(ch) for ch in pdf.get_boxes(ppage, [c['region']]) }
                    for c in pcells.values()]
    matched_cell_regions = match_boxes_by_biggest_overlap_1to1(gold_regions,
                                                               pred_regions)
    region_metrics = [evaluate_reg_one_region(gold_reg, pred_reg)
                      for gold_reg, pred_reg in matched_cell_regions]
    result = {'str box completeness' : 0,
              'str box purity' : 0,
              'str box cpf' : 0,
              'str box precision' : 0,
              'str box recall' : 0,
              'str box f1' : 0,
              'str rel precision' : calc_precision(triplets_tp, triplets_fp, triplets_fn),
              'str rel recall' : calc_recall(triplets_tp, triplets_fp, triplets_fn),
              'str rel f1' : calc_f1(triplets_tp, triplets_fp, triplets_fn)}
    if len(region_metrics) > 0:
        region_metrics = pandas.DataFrame(region_metrics)
        completeness = region_metrics.complete.sum() / float(len(gcells))
        purity = region_metrics.pure.sum() / float(len(pcells))
        result.update({'str box completeness' : completeness,
                       'str box purity' : purity,
                       'str box cpf' : (2 * completeness * purity / (completeness + purity)) if (completeness + purity) > 1e-3 else 0,
                       'str box precision' : region_metrics.precision.mean(),
                       'str box recall' : region_metrics.recall.mean(),
                       'str box f1' : region_metrics.f1.mean() })
    return result


def icdar_evaluate_str_one_doc(pdf, gold_tree, pred_tree):
    gold_tables = icdar_parse_tables_str(gold_tree)
    pred_tables = icdar_parse_tables_str(pred_tree)
    assert set(gold_tables.keys()) == set(pred_tables.keys()), (gold_tables.keys(), pred_tables.keys())
    metrics = [icdar_compare_tables_str(pdf,
                                        gold_tables[table_id],
                                        pred_tables[table_id])
               for table_id in gold_tables.keys()]
    return pandas.DataFrame(metrics).mean(axis=0)


def icdar_evaluate_single_file(pdf_file):
    pdf = PdfMinerWrapper(pdf_file)
    pdf.load()
    basepath = os.path.splitext(pdf_file)[0]
    result = {}
#     result.update(icdar_evaluate_reg_one_doc(pdf,
#                                              load_xml(basepath+'-reg.xml'),
#                                              load_xml(basepath+'-reg-result.xml')))
    result.update(icdar_evaluate_str_one_doc(pdf,
                                             load_xml(basepath+'-str.xml'),
                                             load_xml(basepath+'-str-result.xml')))
    return result


def icdar_evaluate_files(pdf_files):
    pdf_files = list(pdf_files)
    result = []
    idx = []
    for file in tqdm.tqdm(pdf_files):
        try:
            result.append(icdar_evaluate_single_file(file))
            idx.append(file)
        except:
            print('Could not process {} due to\n{}'.format(file, traceback.format_exc()))
    result = pandas.DataFrame(result)
    result['files'] = idx
    result.set_index('files')
    return result

In [147]:
icdar_files = glob.glob('./data/icdar2013-pred/*.pdf')
# icdar_files = ['./data/icdar2013-pred/us-034.pdf']
icdar_metrics = icdar_evaluate_files(icdar_files)
format_metrics_table_md(icdar_metrics)


 42%|████▏     | 28/67 [03:36<05:01,  7.74s/it] 
Could not process ./data/icdar2013-pred/eu-015.pdf due to
Traceback (most recent call last):
  File "<ipython-input-146-99fe0ba2b009>", line 178, in icdar_evaluate_files
    result.append(icdar_evaluate_single_file(file))
  File "<ipython-input-146-99fe0ba2b009>", line 168, in icdar_evaluate_single_file
    load_xml(basepath+'-str-result.xml')))
  File "<ipython-input-146-99fe0ba2b009>", line 150, in icdar_evaluate_str_one_doc
    assert set(gold_tables.keys()) == set(pred_tables.keys()), (gold_tables.keys(), pred_tables.keys())
AssertionError: (dict_keys([1, 2, 3, 4, 5]), dict_keys([1, 2, 3, 4]))

100%|██████████| 67/67 [05:41<00:00,  5.09s/it]
|       | str box completeness   | str box cpf   | str box f1   | str box precision   | str box purity   | str box recall   | str rel f1   | str rel precision   | str rel recall   |
|:------|:-----------------------|:--------------|:-------------|:--------------------|:-----------------|:-----------------|:-------------|:--------------------|:-----------------|
| value | 0.86±0.17              | 0.74±0.20     | 0.93±0.07    | 0.93±0.08           | 0.69±0.24        | 0.97±0.06        | 0.61±0.27    | 0.60±0.29           | 0.66±0.25        |


In [148]:
icdar_metrics.to_csv('./data/icdar2013_pred_var_dil.csv')

In [82]:
icdar_metrics = pandas.read_csv('./data/icdar2013_pred_pp_lstm3.csv',
                                index_col=[0]).set_index('files') \
    .sort_values('str rel f1')
icdar_metrics


Out[82]:
str box completeness str box cpf str box f1 str box precision str box purity str box recall str rel f1 str rel precision str rel recall
files
./data/icdar2013-pred/us-026.pdf 0.135802 0.103286 0.345562 0.287132 0.083333 0.944444 0.000000 0.000000 0.000000
./data/icdar2013-pred/us-001.pdf 0.508809 0.538079 0.767752 0.734736 0.571147 0.955063 0.129033 0.115769 0.155751
./data/icdar2013-pred/us-016.pdf 0.437500 0.534722 0.981958 0.981984 0.687500 0.983870 0.136364 0.136364 0.136364
./data/icdar2013-pred/us-015.pdf 0.387500 0.455423 0.873446 0.956396 0.553571 0.861044 0.177778 0.170928 0.186355
./data/icdar2013-pred/us-013.pdf 0.631579 0.673866 0.855348 0.907723 0.722222 0.904454 0.238095 0.172414 0.384615
./data/icdar2013-pred/us-032.pdf 0.647059 0.538776 0.884594 0.833094 0.461538 0.979701 0.242424 0.307692 0.200000
./data/icdar2013-pred/eu-022.pdf 0.549296 0.631422 0.908260 0.916943 0.742424 0.927780 0.260606 0.204762 0.358333
./data/icdar2013-pred/us-019.pdf 0.613479 0.689355 0.915810 0.899405 0.787684 0.976254 0.261995 0.247316 0.284917
./data/icdar2013-pred/us-018.pdf 0.483436 0.585935 0.855939 0.877207 0.790767 0.909010 0.281070 0.262829 0.308944
./data/icdar2013-pred/eu-003.pdf 0.686594 0.618801 0.848651 0.975498 0.574943 0.825021 0.312263 0.278660 0.362802
./data/icdar2013-pred/us-002.pdf 0.745816 0.797183 0.947027 0.955647 0.859598 0.961664 0.313473 0.246559 0.513291
./data/icdar2013-pred/us-007.pdf 0.576057 0.625632 0.877322 0.929141 0.687500 0.885297 0.325081 0.374973 0.287350
./data/icdar2013-pred/us-035a.pdf 0.694060 0.721002 0.871489 0.976471 0.759038 0.856879 0.346758 0.349977 0.407774
./data/icdar2013-pred/eu-001.pdf 0.579779 0.564091 0.906710 0.912213 0.564640 0.940632 0.409041 0.423637 0.402997
./data/icdar2013-pred/us-031a.pdf 0.891304 0.900110 0.954821 0.965783 0.909091 0.970528 0.420635 0.296089 0.726027
./data/icdar2013-pred/us-036.pdf 0.785714 0.661654 0.981378 0.986983 0.571429 0.978819 0.421053 0.421053 0.421053
./data/icdar2013-pred/us-011a.pdf 0.621429 0.702641 0.972325 0.985915 0.808571 0.961247 0.439850 0.439850 0.439850
./data/icdar2013-pred/us-037.pdf 0.825301 0.875669 0.947343 1.000000 0.932584 0.923665 0.475436 0.383632 0.625000
./data/icdar2013-pred/us-025.pdf 0.709474 0.789710 0.934845 0.947780 0.905263 0.957153 0.477753 0.460081 0.503666
./data/icdar2013-pred/us-040.pdf 0.894737 0.921541 0.988972 1.000000 0.950000 0.980263 0.489796 0.545455 0.444444
./data/icdar2013-pred/eu-009a.pdf 0.880000 0.684444 0.982387 0.976353 0.560000 0.989890 0.495575 0.350000 0.848485
./data/icdar2013-pred/us-009.pdf 0.740741 0.586907 0.888910 0.875361 0.485981 0.950627 0.496599 0.489933 0.503448
./data/icdar2013-pred/us-010.pdf 0.703704 0.737268 0.923461 0.996471 0.774194 0.894602 0.500000 0.466667 0.538462
./data/icdar2013-pred/us-012.pdf 0.785124 0.850360 0.943530 0.990831 0.927419 0.920110 0.514019 0.410448 0.687500
./data/icdar2013-pred/eu-018.pdf 0.730247 0.830158 0.978882 0.994919 0.973777 0.970858 0.518695 0.643746 0.436774
./data/icdar2013-pred/eu-016.pdf 0.832258 0.908451 0.974839 1.000000 1.000000 0.956895 0.548057 0.551440 0.544715
./data/icdar2013-pred/eu-021.pdf 0.855728 0.796487 0.948614 0.953195 0.745526 0.959221 0.558129 0.537060 0.581123
./data/icdar2013-pred/eu-017.pdf 0.785714 0.880000 0.970029 1.000000 1.000000 0.947982 0.558635 0.557447 0.559829
./data/icdar2013-pred/us-021.pdf 0.858607 0.871324 0.961156 0.988528 0.884524 0.950694 0.564787 0.488816 0.680672
./data/icdar2013-pred/us-017.pdf 0.800138 0.877476 0.973617 0.999764 0.983183 0.955515 0.565127 0.560780 0.573011
... ... ... ... ... ... ... ... ... ...
./data/icdar2013-pred/eu-004.pdf 0.792001 0.857680 0.961289 0.997741 0.941791 0.942308 0.617793 0.607520 0.637829
./data/icdar2013-pred/us-039.pdf 0.928571 0.962963 0.989796 1.000000 1.000000 0.982143 0.645161 0.833333 0.526316
./data/icdar2013-pred/eu-019.pdf 0.863158 0.926554 0.980059 1.000000 1.000000 0.966021 0.650000 0.698925 0.607477
./data/icdar2013-pred/us-004.pdf 0.921348 0.936878 0.987303 0.989320 0.952941 0.989804 0.661157 0.547945 0.833333
./data/icdar2013-pred/us-020.pdf 0.843459 0.867458 0.962114 0.997170 0.897978 0.951181 0.664209 0.654103 0.682826
./data/icdar2013-pred/us-038.pdf 0.875000 0.906883 0.982601 1.000000 0.941176 0.970881 0.666667 0.652174 0.681818
./data/icdar2013-pred/us-022.pdf 0.151515 0.262634 0.824475 1.000000 0.985075 0.720013 0.681416 0.693694 0.669565
./data/icdar2013-pred/us-029.pdf 0.888889 0.881890 0.972706 0.961095 0.875000 1.000000 0.681818 0.909091 0.545455
./data/icdar2013-pred/us-014.pdf 0.861111 0.854968 0.956648 0.986019 0.849624 0.942442 0.686111 0.588889 0.833333
./data/icdar2013-pred/us-028.pdf 0.857955 0.860126 0.960933 0.972928 0.868116 0.966260 0.688492 0.757576 0.639640
./data/icdar2013-pred/us-008.pdf 0.891304 0.826667 0.945695 0.957368 0.770833 0.964286 0.720339 0.696970 0.750000
./data/icdar2013-pred/us-027.pdf 0.865035 0.877834 0.960836 0.969816 0.894426 0.967681 0.736063 0.734050 0.738095
./data/icdar2013-pred/us-033.pdf 0.902778 0.913960 0.973394 0.964319 0.926340 0.995582 0.755177 0.762244 0.772727
./data/icdar2013-pred/eu-005.pdf 0.911616 0.927979 0.974457 0.987589 0.945000 0.974079 0.771739 0.764724 0.779066
./data/icdar2013-pred/eu-011.pdf 0.933333 0.965517 0.987963 1.000000 1.000000 0.980000 0.791667 0.791667 0.791667
./data/icdar2013-pred/us-023.pdf 0.927835 0.957715 0.994145 0.999778 0.989583 0.989542 0.792135 0.819767 0.766304
./data/icdar2013-pred/us-034.pdf 0.965517 0.969332 0.990100 0.999138 0.973197 0.985913 0.792482 0.779485 0.811207
./data/icdar2013-pred/eu-007.pdf 0.874459 0.899032 0.948726 0.989165 0.936656 0.946212 0.796478 0.788666 0.805950
./data/icdar2013-pred/eu-010.pdf 0.954545 0.876074 0.971799 0.953435 0.809524 1.000000 0.806452 0.806452 0.806452
./data/icdar2013-pred/eu-012.pdf 0.935549 0.944760 0.991283 0.991005 0.954359 0.993239 0.821730 0.861774 0.791654
./data/icdar2013-pred/eu-008.pdf 0.948276 0.875767 0.965819 0.961894 0.813559 0.984127 0.835979 0.831579 0.840426
./data/icdar2013-pred/us-005.pdf 1.000000 0.947368 0.998246 0.996552 0.900000 1.000000 0.846154 0.846154 0.846154
./data/icdar2013-pred/eu-027.pdf 0.814286 0.894148 0.995889 0.995690 0.991379 0.997691 0.865169 0.885057 0.846154
./data/icdar2013-pred/eu-006.pdf 0.960800 0.979507 0.994565 1.000000 1.000000 0.990461 0.879984 0.883587 0.876506
./data/icdar2013-pred/us-003.pdf 0.947368 0.794567 0.983777 0.971766 0.684211 0.997368 0.896552 0.896552 0.896552
./data/icdar2013-pred/eu-026.pdf 0.936111 0.954034 0.992098 1.000000 0.974359 0.988158 0.905069 0.959596 0.856537
./data/icdar2013-pred/eu-013.pdf 0.982759 0.974981 0.993557 0.998649 0.967708 0.991560 0.957471 0.965476 0.950000
./data/icdar2013-pred/eu-024.pdf 1.000000 0.947368 0.995436 0.991332 0.900000 1.000000 1.000000 1.000000 1.000000
./data/icdar2013-pred/us-006.pdf 1.000000 0.956522 0.997619 0.995370 0.916667 1.000000 1.000000 1.000000 1.000000
./data/icdar2013-pred/eu-023.pdf 1.000000 0.904110 0.993796 0.988133 0.825000 1.000000 1.000000 1.000000 1.000000

66 rows × 9 columns

Results

det_unet1 + det_unet1 + cell_fcn1_no_aug
reg completeness reg cpf reg f1 reg precision reg purity reg recall
value 0.90±0.26 0.34±0.43 0.89±0.23 0.87±0.25 0.32±0.43 0.95±0.22
gold regions + segm_unet3_step_aug_full_shrinked_reweighted_cells_only_dec_dil_out + cell_fcn1_no_aug_diag_with_unet1
str box completeness str box cpf str box f1 str box precision str box purity str box recall str rel f1 str rel precision str rel recall
value 0.78±0.22 0.72±0.23 0.92±0.10 0.91±0.12 0.72±0.25 0.98±0.05 0.58±0.32 0.57±0.33 0.62±0.31
gold_regions + segm_unet1_step_aug_full2_rc_sc_reg + cell_fcn1_no_aug_diag_with_unet1
str box completeness str box cpf str box f1 str box precision str box purity str box recall str rel f1 str rel precision str rel recall
value 0.88±0.15 0.83±0.17 0.96±0.05 0.97±0.05 0.82±0.20 0.97±0.05 0.74±0.22 0.72±0.25 0.80±0.18
gold_regions + segm_unet1_step_aug_full2_rc_sc_reg + cell_fcn1_no_aug_diag_with_unet1 + pp_lstm1
str box completeness str box cpf str box f1 str box precision str box purity str box recall str rel f1 str rel precision str rel recall
value 0.88±0.15 0.86±0.14 0.96±0.05 0.98±0.05 0.86±0.14 0.97±0.04 0.70±0.21 0.70±0.24 0.74±0.19
gold_regions + segm_unet1_step_aug_full2_rc_sc_reg + cell_fcn1_no_aug_diag_with_unet1 + pp_lstm3
str box completeness str box cpf str box f1 str box precision str box purity str box recall str rel f1 str rel precision str rel recall
value 0.79±0.19 0.80±0.17 0.94±0.09 0.96±0.10 0.83±0.17 0.96±0.05 0.59±0.23 0.58±0.25 0.61±0.23

Process images from a folder


In [ ]:
import svgwrite

In [ ]:
def color_tuple_to_html(c):
    return '#' + '{:02x}{:02x}{:02x}'.format(*c)


def extracted_tables_to_svg(page_img, tables, out_file):
    img_buf = io.StringIO()
    page_img.save(img_buf, format="png")
    img_str = base64.b64encode(img_buf.getvalue()).decode('ascii')
    img_data_url = 'data:image/png;base64,' + img_str

    dwg = svgwrite.Drawing(out_file,
                           size=page_img.size,
                           profile='tiny')
    dwg.add(dwg.image(img_data_url))

    for table_info in tables:
        for cls, (y1, x1, y2, x2) in make_boxes_with_channels_from_parsed_table(table_info):
            dwg.add(dwg.rect((x1, y1),
                             (x2 - x1, y2 - y1),
                             **{'fill' : 'none',
                                'stroke' : color_tuple_to_html(CLS_TO_COLOR[cls]),
                                'stroke-opacity' : 1}))

    dwg.save()

In [ ]:
play_files = list(glob.glob('./data/dc/docs_pages_flat/*.png'))

In [ ]:
play_files[:10]