PyMB Example

Import Module


In [1]:
import sys
sys.path.append('../..')
import PyMB
from PyMB import magic  # enable %%PyMB cell magic
import numpy as np


Define the model

$$ Y \sim \mathcal{N}(\hat{Y},\sigma) $$$$ \hat{Y} = \alpha + B x $$

See the TMB tutorial for more information on writing custom models


In [2]:
%%PyMB LinearRegression
// DATA
    DATA_VECTOR(Y);
    DATA_VECTOR(x);

// PARAMETERS
    PARAMETER(alpha);
    PARAMETER(Beta);
    PARAMETER(logSigma);

// MODEL
    vector<Type> Y_hat = alpha + Beta*x;
    REPORT(Y_hat);
    Type nll = -sum(dnorm(Y, Y_hat, exp(logSigma), true));
    return nll;


Created model LinearRegression.
Using tmb_tmp/LinearRegression.cpp.
Compiled in 24.3s.

Simulate data


In [3]:
LinearRegression.data = {
    'x': np.arange(10),
    'Y': np.random.normal(np.arange(10))
}

Set initial parameter values


In [4]:
LinearRegression.init = {
    'alpha': 0.,
    'Beta': 0.,
    'logSigma': 0.
}

Fit the model

The model likelihood will be integrated wrt the random parameters. See here for more information.


In [5]:
LinearRegression.optimize(random=['alpha','Beta'])


Matching hessian patterns... Done

Model optimization complete in 0.1s.


--------------------------------------------------------------------------------


Simulated 100 draws in 0.2s.

alpha:
	mean	[ 0.31609895]
	sd	[ 0.73192141]
	draws	[[-0.72558336 -0.71637844 ...,  1.25651001 -0.43803027]]
	shape	(1, 100)
Beta:
	mean	[ 0.88899844]
	sd	[ 0.13710144]
	draws	[[ 1.09499943  1.00378394 ...,  0.55369137  0.89730876]]
	shape	(1, 100)

Extract fitted values


In [6]:
print(LinearRegression.report('Y_hat'))


[ 0.31609895  1.20509738  2.09409582  2.98309425  3.87209269  4.76109113
  5.65008956  6.539088    7.42808643  8.31708487]

Examine joint density


In [7]:
%matplotlib inline
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
df = pd.DataFrame({ k: v['draws'][0] for k,v in LinearRegression.parameters.iteritems() })
g = sns.PairGrid(df, diag_sharey=False)
g.map_lower(sns.kdeplot, cmap='Blues_d')
g.map_upper(plt.scatter)
g.map_diag(sns.kdeplot, lw=3)


/usr/bin/anaconda/lib/python2.7/site-packages/matplotlib/axes/_axes.py:476: UserWarning: No labelled objects found. Use label='...' kwarg on individual plots.
  warnings.warn("No labelled objects found. "