In [348]:
import numpy as np
import torch

In [26]:
M = np.arange(1,26)
np.random.shuffle(M)
M = torch.tensor(M.reshape(5,5))

In [27]:
M


Out[27]:
tensor([[ 8, 11, 17, 13, 14],
        [ 7, 10, 19,  5, 12],
        [15,  1, 23,  9, 20],
        [ 4,  2, 25, 16,  3],
        [21, 18, 22,  6, 24]])

In [18]:
# what is contiguous doing anyway?
# some error brought up in torch - don't remember now exactly the cause
N  = M.contiguous()

In [36]:
tau = 0.2
kth = int(tau * np.prod(M.shape))
# kth value aligs, no need to sort
theta, _ = torch.kthvalue(M.view(-1), kth)
theta


Out[36]:
tensor(5)

In [37]:
M > theta


Out[37]:
tensor([[ True,  True,  True,  True,  True],
        [ True,  True,  True, False,  True],
        [ True, False,  True,  True,  True],
        [False, False,  True,  True, False],
        [ True,  True,  True,  True,  True]])

The ones that should stay - are the ones with highest correlation. But while the weight keep mask makes sense, keep the highest inweight


In [30]:
M[M>theta]


Out[30]:
tensor([23, 25, 21, 22, 24])

In [19]:
# I'm only keeping the top 20%
# what is the effect when I add the weight mask?


Out[19]:
(torch.Size([10, 10]), torch.Size([10, 10]))

In [32]:
# calculate weight mask
weight = M
zeta = 0.3
weight_pos = weight[weight > 0]
pos_threshold, _ = torch.kthvalue(
    weight_pos, max(int(zeta * len(weight_pos)), 1)
)
weight_keep_mask = (weight >= pos_threshold)
weight_keep_mask


Out[32]:
tensor([[ True,  True,  True,  True,  True],
        [ True,  True,  True, False,  True],
        [ True, False,  True,  True,  True],
        [False, False,  True,  True, False],
        [ True,  True,  True, False,  True]])

In [34]:
M[weight_keep_mask]


Out[34]:
tensor([ 8, 11, 17, 13, 14,  7, 10, 19, 12, 15, 23,  9, 20, 25, 16, 21, 18, 22,
        24])

Testing hebbian only


In [183]:
M = np.arange(1,26)
np.random.shuffle(M)
M = torch.tensor(M.reshape(5,5))

In [184]:
weight = M
corr = torch.rand((5,5))
weight[1, :] = 0
weight[:, 3] = 0
weight[:, 4] = 0
num_params = torch.sum(weight != 0).item()
corr, weight, num_params


Out[184]:
(tensor([[0.3201, 0.8318, 0.3382, 0.9734, 0.0985],
         [0.0401, 0.8620, 0.0845, 0.3778, 0.3996],
         [0.4954, 0.0092, 0.6713, 0.8594, 0.9487],
         [0.8101, 0.0922, 0.2033, 0.7185, 0.4588],
         [0.3897, 0.6865, 0.5072, 0.9749, 0.0597]]),
 tensor([[19, 21, 12,  0,  0],
         [ 0,  0,  0,  0,  0],
         [10, 25,  8,  0,  0],
         [ 2, 11,  7,  0,  0],
         [14, 18,  6,  0,  0]]),
 12)

In [350]:
def prune(weight, num_params, corr, tau, idx=0, hebbian_grow=True):
    with torch.no_grad():

        # print("corr dimension", corr.shape)
        # print("weight dimension", weight.shape)

        # transpose to fit the weights, and eliminate zero weight
        num_synapses = np.prod(corr.shape)
        # corr = corr.t()
        active_synapses = (weight != 0)
        nonactive_synapses = (weight == 0)
        total_active = torch.sum(active_synapses).item()
        total_nonactive = torch.sum(nonactive_synapses).item()

        corr_active = corr[active_synapses]
        # decide which weights to remove based on correlation
        kth = int(tau * total_active)
        print("total active: ", total_active)
        print("kth: ", kth)
        # if kth = 0, keep all the synapses
        if kth == 0:
            hebbian_keep_mask = active_synapses
        # else if kth greater than shape, remove all synapses
        elif kth >= num_synapses:
            hebbian_keep_mask = torch.zeros(corr.shape)
        # if no edge cases
        else:
            keep_threshold, _ = torch.kthvalue(corr_active, kth)
            print(keep_threshold)
            # keep mask are ones above threshold and currently active
            hebbian_keep_mask = (corr > keep_threshold) & active_synapses
            
        # keep_mask = weight_keep_mask & hebbian_keep_mask
        keep_mask = hebbian_keep_mask
        num_add = max(num_params - torch.sum(keep_mask).item(), 0)  

        # added option to have hebbian grow or not
        if hebbian_grow:
            # get threshold
            kth = total_nonactive - num_add
            corr_nonactive = corr[nonactive_synapses]
            add_threshold, _ = torch.kthvalue(corr_nonactive, kth)
            # calculate mask, only for currently nonactive
            add_mask = (corr > add_threshold) & nonactive_synapses
        else:
            # probability of adding is 1 or lower
            p_add = num_add / max(total_nonactive, num_add)
            random_sample = torch.rand(num_synapses) < p_add
            add_mask = random_sample & nonactive_synapses

        # calculate the new mask
        new_mask = keep_mask | add_mask

    # track added connections
    return new_mask, keep_mask, add_mask

In [355]:
new_mask, keep_mask, add_mask = prune(weight, num_params, corr, tau=0.25)


total active:  12
kth:  3
tensor(0.2033)

In [352]:
weight


Out[352]:
tensor([[ 19,  21, -12,   0,   0],
        [  0,   0,   0,   0,   0],
        [ 10,  25,   8,   0,   0],
        [  2, -11,   7,   0,   0],
        [ 14,  18,  -6,   0,   0]])

In [353]:
corr


Out[353]:
tensor([[0.3201, 0.8318, 0.3382, 0.9734, 0.0985],
        [0.0401, 0.8620, 0.0845, 0.3778, 0.3996],
        [0.4954, 0.0092, 0.6713, 0.8594, 0.9487],
        [0.8101, 0.0922, 0.2033, 0.7185, 0.4588],
        [0.3897, 0.6865, 0.5072, 0.9749, 0.0597]])

In [157]:
# no prunning happened
keep_mask, torch.sum(keep_mask)


Out[157]:
(tensor([[False,  True,  True, False, False],
         [False, False, False, False, False],
         [ True,  True,  True, False, False],
         [ True, False,  True, False, False],
         [ True, False,  True, False, False]]), tensor(9))

In [158]:
# adding almost all of the new items
add_mask, torch.sum(add_mask)


Out[158]:
(tensor([[False, False, False, False, False],
         [False, False, False,  True, False],
         [False, False, False, False, False],
         [False, False, False,  True, False],
         [False, False, False, False,  True]]), tensor(3))

In [159]:
new_mask, torch.sum(new_mask)


Out[159]:
(tensor([[False,  True,  True, False, False],
         [False, False, False,  True, False],
         [ True,  True,  True, False, False],
         [ True, False,  True,  True, False],
         [ True, False,  True, False,  True]]), tensor(12))

Testing hebbian + weights


In [171]:
N = np.arange(1,26)
np.random.shuffle(N)
N = torch.tensor(N.reshape(5,5))
N


Out[171]:
tensor([[21,  9, 12,  7,  6],
        [14, 10, 17,  2,  8],
        [ 5, 25, 13, 11,  1],
        [ 4,  3, 24, 15, 16],
        [20, 18, 23, 22, 19]])

In [178]:
torch.kthvalue(N.view(-1), 5)


Out[178]:
(torch.return_types.kthvalue(
 values=tensor(5),
 indices=tensor(10)), torch.return_types.kthvalue(
 values=tensor(25),
 indices=tensor(11)))

In [179]:
torch.kthvalue(N.view(-1), 25)


Out[179]:
torch.return_types.kthvalue(
values=tensor(25),
indices=tensor(11))

In [185]:
weight = M
corr = torch.rand((5,5))
weight[1, :] = 0
weight[:, 3] = 0
weight[:, 4] = 0
num_params = torch.sum(weight != 0).item()
corr, weight, num_params


Out[185]:
(tensor([[0.9703, 0.9620, 0.5798, 0.8359, 0.0570],
         [0.6940, 0.1080, 0.2981, 0.9239, 0.3559],
         [0.6024, 0.7168, 0.1934, 0.4220, 0.1958],
         [0.7452, 0.7896, 0.7346, 0.5306, 0.1022],
         [0.3658, 0.7152, 0.4189, 0.8674, 0.2408]]),
 tensor([[19, 21, 12,  0,  0],
         [ 0,  0,  0,  0,  0],
         [10, 25,  8,  0,  0],
         [ 2, 11,  7,  0,  0],
         [14, 18,  6,  0,  0]]),
 12)

In [360]:
def prune(weight, num_params, corr, tau, zeta, idx=0, hebbian_grow=True):
    with torch.no_grad():

        # print("corr dimension", corr.shape)
        # print("weight dimension", weight.shape)

        # transpose to fit the weights, and eliminate zero weight
        num_synapses = np.prod(weight.shape)
        active_synapses = (weight != 0)
        nonactive_synapses = (weight == 0)
        total_active = torch.sum(active_synapses).item()
        total_nonactive = torch.sum(nonactive_synapses).item()

        # ----------- HEBBIAN PRUNING ----------------
        
        if tau is not None:
            # corr = corr.t()
            corr_active = corr[active_synapses]
            # decide which weights to remove based on correlation
            kth = int(tau * total_active)
            print("total active: ", total_active)
            print("kth: ", kth)
            # if kth = 0, keep all the synapses
            if kth == 0:
                hebbian_keep_mask = active_synapses
            # else if kth greater than shape, remove all synapses
            elif kth >= num_synapses:
                hebbian_keep_mask = torch.zeros(weight.shape)
            # if no edge cases
            else:
                keep_threshold, _ = torch.kthvalue(corr_active, kth)
                print(keep_threshold)
                # keep mask are ones above threshold and currently active
                hebbian_keep_mask = (corr > keep_threshold) & active_synapses

        # ----------- WEIGHT PRUNING ----------------
                        
        if zeta is not None:
            # calculate the positive
            weight_pos = weight[weight > 0]
            pos_kth = int(zeta * len(weight_pos))
            if pos_kth == 0:
                pos_threshold = -1
            else:
                pos_threshold, _ = torch.kthvalue(weight_pos, pos_kth)
            print(pos_kth, pos_threshold)
            
            # calculate the negative
            weight_neg = weight[weight < 0]
            neg_kth = int((1-zeta) * len(weight_neg))
            if neg_kth == 0:
                neg_threshold = 1
            else:
                neg_threshold, _ = torch.kthvalue(weight_neg, neg_kth)
            print(neg_kth, neg_threshold)                
                
            partial_weight_mask = (weight > pos_threshold) | (weight <= neg_threshold)
            weight_mask = partial_weight_mask & active_synapses

        # ----------- COMBINE HEBBIAN AND WEIGHT ----------------            
            
        # join both masks
        if tau and zeta:
            keep_mask = hebbian_keep_mask | weight_mask
        elif tau:
            keep_mask = hebbian_keep_mask
        elif zeta:
            keep_mask = weight_mask
        else:
            keep_mask = active_synapses

        # ----------- GROWTH ----------------            

        # calculate number of params removed to be readded
        num_add = max(num_params - torch.sum(keep_mask).item(), 0)
        print(num_add)
        # added option to have hebbian grow or not
        if hebbian_grow:
            # get threshold
            kth = total_nonactive - num_add
            corr_nonactive = corr[nonactive_synapses]
            add_threshold, _ = torch.kthvalue(corr_nonactive, kth)
            # calculate mask, only for currently nonactive
            add_mask = (corr > add_threshold) & nonactive_synapses
        else:
            # probability of adding is 1 or lower
            p_add = num_add / max(total_nonactive, num_add)
            print(p_add)
            random_sample = torch.rand(num_synapses) < p_add
            add_mask = random_sample & nonactive_synapses

        # calculate the new mask
        new_mask = keep_mask | add_mask

    # track added connections
    return new_mask, keep_mask, add_mask

In [357]:
weight = M
corr = torch.rand((5,5))
weight[1, :] = 0
weight[:, 3] = 0
weight[:, 4] = 0
num_params = torch.sum(weight != 0).item()
corr, weight, num_params


Out[357]:
(tensor([[0.7310, 0.6574, 0.8158, 0.6538, 0.5781],
         [0.7679, 0.7877, 0.4430, 0.9151, 0.9181],
         [0.0374, 0.2257, 0.6370, 0.9288, 0.2408],
         [0.5740, 0.7109, 0.2618, 0.2762, 0.1792],
         [0.2557, 0.9166, 0.7356, 0.7074, 0.0236]]),
 tensor([[19, 21, 12,  0,  0],
         [ 0,  0,  0,  0,  0],
         [10, 25,  8,  0,  0],
         [ 2, 11,  7,  0,  0],
         [14, 18,  6,  0,  0]]),
 12)

In [387]:
# coactivation matrix
corr = torch.tensor([ [0.3201, 0.8318, 0.3382, 0.9734, 0.0985],
                [0.0401, 0.8620, 0.0845, 0.3778, 0.3996],
                [0.4954, 0.0092, 0.6713, 0.8594, 0.9487],
                [0.8101, 0.0922, 0.2033, 0.7185, 0.4588],
                [0.3897, 0.6865, 0.5072, 0.9749, 0.0597]])
# weight matrix
weight = torch.tensor([
                 [19, 21, -12,  0,  0],
                 [ 0,  0,  0,  0,  0],
                 [-10, 25, -8,  0,  0],
                 [ 2, -11,  7,  0,  0],
                 [-14, 18,  -6,  0,  0]])
num_params = torch.sum(weight != 0).item()

In [388]:
new_mask, keep_mask, add_mask = prune(weight, num_params, corr, tau=0.25, zeta=0.50)


total active:  12
kth:  3
tensor(0.2033)
3 tensor(18)
3 tensor(-11)
1

In [389]:
keep_mask


Out[389]:
tensor([[ True,  True,  True, False, False],
        [False, False, False, False, False],
        [ True,  True,  True, False, False],
        [ True,  True, False, False, False],
        [ True,  True,  True, False, False]])

In [386]:
add_mask


Out[386]:
tensor([[False, False, False,  True, False],
        [False, False, False, False, False],
        [False, False, False, False, False],
        [False, False, False, False, False],
        [False, False, False,  True, False]])

In [379]:
new_mask


Out[379]:
tensor([[ True,  True,  True, False, False],
        [False, False, False, False, False],
        [ True,  True,  True, False, False],
        [ True,  True, False, False, False],
        [ True,  True,  True,  True, False]])

Flipping the logic


In [336]:
def prune(weight, num_params, corr, tau, zeta, idx=0, hebbian_grow=True):
    with torch.no_grad():

        # print("corr dimension", corr.shape)
        # print("weight dimension", weight.shape)

        # transpose to fit the weights, and eliminate zero weight
        num_synapses = np.prod(weight.shape)
        active_synapses = (weight != 0)
        nonactive_synapses = (weight == 0)
        total_active = torch.sum(active_synapses).item()
        total_nonactive = torch.sum(nonactive_synapses).item()

        # ----------- HEBBIAN PRUNING ----------------
        
        if tau is not None:
            # corr = corr.t()
            corr_active = corr[active_synapses]
            # decide which weights to remove based on correlation
            kth = int((1-tau) * total_active)
            print("total active: ", total_active)
            print("kth: ", kth)
            # if kth = 0, keep all the synapses
            if kth == 0:
                hebbian_keep_mask = torch.zeros(weight.shape).bool()
            # else if kth greater than shape, remove all synapses
            elif kth >= num_synapses:
                hebbian_keep_mask = active_synapses
            # if no edge cases
            else:
                keep_threshold, _ = torch.kthvalue(corr_active, kth)
                print(keep_threshold)
                # keep mask are ones above threshold and currently active
                hebbian_keep_mask = (corr <= keep_threshold) & active_synapses
            print("hebbian_keep_mask",  hebbian_keep_mask)

        # ----------- WEIGHT PRUNING ----------------
                        
        if zeta is not None:
            
            # calculate the positive
            weight_pos = weight[weight > 0]
            pos_kth = int(zeta * len(weight_pos))
            # if no positive weight, threshold can be 0 (select none)
            if len(weight_pos) > 0:
                # if zeta=0, pos_kth=0, prune nothing
                if pos_kth == 0:
                    pos_threshold = -1
                else:
                    pos_threshold, _ = torch.kthvalue(weight_pos, pos_kth)
            else:
                pos_threshold = 0

            # calculate the negative
            weight_neg = weight[weight < 0]
            neg_kth = int((1-zeta) * len(weight_neg))
            # if no negative weight, threshold -1 (select none)
            if len(weight_neg) > 0:
                # if zeta=1, neg_kth=0, prune all
                if neg_kth == 0:
                    neg_threshold = torch.min(weight_neg).item() - 1
                else:
                    neg_threshold, _ = torch.kthvalue(weight_neg, neg_kth)
            else:
                neg_threshold = -1

            partial_weight_mask = (weight > pos_threshold) | (weight <= neg_threshold)
            weight_mask = partial_weight_mask & active_synapses
            print("weight_mask", weight_mask)

        # ----------- COMBINE HEBBIAN AND WEIGHT ----------------            
            
        # join both masks
        if tau and zeta:
            keep_mask = hebbian_keep_mask | weight_mask
        elif tau:
            keep_mask = hebbian_keep_mask
        elif zeta:
            keep_mask = weight_mask
        else:
            keep_mask = active_synapses

        # ----------- GROWTH ----------------            

        # calculate number of params removed to be readded
        num_add = max(num_params - torch.sum(keep_mask).item(), 0)
        print(num_add)
        # added option to have hebbian grow or not
        if hebbian_grow:
            # get threshold
            kth = num_add # should not be non-int
            if kth > 0:
                corr_nonactive = corr[nonactive_synapses]
                add_threshold, _ = torch.kthvalue(corr_nonactive, kth)
                # calculate mask, only for currently nonactive
                add_mask = (corr <= add_threshold) & nonactive_synapses
            # if there is nothing to add, return zeros
            else:
                add_mask = torch.zeros(weight.shape).bool()
        else:
            # probability of adding is 1 or lower
            p_add = num_add / max(total_nonactive, num_add)
            print(p_add)
            random_sample = torch.rand(num_synapses) < p_add
            add_mask = random_sample & nonactive_synapses

        # calculate the new mask
        new_mask = keep_mask | add_mask

    # track added connections
    return new_mask, keep_mask, add_mask

In [337]:
weight = M
corr = torch.rand((5,5))
weight[1, :] = 0
weight[:, 3] = 0
weight[:, 4] = 0
num_params = torch.sum(weight != 0).item()
corr, weight, num_params


Out[337]:
(tensor([[0.2943, 0.4472, 0.6258, 0.5547, 0.6841],
         [0.7235, 0.9685, 0.5549, 0.5836, 0.1360],
         [0.7461, 0.2407, 0.3790, 0.2589, 0.1135],
         [0.6667, 0.4287, 0.4017, 0.2251, 0.2324],
         [0.2342, 0.6125, 0.4358, 0.9662, 0.5876]]),
 tensor([[19, 21, 12,  0,  0],
         [ 0,  0,  0,  0,  0],
         [10, 25,  8,  0,  0],
         [ 2, 11,  7,  0,  0],
         [14, 18,  6,  0,  0]]),
 12)

In [343]:
# new_mask, keep_mask, add_mask = prune(weight, num_params, corr, tau=1, zeta=1)
# new_mask, keep_mask, add_mask = prune(weight, num_params, corr, tau=1, zeta=None)
# new_mask, keep_mask, add_mask = prune(weight, num_params, corr, tau=None, zeta=1)
new_mask, keep_mask, add_mask = prune(weight, num_params, corr, tau=None, zeta=None)


0

In [344]:
keep_mask


Out[344]:
tensor([[ True,  True,  True, False, False],
        [False, False, False, False, False],
        [ True,  True,  True, False, False],
        [ True,  True,  True, False, False],
        [ True,  True,  True, False, False]])

In [345]:
add_mask


Out[345]:
tensor([[False, False, False, False, False],
        [False, False, False, False, False],
        [False, False, False, False, False],
        [False, False, False, False, False],
        [False, False, False, False, False]])

In [346]:
new_mask


Out[346]:
tensor([[ True,  True,  True, False, False],
        [False, False, False, False, False],
        [ True,  True,  True, False, False],
        [ True,  True,  True, False, False],
        [ True,  True,  True, False, False]])

In [ ]: