In [1]:
from mpl_toolkits.mplot3d import axes3d
import matplotlib.pyplot as plt
from matplotlib import cm
import numpy as np
In [2]:
%matplotlib notebook
In [4]:
def f(x,y):
return 2*x**2 + y**2
In [6]:
x = np.linspace( -1 , 1 , 100 )
y = np.linspace( -1 , 1 , 100 )
X,Y= np.meshgrid( x , y)
Z = f(X,Y)
In [11]:
fig = plt.figure() #figsize=(15, 10))
ax = fig.add_subplot( 111, projection='3d')
ax.tick_params(labelsize=8)
# ax.view_init(azim=azim, elev=elev)
ax.plot_surface(X, Y, Z, rstride=10, cstride=10, alpha=0.3)
ax.contourf(X, Y, Z, zdir='z')#, offset=Z.min, cmap=cm.coolwarm)
# ax.contourf(x, y, z, zdir='x', offset=x_min, cmap=cm.coolwarm)
# if j == 0 or j == 1:
# ax.contourf(x, y, z, zdir='y', offset=y_max, cmap=cm.coolwarm)
# elif j == 2 or j == 3:
# ax.contourf(x, y, z, zdir='y', offset=y_min, cmap=cm.coolwarm)
# ax.set_xlabel('X')
# ax.set_xlim(x_min, x_max)
# ax.set_ylabel('Y')
# ax.set_ylim(y_min, y_max)
# ax.set_zlabel('Z')
# ax.set_zlim(z_min, z_max)
#plt.savefig(file_path, dpi=80)
#plt.close()
Out[11]: