In [1]:
import os
setup_script = os.path.join(os.environ['ENV_JUPYTER_SETUPS_DIR'], 'setup_sci_env_basic.py')
%run $setup_script
In [2]:
import itertools
from matplotlib.mlab import griddata as mlab_griddata
from scipy.interpolate import griddata as scipy_griddata
In [3]:
load_mpl_style('single_plot.mplstyle')
ignore_warnings()
Function that we use for demonstration
In [4]:
def saddle(x,y):
return x**2 - y**2
Original $x,y$ values:
In [5]:
x = np.linspace(-10.0, 10.0, 10)
y = np.linspace(-10.0, 10.0, 10)
In [6]:
X,Y = np.meshgrid(x,y)
In [7]:
X_flat = X.flatten()
Y_flat = Y.flatten()
In [8]:
X_flat.shape
Out[8]:
Calculating $z$ values:
In [9]:
Z = saddle(X,Y)
Z_flat = Z.flatten()
In [10]:
f,a = plt.subplots()
a.scatter(X_flat, Y_flat, c=Z_flat, s=20)
Out[10]:
In [11]:
x.shape
Out[11]:
In [12]:
Z.shape
Out[12]:
In [32]:
f,a = plt.subplots()
a.contourf(x, y, Z, s=20);
f.suptitle('Using the original grid')
Out[32]:
In [31]:
X.shape
Out[31]:
You can also pass the meshgrids as an argument
In [30]:
f,a = plt.subplots()
a.contourf(X, Y, Z, s=20);
f.suptitle('Using the original grid')
Out[30]:
Interpolation specification:
In [14]:
x_num, xmin, xmax = 100, X.min(), X.max()
y_num, ymin, ymax = 100, Y.min(), Y.max()
Original points
In [15]:
X,Y = np.meshgrid(x,y)
X_flat = X.flatten()
Y_flat = Y.flatten()
In [16]:
X_flat.shape
Out[16]:
Interpolation at points:
In [17]:
xi = np.linspace( xmin, xmax, x_num, endpoint=True )
yi = np.linspace( ymin, ymax, y_num, endpoint=True )
In [18]:
zi_mlab_linear = mlab_griddata(X_flat, Y_flat, Z_flat, xi, yi, interp='linear')
In [19]:
zi_mlab_linear.shape
Out[19]:
Original points:
In [20]:
xy_points = np.array(list(itertools.product(*[x,y])))
In [21]:
z_values = saddle(xy_points[:,0], xy_points[:,1])
In [22]:
xy_points.shape
Out[22]:
In [23]:
z_values.shape
Out[23]:
Interpolating at points:
In [24]:
xi = np.linspace( xmin, xmax, x_num, endpoint=True )
yi = np.linspace( ymin, ymax, y_num, endpoint=True )
Xi,Yi = np.meshgrid(xi,yi)
In [25]:
Xi.shape
Out[25]:
Interpolation:
In [26]:
zi_scipy_nearest = scipy_griddata(xy_points, z_values, (Xi, Yi), method='nearest')
zi_scipy_linear = scipy_griddata(xy_points, z_values, (Xi, Yi), method='linear')
zi_scipy_cubic = scipy_griddata(xy_points, z_values, (Xi, Yi), method='cubic')
In [27]:
zi_scipy_linear.shape
Out[27]:
Scatter plot
In [28]:
f,a = plt.subplots()
a.scatter(Xi.flatten(), Yi.flatten(), c='b', s=5, label='interpolated pts.')
a.scatter(X_flat, Y_flat, c='k', s=20, label='original pts.')
plt.legend();
Contours
In [29]:
f,axes = plt.subplots(nrows=3, ncols=2, figsize=(12,12))
a = axes.flatten()
a[0].contourf(x, y, Z)
a[1].contourf(xi, yi, zi_mlab_linear)
a[2].contourf(xi, yi, zi_scipy_nearest)
a[3].contourf(xi, yi, zi_scipy_linear)
a[4].contourf(xi, yi, zi_scipy_cubic)
a[0].set_title('Original grid, 10x10')
a[1].set_title('mlab linear interpolation, 100x100')
a[2].set_title('scipy nearest neighbour interpolation, 100x100')
a[3].set_title('scipy linear interpolation, 100x100')
a[4].set_title('scipy cubic interpolation, 100x100')
f.delaxes(a[5])