In [ ]:
%matplotlib inline
We use template matching to identify the occurrence of an image patch
(in this case, a sub-image centered on a single coin). Here, we
return a single match (the exact same coin), so the maximum value in the
match_template
result corresponds to the coin location. The other coins
look similar, and thus have local maxima; if you expect multiple matches, you
should use a proper peak-finding function.
The match_template
function uses fast, normalized cross-correlation [1]_
to find instances of the template in the image. Note that the peaks in the
output of match_template
correspond to the origin (i.e. top-left corner) of
the template.
.. [1] J. P. Lewis, "Fast Normalized Cross-Correlation", Industrial Light and Magic.
In [1]:
import numpy as np
import matplotlib.pyplot as plt
from skimage import data
from skimage.feature import match_template
from skimage.io import imread, imsave
from skimage.transform import rotate, rescale
from skimage.draw import polygon_perimeter
from glob import glob
import mpld3
mpld3.enable_notebook()
In [2]:
def rotate_good(img, angle):
return 1 - rotate(1 - img, angle)
In [3]:
triangles = [imread(file_name, as_grey=True) for file_name in reversed(['trojkat-maly.png', 'trojkat.png', 'trojkat-duzy.png'])]
triangles_straight = [rotate_good(triangle, angle) for angle in [-90, 0, 90] for triangle in triangles] + [imread('hak.png', as_grey=True), imread('hak-maly.png', as_grey=True)]
triangles_skewed = [rotate_good(triangle, angle) for angle in [-45, 45] for triangle in triangles]
triangles = triangles_straight + triangles_skewed
In [4]:
def match(image):
def match_patterns(patterns):
THRESHOLD = 0.7
best_score = THRESHOLD
best_tid = None
best_x = 0
best_y = 0
best_result = None
for (tid, triangle) in enumerate(patterns):
result = match_template(image, triangle)
ij = np.unravel_index(np.argmax(result), result.shape)
x, y = ij[::-1]
score = result[y, x]
if score > best_score:
best_score = score
best_tid = tid
best_x = x
best_y = y
best_result = result
return (best_result, best_tid, best_x, best_y)
(best_result, best_tid, best_x, best_y) = match_patterns(triangles_straight)
if best_tid is None:
(best_result, best_tid, best_x, best_y) = match_patterns(triangles_skewed)
if best_tid is not None:
best_tid += len(triangles_straight)
if best_tid is not None:
new_image = image.copy()
triangle = triangles[best_tid]
for i in xrange(triangle.shape[0]):
for j in xrange(triangle.shape[1]):
new_image[best_y + i, best_x + j] = 1 - (1 - image[best_y + i, best_x + j]) * triangle[i, j]
return (new_image, best_result, best_tid, best_x, best_y)
else:
return None
In [5]:
def do_stuff(image):
infos = []
while True:
res = match(image)
if res is None:
break
(image, result, tid, x, y) = res
infos.append((tid, x, y))
#ax2.set_axis_off()
#plt.set_title('image')
# highlight matched region
image2 = image.copy()
for (tid, x, y) in infos:
h, w = triangles[tid].shape
(rr, cc) = polygon_perimeter([y, y + h, y + h, y], [x, x, x + w, x + w])
image2[rr, cc] = 0
#rect = plt.Rectangle((x, y), wcoin, hcoin, edgecolor='r', facecolor='none')
#ax2.add_patch(rect)
return image2
In [ ]:
print glob('pics/*.png')
for file_name in glob('pics/*.png'):
print 'dupa1'
image = imread(file_name, as_grey=True)
print 'dupa2'
image2 = do_stuff(image)
print 'dupa3'
#print fig
#print file_name
#file_name = '.'.join(file_name.split('.')[:-1]) + '_plot.png'
#print file_name
#imsave(file_name, image2)
#print 'dupa4'
In [ ]:
(image, result, tid, x, y) = match(image)
triangle = triangles[tid]
print tid
fig = plt.figure(figsize=(8, 3))
ax1 = plt.subplot(1, 3, 1)
ax2 = plt.subplot(1, 3, 2, adjustable='box-forced')
ax3 = plt.subplot(1, 3, 3, sharex=ax2, sharey=ax2, adjustable='box-forced')
ax1.imshow(triangle, cmap=plt.cm.gray)
ax1.set_axis_off()
ax1.set_title('template')
ax2.imshow(image, cmap=plt.cm.gray)
ax2.set_axis_off()
ax2.set_title('image')
# highlight matched region
hcoin, wcoin = triangles[tid].shape
rect = plt.Rectangle((x, y), wcoin, hcoin, edgecolor='r', facecolor='none')
ax2.add_patch(rect)
ax3.imshow(result)
ax3.set_axis_off()
ax3.set_title('`match_template`\nresult')
# highlight matched region
ax3.plot(x, y, 'o', markeredgecolor='r', markerfacecolor='none', markersize=10)
plt.show()
In [ ]: