In [4]:
import keras.backend as K
import numpy as np
x = K.variable(np.random.rand(1, 2, 4, 4, 4))
kernel = K.variable(np.random.rand(2, 2, 3, 3, 3))

from theano import tensor as T
from theano.tensor.nnet import conv3d2d
from theano.sandbox.cuda import dnn


def conv3d_a(x, kernel, strides=(1, 1, 1),
           border_mode='valid', dim_ordering='th',
           volume_shape=None, filter_shape=None):
    '''
    Run on cuDNN if available.
    border_mode: string, "same" or "valid".
    '''
    if dim_ordering not in {'th', 'tf'}:
        raise Exception('Unknown dim_ordering ' + str(dim_ordering))

    if border_mode not in {'same', 'valid'}:
        raise Exception('Invalid border mode: ' + str(border_mode))

    if dim_ordering == 'tf':
        # TF uses the last dimension as channel dimension,
        # instead of the 2nd one.
        # TH input shape: (samples, input_depth, conv_dim1, conv_dim2, conv_dim3)
        # TF input shape: (samples, conv_dim1, conv_dim2, conv_dim3, input_depth)
        # TH kernel shape: (out_depth, input_depth, kernel_dim1, kernel_dim2, kernel_dim3)
        # TF kernel shape: (kernel_dim1, kernel_dim2, kernel_dim3, input_depth, out_depth)
        x = x.dimshuffle((0, 4, 1, 2, 3))
        kernel = kernel.dimshuffle((4, 3, 0, 1, 2))
        if volume_shape:
            volume_shape = (volume_shape[0], volume_shape[4],
                            volume_shape[1], volume_shape[2], volume_shape[3])
        if filter_shape:
            filter_shape = (filter_shape[4], filter_shape[3],
                            filter_shape[0], filter_shape[1], filter_shape[2])

    if border_mode == 'same':
        assert(strides == (1, 1, 1))
        pad_dim1 = (kernel.shape[2] - 1)
        pad_dim2 = (kernel.shape[3] - 1)
        pad_dim3 = (kernel.shape[4] - 1)
        output_shape = (x.shape[0], x.shape[1],
                        x.shape[2] + pad_dim1,
                        x.shape[3] + pad_dim2,
                        x.shape[4] + pad_dim3)
        output = T.zeros(output_shape)
        indices = (slice(None), slice(None),
                   slice(pad_dim1 // 2, x.shape[2] + pad_dim1 // 2),
                   slice(pad_dim2 // 2, x.shape[3] + pad_dim2 // 2),
                   slice(pad_dim3 // 2, x.shape[4] + pad_dim3 // 2))
        x = T.set_subtensor(output[indices], x)
        border_mode = 'valid'

    border_mode_3d = (border_mode, border_mode, border_mode)
    conv_out = conv3d2d.conv3d(signals=x.dimshuffle(0, 2, 1, 3, 4),
                               filters=kernel.dimshuffle(0, 2, 1, 3, 4),
                               border_mode=border_mode_3d)
    conv_out = conv_out.dimshuffle(0, 2, 1, 3, 4)

    # support strides by manually slicing the output
    if strides != (1, 1, 1):
        conv_out = conv_out[:, :, ::strides[0], ::strides[1], ::strides[2]]

    if dim_ordering == 'tf':
        conv_out = conv_out.dimshuffle((0, 2, 3, 4, 1))

    return conv_out


def conv3d_b(x, kernel, strides=(1, 1, 1),
           border_mode='valid', dim_ordering='th',
           volume_shape=None, filter_shape=None):
    '''
    Run on cuDNN if available.
    border_mode: string, "same" or "valid".
    '''
    if dim_ordering not in {'th', 'tf'}:
        raise Exception('Unknown dim_ordering ' + str(dim_ordering))

    if border_mode not in {'same', 'valid'}:
        raise Exception('Invalid border mode: ' + str(border_mode))

    if dim_ordering == 'tf':
        # TF uses the last dimension as channel dimension,
        # instead of the 2nd one.
        # TH input shape: (samples, input_depth, conv_dim1, conv_dim2, conv_dim3)
        # TF input shape: (samples, conv_dim1, conv_dim2, conv_dim3, input_depth)
        # TH kernel shape: (out_depth, input_depth, kernel_dim1, kernel_dim2, kernel_dim3)
        # TF kernel shape: (kernel_dim1, kernel_dim2, kernel_dim3, input_depth, out_depth)
        x = x.dimshuffle((0, 4, 1, 2, 3))
        kernel = kernel.dimshuffle((4, 3, 0, 1, 2))
        if volume_shape:
            volume_shape = (volume_shape[0], volume_shape[4],
                            volume_shape[1], volume_shape[2], volume_shape[3])
        if filter_shape:
            filter_shape = (filter_shape[4], filter_shape[3],
                            filter_shape[0], filter_shape[1], filter_shape[2])

    if border_mode == 'same':
        # assert(strides == (1, 1, 1))
        # pad_dim1 = (kernel.shape[2] - 1)
        # pad_dim2 = (kernel.shape[3] - 1)
        # pad_dim3 = (kernel.shape[4] - 1)
        # output_shape = (x.shape[0], x.shape[1],
        #                 x.shape[2] + pad_dim1,
        #                 x.shape[3] + pad_dim2,
        #                 x.shape[4] + pad_dim3)
        # output = T.zeros(output_shape)
        # indices = (slice(None), slice(None),
        #            slice(pad_dim1 // 2, x.shape[2] + pad_dim1 // 2),
        #            slice(pad_dim2 // 2, x.shape[3] + pad_dim2 // 2),
        #            slice(pad_dim3 // 2, x.shape[4] + pad_dim3 // 2))
        # x = T.set_subtensor(output[indices], x)
        # border_mode = 'valid'
        border_mode = tuple(s//2 for s in filter_shape[-3:])

    # border_mode_3d = (border_mode, border_mode, border_mode)
    # conv_out = conv3d2d.conv3d(signals=x.dimshuffle(0, 2, 1, 3, 4),
    #                            filters=kernel.dimshuffle(0, 2, 1, 3, 4),
    #                            border_mode=border_mode_3d)
    # conv_out = conv_out.dimshuffle(0, 2, 1, 3, 4)
    conv_out = dnn.dnn_conv3d(x, kernel, border_mode=border_mode)

    # support strides by manually slicing the output
    if strides != (1, 1, 1):
        conv_out = conv_out[:, :, ::strides[0], ::strides[1], ::strides[2]]

    if dim_ordering == 'tf':
        conv_out = conv_out.dimshuffle((0, 2, 3, 4, 1))

    return conv_out

print('Mode conv3d2d')
output = conv3d_a(x, kernel, border_mode='same')
print(output.eval())

print('Mode dnn.conv3d')
output = conv3d_b(x, kernel, border_mode='same', filter_shape=(2, 2, 3, 3, 3))
print(output.eval())

In [5]:
import keras.backend as K
import numpy as np
x = K.variable(np.random.rand(4, 4, 4))

In [8]:



[[[ 0.19932584  0.85690832  0.13299252  0.4397842 ]
  [ 0.94352257  0.98625582  0.83453041  0.99154365]
  [ 0.86132693  0.60111326  0.52935213  0.67306691]
  [ 0.56666249  0.82525438  0.82012677  0.52550912]]

 [[ 0.5816378   0.70255286  0.37985691  0.43352047]
  [ 0.38271949  0.74984944  0.08021767  0.4320659 ]
  [ 0.70190269  0.14683472  0.38364509  0.87995744]
  [ 0.65994072  0.73156655  0.57552189  0.22063945]]

 [[ 0.24744232  0.38457248  0.01875688  0.41870102]
  [ 0.99824792  0.9314158   0.35870582  0.24894296]
  [ 0.41254723  0.17246178  0.21148604  0.9790445 ]
  [ 0.6010167   0.7606715   0.88511825  0.23119785]]

 [[ 0.48503515  0.84920651  0.0529296   0.00618725]
  [ 0.15415038  0.08498009  0.39229175  0.06770945]
  [ 0.36610326  0.0097663   0.87874281  0.38494775]
  [ 0.09965402  0.48656422  0.4601936   0.55113703]]]

In [ ]: