Localization

Use "localization" to learn a Cahn-Hilliard model.

Learn a Cahn-Hilliard

Square domain periodic boundary conditions

$$ \dot{\phi} = \nabla^2 \left( \phi^3 - \phi \right) - \gamma \nabla^4 \phi $$

What are we trying to do?

Create a mapping from $t_0$ to $t_{10}$ without doing all the steps. We want to do the following.

$$ \phi[s](t=t_0) \rightarrow \phi[s](t=t_{10})$$

Localization

Use regression for each local state.

Create Samples


In [4]:
%matplotlib inline

import pymks
import matplotlib.pyplot as plt
import numpy as np
from pymks.datasets import make_cahn_hilliard

In [5]:
n_steps = 10
size = (151, 151)
X, y = make_cahn_hilliard(n_samples=10, size=size, dt=1., n_steps=n_steps)

In [6]:
print(X.shape)
print(y.shape)


(10, 151, 151)
(10, 151, 151)

In [7]:
# NBVAL_IGNORE_OUTPUT

plt.imshow(X[0])
plt.colorbar()


Out[7]:
<matplotlib.colorbar.Colorbar at 0x7f974b07e2e8>

In [8]:
# NBVAL_IGNORE_OUTPUT

plt.imshow(y[0])
plt.colorbar()


Out[8]:
<matplotlib.colorbar.Colorbar at 0x7f974b0097b8>

Parallel


In [9]:
from dask import compute, delayed
import dask.multiprocessing

def make_data(seed):
    np.random.seed(seed)
    return make_cahn_hilliard(n_samples=10, size=size, dt=1., n_steps=n_steps)

funcs = [delayed(make_data)(seed) for seed in range(30)]
    
out = compute(*funcs, get=dask.multiprocessing.get)

In [10]:
np.array(out).shape


Out[10]:
(30, 2, 10, 151, 151)

In [11]:
X = np.array(out)[:, 0].reshape((300,) + size)
y = np.array(out)[:, 1].reshape((300,) + size)

Learning


In [12]:
from pymks import MKSLocalizationModel
from pymks.bases import PrimitiveBasis

basis = PrimitiveBasis(n_states=5, domain=[-1, 1])
model = MKSLocalizationModel(basis=basis)

In [13]:
model.fit(X[:-1], y[:-1])

In [14]:
y_pred = model.predict(X[-1:])

In [15]:
# NBVAL_IGNORE_OUTPUT
plt.imshow(y_pred[0])
plt.colorbar()


Out[15]:
<matplotlib.colorbar.Colorbar at 0x7f97485d7400>

In [16]:
# NBVAL_IGNORE_OUTPUT
plt.imshow(y[-1])
plt.colorbar()


Out[16]:
<matplotlib.colorbar.Colorbar at 0x7f9748573550>

Train Test Split


In [17]:
from sklearn.model_selection import train_test_split
from sklearn import metrics

X_train, X_test, y_train, y_test = train_test_split(X, y)

In [18]:
basis = PrimitiveBasis(n_states=5, domain=[-1, 1])
model = MKSLocalizationModel(basis=basis)

In [19]:
model.fit(X_train, y_train)

In [20]:
y_pred = model.predict(X_test)

In [21]:
# NBVAL_IGNORE_OUTPUT

metrics.mean_squared_error(y_pred.flatten(), y_test.flatten())


Out[21]:
2.1555880393661474e-06

In [22]:
# NBVAL_IGNORE_OUTPUT

print(y_pred[0][0][:10])
print(y_test[0][0][:10])


[-0.00110223  0.0045004   0.01059271  0.01693907  0.0232881   0.02937962
  0.03495207  0.03975089  0.04353786  0.04610152]
[  9.93730853e-05   5.86033997e-03   1.21063191e-02   1.85983116e-02
   2.50806271e-02   3.12879128e-02   3.69529907e-02   4.18156644e-02
   4.56325562e-02   4.81878694e-02]

In [23]:
y_pred.shape


Out[23]:
(75, 151, 151)

In [24]:
# NBVAL_IGNORE_OUTPUT

plt.imshow(y_pred[0])
plt.colorbar()


Out[24]:
<matplotlib.colorbar.Colorbar at 0x7f97484a2b70>

In [25]:
# NBVAL_IGNORE_OUTPUT

plt.imshow(y_test[0])
plt.colorbar()


Out[25]:
<matplotlib.colorbar.Colorbar at 0x7f974835fba8>

Scale Up


In [26]:
X_big, y_big = make_cahn_hilliard(n_samples=1, size=(1000, 1000), dt=1., n_steps=n_steps)

In [27]:
# NBVAL_IGNORE_OUTPUT

plt.imshow(y_big[0])
plt.colorbar()


Out[27]:
<matplotlib.colorbar.Colorbar at 0x7f974821b978>

In [28]:
basis = PrimitiveBasis(n_states=5, domain=[-1, 1])
model = MKSLocalizationModel(basis=basis)
model.fit(X, y)
model.resize_coeff(y_big[0].shape)

In [29]:
y_big_pred = model.predict(X_big)

In [30]:
# NBVAL_IGNORE_OUTPUT

plt.imshow(y_big_pred[0])
plt.colorbar()


Out[30]:
<matplotlib.colorbar.Colorbar at 0x7f974810ff28>

In [31]:
# NBVAL_IGNORE_OUTPUT

metrics.mean_squared_error(y_big_pred.flatten(), y_big.flatten())


Out[31]:
2.1099181366034632e-06

In [32]:
# NBVAL_IGNORE_OUTPUT

%timeit make_cahn_hilliard(n_samples=1, size=(1000, 1000), dt=1., n_steps=n_steps)


2.53 s ± 29.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

In [33]:
# NBVAL_IGNORE_OUTPUT

%timeit model.predict(X_big)


218 ms ± 2.06 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Multiple Steps


In [34]:
X2, y2 = make_cahn_hilliard(n_samples=1, size=size, dt=1., n_steps=2 * n_steps)

basis = PrimitiveBasis(n_states=10, domain=[-1, 1])
model = MKSLocalizationModel(basis=basis)
model.fit(X, y)

In [35]:
tmp = model.predict(X2)
y2_pred = model.predict(tmp)

In [36]:
# NBVAL_IGNORE_OUTPUT

plt.imshow(y2[0])
plt.colorbar()


Out[36]:
<matplotlib.colorbar.Colorbar at 0x7f97407faf98>

In [37]:
# NBVAL_IGNORE_OUTPUT

plt.imshow(y2_pred[0])
plt.colorbar()


Out[37]:
<matplotlib.colorbar.Colorbar at 0x7f974075b470>

In [38]:
# NBVAL_IGNORE_OUTPUT

metrics.mean_squared_error(y2_pred.flatten(), y2.flatten())


Out[38]:
0.00031351700544175379

Cross Validation


In [39]:
from pymks.bases import LegendreBasis
from sklearn.model_selection import GridSearchCV
from dask_searchcv import GridSearchCV
from sklearn import metrics
mse = metrics.mean_squared_error

prim_basis = PrimitiveBasis(2, [-1, 1])
leg_basis = LegendreBasis(2, [-1, 1])

params_to_tune = {'n_states': [2, 3, 5, 8, 13],
                  'basis': [prim_basis, leg_basis]}
model = MKSLocalizationModel(prim_basis)
score_func = metrics.make_scorer(lambda x, y: -mse(x.flatten(), y.flatten()))
gscv = GridSearchCV(model, params_to_tune, cv=5, scoring=score_func, scheduler="multiprocessing", n_jobs=8)

In [40]:
?GridSearchCV

In [41]:
# NBVAL_SKIP

gscv.fit(X_train, y_train)


Out[41]:
GridSearchCV(cache_cv=True, cv=5, error_score='raise',
       estimator=MKSLocalizationModel(basis=<pymks.bases.primitive.PrimitiveBasis object at 0x7f97406b5908>,
           lstsq_rcond=2.2204460492503131e-12, n_jobs=None,
           n_states=array([0, 1])),
       iid=True, n_jobs=8,
       param_grid={'n_states': [2, 3, 5, 8, 13], 'basis': [<pymks.bases.primitive.PrimitiveBasis object at 0x7f97406b5908>, <pymks.bases.legendre.LegendreBasis object at 0x7f97406b5358>]},
       refit=True, return_train_score=True, scheduler='multiprocessing',
       scoring=make_scorer(<lambda>))

In [42]:
# NBVAL_SKIP

gscv.best_estimator_


Out[42]:
MKSLocalizationModel(basis=<pymks.bases.legendre.LegendreBasis object at 0x7f97406b5b00>,
           lstsq_rcond=2.2204460492503131e-12, n_jobs=None, n_states=5)

In [43]:
# NBVAL_SKIP

gscv.score(X_test, y_test)


Out[43]:
-1.0112404417022941e-06

In [44]:
# NBVAL_SKIP

gscv.cv_results_


Out[44]:
{'mean_fit_time': array([ 16.07892206,  21.18541507,  30.1707891 ,  41.39256469,
         83.10909031,  15.81350434,  21.09269637,  33.81124536,
         49.81040086,  87.33250256]),
 'mean_score_time': array([  2.85809254,   4.15486318,   5.9808266 ,  13.8916885 ,
         13.74197972,   2.91444176,   4.15067057,   6.35742498,
         11.50110714,  12.87102915]),
 'mean_test_score': array([ -2.22335420e-05,  -2.23956097e-05,  -2.04233343e-06,
         -1.05399388e-06,  -9.84174560e-07,  -2.22335420e-05,
         -2.23846502e-05,  -9.25497479e-07,  -9.41153306e-07,
         -9.68875898e-07]),
 'mean_train_score': array([ -2.19804594e-05,  -2.18662070e-05,  -1.95267014e-06,
         -9.71922160e-07,  -8.57395059e-07,  -2.19804594e-05,
         -2.18599368e-05,  -8.85200210e-07,  -8.69587803e-07,
         -8.44952268e-07]),
 'param_basis': masked_array(data = [<pymks.bases.primitive.PrimitiveBasis object at 0x7f97406cc4e0>
  <pymks.bases.primitive.PrimitiveBasis object at 0x7f97406cc4e0>
  <pymks.bases.primitive.PrimitiveBasis object at 0x7f97406cc4e0>
  <pymks.bases.primitive.PrimitiveBasis object at 0x7f97406cc4e0>
  <pymks.bases.primitive.PrimitiveBasis object at 0x7f97406cc4e0>
  <pymks.bases.legendre.LegendreBasis object at 0x7f97406b5400>
  <pymks.bases.legendre.LegendreBasis object at 0x7f97406b5400>
  <pymks.bases.legendre.LegendreBasis object at 0x7f97406b5400>
  <pymks.bases.legendre.LegendreBasis object at 0x7f97406b5400>
  <pymks.bases.legendre.LegendreBasis object at 0x7f97406b5400>],
              mask = [False False False False False False False False False False],
        fill_value = ?),
 'param_n_states': masked_array(data = [2 3 5 8 13 2 3 5 8 13],
              mask = [False False False False False False False False False False],
        fill_value = ?),
 'params': [{'basis': <pymks.bases.primitive.PrimitiveBasis at 0x7f97406cc4e0>,
   'n_states': 2},
  {'basis': <pymks.bases.primitive.PrimitiveBasis at 0x7f97406cc4e0>,
   'n_states': 3},
  {'basis': <pymks.bases.primitive.PrimitiveBasis at 0x7f97406cc4e0>,
   'n_states': 5},
  {'basis': <pymks.bases.primitive.PrimitiveBasis at 0x7f97406cc4e0>,
   'n_states': 8},
  {'basis': <pymks.bases.primitive.PrimitiveBasis at 0x7f97406cc4e0>,
   'n_states': 13},
  {'basis': <pymks.bases.legendre.LegendreBasis at 0x7f97406b5400>,
   'n_states': 2},
  {'basis': <pymks.bases.legendre.LegendreBasis at 0x7f97406b5400>,
   'n_states': 3},
  {'basis': <pymks.bases.legendre.LegendreBasis at 0x7f97406b5400>,
   'n_states': 5},
  {'basis': <pymks.bases.legendre.LegendreBasis at 0x7f97406b5400>,
   'n_states': 8},
  {'basis': <pymks.bases.legendre.LegendreBasis at 0x7f97406b5400>,
   'n_states': 13}],
 'rank_test_score': array([ 7, 10,  6,  5,  4,  8,  9,  1,  2,  3], dtype=int32),
 'split0_test_score': array([ -2.17990702e-05,  -2.20479674e-05,  -2.05045762e-06,
         -1.05527700e-06,  -9.88822233e-07,  -2.17990702e-05,
         -2.20278504e-05,  -9.17583702e-07,  -9.38035288e-07,
         -9.72806395e-07]),
 'split0_train_score': array([ -2.20844479e-05,  -2.19523656e-05,  -1.95022403e-06,
         -9.71683677e-07,  -8.56712647e-07,  -2.20844479e-05,
         -2.19491025e-05,  -8.86735694e-07,  -8.70338770e-07,
         -8.44264859e-07]),
 'split1_test_score': array([ -2.34110748e-05,  -2.35618389e-05,  -2.12264363e-06,
         -1.12772406e-06,  -1.04627272e-06,  -2.34110748e-05,
         -2.35538072e-05,  -9.96846984e-07,  -1.01022825e-06,
         -1.03247986e-06]),
 'split1_train_score': array([ -2.16900093e-05,  -2.15748778e-05,  -1.93267494e-06,
         -9.54241379e-07,  -8.42281730e-07,  -2.16900093e-05,
         -2.15691970e-05,  -8.67979867e-07,  -8.52824799e-07,
         -8.29814042e-07]),
 'split2_test_score': array([ -2.25688588e-05,  -2.27315482e-05,  -2.09692307e-06,
         -1.11153384e-06,  -1.04475329e-06,  -2.25688588e-05,
         -2.27197064e-05,  -9.84705383e-07,  -9.97969013e-07,
         -1.02785979e-06]),
 'split2_train_score': array([ -2.19002430e-05,  -2.17865300e-05,  -1.93993951e-06,
         -9.58131666e-07,  -8.43489608e-07,  -2.19002430e-05,
         -2.17777775e-05,  -8.71032933e-07,  -8.56231220e-07,
         -8.31121444e-07]),
 'split3_test_score': array([ -2.20531351e-05,  -2.22585789e-05,  -2.06339185e-06,
         -1.02436173e-06,  -9.56860637e-07,  -2.20531351e-05,
         -2.22282812e-05,  -8.97583596e-07,  -9.09654949e-07,
         -9.35430426e-07]),
 'split3_train_score': array([ -2.20304639e-05,  -2.19058398e-05,  -1.94749965e-06,
         -9.78678090e-07,  -8.63670957e-07,  -2.20304639e-05,
         -2.19039051e-05,  -8.91818595e-07,  -8.76712019e-07,
         -8.52303576e-07]),
 'split4_test_score': array([ -2.13355710e-05,  -2.13781152e-05,  -1.87825100e-06,
         -9.51072775e-07,  -8.84163918e-07,  -2.13355710e-05,
         -2.13936059e-05,  -8.30767730e-07,  -8.49879028e-07,
         -8.75803027e-07]),
 'split4_train_score': array([ -2.21971329e-05,  -2.21114216e-05,  -1.99301257e-06,
         -9.96875987e-07,  -8.80820354e-07,  -2.21971329e-05,
         -2.20997017e-05,  -9.08433960e-07,  -8.91832205e-07,
         -8.67257422e-07]),
 'std_fit_time': array([  0.30051438,   0.50655807,   0.4214792 ,   2.83382237,
         14.2735837 ,   0.2020092 ,   0.4517502 ,   0.83604798,
          6.61041831,  14.12084489]),
 'std_score_time': array([ 0.07885922,  0.20402628,  0.17787768,  2.47543572,  0.70477799,
         0.05275955,  0.21185631,  0.28278402,  2.16403836,  1.41425199]),
 'std_test_score': array([  7.10870614e-07,   7.27718731e-07,   8.58629854e-08,
          6.35915888e-08,   6.04908984e-08,   7.10870614e-07,
          7.22926979e-07,   6.06825606e-08,   5.88765405e-08,
          5.88577445e-08]),
 'std_train_score': array([  1.73814487e-07,   1.79078397e-07,   2.10812976e-08,
          1.53047749e-08,   1.42152826e-08,   1.73814487e-07,
          1.78129585e-07,   1.47188303e-08,   1.41811738e-08,
          1.39462667e-08])}