In [ ]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
In [ ]:
#export
from exp.nb_04 import *
In [ ]:
x_train,y_train,x_valid,y_valid = get_data()
train_ds,valid_ds = Dataset(x_train, y_train),Dataset(x_valid, y_valid)
nh,bs = 50,512
c = y_train.max().item()+1
loss_func = F.cross_entropy
In [ ]:
data = DataBunch(*get_dls(train_ds, valid_ds, bs), c)
In [ ]:
#export
def create_learner(model_func, loss_func, data):
return Learner(*model_func(data), loss_func, data)
In [ ]:
learn = create_learner(get_model, loss_func, data)
run = Runner([AvgStatsCallback([accuracy])])
run.fit(3, learn)
In [ ]:
learn = create_learner(partial(get_model, lr=0.3), loss_func, data)
run = Runner([AvgStatsCallback([accuracy])])
run.fit(3, learn)
In [ ]:
#export
def get_model_func(lr=0.5): return partial(get_model, lr=lr)
We define two new callbacks: the Recorder to save track of the loss and our scheduled learning rate, and a ParamScheduler that can schedule any hyperparameter as long as it's registered in the state_dict of the optimizer.
In [ ]:
#export
class Recorder(Callback):
def begin_fit(self): self.lrs,self.losses = [],[]
def after_batch(self):
if not self.in_train: return
self.lrs.append(self.opt.param_groups[-1]['lr'])
self.losses.append(self.loss.detach().cpu())
def plot_lr (self): plt.plot(self.lrs)
def plot_loss(self): plt.plot(self.losses)
class ParamScheduler(Callback):
_order=1
def __init__(self, pname, sched_func): self.pname,self.sched_func = pname,sched_func
def set_param(self):
for pg in self.opt.param_groups:
pg[self.pname] = self.sched_func(self.n_epochs/self.epochs)
def begin_batch(self):
if self.in_train: self.set_param()
Let's start with a simple linear schedule going from start to end. It returns a function that takes a pos
argument (going from 0 to 1) such that this function goes from start
(at pos=0
) to end
(at pos=1
) in a linear fashion.
In [ ]:
def sched_lin(start, end):
def _inner(start, end, pos): return start + pos*(end-start)
return partial(_inner, start, end)
We can refactor this with a decorator.
In [ ]:
#export
def annealer(f):
def _inner(start, end): return partial(f, start, end)
return _inner
@annealer
def sched_lin(start, end, pos): return start + pos*(end-start)
In [ ]:
# shift-tab works too, in Jupyter!
# sched_lin()
In [ ]:
f = sched_lin(1,2)
f(0.3)
Out[ ]:
And here are other scheduler functions:
In [ ]:
#export
@annealer
def sched_cos(start, end, pos): return start + (1 + math.cos(math.pi*(1-pos))) * (end-start) / 2
@annealer
def sched_no(start, end, pos): return start
@annealer
def sched_exp(start, end, pos): return start * (end/start) ** pos
def cos_1cycle_anneal(start, high, end):
return [sched_cos(start, high), sched_cos(high, end)]
#This monkey-patch is there to be able to plot tensors
torch.Tensor.ndim = property(lambda x: len(x.shape))
In [ ]:
annealings = "NO LINEAR COS EXP".split()
a = torch.arange(0, 100)
p = torch.linspace(0.01,1,100)
fns = [sched_no, sched_lin, sched_cos, sched_exp]
for fn, t in zip(fns, annealings):
f = fn(2, 1e-2)
plt.plot(a, [f(o) for o in p], label=t)
plt.legend();
In practice, we'll often want to combine different schedulers, the following function does that: it uses scheds[i]
for pcts[i]
of the training.
In [ ]:
#export
def combine_scheds(pcts, scheds):
assert sum(pcts) == 1.
pcts = tensor([0] + listify(pcts))
assert torch.all(pcts >= 0)
pcts = torch.cumsum(pcts, 0)
def _inner(pos):
idx = (pos >= pcts).nonzero().max()
if idx == 2: idx = 1
actual_pos = (pos-pcts[idx]) / (pcts[idx+1]-pcts[idx])
return scheds[idx](actual_pos)
return _inner
Here is an example: use 30% of the budget to go from 0.3 to 0.6 following a cosine, then the last 70% of the budget to go from 0.6 to 0.2, still following a cosine.
In [ ]:
sched = combine_scheds([0.3, 0.7], [sched_cos(0.3, 0.6), sched_cos(0.6, 0.2)])
In [ ]:
plt.plot(a, [sched(o) for o in p])
Out[ ]:
We can use it for training quite easily...
In [ ]:
cbfs = [Recorder,
partial(AvgStatsCallback,accuracy),
partial(ParamScheduler, 'lr', sched)]
In [ ]:
learn = create_learner(get_model_func(0.3), loss_func, data)
run = Runner(cb_funcs=cbfs)
In [ ]:
run.fit(3, learn)
... then check with our recorder if the learning rate followed the right schedule.
In [ ]:
run.recorder.plot_lr()
In [ ]:
run.recorder.plot_loss()
In [ ]:
!./notebook2script.py 05_anneal.ipynb
In [ ]: