In [1]:
import numba as nb
import numpy as np

def conv_kernel(x, w, rs, n, n_channels, height, width, n_filters, filter_height, filter_width, out_h, out_w):
    for i in range(n):
        for j in range(out_h):
            for p in range(out_w):
                window = x[i, ..., j:j+filter_height, p:p+filter_width]
                for q in range(n_filters):
                    rs[i, q, j, p] += np.sum(w[q] * window)

@nb.jit(nopython=True)
def jit_conv_kernel(x, w, rs, n, n_channels, height, width, n_filters, filter_height, filter_width, out_h, out_w):
    for i in range(n):
        for j in range(out_h):
            for p in range(out_w):
                window = x[i, ..., j:j+filter_height, p:p+filter_width]
                for q in range(n_filters):
                    rs[i, q, j, p] += np.sum(w[q] * window)

def conv(x, w, kernel, args):
    n, n_filters = args[0], args[4]
    out_h, out_w = args[-2:]
    rs = np.zeros([n, n_filters, out_h, out_w], dtype=np.float32)
    kernel(x, w, rs, *args)
    return rs

def cs231n_conv(x, w, args):
    n, n_channels, height, width, n_filters, filter_height, filter_width, out_h, out_w = args
    shape = (n_channels, filter_height, filter_width, n, out_h, out_w)
    strides = (height * width, width, 1, n_channels * height * width, width, 1)
    strides = x.itemsize * np.asarray(strides)
    x_cols = np.lib.stride_tricks.as_strided(x, shape=shape, strides=strides).reshape(
        n_channels * filter_height * filter_width, n * out_h * out_w)
    return w.reshape(n_filters, -1).dot(x_cols).reshape(n_filters, n, out_h, out_w).transpose(1, 0, 2, 3)

# 64 个 3 x 28 x 28 的图像输入(模拟 mnist)
x = np.random.randn(64, 3, 28, 28).astype(np.float32)
# 16 个 5 x 5 的 kernel
w = np.random.randn(16, x.shape[1], 5, 5).astype(np.float32)

n, n_channels, height, width = x.shape
n_filters, _, filter_height, filter_width = w.shape
out_h = height - filter_height + 1
out_w = width - filter_width + 1
args = (n, n_channels, height, width, n_filters, filter_height, filter_width, out_h, out_w)

print(np.linalg.norm((cs231n_conv(x, w, args) - conv(x, w, conv_kernel, args)).ravel()))
print(np.linalg.norm((cs231n_conv(x, w, args) - conv(x, w, jit_conv_kernel, args)).ravel()))
print(np.linalg.norm((conv(x, w, conv_kernel, args) - conv(x, w, jit_conv_kernel, args)).ravel()))
%timeit conv(x, w, conv_kernel, args)
%timeit conv(x, w, jit_conv_kernel, args)
%timeit cs231n_conv(x, w, args)


0.00114815
0.000736224
0.00113975
3.69 s ± 161 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
The slowest run took 4.78 times longer than the fastest. This could mean that an intermediate result is being cached.
1.05 s ± 465 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
18.5 ms ± 10.8 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)
  • 注意:这里如果使用np.allclose的话会过不了assert;事实上,仅仅是将数组的dtypefloat64变成float32、精度就会下降很多,毕竟卷积涉及到的运算太多

In [2]:
@nb.jit(nopython=True)
def jit_conv_kernel2(x, w, rs, n, n_channels, height, width, n_filters, filter_height, filter_width, out_h, out_w):
    for i in range(n):
        for j in range(out_h):
            for p in range(out_w):
                for q in range(n_filters):
                    for r in range(n_channels):
                        for s in range(filter_height):
                            for t in range(filter_width):
                                rs[i, q, j, p] += x[i, r, j+s, p+t] * w[q, r, s, t]
                                
assert np.allclose(conv(x, w, jit_conv_kernel, args), conv(x, w, jit_conv_kernel, args))
%timeit conv(x, w, jit_conv_kernel, args)
%timeit conv(x, w, jit_conv_kernel2, args)
%timeit cs231n_conv(x, w, args)


288 ms ± 6.91 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
71.2 ms ± 5.44 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
8.7 ms ± 62.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
  • 可以看到,使用jit和使用纯numpy进行编程的很大一点不同就是,不要畏惧用for;事实上一般来说,代码“长得越像 C”、速度就会越快

In [3]:
def max_pool_kernel(x, rs, *args):
    n, n_channels, pool_height, pool_width, out_h, out_w = args
    for i in range(n):
        for j in range(n_channels):
            for p in range(out_h):
                for q in range(out_w):
                    window = x[i, j, p:p+pool_height, q:q+pool_width]
                    rs[i, j, p, q] += np.max(window)

@nb.jit(nopython=True)
def jit_max_pool_kernel(x, rs, *args):
    n, n_channels, pool_height, pool_width, out_h, out_w = args
    for i in range(n):
        for j in range(n_channels):
            for p in range(out_h):
                for q in range(out_w):
                    window = x[i, j, p:p+pool_height, q:q+pool_width]
                    rs[i, j, p, q] += np.max(window)
                    
@nb.jit(nopython=True)
def jit_max_pool_kernel2(x, rs, *args):
    n, n_channels, pool_height, pool_width, out_h, out_w = args
    for i in range(n):
        for j in range(n_channels):
            for p in range(out_h):
                for q in range(out_w):
                    _max = x[i, j, p, q]
                    for r in range(pool_height):
                        for s in range(pool_width):
                            _tmp = x[i, j, p+r, q+s]
                            if _tmp > _max:
                                _max = _tmp
                    rs[i, j, p, q] += _max

def max_pool(x, kernel, args):
    n, n_channels = args[:2]
    out_h, out_w = args[-2:]
    rs = np.zeros([n, n_filters, out_h, out_w], dtype=np.float32)
    kernel(x, rs, *args)
    return rs

pool_height, pool_width = 2, 2
n, n_channels, height, width = x.shape
out_h = height - pool_height + 1
out_w = width - pool_width + 1
args = (n, n_channels, pool_height, pool_width, out_h, out_w)

assert np.allclose(max_pool(x, max_pool_kernel, args), max_pool(x, jit_max_pool_kernel, args))
assert np.allclose(max_pool(x, jit_max_pool_kernel, args), max_pool(x, jit_max_pool_kernel2, args))
%timeit max_pool(x, max_pool_kernel, args)
%timeit max_pool(x, jit_max_pool_kernel, args)
%timeit max_pool(x, jit_max_pool_kernel2, args)


696 ms ± 56.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
8.68 ms ± 92.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
1.54 ms ± 59.3 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)