In [ ]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
In [ ]:
#export
from exp.nb_11 import *
In [ ]:
path = datasets.untar_data(datasets.URLs.IMAGEWOOF_160)
In [ ]:
size = 128
bs = 64
tfms = [make_rgb, RandomResizedCrop(size, scale=(0.35,1)), np_to_float, PilRandomFlip()]
val_tfms = [make_rgb, CenterCrop(size), np_to_float]
il = ImageList.from_files(path, tfms=tfms)
sd = SplitData.split_by_func(il, partial(grandparent_splitter, valid_name='val'))
ll = label_by_func(sd, parent_labeler, proc_y=CategoryProcessor())
ll.valid.x.tfms = val_tfms
data = ll.to_databunch(bs, c_in=3, c_out=10, num_workers=8)
In [ ]:
len(il)
In [ ]:
loss_func = LabelSmoothingCrossEntropy()
opt_func = adam_opt(mom=0.9, mom_sqr=0.99, eps=1e-6, wd=1e-2)
In [ ]:
learn = cnn_learner(xresnet18, data, loss_func, opt_func, norm=norm_imagenette)
In [ ]:
def sched_1cycle(lr, pct_start=0.3, mom_start=0.95, mom_mid=0.85, mom_end=0.95):
phases = create_phases(pct_start)
sched_lr = combine_scheds(phases, cos_1cycle_anneal(lr/10., lr, lr/1e5))
sched_mom = combine_scheds(phases, cos_1cycle_anneal(mom_start, mom_mid, mom_end))
return [ParamScheduler('lr', sched_lr),
ParamScheduler('mom', sched_mom)]
In [ ]:
lr = 3e-3
pct_start = 0.5
cbsched = sched_1cycle(lr, pct_start)
In [ ]:
learn.fit(40, cbsched)
In [ ]:
st = learn.model.state_dict()
In [ ]:
type(st)
In [ ]:
', '.join(st.keys())
In [ ]:
st['10.bias']
In [ ]:
mdl_path = path/'models'
mdl_path.mkdir(exist_ok=True)
It's also possible to save the whole model, including the architecture, but it gets quite fiddly and we don't recommend it. Instead, just save the parameters, and recreate the model directly.
In [ ]:
torch.save(st, mdl_path/'iw5')
In [ ]:
pets = datasets.untar_data(datasets.URLs.PETS)
In [ ]:
pets.ls()
In [ ]:
pets_path = pets/'images'
In [ ]:
il = ImageList.from_files(pets_path, tfms=tfms)
In [ ]:
il
In [ ]:
#export
def random_splitter(fn, p_valid): return random.random() < p_valid
In [ ]:
random.seed(42)
In [ ]:
sd = SplitData.split_by_func(il, partial(random_splitter, p_valid=0.1))
In [ ]:
sd
In [ ]:
n = il.items[0].name; n
In [ ]:
re.findall(r'^(.*)_\d+.jpg$', n)[0]
In [ ]:
def pet_labeler(fn): return re.findall(r'^(.*)_\d+.jpg$', fn.name)[0]
In [ ]:
proc = CategoryProcessor()
In [ ]:
ll = label_by_func(sd, pet_labeler, proc_y=proc)
In [ ]:
', '.join(proc.vocab)
In [ ]:
ll.valid.x.tfms = val_tfms
In [ ]:
c_out = len(proc.vocab)
In [ ]:
data = ll.to_databunch(bs, c_in=3, c_out=c_out, num_workers=8)
In [ ]:
learn = cnn_learner(xresnet18, data, loss_func, opt_func, norm=norm_imagenette)
In [ ]:
learn.fit(5, cbsched)
In [ ]:
learn = cnn_learner(xresnet18, data, loss_func, opt_func, c_out=10, norm=norm_imagenette)
In [ ]:
st = torch.load(mdl_path/'iw5')
In [ ]:
m = learn.model
In [ ]:
m.load_state_dict(st)
In [ ]:
cut = next(i for i,o in enumerate(m.children()) if isinstance(o,nn.AdaptiveAvgPool2d))
m_cut = m[:cut]
In [ ]:
xb,yb = get_batch(data.valid_dl, learn)
In [ ]:
pred = m_cut(xb)
In [ ]:
pred.shape
In [ ]:
ni = pred.shape[1]
In [ ]:
#export
class AdaptiveConcatPool2d(nn.Module):
def __init__(self, sz=1):
super().__init__()
self.output_size = sz
self.ap = nn.AdaptiveAvgPool2d(sz)
self.mp = nn.AdaptiveMaxPool2d(sz)
def forward(self, x): return torch.cat([self.mp(x), self.ap(x)], 1)
In [ ]:
nh = 40
m_new = nn.Sequential(
m_cut, AdaptiveConcatPool2d(), Flatten(),
nn.Linear(ni*2, data.c_out))
In [ ]:
learn.model = m_new
In [ ]:
learn.fit(5, cbsched)
In [ ]:
def adapt_model(learn, data):
cut = next(i for i,o in enumerate(learn.model.children())
if isinstance(o,nn.AdaptiveAvgPool2d))
m_cut = learn.model[:cut]
xb,yb = get_batch(data.valid_dl, learn)
pred = m_cut(xb)
ni = pred.shape[1]
m_new = nn.Sequential(
m_cut, AdaptiveConcatPool2d(), Flatten(),
nn.Linear(ni*2, data.c_out))
learn.model = m_new
In [ ]:
learn = cnn_learner(xresnet18, data, loss_func, opt_func, c_out=10, norm=norm_imagenette)
learn.model.load_state_dict(torch.load(mdl_path/'iw5'))
In [ ]:
adapt_model(learn, data)
In [ ]:
for p in learn.model[0].parameters(): p.requires_grad_(False)
In [ ]:
learn.fit(3, sched_1cycle(1e-2, 0.5))
In [ ]:
for p in learn.model[0].parameters(): p.requires_grad_(True)
In [ ]:
learn.fit(5, cbsched, reset_opt=True)
In [ ]:
learn = cnn_learner(xresnet18, data, loss_func, opt_func, c_out=10, norm=norm_imagenette)
learn.model.load_state_dict(torch.load(mdl_path/'iw5'))
adapt_model(learn, data)
In [ ]:
def apply_mod(m, f):
f(m)
for l in m.children(): apply_mod(l, f)
def set_grad(m, b):
if isinstance(m, (nn.Linear,nn.BatchNorm2d)): return
if hasattr(m, 'weight'):
for p in m.parameters(): p.requires_grad_(b)
In [ ]:
apply_mod(learn.model, partial(set_grad, b=False))
In [ ]:
learn.fit(3, sched_1cycle(1e-2, 0.5))
In [ ]:
apply_mod(learn.model, partial(set_grad, b=True))
In [ ]:
learn.fit(5, cbsched, reset_opt=True)
Pytorch already has an apply
method we can use:
In [ ]:
learn.model.apply(partial(set_grad, b=False));
In [ ]:
learn = cnn_learner(xresnet18, data, loss_func, opt_func, c_out=10, norm=norm_imagenette)
In [ ]:
learn.model.load_state_dict(torch.load(mdl_path/'iw5'))
adapt_model(learn, data)
In [ ]:
def bn_splitter(m):
def _bn_splitter(l, g1, g2):
if isinstance(l, nn.BatchNorm2d): g2 += l.parameters()
elif hasattr(l, 'weight'): g1 += l.parameters()
for ll in l.children(): _bn_splitter(ll, g1, g2)
g1,g2 = [],[]
_bn_splitter(m[0], g1, g2)
g2 += m[1:].parameters()
return g1,g2
In [ ]:
a,b = bn_splitter(learn.model)
In [ ]:
test_eq(len(a)+len(b), len(list(m.parameters())))
In [ ]:
Learner.ALL_CBS
In [ ]:
#export
from types import SimpleNamespace
cb_types = SimpleNamespace(**{o:o for o in Learner.ALL_CBS})
In [ ]:
cb_types.after_backward
In [ ]:
#export
class DebugCallback(Callback):
_order = 999
def __init__(self, cb_name, f=None): self.cb_name,self.f = cb_name,f
def __call__(self, cb_name):
if cb_name==self.cb_name:
if self.f: self.f(self.run)
else: set_trace()
In [ ]:
#export
def sched_1cycle(lrs, pct_start=0.3, mom_start=0.95, mom_mid=0.85, mom_end=0.95):
phases = create_phases(pct_start)
sched_lr = [combine_scheds(phases, cos_1cycle_anneal(lr/10., lr, lr/1e5))
for lr in lrs]
sched_mom = combine_scheds(phases, cos_1cycle_anneal(mom_start, mom_mid, mom_end))
return [ParamScheduler('lr', sched_lr),
ParamScheduler('mom', sched_mom)]
In [ ]:
disc_lr_sched = sched_1cycle([0,3e-2], 0.5)
In [ ]:
learn = cnn_learner(xresnet18, data, loss_func, opt_func,
c_out=10, norm=norm_imagenette, splitter=bn_splitter)
learn.model.load_state_dict(torch.load(mdl_path/'iw5'))
adapt_model(learn, data)
In [ ]:
def _print_det(o):
print (len(o.opt.param_groups), o.opt.hypers)
raise CancelTrainException()
learn.fit(1, disc_lr_sched + [DebugCallback(cb_types.after_batch, _print_det)])
In [ ]:
learn.fit(3, disc_lr_sched)
In [ ]:
disc_lr_sched = sched_1cycle([1e-3,1e-2], 0.3)
In [ ]:
learn.fit(5, disc_lr_sched)
In [ ]:
!./notebook2script.py 11a_transfer_learning.ipynb
In [ ]: