In [1]:
from mpl_toolkits.mplot3d import axes3d
import matplotlib.pyplot as plt
from matplotlib import cm
import numpy as np

In [2]:
%matplotlib notebook

In [4]:
def f(x,y):
    return 2*x**2 + y**2

In [6]:
x = np.linspace( -1 , 1 , 100 )
y = np.linspace( -1 , 1 , 100 )
X,Y= np.meshgrid( x , y)

Z = f(X,Y)

In [11]:
fig = plt.figure() #figsize=(15, 10))
ax = fig.add_subplot( 111, projection='3d')
ax.tick_params(labelsize=8)
#        ax.view_init(azim=azim, elev=elev)
ax.plot_surface(X, Y, Z, rstride=10, cstride=10, alpha=0.3)

ax.contourf(X, Y, Z, zdir='z')#, offset=Z.min, cmap=cm.coolwarm)
#    ax.contourf(x, y, z, zdir='x', offset=x_min, cmap=cm.coolwarm)
#        if j == 0 or j == 1:
#            ax.contourf(x, y, z, zdir='y', offset=y_max, cmap=cm.coolwarm)
#        elif j == 2 or j == 3:
#            ax.contourf(x, y, z, zdir='y', offset=y_min, cmap=cm.coolwarm)

#    ax.set_xlabel('X')
#    ax.set_xlim(x_min, x_max)
#    ax.set_ylabel('Y')
#    ax.set_ylim(y_min, y_max)
#    ax.set_zlabel('Z')
#    ax.set_zlim(z_min, z_max)

    #plt.savefig(file_path, dpi=80)
    #plt.close()


Out[11]:
<matplotlib.contour.QuadContourSet at 0x7ffb5a40a2b0>