This is a simple example to show how to use mxnet for multi-task learning.
The network is jointly going to learn whether a number is odd or even and to actually recognize the digit.
For example
etc
In this example we don't expect the tasks to contribute to each other much, but for example multi-task learning has been successfully applied to the domain of image captioning. In A Multi-task Learning Approach for Image Captioning by Wei Zhao, Benyou Wang, Jianbo Ye, Min Yang, Zhou Zhao, Ruotian Luo, Yu Qiao, they train a network to jointly classify images and generate text captions
In [16]:
import logging
import random
import time
import matplotlib.pyplot as plt
import mxnet as mx
from mxnet import gluon, nd, autograd
import numpy as np
In [99]:
batch_size = 128
epochs = 5
ctx = mx.gpu() if mx.context.num_gpus() > 0 else mx.cpu()
lr = 0.01
In [3]:
train_dataset = gluon.data.vision.MNIST(train=True)
test_dataset = gluon.data.vision.MNIST(train=False)
In [4]:
def transform(x,y):
x = x.transpose((2,0,1)).astype('float32')/255.
y1 = y
y2 = y % 2 #odd or even
return x, np.float32(y1), np.float32(y2)
We assign the transform to the original dataset
In [5]:
train_dataset_t = train_dataset.transform(transform)
test_dataset_t = test_dataset.transform(transform)
We load the datasets DataLoaders
In [6]:
train_data = gluon.data.DataLoader(train_dataset_t, shuffle=True, last_batch='rollover', batch_size=batch_size, num_workers=5)
test_data = gluon.data.DataLoader(test_dataset_t, shuffle=False, last_batch='rollover', batch_size=batch_size, num_workers=5)
In [7]:
print("Input shape: {}, Target Labels: {}".format(train_dataset[0][0].shape, train_dataset_t[0][1:]))
In [135]:
class MultiTaskNetwork(gluon.HybridBlock):
def __init__(self):
super(MultiTaskNetwork, self).__init__()
self.shared = gluon.nn.HybridSequential()
with self.shared.name_scope():
self.shared.add(
gluon.nn.Dense(128, activation='relu'),
gluon.nn.Dense(64, activation='relu'),
gluon.nn.Dense(10, activation='relu')
)
self.output1 = gluon.nn.Dense(10) # Digist recognition
self.output2 = gluon.nn.Dense(1) # odd or even
def hybrid_forward(self, F, x):
y = self.shared(x)
output1 = self.output1(y)
output2 = self.output2(y)
return output1, output2
We can use two different losses, one for each output
In [136]:
loss_digits = gluon.loss.SoftmaxCELoss()
loss_odd_even = gluon.loss.SigmoidBCELoss()
We create and initialize the network
In [137]:
mx.random.seed(42)
random.seed(42)
In [138]:
net = MultiTaskNetwork()
In [139]:
net.initialize(mx.init.Xavier(), ctx=ctx)
net.hybridize() # hybridize for speed
In [140]:
trainer = gluon.Trainer(net.collect_params(), 'adam', {'learning_rate':lr})
In [141]:
def evaluate_accuracy(net, data_iterator):
acc_digits = mx.metric.Accuracy(name='digits')
acc_odd_even = mx.metric.Accuracy(name='odd_even')
for i, (data, label_digit, label_odd_even) in enumerate(data_iterator):
data = data.as_in_context(ctx)
label_digit = label_digit.as_in_context(ctx)
label_odd_even = label_odd_even.as_in_context(ctx).reshape(-1,1)
output_digit, output_odd_even = net(data)
acc_digits.update(label_digit, output_digit.softmax())
acc_odd_even.update(label_odd_even, output_odd_even.sigmoid() > 0.5)
return acc_digits.get(), acc_odd_even.get()
We need to balance the contribution of each loss to the overall training and do so by tuning this alpha parameter within [0,1].
In [142]:
alpha = 0.5 # Combine losses factor
In [143]:
for e in range(epochs):
# Accuracies for each task
acc_digits = mx.metric.Accuracy(name='digits')
acc_odd_even = mx.metric.Accuracy(name='odd_even')
# Accumulative losses
l_digits_ = 0.
l_odd_even_ = 0.
for i, (data, label_digit, label_odd_even) in enumerate(train_data):
data = data.as_in_context(ctx)
label_digit = label_digit.as_in_context(ctx)
label_odd_even = label_odd_even.as_in_context(ctx).reshape(-1,1)
with autograd.record():
output_digit, output_odd_even = net(data)
l_digits = loss_digits(output_digit, label_digit)
l_odd_even = loss_odd_even(output_odd_even, label_odd_even)
# Combine the loss of each task
l_combined = (1-alpha)*l_digits + alpha*l_odd_even
l_combined.backward()
trainer.step(data.shape[0])
l_digits_ += l_digits.mean()
l_odd_even_ += l_odd_even.mean()
acc_digits.update(label_digit, output_digit.softmax())
acc_odd_even.update(label_odd_even, output_odd_even.sigmoid() > 0.5)
print("Epoch [{}], Acc Digits {:.4f} Loss Digits {:.4f}".format(
e, acc_digits.get()[1], l_digits_.asscalar()/(i+1)))
print("Epoch [{}], Acc Odd/Even {:.4f} Loss Odd/Even {:.4f}".format(
e, acc_odd_even.get()[1], l_odd_even_.asscalar()/(i+1)))
print("Epoch [{}], Testing Accuracies {}".format(e, evaluate_accuracy(net, test_data)))
In [144]:
def get_random_data():
idx = random.randint(0, len(test_dataset))
img = test_dataset[idx][0]
data, _, _ = test_dataset_t[idx]
data = data.as_in_context(ctx).expand_dims(axis=0)
plt.imshow(img.squeeze().asnumpy(), cmap='gray')
return data
In [152]:
data = get_random_data()
digit, odd_even = net(data)
digit = digit.argmax(axis=1)[0].asnumpy()
odd_even = (odd_even.sigmoid()[0] > 0.5).asnumpy()
print("Predicted digit: {}, odd: {}".format(digit, odd_even))