In [1]:
%pylab inline


Populating the interactive namespace from numpy and matplotlib

In [1]:


In [27]:
import re

#takw raw csv from google docs
with open('responses_v2.csv') as f:
    rawdata = f.readlines()
#make each row into a list
rawdata = rawdata[0].split('\r')
datalist = []
for i,d in enumerate(rawdata):
    datalist.append(d.split(','))
    
## Get subjects
subjects = set()
for i in rawdata[1:]:
    s = i.split(',')
    subjects.add(s[1])
print subjects
    
#probmap = {.1:1,.25:2,.75:3,.9:4}     original version
probmap = {.5:1,.6:2,.7:3,.8:4,.9:5}
    
class Experiment():
    def __init__(self,subject_name):
        self.subject = subject_name
        self.all_data = []
    def get_data(self,data):
        #get data associated with the subject name
        trial_data = []
        for line in data:
            s = line.split(',')
            if self.subject == s[1]:
                # trial_num, probsame, response_num(not count 1), time_taken, response (1=S, 0=D), correct/incorrect (1/0), 
                trial_split = s[2].split('\t')
                trial_num = int((trial_split[1].split('-'))[-1])
                probsame = probmap[float((trial_split[2].split('-'))[-1])]
               # print 'hello',trial_num, probsame                
                pattern = r'response-\d+\t+[SD]+\t+[SD]+\t+\d+'
                matches = re.findall(pattern,s[2])
                # 'response-1\tS\tS\t10239'  ## response, target, actual, time
                for index, response in zip(range(len(matches)), matches):
                    if index > 0:
                        response_num = int((response.split('\t'))[0].split('-')[-1])
                        correct = 1 if ((response.split('\t'))[1] == (response.split('\t'))[2]) else 0
                        #print index
                        time_taken =  int((response.split('\t'))[-1]) - int((matches[index-1].split('\t'))[-1])
                        #print time_taken
                        response_key = 1 if (response.split('\t'))[2] == 'S' else 0
                        #print response
                        #TARGET response means switch
                        switch = 1 if (response.split('\t'))[1] != (matches[index-1].split('\t'))[1] else 0
                        target = 1 if (response.split('\t'))[1] == 'S' else 0
                        trial_data.append([trial_num, probsame, response_num, time_taken, response_key, correct, switch, target])
        self.all_data.append(array(trial_data,dtype='int16'))
        
data_dict = {}
columns=["trial_num", "prob_same", "response_num", 'time_taken', 'response', 'correct', 'switch', 'target']
for subject in ['3337264','Nate!'] :
    e = Experiment(subject)
    e.get_data(rawdata)
    with open(subject + '_formatted.csv','w') as outfile:
        writer = csv.writer(outfile)
        writer.writerows(columns)
        writer.writerows(e.all_data[0])


set(['test33', 'yy', 'thomas', '3337264', 'DOMINIC_MRI', 'Nate!', 'Calvin', 'Dominicmorning1', 'will', 'YY', 'Josh Lynch', 'test', 'CalvinLBS', 'test1', 'Tom', 'y', 'Dominic', 'paul', 'nicole', 'serguei'])
split 3337264 3337264
split 3337264 3337264
split 3337264 3337264
split 3337264 3337264
split 3337264 3337264
split 3337264 3337264
split 3337264 3337264
split 3337264 3337264
split 3337264 3337264
split 3337264 3337264
split 3337264 3337264
split 3337264 3337264
split 3337264 3337264
split 3337264 3337264
split 3337264 3337264
split 3337264 3337264
split 3337264 3337264
split 3337264 3337264
split 3337264 3337264
split 3337264 3337264
split 3337264 3337264
split 3337264 3337264
split 3337264 3337264
split 3337264 3337264
split 3337264 3337264
split 3337264 3337264
split 3337264 3337264
split 3337264 3337264
split 3337264 3337264
split 3337264 3337264
split 3337264 3337264
split 3337264 3337264
split 3337264 3337264
split 3337264 3337264
split 3337264 3337264
split 3337264 3337264
split 3337264 3337264
split 3337264 3337264
split 3337264 3337264
split 3337264 3337264
split 3337264 3337264
split 3337264 3337264
split 3337264 3337264
split 3337264 3337264
split 3337264 3337264
split 3337264 3337264
split 3337264 3337264
split 3337264 3337264
split 3337264 3337264
split 3337264 3337264
split 3337264 3337264
split 3337264 3337264
split 3337264 3337264
split 3337264 3337264
split 3337264 3337264
split 3337264 3337264
split 3337264 3337264
split 3337264 3337264
split 3337264 3337264
split 3337264 3337264
split Nate! Nate!
split Nate! Nate!
split Nate! Nate!
split Nate! Nate!
split Nate! Nate!
split Nate! Nate!
split Nate! Nate!
split Nate! Nate!
split Nate! Nate!
split Nate! Nate!
split Nate! Nate!
split Nate! Nate!
split Nate! Nate!
split Nate! Nate!
split Nate! Nate!
split Nate! Nate!
split Nate! Nate!
split Nate! Nate!
split Nate! Nate!
split Nate! Nate!
split Nate! Nate!
split Nate! Nate!
split Nate! Nate!
split Nate! Nate!
split Nate! Nate!
split Nate! Nate!
split Nate! Nate!
split Nate! Nate!
split Nate! Nate!
split Nate! Nate!
split Nate! Nate!
split Nate! Nate!
split Nate! Nate!
split Nate! Nate!
split Nate! Nate!
split Nate! Nate!
split Nate! Nate!
split Nate! Nate!
split Nate! Nate!
split Nate! Nate!
split Nate! Nate!
split Nate! Nate!
split Nate! Nate!
split Nate! Nate!
split Nate! Nate!
split Nate! Nate!
split Nate! Nate!
split Nate! Nate!
split Nate! Nate!
split Nate! Nate!
split Nate! Nate!
split Nate! Nate!
split Nate! Nate!
split Nate! Nate!
split Nate! Nate!
split Nate! Nate!
split Nate! Nate!
split Nate! Nate!
split Nate! Nate!
split Nate! Nate!

Questions to ask

  1. prob_same vs average response time
  2. correct/incorrect vs sequence (like a bar graph w/ probe trials), also as a function of predictability
  3. Look at prob correct on switch trials as a function of the length of the non-switch sequence preceeding it

In [327]:
## 1. prob_same vs average response time

data_switch = data[logical_and(data[:,6] > 0.5,  data[:,3]  < 2000),:]
data_same = data[logical_and(data[:,6]< 0.5 , data[:,3]  < 2000),:]
#
close()
fig1 = figure(1);
figsize(10,10)
#.1 means arrows mostly switch - most responses will be D
#.9 means arrows mostly stay the same - most responses will be S

for i,index in zip([1, 2,3 , 4],range(1,5)):
    subplot(220 + index)
    hist(data_switch[abs(data_switch[:,1] - i) < .01,3], 20, color='b',alpha=0.5)
    hist(data_same[abs(data_same[:,1] - i) < .01,3], 20, color='r',alpha=0.5)
    axis([0, 2000, 0, 350])
    title('response times ' + str(i))
    legend(['switch','same'])



In [328]:
#mean difference seems to grow as a function 
fig2 = figure(2)
means_switch = []
means_same = []

for i,index in zip([1, 2, 3, 4],range(1,5)):
    plot(i, mean(data_switch[abs(data_switch[:,1] - i) < .01,3]), 'bx',markersize=20)
    means_switch.append(mean(data_switch[abs(data_switch[:,1] - i) < .01,3]))
    plot(i,mean(data_same[abs(data_same[:,1] - i) < .01,3]),'rx',markersize=20)
    means_same.append(mean(data_same[abs(data_same[:,1] - i) < .01,3]))
    axis([0, 5, 0, 1000])
    xlabel('prob of D response')
    ylabel('mean response time')
    legend(['response change','response the same'])

t = [1, 2, 3, 4] #[.1 .25 .75 .9]
(ar,br)=polyfit(t,means_switch,1)
xr=polyval([ar,br],t)
plot(t,xr)

(ar,br)=polyfit(t,means_same,1)
xr=polyval([ar,br],t)
plot(t,xr,'r')


Out[328]:
[<matplotlib.lines.Line2D at 0x109b9b610>]

In [451]:
#2. correct/incorrect vs sequence (like a bar graph w/ probe trials), also as a function of predictability
#for each SWITCH response AFTER the 2nd and before the 2nd to last, look 2 back and 2 forward


def plot_data(subject, color):
    error_count_1 = array([0, 0, 0, 0, 0, 0, 0])
    error_count_2 = array([0, 0, 0, 0, 0, 0, 0])
    error_count_3 = array([0, 0, 0, 0, 0, 0, 0])
    error_count_4 = array([0, 0, 0, 0, 0, 0, 0])
    total_counts = array([0, 0, 0, 0])
    probs = [1, 2, 3, 4]
    columns=["trial_num", "prob_same", "response_num", 'time_taken', 'response', 'correct', 'switch', 'target']
    data = data_dict[subject]
    set_printoptions(threshold=nan)
    for prob_index, prob in zip(range(4), probs):
        data_prob = data[data[:,1] == prob,:] #dealing with JUST the data from this probability
        trials = list(set(data_prob[:,0]))
        trials.sort()
        #for each trial
        for trial in trials:
            data_trial = data_prob[ data_prob[:,0] == trial,:]
            num_responses = data_trial.shape[0]
            for response in range(4,num_responses-2):
                total_counts[prob_index] = total_counts[prob_index] + 1
                if data_trial[response,6] == 1: #if it's a switch trial
                    #print response, 'is a switch trial'
                    #print data_trial[response-2:response+3,:]
                    for offset in range(-4,3): #look in the surroundings
                        #count up the number of errors - 1-correct gives 1 for error, 0 for non-error
                        #print offset, offset+4, response+offset
                        if prob_index == 0:
                            error_count_1[offset+4] += 0 if data_trial[response + offset,5] else 1
                        elif prob_index == 1:
                            error_count_2[offset+4] += 0 if data_trial[response + offset,5] else 1
                        elif prob_index == 2:
                            error_count_3[offset+4] += 0 if data_trial[response + offset,5] else 1                          
                        elif prob_index == 3:
                            error_count_4[offset+4] += 0 if data_trial[response + offset,5] else 1
    counts = [error_count_1,error_count_2,error_count_3,error_count_4]

    figsize(8,8)
    for i in range(4):
        count_proportion = array(counts[i],dtype='float') / array(total_counts[i],dtype='float')
        print count_proportion
        subplot(220 + 1 +  i)
        bar(array([-4, -3, -2, -1, 0, 1, 2]) - .3,count_proportion,color=color,alpha=0.5)
        axis([-5, 4, 0, .05])

#     print total_counts
#     print error_count_1
#     print error_count_2
#     print error_count_3
#     print error_count_4
fig, ax = subplots();
plot_data('CalvinLBS','r')
plot_data('Calvin','b')
ax.set_title('hello')
title('red=low, blue=normal')
#TODO: make sure these are normalized by the NUMBER of trials and responses

#TODO: I feel like I'm probably one trial off, possibly in recording the data.  
#This would make LOTS of sense if -1 was the same as 0 here.  DOUBLE CHECK!!


[ 0.00616333  0.00770416  0.00231125  0.02157165  0.02234206  0.00385208
  0.0046225 ]
[ 0.01594802  0.01890136  0.01712936  0.02953337  0.03662138  0.01653869
  0.01712936]
[ 0.02580645  0.02177419  0.02016129  0.04032258  0.03951613  0.0233871
  0.01693548]
[ 0.01192843  0.0139165   0.01093439  0.03777336  0.0417495   0.01192843
  0.01093439]
[ 0.00468604  0.00281162  0.00468604  0.01218369  0.01499531  0.00749766
  0.00562324]
[ 0.01810585  0.01810585  0.01671309  0.02785515  0.03064067  0.01949861
  0.01532033]
[ 0.01112565  0.0117801   0.01112565  0.02552356  0.02683246  0.01308901
  0.0117801 ]
[ 0.0036855   0.00614251  0.00552826  0.02395577  0.02457002  0.01044226
  0.00675676]
Out[451]:
<matplotlib.text.Text at 0x118980ad0>

Observations

  • .9 is the same between the LBS and normal

  • error is lots higher in the moderate conditions

  • for some reason the .1 has higher SWITCH but not higher SAME errors - in general, switch errors seem to be higher in the .9 and .1 conditions. This is potentially really interesting!!


In [414]:
# 3. Look at prob correct on switch trials as a function of the length of the non-switch sequence preceeding it

#get counts
#this holds the counts, so we can look up by [probsame][previousNoSwitchTrials][correct]
class Multidict(dict):
    """Implementation of perl's autovivification feature."""

    def __getitem__(self, item):
        try:
            return dict.__getitem__(self, item)
        except KeyError:
            value = self[item] = type(self)()
            return value

counts = Multidict()

probs = [1, 2, 3, 4]
columns=["trial_num", "prob_same", "response_num", 'time_taken', 'response', 'correct', 'switch', 'target']

set_printoptions(threshold=nan)
trials = list(set(data[:,0]))
for trial in trials:
    data_trial = data[ data[:,0] == trial,:]
    num_responses = data_trial.shape[0]
    #print data_trial
    for response in range(10,num_responses):
        if data_trial[response,6] == 1: #if it's a switch trial
            #print response, 'is a switch trial'
            #look back up to 10 and count the number of subsequent NON-switches BEFORE this trial
            count = 0
            for lookback in range(-1,-11,-1):
                if data_trial[response + lookback,6] == 0: #if not switch
                    count += 1 #add to count
                else:
                    break #stop the for-loop, stop counting
            #add to database
            try:
                counts[data_trial[response,1]][count][data_trial[response,5]] = counts[data_trial[response,1]][count][data_trial[response,5]] +  1
            except: 
                counts[data_trial[response,1]][count][data_trial[response,5]] = 1

In [420]:
fig = figure();

colors = ['r','g','b','k']
#for each prob
for p in probs:
    #for each distance back
    counts_correct = zeros((1,10))
    counts_incorrect = zeros((1,10))
    for dist in range(1,11):
        #print counts_correct
    #get the correct and incorrect counts
        #print p, dist, counts[p][dist][1], counts[p][dist][0]
        #try adding to everything less than it
        for i in range(1,dist+1):
#             counts_correct[0][dist-1] += counts[p][dist][1]
#             counts_incorrect[0][dist-1] += counts[p][dist][0]
            if counts[p][i-1][0] == {}: counts[p][i-1][0] = 0
            if counts[p][i-1][1] == {}: counts[p][i-1][1] = 0    

            counts_correct[0][i-1] += counts[p][i][1]
            counts_incorrect[0][i-1] += counts[p][i][0]


    print counts_correct, counts_incorrect
    prob_error = counts_incorrect / (counts_incorrect + counts_correct)
    
    plot(range(1,11),prob_error[0],colors[p-1])
    print (counts_incorrect + counts_correct)
xlabel('number of non-switches proceeding the switch')
ylabel('proportion of errors')

#TODO question: why are there so many with 10-back but not with 9-back? 10-back is "10 or more" 
#TODO I don't think it's counting the incorrect properly - they should be strictly decreasing.


[[ 120.   72.   48.   21.    6.   20.   28.    6.    6.   32.]] [[ 10.   9.  16.  14.   0.   5.   0.   0.   2.   3.]]
[[ 130.   81.   64.   35.    6.   25.   28.    6.    8.   35.]]
[[ 610.  180.   72.   28.   48.   15.   24.   12.    0.    8.]] [[ 30.   9.  32.   7.   0.   5.   0.   0.   2.   0.]]
[[ 640.  189.  104.   35.   48.   20.   24.   12.    2.    8.]]
[[ 770.  414.  312.  147.   48.   60.   32.   21.    6.   12.]] [[ 40.  36.  32.   0.  36.  30.   4.   3.   0.   3.]]
[[ 810.  450.  344.  147.   84.   90.   36.   24.    6.   15.]]
[[ 180.   72.   48.   28.   60.   10.   12.    9.    4.   42.]] [[ 30.  18.  16.  21.  12.   5.   4.   0.   0.  13.]]
[[ 210.   90.   64.   49.   72.   15.   16.    9.    4.   55.]]
Out[420]:
<matplotlib.text.Text at 0x10ab9e590>

Comments

  • Red and Black are the extrme cases - they are more similar than green and blue.

  • 1-4 seems to be a strong trend. I wonder what happens with 5. In any case, PART of it seems linear.

  • TODO: need to fix the counting - for some reason the incorrect is not monotonically decreasing where it SHOULD be


In [ ]: