Spatial Gaussian Process inference in PyMC3

This is the first step in modelling Species occurrence. The good news is that MCMC works, The bad one is that it's computationally intense.


In [10]:
# Load Biospytial modules and etc.
%matplotlib inline
import sys
sys.path.append('/apps/external_plugins/spystats/')
#import django
#django.setup()
import pandas as pd
import matplotlib.pyplot as plt
## Use the ggplot style
plt.style.use('ggplot')
import numpy as np

In [11]:
## Model Specification
import pymc3 as pm
from spystats import tools

Simulated gaussian data


In [12]:
sigma=3.5
range_a=10.13
kappa=3.0/2.0
#ls = 0.2
#tau = 2.0
cov = sigma * pm.gp.cov.Matern32(2, range_a,active_dims=[0,1])

In [13]:
n = 10
grid = tools.createGrid(grid_sizex=n,grid_sizey=n,minx=0,miny=0,maxx=50,maxy=50)

In [14]:
K = cov(grid[['Lon','Lat']].values).eval()
sample = pm.MvNormal.dist(mu=np.zeros(K.shape[0]), cov=K).random(size=1)
grid['Z'] = sample

In [15]:
plt.figure(figsize=(14,4))
plt.imshow(grid.Z.values.reshape(n,n),interpolation=None)


Out[15]:
<matplotlib.image.AxesImage at 0x7fcd948daf50>

In [16]:
print("sigma: %s, phi: %s"%(sigma,range_a))


sigma: 3.5, phi: 10.13

In [18]:
## Analysis, GP only one parameter to fit
# The variational method is much beter.
with pm.Model() as model:
    
    #sigma = 1.0
    sigma = pm.Uniform('sigma',0,4)
    phi = pm.Normal('phi',mu=8,sd=3)
#    phi = pm.Uniform('phi',5,10)

    cov = sigma * pm.gp.cov.Matern32(2,phi,active_dims=[0,1])
    K = cov(grid[['Lon','Lat']].values)
    y_obs = pm.MvNormal('y_obs',mu=np.zeros(n*n),cov=K,observed=grid.Z)
    
    #gp = pm.gp.Latent(cov_func=cov,observed=sample)
    # Use elliptical slice sampling
    #ess_step = pm.EllipticalSlice(vars=[f_sample], prior_cov=K)
    #ess_Step = pm.HamiltonianMC()
    #%time trace = pm.sample(5000)
    ## Variational
    %time results = pm.fit()


Average Loss = 139.51: 100%|██████████| 10000/10000 [00:28<00:00, 348.57it/s]
CPU times: user 1min 36s, sys: 2.14 s, total: 1min 39s
Wall time: 30.6 s

Diagnostics

For one parameter it took around 1.3 minutes For two parameters it took 4min 27 seconds


In [19]:
from pymc3 import find_MAP
map_estimate = find_MAP(model=model)


logp = -139.58, ||grad|| = 2.6669e-05: 100%|██████████| 12/12 [00:00<00:00, 341.72it/s]  

Simulated Poisson data with latent Gaussian Field


In [20]:
np.random.seed(1234)

sigma=3.5
range_a=10.13
kappa=3.0/2.0
#ls = 0.2
alpha = 0.0
cov = sigma * pm.gp.cov.Matern32(2, range_a,active_dims=[0,1])
n = 20
grid = tools.createGrid(grid_sizex=n,grid_sizey=n,minx=0,miny=0,maxx=20,maxy=20)
K = cov(grid[['Lon','Lat']].values).eval()
pfield = pm.MvNormal.dist(mu=np.zeros(K.shape[0]), cov=K).random(size=1)

poiss_data = np.exp(alpha + pfield)

grid['Z'] = poiss_data
#grid['Z'] = pfield
plt.figure(figsize=(14,4))
plt.imshow(grid.Z.values.reshape(n,n),interpolation=None)
plt.colorbar()
print("sigma: %s, phi: %s"%(sigma,range_a))


sigma: 3.5, phi: 10.13

In [21]:
## Analysis, GP only one parameter to fit
# The variational method is much beter.
from pymc3.variational.callbacks import CheckParametersConvergence

with pm.Model() as model:
    sigma=3.5
    range_a=10.13
    
    
    #sigma = pm.Uniform('sigma',0,4)
    #phi = pm.HalfNormal('phi',mu=8,sd=3)
    #phi = pm.Uniform('phi',6,12)
    phi = pm.Uniform('phi',5,15)
    cov = sigma * pm.gp.cov.Matern32(2,phi,active_dims=[0,1])
    #K = cov(grid[['Lon','Lat']].values)
    #phiprint = tt.printing.Print('phi')(phi)
    
    ## The latent function
    gp = pm.gp.Latent(cov_func=cov)
    
    ## I don't know why this
    f = gp.prior("latent_field", X=grid[['Lon','Lat']].values,reparameterize=True)
    
    #f_print = tt.printing.Print('latent_field')(f)
    
    y_obs = pm.Poisson('y_obs',mu=f,observed=grid.Z)
    
    #y_obs = pm.MvNormal('y_obs',mu=np.zeros(n*n),cov=K,observed=grid.Z)
    
    #gp = pm.gp.Latent(cov_func=cov,observed=sample)
    # Use elliptical slice sampling
    #ess_step = pm.EllipticalSlice(vars=[f_sample], prior_cov=K)
    #step = pm.HamiltonianMC()
    #step = pm.Metropolis()
    #%time trace = pm.sample(5000,step)#,tune=0,chains=1)
    ## Variational
    
    %time mean_field = pm.fit(method='advi', callbacks=[CheckParametersConvergence()])


Average Loss = inf: 100%|██████████| 10000/10000 [03:30<00:00, 47.47it/s]
CPU times: user 12min 55s, sys: 12.9 s, total: 13min 8s
Wall time: 3min 32s

ESsta dando un monton de inf en averafe lost


In [22]:
# pm.traceplot(trace)

In [23]:
#for RV in model.basic_RVs:
#    print(RV.name, RV.logp(model.test_point))

In [24]:
from pymc3 import find_MAP
map_estimate = find_MAP(model=model)
map_estimate


logp = -inf, ||grad|| = 214.35: 100%|██████████| 3/3 [00:00<00:00, 52.60it/s]
Out[24]:
{'latent_field': array([-0.3834217 , -0.41007374, -0.43583168, -0.4601648 , -0.4825283 ,
        -0.50238856, -0.51925151, -0.53269178, -0.54237913, -0.54809885,
        -0.54976291, -0.54741049, -0.54119765, -0.53137821, -0.51827874,
        -0.502271  , -0.48374474, -0.46308261, -0.44063796, -0.41671505,
        -0.42857215, -0.45936469, -0.48921593, -0.51749078, -0.54353027,
        -0.56668294, -0.58634113, -0.6019791 , -0.61318854, -0.61970607,
        -0.6214286 , -0.61841367, -0.61086516, -0.59910711, -0.58355045,
        -0.56465742, -0.54290813, -0.51877131, -0.49268106, -0.46502698,
        -0.4771009 , -0.51254917, -0.54702953, -0.57978574, -0.61002385,
        -0.63695061, -0.65981963, -0.67798131, -0.69093038, -0.69834355,
        -0.70010052, -0.69628413, -0.68716007, -0.67314079, -0.65474072,
        -0.63253056, -0.60709641, -0.57900781, -0.54879816, -0.51696759,
        -0.52875988, -0.56939683, -0.60906844, -0.64687973, -0.68187894,
        -0.71310384, -0.73963986, -0.76068513, -0.77561442, -0.78403094,
        -0.78579536, -0.78102558, -0.77006775, -0.753446  , -0.73180203,
        -0.70583633, -0.67625935, -0.64375753, -0.60897749, -0.57253092,
        -0.58313849, -0.62949584, -0.67492884, -0.71838623, -0.75873437,
        -0.7948123 , -0.82550359, -0.84982063, -0.8669905 , -0.87652665,
        -0.87826937, -0.87238496, -0.85932416, -0.83975158, -0.81446405,
        -0.7843152 , -0.75015737, -0.71280566, -0.67302445, -0.63153252,
        -0.63963485, -0.69221609, -0.74396115, -0.79364583, -0.83993079,
        -0.8814253 , -0.91677457, -0.94476712, -0.9644499 , -0.97522767,
        -0.97691963, -0.96975766, -0.95432684, -0.93146613, -0.90215791,
        -0.86743175, -0.82829585, -0.78569948, -0.74052295, -0.6935867 ,
        -0.69743107, -0.75667547, -0.81522527, -0.8716685 , -0.92443791,
        -0.97188341, -1.01237531, -1.04443797, -1.06690267, -1.07904518,
        -1.08066365, -1.07207541, -1.05403386, -1.02758897, -0.99393856,
        -0.95430764, -0.90986892, -0.86170482, -0.81080429, -0.75808186,
        -0.75547714, -0.82171494, -0.88745426, -0.95108237, -1.01078467,
        -1.06462644, -1.11067218, -1.14714583, -1.17262985, -1.18625499,
        -1.18779906, -1.17768205, -1.15685855, -1.12662312, -1.08841327,
        -1.04366293, -0.99371195, -0.93976864, -0.8829174 , -0.82415544,
        -0.81248903, -0.88588981, -0.95903608, -1.0301022 , -1.09700844,
        -1.15751575, -1.20936452, -1.25045977, -1.27911844, -1.2943307 ,
        -1.29585551, -1.28421356, -1.26057872, -1.22650731, -1.18368892,
        -1.13377349, -1.07826602, -1.01848873, -0.95560251, -0.89067041,
        -0.86696701, -0.94748525, -1.0280252 , -1.10653457, -1.18064915,
        -1.2478083 , -1.30542234, -1.35109707, -1.38292589, -1.39980673,
        -1.40148172, -1.38852773, -1.36229871, -1.32459367, -1.27735606,
        -1.22245627, -1.16156166, -1.09609463, -1.02727033, -0.95619404,
        -0.91723968, -1.00456366, -1.09219427, -1.17783608, -1.25881674,
        -1.33223649, -1.39518823, -1.44504731, -1.47984353, -1.49844044,
        -1.50056596, -1.48679995, -1.45852013, -1.4177019 , -1.36652861,
        -1.30709644, -1.24123268, -1.17044664, -1.09598822, -1.0189791 ,
        -0.96153778, -1.05504857, -1.14913202, -1.24123244, -1.32834348,
        -1.40722566, -1.47472324, -1.52816716, -1.5656089 , -1.58592347,
        -1.58882795, -1.57486514, -1.54535455, -1.50226434, -1.4479516 ,
        -1.38472874, -1.31457455, -1.2390684 , -1.15948115, -1.07694563,
        -0.9981003 , -1.09684859, -1.19638911, -1.29389613, -1.38601768,
        -1.46922797, -1.54030971, -1.59664061, -1.63633758, -1.65831217,
        -1.66227568, -1.64871695, -1.61885662, -1.57456188, -1.51817949,
        -1.45218142, -1.37865899, -1.29924057, -1.2152088 , -1.12772698,
        -1.02531203, -1.12802344, -1.23167333, -1.33318691, -1.42889708,
        -1.51515882, -1.58878491, -1.64724434, -1.68874315, -1.71224616,
        -1.71746594, -1.70483347, -1.67545431, -1.63104488, -1.57383329,
        -1.50627834, -1.43049519, -1.34814291, -1.2605404 , -1.16892443,
        -1.04186134, -1.14698901, -1.25310715, -1.35696214, -1.45471005,
        -1.54267491, -1.61775005, -1.67752061, -1.72029704, -1.74510821,
        -1.75167517, -1.74037523, -1.71219823, -1.66869185, -1.61188312,
        -1.54404709, -1.46720822, -1.383026  , -1.29290178, -1.19825954,
        -1.04688492, -1.15271213, -1.25949614, -1.36390892, -1.46206281,
        -1.55032394, -1.62569797, -1.68590131, -1.72935556, -1.75515707,
        -1.76303848, -1.75332723, -1.72690052, -1.68513041, -1.62979463,
        -1.56286715, -1.48620588, -1.40144219, -1.31008571, -1.21381421,
        -1.04006119, -1.14481706, -1.2504259 , -1.35358957, -1.45048928,
        -1.53760713, -1.61209312, -1.67180393, -1.71527075, -1.74165279,
        -1.75069002, -1.74265706, -1.71831461, -1.67884841, -1.62577228,
        -1.56076598, -1.48549435, -1.40154091, -1.31050949, -1.21427032,
        -1.02164367, -1.12360942, -1.22626008, -1.32642291, -1.42044931,
        -1.5050096 , -1.57743066, -1.63571442, -1.6784939 , -1.70497992,
        -1.71490973, -1.70849759, -1.6863829 , -1.6495664 , -1.59932215,
        -1.53708351, -1.46434069, -1.38261679, -1.2935607 , -1.1991372 ,
        -0.99244655, -1.09005747, -1.18812065, -1.28367348, -1.37333427,
        -1.45402937, -1.52328902, -1.57926012, -1.6206658 , -1.64675464,
        -1.65724983, -1.65229927, -1.63242332, -1.59845467, -1.5514667 ,
        -1.49269697, -1.42349177, -1.34530734, -1.25978784, -1.16890262,
        -0.95378601, -1.04573643, -1.13784652, -1.22743088, -1.31146848,
        -1.38719639, -1.45236142, -1.50525023, -1.54466449, -1.569879  ,
        -1.58059522, -1.57689391, -1.55918607, -1.5281605 , -1.48472872,
        -1.42997515, -1.36512895, -1.29157618, -1.21091932, -1.12506583]),
 'latent_field_rotated_': array([-0.20494749, -0.10095638, -0.07905667, -0.07580585, -0.07476589,
        -0.07234559, -0.06930893, -0.06557659, -0.06135392, -0.05679548,
        -0.05207112, -0.04732661, -0.0426775 , -0.03820181, -0.03394069,
        -0.02990269, -0.02606947, -0.02240043, -0.01883474, -0.01529026,
        -0.1524021 , -0.06731612, -0.07468189, -0.07543851, -0.0777946 ,
        -0.07855393, -0.07839125, -0.0770924 , -0.07478364, -0.07157796,
        -0.06765783, -0.06322154, -0.05846621, -0.05356574, -0.04865757,
        -0.04385315, -0.03920251, -0.03488334, -0.03073877, -0.0289584 ,
        -0.08196834, -0.04380591, -0.05198339, -0.05407075, -0.05768883,
        -0.0601566 , -0.06187716, -0.06257812, -0.06226409, -0.06097348,
        -0.05883516, -0.05602586, -0.05274346, -0.04917997, -0.04550224,
        -0.04187779, -0.03844754, -0.03567951, -0.03388969, -0.03450034,
        -0.07549382, -0.04041733, -0.04939809, -0.05203178, -0.05621666,
        -0.0592529 , -0.06148518, -0.06259721, -0.06256401, -0.06142302,
        -0.05932925, -0.05650403, -0.05319491, -0.04963718, -0.04603131,
        -0.04257153, -0.03942811, -0.03710138, -0.03584006, -0.03651759,
        -0.07741419, -0.04288697, -0.05224779, -0.05554156, -0.06033132,
        -0.06390916, -0.06653767, -0.06784932, -0.06779203, -0.06641879,
        -0.06393702, -0.06064099, -0.05684841, -0.0528433 , -0.04884744,
        -0.04505551, -0.04162486, -0.03905525, -0.03758411, -0.03810103,
        -0.07829439, -0.04485248, -0.0545332 , -0.05854348, -0.06396671,
        -0.06813076, -0.07119061, -0.07271434, -0.07261209, -0.07095851,
        -0.06803011, -0.06421622, -0.0599229 , -0.05548677, -0.05114331,
        -0.04707624, -0.043418  , -0.04065068, -0.03898145, -0.03929501,
        -0.07792118, -0.04622849, -0.05608995, -0.06081812, -0.06685529,
        -0.07161358, -0.07512219, -0.07686753, -0.07671013, -0.07475582,
        -0.07136042, -0.06702551, -0.06226498, -0.05746403, -0.05285574,
        -0.0485991 , -0.04478743, -0.04186533, -0.03999763, -0.04006557,
        -0.07597667, -0.04671387, -0.05658838, -0.06196443, -0.06851482,
        -0.07379484, -0.0777125 , -0.07966218, -0.07944842, -0.07723986,
        -0.07344654, -0.06866878, -0.06355752, -0.05852437, -0.05378088,
        -0.04945883, -0.04560439, -0.04260374, -0.04055684, -0.04031915,
        -0.07237884, -0.0461642 , -0.05584601, -0.0617139 , -0.06855936,
        -0.07413825, -0.07828469, -0.08034239, -0.08006642, -0.07780343,
        -0.0738728 , -0.06885952, -0.06362342, -0.05853456, -0.05380256,
        -0.04955165, -0.04577845, -0.04278799, -0.04058279, -0.03995149,
        -0.06715175, -0.0445    , -0.05377308, -0.05990921, -0.06671818,
        -0.07217768, -0.07608877, -0.07795095, -0.07765217, -0.07577966,
        -0.0723036 , -0.0675128 , -0.06242755, -0.05742752, -0.05283808,
        -0.0487815 , -0.04520722, -0.04231654, -0.03998046, -0.03883516,
        -0.0604601 , -0.04173981, -0.05041982, -0.05659094, -0.06300666,
        -0.06785047, -0.0709469 , -0.07215064, -0.07215994, -0.07098108,
        -0.06854363, -0.0647727 , -0.05996689, -0.05519208, -0.05082483,
        -0.04706693, -0.04377374, -0.0410345 , -0.03859492, -0.03681317,
        -0.052581  , -0.03795681, -0.04591476, -0.05191578, -0.05763035,
        -0.06152919, -0.06370612, -0.06484938, -0.0650707 , -0.06438871,
        -0.062751  , -0.0600587 , -0.05626533, -0.05177071, -0.04773291,
        -0.04435677, -0.04137886, -0.03873919, -0.03612414, -0.03365744,
        -0.04387162, -0.03322592, -0.04035997, -0.04597342, -0.0507089 ,
        -0.05356581, -0.05544946, -0.05650033, -0.05683618, -0.05649136,
        -0.05543684, -0.05358695, -0.0508481 , -0.04728798, -0.04357812,
        -0.04073138, -0.03801762, -0.0353799 , -0.03239432, -0.02911553,
        -0.03483696, -0.02773845, -0.03392948, -0.03899014, -0.04261998,
        -0.04501021, -0.04657076, -0.04746728, -0.04782586, -0.04769797,
        -0.04707682, -0.04589667, -0.04404916, -0.04148532, -0.0386194 ,
        -0.03624781, -0.0336891 , -0.03100492, -0.02777698, -0.02379231,
        -0.02609093, -0.02191012, -0.02701169, -0.03140952, -0.03433961,
        -0.03620604, -0.03739294, -0.03807637, -0.03837947, -0.03836164,
        -0.03803702, -0.03737625, -0.03630617, -0.03473215, -0.03275415,
        -0.03079089, -0.02832872, -0.02550893, -0.0221603 , -0.01826185,
        -0.01817713, -0.01625148, -0.02029448, -0.02393898, -0.02613643,
        -0.02745715, -0.02825734, -0.02870348, -0.0289004 , -0.0288997 ,
        -0.0287207 , -0.02835784, -0.027784  , -0.0269515 , -0.02578982,
        -0.02421586, -0.02212435, -0.01953038, -0.01637254, -0.01297872,
        -0.01143394, -0.01101474, -0.01404649, -0.01687404, -0.01834726,
        -0.01915075, -0.0196022 , -0.01983593, -0.01992574, -0.01990414,
        -0.01978117, -0.01955033, -0.01919038, -0.01866466, -0.01791704,
        -0.01687472, -0.01543594, -0.01355939, -0.01113939, -0.00838121,
        -0.00616378, -0.00649324, -0.00857165, -0.01055798, -0.01136989,
        -0.01174511, -0.01193571, -0.0120232 , -0.01204682, -0.01202089,
        -0.01194826, -0.01182317, -0.0116318 , -0.01135068, -0.01094265,
        -0.01035493, -0.00950578, -0.00832748, -0.00669003, -0.00467902,
        -0.00256191, -0.00300083, -0.00419494, -0.00536893, -0.00565432,
        -0.00574979, -0.00579273, -0.00580659, -0.00580326, -0.00578525,
        -0.00575197, -0.00569985, -0.00562215, -0.00550771, -0.00533822,
        -0.00508544, -0.00470068, -0.0041258 , -0.0032451 , -0.00204349,
        -0.00062085, -0.00081911, -0.0012532 , -0.00171723, -0.00171474,
        -0.00171277, -0.00171016, -0.0017067 , -0.001702  , -0.00169554,
        -0.00168653, -0.0016738 , -0.00165547, -0.00162856, -0.00158804,
        -0.00152557, -0.00142517, -0.00126255, -0.00098199, -0.00053731]),
 'phi': array(10.0),
 'phi_interval__': array(0.0)}

In [87]:
plt.imshow(map_estimate['latent_field'].reshape(20,20))


Out[87]:
<matplotlib.image.AxesImage at 0x7fa2c84d8a50>

In [26]:
pm.plot_posterior(mean_field.sample(10), color='LightSeaGreen');