In [1]:
import torch
import numpy as np
import itertools
import seaborn as sns
import matplotlib.pyplot as plt

In [2]:
data = np.zeros((3, 6), dtype=np.float32)
s = [4, 1, 3]
mask = np.tile(np.arange(6), reps=(3, 1)) < np.tile(s, (6, 1)).T
data[mask] = 0
data[~mask] = 1
data


Out[2]:
array([[0., 0., 0., 0., 1., 1.],
       [0., 1., 1., 1., 1., 1.],
       [0., 0., 0., 1., 1., 1.]], dtype=float32)

In [3]:
annot = np.tile(np.arange(6), reps=(3, 1))
annot = np.arange(1, 19).reshape(3, 6)
fig, ax = plt.subplots(1, 2, sharey=True, figsize=(12, 3))

rev_annot = []
rev_data = []
for i in range(annot.shape[0]):
    sp = s[i]
    rev_annot.append(
        np.concatenate((annot[i, sp:], annot[i, :sp]))
    )
    rev_data.append(
        np.concatenate((data[i, sp:], data[i, :sp]))
    )
rev_annot = np.stack(rev_annot)
rev_data = np.stack(rev_data)

sns.heatmap(data, ax=ax[0], cbar=False, linewidths=.5, cmap='Set3',
            annot=annot, xticklabels=False, yticklabels=False,
            annot_kws={'fontsize': 'xx-large'})

sns.heatmap(rev_data, ax=ax[1], cbar=False, linewidths=.5, cmap='Set3',
            annot=rev_annot, xticklabels=False, yticklabels=False,
            annot_kws={'fontsize': 'xx-large'})

fig.savefig("random_shift_rows.png", dpi=100)


Numpy


In [4]:
def shift_rows_numpy1(data, splits):
    shifted = []
    for i in range(X.shape[0]):
        sp = splits[i]
        shifted.append(
            np.concatenate((data[i, sp:], data[i, :sp]))
        )
    return np.stack(shifted)

In [5]:
def shift_rows_numpy2(data, splits):
    shifted = []
    for i in range(X.shape[0]):
        sp = splits[i]
        if i == 0:
            shifted = np.concatenate((data[i, sp:], data[i, :sp])).reshape(1, -1)
        else:
            shifted = np.concatenate((
                shifted,
                np.concatenate((data[i, sp:], data[i, :sp])).reshape(1, -1)
            ), 0)
    return shifted

In [6]:
def shift_rows_numpy3(data, splits):
    N, M = data.shape[:2]
    mask = np.tile(np.arange(M), (N, 1)) < np.tile(splits, (M, 1)).T
    inv_mask = np.tile(np.arange(M), (N, 1)) >= M - np.tile(splits, (M, 1)).T
    mask_i, mask_j = np.where(mask)
    inv_i, inv_j = np.where(inv_mask)
    shifted = np.zeros_like(data)
    shifted[inv_i, inv_j] = data[mask_i, mask_j]
    mask_i, mask_j = np.where(~mask)
    inv_i, inv_j = np.where(~inv_mask)
    shifted[inv_i, inv_j] = data[mask_i, mask_j]
    return shifted

In [7]:
N = 1000
M = 2000
X = np.random.random(size=(N, M))
splits = np.random.randint(M, size=N)

In [8]:
%%timeit
shift_rows_numpy1(X, splits)


7.39 ms ± 16 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

In [9]:
%%timeit
shift_rows_numpy2(X, splits)


777 ms ± 10.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

In [10]:
%%timeit
shift_rows_numpy3(X, splits)


101 ms ± 375 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

In [11]:
np.array_equal(shift_rows_numpy1(X, splits), shift_rows_numpy2(X, splits))


Out[11]:
True

In [12]:
np.array_equal(shift_rows_numpy3(X, splits), shift_rows_numpy2(X, splits))


Out[12]:
True

Pytorch on CPU


In [13]:
def shift_rows_pytorch1(data, splits):
    shifted = []
    for i, row in enumerate(data):
        shifted.append(
            torch.cat((row[splits[i]:], row[:splits[i]]))
        )
    return torch.stack(shifted)

In [30]:
def shift_rows_pytorch2(X, splits):
    for i, row in enumerate(X):
        if i == 0:
            shifted = torch.cat((row[splits[i]:], row[:splits[i]])).view(1, -1)
        else:
            shifted = torch.cat((
                shifted,
                torch.cat((row[splits[i]:], row[:splits[i]])).view(1, -1)
            ))
    return shifted

In [31]:
def shift_rows_pytorch3(data, splits):
    N, M = data.shape[:2]
    all_idx = torch.arange(M, out=torch.LongTensor()).repeat((N, 1))
    mask = all_idx < splits.repeat((M, 1)).t()
    inv_mask = all_idx >= M - splits.repeat((M, 1)).t()
    
    idx = mask.nonzero()
    i = idx[:, 0]
    j = idx[:, 1]
    inv_idx = inv_mask.nonzero()
    inv_i = inv_idx[:, 0]
    inv_j = inv_idx[:, 1]
    
    shifted = torch.zeros_like(data)
    shifted[inv_i, inv_j] = data[i, j]
    
    idx = (mask == 0).nonzero()
    i = idx[:, 0]
    j = idx[:, 1]
    inv_idx = (inv_mask == 0).nonzero()
    inv_i = inv_idx[:, 0]
    inv_j = inv_idx[:, 1]
    shifted[inv_i, inv_j] = data[i, j]
    
    return shifted

In [32]:
def shift_rows_pytorch4(X, splits):
    N = X.size(0)
    M = X.size(1)
    all_idx = torch.arange(M, out=torch.LongTensor()).repeat((N, 1))
    all_j = torch.arange(N, out=torch.LongTensor()).repeat((M, 1)).t()
    mask = all_idx < splits.repeat((M, 1)).t()
    X2 = torch.zeros_like(X)
    i = all_j[~mask]
    j = all_idx[~mask]
    flipped_mask = mask[:, list(range(M-1, -1, -1))]
    X2[~flipped_mask] = X[i, j]
    i = all_j[mask]
    j = all_idx[mask]
    X2[flipped_mask] = X[i, j]
    return X2

In [33]:
N = 1000
M = 2000
splits = torch.from_numpy(np.random.randint(M, size=N))
X = torch.FloatTensor(N, M).random_(200, 500).cuda()

In [34]:
%%timeit
shift_rows_pytorch1(X, splits)


22.8 ms ± 53.8 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

In [35]:
%%timeit
shift_rows_pytorch2(X, splits)


56 ms ± 99.8 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

In [36]:
%%timeit
shift_rows_pytorch3(X, splits)


176 ms ± 330 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

In [37]:
%%timeit
shift_rows_pytorch4(X, splits)


192 ms ± 501 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

Pytorch on GPU


In [38]:
N = 1000
M = 2000
splits = torch.from_numpy(np.random.randint(M, size=N))
X = torch.FloatTensor(N, M).random_(200, 500).cuda()
if torch.cuda.is_available():
    X = X.cuda()
else:
    print("CUDA unavailable")

In [39]:
%%timeit
shift_rows_pytorch1(X, splits)


22.8 ms ± 87.5 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

In [40]:
%%timeit
shift_rows_pytorch2(X, splits)


55.7 ms ± 82.5 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

In [41]:
%%timeit
shift_rows_pytorch3(X, splits)


176 ms ± 235 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

In [42]:
%%timeit
shift_rows_pytorch4(X, splits)


195 ms ± 493 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

In [43]:
(shift_rows_pytorch1(X, splits) != shift_rows_pytorch2(X, splits)).nonzero()


Out[43]:
tensor([], dtype=torch.int64, device='cuda:0')

In [44]:
(shift_rows_pytorch1(X, splits) != shift_rows_pytorch3(X, splits)).nonzero()


Out[44]:
tensor([], dtype=torch.int64, device='cuda:0')

In [45]:
(shift_rows_pytorch1(X, splits) != shift_rows_pytorch4(X, splits)).nonzero()


Out[45]:
tensor([], dtype=torch.int64, device='cuda:0')