Data augmentation


In [ ]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

In [ ]:
#export
from exp.nb_09c import *

PIL transforms

We start with PIL transforms to resize all our images to the same size. Then, when they are in a batch, we can apply data augmentation to all of them at the same time on the GPU. We have already seen the basics of resizing and putting on the GPU in 08, but we'll look more into it now.

View images


In [ ]:
#export
make_rgb._order=0

In [ ]:
path = datasets.untar_data(datasets.URLs.IMAGENETTE)
tfms = [make_rgb, ResizeFixed(128), to_byte_tensor, to_float_tensor]

In [ ]:
def get_il(tfms): return ImageList.from_files(path, tfms=tfms)

In [ ]:
il = get_il(tfms)

In [ ]:
show_image(il[0])



In [ ]:
img = PIL.Image.open(il.items[0])

In [ ]:
img


Out[ ]:

In [ ]:
img.getpixel((1,1))


Out[ ]:
(28, 34, 32)

In [ ]:
import numpy as np

In [ ]:
%timeit -n 10 a = np.array(PIL.Image.open(il.items[0]))


22.1 ms ± 307 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

Be careful of resampling methods, you can quickly lose some textures!


In [ ]:
img.resize((128,128), resample=PIL.Image.ANTIALIAS)


Out[ ]:

In [ ]:
img.resize((128,128), resample=PIL.Image.BILINEAR)


Out[ ]:

In [ ]:
img.resize((128,128), resample=PIL.Image.NEAREST)


Out[ ]:

In [ ]:
img.resize((256,256), resample=PIL.Image.BICUBIC).resize((128,128), resample=PIL.Image.NEAREST)


Out[ ]:

In [ ]:
%timeit img.resize((224,224), resample=PIL.Image.BICUBIC)


12.9 ms ± 30.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

In [ ]:
%timeit img.resize((224,224), resample=PIL.Image.BILINEAR)


7.81 ms ± 2.86 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

In [ ]:
%timeit -n 10 img.resize((224,224), resample=PIL.Image.NEAREST)


90.7 µs ± 4.61 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

Flip

Flip can be done with PIL very fast.


In [ ]:
#export
import random

In [ ]:
def pil_random_flip(x):
    return x.transpose(PIL.Image.FLIP_LEFT_RIGHT) if random.random()<0.5 else x

In [ ]:
il1 = get_il(tfms)
il1.items = [il1.items[0]]*64
dl = DataLoader(il1, 8)

In [ ]:
x = next(iter(dl))

Here is a convenience function to look at images in a batch.


In [ ]:
#export
def show_image(im, ax=None, figsize=(3,3)):
    if ax is None: _,ax = plt.subplots(1, 1, figsize=figsize)
    ax.axis('off')
    ax.imshow(im.permute(1,2,0))

def show_batch(x, c=4, r=None, figsize=None):
    n = len(x)
    if r is None: r = int(math.ceil(n/c))
    if figsize is None: figsize=(c*3,r*3)
    fig,axes = plt.subplots(r,c, figsize=figsize)
    for xi,ax in zip(x,axes.flat): show_image(xi, ax)

Without data augmentation:


In [ ]:
show_batch(x)


With random flip:


In [ ]:
il1.tfms.append(pil_random_flip)

In [ ]:
x = next(iter(dl))
show_batch(x)


We can also make that transform a class so it's easier to set the value of the parameter p. As seen before, it also allows us to set the _order attribute.


In [ ]:
class PilRandomFlip(Transform):
    _order=11
    def __init__(self, p=0.5): self.p=p
    def __call__(self, x):
        return x.transpose(PIL.Image.FLIP_LEFT_RIGHT) if random.random()<self.p else x

In [ ]:
#export
class PilTransform(Transform): _order=11

class PilRandomFlip(PilTransform):
    def __init__(self, p=0.5): self.p=p
    def __call__(self, x):
        return x.transpose(PIL.Image.FLIP_LEFT_RIGHT) if random.random()<self.p else x

In [ ]:
del(il1.tfms[-1])
il1.tfms.append(PilRandomFlip(0.8))

In [ ]:
x = next(iter(dl))
show_batch(x)


PIL can also do the whole dihedral group of transformations (random horizontal flip, random vertical flip and the four 90 degrees rotation) with the transpose method. Here are the codes of a few transformations:


In [ ]:
PIL.Image.FLIP_LEFT_RIGHT,PIL.Image.ROTATE_270,PIL.Image.TRANSVERSE


Out[ ]:
(0, 4, 6)

Be careful that img.transpose(0) is already one transform, so doing nothing requires a separate case, then we have 7 different transformations.


In [ ]:
img = PIL.Image.open(il.items[0])
img = img.resize((128,128), resample=PIL.Image.NEAREST)
_, axs = plt.subplots(2, 4, figsize=(12, 6))
for i,ax in enumerate(axs.flatten()):
    if i==0: ax.imshow(img)
    else:    ax.imshow(img.transpose(i-1))
    ax.axis('off')


And we can implement it like this:


In [ ]:
#export
class PilRandomDihedral(PilTransform):
    def __init__(self, p=0.75): self.p=p*7/8 #Little hack to get the 1/8 identity dihedral transform taken into account.
    def __call__(self, x):
        if random.random()>self.p: return x
        return x.transpose(random.randint(0,6))

In [ ]:
del(il1.tfms[-1])
il1.tfms.append(PilRandomDihedral())

In [ ]:
show_batch(next(iter(dl)))


Random crop


In [ ]:
img = PIL.Image.open(il.items[0])
img.size


Out[ ]:
(1442, 828)

To crop an image with PIL we have to specify the top/left and bottom/right corner in this format: (left, top, right, bottom). We won't just crop the size we want, but first crop the section we want of the image and then apply a resize. In what follows, we call the first one the crop_size.


In [ ]:
img.crop((60,60,320,320)).resize((128,128), resample=PIL.Image.BILINEAR)


Out[ ]:

In [ ]:
cnr2 = (60,60,320,320)
resample = PIL.Image.BILINEAR

This is pretty fast in PIL:


In [ ]:
%timeit -n 10 img.crop(cnr2).resize((128,128), resample=resample)


818 µs ± 26.5 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

Our time budget: aim for 5 mins per batch for imagenet on 8 GPUs. 1.25m images in imagenet. So on one GPU per minute that's 1250000/8/5 == 31250, or 520 per second. Assuming 4 cores per GPU, then we want ~125 images per second - so try to stay <10ms per image. Here we have time to do more things. For instance, we can do the crop and resize in the same call to transform, which will give a smoother result.


In [ ]:
img.transform((128,128), PIL.Image.EXTENT, cnr2, resample=resample)


Out[ ]:

In [ ]:
%timeit -n 10 img.transform((128,128), PIL.Image.EXTENT, cnr2, resample=resample)


751 µs ± 3.05 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

It's a little bit slower but still fast enough for our purpose, so we will use this. We then define a general crop transform and two subclasses: one to crop at the center (for validation) and one to randomly crop. Each time, the subclass only implements the way to get the four corners passed to PIL.


In [ ]:
#export
from random import randint

def process_sz(sz):
    sz = listify(sz)
    return tuple(sz if len(sz)==2 else [sz[0],sz[0]])

def default_crop_size(w,h): return [w,w] if w < h else [h,h]

class GeneralCrop(PilTransform):
    def __init__(self, size, crop_size=None, resample=PIL.Image.BILINEAR): 
        self.resample,self.size = resample,process_sz(size)
        self.crop_size = None if crop_size is None else process_sz(crop_size)
        
    def default_crop_size(self, w,h): return default_crop_size(w,h)

    def __call__(self, x):
        csize = self.default_crop_size(*x.size) if self.crop_size is None else self.crop_size
        return x.transform(self.size, PIL.Image.EXTENT, self.get_corners(*x.size, *csize), resample=self.resample)
    
    def get_corners(self, w, h): return (0,0,w,h)

class CenterCrop(GeneralCrop):
    def __init__(self, size, scale=1.14, resample=PIL.Image.BILINEAR):
        super().__init__(size, resample=resample)
        self.scale = scale
        
    def default_crop_size(self, w,h): return [w/self.scale,h/self.scale]
    
    def get_corners(self, w, h, wc, hc):
        return ((w-wc)//2, (h-hc)//2, (w-wc)//2+wc, (h-hc)//2+hc)

In [ ]:
il1.tfms = [make_rgb, CenterCrop(128), to_byte_tensor, to_float_tensor]

In [ ]:
show_batch(next(iter(dl)))


RandomResizeCrop

This is the usual data augmentation used on ImageNet (introduced here) that consists of selecting 8 to 100% of the image area and a scale between 3/4 and 4/3 as a crop, then resizing it to the desired size. It combines some zoom and a bit of squishing at a very low computational cost.


In [ ]:
# export
class RandomResizedCrop(GeneralCrop):
    def __init__(self, size, scale=(0.08,1.0), ratio=(3./4., 4./3.), resample=PIL.Image.BILINEAR):
        super().__init__(size, resample=resample)
        self.scale,self.ratio = scale,ratio
    
    def get_corners(self, w, h, wc, hc):
        area = w*h
        #Tries 10 times to get a proper crop inside the image.
        for attempt in range(10):
            area = random.uniform(*self.scale) * area
            ratio = math.exp(random.uniform(math.log(self.ratio[0]), math.log(self.ratio[1])))
            new_w = int(round(math.sqrt(area * ratio)))
            new_h = int(round(math.sqrt(area / ratio)))
            if new_w <= w and new_h <= h:
                left = random.randint(0, w - new_w)
                top  = random.randint(0, h - new_h)
                return (left, top, left + new_w, top + new_h)
        
        # Fallback to squish
        if   w/h < self.ratio[0]: size = (w, int(w/self.ratio[0]))
        elif w/h > self.ratio[1]: size = (int(h*self.ratio[1]), h)
        else:                     size = (w, h)
        return ((w-size[0])//2, (h-size[1])//2, (w+size[0])//2, (h+size[1])//2)

In [ ]:
il1.tfms = [make_rgb, RandomResizedCrop(128), to_byte_tensor, to_float_tensor]

In [ ]:
show_batch(next(iter(dl)))


Perspective warping

To do perspective warping, we map the corners of the image to new points: for instance, if we want to tilt the image so that the top looks closer to us, the top/left corner needs to be shifted to the right and the top/right to the left. To avoid squishing, the bottom/left corner needs to be shifted to the left and the bottom/right corner to the right. For instance, if we have an image with corners in:

(60,60,60,280,280,280,280,60)

(top/left, bottom/left, bottom/right, top/right) then a warped version is:

(90,60,30,280,310,280,250,60)

PIL can do this for us but it requires 8 coefficients we need to calculate. The math isn't the most important here, as we've done it for you. We need to solve this equation. The equation solver is called torch.solve in PyTorch.


In [ ]:
# export
from torch import FloatTensor,LongTensor

def find_coeffs(orig_pts, targ_pts):
    matrix = []
    #The equations we'll need to solve.
    for p1, p2 in zip(targ_pts, orig_pts):
        matrix.append([p1[0], p1[1], 1, 0, 0, 0, -p2[0]*p1[0], -p2[0]*p1[1]])
        matrix.append([0, 0, 0, p1[0], p1[1], 1, -p2[1]*p1[0], -p2[1]*p1[1]])

    A = FloatTensor(matrix)
    B = FloatTensor(orig_pts).view(8, 1)
    #The 8 scalars we seek are solution of AX = B
    return list(torch.solve(B,A)[0][:,0])

In [ ]:
# export
def warp(img, size, src_coords, resample=PIL.Image.BILINEAR):
    w,h = size
    targ_coords = ((0,0),(0,h),(w,h),(w,0))
    c = find_coeffs(src_coords,targ_coords)
    res = img.transform(size, PIL.Image.PERSPECTIVE, list(c), resample=resample)
    return res

In [ ]:
targ = ((0,0),(0,128),(128,128),(128,0))
src  = ((90,60),(30,280),(310,280),(250,60))

In [ ]:
c = find_coeffs(src, targ)
img.transform((128,128), PIL.Image.PERSPECTIVE, list(c), resample=resample)


Out[ ]:

In [ ]:
%timeit -n 10 warp(img, (128,128), src)


990 µs ± 9.37 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

In [ ]:
%timeit -n 10 warp(img, (128,128), src, resample=PIL.Image.NEAREST)


338 µs ± 10.5 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

In [ ]:
warp(img, (64,64), src, resample=PIL.Image.BICUBIC)


Out[ ]:

In [ ]:
warp(img, (64,64), src, resample=PIL.Image.NEAREST)


Out[ ]:

In [ ]:
# export
def uniform(a,b): return a + (b-a) * random.random()

We can add a transform to do this perspective warping automatically with the rand resize and crop.


In [ ]:
class PilTiltRandomCrop(PilTransform):
    def __init__(self, size, crop_size=None, magnitude=0., resample=PIL.Image.NEAREST): 
        self.resample,self.size,self.magnitude = resample,process_sz(size),magnitude
        self.crop_size = None if crop_size is None else process_sz(crop_size)
        
    def __call__(self, x):
        csize = default_crop_size(*x.size) if self.crop_size is None else self.crop_size
        up_t,lr_t = uniform(-self.magnitude, self.magnitude),uniform(-self.magnitude, self.magnitude)
        left,top = randint(0,x.size[0]-csize[0]),randint(0,x.size[1]-csize[1])
        src_corners = tensor([[-up_t, -lr_t], [up_t, 1+lr_t], [1-up_t, 1-lr_t], [1+up_t, lr_t]])
        src_corners = src_corners * tensor(csize).float() + tensor([left,top]).float()
        src_corners = tuple([(int(o[0].item()), int(o[1].item())) for o in src_corners])
        return warp(x, self.size, src_corners, resample=self.resample)

In [ ]:
il1.tfms = [make_rgb, PilTiltRandomCrop(128, magnitude=0.1), to_byte_tensor, to_float_tensor]

In [ ]:
x = next(iter(dl))
show_batch(x)


Problem is that black padding appears as soon as our target points are outside of the image, so we have to limit the magnitude if we want to avoid that.


In [ ]:
# export
class PilTiltRandomCrop(PilTransform):
    def __init__(self, size, crop_size=None, magnitude=0., resample=PIL.Image.BILINEAR): 
        self.resample,self.size,self.magnitude = resample,process_sz(size),magnitude
        self.crop_size = None if crop_size is None else process_sz(crop_size)
        
    def __call__(self, x):
        csize = default_crop_size(*x.size) if self.crop_size is None else self.crop_size
        left,top = randint(0,x.size[0]-csize[0]),randint(0,x.size[1]-csize[1])
        top_magn = min(self.magnitude, left/csize[0], (x.size[0]-left)/csize[0]-1)
        lr_magn  = min(self.magnitude, top /csize[1], (x.size[1]-top) /csize[1]-1)
        up_t,lr_t = uniform(-top_magn, top_magn),uniform(-lr_magn, lr_magn)
        src_corners = tensor([[-up_t, -lr_t], [up_t, 1+lr_t], [1-up_t, 1-lr_t], [1+up_t, lr_t]])
        src_corners = src_corners * tensor(csize).float() + tensor([left,top]).float()
        src_corners = tuple([(int(o[0].item()), int(o[1].item())) for o in src_corners])
        return warp(x, self.size, src_corners, resample=self.resample)

In [ ]:
il1.tfms = [make_rgb, PilTiltRandomCrop(128, 200, magnitude=0.2), to_byte_tensor, to_float_tensor]

In [ ]:
x = next(iter(dl))
show_batch(x)


Faster tensor creation


In [ ]:
[(o._order,o) for o in sorted(tfms, key=operator.attrgetter('_order'))]


Out[ ]:
[(0, <function exp.nb_08.make_rgb(item)>),
 (10, <exp.nb_08.ResizeFixed at 0x7f995b605ac8>),
 (11, <__main__.PilRandomDihedral at 0x7f99023e3dd8>),
 (20, <function exp.nb_08.to_byte_tensor(item)>),
 (30, <function exp.nb_08.to_float_tensor(item)>)]

In [ ]:
#export
import numpy as np

def np_to_float(x): return torch.from_numpy(np.array(x, dtype=np.float32, copy=False)).permute(2,0,1).contiguous()/255.
np_to_float._order = 30

It is actually faster to combine to_float_tensor and to_byte_tensor in one transform using numpy.


In [ ]:
%timeit -n 10 to_float_tensor(to_byte_tensor(img))


2.16 ms ± 18.1 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

In [ ]:
%timeit -n 10 np_to_float(img)


1.38 ms ± 10.1 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

Batch data augmentation

You can write your own augmentation for your domain's data types, and have them run on the GPU, by using regular PyTorch tensor operations. Here's an example for images. The key is to do them on a whole batch at a time. Nearly all PyTorch operations can be done batch-wise.

Once we have resized our images so that we can batch them together, we can apply more data augmentation on a batch level. For the affine/coord transforms, we proceed like this:

  1. generate a grid map of the size of our batch (bs x height x width x 2) that contains the coordinates of a grid of size height x width (this will be the final size of the image, and doesn't have to be the same as the current size in the batch)
  2. apply the affine transforms (which is a matrix multiplication) and the coord transforms to that grid map
  3. interpolate the values of the final pixels we want from the initial images in the batch, according to the transformed grid map

For 1. and 3. there are PyTorch functions: F.affine_grid and F.grid_sample. F.affine_grid can even combine 1 and 2 if we just want to do an affine transformation.

Step 1: generate the grid


In [ ]:
il1.tfms = [make_rgb, PilTiltRandomCrop(128, magnitude=0.2), to_byte_tensor, to_float_tensor]

In [ ]:
dl = DataLoader(il1, 64)

In [ ]:
x = next(iter(dl))

In [ ]:
from torch import FloatTensor

In [ ]:
def affine_grid_cpu(size):
    N, C, H, W = size
    grid = FloatTensor(N, H, W, 2)
    linear_points = torch.linspace(-1, 1, W) if W > 1 else tensor([-1])
    grid[:, :, :, 0] = torch.ger(torch.ones(H), linear_points).expand_as(grid[:, :, :, 0])
    linear_points = torch.linspace(-1, 1, H) if H > 1 else tensor([-1])
    grid[:, :, :, 1] = torch.ger(linear_points, torch.ones(W)).expand_as(grid[:, :, :, 1])
    return grid

In [ ]:
grid = affine_grid_cpu(x.size())

In [ ]:
grid.shape


Out[ ]:
torch.Size([64, 128, 128, 2])

In [ ]:
grid[0,:5,:5]


Out[ ]:
tensor([[[-1.0000, -1.0000],
         [-0.9843, -1.0000],
         [-0.9685, -1.0000],
         [-0.9528, -1.0000],
         [-0.9370, -1.0000]],

        [[-1.0000, -0.9843],
         [-0.9843, -0.9843],
         [-0.9685, -0.9843],
         [-0.9528, -0.9843],
         [-0.9370, -0.9843]],

        [[-1.0000, -0.9685],
         [-0.9843, -0.9685],
         [-0.9685, -0.9685],
         [-0.9528, -0.9685],
         [-0.9370, -0.9685]],

        [[-1.0000, -0.9528],
         [-0.9843, -0.9528],
         [-0.9685, -0.9528],
         [-0.9528, -0.9528],
         [-0.9370, -0.9528]],

        [[-1.0000, -0.9370],
         [-0.9843, -0.9370],
         [-0.9685, -0.9370],
         [-0.9528, -0.9370],
         [-0.9370, -0.9370]]])

In [ ]:
%timeit -n 10 grid = affine_grid_cpu(x.size())


5.78 ms ± 397 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

Coords in the grid go from -1, to 1 (PyTorch convention).

PyTorch version is slower on the CPU but optimized to go very fast on the GPU


In [ ]:
m = tensor([[1., 0., 0.], [0., 1., 0.]])
theta = m.expand(x.size(0), 2, 3)

In [ ]:
theta.shape


Out[ ]:
torch.Size([64, 2, 3])

In [ ]:
%timeit -n 10 grid = F.affine_grid(theta, x.size())


10.1 ms ± 545 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

In [ ]:
%timeit -n 10 grid = F.affine_grid(theta.cuda(), x.size())


The slowest run took 25.04 times longer than the fastest. This could mean that an intermediate result is being cached.
351 µs ± 665 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

So we write our own version that dispatches on the CPU with our function and uses PyTorch's on the GPU.


In [ ]:
def affine_grid(x, size):
    size = (size,size) if isinstance(size, int) else tuple(size)
    size = (x.size(0),x.size(1)) + size
    if x.device.type == 'cpu': return affine_grid_cpu(size) 
    m = tensor([[1., 0., 0.], [0., 1., 0.]], device=x.device)
    return F.affine_grid(m.expand(x.size(0), 2, 3), size)

In [ ]:
grid = affine_grid(x, 128)

Step 2: Affine multiplication

In 2D an affine transformation has the form y = Ax + b where A is a 2x2 matrix and b a vector with 2 coordinates. It's usually represented by the 3x3 matrix

A[0,0]  A[0,1]  b[0]
A[1,0]  A[1,1]  b[1]
   0       0     1

because then the composition of two affine transforms can be computed with the matrix product of their 3x3 representations.


In [ ]:
from torch import stack,zeros_like,ones_like

The matrix for a rotation that has an angle of theta is:

cos(theta) -sin(theta) 0
sin(theta)  cos(theta) 0
0           0          1

Here we have to apply the reciprocal of a regular rotation (exercise: find why!) so we use this matrix:

 cos(theta) sin(theta) 0
-sin(theta) cos(theta) 0
 0          0          1

then we draw a different theta for each version of the image in the batch to return a batch of rotation matrices (size bs x 3 x 3).


In [ ]:
def rotation_matrix(thetas):
    thetas.mul_(math.pi/180)
    rows = [stack([thetas.cos(),             thetas.sin(),             torch.zeros_like(thetas)], dim=1),
            stack([-thetas.sin(),            thetas.cos(),             torch.zeros_like(thetas)], dim=1),
            stack([torch.zeros_like(thetas), torch.zeros_like(thetas), torch.ones_like(thetas)], dim=1)]
    return stack(rows, dim=1)

In [ ]:
thetas = torch.empty(x.size(0)).uniform_(-30,30)

In [ ]:
thetas[:5]


Out[ ]:
tensor([-20.3453,  25.1311, -15.2565,  -2.8207,  25.0947])

In [ ]:
m = rotation_matrix(thetas)

In [ ]:
m.shape, m[:,None].shape, grid.shape


Out[ ]:
(torch.Size([64, 3, 3]),
 torch.Size([64, 1, 3, 3]),
 torch.Size([64, 128, 128, 2]))

In [ ]:
grid.view(64,-1,2).shape


Out[ ]:
torch.Size([64, 16384, 2])

We have to apply our rotation to every point in the grid. The matrix a is given by the first two rows and two columns of m and the vector b is the first two coefficients of the last column. Of course we have to deal with the fact that here m is a batch of matrices.


In [ ]:
a = m[:,:2,:2]
b = m[:, 2:,:2]
tfm_grid = (grid.view(64,-1,2) @ a + b).view(64, 128, 128, 2)

We can also do this without the view by using broadcasting.


In [ ]:
%timeit -n 10 tfm_grid = grid @ m[:,None,:2,:2] + m[:,2,:2][:,None,None]


22.9 ms ± 337 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

In [ ]:
%timeit -n 10 tfm_grid = torch.einsum('bijk,bkl->bijl', grid, m[:,:2,:2]) + m[:,2,:2][:,None,None]


11.5 ms ± 539 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

In [ ]:
%timeit -n 10 tfm_grid = torch.matmul(grid, m[:,:2,:2].unsqueeze(1)) + m[:,2,:2][:,None,None]


22.9 ms ± 162 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

In [ ]:
%timeit -n 10 tfm_grid = (torch.bmm(grid.view(64,-1,2), m[:,:2,:2]) + m[:,2,:2][:,None]).view(-1, 128, 128, 2)


11.1 ms ± 331 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

And on the GPU


In [ ]:
grid = grid.cuda()
m = m.cuda()

In [ ]:
%timeit -n 10 tfm_grid = grid @ m[:,None,:2,:2] + m[:,2,:2][:,None,None]


166 µs ± 117 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

In [ ]:
%timeit -n 10 tfm_grid = torch.einsum('bijk,bkl->bijl', grid, m[:,:2,:2]) + m[:,2,:2][:,None,None]


138 µs ± 5.96 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

In [ ]:
%timeit -n 10 tfm_grid = torch.matmul(grid, m[:,:2,:2].unsqueeze(1)) + m[:,2,:2][:,None,None]


123 µs ± 8.41 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

In [ ]:
%timeit -n 10 tfm_grid = (torch.bmm(grid.view(64,-1,2), m[:,:2,:2]) + m[:,2,:2][:,None]).view(-1, 128, 128, 2)


74.9 µs ± 3.76 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

Step 3: interpolate

Since bmm is always the fastest, we use this one for the matrix multiplication.


In [ ]:
tfm_grid = torch.bmm(grid.view(64,-1,2), m[:,:2,:2]).view(-1, 128, 128, 2)

The interpolation to find our coordinates back is done by grid_sample.


In [ ]:
tfm_x = F.grid_sample(x, tfm_grid.cpu())

In [ ]:
show_batch(tfm_x, r=2)


It takes a padding_mode argument.


In [ ]:
tfm_x = F.grid_sample(x, tfm_grid.cpu(), padding_mode='reflection')

In [ ]:
show_batch(tfm_x, r=2)


Timing

Let's look at the speed now!


In [ ]:
def rotate_batch(x, size, degrees):
    grid = affine_grid(x, size)
    thetas = x.new(x.size(0)).uniform_(-degrees,degrees)
    m = rotation_matrix(thetas)
    tfm_grid = grid @ m[:,:2,:2].unsqueeze(1) + m[:,2,:2][:,None,None]
    return F.grid_sample(x, tfm_grid)

In [ ]:
show_batch(rotate_batch(x, 128, 30), r=2)


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

In [ ]:
%timeit -n 10 tfm_x = rotate_batch(x, 128, 30)


45.6 ms ± 491 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

In [ ]:
%timeit -n 10 tfm_x = rotate_batch(x.cuda(), 128, 30)


2.65 ms ± 138 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

Not bad for 64 rotations!

Jit version

But we can be even faster!


In [ ]:
from torch import Tensor

In [ ]:
from torch.jit import script

@script
def rotate_batch(x:Tensor, size:int, degrees:float) -> Tensor:
    sz = (x.size(0),x.size(1)) + (size,size)
    idm = torch.zeros(2,3, device=x.device)
    idm[0,0] = 1.
    idm[1,1] = 1.
    grid = F.affine_grid(idm.expand(x.size(0), 2, 3), sz)
    thetas = torch.zeros(x.size(0), device=x.device).uniform_(-degrees,degrees)
    m = rotation_matrix(thetas)
    tfm_grid = torch.matmul(grid, m[:,:2,:2].unsqueeze(1)) + m[:,2,:2].unsqueeze(1).unsqueeze(2)
    return F.grid_sample(x, tfm_grid)

In [ ]:
m = tensor([[1., 0., 0.], [0., 1., 0.]], device=x.device)

In [ ]:
%timeit -n 10 tfm_x = rotate_batch(x.cuda(), 128, 30)


2.6 ms ± 12 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

The speed of this depends a lot on what card you have. On a V100 it is generally about 3x faster than non-JIT (as at April 2019) although PyTorch JIT is rapidly improving.

affine multiplication with affine_grid

And even faster if we give the matrix rotation to affine_grid.


In [ ]:
def rotate_batch(x, size, degrees):
    size = (size,size) if isinstance(size, int) else tuple(size)
    size = (x.size(0),x.size(1)) + size
    thetas = x.new(x.size(0)).uniform_(-degrees,degrees)
    m = rotation_matrix(thetas)
    grid = F.affine_grid(m[:,:2], size)
    return F.grid_sample(x.cuda(), grid)

In [ ]:
%timeit -n 10 tfm_x = rotate_batch(x.cuda(), 128, 30)


1.92 ms ± 50.3 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

Export


In [ ]:
!./notebook2script.py 10_augmentation.ipynb

In [ ]: