This notebook includes a pybind11 implementation of Stochastic Gradient Descent Hamiltonian Monte Carlo. Performance is compared on the Pima Indian dataset.


In [3]:
import os
if not os.path.exists('./eigen'):
    ! git clone https://github.com/RLovelett/eigen.git

In [4]:
import cppimport
import numpy as np
import matplotlib.pyplot as plt
import sghmc

In [1]:
%%file wrap.cpp
<%
cfg['compiler_args'] = ['-std=c++11']
cfg['include_dirs'] = ['./eigen']
setup_pybind11(cfg)
%>

#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#include <pybind11/eigen.h>
#include <Eigen/Cholesky>

#include <random>
#include <algorithm>
#include <iterator>
#include <iostream>

namespace py = pybind11;

Eigen::VectorXd logistic(Eigen::VectorXd x) {
    return 1.0/(1.0 + exp((-x).array()));
}

Eigen::VectorXd gd(Eigen::MatrixXd X, Eigen::VectorXd y, Eigen::VectorXd beta, double alpha, int niter) {
    int n = X.rows();
    
    Eigen::VectorXd y_pred;
    Eigen::VectorXd resid;
    Eigen::VectorXd grad;
    Eigen::MatrixXd Xt = X.transpose();
            
    for (int i=0; i<niter; i++) {
        y_pred = logistic(X * beta);
        resid = y - y_pred;
        grad = Xt * resid / n;
        beta = beta + alpha * grad;
    }
    return beta;
}
   
Eigen::MatrixXd mvnorm(Eigen::VectorXd mu, Eigen::MatrixXd Sigma, int n) {
    /*
    Samples from multivariate normal
    */
    std::default_random_engine gen(std::random_device{}());
    std::normal_distribution<double> distribution(0, 1);    
    
    Eigen::MatrixXd A(Sigma.llt().matrixL());
    int p = mu.size();
    Eigen::MatrixXd Z(n,p);  
            
    for(int i=0; i<n; i++) {
        Eigen::VectorXd v(p);
        for(int j=0; j<p; j++){
            v[j] = distribution(gen);
        }
        Z.row(i) = mu + A*v;
    }
    return Z;          
}
    

std::unordered_set<int> pickSet(int N, int k, std::mt19937& gen)
{
    // Index of random rows to take.
    // Adapted from http://stackoverflow.com/questions/28287138/c-randomly-sample-k-numbers-from-range-0n-1-n-k-without-replacement/28287837
    
    std::unordered_set<int> sample;
    std::default_random_engine generator;

    for(int d = N - k; d < N; d++) {
        int t = std::uniform_int_distribution<>(0, d)(generator);
        if (sample.find(t) == sample.end()) {
            sample.insert(t);
        } else {
            sample.insert(d);
        }
    }
    return sample;
}


std::vector<int> pick(int N, int k) {
    // Randomly samples k integers from 1:N
    // Adapted from http://stackoverflow.com/questions/28287138/c-randomly-sample-k-numbers-from-range-0n-1-n-k-without-replacement/28287837        
            
    std::random_device rd;
    std::mt19937 gen(rd());

    std::unordered_set<int> elems = pickSet(N, k, gen);

    std::vector<int> result(elems.begin(), elems.end());
    std::shuffle(result.begin(), result.end(), gen);
    return result;
}
    
    

Eigen::VectorXd stogradU_logistic(Eigen::VectorXd theta, Eigen::VectorXd Y, Eigen::MatrixXd X, int nbatch, double phi) {
    // Stochastic gradient function
    int n = X.rows();
    int p = X.cols();
    
    // Allocate
    Eigen::MatrixXd Xsamp = Eigen::MatrixXd::Zero( nbatch, p ); 
    Eigen::VectorXd Ysamp = Eigen::VectorXd::Zero( nbatch );            
    Eigen::VectorXd Y_pred;
    Eigen::VectorXd epsilon;
    Eigen::VectorXd grad;
            
    std::vector<int> r = pick(n, nbatch);
            
    for(int i=0; i<nbatch; i++) {
        Xsamp.row(i) = X.row(r[i]-1);
        Ysamp.row(i) = Y.row(r[i]-1);
    }        
            
            
    Eigen::MatrixXd Xsampt = Xsamp.transpose();
                
    Y_pred = logistic(Xsamp * theta);
    epsilon = Ysamp - Y_pred;
    grad = n/nbatch * Xsampt * epsilon - phi * theta;

    return -grad;
}
        
Eigen::VectorXd sghmc_opt(Eigen::VectorXd Y, Eigen::MatrixXd X, Eigen::MatrixXd M, Eigen::MatrixXd Minv, double eps, int m, Eigen::VectorXd theta, Eigen::MatrixXd C, Eigen::MatrixXd B, Eigen::MatrixXd D, double phi, int nbatch) {          
    // Optimized sghmc
     
    int n = X.rows();
    int p = X.cols();
    Eigen::VectorXd sgrad;
    Eigen::VectorXd noise;            
    
    // Randomly sample momentum
    Eigen::VectorXd mu = Eigen::VectorXd::Zero( p );       
    Eigen::VectorXd r = mvnorm(mu,M,1).row(0);
    //Eigen::MatrixXd r = Eigen::VectorXd::Zero( p, p );   
    
       
    for(int i=0; i<m; i++) {
        theta = theta + eps * Minv * r;
                    
        sgrad = stogradU_logistic(theta, Y, X, nbatch, phi);
        noise = mvnorm(mu,D,1).row(0);
        r = r - eps*sgrad - eps*C*Minv*r + noise;
    }        

    return theta;

}              

Eigen::MatrixXd sghmc_opt_run(Eigen::VectorXd Y, Eigen::MatrixXd X, Eigen::MatrixXd M, double eps, int m, Eigen::VectorXd theta, Eigen::MatrixXd C, Eigen::MatrixXd V, double phi, int nsample, int nbatch) {
    // sghmc wrapper  
    
    int n = X.rows();
    int p = X.cols();  
    
    // Precompute        
    Eigen::MatrixXd Minv = M;        
            
    Eigen::MatrixXd B = 0.5 * V * eps;
    Eigen::MatrixXd D = 2*(C-B)*eps;
    
    //Allocate        
    Eigen::MatrixXd samples(nsample,p);        
            
    for(int i=0; i<nsample; i++) {
        theta = sghmc_opt(Y, X, M, Minv, eps, m, theta, C, B, D, phi, nbatch);
        samples.row(i) = theta;
    }
    
    return samples;      
}        
        
PYBIND11_PLUGIN(wrap) {
    py::module m("wrap", "pybind11 example plugin");
    m.def("gd", &gd, "The gradient descent fucntion.");
    m.def("logistic", &logistic, "The logistic function.");
    m.def("mvnorm", &mvnorm, "Random multivariate normal function");
    m.def("sghmc_opt", &sghmc_opt, "SGHMC");
    m.def("stogradU_logistic", &stogradU_logistic, "Logistic stochastic gradient");
    m.def("sghmc_opt_run", &sghmc_opt_run, "Wrapper for sghmc");
    m.def("pickSet", &pickSet, "Random sampling helper");
    m.def("pick", &pick, "Random sampling");

    return m.ptr();
}


Overwriting wrap.cpp

In [5]:
cppimport.force_rebuild() 
funcs = cppimport.imp("wrap")

In [40]:
### Load data and set parameters

pima = np.genfromtxt('pima-indians-diabetes.data', delimiter=',')
# Load data
X = np.concatenate((np.ones((pima.shape[0],1)),pima[:,0:8]), axis=1)
Y = pima[:,8]

Xs = (X - np.mean(X, axis=0))/np.concatenate((np.ones(1),np.std(X[:,1:], axis=0)))
Xs = Xs[:,1:]

n, q = Xs.shape

# SGHMC - Scaled (no intercept)
nsample = 1000
m = 20
eps = .002
theta = np.zeros(q)
phi = 5
nbatch = 100
C = 1 * np.identity(q)
V = 0 * np.identity(q)
M = np.identity(q)

pybind11

prun:


In [41]:
%prun -q -D work_pybind11.prof  funcs.sghmc_opt_run(Y, Xs, M, eps, m, np.zeros(q), C, V, phi, nsample, nbatch)


 
*** Profile stats marshalled to file 'work_pybind11.prof'. 

In [42]:
import pstats
p = pstats.Stats('work_pybind11.prof')
p.sort_stats('time', 'cumulative').print_stats()
pass


Mon May  1 19:07:35 2017    work_pybind11.prof

         5 function calls in 0.744 seconds

   Ordered by: internal time, cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.743    0.743    0.743    0.743 {built-in method wrap.sghmc_opt_run}
        1    0.000    0.000    0.744    0.744 {built-in method builtins.exec}
        1    0.000    0.000    0.743    0.743 <string>:1(<module>)
        1    0.000    0.000    0.000    0.000 {built-in method numpy.core.multiarray.zeros}
        1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}


timeit:


In [58]:
%timeit -n10 -r3 funcs.sghmc_opt_run(Y, Xs, M, eps, m, np.zeros(q), C, V, phi, nsample, nbatch)


10 loops, best of 3: 693 ms per loop

In [56]:
%timeit -n1000 -r10 funcs.stogradU_logistic(theta, Y, Xs, nbatch, phi)


1000 loops, best of 10: 45.6 µs per loop

Python

prun:


In [46]:
%prun -q -D work_python.prof sghmc.run_sghmc(Y, Xs, sghmc.U_logistic, sghmc.stogradU_logistic, M, eps, m, np.zeros(q), C, V, phi, nsample, nbatch)


 
*** Profile stats marshalled to file 'work_python.prof'. 

In [47]:
import pstats
p = pstats.Stats('work_python.prof')
p.sort_stats('time', 'cumulative').print_stats()
pass


Mon May  1 19:08:06 2017    work_python.prof

         1050191 function calls (1050190 primitive calls) in 5.865 seconds

   Ordered by: internal time, cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
    20000    0.988    0.000    1.395    0.000 {method 'choice' of 'mtrand.RandomState' objects}
    21000    0.929    0.000    2.657    0.000 {method 'multivariate_normal' of 'mtrand.RandomState' objects}
    20000    0.801    0.000    2.559    0.000 /home/jovyan/work/CokerAmitaiSGHMC/logistic_regression/sghmc.py:122(stogradU_logistic)
    21000    0.760    0.000    1.217    0.000 /opt/conda/lib/python3.5/site-packages/numpy/linalg/linalg.py:1254(svd)
     1000    0.450    0.000    5.859    0.006 /home/jovyan/work/CokerAmitaiSGHMC/logistic_regression/sghmc.py:145(sghmc)
    64000    0.395    0.000    0.395    0.000 {method 'reduce' of 'numpy.ufunc' objects}
    20000    0.280    0.000    0.280    0.000 /home/jovyan/work/CokerAmitaiSGHMC/logistic_regression/sghmc.py:3(logistic)
    20000    0.148    0.000    0.334    0.000 /opt/conda/lib/python3.5/site-packages/numpy/core/fromnumeric.py:2433(prod)
    63001    0.108    0.000    0.108    0.000 {method 'astype' of 'numpy.ndarray' objects}
     1000    0.101    0.000    0.118    0.000 /home/jovyan/work/CokerAmitaiSGHMC/logistic_regression/sghmc.py:9(U_logistic)
    23000    0.088    0.000    0.271    0.000 /opt/conda/lib/python3.5/site-packages/numpy/core/fromnumeric.py:1743(sum)
    20000    0.083    0.000    0.083    0.000 {built-in method numpy.core.multiarray.arange}
    21001    0.071    0.000    0.120    0.000 /opt/conda/lib/python3.5/site-packages/numpy/linalg/linalg.py:139(_commonType)
    21003    0.062    0.000    0.062    0.000 {built-in method numpy.core.multiarray.zeros}
    40000    0.061    0.000    0.061    0.000 /opt/conda/lib/python3.5/site-packages/numpy/core/_internal.py:227(__init__)
    21000    0.046    0.000    0.191    0.000 /opt/conda/lib/python3.5/site-packages/numpy/core/fromnumeric.py:1900(any)
    21006    0.042    0.000    0.042    0.000 {built-in method builtins.hasattr}
    21001    0.040    0.000    0.086    0.000 /opt/conda/lib/python3.5/site-packages/numpy/linalg/linalg.py:106(_makearray)
    21001    0.037    0.000    0.037    0.000 /opt/conda/lib/python3.5/site-packages/numpy/linalg/linalg.py:101(get_linalg_error_extobj)
    42001    0.031    0.000    0.047    0.000 /opt/conda/lib/python3.5/site-packages/numpy/linalg/linalg.py:124(_realType)
    21001    0.029    0.000    0.038    0.000 /opt/conda/lib/python3.5/site-packages/numpy/core/numeric.py:414(asarray)
    21000    0.027    0.000    0.027    0.000 /opt/conda/lib/python3.5/site-packages/numpy/linalg/linalg.py:219(_assertNoEmpty2d)
    21000    0.025    0.000    0.035    0.000 /opt/conda/lib/python3.5/site-packages/numpy/core/numeric.py:484(asanyarray)
    42002    0.024    0.000    0.033    0.000 /opt/conda/lib/python3.5/site-packages/numpy/linalg/linalg.py:111(isComplexType)
    21001    0.023    0.000    0.031    0.000 /opt/conda/lib/python3.5/site-packages/numpy/linalg/linalg.py:198(_assertRankAtLeast2)
    21000    0.022    0.000    0.110    0.000 {method 'any' of 'numpy.ndarray' objects}
    21002    0.022    0.000    0.064    0.000 <frozen importlib._bootstrap>:996(_handle_fromlist)
    42001    0.019    0.000    0.019    0.000 {built-in method numpy.core.multiarray.array}
    23000    0.017    0.000    0.166    0.000 /opt/conda/lib/python3.5/site-packages/numpy/core/_methods.py:31(_sum)
    23002    0.017    0.000    0.017    0.000 {built-in method builtins.isinstance}
    42001    0.016    0.000    0.016    0.000 {method 'get' of 'dict' objects}
    63003    0.015    0.000    0.015    0.000 {built-in method builtins.issubclass}
    20000    0.014    0.000    0.186    0.000 /opt/conda/lib/python3.5/site-packages/numpy/core/_methods.py:34(_prod)
    21000    0.013    0.000    0.088    0.000 /opt/conda/lib/python3.5/site-packages/numpy/core/_methods.py:37(_any)
    20000    0.012    0.000    0.012    0.000 {method 'ravel' of 'numpy.ndarray' objects}
    40000    0.012    0.000    0.012    0.000 /opt/conda/lib/python3.5/site-packages/numpy/core/_internal.py:252(get_data)
    42001    0.009    0.000    0.009    0.000 {method '__array_prepare__' of 'numpy.ndarray' objects}
    21007    0.008    0.000    0.008    0.000 {built-in method builtins.getattr}
    21005    0.008    0.000    0.008    0.000 {built-in method builtins.len}
        1    0.004    0.004    5.865    5.865 /home/jovyan/work/CokerAmitaiSGHMC/logistic_regression/sghmc.py:185(run_sghmc)
        1    0.001    0.001    0.002    0.002 /opt/conda/lib/python3.5/site-packages/numpy/linalg/linalg.py:458(inv)
     1000    0.001    0.000    0.002    0.000 /opt/conda/lib/python3.5/site-packages/numpy/lib/type_check.py:432(asscalar)
     1000    0.001    0.000    0.001    0.000 {method 'item' of 'numpy.ndarray' objects}
        3    0.000    0.000    0.000    0.000 {built-in method posix.stat}
        1    0.000    0.000    0.000    0.000 <frozen importlib._bootstrap_external>:816(get_data)
      2/1    0.000    0.000    5.865    5.865 {built-in method builtins.exec}
        1    0.000    0.000    0.000    0.000 {built-in method marshal.loads}
        2    0.000    0.000    0.000    0.000 <frozen importlib._bootstrap_external>:246(cache_from_source)
        1    0.000    0.000    0.000    0.000 /opt/conda/lib/python3.5/site-packages/numpy/dual.py:12(<module>)
        1    0.000    0.000    0.000    0.000 <frozen importlib._bootstrap_external>:1215(find_spec)
        1    0.000    0.000    0.000    0.000 <frozen importlib._bootstrap>:879(_find_spec)
        1    0.000    0.000    0.000    0.000 <frozen importlib._bootstrap_external>:729(get_code)
        1    0.000    0.000    5.865    5.865 <string>:1(<module>)
        1    0.000    0.000    0.000    0.000 <frozen importlib._bootstrap>:659(_load_unlocked)
        6    0.000    0.000    0.000    0.000 <frozen importlib._bootstrap_external>:50(_path_join)
        1    0.000    0.000    0.000    0.000 <frozen importlib._bootstrap_external>:1101(_get_spec)
        6    0.000    0.000    0.000    0.000 <frozen importlib._bootstrap_external>:52(<listcomp>)
        1    0.000    0.000    0.001    0.001 <frozen importlib._bootstrap>:966(_find_and_load)
        1    0.000    0.000    0.000    0.000 <frozen importlib._bootstrap>:510(_init_module_attrs)
        1    0.000    0.000    0.001    0.001 <frozen importlib._bootstrap>:939(_find_and_load_unlocked)
        1    0.000    0.000    0.000    0.000 <frozen importlib._bootstrap_external>:474(_compile_bytecode)
        1    0.000    0.000    0.000    0.000 <frozen importlib._bootstrap_external>:419(_validate_bytecode_header)
        1    0.000    0.000    0.000    0.000 <frozen importlib._bootstrap>:176(_get_module_lock)
        1    0.000    0.000    0.000    0.000 {method 'read' of '_io.FileIO' objects}
        1    0.000    0.000    0.000    0.000 <frozen importlib._bootstrap_external>:659(exec_module)
        1    0.000    0.000    0.000    0.000 <frozen importlib._bootstrap>:570(module_from_spec)
        1    0.000    0.000    0.000    0.000 <frozen importlib._bootstrap>:321(__exit__)
        1    0.000    0.000    0.000    0.000 <frozen importlib._bootstrap>:74(__init__)
        1    0.000    0.000    0.000    0.000 <frozen importlib._bootstrap_external>:342(_get_cached)
        1    0.000    0.000    0.000    0.000 <frozen importlib._bootstrap>:163(__enter__)
        1    0.000    0.000    0.000    0.000 <frozen importlib._bootstrap_external>:1210(_get_spec)
        1    0.000    0.000    0.000    0.000 <frozen importlib._bootstrap_external>:513(spec_from_file_location)
        1    0.000    0.000    0.000    0.000 <frozen importlib._bootstrap>:94(acquire)
        1    0.000    0.000    0.000    0.000 <frozen importlib._bootstrap>:119(release)
        1    0.000    0.000    0.000    0.000 /opt/conda/lib/python3.5/site-packages/numpy/linalg/linalg.py:209(_assertNdSquareness)
        8    0.000    0.000    0.000    0.000 {method 'rpartition' of 'str' objects}
        3    0.000    0.000    0.000    0.000 <frozen importlib._bootstrap_external>:68(_path_stat)
        2    0.000    0.000    0.000    0.000 <frozen importlib._bootstrap>:406(cached)
        2    0.000    0.000    0.000    0.000 <frozen importlib._bootstrap_external>:45(_r_long)
       14    0.000    0.000    0.000    0.000 {method 'rstrip' of 'str' objects}
        8    0.000    0.000    0.000    0.000 {method 'join' of 'str' objects}
        1    0.000    0.000    0.000    0.000 {built-in method builtins.setattr}
        1    0.000    0.000    0.000    0.000 <frozen importlib._bootstrap_external>:87(_path_isfile)
        1    0.000    0.000    0.000    0.000 <frozen importlib._bootstrap_external>:78(_path_is_mode_type)
        1    0.000    0.000    0.000    0.000 <frozen importlib._bootstrap_external>:826(path_stats)
        1    0.000    0.000    0.000    0.000 <frozen importlib._bootstrap>:789(find_spec)
        1    0.000    0.000    0.000    0.000 {built-in method builtins.any}
        1    0.000    0.000    0.000    0.000 <frozen importlib._bootstrap>:35(_new_module)
        4    0.000    0.000    0.000    0.000 {method 'format' of 'str' objects}
        1    0.000    0.000    0.000    0.000 <frozen importlib._bootstrap>:314(__enter__)
        6    0.000    0.000    0.000    0.000 <frozen importlib._bootstrap_external>:366(_verbose_message)
        3    0.000    0.000    0.000    0.000 <frozen importlib._bootstrap>:852(__enter__)
        1    0.000    0.000    0.000    0.000 <frozen importlib._bootstrap_external>:1133(find_spec)
        2    0.000    0.000    0.000    0.000 <frozen importlib._bootstrap_external>:56(_path_split)
        1    0.000    0.000    0.000    0.000 <frozen importlib._bootstrap_external>:382(_check_name_wrapper)
        3    0.000    0.000    0.000    0.000 <frozen importlib._bootstrap>:856(__exit__)
        2    0.000    0.000    0.000    0.000 {built-in method from_bytes}
        1    0.000    0.000    0.000    0.000 <frozen importlib._bootstrap>:372(__init__)
        1    0.000    0.000    0.000    0.000 <frozen importlib._bootstrap_external>:786(__init__)
        1    0.000    0.000    0.000    0.000 <frozen importlib._bootstrap>:159(__init__)
        4    0.000    0.000    0.000    0.000 {built-in method _imp.release_lock}
        1    0.000    0.000    0.000    0.000 <frozen importlib._bootstrap>:310(__init__)
        1    0.000    0.000    0.000    0.000 <frozen importlib._bootstrap>:214(_call_with_frames_removed)
        1    0.000    0.000    0.000    0.000 <frozen importlib._bootstrap>:170(__exit__)
        1    0.000    0.000    0.000    0.000 <frozen importlib._bootstrap>:419(parent)
        1    0.000    0.000    0.000    0.000 <frozen importlib._bootstrap>:427(has_location)
        4    0.000    0.000    0.000    0.000 <frozen importlib._bootstrap>:324(<genexpr>)
        1    0.000    0.000    0.000    0.000 {built-in method _imp.is_frozen}
        1    0.000    0.000    0.000    0.000 <frozen importlib._bootstrap>:716(find_spec)
        1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}
        1    0.000    0.000    0.000    0.000 {built-in method _imp._fix_co_filename}
        1    0.000    0.000    0.000    0.000 {method 'endswith' of 'str' objects}
        1    0.000    0.000    0.000    0.000 {built-in method builtins.max}
        1    0.000    0.000    0.000    0.000 <frozen importlib._bootstrap>:190(cb)
        1    0.000    0.000    0.000    0.000 <frozen importlib._bootstrap_external>:811(get_filename)
        1    0.000    0.000    0.000    0.000 <frozen importlib._bootstrap_external>:1064(_path_importer_cache)
        1    0.000    0.000    0.000    0.000 <frozen importlib._bootstrap_external>:34(_relax_case)
        1    0.000    0.000    0.000    0.000 <frozen importlib._bootstrap>:225(_verbose_message)
        1    0.000    0.000    0.000    0.000 <frozen importlib._bootstrap_external>:656(create_module)
        1    0.000    0.000    0.000    0.000 {built-in method builtins.min}
        2    0.000    0.000    0.000    0.000 {built-in method _thread.allocate_lock}
        2    0.000    0.000    0.000    0.000 {built-in method _thread.get_ident}
        3    0.000    0.000    0.000    0.000 {built-in method _imp.acquire_lock}


timeit:


In [54]:
%timeit -n1000 -r10 sghmc.stogradU_logistic(theta, Y, Xs, nbatch, phi)


1000 loops, best of 10: 97.9 µs per loop

In [55]:
%timeit -n5 -r3 sghmc.run_sghmc(Y, Xs, sghmc.U_logistic, sghmc.stogradU_logistic, M, eps, m, np.zeros(q), C, V, phi, nsample, nbatch)


10 loops, best of 3: 5.12 s per loop

In [ ]: