In [1]:
import numpy as np
import pandas as pd
import torch.nn as nn
import matplotlib.pyplot as plt
from matplotlib import colors as mcolors
import sys
sys.path.append('../../..')
from batchflow.opensets import Imagenette160
from batchflow import Pipeline, B, V, C, W
from batchflow.models.torch import ResNet34, ResBlock
from batchflow.models.torch.layers import ConvBlock
from batchflow.models.metrics import ClassificationMetrics
from batchflow.research import Research, Option, Results, KV, RP, REU, RI
from batchflow.utils import plot_results_by_config, show_research, print_results
In [2]:
# Global constants
NUM_ITERS = 50000 # number of iterations to train each model for
N_REPS = 4 # number of times to repeat each model train
RESEARCH_NAME = 'research' # name of Research object
DEVICES = [3, 4, 5, 6, 7] # devices to use
WORKERS = len(DEVICES) # number of simultaneously trained models
TEST_FREQUENCY = 200
dataset = Imagenette160() # dataset to train models on
In [3]:
domain = (Option('body', [KV({}, 'ResBlock'),
# apply chosen self-attention to each block
KV({'encoder/blocks' : {'attention_mode': 'se'}},
'block_SE_default'),
KV({'encoder/blocks' : {'attention_mode': 'scse'}},
'block_SCSE_default'),
KV({'encoder/blocks' : {'attention_mode': 'bam'}},
'block_BAM_default'),
KV({'encoder/blocks' : {'attention_mode': 'cbam'}},
'block_CBAM_default'),
KV({'encoder/blocks' : {'attention_mode': 'se', 'self_attention/ratio': 8}},
'block_SE_8'),
# apply chosen self-attention once per stage
KV({'encoder/order': ['skip', 'block', 'd'],
'encoder/downsample': {'layout': 'S', 'attention_mode': 'se'}},
'stage_SE_default'),
KV({'encoder/order': ['skip', 'block', 'd'],
'encoder/downsample': {'layout': 'S', 'attention_mode': 'scse'}},
'stage_SCSE_default'),
KV({'encoder/order': ['skip', 'block', 'd'],
'encoder/downsample': {'layout': 'S', 'attention_mode': 'bam'}},
'stage_BAM_default'),
KV({'encoder/order': ['skip', 'block', 'd'],
'encoder/downsample': {'layout': 'S', 'attention_mode': 'cbam'}},
'stage_CBAM_default'),
]))
In [4]:
config = {
'inputs/labels/classes': 10,
'body': C('body'),
'head/layout': 'Vf',
'head/units': 10,
"decay": dict(name='exp', gamma=0.1),
"n_iters": 7500,
'device': C('device'),
}
In [5]:
train_root = (dataset.train.p
.crop(shape=(160, 160), origin='center')
.to_array(channels='first', dtype=np.float32)
.multiply(multiplier=1/255)
.run_later(64, n_epochs=None, drop_last=True,
shuffle=True, prefetch=5)
)
train_pipeline = (Pipeline()
.init_variable('loss')
.init_model('dynamic', ResNet34, 'my_model', config=config)
.train_model('my_model', B('images'), B('labels'),
fetches='loss', save_to=V('loss'))
)
In [6]:
def acc(iteration, import_from):
pipeline = (dataset.test.p
.import_model('my_model', import_from)
.init_variable('true', [])
.update(V('true', mode='a'), B.labels)
.init_variable('predictions', [])
.crop(shape=(160, 160), origin='center')
.to_array(channels='first', dtype=np.float32)
.multiply(multiplier=1/255)
.predict_model('my_model', B('images'), fetches='predictions',
save_to=V('predictions', mode='a'))
)
pipeline.run(128, n_epochs=1, drop_last=False, shuffle=True)
pred = np.concatenate(pipeline.v('predictions'))
true = np.concatenate(pipeline.v('true'))
accuracy = ClassificationMetrics(true, pred, fmt='logits',
num_classes=10, axis=1).accuracy()
return accuracy
In [7]:
research = (Research()
.init_domain(domain, n_reps=N_REPS)
.add_pipeline(root=train_root, branch=train_pipeline, variables='loss',
name='train_ppl', logging=True)
.add_callable(acc, returns='acc_vall', name='acc_fn', execute=TEST_FREQUENCY,
iteration=RI(), import_from=RP('train_ppl')))
In [ ]:
!rm -rf research
research.run(NUM_ITERS, name=RESEARCH_NAME,
devices=DEVICES, workers=WORKERS,
bar=True)
In [ ]:
%%time
results = Results(path=RESEARCH_NAME, concat_config=True)
# results = research.load_results(concat_config=True)
In [14]:
show_research(results.df, layout=['train_ppl/loss', 'acc_fn/acc_vall'], average_repetitions=True,
color=list(mcolors.TABLEAU_COLORS.keys()), log_scale=False, rolling_window=10)
In [12]:
print_results(results.df, 'acc_fn/acc_vall', False, ascending=True, n_last=100)
Out[12]:
In [ ]:
In [13]:
1+0
Out[13]:
In [ ]: