Linear Regression Using SVRGModule on YearPredictionMSD Dataset

In this notebook, a linear regression model will be fit on YearPredictionMSD dataset, which contains predictions of the release year of a song based on its audio features. The dataset has 90 features and over 400,000 samples. The dataset is downsampled to 5,000 in this experiment.


In [16]:
import mxnet as mx
from sklearn.datasets import load_svmlight_file
import numpy as np
import json
import tempfile
import os
from mxnet.contrib.svrg_optimization.svrg_module import SVRGModule

Read Data

The first step is to get the training features and labels and normalize the data. In this example, we will use 5000 data samples.


In [ ]:
# Download data file
# from subprocess import call
# YearPredictionMSD dataset: https://archive.ics.uci.edu/ml/datasets/yearpredictionmsd
# call(['wget', 'https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/regression/YearPredictionMSD.bz2'])
# call(['bzip2', '-d', 'YearPredictionMSD.bz2'])

In [30]:
feature_dim = 90
print("Reading data from disk...")
train_features, train_labels = load_svmlight_file('YearPredictionMSD', n_features=feature_dim, dtype=np.float32)
train_features = train_features.todense()

# normalize the data: subtract means and divide by standard deviations
label_mean = train_labels.mean()
label_std = np.sqrt(np.square(train_labels - label_mean).mean())
feature_means = train_features.mean(axis=0)
feature_stds = np.sqrt(np.square(train_features - feature_means).mean(axis=0))

train_features = (train_features - feature_means) / feature_stds
train_labels = (train_labels - label_mean) / label_std

train_features = train_features[-5000:]
train_labels = train_labels[-5000:]


Reading data from disk...

Create Linear Regression Network


In [19]:
def create_lin_reg_network(batch_size=100):
    train_iter = mx.io.NDArrayIter(train_features, train_labels, batch_size=batch_size, shuffle=True,
                               data_name='data', label_name='label')
    data = mx.sym.Variable("data")
    label = mx.sym.Variable("label")
    weight = mx.sym.Variable("fc_weight", shape=(1, 90))
    net = mx.sym.dot(data, weight.transpose())
    bias = mx.sym.Variable("fc_bias", shape=(1,), wd_mult=0.0, lr_mult=10.0)
    net = mx.sym.broadcast_plus(net, bias)
    net = mx.sym.LinearRegressionOutput(data=net, label=label)
    
    return train_iter, net

SVRGModule with SVRG Optimization

In this example, we will use intermediate level API for SVRGModule and the dump mse per epoch to JSON file for plotting graphs.


In [24]:
def train_svrg_lin_reg(num_epoch=100, batch_size=100, update_freq=2, output='svrg_lr.json', 
                       optimizer_params=None):

    di, net = create_lin_reg_network(batch_size=batch_size)
    
    #Create a SVRGModule
    mod = SVRGModule(symbol=net, context=mx.cpu(0), data_names=['data'], label_names=['label'], update_freq=update_freq)
    mod.bind(data_shapes=di.provide_data, label_shapes=di.provide_label)
    mod.init_params(initializer=mx.init.Zero(), allow_missing=False, force_init=False, allow_extra=False)
    mod.init_optimizer(kvstore='local', optimizer='sgd', optimizer_params=optimizer_params)
    metrics = mx.metric.create("mse")
    
    results = {}
    for e in range(num_epoch):
        results[e] = {}
        metrics.reset()
        if e % mod.update_freq == 0:
            mod.update_full_grads(di)
        di.reset()
        for batch in di:
            mod.forward_backward(data_batch=batch)
            mod.update()
            mod.update_metric(metrics, batch.label)
        results[e]["mse"] = metrics.get()[1]
   
    f = open(output, 'w+')
    f.write(json.dumps(results, indent=4, sort_keys=True))
    f.close()

Module with SGD Optimization


In [25]:
def train_sgd_lin_reg(num_epoch=100, batch_size=100, update_freq=2, output='sgd_lr.json', 
                       optimizer_params=None):
    
    di, net = create_lin_reg_network(batch_size=batch_size)
    
    #Create a standard module
    mod = mx.mod.Module(symbol=net, context=mx.cpu(0), data_names=['data'], label_names=['label'])
    mod.bind(data_shapes=di.provide_data, label_shapes=di.provide_label)
    mod.init_params(initializer=mx.init.Zero(), allow_missing=False, force_init=False, allow_extra=False)
    mod.init_optimizer(kvstore='local', optimizer='sgd', optimizer_params=optimizer_params)
    metrics = mx.metric.create("mse")
    
    results = {}
    for e in range(num_epoch):
        results[e] = {}
        metrics.reset()
        di.reset()
        for batch in di:
            mod.forward_backward(data_batch=batch)
            mod.update()
            mod.update_metric(metrics, batch.label)
        results[e]["mse"] = metrics.get()[1]
    f = open(output, 'w+')
    f.write(json.dumps(results, indent=4, sort_keys=True))
    f.close()

In [11]:
import json
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import pandas as pd

Training Loss over 100 Epochs Using lr_scheduler

When a large learning rate is used with SGD, training loss will drop fast but will oscillates above the minimum and never converges. With a small learning rate, it will eventually reach the minimum after many iterations. A common practice is to use learning rate scheduling by starting with a large learning rate and gradually decreasing it.


In [31]:
train_svrg_lin_reg(optimizer_params={'lr_scheduler': mx.lr_scheduler.FactorScheduler(step=10, factor=0.99)})
train_sgd_lin_reg(optimizer_params={'lr_scheduler': mx.lr_scheduler.FactorScheduler(step=10, factor=0.99)})

In [32]:
# plot graph
#Plot training loss over Epochs:
color = sns.color_palette()
#Draw Weight Variance Ratio
dataplot3 = {"svrg_mse": [], "sgd_mse": []}
with open('sgd_lr.json') as sgd_data, open('svrg_lr.json') as svrg_data:
    sgd = json.load(sgd_data)
    svrg = json.load(svrg_data)
    for epoch in range(100):
        dataplot3["svrg_mse"].append(svrg[str(epoch)]["mse"])
        dataplot3["sgd_mse"].append(sgd[str(epoch)]["mse"])

x3 = list(range(100))
plt.figure(figsize=(20, 12))
plt.title("Training Loss Over Epochs")
sns.pointplot(x3, dataplot3['svrg_mse'], color=color[9])
sns.pointplot(x3, dataplot3['sgd_mse'], color=color[8])
color_patch1 = mpatches.Patch(color=color[9], label="svrg_mse")
color_patch2 = mpatches.Patch(color=color[8], label="sgd_mse")
plt.legend(handles=[color_patch1, color_patch2])
plt.ylabel('Training Loss', fontsize=12)
plt.xlabel('Epochs', fontsize=12)


Out[32]:
Text(0.5,0,'Epochs')

Training Loss Comparison with SGD with fixed learning rates

Choosing learning rate (0.0025, 0.001, 0.005) for SGD and a relatively large learning rate 0.025 for SVRG, we can see SVRG smoothly goes down faster than SGD. Learning rate for SVRG does not need to decay to zero, which means we can start with a larger learning rate.


In [33]:
train_svrg_lin_reg(output="svrg_0.025.json", optimizer_params=(('learning_rate', 0.025),))
train_sgd_lin_reg(output="sgd_0.001.json", optimizer_params=(("learning_rate", 0.001),))
train_sgd_lin_reg(output="sgd_0.0025.json", optimizer_params=(("learning_rate", 0.0025),))
train_sgd_lin_reg(output="sgd_0.005.json", optimizer_params=(("learning_rate", 0.005),))

In [34]:
#Plot training loss over Epochs:
color = sns.color_palette()
#Draw Weight Variance Ratio
dataplot3 = {"svrg_mse": [], "sgd_mse_lr_0.001": [], "sgd_mse_lr_0.0025": [], "sgd_mse_lr_0.005":[]}
with open('sgd_0.001.json') as sgd_data, open('svrg_0.025.json') as svrg_data, open('sgd_0.0025.json') as sgd_data_2, open('sgd_0.005.json') as sgd_data_3:
    sgd = json.load(sgd_data)
    svrg = json.load(svrg_data)
    sgd_lr = json.load(sgd_data_2)
    sgd_lr_2 = json.load(sgd_data_3)
    for epoch in range(100):
        dataplot3["svrg_mse"].append(svrg[str(epoch)]["mse"])
        dataplot3["sgd_mse_lr_0.001"].append(sgd[str(epoch)]["mse"])
        dataplot3["sgd_mse_lr_0.0025"].append(sgd_lr[str(epoch)]["mse"])
        dataplot3["sgd_mse_lr_0.005"].append(sgd_lr_2[str(epoch)]["mse"])

x3 = list(range(100))
plt.figure(figsize=(20, 12))
plt.title("Training Loss Over Epochs")
sns.pointplot(x3, dataplot3['svrg_mse'], color=color[9])
sns.pointplot(x3, dataplot3['sgd_mse_lr_0.001'], color=color[8])
sns.pointplot(x3, dataplot3['sgd_mse_lr_0.0025'], color=color[3])
sns.pointplot(x3, dataplot3['sgd_mse_lr_0.005'], color=color[7])
color_patch1 = mpatches.Patch(color=color[9], label="svrg_mse_0.025")
color_patch2 = mpatches.Patch(color=color[8], label="sgd_mse_lr_0.001")
color_patch3 = mpatches.Patch(color=color[3], label="sgd_mse_lr_0.0025")
color_patch4 = mpatches.Patch(color=color[7], label="sgd_mse_lr_0.005")
plt.legend(handles=[color_patch1, color_patch2, color_patch3, color_patch4])
plt.ylabel('Training Loss', fontsize=12)
plt.xlabel('Epochs', fontsize=12)


Out[34]:
Text(0.5,0,'Epochs')