In [1]:
import torch
import numpy as np
import itertools
import seaborn as sns
import matplotlib.pyplot as plt
In [2]:
data = np.zeros((3, 6), dtype=np.float32)
s = [4, 1, 3]
mask = np.tile(np.arange(6), reps=(3, 1)) < np.tile(s, (6, 1)).T
data[mask] = 0
data[~mask] = 1
data
Out[2]:
In [3]:
annot = np.tile(np.arange(6), reps=(3, 1))
annot = np.arange(1, 19).reshape(3, 6)
fig, ax = plt.subplots(1, 2, sharey=True, figsize=(12, 3))
rev_annot = []
rev_data = []
for i in range(annot.shape[0]):
sp = s[i]
rev_annot.append(
np.concatenate((annot[i, sp:], annot[i, :sp]))
)
rev_data.append(
np.concatenate((data[i, sp:], data[i, :sp]))
)
rev_annot = np.stack(rev_annot)
rev_data = np.stack(rev_data)
sns.heatmap(data, ax=ax[0], cbar=False, linewidths=.5, cmap='Set3',
annot=annot, xticklabels=False, yticklabels=False,
annot_kws={'fontsize': 'xx-large'})
sns.heatmap(rev_data, ax=ax[1], cbar=False, linewidths=.5, cmap='Set3',
annot=rev_annot, xticklabels=False, yticklabels=False,
annot_kws={'fontsize': 'xx-large'})
fig.savefig("random_shift_rows.png", dpi=100)
In [4]:
def shift_rows_numpy1(data, splits):
shifted = []
for i in range(X.shape[0]):
sp = splits[i]
shifted.append(
np.concatenate((data[i, sp:], data[i, :sp]))
)
return np.stack(shifted)
In [5]:
def shift_rows_numpy2(data, splits):
shifted = []
for i in range(X.shape[0]):
sp = splits[i]
if i == 0:
shifted = np.concatenate((data[i, sp:], data[i, :sp])).reshape(1, -1)
else:
shifted = np.concatenate((
shifted,
np.concatenate((data[i, sp:], data[i, :sp])).reshape(1, -1)
), 0)
return shifted
In [6]:
def shift_rows_numpy3(data, splits):
N, M = data.shape[:2]
mask = np.tile(np.arange(M), (N, 1)) < np.tile(splits, (M, 1)).T
inv_mask = np.tile(np.arange(M), (N, 1)) >= M - np.tile(splits, (M, 1)).T
mask_i, mask_j = np.where(mask)
inv_i, inv_j = np.where(inv_mask)
shifted = np.zeros_like(data)
shifted[inv_i, inv_j] = data[mask_i, mask_j]
mask_i, mask_j = np.where(~mask)
inv_i, inv_j = np.where(~inv_mask)
shifted[inv_i, inv_j] = data[mask_i, mask_j]
return shifted
In [7]:
N = 1000
M = 2000
X = np.random.random(size=(N, M))
splits = np.random.randint(M, size=N)
In [8]:
%%timeit
shift_rows_numpy1(X, splits)
In [9]:
%%timeit
shift_rows_numpy2(X, splits)
In [10]:
%%timeit
shift_rows_numpy3(X, splits)
In [11]:
np.array_equal(shift_rows_numpy1(X, splits), shift_rows_numpy2(X, splits))
Out[11]:
In [12]:
np.array_equal(shift_rows_numpy3(X, splits), shift_rows_numpy2(X, splits))
Out[12]:
In [13]:
def shift_rows_pytorch1(data, splits):
shifted = []
for i, row in enumerate(data):
shifted.append(
torch.cat((row[splits[i]:], row[:splits[i]]))
)
return torch.stack(shifted)
In [30]:
def shift_rows_pytorch2(X, splits):
for i, row in enumerate(X):
if i == 0:
shifted = torch.cat((row[splits[i]:], row[:splits[i]])).view(1, -1)
else:
shifted = torch.cat((
shifted,
torch.cat((row[splits[i]:], row[:splits[i]])).view(1, -1)
))
return shifted
In [31]:
def shift_rows_pytorch3(data, splits):
N, M = data.shape[:2]
all_idx = torch.arange(M, out=torch.LongTensor()).repeat((N, 1))
mask = all_idx < splits.repeat((M, 1)).t()
inv_mask = all_idx >= M - splits.repeat((M, 1)).t()
idx = mask.nonzero()
i = idx[:, 0]
j = idx[:, 1]
inv_idx = inv_mask.nonzero()
inv_i = inv_idx[:, 0]
inv_j = inv_idx[:, 1]
shifted = torch.zeros_like(data)
shifted[inv_i, inv_j] = data[i, j]
idx = (mask == 0).nonzero()
i = idx[:, 0]
j = idx[:, 1]
inv_idx = (inv_mask == 0).nonzero()
inv_i = inv_idx[:, 0]
inv_j = inv_idx[:, 1]
shifted[inv_i, inv_j] = data[i, j]
return shifted
In [32]:
def shift_rows_pytorch4(X, splits):
N = X.size(0)
M = X.size(1)
all_idx = torch.arange(M, out=torch.LongTensor()).repeat((N, 1))
all_j = torch.arange(N, out=torch.LongTensor()).repeat((M, 1)).t()
mask = all_idx < splits.repeat((M, 1)).t()
X2 = torch.zeros_like(X)
i = all_j[~mask]
j = all_idx[~mask]
flipped_mask = mask[:, list(range(M-1, -1, -1))]
X2[~flipped_mask] = X[i, j]
i = all_j[mask]
j = all_idx[mask]
X2[flipped_mask] = X[i, j]
return X2
In [33]:
N = 1000
M = 2000
splits = torch.from_numpy(np.random.randint(M, size=N))
X = torch.FloatTensor(N, M).random_(200, 500).cuda()
In [34]:
%%timeit
shift_rows_pytorch1(X, splits)
In [35]:
%%timeit
shift_rows_pytorch2(X, splits)
In [36]:
%%timeit
shift_rows_pytorch3(X, splits)
In [37]:
%%timeit
shift_rows_pytorch4(X, splits)
In [38]:
N = 1000
M = 2000
splits = torch.from_numpy(np.random.randint(M, size=N))
X = torch.FloatTensor(N, M).random_(200, 500).cuda()
if torch.cuda.is_available():
X = X.cuda()
else:
print("CUDA unavailable")
In [39]:
%%timeit
shift_rows_pytorch1(X, splits)
In [40]:
%%timeit
shift_rows_pytorch2(X, splits)
In [41]:
%%timeit
shift_rows_pytorch3(X, splits)
In [42]:
%%timeit
shift_rows_pytorch4(X, splits)
In [43]:
(shift_rows_pytorch1(X, splits) != shift_rows_pytorch2(X, splits)).nonzero()
Out[43]:
In [44]:
(shift_rows_pytorch1(X, splits) != shift_rows_pytorch3(X, splits)).nonzero()
Out[44]:
In [45]:
(shift_rows_pytorch1(X, splits) != shift_rows_pytorch4(X, splits)).nonzero()
Out[45]: