In [1]:
#%matplotlib inline

P3. Computational Photography

Week #2 Seam Carving - Deletion


In [2]:
from skimage.io import imread
from skimage import color, data, restoration
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from PIL import Image, ImageDraw
import numpy as np
from scipy.misc import imresize
from scipy import ndimage as nd

In [3]:
'''
image uint8 normalize function
'''
def norm(img):
    return np.real((img - np.min(img)) * 256 / (np.max(img) - np.min(img)))

In [4]:
RGB = plt.imread('./img/agbar.png')
scale_factor=2
RGB = imresize(RGB, ( RGB.shape[0]/scale_factor, RGB.shape[1]/scale_factor,3),interp='bilinear').astype('float')
img = color.rgb2gray(RGB)

#Visualization
fig = plt.figure(1) 
plt.subplot(1,2,1)
plt.imshow(RGB.astype('uint8'))
plt.title('Original')
plt.axis('off')
plt.subplot(1,2,2)
plt.imshow(norm(img), cmap="gray")
plt.title('GrayScale')
plt.axis('off')
plt.gcf().set_size_inches((10,10))
fig.tight_layout()
plt.show()

Gradient Magnitude and Energy Matrix


In [5]:
# * * * * * * * * * * * * * * *
# * Gradient Magnitude Matrix *
# * * * * * * * * * * * * * * *
def gradient(img):
    '''
    compute image gradient magnitude matrix.
    '''
    gx, gy = np.gradient(img)
    grad = np.sqrt(gx*gx + gy*gy)
    return grad

# * * * * * * * * *
# * Energy Matrix *
# * * * * * * * * *
def compute_energy(im):
    '''
    This implements the dynamic programming seam-find algorithm. For an m*n picture, this algorithm
     takes O(m*n) time 
    '''
    im_width, im_height = im.shape
    cost = np.zeros(im.shape)
    cost[0] = im[0] #first row for energy matrix as the original gradient magnitude
  
    for x in xrange(1, im_width):
        for y in xrange(im_height):
            if y == 0:
                min_val = min( cost[x-1, y], cost[x-1, y+1] )
            elif y < im_height - 2:
                min_val = min( cost[x-1, y], cost[x-1, y+1] )
                min_val = min( min_val, cost[x-1, y-1] )
            else:
                min_val = min( cost[x-1, y], cost[x-1, y-1] )
            cost[x,y] = im[x,y] + min_val
    return cost

In [6]:
grad = gradient(img)
energy = compute_energy(grad)

#Visualization
fig = plt.figure(2) 
plt.subplot(1,2,1)
plt.imshow(norm(grad), cmap="gray")
plt.title('Gradient Magnitude')
plt.axis('off')
plt.subplot(1,2,2)
plt.imshow(norm(energy), cmap="gray" )
plt.title('Energy')
plt.axis('off')
plt.gcf().set_size_inches((10,10))
fig.tight_layout()
plt.show()

Finding and Removing Vertical Seams


In [7]:
def find_vertical_seam(im, mask):
    '''
    Takes a grayscale img and returns the lowest energy vertical seam as a list of pixels (2-tuples). 
    This implements the dynamic programming seam-find algorithm. For an m*n picture, this algorithm
    takes O(m*n) time 
    @im: a grayscale image
    '''
    assert len(im.shape) == 2
    
    im_width, im_height = im.shape
    grad = apply_mask(gradient(im), mask)
    cost = compute_energy(grad)
    
    min_val = 1e1000
    path = []
    for y in xrange(im_height):
        if cost[im_width-1,y] < min_val:
            min_val = cost[im_width-1,y]
            min_ptr = y
    
    # if there is still a negative value in the last row of energy mat -> the selected area deletion is not finished
    if(min_val >= 0):
        return path, True
    
    
    pos = (im_width-1, min_ptr)
    path.append(pos) 
    while pos[0] != 0:
        val = cost[pos] - grad[pos]
        x,y = pos
        if y == 0:
            if val == cost[x-1,y+1]:
                pos = (x-1,y+1) 
            else:
                pos = (x-1,y)
        elif y <= im_height - 2:
            if val == cost[x-1,y+1]:
                pos = (x-1,y+1) 
            elif val == cost[x-1,y]:
                pos = (x-1,y)
            else:
                pos = (x-1,y-1)
        else:
            if val == cost[x-1,y]:
                pos = (x-1,y)
            else:
                pos = (x-1,y-1) 
                
        path.append(pos)  
    #print "Reconstruction Complete."
    return path, False


def mark_seam(img, path, mark_as='red'):
    '''
    Marks a seam for easy visual checking
    @img: an input img
    @path: the seam
    '''
    assert mark_as in ['red','green','blue','black','white']
    assert len(img.shape) == 3
    
    #print "Marking seam..."
    for pixel in path:
        if mark_as == 'red':
           img[pixel] = (255,0,0)
        elif mark_as == 'green':
            img[pixel] = (0,255,0)
        elif mark_as == 'blue':
            img[pixel] = (0,0,255)
        elif mark_as == 'white':
            img[pixel] = (255,255,255)
        elif mark_as == 'black':
            img[pixel] = (0,0,0)
    #print "Marking Complete."
    
    return img;


def delete_vertical_seam (img, path):
    '''
    Deletes the pixels in a vertical path from img
    @img: an input img
    @path: pixels to delete in a vertical path
    '''
    #print "Deleting Vertical Seam..."
    img_height, img_width = img.shape[:2]
    if(len(img.shape) == 3):
        i = np.zeros((img_height, img_width-1, img.shape[2]))
    else:
        i = np.zeros((img_height, img_width-1))
    
    path_set = set(path)
    seen_set = set()
    for x in xrange(img_height):
        for y in xrange(img_width):
            if (x,y) not in path_set and x not in seen_set:
                i[x,y] = img[x,y]
            elif (x,y) in path_set:
                seen_set.add(x)
            else:
                i[x,y-1] = img[x,y]
    #print "Deletion Complete."
    return i;

interactive_only.py


In [8]:
class get_mouse_click():
    """
    Mouse interaction interface for radial distortion removal.
    """
    def __init__(self, img):
      height, width = img.shape[:2]
      self.figure = plt.imshow(img, extent=(0, width, height, 0))
      plt.gray()
      plt.title('select the object to remove')
      plt.xlabel('Select sets of  points with left mouse button,\n'
                 'click right button to close the polygon.')
      plt.connect('button_press_event', self.button_press)
      plt.connect('motion_notify_event', self.mouse_move)

      self.img = np.atleast_3d(img)
      self.points = []
      self.centre = np.array([(width - 1)/2., (height - 1)/2.])

      self.height = height
      self.width = width

      self.make_cursorline()
      self.figure.axes.set_autoscale_on(False)

      plt.show()
      plt.close()

    def make_cursorline(self):
        self.cursorline, = plt.plot([0],[0],'r:+',
                                    linewidth=2,markersize=15,markeredgecolor='b')

    def button_press(self,event):
        """
        Register mouse clicks.
        """
        if (event.button == 1 and event.xdata and event.ydata):
            self.points.append((event.xdata,event.ydata))
            print "Coordinate entered: (%f,%f)" % (event.xdata, event.ydata)

            #if len(self.points) % 2 == 0:
            plt.gca().lines.append(self.cursorline)
            self.make_cursorline()

        if (event.button != 1):
            #print "pepito: " ,self.points
            self.points.append((self.points[0][0],self.points[0][1]))
            plt.close()
            return self.points


    def mouse_move(self,event):
        """
        Handle cursor drawing.
        """
        pts_last_set=len(self.points)
        pts = np.zeros((pts_last_set+1,2))
        if pts_last_set > 0:
            # Line follows up to 3 clicked points:
            pts[:pts_last_set] = self.points[-pts_last_set:]
            # The last point of the line follows the mouse cursor
        pts[pts_last_set:] = [event.xdata,event.ydata]
        self.cursorline.set_data(pts[:,0], pts[:,1])
        plt.draw()


def compute_mask(width,height,polygon):
    img = Image.new('L', (width, height), 0)
    ImageDraw.Draw(img).polygon(polygon, outline=1, fill=1)
    mask = np.array(img)
    mask100 = 100*np.ones([mask.shape[0],mask.shape[1]])
    mask=mask*101
    mask = 1. - mask; # switch 0s and 1s
    
    return mask

Area Selection


In [9]:
# * * * * * * * * * * * * *
# * Area Mask application *
# * * * * * * * * * * * * *
def apply_mask(img, mask):    
    masked = img.copy()

    for x in xrange(mask.shape[0]):
        for y in xrange(mask.shape[1]):
            if mask[x,y] == -100:
                masked[x,y] = mask[x,y]
    return masked

# * * * * * * * * * * * * *
# * Seam Carving Deletion *
# * * * * * * * * * * * * *
def seam_carving_deletion(RGB, mask, verbose=False):
    seams = []
    rgb = RGB.copy();
    i = 1
    
    while(True):
        gray = color.rgb2gray(rgb)
        path, end = find_vertical_seam(gray, mask)
        
        if (end):
            return rgb, seams;
        
        rgb = mark_seam(rgb, path, 'blue')
        seams.append([plt.imshow(rgb.astype('uint8'))])
        
        rgb = delete_vertical_seam(rgb, path)
        mask = delete_vertical_seam(mask, path)
        
        if verbose:
            print "iteration "+str(i),"\tseam carved...", rgb.shape
            i += 1
        
    return rgb, seams;

In [15]:
width, height = img.shape
# GET AREA TO DELETE FROM IMAGE
rdi = get_mouse_click(RGB.astype('uint8'))
polygon = rdi.points

# COMPUTE MASK FROM POLYGON
mask = compute_mask(height, width, polygon)


Coordinate entered: (131.250000,80.841121)
Coordinate entered: (141.374611,87.071651)
Coordinate entered: (152.278037,133.021807)
Coordinate entered: (153.056854,148.598131)
Coordinate entered: (153.056854,184.423676)
Coordinate entered: (153.056854,223.364486)
Coordinate entered: (112.558411,221.806854)
Coordinate entered: (111.779595,140.031153)
Coordinate entered: (112.558411,123.676012)
Coordinate entered: (118.788941,93.302181)
Coordinate entered: (125.019470,82.398754)

Execution Test and Results


In [16]:
fig = plt.figure(1)
result, seams = seam_carving_deletion(RGB, mask, True)


iteration 1 	seam carved... (400, 299, 3)
iteration 2 	seam carved... (400, 298, 3)
iteration 3 	seam carved... (400, 297, 3)
iteration 4 	seam carved... (400, 296, 3)
iteration 5 	seam carved... (400, 295, 3)
iteration 6 	seam carved... (400, 294, 3)
iteration 7 	seam carved... (400, 293, 3)
iteration 8 	seam carved... (400, 292, 3)
iteration 9 	seam carved... (400, 291, 3)
iteration 10 	seam carved... (400, 290, 3)
iteration 11 	seam carved... (400, 289, 3)
iteration 12 	seam carved... (400, 288, 3)
iteration 13 	seam carved... (400, 287, 3)
iteration 14 	seam carved... (400, 286, 3)
iteration 15 	seam carved... (400, 285, 3)
iteration 16 	seam carved... (400, 284, 3)
iteration 17 	seam carved... (400, 283, 3)
iteration 18 	seam carved... (400, 282, 3)
iteration 19 	seam carved... (400, 281, 3)
iteration 20 	seam carved... (400, 280, 3)
iteration 21 	seam carved... (400, 279, 3)
iteration 22 	seam carved... (400, 278, 3)
iteration 23 	seam carved... (400, 277, 3)
iteration 24 	seam carved... (400, 276, 3)
iteration 25 	seam carved... (400, 275, 3)
iteration 26 	seam carved... (400, 274, 3)
iteration 27 	seam carved... (400, 273, 3)
iteration 28 	seam carved... (400, 272, 3)
iteration 29 	seam carved... (400, 271, 3)
iteration 30 	seam carved... (400, 270, 3)
iteration 31 	seam carved... (400, 269, 3)
iteration 32 	seam carved... (400, 268, 3)
iteration 33 	seam carved... (400, 267, 3)
iteration 34 	seam carved... (400, 266, 3)
iteration 35 	seam carved... (400, 265, 3)
iteration 36 	seam carved... (400, 264, 3)
iteration 37 	seam carved... (400, 263, 3)
iteration 38 	seam carved... (400, 262, 3)
iteration 39 	seam carved... (400, 261, 3)
iteration 40 	seam carved... (400, 260, 3)
iteration 41 	seam carved... (400, 259, 3)
iteration 42 	seam carved... (400, 258, 3)
iteration 43 	seam carved... (400, 257, 3)

Comment <%matplotlib inline> cell and Run all to show animation.


In [20]:
#run animation
ani = animation.ArtistAnimation(fig, seams, interval=100, blit=True, repeat_delay=1000);
plt.show();

In [18]:
#Visualization
fig = plt.figure(4) 
plt.subplot(1,2,1)
plt.imshow(RGB.astype('uint8'))
plt.title('Original')
plt.axis('off')
plt.subplot(1,2,2)
plt.imshow(result.astype('uint8'))
plt.title('Seam Carving Deletion')
plt.axis('off')
plt.gcf().set_size_inches((10,10))
fig.tight_layout()
plt.show();

In [ ]: