In [1]:
import torch
import time
from src.architectures.jet_transforms.nmp.sparsegraphgen import sparse_topk, sparse
from ilya.src.utils.utils_pt import sparse_cat

In [6]:
k = 3
d = 13
bs = 3
n = 1000
reps = 100
S = torch.round(torch.rand(2,n,n)  + 1000)

In [7]:
S_topk = [sparse_topk(s, k) for s in S]
#S_topk = sparse_cat(S_topk,n,n)

In [8]:
def time_sparse_topk(reps, bs, n, k, d):
    import time
    t = time.time()
    for i in range(reps):
        S = torch.round(torch.rand(bs,n,n) + 100)
        sps = [sparse_topk(m, k) for m in S]
        vecs = [torch.bernoulli(0.1 * torch.ones(n,d)) for i in range(bs)]
        for i, (s, v) in enumerate(zip(sps, vecs)):
            outs = [sp.mm(v) for sp, v in zip(sps, vecs)]
        

    t = (time.time() - t) / reps
    print("{:.1f}".format(t))


  File "<ipython-input-8-11b5d2a38c42>", line 10
    outs = [sp.mm(v) for sp, v in zip(sps, vecs)]
       ^
IndentationError: expected an indented block

In [5]:
time_sparse_topk(reps,bs,n,k,d)


0.0

In [105]:
s = S_topk[0]
s.to_dense()


Out[105]:
    0  1001     0  1001     0  1001
    0     0  1001     0  1001  1001
 1001     0  1001  1001     0     0
 1000  1000     0     0     0  1001
 1001     0  1001  1001     0     0
 1001  1001     0  1001     0     0
 1001     0  1001  1001     0     0
[torch.FloatTensor of size (7,6)]

In [49]:
s.mm(vec.to_dense())


Out[49]:
   40     0     0     0     0     0     0     0     0     5     0     0     0
   35     0     0     0     0     0    44    44     0    44     0    44    44
    0     0     0     0     0     0     0     0     0    11     0     0     0
    0     0     0     0     0     0     0     0     0     0     0     0     0
    0     0     0     0     0     0     0     0     0    27     0     0     0
   45     0     0     0     0     0    24    24     0    24     0    24    24
    0     0     0     0     0     0     0     0     0    37     0     0     0
[torch.FloatTensor of size (7,13)]

In [50]:



Out[50]:
   40    37     5     0     0     0     0
   35     0     0    44    28     0     0
    0     0    11     0    30    21     0
    0     0     0     0    11    22     4
    0    47    27     0     0     0    44
   45     0     0    24     0    18     0
    0     0    37     0    21     0    20
[torch.FloatTensor of size (7,7)]

In [158]:



0.6

In [147]:


In [149]:



0.6

In [82]:



  0   0   9  -6  28
 45   0  24   0  20
  0  33  38   0  41
 21   3   0   0  30
-31  39 -32   0   0
[torch.FloatTensor of size (5,5)]


In [81]:
print(sparse_topk(S,k).to_dense())


  0   0   9  -6  28
 45   0  24   0  20
  0  33  38   0  41
 21   3   0   0  30
-31  39 -32   0   0
[torch.FloatTensor of size (5,5)]


In [70]:


In [84]:
bs = 3
S = torch.round(torch.rand(bs, n,n) * 100 - 50)
print(S)


(0 ,.,.) = 
  11 -38  -2 -23  33
 -28  29 -29 -37 -45
 -37 -33  35 -27 -46
 -31   1 -30 -47 -12
  48 -25 -30  27 -26

(1 ,.,.) = 
  49  46  25 -22 -41
  44  38  -7 -49  25
 -44  28 -37 -41 -25
 -24 -11  23  -3  24
  12  -2  -9   3  -3

(2 ,.,.) = 
   8   3  34  -5  31
 -19  16  43 -45 -15
  11 -47 -12 -28 -19
   4  -8  27 -49  32
 -11  -4 -40   4  -1
[torch.FloatTensor of size (3,5,5)]


In [85]:
sparse = [sparse_topk(m, k) for m in S]

In [87]:
for s in sparse:
    print(s.to_dense())


 11   0  -2   0  33
-28  29 -29   0   0
  0 -33  35 -27   0
  0   1 -30   0 -12
 48 -25   0  27   0
[torch.FloatTensor of size (5,5)]


 49  46  25   0   0
 44  38   0   0  25
  0  28 -37   0 -25
  0   0  23  -3  24
 12  -2   0   3   0
[torch.FloatTensor of size (5,5)]


  8   0  34   0  31
  0  16  43   0 -15
 11   0 -12   0 -19
  4   0  27   0  32
  0  -4   0   4  -1
[torch.FloatTensor of size (5,5)]


In [89]:
sparse_catd = sparse_cat(sparse, n,n)

In [96]:
print(sparse_catd._indices())



Columns 0 to 12 
    0     0     0     0     0     0     0     0     0     0     0     0     0
    0     0     0     1     1     1     2     2     2     3     3     3     4
    0     2     4     0     1     2     1     2     3     1     2     4     0

Columns 13 to 25 
    0     0     1     1     1     1     1     1     1     1     1     1     1
    4     4     0     0     0     1     1     1     2     2     2     3     3
    1     3     0     1     2     0     1     4     1     2     4     2     3

Columns 26 to 38 
    1     1     1     1     2     2     2     2     2     2     2     2     2
    3     4     4     4     0     0     0     1     1     1     2     2     2
    4     0     1     3     0     2     4     1     2     4     0     2     4

Columns 39 to 44 
    2     2     2     2     2     2
    3     3     3     4     4     4
    0     2     4     1     3     4
[torch.LongTensor of size (3,45)]


In [99]:
for s in sparse:
    print(s._indices())



Columns 0 to 12 
    0     0     0     1     1     1     2     2     2     3     3     3     4
    4     0     2     1     0     2     2     3     1     1     4     2     0

Columns 13 to 14 
    4     4
    3     1
[torch.LongTensor of size (2,15)]



Columns 0 to 12 
    0     0     0     1     1     1     2     2     2     3     3     3     4
    0     1     2     0     1     4     1     4     2     4     2     3     0

Columns 13 to 14 
    4     4
    3     1
[torch.LongTensor of size (2,15)]



Columns 0 to 12 
    0     0     0     1     1     1     2     2     2     3     3     3     4
    2     4     0     2     1     4     0     2     4     4     2     0     3

Columns 13 to 14 
    4     4
    4     1
[torch.LongTensor of size (2,15)]


In [ ]: