In [348]:
import numpy as np
import torch
In [26]:
M = np.arange(1,26)
np.random.shuffle(M)
M = torch.tensor(M.reshape(5,5))
In [27]:
M
Out[27]:
In [18]:
# what is contiguous doing anyway?
# some error brought up in torch - don't remember now exactly the cause
N = M.contiguous()
In [36]:
tau = 0.2
kth = int(tau * np.prod(M.shape))
# kth value aligs, no need to sort
theta, _ = torch.kthvalue(M.view(-1), kth)
theta
Out[36]:
In [37]:
M > theta
Out[37]:
The ones that should stay - are the ones with highest correlation. But while the weight keep mask makes sense, keep the highest inweight
In [30]:
M[M>theta]
Out[30]:
In [19]:
# I'm only keeping the top 20%
# what is the effect when I add the weight mask?
Out[19]:
In [32]:
# calculate weight mask
weight = M
zeta = 0.3
weight_pos = weight[weight > 0]
pos_threshold, _ = torch.kthvalue(
weight_pos, max(int(zeta * len(weight_pos)), 1)
)
weight_keep_mask = (weight >= pos_threshold)
weight_keep_mask
Out[32]:
In [34]:
M[weight_keep_mask]
Out[34]:
In [183]:
M = np.arange(1,26)
np.random.shuffle(M)
M = torch.tensor(M.reshape(5,5))
In [184]:
weight = M
corr = torch.rand((5,5))
weight[1, :] = 0
weight[:, 3] = 0
weight[:, 4] = 0
num_params = torch.sum(weight != 0).item()
corr, weight, num_params
Out[184]:
In [350]:
def prune(weight, num_params, corr, tau, idx=0, hebbian_grow=True):
with torch.no_grad():
# print("corr dimension", corr.shape)
# print("weight dimension", weight.shape)
# transpose to fit the weights, and eliminate zero weight
num_synapses = np.prod(corr.shape)
# corr = corr.t()
active_synapses = (weight != 0)
nonactive_synapses = (weight == 0)
total_active = torch.sum(active_synapses).item()
total_nonactive = torch.sum(nonactive_synapses).item()
corr_active = corr[active_synapses]
# decide which weights to remove based on correlation
kth = int(tau * total_active)
print("total active: ", total_active)
print("kth: ", kth)
# if kth = 0, keep all the synapses
if kth == 0:
hebbian_keep_mask = active_synapses
# else if kth greater than shape, remove all synapses
elif kth >= num_synapses:
hebbian_keep_mask = torch.zeros(corr.shape)
# if no edge cases
else:
keep_threshold, _ = torch.kthvalue(corr_active, kth)
print(keep_threshold)
# keep mask are ones above threshold and currently active
hebbian_keep_mask = (corr > keep_threshold) & active_synapses
# keep_mask = weight_keep_mask & hebbian_keep_mask
keep_mask = hebbian_keep_mask
num_add = max(num_params - torch.sum(keep_mask).item(), 0)
# added option to have hebbian grow or not
if hebbian_grow:
# get threshold
kth = total_nonactive - num_add
corr_nonactive = corr[nonactive_synapses]
add_threshold, _ = torch.kthvalue(corr_nonactive, kth)
# calculate mask, only for currently nonactive
add_mask = (corr > add_threshold) & nonactive_synapses
else:
# probability of adding is 1 or lower
p_add = num_add / max(total_nonactive, num_add)
random_sample = torch.rand(num_synapses) < p_add
add_mask = random_sample & nonactive_synapses
# calculate the new mask
new_mask = keep_mask | add_mask
# track added connections
return new_mask, keep_mask, add_mask
In [355]:
new_mask, keep_mask, add_mask = prune(weight, num_params, corr, tau=0.25)
In [352]:
weight
Out[352]:
In [353]:
corr
Out[353]:
In [157]:
# no prunning happened
keep_mask, torch.sum(keep_mask)
Out[157]:
In [158]:
# adding almost all of the new items
add_mask, torch.sum(add_mask)
Out[158]:
In [159]:
new_mask, torch.sum(new_mask)
Out[159]:
In [171]:
N = np.arange(1,26)
np.random.shuffle(N)
N = torch.tensor(N.reshape(5,5))
N
Out[171]:
In [178]:
torch.kthvalue(N.view(-1), 5)
Out[178]:
In [179]:
torch.kthvalue(N.view(-1), 25)
Out[179]:
In [185]:
weight = M
corr = torch.rand((5,5))
weight[1, :] = 0
weight[:, 3] = 0
weight[:, 4] = 0
num_params = torch.sum(weight != 0).item()
corr, weight, num_params
Out[185]:
In [360]:
def prune(weight, num_params, corr, tau, zeta, idx=0, hebbian_grow=True):
with torch.no_grad():
# print("corr dimension", corr.shape)
# print("weight dimension", weight.shape)
# transpose to fit the weights, and eliminate zero weight
num_synapses = np.prod(weight.shape)
active_synapses = (weight != 0)
nonactive_synapses = (weight == 0)
total_active = torch.sum(active_synapses).item()
total_nonactive = torch.sum(nonactive_synapses).item()
# ----------- HEBBIAN PRUNING ----------------
if tau is not None:
# corr = corr.t()
corr_active = corr[active_synapses]
# decide which weights to remove based on correlation
kth = int(tau * total_active)
print("total active: ", total_active)
print("kth: ", kth)
# if kth = 0, keep all the synapses
if kth == 0:
hebbian_keep_mask = active_synapses
# else if kth greater than shape, remove all synapses
elif kth >= num_synapses:
hebbian_keep_mask = torch.zeros(weight.shape)
# if no edge cases
else:
keep_threshold, _ = torch.kthvalue(corr_active, kth)
print(keep_threshold)
# keep mask are ones above threshold and currently active
hebbian_keep_mask = (corr > keep_threshold) & active_synapses
# ----------- WEIGHT PRUNING ----------------
if zeta is not None:
# calculate the positive
weight_pos = weight[weight > 0]
pos_kth = int(zeta * len(weight_pos))
if pos_kth == 0:
pos_threshold = -1
else:
pos_threshold, _ = torch.kthvalue(weight_pos, pos_kth)
print(pos_kth, pos_threshold)
# calculate the negative
weight_neg = weight[weight < 0]
neg_kth = int((1-zeta) * len(weight_neg))
if neg_kth == 0:
neg_threshold = 1
else:
neg_threshold, _ = torch.kthvalue(weight_neg, neg_kth)
print(neg_kth, neg_threshold)
partial_weight_mask = (weight > pos_threshold) | (weight <= neg_threshold)
weight_mask = partial_weight_mask & active_synapses
# ----------- COMBINE HEBBIAN AND WEIGHT ----------------
# join both masks
if tau and zeta:
keep_mask = hebbian_keep_mask | weight_mask
elif tau:
keep_mask = hebbian_keep_mask
elif zeta:
keep_mask = weight_mask
else:
keep_mask = active_synapses
# ----------- GROWTH ----------------
# calculate number of params removed to be readded
num_add = max(num_params - torch.sum(keep_mask).item(), 0)
print(num_add)
# added option to have hebbian grow or not
if hebbian_grow:
# get threshold
kth = total_nonactive - num_add
corr_nonactive = corr[nonactive_synapses]
add_threshold, _ = torch.kthvalue(corr_nonactive, kth)
# calculate mask, only for currently nonactive
add_mask = (corr > add_threshold) & nonactive_synapses
else:
# probability of adding is 1 or lower
p_add = num_add / max(total_nonactive, num_add)
print(p_add)
random_sample = torch.rand(num_synapses) < p_add
add_mask = random_sample & nonactive_synapses
# calculate the new mask
new_mask = keep_mask | add_mask
# track added connections
return new_mask, keep_mask, add_mask
In [357]:
weight = M
corr = torch.rand((5,5))
weight[1, :] = 0
weight[:, 3] = 0
weight[:, 4] = 0
num_params = torch.sum(weight != 0).item()
corr, weight, num_params
Out[357]:
In [387]:
# coactivation matrix
corr = torch.tensor([ [0.3201, 0.8318, 0.3382, 0.9734, 0.0985],
[0.0401, 0.8620, 0.0845, 0.3778, 0.3996],
[0.4954, 0.0092, 0.6713, 0.8594, 0.9487],
[0.8101, 0.0922, 0.2033, 0.7185, 0.4588],
[0.3897, 0.6865, 0.5072, 0.9749, 0.0597]])
# weight matrix
weight = torch.tensor([
[19, 21, -12, 0, 0],
[ 0, 0, 0, 0, 0],
[-10, 25, -8, 0, 0],
[ 2, -11, 7, 0, 0],
[-14, 18, -6, 0, 0]])
num_params = torch.sum(weight != 0).item()
In [388]:
new_mask, keep_mask, add_mask = prune(weight, num_params, corr, tau=0.25, zeta=0.50)
In [389]:
keep_mask
Out[389]:
In [386]:
add_mask
Out[386]:
In [379]:
new_mask
Out[379]:
In [336]:
def prune(weight, num_params, corr, tau, zeta, idx=0, hebbian_grow=True):
with torch.no_grad():
# print("corr dimension", corr.shape)
# print("weight dimension", weight.shape)
# transpose to fit the weights, and eliminate zero weight
num_synapses = np.prod(weight.shape)
active_synapses = (weight != 0)
nonactive_synapses = (weight == 0)
total_active = torch.sum(active_synapses).item()
total_nonactive = torch.sum(nonactive_synapses).item()
# ----------- HEBBIAN PRUNING ----------------
if tau is not None:
# corr = corr.t()
corr_active = corr[active_synapses]
# decide which weights to remove based on correlation
kth = int((1-tau) * total_active)
print("total active: ", total_active)
print("kth: ", kth)
# if kth = 0, keep all the synapses
if kth == 0:
hebbian_keep_mask = torch.zeros(weight.shape).bool()
# else if kth greater than shape, remove all synapses
elif kth >= num_synapses:
hebbian_keep_mask = active_synapses
# if no edge cases
else:
keep_threshold, _ = torch.kthvalue(corr_active, kth)
print(keep_threshold)
# keep mask are ones above threshold and currently active
hebbian_keep_mask = (corr <= keep_threshold) & active_synapses
print("hebbian_keep_mask", hebbian_keep_mask)
# ----------- WEIGHT PRUNING ----------------
if zeta is not None:
# calculate the positive
weight_pos = weight[weight > 0]
pos_kth = int(zeta * len(weight_pos))
# if no positive weight, threshold can be 0 (select none)
if len(weight_pos) > 0:
# if zeta=0, pos_kth=0, prune nothing
if pos_kth == 0:
pos_threshold = -1
else:
pos_threshold, _ = torch.kthvalue(weight_pos, pos_kth)
else:
pos_threshold = 0
# calculate the negative
weight_neg = weight[weight < 0]
neg_kth = int((1-zeta) * len(weight_neg))
# if no negative weight, threshold -1 (select none)
if len(weight_neg) > 0:
# if zeta=1, neg_kth=0, prune all
if neg_kth == 0:
neg_threshold = torch.min(weight_neg).item() - 1
else:
neg_threshold, _ = torch.kthvalue(weight_neg, neg_kth)
else:
neg_threshold = -1
partial_weight_mask = (weight > pos_threshold) | (weight <= neg_threshold)
weight_mask = partial_weight_mask & active_synapses
print("weight_mask", weight_mask)
# ----------- COMBINE HEBBIAN AND WEIGHT ----------------
# join both masks
if tau and zeta:
keep_mask = hebbian_keep_mask | weight_mask
elif tau:
keep_mask = hebbian_keep_mask
elif zeta:
keep_mask = weight_mask
else:
keep_mask = active_synapses
# ----------- GROWTH ----------------
# calculate number of params removed to be readded
num_add = max(num_params - torch.sum(keep_mask).item(), 0)
print(num_add)
# added option to have hebbian grow or not
if hebbian_grow:
# get threshold
kth = num_add # should not be non-int
if kth > 0:
corr_nonactive = corr[nonactive_synapses]
add_threshold, _ = torch.kthvalue(corr_nonactive, kth)
# calculate mask, only for currently nonactive
add_mask = (corr <= add_threshold) & nonactive_synapses
# if there is nothing to add, return zeros
else:
add_mask = torch.zeros(weight.shape).bool()
else:
# probability of adding is 1 or lower
p_add = num_add / max(total_nonactive, num_add)
print(p_add)
random_sample = torch.rand(num_synapses) < p_add
add_mask = random_sample & nonactive_synapses
# calculate the new mask
new_mask = keep_mask | add_mask
# track added connections
return new_mask, keep_mask, add_mask
In [337]:
weight = M
corr = torch.rand((5,5))
weight[1, :] = 0
weight[:, 3] = 0
weight[:, 4] = 0
num_params = torch.sum(weight != 0).item()
corr, weight, num_params
Out[337]:
In [343]:
# new_mask, keep_mask, add_mask = prune(weight, num_params, corr, tau=1, zeta=1)
# new_mask, keep_mask, add_mask = prune(weight, num_params, corr, tau=1, zeta=None)
# new_mask, keep_mask, add_mask = prune(weight, num_params, corr, tau=None, zeta=1)
new_mask, keep_mask, add_mask = prune(weight, num_params, corr, tau=None, zeta=None)
In [344]:
keep_mask
Out[344]:
In [345]:
add_mask
Out[345]:
In [346]:
new_mask
Out[346]:
In [ ]: