In [ ]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
In [ ]:
#export
from exp.nb_09c import *
We start with PIL transforms to resize all our images to the same size. Then, when they are in a batch, we can apply data augmentation to all of them at the same time on the GPU. We have already seen the basics of resizing and putting on the GPU in 08, but we'll look more into it now.
In [ ]:
#export
make_rgb._order=0
In [ ]:
path = datasets.untar_data(datasets.URLs.IMAGENETTE)
tfms = [make_rgb, ResizeFixed(128), to_byte_tensor, to_float_tensor]
In [ ]:
def get_il(tfms): return ImageList.from_files(path, tfms=tfms)
In [ ]:
il = get_il(tfms)
In [ ]:
show_image(il[0])
In [ ]:
img = PIL.Image.open(il.items[0])
In [ ]:
img
Out[ ]:
In [ ]:
img.getpixel((1,1))
Out[ ]:
In [ ]:
import numpy as np
In [ ]:
%timeit -n 10 a = np.array(PIL.Image.open(il.items[0]))
Be careful of resampling methods, you can quickly lose some textures!
In [ ]:
img.resize((128,128), resample=PIL.Image.ANTIALIAS)
Out[ ]:
In [ ]:
img.resize((128,128), resample=PIL.Image.BILINEAR)
Out[ ]:
In [ ]:
img.resize((128,128), resample=PIL.Image.NEAREST)
Out[ ]:
In [ ]:
img.resize((256,256), resample=PIL.Image.BICUBIC).resize((128,128), resample=PIL.Image.NEAREST)
Out[ ]:
In [ ]:
%timeit img.resize((224,224), resample=PIL.Image.BICUBIC)
In [ ]:
%timeit img.resize((224,224), resample=PIL.Image.BILINEAR)
In [ ]:
%timeit -n 10 img.resize((224,224), resample=PIL.Image.NEAREST)
Flip can be done with PIL very fast.
In [ ]:
#export
import random
In [ ]:
def pil_random_flip(x):
return x.transpose(PIL.Image.FLIP_LEFT_RIGHT) if random.random()<0.5 else x
In [ ]:
il1 = get_il(tfms)
il1.items = [il1.items[0]]*64
dl = DataLoader(il1, 8)
In [ ]:
x = next(iter(dl))
Here is a convenience function to look at images in a batch.
In [ ]:
#export
def show_image(im, ax=None, figsize=(3,3)):
if ax is None: _,ax = plt.subplots(1, 1, figsize=figsize)
ax.axis('off')
ax.imshow(im.permute(1,2,0))
def show_batch(x, c=4, r=None, figsize=None):
n = len(x)
if r is None: r = int(math.ceil(n/c))
if figsize is None: figsize=(c*3,r*3)
fig,axes = plt.subplots(r,c, figsize=figsize)
for xi,ax in zip(x,axes.flat): show_image(xi, ax)
Without data augmentation:
In [ ]:
show_batch(x)
With random flip:
In [ ]:
il1.tfms.append(pil_random_flip)
In [ ]:
x = next(iter(dl))
show_batch(x)
We can also make that transform a class so it's easier to set the value of the parameter p
. As seen before, it also allows us to set the _order
attribute.
In [ ]:
class PilRandomFlip(Transform):
_order=11
def __init__(self, p=0.5): self.p=p
def __call__(self, x):
return x.transpose(PIL.Image.FLIP_LEFT_RIGHT) if random.random()<self.p else x
In [ ]:
#export
class PilTransform(Transform): _order=11
class PilRandomFlip(PilTransform):
def __init__(self, p=0.5): self.p=p
def __call__(self, x):
return x.transpose(PIL.Image.FLIP_LEFT_RIGHT) if random.random()<self.p else x
In [ ]:
del(il1.tfms[-1])
il1.tfms.append(PilRandomFlip(0.8))
In [ ]:
x = next(iter(dl))
show_batch(x)
PIL can also do the whole dihedral group of transformations (random horizontal flip, random vertical flip and the four 90 degrees rotation) with the transpose
method. Here are the codes of a few transformations:
In [ ]:
PIL.Image.FLIP_LEFT_RIGHT,PIL.Image.ROTATE_270,PIL.Image.TRANSVERSE
Out[ ]:
Be careful that img.transpose(0)
is already one transform, so doing nothing requires a separate case, then we have 7 different transformations.
In [ ]:
img = PIL.Image.open(il.items[0])
img = img.resize((128,128), resample=PIL.Image.NEAREST)
_, axs = plt.subplots(2, 4, figsize=(12, 6))
for i,ax in enumerate(axs.flatten()):
if i==0: ax.imshow(img)
else: ax.imshow(img.transpose(i-1))
ax.axis('off')
And we can implement it like this:
In [ ]:
#export
class PilRandomDihedral(PilTransform):
def __init__(self, p=0.75): self.p=p*7/8 #Little hack to get the 1/8 identity dihedral transform taken into account.
def __call__(self, x):
if random.random()>self.p: return x
return x.transpose(random.randint(0,6))
In [ ]:
del(il1.tfms[-1])
il1.tfms.append(PilRandomDihedral())
In [ ]:
show_batch(next(iter(dl)))
In [ ]:
img = PIL.Image.open(il.items[0])
img.size
Out[ ]:
To crop an image with PIL we have to specify the top/left and bottom/right corner in this format: (left, top, right, bottom). We won't just crop the size we want, but first crop the section we want of the image and then apply a resize. In what follows, we call the first one the crop_size
.
In [ ]:
img.crop((60,60,320,320)).resize((128,128), resample=PIL.Image.BILINEAR)
Out[ ]:
In [ ]:
cnr2 = (60,60,320,320)
resample = PIL.Image.BILINEAR
This is pretty fast in PIL:
In [ ]:
%timeit -n 10 img.crop(cnr2).resize((128,128), resample=resample)
Our time budget: aim for 5 mins per batch for imagenet on 8 GPUs. 1.25m images in imagenet. So on one GPU per minute that's 1250000/8/5 == 31250
, or 520 per second. Assuming 4 cores per GPU, then we want ~125 images per second - so try to stay <10ms per image. Here we have time to do more things. For instance, we can do the crop and resize in the same call to transform
, which will give a smoother result.
In [ ]:
img.transform((128,128), PIL.Image.EXTENT, cnr2, resample=resample)
Out[ ]:
In [ ]:
%timeit -n 10 img.transform((128,128), PIL.Image.EXTENT, cnr2, resample=resample)
It's a little bit slower but still fast enough for our purpose, so we will use this. We then define a general crop transform and two subclasses: one to crop at the center (for validation) and one to randomly crop. Each time, the subclass only implements the way to get the four corners passed to PIL.
In [ ]:
#export
from random import randint
def process_sz(sz):
sz = listify(sz)
return tuple(sz if len(sz)==2 else [sz[0],sz[0]])
def default_crop_size(w,h): return [w,w] if w < h else [h,h]
class GeneralCrop(PilTransform):
def __init__(self, size, crop_size=None, resample=PIL.Image.BILINEAR):
self.resample,self.size = resample,process_sz(size)
self.crop_size = None if crop_size is None else process_sz(crop_size)
def default_crop_size(self, w,h): return default_crop_size(w,h)
def __call__(self, x):
csize = self.default_crop_size(*x.size) if self.crop_size is None else self.crop_size
return x.transform(self.size, PIL.Image.EXTENT, self.get_corners(*x.size, *csize), resample=self.resample)
def get_corners(self, w, h): return (0,0,w,h)
class CenterCrop(GeneralCrop):
def __init__(self, size, scale=1.14, resample=PIL.Image.BILINEAR):
super().__init__(size, resample=resample)
self.scale = scale
def default_crop_size(self, w,h): return [w/self.scale,h/self.scale]
def get_corners(self, w, h, wc, hc):
return ((w-wc)//2, (h-hc)//2, (w-wc)//2+wc, (h-hc)//2+hc)
In [ ]:
il1.tfms = [make_rgb, CenterCrop(128), to_byte_tensor, to_float_tensor]
In [ ]:
show_batch(next(iter(dl)))
This is the usual data augmentation used on ImageNet (introduced here) that consists of selecting 8 to 100% of the image area and a scale between 3/4 and 4/3 as a crop, then resizing it to the desired size. It combines some zoom and a bit of squishing at a very low computational cost.
In [ ]:
# export
class RandomResizedCrop(GeneralCrop):
def __init__(self, size, scale=(0.08,1.0), ratio=(3./4., 4./3.), resample=PIL.Image.BILINEAR):
super().__init__(size, resample=resample)
self.scale,self.ratio = scale,ratio
def get_corners(self, w, h, wc, hc):
area = w*h
#Tries 10 times to get a proper crop inside the image.
for attempt in range(10):
area = random.uniform(*self.scale) * area
ratio = math.exp(random.uniform(math.log(self.ratio[0]), math.log(self.ratio[1])))
new_w = int(round(math.sqrt(area * ratio)))
new_h = int(round(math.sqrt(area / ratio)))
if new_w <= w and new_h <= h:
left = random.randint(0, w - new_w)
top = random.randint(0, h - new_h)
return (left, top, left + new_w, top + new_h)
# Fallback to squish
if w/h < self.ratio[0]: size = (w, int(w/self.ratio[0]))
elif w/h > self.ratio[1]: size = (int(h*self.ratio[1]), h)
else: size = (w, h)
return ((w-size[0])//2, (h-size[1])//2, (w+size[0])//2, (h+size[1])//2)
In [ ]:
il1.tfms = [make_rgb, RandomResizedCrop(128), to_byte_tensor, to_float_tensor]
In [ ]:
show_batch(next(iter(dl)))
To do perspective warping, we map the corners of the image to new points: for instance, if we want to tilt the image so that the top looks closer to us, the top/left corner needs to be shifted to the right and the top/right to the left. To avoid squishing, the bottom/left corner needs to be shifted to the left and the bottom/right corner to the right. For instance, if we have an image with corners in:
(60,60,60,280,280,280,280,60)
(top/left, bottom/left, bottom/right, top/right) then a warped version is:
(90,60,30,280,310,280,250,60)
PIL can do this for us but it requires 8 coefficients we need to calculate. The math isn't the most important here, as we've done it for you. We need to solve this equation. The equation solver is called torch.solve
in PyTorch.
In [ ]:
# export
from torch import FloatTensor,LongTensor
def find_coeffs(orig_pts, targ_pts):
matrix = []
#The equations we'll need to solve.
for p1, p2 in zip(targ_pts, orig_pts):
matrix.append([p1[0], p1[1], 1, 0, 0, 0, -p2[0]*p1[0], -p2[0]*p1[1]])
matrix.append([0, 0, 0, p1[0], p1[1], 1, -p2[1]*p1[0], -p2[1]*p1[1]])
A = FloatTensor(matrix)
B = FloatTensor(orig_pts).view(8, 1)
#The 8 scalars we seek are solution of AX = B
return list(torch.solve(B,A)[0][:,0])
In [ ]:
# export
def warp(img, size, src_coords, resample=PIL.Image.BILINEAR):
w,h = size
targ_coords = ((0,0),(0,h),(w,h),(w,0))
c = find_coeffs(src_coords,targ_coords)
res = img.transform(size, PIL.Image.PERSPECTIVE, list(c), resample=resample)
return res
In [ ]:
targ = ((0,0),(0,128),(128,128),(128,0))
src = ((90,60),(30,280),(310,280),(250,60))
In [ ]:
c = find_coeffs(src, targ)
img.transform((128,128), PIL.Image.PERSPECTIVE, list(c), resample=resample)
Out[ ]:
In [ ]:
%timeit -n 10 warp(img, (128,128), src)
In [ ]:
%timeit -n 10 warp(img, (128,128), src, resample=PIL.Image.NEAREST)
In [ ]:
warp(img, (64,64), src, resample=PIL.Image.BICUBIC)
Out[ ]:
In [ ]:
warp(img, (64,64), src, resample=PIL.Image.NEAREST)
Out[ ]:
In [ ]:
# export
def uniform(a,b): return a + (b-a) * random.random()
We can add a transform to do this perspective warping automatically with the rand resize and crop.
In [ ]:
class PilTiltRandomCrop(PilTransform):
def __init__(self, size, crop_size=None, magnitude=0., resample=PIL.Image.NEAREST):
self.resample,self.size,self.magnitude = resample,process_sz(size),magnitude
self.crop_size = None if crop_size is None else process_sz(crop_size)
def __call__(self, x):
csize = default_crop_size(*x.size) if self.crop_size is None else self.crop_size
up_t,lr_t = uniform(-self.magnitude, self.magnitude),uniform(-self.magnitude, self.magnitude)
left,top = randint(0,x.size[0]-csize[0]),randint(0,x.size[1]-csize[1])
src_corners = tensor([[-up_t, -lr_t], [up_t, 1+lr_t], [1-up_t, 1-lr_t], [1+up_t, lr_t]])
src_corners = src_corners * tensor(csize).float() + tensor([left,top]).float()
src_corners = tuple([(int(o[0].item()), int(o[1].item())) for o in src_corners])
return warp(x, self.size, src_corners, resample=self.resample)
In [ ]:
il1.tfms = [make_rgb, PilTiltRandomCrop(128, magnitude=0.1), to_byte_tensor, to_float_tensor]
In [ ]:
x = next(iter(dl))
show_batch(x)
Problem is that black padding appears as soon as our target points are outside of the image, so we have to limit the magnitude if we want to avoid that.
In [ ]:
# export
class PilTiltRandomCrop(PilTransform):
def __init__(self, size, crop_size=None, magnitude=0., resample=PIL.Image.BILINEAR):
self.resample,self.size,self.magnitude = resample,process_sz(size),magnitude
self.crop_size = None if crop_size is None else process_sz(crop_size)
def __call__(self, x):
csize = default_crop_size(*x.size) if self.crop_size is None else self.crop_size
left,top = randint(0,x.size[0]-csize[0]),randint(0,x.size[1]-csize[1])
top_magn = min(self.magnitude, left/csize[0], (x.size[0]-left)/csize[0]-1)
lr_magn = min(self.magnitude, top /csize[1], (x.size[1]-top) /csize[1]-1)
up_t,lr_t = uniform(-top_magn, top_magn),uniform(-lr_magn, lr_magn)
src_corners = tensor([[-up_t, -lr_t], [up_t, 1+lr_t], [1-up_t, 1-lr_t], [1+up_t, lr_t]])
src_corners = src_corners * tensor(csize).float() + tensor([left,top]).float()
src_corners = tuple([(int(o[0].item()), int(o[1].item())) for o in src_corners])
return warp(x, self.size, src_corners, resample=self.resample)
In [ ]:
il1.tfms = [make_rgb, PilTiltRandomCrop(128, 200, magnitude=0.2), to_byte_tensor, to_float_tensor]
In [ ]:
x = next(iter(dl))
show_batch(x)
In [ ]:
[(o._order,o) for o in sorted(tfms, key=operator.attrgetter('_order'))]
Out[ ]:
In [ ]:
#export
import numpy as np
def np_to_float(x): return torch.from_numpy(np.array(x, dtype=np.float32, copy=False)).permute(2,0,1).contiguous()/255.
np_to_float._order = 30
It is actually faster to combine to_float_tensor
and to_byte_tensor
in one transform using numpy.
In [ ]:
%timeit -n 10 to_float_tensor(to_byte_tensor(img))
In [ ]:
%timeit -n 10 np_to_float(img)
You can write your own augmentation for your domain's data types, and have them run on the GPU, by using regular PyTorch tensor operations. Here's an example for images. The key is to do them on a whole batch at a time. Nearly all PyTorch operations can be done batch-wise.
Once we have resized our images so that we can batch them together, we can apply more data augmentation on a batch level. For the affine/coord transforms, we proceed like this:
For 1. and 3. there are PyTorch functions: F.affine_grid
and F.grid_sample
. F.affine_grid
can even combine 1 and 2 if we just want to do an affine transformation.
In [ ]:
il1.tfms = [make_rgb, PilTiltRandomCrop(128, magnitude=0.2), to_byte_tensor, to_float_tensor]
In [ ]:
dl = DataLoader(il1, 64)
In [ ]:
x = next(iter(dl))
In [ ]:
from torch import FloatTensor
In [ ]:
def affine_grid_cpu(size):
N, C, H, W = size
grid = FloatTensor(N, H, W, 2)
linear_points = torch.linspace(-1, 1, W) if W > 1 else tensor([-1])
grid[:, :, :, 0] = torch.ger(torch.ones(H), linear_points).expand_as(grid[:, :, :, 0])
linear_points = torch.linspace(-1, 1, H) if H > 1 else tensor([-1])
grid[:, :, :, 1] = torch.ger(linear_points, torch.ones(W)).expand_as(grid[:, :, :, 1])
return grid
In [ ]:
grid = affine_grid_cpu(x.size())
In [ ]:
grid.shape
Out[ ]:
In [ ]:
grid[0,:5,:5]
Out[ ]:
In [ ]:
%timeit -n 10 grid = affine_grid_cpu(x.size())
Coords in the grid go from -1, to 1 (PyTorch convention).
PyTorch version is slower on the CPU but optimized to go very fast on the GPU
In [ ]:
m = tensor([[1., 0., 0.], [0., 1., 0.]])
theta = m.expand(x.size(0), 2, 3)
In [ ]:
theta.shape
Out[ ]:
In [ ]:
%timeit -n 10 grid = F.affine_grid(theta, x.size())
In [ ]:
%timeit -n 10 grid = F.affine_grid(theta.cuda(), x.size())
So we write our own version that dispatches on the CPU with our function and uses PyTorch's on the GPU.
In [ ]:
def affine_grid(x, size):
size = (size,size) if isinstance(size, int) else tuple(size)
size = (x.size(0),x.size(1)) + size
if x.device.type == 'cpu': return affine_grid_cpu(size)
m = tensor([[1., 0., 0.], [0., 1., 0.]], device=x.device)
return F.affine_grid(m.expand(x.size(0), 2, 3), size)
In [ ]:
grid = affine_grid(x, 128)
In 2D an affine transformation has the form y = Ax + b where A is a 2x2 matrix and b a vector with 2 coordinates. It's usually represented by the 3x3 matrix
A[0,0] A[0,1] b[0]
A[1,0] A[1,1] b[1]
0 0 1
because then the composition of two affine transforms can be computed with the matrix product of their 3x3 representations.
In [ ]:
from torch import stack,zeros_like,ones_like
The matrix for a rotation that has an angle of theta
is:
cos(theta) -sin(theta) 0
sin(theta) cos(theta) 0
0 0 1
Here we have to apply the reciprocal of a regular rotation (exercise: find why!) so we use this matrix:
cos(theta) sin(theta) 0
-sin(theta) cos(theta) 0
0 0 1
then we draw a different theta
for each version of the image in the batch to return a batch of rotation matrices (size bs x 3 x 3
).
In [ ]:
def rotation_matrix(thetas):
thetas.mul_(math.pi/180)
rows = [stack([thetas.cos(), thetas.sin(), torch.zeros_like(thetas)], dim=1),
stack([-thetas.sin(), thetas.cos(), torch.zeros_like(thetas)], dim=1),
stack([torch.zeros_like(thetas), torch.zeros_like(thetas), torch.ones_like(thetas)], dim=1)]
return stack(rows, dim=1)
In [ ]:
thetas = torch.empty(x.size(0)).uniform_(-30,30)
In [ ]:
thetas[:5]
Out[ ]:
In [ ]:
m = rotation_matrix(thetas)
In [ ]:
m.shape, m[:,None].shape, grid.shape
Out[ ]:
In [ ]:
grid.view(64,-1,2).shape
Out[ ]:
We have to apply our rotation to every point in the grid. The matrix a is given by the first two rows and two columns of m
and the vector b
is the first two coefficients of the last column. Of course we have to deal with the fact that here m
is a batch of matrices.
In [ ]:
a = m[:,:2,:2]
b = m[:, 2:,:2]
tfm_grid = (grid.view(64,-1,2) @ a + b).view(64, 128, 128, 2)
We can also do this without the view
by using broadcasting.
In [ ]:
%timeit -n 10 tfm_grid = grid @ m[:,None,:2,:2] + m[:,2,:2][:,None,None]
In [ ]:
%timeit -n 10 tfm_grid = torch.einsum('bijk,bkl->bijl', grid, m[:,:2,:2]) + m[:,2,:2][:,None,None]
In [ ]:
%timeit -n 10 tfm_grid = torch.matmul(grid, m[:,:2,:2].unsqueeze(1)) + m[:,2,:2][:,None,None]
In [ ]:
%timeit -n 10 tfm_grid = (torch.bmm(grid.view(64,-1,2), m[:,:2,:2]) + m[:,2,:2][:,None]).view(-1, 128, 128, 2)
And on the GPU
In [ ]:
grid = grid.cuda()
m = m.cuda()
In [ ]:
%timeit -n 10 tfm_grid = grid @ m[:,None,:2,:2] + m[:,2,:2][:,None,None]
In [ ]:
%timeit -n 10 tfm_grid = torch.einsum('bijk,bkl->bijl', grid, m[:,:2,:2]) + m[:,2,:2][:,None,None]
In [ ]:
%timeit -n 10 tfm_grid = torch.matmul(grid, m[:,:2,:2].unsqueeze(1)) + m[:,2,:2][:,None,None]
In [ ]:
%timeit -n 10 tfm_grid = (torch.bmm(grid.view(64,-1,2), m[:,:2,:2]) + m[:,2,:2][:,None]).view(-1, 128, 128, 2)
Since bmm
is always the fastest, we use this one for the matrix multiplication.
In [ ]:
tfm_grid = torch.bmm(grid.view(64,-1,2), m[:,:2,:2]).view(-1, 128, 128, 2)
The interpolation to find our coordinates back is done by grid_sample
.
In [ ]:
tfm_x = F.grid_sample(x, tfm_grid.cpu())
In [ ]:
show_batch(tfm_x, r=2)
It takes a padding_mode
argument.
In [ ]:
tfm_x = F.grid_sample(x, tfm_grid.cpu(), padding_mode='reflection')
In [ ]:
show_batch(tfm_x, r=2)
Let's look at the speed now!
In [ ]:
def rotate_batch(x, size, degrees):
grid = affine_grid(x, size)
thetas = x.new(x.size(0)).uniform_(-degrees,degrees)
m = rotation_matrix(thetas)
tfm_grid = grid @ m[:,:2,:2].unsqueeze(1) + m[:,2,:2][:,None,None]
return F.grid_sample(x, tfm_grid)
In [ ]:
show_batch(rotate_batch(x, 128, 30), r=2)
In [ ]:
%timeit -n 10 tfm_x = rotate_batch(x, 128, 30)
In [ ]:
%timeit -n 10 tfm_x = rotate_batch(x.cuda(), 128, 30)
Not bad for 64 rotations!
But we can be even faster!
In [ ]:
from torch import Tensor
In [ ]:
from torch.jit import script
@script
def rotate_batch(x:Tensor, size:int, degrees:float) -> Tensor:
sz = (x.size(0),x.size(1)) + (size,size)
idm = torch.zeros(2,3, device=x.device)
idm[0,0] = 1.
idm[1,1] = 1.
grid = F.affine_grid(idm.expand(x.size(0), 2, 3), sz)
thetas = torch.zeros(x.size(0), device=x.device).uniform_(-degrees,degrees)
m = rotation_matrix(thetas)
tfm_grid = torch.matmul(grid, m[:,:2,:2].unsqueeze(1)) + m[:,2,:2].unsqueeze(1).unsqueeze(2)
return F.grid_sample(x, tfm_grid)
In [ ]:
m = tensor([[1., 0., 0.], [0., 1., 0.]], device=x.device)
In [ ]:
%timeit -n 10 tfm_x = rotate_batch(x.cuda(), 128, 30)
The speed of this depends a lot on what card you have. On a V100 it is generally about 3x faster than non-JIT (as at April 2019) although PyTorch JIT is rapidly improving.
And even faster if we give the matrix rotation to affine_grid
.
In [ ]:
def rotate_batch(x, size, degrees):
size = (size,size) if isinstance(size, int) else tuple(size)
size = (x.size(0),x.size(1)) + size
thetas = x.new(x.size(0)).uniform_(-degrees,degrees)
m = rotation_matrix(thetas)
grid = F.affine_grid(m[:,:2], size)
return F.grid_sample(x.cuda(), grid)
In [ ]:
%timeit -n 10 tfm_x = rotate_batch(x.cuda(), 128, 30)
In [ ]:
!./notebook2script.py 10_augmentation.ipynb
In [ ]: