In [1]:
import pulse2percept as p2p
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
In [2]:
axon_bundles = p2p.utils.parfor(p2p.retina.jansonius2009, np.linspace(-180, 180, 501))
In [3]:
plt.figure(figsize=(10, 6))
for ax in axon_bundles:
plt.plot(ax[:, 0], ax[:, 1])
In [4]:
def find_closest_axon(pos_xy, axon_bundles):
xneuron, yneuron = pos_xy
# find the nearest axon to this pixel
dist = [min((ax[:, 0] - xneuron) ** 2 + (ax[:, 1] - yneuron) ** 2) for ax in axon_bundles]
axon_id = np.argmin(dist)
# find the position on the axon
ax = axon_bundles[axon_id]
dist = (ax[:, 0] - xneuron) ** 2 + (ax[:, 1] - yneuron) ** 2
pos_id = np.argmin(dist)
# add all positions: from `pos_id` to the optic disc
return axon_bundles[axon_id][pos_id:0:-1, :]
In [5]:
def assign_axons(xg, yg, axon_bundles, engine='joblib', scheduler='threading', n_jobs=-1):
# Let's say we want a neuron at every pixel location.
# We loop over all (x, y) locations and find the closest axon:
# pos_xy = [(x, y) for x, y in zip(xg.ravel(), yg.ravel())]
pos_xy = np.column_stack((xg.ravel(), yg.ravel()))
return p2p.utils.parfor(find_closest_axon, pos_xy, func_args=[axon_bundles])
In [ ]:
xg, yg = np.meshgrid(np.linspace(-10, 10, 101), np.linspace(-10, 10, 101), indexing='xy')
print('grid step: %f dva, %f um' % (xg[0, 1] - xg[0, 0], p2p.retina.dva2ret(xg[0, 1] - xg[0, 0])))
In [ ]:
axons = assign_axons(xg, yg, axon_bundles)
In [ ]:
plt.figure(figsize=(10, 8))
n_axons = np.minimum(50, len(axons))
idx_axons = np.arange(len(axons))
np.random.seed(42)
np.random.shuffle(idx_axons)
idx_axons = idx_axons[:n_axons]
for ax, x, y in zip(np.array(axons)[idx_axons], xg.ravel()[idx_axons], yg.ravel()[idx_axons]):
plt.plot(ax[:, 0], ax[:, 1])
plt.plot(x, y, 's', markersize=8, alpha=0.5)
plt.plot(ax[0, 0], ax[0, 1], 'o')
for e in p2p.implants.ArgusI():
plt.plot(p2p.retina.ret2dva(e.x_center),
p2p.retina.ret2dva(e.y_center), 'ok', markersize=30, alpha=0.4)
plt.axis('equal');
In words: For every axon, there is a function that describes how sensitive the local tissue is. This is a function of the distance from the soma. Possibilites are:
This function needs to be multiplied with the current spread. Then what do you do?
In [ ]:
def axon_sensitivity(dist, rule='decay', decay_const=3.0):
if rule.lower() == 'decay':
return np.exp(-dist / decay_const)
elif rule.lower() == 'fried':
mu_gauss = p2p.retina.ret2dva(50.0)
std_gauss = p2p.retina.ret2dva(20.0)
bell = 0.7 * np.exp(-(dist - mu_gauss) ** 2 / (2 * std_gauss ** 2))
plateau = 0.3
soma = np.maximum(mu_gauss - dist, 0)
return np.maximum(0, bell - 0.001 * dist + plateau - soma)
else:
raise ValueError('Unknown rule "%s"' % rule)
In [ ]:
plt.figure(figsize=(12, 5))
plt.subplot(121)
dist = np.linspace(0, p2p.retina.ret2dva(1000), 1000)
for decay_const in [0.01, 0.1, 1.0, 2.0, 10.0]:
plt.plot(dist, axon_sensitivity(dist, rule='decay', decay_const=decay_const),
linewidth=3, label='$\lambda$=' + str(decay_const))
plt.legend()
plt.xlabel('dist (dva^2)')
plt.title('Decay rule')
plt.subplot(122)
plt.plot(dist, axon_sensitivity(dist, rule='fried'), linewidth=3)
plt.xlabel('dist (dva^2)')
plt.title('Fried rule');
In [ ]:
std = 1.0
cs = np.exp(-((xg - 5) ** 2 + (yg - 5) ** 2) / (2 * std ** 2))
plt.imshow(np.flipud(cs))
In [ ]:
from scipy.spatial import cKDTree
# pos_xy = np.vstack((xg.ravel(), yg.ravel())).T
pos_xy = np.column_stack((xg.ravel(), yg.ravel()))
tree = cKDTree(pos_xy)
In [ ]:
_, plot_axon = tree.query((2, 5))
print('idx_plot: ', plot_axon)
axon = axons[plot_axon]
_, idx_neuron = tree.query(axon[0, :])
# Consider only pixels within the grid
idx_valid = (axon[:, 0] >= xg.min()) * (axon[:, 0] <= xg.max())
idx_valid *= (axon[:, 1] >= yg.min()) * (axon[:, 1] <= yg.max())
# For these, find the xg, yg coordinates
_, idx_cs = tree.query(axon[idx_valid, :])
# Drop duplicates
_, idx_cs_unique = np.unique(idx_cs, return_index=True)
idx_cs = idx_cs[np.sort(idx_cs_unique)]
idx_dist = np.insert(idx_cs, 0, idx_neuron, axis=0)
idx_cs, idx_dist
In [ ]:
dist = np.sqrt(np.diff(xg.ravel()[idx_dist]) ** 2 + np.diff(yg.ravel()[idx_dist]) ** 2)
dist
In [ ]:
plt.plot(np.cumsum(dist))
plt.ylabel('dist (deg^2)')
plt.xlabel('axon segment')
In [ ]:
plt.plot(axon_sensitivity(np.cumsum(dist), rule='decay'))
plt.xlabel('axon segment')
plt.ylabel('sensitivity')
In [ ]:
plt.plot(cs.ravel()[idx_cs])
plt.xlabel('axon segment')
plt.ylabel('electric field "current spread"')
In [ ]:
axon_weights = axon_sensitivity(np.cumsum(dist), rule='decay') * cs.ravel()[idx_cs]
plt.plot(axon_weights)
plt.xlabel('axon segment')
plt.ylabel('effective current')
In [ ]:
axon_weights.mean(), axon_weights.max()
In [ ]:
def distance_from_soma(axon, tree, xg, yg):
# Consider only pixels within the grid
idx_valid = (axon[:, 0] >= xg.min()) * (axon[:, 0] <= xg.max())
idx_valid *= (axon[:, 1] >= yg.min()) * (axon[:, 1] <= yg.max())
# For these, find the xg, yg coordinates
_, idx_cs = tree.query(axon[idx_valid, :])
if len(idx_cs) == 0:
return 0, np.inf
# Drop duplicates
_, idx_cs_unique = np.unique(idx_cs, return_index=True)
idx_cs = idx_cs[np.sort(idx_cs_unique)]
_, idx_neuron = tree.query(axon[0, :])
if len(idx_cs) == 0:
return idx_neuron, 0.0
else:
# For distance calculation, add a pixel at the location of the soma
idx_dist = np.insert(idx_cs, 0, idx_neuron, axis=0)
# Calculate distance from soma
xdiff = np.diff(xg.ravel()[idx_dist])
ydiff = np.diff(yg.ravel()[idx_dist])
dist = np.sqrt(np.cumsum(xdiff ** 2 + ydiff ** 2))
return idx_cs, dist
In [ ]:
axons_dist = p2p.utils.parfor(distance_from_soma, axons, func_args=[tree, xg, yg])
In [ ]:
def get_axon_contribution(axon_dist, cs, sensitivity_rule='fried',
activation_rule='max', min_contribution=0.01):
idx_cs, dist = axon_dist
# Find effective current
axon_weights = axon_sensitivity(dist, rule=sensitivity_rule) * cs.ravel()[idx_cs]
if activation_rule == 'max':
axon_contribution = axon_weights.max()
elif activation_rule == 'mean':
axon_contribution = axon_weights.mean()
else:
raise ValueError('Unknown activation rule "%s"' % activation_rule)
if axon_contribution < min_contribution:
return None
else:
if len(idx_cs) > 1:
idx_neuron = idx_cs[0]
else:
idx_neuron = idx_cs
return idx_neuron, axon_contribution
In [ ]:
sensitivity_rules = ['decay', 'fried']
activity_rules = ['mean', 'max']
idx_plot = 1
plt.figure(figsize=(14, 8))
for sens_rule in sensitivity_rules:
for act_rule in activity_rules:
contrib = p2p.utils.parfor(get_axon_contribution, axons_dist, func_args=[cs],
func_kwargs={'sensitivity_rule': sens_rule, 'activation_rule': act_rule},
engine='joblib')
plt.subplot(len(sensitivity_rules), len(activity_rules), idx_plot)
px_contrib = list(filter(None, contrib))
ecs = np.zeros_like(cs)
for i, e in px_contrib:
ecs.ravel()[i] = e
plt.imshow(np.flipud(ecs))
plt.colorbar(fraction=0.046, pad=0.04)
plt.title('%s, %s rule' % (sens_rule, act_rule))
idx_plot += 1
In [ ]:
np.argmax(ecs)
In [ ]: