``````

In [2]:

from collections import namedtuple
LayerParams = namedtuple(
"LayerParams",
[
"percent_on_k_winner",
"boost_strength",
"boost_strength_factor",
"k_inference_factor",
"local",
"weights_density",
],
defaults=[0.25, 1.4, 0.7, 1.0, False, 0.5],
)

``````
``````

In [3]:

lp = LayerParams(boost_strength=0.5)

``````
``````

In [6]:

def prt(*args):
print(args)
prt(*lp)

``````
``````

(0.25, 0.5, 0.7, 1.0, False, 0.5)

``````
``````

In [14]:

from itertools import combinations, permutations, combinations_with_replacement, product

``````
``````

In [20]:

for idx in  product(range(5), range(5)):
print(idx)

``````
``````

(0, 0)
(0, 1)
(0, 2)
(0, 3)
(0, 4)
(1, 0)
(1, 1)
(1, 2)
(1, 3)
(1, 4)
(2, 0)
(2, 1)
(2, 2)
(2, 3)
(2, 4)
(3, 0)
(3, 1)
(3, 2)
(3, 3)
(3, 4)
(4, 0)
(4, 1)
(4, 2)
(4, 3)
(4, 4)

``````
``````

In [32]:

import torch
t = torch.rand(5,5)
q = torch.ones(5,5)

``````
``````

In [25]:

t[idx]

``````
``````

Out[25]:

tensor(0.1917)

``````
``````

In [30]:

torch.eq(t,q)

``````
``````

Out[30]:

tensor([[True, True, True, True, True],
[True, True, True, True, True],
[True, True, True, True, True],
[True, True, True, True, True],
[True, True, True, True, True]])

``````
``````

In [33]:

torch.allclose(t,q)

``````
``````

Out[33]:

False

``````
``````

In [34]:

lowest_25_hebb = [(4,0), (2,1), (0,1)]
lowest_50_mag = [(2,0), (2,2), (4,2), (0,1), (3,2), (4,1)]

``````
``````

In [40]:

set(lowest_25_hebb).intersection(lowest_50_mag)

``````
``````

Out[40]:

{(0, 1)}

``````
``````

In [47]:

corr = torch.tensor(
[
[0.3201, 0.8318, 0.3382, 0.9734, 0.0985],
[0.0401, 0.8620, 0.0845, 0.3778, 0.3996],
[0.4954, 0.0092, 0.6713, 0.8594, 0.9487],
[0.8101, 0.0922, 0.2033, 0.7185, 0.4588],
[0.3897, 0.6865, 0.5072, 0.9749, 0.0597],
]
)

weight = torch.tensor(
[
[19, 2, -12, 0, 0],
[0, 0, 0, 0, 0],
[-10, 25, -8, 0, 0],
[21, -11, 7, 0, 0],
[-14, 18, -6, 0, 0],
]
)

``````
``````

In [48]:

W = corr.T * weight.float()

``````
``````

In [49]:

W

``````
``````

Out[49]:

tensor([[ 6.0819,  0.0802, -5.9448,  0.0000,  0.0000],
[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
[-3.3820,  2.1125, -5.3704,  0.0000,  0.0000],
[20.4414, -4.1558,  6.0158,  0.0000,  0.0000],
[-1.3790,  7.1928, -5.6922,  0.0000,  0.0000]])

``````

Random sampling k elements

``````

In [51]:

weights = torch.tensor([0, 10, 3, 0], dtype=torch.float)
torch.multinomial(weights, 2)

``````
``````

Out[51]:

tensor([2, 1])

``````
``````

In [58]:

torch.multinomial((W > 0).float(), 2)

``````
``````

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-58-7d7f93659649> in <module>
----> 1 torch.multinomial((W > 0).float(), 2)

RuntimeError: invalid argument 2: invalid multinomial distribution (sum of probabilities <= 0) at ../aten/src/TH/generic/THTensorRandom.cpp:374

``````
``````

In [59]:

Wp = (W>0)

``````
``````

In [72]:

samples = torch.nonzero(Wp, as_tuple=False)

``````
``````

In [73]:

samples

``````
``````

Out[73]:

tensor([[0, 0],
[0, 1],
[2, 1],
[3, 0],
[3, 2],
[4, 1]])

``````
``````

In [74]:

samples.shape

``````
``````

Out[74]:

torch.Size([6, 2])

``````
``````

In [75]:

list(samples)

``````
``````

Out[75]:

[tensor([0, 0]),
tensor([0, 1]),
tensor([2, 1]),
tensor([3, 0]),
tensor([3, 2]),
tensor([4, 1])]

``````
``````

In [86]:

idx = np.random.choice(range(len(samples)), 3)
print(samples)
print(idx)
selected = samples[idx]
selected

``````
``````

tensor([[0, 0],
[0, 1],
[2, 1],
[3, 0],
[3, 2],
[4, 1]])
[0 2 1]

Out[86]:

tensor([[0, 0],
[2, 1],
[0, 1]])

``````
``````

In [91]:

W[list(zip(*selected))] = True

``````
``````

In [92]:

W

``````
``````

Out[92]:

tensor([[ 1.0000,  1.0000, -5.9448,  0.0000,  0.0000],
[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
[-3.3820,  1.0000, -5.3704,  0.0000,  0.0000],
[20.4414, -4.1558,  6.0158,  0.0000,  0.0000],
[-1.3790,  7.1928, -5.6922,  0.0000,  0.0000]])

``````
``````

In [62]:

Wp

``````
``````

Out[62]:

tensor([[ True,  True, False, False, False],
[False, False, False, False, False],
[False,  True, False, False, False],
[ True, False,  True, False, False],
[False,  True, False, False, False]])

``````
``````

In [ ]:

``````