Update domain in Research

Sometimes one needs to change the domain of parameters during Research execution. update_domain mathod helps to do that.

We start with some useful imports and constant definitions


In [1]:
import sys
import os
import shutil

import numpy as np
import matplotlib
%matplotlib inline

os.environ["CUDA_VISIBLE_DEVICES"] = "6"

In [2]:
sys.path.append('../../..')

from batchflow import Pipeline, B, C, V, D, L
from batchflow.opensets import CIFAR10
from batchflow.models.torch import VGG7, VGG16, ResNet18
from batchflow.research import Research, Option, Results, PrintLogger, RP, RR

In [3]:
BATCH_SIZE=64

ds = CIFAR10()

In [4]:
def clear_previous_results(res_name):
    if os.path.exists(res_name):
        shutil.rmtree(res_name)

Let us solve the following problem: for one epoch we will train three models: VGG7, VGG16 and ResNet18, then we will choose the best model with the highest test accuracy and finally will train it for 10 epochs. Define pipelines where we will change 'model' and 'n_epochs'.


In [5]:
model_config={
    'inputs/images/shape': B('image_shape'),
    'inputs/labels/classes': D('num_classes'),
    'inputs/labels/name': 'targets',
    'initial_block/inputs': 'images'
}

In [6]:
train_pipeline = (Pipeline()
            .set_dataset(C('dataset'))
            .init_variable('loss')
            .init_model('dynamic', C('model'), 'conv', config=model_config)
            .to_array(dtype='float32')
            .train_model('conv', B('images'), B('labels'),
                         fetches='loss', save_to=V('loss', mode='w'))
            .run_later(batch_size=BATCH_SIZE, n_epochs=C('n_epochs'))
)

test_pipeline = (Pipeline()
                 .init_variable('predictions')
                 .init_variable('metrics')
                 .import_model('conv', C('import_from'))
                 .to_array(dtype='float32')
                 .predict_model('conv', B('images'),
                                fetches='predictions', save_to=V('predictions'))
                 .gather_metrics('class', targets=B('labels'), predictions=V('predictions'), 
                                fmt='logits', axis=-1, save_to=V('metrics', mode='a'))
                 .run_later(batch_size=BATCH_SIZE, n_epochs=1)) << CIFAR10().test

Firstly, define initial domain.


In [7]:
domain = Option('model', [VGG7, VGG16, ResNet18]) * Option('n_epochs', [1])

To update domain we can define some function which return new domain or None if domain will not be updated. In our case funtion update_domain accepts research results as pandas.DataFrame, takes model with the highest accuracy and create new domain with that model and n_epochs=2.


In [8]:
def update_domain(results):
    best_model = results.iloc[results['accuracy'].idxmax()].model
    domain = Option('model', [best_model]) * Option('n_epochs', [2])
    return domain

We add update function into research as a parameter of update_domain function. each parameter defines how often function will be applied. If each='last', update function will be applied when current domain will be exhausted. n_updates parameter defines the number of domain updates. All other parameters are used as kwargs for update_domain function.

One can also define some callable to update each config. For example, we can define dataset of size 1000 for the first stage of experiment and the whole train for the second.


In [9]:
def create_dataset(config):
    update = config['update']
    if update == 0:
        dataset = ds.train.create_subset(np.arange(1000))
    else:
        dataset = ds.train
    return {'dataset': dataset}

In [10]:
research = (Research()
            .add_pipeline(train_pipeline, variables='loss', name='train_ppl')
            .add_pipeline(test_pipeline, run=False, name='test_ppl',
                          import_from=RP('train_ppl'), execute='last')
            .get_metrics(pipeline='test_ppl', metrics_var='metrics', metrics_name='accuracy',
                         returns='accuracy', execute='last')
            .init_domain(domain)
            .update_config(create_dataset)
            .update_domain(update_domain, each='last', n_updates=1,
                           results=RR(names='test_ppl_metrics', use_alias=False).df))

res_name = 'dynamic_domain_research'
clear_previous_results(res_name)

research.run(n_iters=None, name=res_name, bar=True)


Research dynamic_domain_research is starting...
Domain updated: 1: : 4it [04:06, 61.67s/it]                       
Out[10]:
<batchflow.research.research.Research at 0x7f595776fcc0>

Resulting pandas.DataFrame will have 'update' column with the number of updates before we get current config.


In [11]:
acc = research.load_results(names='test_ppl_metrics', update=1).df
print('Best model:    ', acc.model.values[0])
print('Final accuracy:', acc.accuracy.values[0])


Best model:     VGG16
Final accuracy: 0.546875

Updates will be stopped after n_updates or when update function will return None.