In [1]:
import pulse2percept as p2p
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline


2017-08-21 23:44:57,526 [pulse2percept] [INFO] Welcome to pulse2percept

In [2]:
axon_bundles = p2p.utils.parfor(p2p.retina.jansonius2009, np.linspace(-180, 180, 501))

In [3]:
plt.figure(figsize=(10, 6))
for ax in axon_bundles:
    plt.plot(ax[:, 0], ax[:, 1])



In [4]:
def find_closest_axon(pos_xy, axon_bundles):
    xneuron, yneuron = pos_xy
    # find the nearest axon to this pixel
    dist = [min((ax[:, 0] - xneuron) ** 2 + (ax[:, 1] - yneuron) ** 2) for ax in axon_bundles]
    axon_id = np.argmin(dist)

    # find the position on the axon
    ax = axon_bundles[axon_id]
    dist = (ax[:, 0] - xneuron) ** 2 + (ax[:, 1] - yneuron) ** 2
    pos_id = np.argmin(dist)
    
    # add all positions: from `pos_id` to the optic disc
    return axon_bundles[axon_id][pos_id:0:-1, :]

In [5]:
def assign_axons(xg, yg, axon_bundles, engine='joblib', scheduler='threading', n_jobs=-1):
    # Let's say we want a neuron at every pixel location.
    # We loop over all (x, y) locations and find the closest axon:
#     pos_xy = [(x, y) for x, y in zip(xg.ravel(), yg.ravel())]
    pos_xy = np.column_stack((xg.ravel(), yg.ravel()))
    return p2p.utils.parfor(find_closest_axon, pos_xy, func_args=[axon_bundles])

In [ ]:
xg, yg = np.meshgrid(np.linspace(-10, 10, 101), np.linspace(-10, 10, 101), indexing='xy')
print('grid step: %f dva, %f um' % (xg[0, 1] - xg[0, 0], p2p.retina.dva2ret(xg[0, 1] - xg[0, 0])))


grid step: 0.200000 dva, 53.613641 um

In [ ]:
axons = assign_axons(xg, yg, axon_bundles)

In [ ]:
plt.figure(figsize=(10, 8))
n_axons = np.minimum(50, len(axons))

idx_axons = np.arange(len(axons))
np.random.seed(42)
np.random.shuffle(idx_axons)
idx_axons = idx_axons[:n_axons]
for ax, x, y in zip(np.array(axons)[idx_axons], xg.ravel()[idx_axons], yg.ravel()[idx_axons]):
    plt.plot(ax[:, 0], ax[:, 1])
    plt.plot(x, y, 's', markersize=8, alpha=0.5)
    plt.plot(ax[0, 0], ax[0, 1], 'o')
for e in p2p.implants.ArgusI():
    plt.plot(p2p.retina.ret2dva(e.x_center),
             p2p.retina.ret2dva(e.y_center), 'ok', markersize=30, alpha=0.4)
plt.axis('equal');

In words: For every axon, there is a function that describes how sensitive the local tissue is. This is a function of the distance from the soma. Possibilites are:

  • drops exponentially with distance
  • is the Fried function

This function needs to be multiplied with the current spread. Then what do you do?

  • You could sum all these values: this is the contribution of this axon at the pixel location of the soma. This is basically the dot product that we implement now.
  • You could take the effective current to be the max of this element-wise product.
  • You could walk along the axon from the optic disc to the soma. The axon is on if we reach some threshold at any point. This is binary. It's kinda weird, but isn't this how things work?

In [ ]:
def axon_sensitivity(dist, rule='decay', decay_const=3.0):
    if rule.lower() == 'decay':
        return np.exp(-dist / decay_const)
    elif rule.lower() == 'fried':
        mu_gauss = p2p.retina.ret2dva(50.0)
        std_gauss = p2p.retina.ret2dva(20.0)
        bell = 0.7 * np.exp(-(dist - mu_gauss) ** 2 / (2 * std_gauss ** 2))
        plateau = 0.3
        soma = np.maximum(mu_gauss - dist, 0)
        return np.maximum(0, bell - 0.001 * dist + plateau - soma)
    else:
        raise ValueError('Unknown rule "%s"' % rule)

In [ ]:
plt.figure(figsize=(12, 5))
plt.subplot(121)
dist = np.linspace(0, p2p.retina.ret2dva(1000), 1000)
for decay_const in [0.01, 0.1, 1.0, 2.0, 10.0]:
    plt.plot(dist, axon_sensitivity(dist, rule='decay', decay_const=decay_const), 
             linewidth=3, label='$\lambda$=' + str(decay_const))
plt.legend()
plt.xlabel('dist (dva^2)')
plt.title('Decay rule')

plt.subplot(122)
plt.plot(dist, axon_sensitivity(dist, rule='fried'), linewidth=3)
plt.xlabel('dist (dva^2)')
plt.title('Fried rule');

In [ ]:
std = 1.0
cs = np.exp(-((xg - 5) ** 2 + (yg - 5) ** 2) / (2 * std ** 2))
plt.imshow(np.flipud(cs))

In [ ]:
from scipy.spatial import cKDTree
# pos_xy = np.vstack((xg.ravel(), yg.ravel())).T
pos_xy = np.column_stack((xg.ravel(), yg.ravel()))
tree = cKDTree(pos_xy)

In [ ]:
_, plot_axon = tree.query((2, 5))
print('idx_plot: ', plot_axon)
axon = axons[plot_axon]
_, idx_neuron = tree.query(axon[0, :])

# Consider only pixels within the grid
idx_valid = (axon[:, 0] >= xg.min()) * (axon[:, 0] <= xg.max())
idx_valid *= (axon[:, 1] >= yg.min()) * (axon[:, 1] <= yg.max())

# For these, find the xg, yg coordinates
_, idx_cs = tree.query(axon[idx_valid, :])

# Drop duplicates
_, idx_cs_unique = np.unique(idx_cs, return_index=True)
idx_cs = idx_cs[np.sort(idx_cs_unique)]
idx_dist = np.insert(idx_cs, 0, idx_neuron, axis=0)
idx_cs, idx_dist

In [ ]:
dist = np.sqrt(np.diff(xg.ravel()[idx_dist]) ** 2 + np.diff(yg.ravel()[idx_dist]) ** 2)
dist

In [ ]:
plt.plot(np.cumsum(dist))
plt.ylabel('dist (deg^2)')
plt.xlabel('axon segment')

In [ ]:
plt.plot(axon_sensitivity(np.cumsum(dist), rule='decay'))
plt.xlabel('axon segment')
plt.ylabel('sensitivity')

In [ ]:
plt.plot(cs.ravel()[idx_cs])
plt.xlabel('axon segment')
plt.ylabel('electric field "current spread"')

In [ ]:
axon_weights = axon_sensitivity(np.cumsum(dist), rule='decay') * cs.ravel()[idx_cs]
plt.plot(axon_weights)
plt.xlabel('axon segment')
plt.ylabel('effective current')

In [ ]:
axon_weights.mean(), axon_weights.max()

In [ ]:
def distance_from_soma(axon, tree, xg, yg):
    # Consider only pixels within the grid
    idx_valid = (axon[:, 0] >= xg.min()) * (axon[:, 0] <= xg.max())
    idx_valid *= (axon[:, 1] >= yg.min()) * (axon[:, 1] <= yg.max())
    
    # For these, find the xg, yg coordinates
    _, idx_cs = tree.query(axon[idx_valid, :])
    if len(idx_cs) == 0:
        return 0, np.inf

    # Drop duplicates
    _, idx_cs_unique = np.unique(idx_cs, return_index=True)
    idx_cs = idx_cs[np.sort(idx_cs_unique)]
    
    _, idx_neuron = tree.query(axon[0, :])
    if len(idx_cs) == 0:
        return idx_neuron, 0.0
    else:
        # For distance calculation, add a pixel at the location of the soma
        idx_dist = np.insert(idx_cs, 0, idx_neuron, axis=0)

        # Calculate distance from soma
        xdiff = np.diff(xg.ravel()[idx_dist])
        ydiff = np.diff(yg.ravel()[idx_dist])
        dist = np.sqrt(np.cumsum(xdiff ** 2 + ydiff ** 2))

    return idx_cs, dist

In [ ]:
axons_dist = p2p.utils.parfor(distance_from_soma, axons, func_args=[tree, xg, yg])

In [ ]:
def get_axon_contribution(axon_dist, cs, sensitivity_rule='fried',
                          activation_rule='max', min_contribution=0.01):
    idx_cs, dist = axon_dist

    # Find effective current
    axon_weights = axon_sensitivity(dist, rule=sensitivity_rule) * cs.ravel()[idx_cs]
    
    if activation_rule == 'max':
        axon_contribution = axon_weights.max()
    elif activation_rule == 'mean':
        axon_contribution = axon_weights.mean()
    else:
        raise ValueError('Unknown activation rule "%s"' % activation_rule)
    
    if axon_contribution < min_contribution:
        return None
    else:
        if len(idx_cs) > 1:
            idx_neuron = idx_cs[0]
        else:
            idx_neuron = idx_cs
        return idx_neuron, axon_contribution

In [ ]:
sensitivity_rules = ['decay', 'fried']
activity_rules = ['mean', 'max']

idx_plot = 1
plt.figure(figsize=(14, 8))
for sens_rule in sensitivity_rules:
    for act_rule in activity_rules:
        contrib = p2p.utils.parfor(get_axon_contribution, axons_dist, func_args=[cs],
                               func_kwargs={'sensitivity_rule': sens_rule, 'activation_rule': act_rule},
                               engine='joblib')
        plt.subplot(len(sensitivity_rules), len(activity_rules), idx_plot)
        px_contrib = list(filter(None, contrib))
        ecs = np.zeros_like(cs)
        for i, e in px_contrib:
            ecs.ravel()[i] = e
        plt.imshow(np.flipud(ecs))
        plt.colorbar(fraction=0.046, pad=0.04)
        plt.title('%s, %s rule' % (sens_rule, act_rule))
        idx_plot += 1

In [ ]:
np.argmax(ecs)

In [ ]: