In [2]:
%matplotlib inline
import matplotlib
import numpy as np
import matplotlib.pyplot as plt
In [3]:
def plot_func(f, label="", start=-4, end=4, rate=1000):
plt.figure("f vs x")
plt.ylabel("f")
plt.xlabel("x")
x = np.linspace(start, end, rate)
curve, = plt.plot(x, f(x), label=label)
In [4]:
plot_func(lambda x: 2*x+1)
In [5]:
def plot_mesh(W, f, sz=100):
nx, ny = (sz, sz)
x = np.linspace(0, 11, nx)
y = np.linspace(0, 11, ny)
xv, yv = np.meshgrid(x, y)
grid = np.zeros((nx, ny, 3))
for i in xrange(nx):
for j in xrange(ny):
pt = np.array([1, xv[i, j], yv[i, j]])
grid[i, j] = pt
predict = np.apply_along_axis(lambda x: f(W, x), 2, grid) # 3 axes: (nx, ny, 3)
plt.contour(x, y, predict, colors="b")
# plt.show()
In [10]:
import warnings
warnings.filterwarnings("ignore") # suppress FutureWarning
f = lambda W, x: np.sign(W.T.dot(x))
W = np.array([0.1, -0.2, 0.1])
plot_mesh(W, f)
In [ ]: