In [12]:
from __future__ import print_function
import numpy as np
The naive approach to calculate the edit distance is to first define three possible operations allowed: $\{ \text{substitution}, \text{insert}, \text{delete}\}$. We maintain the index to $s_1$ with $i$ and the index to $s_2$ with $j$.
In [294]:
def edit_distance(s1, s2, i, j, match_cost, insert_cost, delete_cost):
if i == 0: return j*delete_cost(" ")
if j == 0: return i*insert_cost(" ")
match_op = edit_distance(s1, s2, i-1, j-1, match_cost, insert_cost, delete_cost) \
+ match_cost(s1[i], s2[j])
insert_op = edit_distance(s1, s2, i, j-1, match_cost, insert_cost, delete_cost) \
+ insert_cost(s2[j])
delete_op = edit_distance(s1, s2, i-1, j, match_cost, insert_cost, delete_cost) \
+ delete_cost(s1[i])
ops = [match_op, insert_op, delete_op]
return min(ops)
s1 = "abc"
s2 = "xaxbxc"
print(edit_distance(s1, s2, len(s1)-1, len(s2)-1,
lambda c1, c2: 0 if c1 == c2 else 1,
lambda c: 1,
lambda c: 1))
The function can only be used on very small strings due to the fact that it branches for each operation. The best case scenario (where $i$ and $j$ are both decremented) it has a $O(3^n)$. However, we can observe that many of the calculations are duplicated. Each branching of the tree does not share information with the other branches, and necessarily overlap the portions of the string they process. Because the function detereministic and is parameterized by $i$ and $j$, there can only be $|s_1|*|s_2|$ possible calls made, and thus should have a $O(mn)$.
We can use dynamic programming to solve this for us. The function still behaves in a similar fashion, except we now keep two matrices. The cost matrix keeps a running tally of the total cost to reach $(i, j)$. The parent matrix keeps track of which operation was selectd to reach $(i, j)$. For both matrices, its assumed that the minimum cost was always selected.
One thing to note is that the matrices are indexed with +1 from the previous code. This is because the $0$th indices of the matrices refer to $1$ before the first character of the string.
In [297]:
def calc_distance(cost, parent, s1, s2, match_cost, insert_cost, delete_cost):
for i, c1 in enumerate(s1):
for j, c2 in enumerate(s2):
match_op = cost[i, j] + match_cost(c1, c2)
insert_op = cost[i+1, j] + insert_cost(c1)
delete_op = cost[i, j+1] + delete_cost(c2)
# determine which operation is cheapest
ops = [match_op, insert_op, delete_op]
idx = np.argmin(ops)
# cost for (i, j) is determined by whichever op is cheapest
# to perform.
cost[i+1, j+1] = ops[idx]
parent[i+1, j+1] = idx
return (cost, parent)
Our new function calc_distance doesn't return the final desired value. Instead we can use reconstruct_path on the parent matrix given a starting location.
In [287]:
def reconstruct_path(parent, s1, s2, i, j, on_match=None, on_insert=None, on_delete=None):
if parent[i, j] == -1: return
if parent[i, j] == 0:
reconstruct_path(parent, s1, s2, i-1, j-1, on_match, on_insert, on_delete)
if on_match: on_match(s1[i-1], s2[j-1])
elif parent[i, j] == 1:
reconstruct_path(parent, s1, s2, i, j-1, on_match, on_insert, on_delete)
if on_insert: on_insert(s1[i-1], s2[j-1])
elif parent[i, j] == 2:
reconstruct_path(parent, s1, s2, i-1, j, on_match, on_insert, on_delete)
if on_delete: on_delete(s1[i-1], s2[j-1])
In [298]:
def edit_distance(s1, s2):
cost = np.zeros((len(s1)+1, len(s2)+1), dtype=np.uint8)
parent = np.zeros((len(s1)+1, len(s2)+1), dtype=np.int8)
cost[0, 0] = 0
parent[0, 0] = -1
# deletion (s1[i] == s2[i-k])
cost[1:, 0] = range(1, len(s1)+1)
parent[1:, 0] = 2
# insertion (s1[i+k] == s2[i])
cost[0, 1:] = range(1, len(s2)+1)
parent[0, 1:] = 1
cost, parent = calc_distance(cost, parent, s1, s2,
lambda c1, c2: 0 if c1 == c2 else 1,
lambda c: 1,
lambda c: 1)
result = []
reconstruct_path(parent, s1, s2, len(s1), len(s2),
lambda c1, c2: result.append("M") if c1 == c2 else result.append("S"),
lambda c1, c2: result.append("I"),
lambda c1, v2: result.append("D"))
return cost[len(s1), len(s2)], result
s1 = "thou shalt not"
s2 = "you should not"
cost, result = edit_distance(s1, s2)
result = "".join(result)
print(cost)
correct = "DSMMMMMISMSMMMM"
print(result == correct)
In [280]:
def match_substring(s1, s2):
cost = np.zeros((len(s1)+1, len(s2)+1), dtype=np.uint8)
parent = np.zeros((len(s1)+1, len(s2)+1), dtype=np.int8)
cost[0, 0] = 0
parent[0, 0] = -1
cost[1:, 0] = range(1, len(s1)+1)
parent[1:, 0] = 2
cost[0, 1:] = range(1, len(s2)+1)
parent[0, 1:] = 1
cost, parent = calc_distance(cost, parent, s1, s2,
lambda c1, c2: 0 if c1 == c2 else 1,
lambda c: 1,
lambda c: 1)
i = len(s1)
j = np.argmin(cost[i, :])
count = [0]
def inc(c1, c2):
count[0] += 1
reconstruct_path(parent, s1, s2, i, j, inc)
return cost[i, j], count[0]
s1 = "!test!"
s2 = "test"
cost, count = match_substring(s1, s2)
print("cost: ", cost)
print(count == len(s2))
In [281]:
def longest_subsequence(s1, s2):
cost = np.zeros((len(s1)+1, len(s2)+1), dtype=np.uint8)
parent = np.zeros((len(s1)+1, len(s2)+1), dtype=np.int8)
cost[0, 0] = 0
parent[0, 0] = -1
cost[1:, 0] = range(1, len(s1)+1)
parent[1:, 0] = 2
cost[0, 1:] = range(1, len(s2)+1)
parent[0, 1:] = 1
cost, parent = calc_distance(cost, parent, s1, s2,
lambda c1, c2: 0 if c1 == c2 else 10,
lambda c: 1,
lambda c: 1)
result = []
reconstruct_path(parent, s1, s2, len(s1), len(s2),
lambda c1, c2: result.append(c1))
return "".join(result)
s1 = " democrat"
s2 = "republican"
result = longest_subsequence(s1, s2)
print(result)
In [316]:
def _calc_distance(s1, s2, i, j, match_cost, insert_cost, delete_cost, cache={}):
if (i, j) in cache:
return cache[(i, j)]
result = edit_distance(s1, s2, i, j, match_cost, insert_cost, delete_cost, cache)
cache[(i, j)] = result
return result
def edit_distance(s1, s2, i, j, match_cost, insert_cost, delete_cost, cache={}):
# s1 ran out, delete the rest of the string
if i == 0: return j*delete_cost(" ")
# s2 ran out, insert until we match s1
if j == 0: return i*insert_cost(" ")
match_op = _calc_distance(s1, s2, i-1, j-1, match_cost, insert_cost, delete_cost, cache) \
+ match_cost(s1[i-1], s2[j-1])
insert_op = _calc_distance(s1, s2, i, j-1, match_cost, insert_cost, delete_cost, cache) \
+ insert_cost(s2[j-1])
delete_op = _calc_distance(s1, s2, i-1, j, match_cost, insert_cost, delete_cost, cache) \
+ delete_cost(s1[i-1])
ops = [match_op, insert_op, delete_op]
return min(ops)
s1 = "you should not"
s2 = "thou shalt not"
print(edit_distance(s1, s2, len(s1), len(s2),
lambda c1, c2: 0 if c1 == c2 else 1,
lambda c: 1,
lambda c: 1))