Perform AnkleSLIP continous predictions

last edits:

  • Feb. 7th, 2014 MM: forked from AnkleSLIP main notebook
  • Mar. 12, 2014 MM: changed "factor" computation to out-of-sample factor computation (not yet ready - bug present)
  • May 22, 2014 MM: changed figure format
  • Jul 22, 2014 MM: changed figure annotations

Step 1: Init notebook

to content


In [2]:
%cd
%cd mmnotebooks


/home/moritz
/qnap/pylibs_users/mm/mmnotebooks

In [3]:
# --- run required cells automatically? :) (notebook must be stored for this!)
import mutils.io as mio
import mutils.misc as mi

ws5subject = 3

if not 'ws5' in locals():
    ws5 = mio.saveable()
    
ws5.store_figs = True

nbfile = 'AnkleSLIP - find minimal model- Version 3.ipynb'

# load basic config
mi.run_nbcells(nbfile, ['0'])
conf.subject = ws5subject
conf.quiet = True
conf.startup_n_full_maps = 2
conf.startup_compute_PCA = False

mi.run_nbcells(nbfile, 
               ['0.1','1', '3', '3.1', '3.2',]) #'4', '4.2a', '4.2b', '4.2c'])

print "\ndone"


_saveables                 list (10)        list of types that can be stored
cslip_forceZeroRef         bool  True       reference values for controlled SLIP maps must be zero or not
dt_medfilter               bool  False      
dt_window                  int  30          
exclude_IC_from_factors    bool  False      
n_factors                  int  5           how many (optimal) factors to select of the full kinematic state
normalize_m                bool  True       
po_average_over_IC         bool  True       average over IC's and T,ymin (alt: parameters) for reference SLIP
quiet                      bool  False      
select_ankle_SLIP          bool  True       
startup_compute_PCA        bool  False      
startup_compute_full_maps  bool  True       
startup_n_full_maps        int  20          
subject                    int  2           
ttype                      int  1           

done

Step 2: create forward models

content


In [4]:
# create continuous model
# exclude first apex value because "augmented SLIP" is not defined for this

# these files contain the content of the corresponding cells.
# set this to "True" to re-run everything




ws5.skip_mdls = [4, 5, 7] # models *not* to recompute! (1-6; can be empty)
ws5.subj = conf.subject
ws5.out_of_sample = True
ws5.out_of_sample_doc = 'perform out of sample prediction (for CoM)'
ws5.nboot = 100
ws5.nboot_doc = 'number of bootstrap repetitions for predictions'
ws5.use_cached_fullmdl = False # use cached file or not

ws5.nps = 50

print "building models ..."
ws1.k.selection = ['com_x', 'com_y', 'com_z']
odat_c = ws1.k.make_1D(nps=ws5.nps, phases_list=ws1.dataset_full.all_phases_r)[:, 2*ws5.nps:]
ws5.odat = fda.dt_movingavg(odat_c, conf.dt_window, conf.dt_medfilter)[1:, :]
ws5.odat_doc = 'com z, vx, vy, vz with ' + str(ws5.nps) + ' nps'

pr = mod(hstack(ws1.dataset_full.all_phases_r), 2*pi)
pr = fda.dt_movingavg(pr[:, newaxis], conf.dt_window, conf.dt_medfilter)


# create floquet models
if not 1 in ws5.skip_mdls:
    ws5.idat1 = ws1.dataset_full.n_kin_r[1:, :]
    ws5.idat1_doc = 'model 1: full dataset at (right) apex'


if not 2 in ws5.skip_mdls:
    ws5.idat2 = ws1.reddat_r[1:, :]
    addstr = 'including IC' if not conf.exclude_IC_from_factors else 'excluding IC'
    ws5.idat2_doc = 'model 2: CoM + factors ' + addstr

if not 3 in ws5.skip_mdls:
    ws1.k.selection = ['com_x', 'com_y', 'com_z', 'r_anl_y - com_y', 
                       'r_anl_z - com_z', 'r_anl_x - com_x',] #'r_sia_y - l_sia_y'] # add hip rotation indicator?
    tmp_dsr = build_dataset(ws1.k, ws1.SlipData, dt_window=conf.dt_window, dt_median=conf.dt_medfilter)
    ws5.idat3 = fda.dt_movingavg(tmp_dsr.all_kin_r, conf.dt_window, conf.dt_medfilter)[1:, :]
    ws5.idat3_doc = 'model 3: CoM + ankles '

if not 4 in ws5.skip_mdls:
    ws5.idat4 = fda.dt_movingavg(hstack([ws1.dataset_full.all_IC_r[1:, :], 
                                         pr[1:, :], # phase mod 2pi at apex
                                         hstack(dpr)[1:, newaxis], # phase velocity at apex
                                         ]), conf.dt_window, conf.dt_medfilter)
    ws5.idat4_doc = 'model 4: CoM + phi + vphi'

if not 5 in ws5.skip_mdls:
    ws5.idat5 = fda.dt_movingavg(hstack([ws1.dataset_full.all_IC_r[1:, :], ws1.dataset_full.s_param_l[:-1, :],
                                         pr[1:, :]]), conf.dt_window, conf.dt_medfilter)
    ws5.idat5_doc = 'model 5: augmented SLIP (exact apex state + phase info)'


if not 6 in ws5.skip_mdls:
    ws5.idat6 = fda.dt_movingavg(hstack([ws1.dataset_full.all_IC_r[1:, :], ws1.dataset_full.s_param_l[:-1, :]]),
                                 conf.dt_window, conf.dt_medfilter)
    ws5.idat6_doc = 'model 6: augmented SLIP'

if not 7 in ws5.skip_mdls:
    ws1.k.selection = ['com_x', 'com_y', 'com_z', 'r_anl_y - com_y', 
                       'r_anl_z - com_z', 'r_anl_x - com_x', 'cvii_y - sacr_y']
    tmp_dsr = build_dataset(ws1.k, ws1.SlipData, dt_window=conf.dt_window, dt_median=conf.dt_medfilter)
    ws5.idat7 = fda.dt_movingavg(tmp_dsr.all_kin_r, conf.dt_window, conf.dt_medfilter)[1:, :]
    ws5.idat7_doc = 'model 7: CoM + ankles '
    

print "performing prediction tests"
print "*" * ws5.nps

print "model 1:"
if 1 in ws5.skip_mdls:
    print "skipped"
else:
    res = None
    if ws5.use_cached_fullmdl:
        try:
            res = mio.mload('tmp/vred_s{}fm.list'.format(conf.subject))
            print "loaded mdl from file (median detrended!)"
        except IOError:
            pass
        
    if res == None: # not loaded; either not found or cache disabled
        res = []
        for rep in range(ws5.nps):
            res.append(vstack(st.predTest(ws5.idat1, ws5.odat[:,rep::ws5.nps], out_of_sample=ws5.out_of_sample,
                                          nboot=ws5.nboot)))
            sys.stdout.write('.')
        print " done"
        mio.msave('tmp/vred_s{}fm.list'.format(conf.subject), res)
        
    ws5.pred_mdl1 = res
    ws5.pred_mdl1_doc = ''.join(['list of rem. var. of CoM using mdl1 for prediction; for each of ',
                                 str(ws5.nps), ' phases of gait cyc.'])

print "model 2:"
if 2 in ws5.skip_mdls:
    print "skipped"
else:
    # this code was developed and tested in another cell, that's why it looks strange at first
    idat = ws1.dataset.s_kin_r[1:, :]
    odat = ws5.odat
    
    arrv = []
    cnt = 0
    for sct in range(ws5.nps): # section to predict
        rrv = []
        cnt += 1
        print 'sec: ', cnt, '/', ws5.nps, 
        for rep in range(ws5.nboot):
            sys.stdout.write('.')
            l_odat = odat[:, sct::ws5.nps]
            bs_idx = randint(0, idat.shape[0], idat.shape[0])
            o_idx = fda.otheridx(bs_idx, idat.shape[0])
            
            
            facs_r = st.find_factors(idat[bs_idx, :].T, ws1.dataset.s_param_r[bs_idx + 1, :].T, k=conf.n_factors)
            fscore_r = dot(facs_r.T, idat.T).T
            
            facmdl = hstack([fscore_r, ws1.dataset_full.all_IC_rc[1:, :]])
            facmdl = fda.dt_movingavg(facmdl, conf.dt_window, conf.dt_medfilter)
    
            
            A = dot(l_odat[bs_idx, :].T, pinv(facmdl[bs_idx, :].T,rcond=1e-8))
           
            pred = dot(A, facmdl[o_idx, :].T).T
            
            rrv.append(var(l_odat[o_idx, :] - pred, axis=0) / var(l_odat[o_idx, :], axis=0))
        
        arrv.append(vstack(rrv))
        print '\n',
        
    print '\ndone'
    
    
    # old (original) code
    #res = []
    #for rep in range(ws5.nps):
   # 
   #     res.append(vstack(st.predTest(ws5.idat2, ws5.odat[:,rep::ws5.nps], out_of_sample=ws5.out_of_sample,
   #                                   nboot=ws5.nboot)))
   #     sys.stdout.write('.')
    #print " done"
    #ws5.pred_mdl2 = res
    ws5.pred_mdl2 = arrv
    ws5.pred_mdl2_doc = ''.join(['list of rem. var. of CoM using mdl2 for prediction; for each of ',
                                 str(ws5.nps), ' phases of gait cyc.'])

    

print "model 3:"
if 3 in ws5.skip_mdls:
    print "skipped"
else:
    res = []
    for rep in range(ws5.nps):
        res.append(vstack(st.predTest(ws5.idat3, ws5.odat[:,rep::ws5.nps], out_of_sample=ws5.out_of_sample,
                                      nboot=ws5.nboot)))
        sys.stdout.write('.')
    print " done"
    ws5.pred_mdl3 = res
    ws5.pred_mdl3_doc = ''.join(['list of rem. var. of CoM using mdl3 for prediction; for each of ',
                                 str(ws5.nps), ' phases of gait cyc.'])

print "model 4:"
if 4 in ws5.skip_mdls:
    print "skipped"
else:
    res = []
    for rep in range(ws5.nps):
        res.append(vstack(st.predTest(ws5.idat4, ws5.odat[:,rep::ws5.nps], out_of_sample=ws5.out_of_sample,
                                      nboot=ws5.nboot)))
        sys.stdout.write('.')
    print " done"
    ws5.pred_mdl4 = res
    ws5.pred_mdl4_doc = ''.join(['list of rem. var. of CoM using mdl4 for prediction; for each of ',
                                  str(ws5.nps),' phases of gait cyc.' ])

print "model 5:"
if 5 in ws5.skip_mdls:
    print "skipped"
else:
    res = []
    for rep in range(ws5.nps):
        res.append(vstack(st.predTest(ws5.idat5, ws5.odat[:,rep::ws5.nps], out_of_sample=ws5.out_of_sample,
                                      nboot=ws5.nboot)))
        sys.stdout.write('.')
    print " done"
    ws5.pred_mdl5 = res
    ws5.pred_mdl5_doc = ''.join(['list of rem. var. of CoM using mdl5 for prediction; for each of ',
                                 str(ws5.nps), ' phases of gait cyc.'])

print "model 6:"
if 6 in ws5.skip_mdls:
    print "skipped"
else:
    res = []
    for rep in range(ws5.nps):
        res.append(vstack(st.predTest(ws5.idat6, ws5.odat[:,rep::ws5.nps], out_of_sample=ws5.out_of_sample,
                                      nboot=ws5.nboot)))
        sys.stdout.write('.')
    print " done"
    ws5.pred_mdl6 = res
    ws5.pred_mdl6_doc = ''.join(['list of rem. var. of CoM using mdl5 for prediction; for each of ' ,
                                 str(ws5.nps), ' phases of gait cyc.'])


print "model 7:"
if 7 in ws5.skip_mdls:
    print "skipped"
else:
    res = []
    for rep in range(ws5.nps):
        res.append(vstack(st.predTest(ws5.idat7, ws5.odat[:,rep::ws5.nps], out_of_sample=ws5.out_of_sample,
                                      nboot=ws5.nboot)))
        sys.stdout.write('.')
    print " done"
    ws5.pred_mdl7 = res
    ws5.pred_mdl7_doc = ''.join(['list of rem. var. of CoM using mdl7 for prediction; for each of ',
                                  str(ws5.nps),' phases of gait cyc.' ])


building models ...
performing prediction tests
**************************************************
model 1:
.................................................. done
model 2:
sec:  1 / 50.................................................................................................... 
sec:  2 / 50.................................................................................................... 
sec:  3 / 50.................................................................................................... 
sec:  4 / 50.................................................................................................... 
sec:  5 / 50.................................................................................................... 
sec:  6 / 50.................................................................................................... 
sec:  7 / 50.................................................................................................... 
sec:  8 / 50.................................................................................................... 
sec:  9 / 50.................................................................................................... 
sec:  10 / 50.................................................................................................... 
sec:  11 / 50.................................................................................................... 
sec:  12 / 50.................................................................................................... 
sec:  13 / 50.................................................................................................... 
sec:  14 / 50.................................................................................................... 
sec:  15 / 50.................................................................................................... 
sec:  16 / 50.................................................................................................... 
sec:  17 / 50.................................................................................................... 
sec:  18 / 50.................................................................................................... 
sec:  19 / 50.................................................................................................... 
sec:  20 / 50.................................................................................................... 
sec:  21 / 50.................................................................................................... 
sec:  22 / 50.................................................................................................... 
sec:  23 / 50.................................................................................................... 
sec:  24 / 50.................................................................................................... 
sec:  25 / 50.................................................................................................... 
sec:  26 / 50.................................................................................................... 
sec:  27 / 50.................................................................................................... 
sec:  28 / 50.................................................................................................... 
sec:  29 / 50.................................................................................................... 
sec:  30 / 50.................................................................................................... 
sec:  31 / 50.................................................................................................... 
sec:  32 / 50.................................................................................................... 
sec:  33 / 50.................................................................................................... 
sec:  34 / 50.................................................................................................... 
sec:  35 / 50.................................................................................................... 
sec:  36 / 50.................................................................................................... 
sec:  37 / 50.................................................................................................... 
sec:  38 / 50.................................................................................................... 
sec:  39 / 50.................................................................................................... 
sec:  40 / 50.................................................................................................... 
sec:  41 / 50.................................................................................................... 
sec:  42 / 50.................................................................................................... 
sec:  43 / 50.................................................................................................... 
sec:  44 / 50.................................................................................................... 
sec:  45 / 50.................................................................................................... 
sec:  46 / 50.................................................................................................... 
sec:  47 / 50.................................................................................................... 
sec:  48 / 50.................................................................................................... 
sec:  49 / 50.................................................................................................... 
sec:  50 / 50.................................................................................................... 

done
model 3:
.................................................. done
model 4:
skipped
model 5:
skipped
model 6:
.................................................. done
model 7:
skipped

Step 3: visualize

content


In [5]:
# --- visualize

use_steps_as_xticks = True

def plotband(x, y, ups, downs, color='#0067ea', alpha=.25, **kwargs):
    """
    plots a line surrounded by a ribbon of the same color, but semi-transparent
    
    :args:
        x (N-by-1): positions on horizontal axis
        y (N-by-1): corresponding vertical values of the (center) line
        ups (N-by-1): upper edge of the ribbon
        downs (N-by-1): lower edge of the ribbon
        
    :returns:
        [line, patch] returns of underlying "plot" and "fill_between" function
    """
    pt1 = plot(x, y, color, **kwargs )
    pt2 = fill_between(x, ups, downs, color='None', facecolor=color, lw=0, alpha=alpha)
    return [pt1, pt2]




import matplotlib.font_manager as fmt
FM = fmt.FontManager()
#fnt = FM.findfont('Times New Roman') # for ProcRSoc B
fnt = FM.findfont('Arial') # for NComm
fnt_s = FM.findfont('Symbol') # for NComm

fp = fmt.FontProperties(fname=fnt)
fp.set_size(9.0)

fp_s = fmt.FontProperties(fname=fnt_s)
fp_s.set_size(9.0)


# make font in math the same as in text
import matplotlib as mpl
mpl.rcParams['mathtext.default'] = 'regular' # this is the key

import mutils.plotting as mplt
colors_ = mplt.colorset_distinct

means = []
stds = []

for mdl in arange(7) + 1:
    if mdl not in ws5.skip_mdls:
        code = 'vstack([mean(x, axis=0) for x in ws5.pred_mdl{}])'.format(mdl)
        means.append(eval(code))
        code = 'vstack([std(x, axis=0) for x in ws5.pred_mdl{}])'.format(mdl)
        stds.append(eval(code))
        
#fig = figure(figsize=(6.83,2)) # ProcRSoc B
#fig = figure(figsize=(18./2.56, 1.5)) # NComm
fig = figure(figsize=(18./2.56, 1.25)) # NComm


phase = linspace(0,2*pi, 50, endpoint=False)

titles = ['CoM height'.format(conf.subject),
          'CoM lat. velocity'.format(conf.subject),
          'CoM horiz. velocity'.format(conf.subject),]

fmt_m = {'linestyle' : '-', 'marker' : '', 'linewidth' : 1.5}
fmt_s = {'linestyle' : '--', 'linewidth' : .75}
mdl_lbl = ['full state', 'Factor-SLIP state', 'Ankle-SLIP state', 'Augmented-SLIP state' ] #'ankle* SLIP']

all_ax = []

for dim in range(3):
    all_ax.append(subplot(1, 3, dim+1))
    for nr, (m, s) in enumerate(zip(means, stds)):
        plotband(phase, m[:, dim], m[:, dim] + s[:, dim],  m[:, dim] - s[:, dim], color=colors_[nr])
        #plot(phase, m[:, dim], color=colors_[nr], **fmt_m)
        #plot(phase, m[:, dim] + s[:, dim], color=colors_[nr], **fmt_s)
        #plot(phase, m[:, dim] + s[:, dim], styles_[nr])
     
    #title(titles[dim], fontproperties=fp)
    gca().set_xticks([0, 3.14, 6.28])
    if use_steps_as_xticks:
        gca().set_xticklabels(['0', '1', '2'], fontproperties=fp)
    else:
        gca().set_xticklabels(['0', r'$\pi$', r'$2 \pi$'], fontproperties=fp_s)
    gca().set_yticks(arange(6) * .2)

    
    
    ylim(0, 1.1)
    plot([0,6.28], [1, 1], 'k--')
    plot([pi, pi],[0, 1], 'k--')
    grid(True, ls='-', lw=.1)
    if dim == 1:
        #xlabel('phase [rad]', fontproperties=fp)
        pass
    if dim == 0:
        if use_steps_as_xticks:
            gca().set_yticklabels(arange(6) * .2, fontproperties=fp)
            gca().text(pi/2, -0.2, 'steps ahead', fontproperties=fp, ha='center', va='bottom')
            gca().arrow(pi/2-1, -0.25, 2, 0, clip_on=False ) # width=
        else:
            gca().text(pi/2, -0.2, 'phase [rad]', fontproperties=fp, ha='center', va='bottom')
            gca().set_yticklabels(arange(6) * .2, fontproperties=fp_s)
            gca().arrow(pi/2-1, -0.25, 2, 0, clip_on=False ) # width=        
        ylabel('rrv', fontproperties=fp)       
                
    else:
        gca().set_yticklabels([])

lax = axes([.025,.92,.95, .065], frameon=False)

#lax = axes([.025,.86,.95, .125], frameon=False)
#lax.text(2.32, 0.45, 'prediction quality for different affine models'.format(conf.subject),
#         va='center', ha='left', fontproperties=fp)
xpos = array([2.35, 3.1, 4.35, 5.55, 5.95]) - 1.5
for mdl in range(4):
    lax.plot([xpos[mdl], xpos[mdl] + .15], [-1, -1], '-', color=colors_[mdl], **fmt_m)
    lax.text(xpos[mdl] + .2 , -1, mdl_lbl[mdl], va='center', fontproperties=fp) # va: equivalent to verticalalignment
lax.set_ylim([-1,0])
lax.set_xlim(.5, 5.5)

lax.set_xticks([])
lax.set_yticks([])

for ax, lbl in zip(all_ax, ['CoM height', 'CoM lat. velocity', 'CoM horiz. velocity']):
    xl = ax.get_xlim()
    yl = ax.get_ylim()
    x_ = xl[0] + .55
    y_ = yl[0] + .9*(yl[1] - yl[0]) - 0.025
    ax.text(x_, y_, lbl, fontproperties=fp, bbox=dict(facecolor='#ffffff', edgecolor='k', pad=3.0, lw=.1), va='center')


all_ax[1].arrow(.3,0.95,0,-0.8, head_width=0.15, head_length=0.1, lw=1, )
all_ax[1].text(.8,0.45, 'better\n prediction', ha='center', va='center', rotation=90, fontproperties=fp,
               bbox=dict(facecolor='#ffffff', edgecolor='None', pad=0, lw=.1))
    
subplots_adjust(left=.075, right=.975, wspace=.08,bottom=.2, top=.85)
#subplots_adjust(left=.075, right=.975, wspace=.08,bottom=.2, top=.775)

if ws5.store_figs:
    savefig('img/fig_fmdl2_pred_subj{}.pdf'.format(conf.subject))
    print 'stored as: img/fig_fmdl2_pred_subj{}.pdf'.format(conf.subject)
#    savefig('img/fig_fmdl_pred_subj{}.svg'.format(conf.subject))

pass # suppress output of last function call


stored as: img/fig_fmdl2_pred_subj3.pdf
/usr/lib/pymodules/python2.7/matplotlib/font_manager.py:1236: UserWarning: findfont: Font family ['Symbol'] not found. Falling back to Bitstream Vera Sans
  (prop.get_family(), self.defaultFamily[fontext]))
/usr/lib/pymodules/python2.7/matplotlib/font_manager.py:1246: UserWarning: findfont: Could not match :family=Bitstream Vera Sans:style=normal:variant=normal:weight=normal:stretch=normal:size=10. Returning /usr/share/matplotlib/mpl-data/fonts/ttf/cmb10.ttf
  UserWarning)

In [5]:



filename: img/fig_fmdl2_pred_subj7.pdf

In [ ]: