In [107]:
import numpy as np
P = 9; T = 8; X = 7; Z = 17; D = 5
data = np.arange(P*T*Z*D).reshape(P,T,Z,D)
ind = np.random.randint(0, Z, size=(P,T,X))

In [108]:
result = data.reshape(-1, D)[ind.reshape(-1, X)+np.arange(P*T)[..., None], :].reshape(P,T,X,D)
print(result[0, 7])


[[ 35  36  37  38  39]
 [ 40  41  42  43  44]
 [100 101 102 103 104]
 [ 80  81  82  83  84]
 [ 65  66  67  68  69]
 [110 111 112 113 114]
 [ 80  81  82  83  84]]

In [114]:
data[6, 7][ind[6, 7]]


Out[114]:
array([[4685, 4686, 4687, 4688, 4689],
       [4715, 4716, 4717, 4718, 4719],
       [4710, 4711, 4712, 4713, 4714],
       [4755, 4756, 4757, 4758, 4759],
       [4690, 4691, 4692, 4693, 4694],
       [4695, 4696, 4697, 4698, 4699],
       [4740, 4741, 4742, 4743, 4744]])

In [111]:
result = data[np.arange(P, dtype=int)[:, None, None, None], np.arange(T, dtype=int)[None, :, None, None], ind[..., None], np.arange(D, dtype=int)[None, None, None, :]]

In [113]:
print(result[6, 7])


[[4685 4686 4687 4688 4689]
 [4715 4716 4717 4718 4719]
 [4710 4711 4712 4713 4714]
 [4755 4756 4757 4758 4759]
 [4690 4691 4692 4693 4694]
 [4695 4696 4697 4698 4699]
 [4740 4741 4742 4743 4744]]

In [ ]: