Sometimes one needs to change the domain of parameters during Research
execution. update_domain
mathod helps to do that.
We start with some useful imports and constant definitions
In [1]:
import sys
import os
import shutil
import numpy as np
import matplotlib
%matplotlib inline
os.environ["CUDA_VISIBLE_DEVICES"] = "6"
In [2]:
sys.path.append('../../..')
from batchflow import Pipeline, B, C, V, D, L
from batchflow.opensets import CIFAR10
from batchflow.models.torch import VGG7, VGG16, ResNet18
from batchflow.research import Research, Option, Results, PrintLogger, RP, RR
In [3]:
BATCH_SIZE=64
ds = CIFAR10()
In [4]:
def clear_previous_results(res_name):
if os.path.exists(res_name):
shutil.rmtree(res_name)
Let us solve the following problem: for one epoch we will train three models: VGG7, VGG16 and ResNet18, then we will choose the best model with the highest test accuracy and finally will train it for 10 epochs. Define pipelines where we will change 'model'
and 'n_epochs'
.
In [5]:
model_config={
'inputs/images/shape': B('image_shape'),
'inputs/labels/classes': D('num_classes'),
'inputs/labels/name': 'targets',
'initial_block/inputs': 'images'
}
In [6]:
train_pipeline = (Pipeline()
.set_dataset(C('dataset'))
.init_variable('loss')
.init_model('dynamic', C('model'), 'conv', config=model_config)
.to_array(dtype='float32')
.train_model('conv', B('images'), B('labels'),
fetches='loss', save_to=V('loss', mode='w'))
.run_later(batch_size=BATCH_SIZE, n_epochs=C('n_epochs'))
)
test_pipeline = (Pipeline()
.init_variable('predictions')
.init_variable('metrics')
.import_model('conv', C('import_from'))
.to_array(dtype='float32')
.predict_model('conv', B('images'),
fetches='predictions', save_to=V('predictions'))
.gather_metrics('class', targets=B('labels'), predictions=V('predictions'),
fmt='logits', axis=-1, save_to=V('metrics', mode='a'))
.run_later(batch_size=BATCH_SIZE, n_epochs=1)) << CIFAR10().test
Firstly, define initial domain.
In [7]:
domain = Option('model', [VGG7, VGG16, ResNet18]) * Option('n_epochs', [1])
To update domain we can define some function which return new domain
or None
if domain will not be updated. In our case funtion update_domain
accepts research results as pandas.DataFrame
, takes model with the highest accuracy and create new domain with that model and n_epochs=2
.
In [8]:
def update_domain(results):
best_model = results.iloc[results['accuracy'].idxmax()].model
domain = Option('model', [best_model]) * Option('n_epochs', [2])
return domain
We add update function into research as a parameter of update_domain
function. each
parameter defines how often function will be applied. If each='last'
, update function will be applied when current domain will be exhausted. n_updates
parameter defines the number of domain updates. All other parameters are used as kwargs
for update_domain
function.
One can also define some callable to update each config. For example, we can define dataset of size 1000 for the first stage of experiment and the whole train for the second.
In [9]:
def create_dataset(config):
update = config['update']
if update == 0:
dataset = ds.train.create_subset(np.arange(1000))
else:
dataset = ds.train
return {'dataset': dataset}
In [10]:
research = (Research()
.add_pipeline(train_pipeline, variables='loss', name='train_ppl')
.add_pipeline(test_pipeline, run=False, name='test_ppl',
import_from=RP('train_ppl'), execute='last')
.get_metrics(pipeline='test_ppl', metrics_var='metrics', metrics_name='accuracy',
returns='accuracy', execute='last')
.init_domain(domain)
.update_config(create_dataset)
.update_domain(update_domain, each='last', n_updates=1,
results=RR(names='test_ppl_metrics', use_alias=False).df))
res_name = 'dynamic_domain_research'
clear_previous_results(res_name)
research.run(n_iters=None, name=res_name, bar=True)
Out[10]:
Resulting pandas.DataFrame
will have 'update'
column with the number of updates before we get current config.
In [11]:
acc = research.load_results(names='test_ppl_metrics', update=1).df
print('Best model: ', acc.model.values[0])
print('Final accuracy:', acc.accuracy.values[0])
Updates will be stopped after n_updates
or when update function will return None
.