In [2]:
%matplotlib inline

import matplotlib
import numpy as np
import matplotlib.pyplot as plt

In [3]:
def plot_func(f, label="", start=-4, end=4, rate=1000):
    plt.figure("f vs x")
    plt.ylabel("f")
    plt.xlabel("x")
    x = np.linspace(start, end, rate)
    curve, = plt.plot(x, f(x), label=label)

In [4]:
plot_func(lambda x: 2*x+1)



In [5]:
def plot_mesh(W, f, sz=100):
    nx, ny = (sz, sz)
    x = np.linspace(0, 11, nx)
    y = np.linspace(0, 11, ny)
    xv, yv = np.meshgrid(x, y)
    grid = np.zeros((nx, ny, 3))
    for i in xrange(nx):
        for j in xrange(ny):
            pt = np.array([1, xv[i, j], yv[i, j]])
            grid[i, j] = pt
    
    
    predict = np.apply_along_axis(lambda x: f(W, x), 2, grid)  # 3 axes: (nx, ny, 3)
    plt.contour(x, y, predict, colors="b")
    # plt.show()

In [10]:
import warnings
warnings.filterwarnings("ignore")  # suppress FutureWarning

f = lambda W, x: np.sign(W.T.dot(x))
W = np.array([0.1, -0.2, 0.1])
plot_mesh(W, f)



In [ ]: