In [ ]:
%matplotlib inline
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

import datajoint as dj
from pipeline import rf, pre, psy

In [ ]:
(dj.ERD(rf)+1).draw()

In [ ]:
dj.ERD(pre).draw()

In [ ]:
pre.AverageFrame().heading

In [ ]:
a = pre.AverageFrame() & 'animal_id=8623' & 'slice=3' & 'scan_idx=6' & 'channel=1'

In [ ]:
pre.ScanInfo()

In [ ]:
plt.imshow(a.fetch1['frame'], cmap=plt.cm.gray)

In [ ]:
a = pre.Trace()*rf.Scan().proj('depth') & (rf.Sync() * psy.MovingNoise()) & 'depth>300' & 'slice=3'

In [ ]:
@pre.schema
class Contrast(dj.Computed):
    definition = """
    -> AverageFrame
    --- 
    contrast : double   # fake quantity
    """
    
    def _make_tuples(self, key):
        frame = (pre.AverageFrame() & key).fetch1['frame']
        key['contrast'] = frame.max()/frame.min()
        self.insert1(key)
        print('.', end='', flush=True)

In [ ]:
(dj.ERD(Contrast)-1+1-1+1).draw()

In [ ]:
Contrast().populate('animal_id=8623', 'channel=1')

In [ ]:
Contrast().progress();

In [ ]:
Contrast().progress('animal_id=8623', 'channel=1');

In [ ]:
Contrast()

In [ ]:
sns.distplot(Contrast().fetch['contrast'])

In [ ]:
sns.lmplot('depth','contrast', data=pd.DataFrame((Contrast()*rf.Scan()).fetch()))

In [ ]:
Contrast().drop()

In [ ]:
pre.Trace()

In [ ]:
pre.Trace()*pre.Spikes()*rf.Sync()

In [ ]:
pre.Trace()*pre.Spikes()*rf.Sync() & psy.MovingNoise() & (rf.Scan() & 'depth>300') & 'slice=3' & 'spike_inference=3'

In [ ]:
a = pre.Trace()*pre.Spikes()*rf.Sync() & psy.MovingNoise() & (rf.Scan() & 'depth>300') & 'slice=3' & 'spike_inference=3'
a.heading

In [ ]:
a.make_sql()

In [ ]:
# fetch image and mask data for a single cell
a &= dict(mask_id=13, spike_inference=3)
mask = pre.AverageFrame() * pre.SegmentMask() & a & 'channel=1'
frame, mask_px, mask_w = mask.fetch1['frame','mask_pixels','mask_weights'] 

# plot image and mask 
with sns.axes_style('white'):
    fig, ax = plt.subplots()
img = frame.ravel()*np.NaN
img[mask_px.astype(int).squeeze()] = mask_w.squeeze()
img = img.reshape(frame.shape)
ax.imshow(np.sqrt(frame), cmap=plt.cm.get_cmap('gray'))
ax.imshow(img.T, alpha=1, cmap=plt.cm.get_cmap('magma'))

In [ ]:
# fetch 
traces = np.hstack(a.fetch['ca_trace'])

# save 
pd.DataFrame(traces).to_csv('my_traces.csv',index=False)

In [ ]:
%matplotlib notebook

# fetch trace and trial data
times, traces, spikes = a.fetch1['frame_times', 'ca_trace','spike_trace']
trial_times = (psy.Trial() * a & 'trial_idx between first_trial and last_trial').fetch['flip_times']

# plot traces against time
trial_times = np.asarray([r[0,0] for r in trial_times])
with sns.axes_style('whitegrid'):
    fig, ax = plt.subplots(2,1, sharex=True)
ax[0].plot(times[0,::3]- times[0,0], traces, label='Ca Trace')
ax[1].plot(times[0,::3]- times[0,0], spikes, label='Spike Rate')
ax[0].set_ylabel('Fluorescence')
ax[1].set_ylabel('inferred spike rate')
ax[1].set_xlabel('time [s]')
ax[1].plot(trial_times - times[0,0], 0*trial_times+4,'h',color=sns.xkcd_rgb['greenblue'],ms=5)
sns.despine(fig)
fig.tight_layout()

In [ ]: