Quickstart

This is a simple tutorial to get started with Cogitare main functionalities.

In this tutorial, we will write a Convolutional Neural Network (CNN) to classify handwritten digits (MNIST).

Model

We start by defining our CNN model.

When developing a model with Cogitare, your model must extend the cogitare.Model class. This class provides the Model interface, which allows you to train and evaluate the model efficiently.

To implement a model, you must extend the cogitare.Model class and implement the forward() and loss() methods. The forward method will receive the batch. In this way, it is necessary to implement the forward pass through the network in this method, and then return the output of the net. The loss method will receive the output of the forward() and the batch received from the iterator, apply a loss function, compute and return it.

The Model interface will iterate over the dataset, and execute each batch on forward, loss, and backward.


In [1]:
# adapted from https://github.com/pytorch/examples/blob/master/mnist/main.py
from cogitare import Model
from cogitare import utils
from cogitare.data import DataSet, AsyncDataLoader
from cogitare.plugins import EarlyStopping
from cogitare.metrics.classification import accuracy
import cogitare

import torch.nn as nn
import torch
import torch.nn.functional as F
from torch.nn.utils import clip_grad_norm
import torch.optim as optim

from sklearn.datasets import fetch_mldata

import numpy as np

CUDA = True


cogitare.utils.set_cuda(CUDA)

In [2]:
class CNN(Model):
    
    def __init__(self):
        super(CNN, self).__init__()
        
        # define the model
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)
    
    def forward(self, batch):
        # in this sample, each batch will be a tuple containing (input_batch, expected_batch)
        # in forward in are only interested in input so that we can ignore the second item of the tuple
        input, _ = batch
        
        # batch X flat tensor -> batch X 1 channel (gray) X width X heigth
        input = input.view(32, 1, 28, 28)
        
        # pass the data in the net
        x = F.relu(F.max_pool2d(self.conv1(input), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)

        # return the model output
        return F.log_softmax(x, dim=1)
    
    def loss(self, output, batch):
        # in this sample, each batch will be a tuple containing (input_batch, expected_batch)
        # in loss in are only interested in expected so that we can ignore the first item of the tuple
        _, expected = batch
        
        return F.nll_loss(output, expected)

The model class is simple; it only requires de forward and loss methods. By default, Cogitare will backward the loss returned by the loss() method, and optimize the model parameters. If you want to disable the Cogitare backward and optimization steps, just return None in the loss function. If you return None, you are responsible by backwarding and optimizing the parameters.

Data Loading

In this step, we will load the data from sklearn package.


In [3]:
mnist = fetch_mldata('MNIST original')
mnist.data = (mnist.data / 255).astype(np.float32)

Cogitare provides a toolbox to load and pre-process data for your models. In this introduction, we will use the DataSet and the AsyncDataLoader as examples.

The DataSet is responsible by iterating over multiples data iterators (in our case, we'll have two data iterators: input samples, expected samples).


In [4]:
# as input, the DataSet is expected a list of iterators. In our case, the first iterator is the input 
# data and the second iterator is the target data

# also, we set the batch size to 32 and enable the shuffling

# drop the last batch if its size is different of 32
data = DataSet([mnist.data, mnist.target.astype(int)], batch_size=32, shuffle=True, drop_last=True)

# then, we split our dataset into a train and into a validation sets, by a ratio of 0.8
data_train, data_validation = data.split(0.8)

Notice that Cogitare accepts any iterator as input. Instead of using our DataSet, you can use the mnist.data itself, PyTorch's data loaders, or any other input that acts as an iterator.

In some cases, we can increase the model performance by loading the data using multiples threads/processes or by pre-loading the data before being requested by the model.

With the AsyncDataLoader, we can load N batches ahead of the model execution in parallel. We present this technique in this sample because it can increase performance in a wide range of models (when the data loading or pre-processing is slower than the model execution).


In [5]:
def pre_process(batch):
    input, expected = batch
    
    # the data is a numpy.ndarray (loaded from sklearn), so we need to convert it to Variable
    input = utils.to_variable(input, dtype=torch.FloatTensor)  # converts to a torch Variable of LongTensor
    expected = utils.to_variable(expected, dtype=torch.LongTensor)  # converts to a torch Variable of LongTensor
    return input, expected


# we wrap our data_train and data_validation iterators over the async data loader.
# each loader will load 16 batches ahead of the model execution using 8 workers (8 threads, in this case).
# for each batch, it will be pre-processed in parallel with the preprocess function, that will load the data
# on GPU
data_train = AsyncDataLoader(data_train, buffer_size=16, mode='threaded', workers=8, on_batch_loaded=pre_process)
data_validation = AsyncDataLoader(data_validation, buffer_size=16, mode='threaded', workers=8, on_batch_loaded=pre_process)

to cache the async buffer before training, we can:


In [6]:
data_train.cache()
data_validation.cache()

Let's look how the data looks like:


In [7]:
next(data_train)


Out[7]:
(Variable containing:
  0.0000  0.0000  0.0000  ...   0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  ...   0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  ...   0.0000  0.0000  0.0000
           ...             ⋱             ...          
  0.0000  0.0000  0.0000  ...   0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  ...   0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  ...   0.0000  0.0000  0.0000
 [torch.cuda.FloatTensor of size 32x784 (GPU 0)], Variable containing:
  6
  0
  5
  8
  1
  7
  3
  2
  3
  5
  2
  6
  2
  7
  2
  5
  8
  1
  3
  8
  8
  4
  4
  0
  9
  0
  2
  6
  6
  6
  6
  2
 [torch.cuda.LongTensor of size 32 (GPU 0)])

Training

Now, we can train our model.

First, lets create the model instance and add the default plugins to watch the training status. The default plugin includes:

  • Progress bar per batch and epoch
  • Plot training and validation losses (if validation_dataset is present)
  • Log training loss

In [8]:
model = CNN()
model.register_default_plugins()

Besides that, we may want to add some extra plugins, such as the EarlyStopping. So, if the model is not decreasing the loss after N epochs, the training stops and the best model is used.

To add the early stopping algorithm, you can use:


In [9]:
early = EarlyStopping(max_tries=10, path='/tmp/model.pt')
# after 10 epochs without decreasing the loss, stop the training and the best model is saved at /tmp/model.pt

# the plugin will execute in the end of each epoch
model.register_plugin(early, 'on_end_epoch')

Also, a common technique is to clip the gradient during training. If you want to clip the grad, you can use:


In [10]:
model.register_plugin(lambda *args, **kw: clip_grad_norm(model.parameters(), 1.0), 'before_step')
# will execute the clip_grad_norm before each optimization step

Now, we define the optimizator, and then start the model training:


In [11]:
optimizer = optim.Adam(model.parameters(), lr=0.001)

if CUDA:
    model = model.cuda()
model.learn(data_train, optimizer, data_validation, max_epochs=100)


2018-02-02 20:59:23 sprawl cogitare.core.model[2443] INFO Model: 

CNN(
  (conv1): Conv2d (1, 10, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d (10, 20, kernel_size=(5, 5), stride=(1, 1))
  (conv2_drop): Dropout2d(p=0.5)
  (fc1): Linear(in_features=320, out_features=50)
  (fc2): Linear(in_features=50, out_features=10)
)

2018-02-02 20:59:23 sprawl cogitare.core.model[2443] INFO Training data: 

DataSet with:
    containers: [
        TensorHolder with 1750x32 samples
	TensorHolder with 1750x32 samples
    ],
    batch size: 32


2018-02-02 20:59:23 sprawl cogitare.core.model[2443] INFO Number of trainable parameters: 21,840
2018-02-02 20:59:23 sprawl cogitare.core.model[2443] INFO Number of non-trainable parameters: 0
2018-02-02 20:59:23 sprawl cogitare.core.model[2443] INFO Total number of parameters: 21,840
2018-02-02 20:59:23 sprawl cogitare.core.model[2443] INFO Starting the training ...
2018-02-02 20:59:30 sprawl [CNN][2443] INFO [CNN] Loss: 0.547499 | 7 seconds
batch: 100%|█████████▉| 1749/1750 [00:06<00:00, 290.32it/s]
                                                           
                                           2018-02-02 20:59:37 sprawl [CNN][2443] INFO [CNN] Loss: 0.262575 | 13 seconds
batch:   0%|          | 2/1750 [00:13<00:05, 301.84it/s]
epoch:   1%|          | 1/100 [00:06<10:36,  6.43s/it]
                                                           
                                                   2018-02-02 20:59:43 sprawl [CNN][2443] INFO [CNN] Loss: 0.221933 | 20 seconds
batch:   0%|          | 2/1750 [00:19<00:05, 293.35it/s]
epoch:   2%|▏         | 2/100 [00:12<05:15,  3.22s/it]
                                                           
                                                   2018-02-02 20:59:49 sprawl [CNN][2443] INFO [CNN] Loss: 0.195854 | 26 seconds
batch:   0%|          | 2/1750 [00:25<00:05, 303.73it/s]
epoch:   3%|▎         | 3/100 [00:19<06:41,  4.14s/it]
                                                           
                                                   2018-02-02 20:59:56 sprawl [CNN][2443] INFO [CNN] Loss: 0.178861 | 32 seconds
batch:   0%|          | 2/1750 [00:32<00:05, 293.98it/s]
epoch:   4%|▍         | 4/100 [00:25<07:39,  4.79s/it]
                                                           
                                                   2018-02-02 21:00:02 sprawl [CNN][2443] INFO [CNN] Loss: 0.176302 | 39 seconds
batch:   0%|          | 2/1750 [00:38<00:05, 295.54it/s]
epoch:   5%|▌         | 5/100 [00:31<08:19,  5.26s/it]
                                                           
                                                   2018-02-02 21:00:09 sprawl [CNN][2443] INFO [CNN] Loss: 0.164552 | 45 seconds
batch:   0%|          | 2/1750 [00:45<00:05, 296.85it/s]
epoch:   6%|▌         | 6/100 [00:38<08:46,  5.60s/it]
                                                           
                                                   2018-02-02 21:00:15 sprawl [CNN][2443] INFO [CNN] Loss: 0.156181 | 52 seconds
batch:   0%|          | 2/1750 [00:51<00:06, 280.15it/s]
epoch:   7%|▋         | 7/100 [00:44<09:07,  5.89s/it]
                                                           
                                                   2018-02-02 21:00:22 sprawl [CNN][2443] INFO [CNN] Loss: 0.151165 | 58 seconds
batch:   0%|          | 2/1750 [00:57<00:05, 295.65it/s]
epoch:   8%|▊         | 8/100 [00:51<09:16,  6.05s/it]
                                                           
                                                   2018-02-02 21:00:28 sprawl [CNN][2443] INFO [CNN] Loss: 0.149550 | 1 minutes 5 seconds
batch:   0%|          | 2/1750 [01:04<00:05, 292.29it/s]
epoch:   9%|▉         | 9/100 [00:57<09:22,  6.18s/it]
                                                           
                                                    2018-02-02 21:00:35 sprawl [CNN][2443] INFO [CNN] Loss: 0.143826 | 1 minutes 11 seconds
batch:   0%|          | 2/1750 [01:10<00:05, 296.42it/s]
epoch:  10%|█         | 10/100 [01:04<09:23,  6.26s/it]
                                                           
                                                    2018-02-02 21:00:41 sprawl [CNN][2443] INFO [CNN] Loss: 0.137406 | 1 minutes 18 seconds
batch:   0%|          | 2/1750 [01:17<00:05, 293.46it/s]
epoch:  11%|█         | 11/100 [01:10<09:23,  6.33s/it]
                                                           
                                                    2018-02-02 21:00:48 sprawl [CNN][2443] INFO [CNN] Loss: 0.135855 | 1 minutes 24 seconds
batch:   0%|          | 2/1750 [01:23<00:06, 288.40it/s]
epoch:  12%|█▏        | 12/100 [01:17<09:21,  6.38s/it]
                                                           
                                                    2018-02-02 21:00:54 sprawl [CNN][2443] INFO [CNN] Loss: 0.137567 | 1 minutes 30 seconds
batch:   0%|          | 2/1750 [01:30<00:05, 300.62it/s]
epoch:  13%|█▎        | 13/100 [01:23<09:17,  6.41s/it]
                                                           
                                                    2018-02-02 21:01:00 sprawl [CNN][2443] INFO [CNN] Loss: 0.130688 | 1 minutes 37 seconds
batch:   0%|          | 2/1750 [01:36<00:05, 294.53it/s]
epoch:  14%|█▍        | 14/100 [01:30<09:11,  6.41s/it]
                                                           
                                                    2018-02-02 21:01:07 sprawl [CNN][2443] INFO [CNN] Loss: 0.124864 | 1 minutes 43 seconds
batch:   0%|          | 2/1750 [01:43<00:06, 290.84it/s]
epoch:  15%|█▌        | 15/100 [01:36<09:05,  6.42s/it]
                                                           
                                                    2018-02-02 21:01:13 sprawl [CNN][2443] INFO [CNN] Loss: 0.127796 | 1 minutes 50 seconds
batch:   0%|          | 2/1750 [01:49<00:06, 287.31it/s]
epoch:  16%|█▌        | 16/100 [01:42<08:59,  6.42s/it]
                                                           
                                                    2018-02-02 21:01:20 sprawl [CNN][2443] INFO [CNN] Loss: 0.127432 | 1 minutes 56 seconds
batch:   0%|          | 2/1750 [01:55<00:05, 299.11it/s]
epoch:  17%|█▋        | 17/100 [01:49<08:53,  6.43s/it]
                                                           
                                                    2018-02-02 21:01:26 sprawl [CNN][2443] INFO [CNN] Loss: 0.124323 | 2 minutes 2 seconds
batch:   0%|          | 2/1750 [02:02<00:05, 298.84it/s]
epoch:  18%|█▊        | 18/100 [01:55<08:45,  6.41s/it]
                                                           
                                                    2018-02-02 21:01:32 sprawl [CNN][2443] INFO [CNN] Loss: 0.121641 | 2 minutes 9 seconds
batch:   0%|          | 2/1750 [02:08<00:05, 297.19it/s]
epoch:  19%|█▉        | 19/100 [02:01<08:36,  6.38s/it]
                                                           
                                                    2018-02-02 21:01:39 sprawl [CNN][2443] INFO [CNN] Loss: 0.124505 | 2 minutes 15 seconds
batch:   0%|          | 2/1750 [02:14<00:05, 302.72it/s]
epoch:  20%|██        | 20/100 [02:08<08:29,  6.37s/it]
                                                           
                                                    2018-02-02 21:01:45 sprawl [CNN][2443] INFO [CNN] Loss: 0.123743 | 2 minutes 21 seconds
batch:   0%|          | 2/1750 [02:21<00:05, 296.21it/s]
epoch:  21%|██        | 21/100 [02:14<08:21,  6.34s/it]
                                                           
                                                    2018-02-02 21:01:51 sprawl [CNN][2443] INFO [CNN] Loss: 0.127134 | 2 minutes 28 seconds
batch:   0%|          | 2/1750 [02:27<00:05, 299.64it/s]
epoch:  22%|██▏       | 22/100 [02:20<08:13,  6.33s/it]
                                                           
                                                    2018-02-02 21:01:58 sprawl [CNN][2443] INFO [CNN] Loss: 0.122418 | 2 minutes 34 seconds
batch:   0%|          | 2/1750 [02:33<00:05, 306.62it/s]
epoch:  23%|██▎       | 23/100 [02:27<08:07,  6.34s/it]
                                                           
                                                    2018-02-02 21:02:04 sprawl [CNN][2443] INFO [CNN] Loss: 0.119646 | 2 minutes 40 seconds
batch:   0%|          | 2/1750 [02:40<00:05, 299.34it/s]
epoch:  24%|██▍       | 24/100 [02:33<07:59,  6.31s/it]
epoch:  25%|██▌       | 25/100 [02:33<07:54,  6.32s/it]2018-02-02 21:02:04 sprawl cogitare.core.model[2443] INFO Training stopped
2018-02-02 21:02:04 sprawl cogitare.core.model[2443] INFO Training finished

Stopping training after 10 tries. Best score 0.0909
Model restored from: /tmp/model.pt
Out[11]:
False

To check the model loss and accuracy on the validation dataset:


In [12]:
def model_accuracy(output, data):
    _, indices = torch.max(output, 1)
    
    return accuracy(indices, data[1])

# evaluate the model loss and accuracy over the validation dataset
metrics = model.evaluate_with_metrics(data_validation, {'loss': model.metric_loss, 'accuracy': model_accuracy})

# the metrics is an dict mapping the metric name (loss or accuracy, in this sample) to a list of the accuracy output
# we have a measurement per batch. So, to have a value of the full dataset, we take the mean value:

metrics_mean = {'loss': 0, 'accuracy': 0}
for loss, acc in zip(metrics['loss'], metrics['accuracy']):
    metrics_mean['loss'] += loss
    metrics_mean['accuracy'] += acc.item()

qtd = len(metrics['loss'])

print('Loss: {}'.format(metrics_mean['loss'] / qtd))
print('Accuracy: {}'.format(metrics_mean['accuracy'] / qtd))


Loss: 0.10143917564566948
Accuracy: 0.9846252860411899
batch:   0%|          | 2/1750 [02:50<00:05, 299.34it/s]

One of the advantages of Cogitare is the plug-and-play APIs, which let you add/remove functionalities easily. With this sample, we trained a model with training progress bar, error plotting, early stopping, grad clipping, and model evaluation easily.