In [1]:
import torch
import time
from src.architectures.jet_transforms.nmp.sparsegraphgen import sparse_topk, sparse
from ilya.src.utils.utils_pt import sparse_cat
In [6]:
k = 3
d = 13
bs = 3
n = 1000
reps = 100
S = torch.round(torch.rand(2,n,n) + 1000)
In [7]:
S_topk = [sparse_topk(s, k) for s in S]
#S_topk = sparse_cat(S_topk,n,n)
In [8]:
def time_sparse_topk(reps, bs, n, k, d):
import time
t = time.time()
for i in range(reps):
S = torch.round(torch.rand(bs,n,n) + 100)
sps = [sparse_topk(m, k) for m in S]
vecs = [torch.bernoulli(0.1 * torch.ones(n,d)) for i in range(bs)]
for i, (s, v) in enumerate(zip(sps, vecs)):
outs = [sp.mm(v) for sp, v in zip(sps, vecs)]
t = (time.time() - t) / reps
print("{:.1f}".format(t))
In [5]:
time_sparse_topk(reps,bs,n,k,d)
In [105]:
s = S_topk[0]
s.to_dense()
Out[105]:
In [49]:
s.mm(vec.to_dense())
Out[49]:
In [50]:
Out[50]:
In [158]:
In [147]:
In [149]:
In [82]:
In [81]:
print(sparse_topk(S,k).to_dense())
In [70]:
In [84]:
bs = 3
S = torch.round(torch.rand(bs, n,n) * 100 - 50)
print(S)
In [85]:
sparse = [sparse_topk(m, k) for m in S]
In [87]:
for s in sparse:
print(s.to_dense())
In [89]:
sparse_catd = sparse_cat(sparse, n,n)
In [96]:
print(sparse_catd._indices())
In [99]:
for s in sparse:
print(s._indices())
In [ ]: