In [ ]:
"""make a 3d scatterplot of the synapses
currently the program only looks at data points
with synapses > mean synapses
and then randomly samples from that data set
since if there are too many data points the 3d grapher
runs very slowly
-Jay Miller """
from mpl_toolkits.mplot3d import axes3d
import matplotlib.pyplot as plt
import numpy as np
def check_condition(row):
if row[-1] == 0:
return False
return True
def synapse_filt(row, avg):
if row[-1] > avg:
return True
return False
csv = np.genfromtxt('output.csv', delimiter=",")
samples = 5000
# only look at data points where the number of synapses is greater than avg
a = np.apply_along_axis(check_condition, 1, csv)
a = np.where(a == True)[0]
nonzero_rows = csv[a, :]
avg_synapse = np.mean(nonzero_rows[1:, -1])
print avg_synapse
filter_avg_synapse = np.apply_along_axis(synapse_filt, 1,
nonzero_rows, avg_synapse)
a = np.where(filter_avg_synapse == True)[0]
nonzero_filtered = nonzero_rows[a, :]
xyz_only = nonzero_filtered[:, [1, 2, 3]]
# randomly sample
perm = np.random.permutation(xrange(1, len(xyz_only[:])))
xyz_only = xyz_only[perm[:samples]]
# get range for graphing
x_min = np.amin(xyz_only[:, 0])
x_max = np.amax(xyz_only[:, 0])
y_max = np.amax(xyz_only[:, 1])
y_min = np.amin(xyz_only[:, 1])
z_min = np.amin(xyz_only[:, 2])
z_max = np.amax(xyz_only[:, 2])
# following code adopted from
# https://www.getdatajoy.com/examples/python-plots/3d-scatter-plot
fig = plt.figure()
ax = fig.gca(projection='3d')
ax.set_title('3D Scatter Plot')
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_zlabel('z')
ax.set_xlim(x_min, x_max)
ax.set_ylim(y_min, y_max)
ax.set_zlim(z_min, z_max)
ax.view_init()
ax.dist = 12 # distance
ax.scatter(
xyz_only[:, 0], xyz_only[:, 1], xyz_only[:, 2], # data
color='purple', # marker colour
marker='o', # marker shape
s=30 # marker size
)
plt.show() # render the plot