In [1]:
import numba as nb
import numpy as np
def conv_kernel(x, w, rs, n, n_channels, height, width, n_filters, filter_height, filter_width, out_h, out_w):
for i in range(n):
for j in range(out_h):
for p in range(out_w):
window = x[i, ..., j:j+filter_height, p:p+filter_width]
for q in range(n_filters):
rs[i, q, j, p] += np.sum(w[q] * window)
@nb.jit(nopython=True)
def jit_conv_kernel(x, w, rs, n, n_channels, height, width, n_filters, filter_height, filter_width, out_h, out_w):
for i in range(n):
for j in range(out_h):
for p in range(out_w):
window = x[i, ..., j:j+filter_height, p:p+filter_width]
for q in range(n_filters):
rs[i, q, j, p] += np.sum(w[q] * window)
def conv(x, w, kernel, args):
n, n_filters = args[0], args[4]
out_h, out_w = args[-2:]
rs = np.zeros([n, n_filters, out_h, out_w], dtype=np.float32)
kernel(x, w, rs, *args)
return rs
def cs231n_conv(x, w, args):
n, n_channels, height, width, n_filters, filter_height, filter_width, out_h, out_w = args
shape = (n_channels, filter_height, filter_width, n, out_h, out_w)
strides = (height * width, width, 1, n_channels * height * width, width, 1)
strides = x.itemsize * np.asarray(strides)
x_cols = np.lib.stride_tricks.as_strided(x, shape=shape, strides=strides).reshape(
n_channels * filter_height * filter_width, n * out_h * out_w)
return w.reshape(n_filters, -1).dot(x_cols).reshape(n_filters, n, out_h, out_w).transpose(1, 0, 2, 3)
# 64 个 3 x 28 x 28 的图像输入(模拟 mnist)
x = np.random.randn(64, 3, 28, 28).astype(np.float32)
# 16 个 5 x 5 的 kernel
w = np.random.randn(16, x.shape[1], 5, 5).astype(np.float32)
n, n_channels, height, width = x.shape
n_filters, _, filter_height, filter_width = w.shape
out_h = height - filter_height + 1
out_w = width - filter_width + 1
args = (n, n_channels, height, width, n_filters, filter_height, filter_width, out_h, out_w)
print(np.linalg.norm((cs231n_conv(x, w, args) - conv(x, w, conv_kernel, args)).ravel()))
print(np.linalg.norm((cs231n_conv(x, w, args) - conv(x, w, jit_conv_kernel, args)).ravel()))
print(np.linalg.norm((conv(x, w, conv_kernel, args) - conv(x, w, jit_conv_kernel, args)).ravel()))
%timeit conv(x, w, conv_kernel, args)
%timeit conv(x, w, jit_conv_kernel, args)
%timeit cs231n_conv(x, w, args)
np.allclose的话会过不了assert;事实上,仅仅是将数组的dtype从float64变成float32、精度就会下降很多,毕竟卷积涉及到的运算太多
In [2]:
@nb.jit(nopython=True)
def jit_conv_kernel2(x, w, rs, n, n_channels, height, width, n_filters, filter_height, filter_width, out_h, out_w):
for i in range(n):
for j in range(out_h):
for p in range(out_w):
for q in range(n_filters):
for r in range(n_channels):
for s in range(filter_height):
for t in range(filter_width):
rs[i, q, j, p] += x[i, r, j+s, p+t] * w[q, r, s, t]
assert np.allclose(conv(x, w, jit_conv_kernel, args), conv(x, w, jit_conv_kernel, args))
%timeit conv(x, w, jit_conv_kernel, args)
%timeit conv(x, w, jit_conv_kernel2, args)
%timeit cs231n_conv(x, w, args)
jit和使用纯numpy进行编程的很大一点不同就是,不要畏惧用for;事实上一般来说,代码“长得越像 C”、速度就会越快
In [3]:
def max_pool_kernel(x, rs, *args):
n, n_channels, pool_height, pool_width, out_h, out_w = args
for i in range(n):
for j in range(n_channels):
for p in range(out_h):
for q in range(out_w):
window = x[i, j, p:p+pool_height, q:q+pool_width]
rs[i, j, p, q] += np.max(window)
@nb.jit(nopython=True)
def jit_max_pool_kernel(x, rs, *args):
n, n_channels, pool_height, pool_width, out_h, out_w = args
for i in range(n):
for j in range(n_channels):
for p in range(out_h):
for q in range(out_w):
window = x[i, j, p:p+pool_height, q:q+pool_width]
rs[i, j, p, q] += np.max(window)
@nb.jit(nopython=True)
def jit_max_pool_kernel2(x, rs, *args):
n, n_channels, pool_height, pool_width, out_h, out_w = args
for i in range(n):
for j in range(n_channels):
for p in range(out_h):
for q in range(out_w):
_max = x[i, j, p, q]
for r in range(pool_height):
for s in range(pool_width):
_tmp = x[i, j, p+r, q+s]
if _tmp > _max:
_max = _tmp
rs[i, j, p, q] += _max
def max_pool(x, kernel, args):
n, n_channels = args[:2]
out_h, out_w = args[-2:]
rs = np.zeros([n, n_filters, out_h, out_w], dtype=np.float32)
kernel(x, rs, *args)
return rs
pool_height, pool_width = 2, 2
n, n_channels, height, width = x.shape
out_h = height - pool_height + 1
out_w = width - pool_width + 1
args = (n, n_channels, pool_height, pool_width, out_h, out_w)
assert np.allclose(max_pool(x, max_pool_kernel, args), max_pool(x, jit_max_pool_kernel, args))
assert np.allclose(max_pool(x, jit_max_pool_kernel, args), max_pool(x, jit_max_pool_kernel2, args))
%timeit max_pool(x, max_pool_kernel, args)
%timeit max_pool(x, jit_max_pool_kernel, args)
%timeit max_pool(x, jit_max_pool_kernel2, args)