Data block API foundations


In [ ]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

In [ ]:
#export
from exp.nb_07a import *

In [ ]:
datasets.URLs.IMAGENETTE_160


Out[ ]:
'https://s3.amazonaws.com/fast-ai-imageclas/imagenette-160'

Image ItemList

Previously we were reading in to RAM the whole MNIST dataset at once, loading it as a pickle file. We can't do that for datasets larger than our RAM capacity, so instead we leave the images on disk and just grab the ones we need for each mini-batch as we use them.

Let's use the imagenette dataset and build the data blocks we need along the way.

Get images


In [ ]:
path = datasets.untar_data(datasets.URLs.IMAGENETTE_160)
path


Out[ ]:
PosixPath('/home/ubuntu/.fastai/data/imagenette-160')

To be able to look at what's inside a directory from a notebook, we add the .ls method to Path with a monkey-patch.


In [ ]:
#export
import PIL,os,mimetypes
Path.ls = lambda x: list(x.iterdir())

In [ ]:
path.ls()


Out[ ]:
[PosixPath('/home/ubuntu/.fastai/data/imagenette-160/val'),
 PosixPath('/home/ubuntu/.fastai/data/imagenette-160/train'),
 PosixPath('/home/ubuntu/.fastai/data/imagenette-160/models')]

In [ ]:
(path/'val').ls()


Out[ ]:
[PosixPath('/home/ubuntu/.fastai/data/imagenette-160/val/n03888257'),
 PosixPath('/home/ubuntu/.fastai/data/imagenette-160/val/n03445777'),
 PosixPath('/home/ubuntu/.fastai/data/imagenette-160/val/n03425413'),
 PosixPath('/home/ubuntu/.fastai/data/imagenette-160/val/n01440764'),
 PosixPath('/home/ubuntu/.fastai/data/imagenette-160/val/n03028079'),
 PosixPath('/home/ubuntu/.fastai/data/imagenette-160/val/n02979186'),
 PosixPath('/home/ubuntu/.fastai/data/imagenette-160/val/n03394916'),
 PosixPath('/home/ubuntu/.fastai/data/imagenette-160/val/n02102040'),
 PosixPath('/home/ubuntu/.fastai/data/imagenette-160/val/n03417042'),
 PosixPath('/home/ubuntu/.fastai/data/imagenette-160/val/n03000684')]

Let's have a look inside a class folder (the first class is tench):


In [ ]:
path_tench = path/'val'/'n01440764'

In [ ]:
img_fn = path_tench.ls()[0]
img_fn


Out[ ]:
PosixPath('/home/ubuntu/.fastai/data/imagenette-160/val/n01440764/ILSVRC2012_val_00017995.JPEG')

In [ ]:
img = PIL.Image.open(img_fn)
img


Out[ ]:

In [ ]:
plt.imshow(img)


Out[ ]:
<matplotlib.image.AxesImage at 0x7fb0cf4b5320>

In [ ]:
import numpy
imga = numpy.array(img)

In [ ]:
imga.shape


Out[ ]:
(160, 213, 3)

In [ ]:
imga[:10,:10,0]


Out[ ]:
array([[1, 1, 1, 1, ..., 1, 1, 1, 1],
       [1, 1, 1, 1, ..., 1, 1, 1, 1],
       [1, 1, 1, 1, ..., 1, 1, 1, 1],
       [1, 1, 1, 1, ..., 1, 1, 1, 1],
       ...,
       [1, 1, 1, 1, ..., 1, 1, 1, 1],
       [1, 1, 1, 1, ..., 1, 1, 1, 1],
       [1, 1, 1, 1, ..., 1, 1, 1, 1],
       [1, 1, 1, 1, ..., 1, 1, 1, 1]], dtype=uint8)

Just in case there are other files in the directory (models, texts...) we want to keep only the images. Let's not write it out by hand, but instead use what's already on our computer (the MIME types database).


In [ ]:
#export
image_extensions = set(k for k,v in mimetypes.types_map.items() if v.startswith('image/'))

In [ ]:
' '.join(image_extensions)


Out[ ]:
'.svg .tif .ras .png .ppm .tiff .gif .jpe .xbm .pnm .jpg .jpeg .xpm .bmp .xwd .ief .rgb .ico .pgm .pbm'

In [ ]:
#export
def setify(o): return o if isinstance(o,set) else set(listify(o))

In [ ]:
test_eq(setify('aa'), {'aa'})
test_eq(setify(['aa',1]), {'aa',1})
test_eq(setify(None), set())
test_eq(setify(1), {1})
test_eq(setify({1}), {1})

Now let's walk through the directories and grab all the images. The first private function grabs all the images inside a given directory and the second one walks (potentially recursively) through all the folder in path.


In [ ]:
#export
def _get_files(p, fs, extensions=None):
    p = Path(p)
    res = [p/f for f in fs if not f.startswith('.')
           and ((not extensions) or f'.{f.split(".")[-1].lower()}' in extensions)]
    return res

In [ ]:
t = [o.name for o in os.scandir(path_tench)]
t = _get_files(path, t, extensions=image_extensions)
t[:3]


Out[ ]:
[PosixPath('/home/ubuntu/.fastai/data/imagenette-160/ILSVRC2012_val_00017995.JPEG'),
 PosixPath('/home/ubuntu/.fastai/data/imagenette-160/ILSVRC2012_val_00009379.JPEG'),
 PosixPath('/home/ubuntu/.fastai/data/imagenette-160/ILSVRC2012_val_00003014.JPEG')]

In [ ]:
#export
def get_files(path, extensions=None, recurse=False, include=None):
    path = Path(path)
    extensions = setify(extensions)
    extensions = {e.lower() for e in extensions}
    if recurse:
        res = []
        for i,(p,d,f) in enumerate(os.walk(path)): # returns (dirpath, dirnames, filenames)
            if include is not None and i==0: d[:] = [o for o in d if o in include]
            else:                            d[:] = [o for o in d if not o.startswith('.')]
            res += _get_files(p, f, extensions)
        return res
    else:
        f = [o.name for o in os.scandir(path) if o.is_file()]
        return _get_files(path, f, extensions)

In [ ]:
get_files(path_tench, image_extensions)[:3]


Out[ ]:
[PosixPath('/home/ubuntu/.fastai/data/imagenette-160/val/n01440764/ILSVRC2012_val_00017995.JPEG'),
 PosixPath('/home/ubuntu/.fastai/data/imagenette-160/val/n01440764/ILSVRC2012_val_00009379.JPEG'),
 PosixPath('/home/ubuntu/.fastai/data/imagenette-160/val/n01440764/ILSVRC2012_val_00003014.JPEG')]

We need the recurse argument when we start from path since the pictures are two level below in directories.


In [ ]:
get_files(path, image_extensions, recurse=True)[:3]


Out[ ]:
[PosixPath('/home/ubuntu/.fastai/data/imagenette-160/val/n03888257/ILSVRC2012_val_00016387.JPEG'),
 PosixPath('/home/ubuntu/.fastai/data/imagenette-160/val/n03888257/ILSVRC2012_val_00034544.JPEG'),
 PosixPath('/home/ubuntu/.fastai/data/imagenette-160/val/n03888257/ILSVRC2012_val_00009593.JPEG')]

In [ ]:
all_fns = get_files(path, image_extensions, recurse=True)
len(all_fns)


Out[ ]:
13394

Imagenet is 100 times bigger than imagenette, so we need this to be fast.


In [ ]:
%timeit -n 10 get_files(path, image_extensions, recurse=True)


72.6 ms ± 134 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

Prepare for modeling

What we need to do:

  • Get files
  • Split validation set
    • random%, folder name, csv, ...
  • Label:
    • folder name, file name/re, csv, ...
  • Transform per image (optional)
  • Transform to tensor
  • DataLoader
  • Transform per batch (optional)
  • DataBunch
  • Add test set (optional)

Get files

We use the ListContainer class from notebook 06 to store our objects in an ItemList. The get method will need to be subclassed to explain how to access an element (open an image for instance), then the private _get method can allow us to apply any additional transform to it.

new will be used in conjunction with __getitem__ (that works for one index or a list of indices) to create training and validation set from a single stream when we split the data.


In [ ]:
#export
def compose(x, funcs, *args, order_key='_order', **kwargs):
    key = lambda o: getattr(o, order_key, 0)
    for f in sorted(listify(funcs), key=key): x = f(x, **kwargs)
    return x

class ItemList(ListContainer):
    def __init__(self, items, path='.', tfms=None):
        super().__init__(items)
        self.path,self.tfms = Path(path),tfms

    def __repr__(self): return f'{super().__repr__()}\nPath: {self.path}'
    
    def new(self, items, cls=None):
        if cls is None: cls=self.__class__
        return cls(items, self.path, tfms=self.tfms)
    
    def  get(self, i): return i
    def _get(self, i): return compose(self.get(i), self.tfms)
    
    def __getitem__(self, idx):
        res = super().__getitem__(idx)
        if isinstance(res,list): return [self._get(o) for o in res]
        return self._get(res)

class ImageList(ItemList):
    @classmethod
    def from_files(cls, path, extensions=None, recurse=True, include=None, **kwargs):
        if extensions is None: extensions = image_extensions
        return cls(get_files(path, extensions, recurse=recurse, include=include), path, **kwargs)
    
    def get(self, fn): return PIL.Image.open(fn)

Transforms aren't only used for data augmentation. To allow total flexibility, ImageList returns the raw PIL image. The first thing is to convert it to 'RGB' (or something else).

Transforms only need to be functions that take an element of the ItemList and transform it. If they need state, they can be defined as a class. Also, having them as a class allows to define an _order attribute (default 0) that is used to sort the transforms.


In [ ]:
#export
class Transform(): _order=0

class MakeRGB(Transform):
    def __call__(self, item): return item.convert('RGB')

def make_rgb(item): return item.convert('RGB')

In [ ]:
il = ImageList.from_files(path, tfms=make_rgb)

In [ ]:
il


Out[ ]:
ImageList (13394 items)
[PosixPath('/home/ubuntu/.fastai/data/imagenette-160/val/n03888257/ILSVRC2012_val_00016387.JPEG'), PosixPath('/home/ubuntu/.fastai/data/imagenette-160/val/n03888257/ILSVRC2012_val_00034544.JPEG'), PosixPath('/home/ubuntu/.fastai/data/imagenette-160/val/n03888257/ILSVRC2012_val_00009593.JPEG'), PosixPath('/home/ubuntu/.fastai/data/imagenette-160/val/n03888257/ILSVRC2012_val_00029149.JPEG'), PosixPath('/home/ubuntu/.fastai/data/imagenette-160/val/n03888257/ILSVRC2012_val_00037770.JPEG'), PosixPath('/home/ubuntu/.fastai/data/imagenette-160/val/n03888257/ILSVRC2012_val_00009370.JPEG'), PosixPath('/home/ubuntu/.fastai/data/imagenette-160/val/n03888257/ILSVRC2012_val_00031268.JPEG'), PosixPath('/home/ubuntu/.fastai/data/imagenette-160/val/n03888257/ILSVRC2012_val_00047147.JPEG'), PosixPath('/home/ubuntu/.fastai/data/imagenette-160/val/n03888257/ILSVRC2012_val_00031035.JPEG'), PosixPath('/home/ubuntu/.fastai/data/imagenette-160/val/n03888257/ILSVRC2012_val_00020698.JPEG')...]
Path: /home/ubuntu/.fastai/data/imagenette-160

In [ ]:
img = il[0]; img


Out[ ]:

We can also index with a range or a list of integers:


In [ ]:
il[:1]


Out[ ]:
[<PIL.Image.Image image mode=RGB size=160x200 at 0x7FB0CF42D048>]

Split validation set

Here, we need to split the files between those in the folder train and those in the folder val.


In [ ]:
fn = il.items[0]; fn


Out[ ]:
PosixPath('/home/ubuntu/.fastai/data/imagenette-160/val/n03888257/ILSVRC2012_val_00016387.JPEG')

Since our filenames are path object, we can find the directory of the file with .parent. We need to go back two folders before since the last folders are the class names.


In [ ]:
fn.parent.parent.name


Out[ ]:
'val'

In [ ]:
#export
def grandparent_splitter(fn, valid_name='valid', train_name='train'):
    gp = fn.parent.parent.name
    return True if gp==valid_name else False if gp==train_name else None

def split_by_func(items, f):
    mask = [f(o) for o in items]
    # `None` values will be filtered out
    f = [o for o,m in zip(items,mask) if m==False]
    t = [o for o,m in zip(items,mask) if m==True ]
    return f,t

In [ ]:
splitter = partial(grandparent_splitter, valid_name='val')

In [ ]:
%time train,valid = split_by_func(il, splitter)


CPU times: user 38.4 ms, sys: 169 µs, total: 38.6 ms
Wall time: 38.2 ms

In [ ]:
len(train),len(valid)


Out[ ]:
(12894, 500)

Now that we can split our data, let's create the class that will contain it. It just needs two ItemList to be initialized, and we create a shortcut to all the unknown attributes by trying to grab them in the train ItemList.


In [ ]:
#export
class SplitData():
    def __init__(self, train, valid): self.train,self.valid = train,valid
        
    def __getattr__(self,k): return getattr(self.train,k)
    #This is needed if we want to pickle SplitData and be able to load it back without recursion errors
    def __setstate__(self,data:Any): self.__dict__.update(data) 
    
    @classmethod
    def split_by_func(cls, il, f):
        lists = map(il.new, split_by_func(il.items, f))
        return cls(*lists)

    def __repr__(self): return f'{self.__class__.__name__}\nTrain: {self.train}\nValid: {self.valid}\n'

In [ ]:
sd = SplitData.split_by_func(il, splitter); sd


Out[ ]:
SplitData
Train: ImageList (12894 items)
[PosixPath('/home/ubuntu/.fastai/data/imagenette-160/train/n03888257/n03888257_9403.JPEG'), PosixPath('/home/ubuntu/.fastai/data/imagenette-160/train/n03888257/n03888257_6402.JPEG'), PosixPath('/home/ubuntu/.fastai/data/imagenette-160/train/n03888257/n03888257_4446.JPEG'), PosixPath('/home/ubuntu/.fastai/data/imagenette-160/train/n03888257/n03888257_22655.JPEG'), PosixPath('/home/ubuntu/.fastai/data/imagenette-160/train/n03888257/n03888257_29390.JPEG'), PosixPath('/home/ubuntu/.fastai/data/imagenette-160/train/n03888257/n03888257_17004.JPEG'), PosixPath('/home/ubuntu/.fastai/data/imagenette-160/train/n03888257/n03888257_8837.JPEG'), PosixPath('/home/ubuntu/.fastai/data/imagenette-160/train/n03888257/n03888257_19451.JPEG'), PosixPath('/home/ubuntu/.fastai/data/imagenette-160/train/n03888257/n03888257_12883.JPEG'), PosixPath('/home/ubuntu/.fastai/data/imagenette-160/train/n03888257/n03888257_13476.JPEG')...]
Path: /home/ubuntu/.fastai/data/imagenette-160
Valid: ImageList (500 items)
[PosixPath('/home/ubuntu/.fastai/data/imagenette-160/val/n03888257/ILSVRC2012_val_00016387.JPEG'), PosixPath('/home/ubuntu/.fastai/data/imagenette-160/val/n03888257/ILSVRC2012_val_00034544.JPEG'), PosixPath('/home/ubuntu/.fastai/data/imagenette-160/val/n03888257/ILSVRC2012_val_00009593.JPEG'), PosixPath('/home/ubuntu/.fastai/data/imagenette-160/val/n03888257/ILSVRC2012_val_00029149.JPEG'), PosixPath('/home/ubuntu/.fastai/data/imagenette-160/val/n03888257/ILSVRC2012_val_00037770.JPEG'), PosixPath('/home/ubuntu/.fastai/data/imagenette-160/val/n03888257/ILSVRC2012_val_00009370.JPEG'), PosixPath('/home/ubuntu/.fastai/data/imagenette-160/val/n03888257/ILSVRC2012_val_00031268.JPEG'), PosixPath('/home/ubuntu/.fastai/data/imagenette-160/val/n03888257/ILSVRC2012_val_00047147.JPEG'), PosixPath('/home/ubuntu/.fastai/data/imagenette-160/val/n03888257/ILSVRC2012_val_00031035.JPEG'), PosixPath('/home/ubuntu/.fastai/data/imagenette-160/val/n03888257/ILSVRC2012_val_00020698.JPEG')...]
Path: /home/ubuntu/.fastai/data/imagenette-160

Labeling

Labeling has to be done after splitting, because it uses training set information to apply to the validation set, using a Processor.

A Processor is a transformation that is applied to all the inputs once at initialization, with some state computed on the training set that is then applied without modification on the validation set (and maybe the test set or at inference time on a single item). For instance, it could be processing texts to tokenize, then numericalize them. In that case we want the validation set to be numericalized with exactly the same vocabulary as the training set.

Another example is in tabular data, where we fill missing values with (for instance) the median computed on the training set. That statistic is stored in the inner state of the Processor and applied on the validation set.

In our case, we want to convert label strings to numbers in a consistent and reproducible way. So we create a list of possible labels in the training set, and then convert our labels to numbers based on this vocab.


In [ ]:
#export
from collections import OrderedDict

def uniqueify(x, sort=False):
    res = list(OrderedDict.fromkeys(x).keys())
    if sort: res.sort()
    return res

First, let's define the processor. We also define a ProcessedItemList with an obj method that can get the unprocessed items: for instance a processed label will be an index between 0 and the number of classes - 1, the corresponding obj will be the name of the class. The first one is needed by the model for the training, but the second one is better for displaying the objects.


In [ ]:
#export
class Processor(): 
    def process(self, items): return items

class CategoryProcessor(Processor):
    def __init__(self): self.vocab=None
    
    def __call__(self, items):
        #The vocab is defined on the first use.
        if self.vocab is None:
            self.vocab = uniqueify(items)
            self.otoi  = {v:k for k,v in enumerate(self.vocab)}
        return [self.proc1(o) for o in items]
    def proc1(self, item):  return self.otoi[item]
    
    def deprocess(self, idxs):
        assert self.vocab is not None
        return [self.deproc1(idx) for idx in idxs]
    def deproc1(self, idx): return self.vocab[idx]

Here we label according to the folders of the images, so simply fn.parent.name. We label the training set first with a newly created CategoryProcessor so that it computes its inner vocab on that set. Then we label the validation set using the same processor, which means it uses the same vocab. The end result is another SplitData object.


In [ ]:
#export
def parent_labeler(fn): return fn.parent.name

def _label_by_func(ds, f, cls=ItemList): return cls([f(o) for o in ds.items], path=ds.path)

#This is a slightly different from what was seen during the lesson,
#   we'll discuss the changes in lesson 11
class LabeledData():
    def process(self, il, proc): return il.new(compose(il.items, proc))

    def __init__(self, x, y, proc_x=None, proc_y=None):
        self.x,self.y = self.process(x, proc_x),self.process(y, proc_y)
        self.proc_x,self.proc_y = proc_x,proc_y
        
    def __repr__(self): return f'{self.__class__.__name__}\nx: {self.x}\ny: {self.y}\n'
    def __getitem__(self,idx): return self.x[idx],self.y[idx]
    def __len__(self): return len(self.x)
    
    def x_obj(self, idx): return self.obj(self.x, idx, self.proc_x)
    def y_obj(self, idx): return self.obj(self.y, idx, self.proc_y)
    
    def obj(self, items, idx, procs):
        isint = isinstance(idx, int) or (isinstance(idx,torch.LongTensor) and not idx.ndim)
        item = items[idx]
        for proc in reversed(listify(procs)):
            item = proc.deproc1(item) if isint else proc.deprocess(item)
        return item

    @classmethod
    def label_by_func(cls, il, f, proc_x=None, proc_y=None):
        return cls(il, _label_by_func(il, f), proc_x=proc_x, proc_y=proc_y)

def label_by_func(sd, f, proc_x=None, proc_y=None):
    train = LabeledData.label_by_func(sd.train, f, proc_x=proc_x, proc_y=proc_y)
    valid = LabeledData.label_by_func(sd.valid, f, proc_x=proc_x, proc_y=proc_y)
    return SplitData(train,valid)

In [ ]:
ll = label_by_func(sd, parent_labeler, proc_y=CategoryProcessor())

In [ ]:
assert ll.train.proc_y is ll.valid.proc_y

In [ ]:
ll.train.y


Out[ ]:
ItemList (12894 items)
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0...]
Path: /home/ubuntu/.fastai/data/imagenette-160

In [ ]:
ll.train.y.items[0], ll.train.y_obj(0), ll.train.y_obj(slice(2))


Out[ ]:
(0, 'n03888257', ['n03888257', 'n03888257'])

In [ ]:
ll


Out[ ]:
SplitData
Train: LabeledData
x: ImageList (12894 items)
[PosixPath('/home/ubuntu/.fastai/data/imagenette-160/train/n03888257/n03888257_9403.JPEG'), PosixPath('/home/ubuntu/.fastai/data/imagenette-160/train/n03888257/n03888257_6402.JPEG'), PosixPath('/home/ubuntu/.fastai/data/imagenette-160/train/n03888257/n03888257_4446.JPEG'), PosixPath('/home/ubuntu/.fastai/data/imagenette-160/train/n03888257/n03888257_22655.JPEG'), PosixPath('/home/ubuntu/.fastai/data/imagenette-160/train/n03888257/n03888257_29390.JPEG'), PosixPath('/home/ubuntu/.fastai/data/imagenette-160/train/n03888257/n03888257_17004.JPEG'), PosixPath('/home/ubuntu/.fastai/data/imagenette-160/train/n03888257/n03888257_8837.JPEG'), PosixPath('/home/ubuntu/.fastai/data/imagenette-160/train/n03888257/n03888257_19451.JPEG'), PosixPath('/home/ubuntu/.fastai/data/imagenette-160/train/n03888257/n03888257_12883.JPEG'), PosixPath('/home/ubuntu/.fastai/data/imagenette-160/train/n03888257/n03888257_13476.JPEG')...]
Path: /home/ubuntu/.fastai/data/imagenette-160
y: ItemList (12894 items)
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0...]
Path: /home/ubuntu/.fastai/data/imagenette-160

Valid: LabeledData
x: ImageList (500 items)
[PosixPath('/home/ubuntu/.fastai/data/imagenette-160/val/n03888257/ILSVRC2012_val_00016387.JPEG'), PosixPath('/home/ubuntu/.fastai/data/imagenette-160/val/n03888257/ILSVRC2012_val_00034544.JPEG'), PosixPath('/home/ubuntu/.fastai/data/imagenette-160/val/n03888257/ILSVRC2012_val_00009593.JPEG'), PosixPath('/home/ubuntu/.fastai/data/imagenette-160/val/n03888257/ILSVRC2012_val_00029149.JPEG'), PosixPath('/home/ubuntu/.fastai/data/imagenette-160/val/n03888257/ILSVRC2012_val_00037770.JPEG'), PosixPath('/home/ubuntu/.fastai/data/imagenette-160/val/n03888257/ILSVRC2012_val_00009370.JPEG'), PosixPath('/home/ubuntu/.fastai/data/imagenette-160/val/n03888257/ILSVRC2012_val_00031268.JPEG'), PosixPath('/home/ubuntu/.fastai/data/imagenette-160/val/n03888257/ILSVRC2012_val_00047147.JPEG'), PosixPath('/home/ubuntu/.fastai/data/imagenette-160/val/n03888257/ILSVRC2012_val_00031035.JPEG'), PosixPath('/home/ubuntu/.fastai/data/imagenette-160/val/n03888257/ILSVRC2012_val_00020698.JPEG')...]
Path: /home/ubuntu/.fastai/data/imagenette-160
y: ItemList (500 items)
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0...]
Path: /home/ubuntu/.fastai/data/imagenette-160

Transform to tensor


In [ ]:
ll.train[0]


Out[ ]:
(<PIL.Image.Image image mode=RGB size=160x243 at 0x7FB0CF0488D0>, 0)

In [ ]:
ll.train[0][0]


Out[ ]:

To be able to put all our images in a batch, we need them to have all the same size. We can do this easily in PIL.


In [ ]:
ll.train[0][0].resize((128,128))


Out[ ]:

The first transform resizes to a given size, then we convert the image to a by tensor before converting it to float and dividing by 255. We will investigate data augmentation transforms at length in notebook 10.


In [ ]:
#export
class ResizeFixed(Transform):
    _order=10
    def __init__(self,size):
        if isinstance(size,int): size=(size,size)
        self.size = size
        
    def __call__(self, item): return item.resize(self.size, PIL.Image.BILINEAR)

def to_byte_tensor(item):
    res = torch.ByteTensor(torch.ByteStorage.from_buffer(item.tobytes()))
    w,h = item.size
    return res.view(h,w,-1).permute(2,0,1)
to_byte_tensor._order=20

def to_float_tensor(item): return item.float().div_(255.)
to_float_tensor._order=30

In [ ]:
tfms = [make_rgb, ResizeFixed(128), to_byte_tensor, to_float_tensor]

il = ImageList.from_files(path, tfms=tfms)
sd = SplitData.split_by_func(il, splitter)
ll = label_by_func(sd, parent_labeler, proc_y=CategoryProcessor())

Here is a little convenience function to show an image from the corresponding tensor.


In [ ]:
#export
def show_image(im, figsize=(3,3)):
    plt.figure(figsize=figsize)
    plt.axis('off')
    plt.imshow(im.permute(1,2,0))

In [ ]:
x,y = ll.train[0]
x.shape


Out[ ]:
torch.Size([3, 128, 128])

In [ ]:
show_image(x)


Modeling

DataBunch

Now we are ready to put our datasets together in a DataBunch.


In [ ]:
bs=64

In [ ]:
train_dl,valid_dl = get_dls(ll.train,ll.valid,bs, num_workers=4)

In [ ]:
x,y = next(iter(train_dl))

In [ ]:
x.shape


Out[ ]:
torch.Size([64, 3, 128, 128])

We can still see the images in a batch and get the corresponding classes.


In [ ]:
show_image(x[0])
ll.train.proc_y.vocab[y[0]]


Out[ ]:
'n03888257'

In [ ]:
y


Out[ ]:
tensor([1, 6, 2, 2, 6, 0, 0, 2, 3, 1, 0, 9, 8, 6, 2, 7, 4, 9, 4, 0, 1, 6, 8, 4,
        7, 7, 0, 1, 0, 0, 3, 6, 1, 1, 8, 0, 3, 8, 1, 8, 8, 2, 1, 0, 3, 5, 3, 5,
        5, 5, 5, 9, 7, 1, 4, 6, 6, 6, 1, 9, 2, 9, 4, 0])

We change a little bit our DataBunch to add a few attributes: c_in (for channel in) and c_out (for channel out) instead of just c. This will help when we need to build our model.


In [ ]:
#export
class DataBunch():
    def __init__(self, train_dl, valid_dl, c_in=None, c_out=None):
        self.train_dl,self.valid_dl,self.c_in,self.c_out = train_dl,valid_dl,c_in,c_out

    @property
    def train_ds(self): return self.train_dl.dataset

    @property
    def valid_ds(self): return self.valid_dl.dataset

Then we define a function that goes directly from the SplitData to a DataBunch.


In [ ]:
#export
def databunchify(sd, bs, c_in=None, c_out=None, **kwargs):
    dls = get_dls(sd.train, sd.valid, bs, **kwargs)
    return DataBunch(*dls, c_in=c_in, c_out=c_out)

SplitData.to_databunch = databunchify

This gives us the full summary on how to grab our data and put it in a DataBunch:


In [ ]:
path = datasets.untar_data(datasets.URLs.IMAGENETTE_160)
tfms = [make_rgb, ResizeFixed(128), to_byte_tensor, to_float_tensor]

il = ImageList.from_files(path, tfms=tfms)
sd = SplitData.split_by_func(il, partial(grandparent_splitter, valid_name='val'))
ll = label_by_func(sd, parent_labeler, proc_y=CategoryProcessor())
data = ll.to_databunch(bs, c_in=3, c_out=10, num_workers=4)

Model


In [ ]:
cbfs = [partial(AvgStatsCallback,accuracy),
        CudaCallback]

We will normalize with the statistics from a batch.


In [ ]:
m,s = x.mean((0,2,3)).cuda(),x.std((0,2,3)).cuda()
m,s


Out[ ]:
(tensor([0.4768, 0.4785, 0.4628], device='cuda:0'),
 tensor([0.2538, 0.2485, 0.2760], device='cuda:0'))

In [ ]:
#export
def normalize_chan(x, mean, std):
    return (x-mean[...,None,None]) / std[...,None,None]

_m = tensor([0.47, 0.48, 0.45])
_s = tensor([0.29, 0.28, 0.30])
norm_imagenette = partial(normalize_chan, mean=_m.cuda(), std=_s.cuda())

In [ ]:
cbfs.append(partial(BatchTransformXCallback, norm_imagenette))

In [ ]:
nfs = [64,64,128,256]

We build our model using Bag of Tricks for Image Classification with Convolutional Neural Networks, in particular: we don't use a big conv 7x7 at first but three 3x3 convs, and don't go directly from 3 channels to 64 but progressively add those.


In [ ]:
#export
import math
def prev_pow_2(x): return 2**math.floor(math.log2(x))

def get_cnn_layers(data, nfs, layer, **kwargs):
    def f(ni, nf, stride=2): return layer(ni, nf, 3, stride=stride, **kwargs)
    l1 = data.c_in
    l2 = prev_pow_2(l1*3*3)
    layers =  [f(l1  , l2  , stride=1),
               f(l2  , l2*2, stride=2),
               f(l2*2, l2*4, stride=2)]
    nfs = [l2*4] + nfs
    layers += [f(nfs[i], nfs[i+1]) for i in range(len(nfs)-1)]
    layers += [nn.AdaptiveAvgPool2d(1), Lambda(flatten), 
               nn.Linear(nfs[-1], data.c_out)]
    return layers

def get_cnn_model(data, nfs, layer, **kwargs):
    return nn.Sequential(*get_cnn_layers(data, nfs, layer, **kwargs))

def get_learn_run(nfs, data, lr, layer, cbs=None, opt_func=None, **kwargs):
    model = get_cnn_model(data, nfs, layer, **kwargs)
    init_cnn(model)
    return get_runner(model, data, lr=lr, cbs=cbs, opt_func=opt_func)

In [ ]:
sched = combine_scheds([0.3,0.7], cos_1cycle_anneal(0.1,0.3,0.05))

In [ ]:
learn,run = get_learn_run(nfs, data, 0.2, conv_layer, cbs=cbfs+[
    partial(ParamScheduler, 'lr', sched)
])

Let's have a look at our model using Hooks. We print the layers and the shapes of their outputs.


In [ ]:
#export
def model_summary(run, learn, data, find_all=False):
    xb,yb = get_batch(data.valid_dl, run)
    device = next(learn.model.parameters()).device#Model may not be on the GPU yet
    xb,yb = xb.to(device),yb.to(device)
    mods = find_modules(learn.model, is_lin_layer) if find_all else learn.model.children()
    f = lambda hook,mod,inp,out: print(f"{mod}\n{out.shape}\n")
    with Hooks(mods, f) as hooks: learn.model(xb)

In [ ]:
model_summary(run, learn, data)


Sequential(
  (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (1): GeneralRelu()
  (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
torch.Size([128, 16, 128, 128])

Sequential(
  (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  (1): GeneralRelu()
  (2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
torch.Size([128, 32, 64, 64])

Sequential(
  (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  (1): GeneralRelu()
  (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
torch.Size([128, 64, 32, 32])

Sequential(
  (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  (1): GeneralRelu()
  (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
torch.Size([128, 64, 16, 16])

Sequential(
  (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  (1): GeneralRelu()
  (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
torch.Size([128, 64, 8, 8])

Sequential(
  (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  (1): GeneralRelu()
  (2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
torch.Size([128, 128, 4, 4])

Sequential(
  (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  (1): GeneralRelu()
  (2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
torch.Size([128, 256, 2, 2])

AdaptiveAvgPool2d(output_size=1)
torch.Size([128, 256, 1, 1])

Lambda()
torch.Size([128, 256])

Linear(in_features=256, out_features=10, bias=True)
torch.Size([128, 10])

And we can train the model:


In [ ]:
%time run.fit(5, learn)


train: [1.7975745138242594, tensor(0.3771, device='cuda:0')]
valid: [1.950084228515625, tensor(0.3640, device='cuda:0')]
train: [1.331341733558244, tensor(0.5549, device='cuda:0')]
valid: [1.182614013671875, tensor(0.6160, device='cuda:0')]
train: [1.0004353405653792, tensor(0.6729, device='cuda:0')]
valid: [0.9452028198242187, tensor(0.6740, device='cuda:0')]
train: [0.744675257750698, tensor(0.7583, device='cuda:0')]
valid: [0.8292762451171874, tensor(0.7360, device='cuda:0')]
train: [0.5341721137253761, tensor(0.8359, device='cuda:0')]
valid: [0.798895751953125, tensor(0.7360, device='cuda:0')]
CPU times: user 25.6 s, sys: 10.7 s, total: 36.4 s
Wall time: 1min 7s

The leaderboard as this notebook is written has ~85% accuracy for 5 epochs at 128px size, so we're definitely on the right track!

Export


In [ ]:
!python notebook2script.py 08_data_block.ipynb


Converted 08_data_block.ipynb to exp/nb_08.py

In [ ]: