S1 - Sample-based implementation of Blahut-Arimoto iteration

This notebook is part of the supplementary material for:
Genewein T., Leibfried F., Grau-Moya J., Braun D.A. (2015) Bounded rationality, abstraction and hierarchical decision-making: an information-theoretic optimality principle, Frontiers in Robotics and AI.

More information on how to run the notebook on the accompanying github repsitory where you can also find updated versions of the code and notebooks.

This notebook in mentioned in Section 2.3. Due to time- and space-limitations the results of this notebook are not in the paper.

Disclaimer

This notebook provides a proof-of-concept implementation of a naive sample-based Blahut-Arimoto iteration scheme. Neither the code nor the notebook have been particularly polished or tested. There is a short theory-bit in the beginning of the notebook but most of the explanations are brief and mixed into the code as comments.

Free energy rejection sampling

The solution to a free energy variational problem (Section 2.1 in the paper) has the form of a Boltzmann distribution $$p(y) = \frac{1}{Z}p_0(y)e^{\beta U(y)},$$ where $Z=\sum_y p_0(y)e^{\beta U(y)}$ denotes the partition sum, $p_0(y)$ is a prior distribution and $U(y)$ is the utility function. The inverse temperature $\beta$ can be interpreted as a resource parameter and it governs how far the posterior $p(y)$ can deviate from the prior (measured as a KL-divergence) - see the paper Section 2.1 for details.

For a decision-maker, it suffices to obtain a sample from $p(y)$ and act according to that sample, rather than computing the full distribution $p(y)$. A simple scheme to sample from $p(y)$ is given by rejection sampling.

Rejection sampling
Goal: get a sample $y$ from the distribution $f(y)$. Draw from a uniform distribution $u\sim \mathcal{U}(0,1)$ and from a proposal distribution $y\sim g(y)$. If $u < \frac{f(y)}{M g(y)}$, accept the sample as a sample from $f(y)$, otherwise reject the sample and repeat. $M$ is a constant that ensures that $M g(y) \geq f(y)~\forall y$. Note that rejection sampling also works for sampling from an unnormalized distribution as long as $M$ is chosen accordingly.

For the free-energy problem, we want a sample from $p(y)\propto f(y) = p_0(y)e^{\beta U(y)}$. We choose $g(y)=p_0(y)$ and set $M=e^{\beta U_{max}}$, where $U_{max}=\underset{y}{max}~U(y)$

Finally we get the following rejection sampling scheme:

  • draw from a uniform distribution $u\sim \mathcal{U}(0,1)$
  • draw from the proposal distribution $x\sim p_0(y)$ (the prior)
    • if $u < \frac{\exp(\beta U(y))}{\exp(\beta U_\mathrm{max})}$ accept the sample as a sample from the posterior $p(y)$
    • otherwise reject the sample (and re-sample).

Rate distortion rejection sampling

The solution to the rate distortion problem looks very similar to the Boltzmann distribution in the free-energy case. However, there is one crucial difference: in the free-energy case, the prior is an arbitrary distribution - in the rate distortion case, the prior is replaced by the marginal distribution, which leads to a set of self-consistent equations $$\begin{align} p^*(a|w)&=\frac{1}{Z(w)}p(a)e^{\beta U(a,w)} \\ p(a)&=\sum_w p(w)p(a|w) \end{align}$$ After convergence of the Blahut-Arimoto iterations, the marginal $p(a)$ can just be treated like a prior and the rejection sampling scheme described above can straightforwardly be used. However, when initializing with an arbitary marginal distribution $\hat{p}(a)$ the iterations must be performed in a sample-based manner until convergence.

Here, we do this in a naive and very straightforward way: we represent $\hat{p}(a)$ simply through counters (a categorical distribution). Then we do the following:

  1. Draw a number of samples (a batch) from $\hat{p}^*(a|w)=\frac{1}{Z(w)}\hat{p}(a)e^{\beta U(a,w)}$ using the rejection sampling scheme.
  2. Update $\hat{p}(a)$ with the accepted samples obtained in step 1. There are different possibilities for the update step
    1. Simply increase the counters for each accepted $a$ and re-normalize (no-forgetting)
    2. Reset the counters for $a$ and use only the last batch of accepted samples to empirically estimate $p(a)$ (full-forgetting)
    3. Use an exponentially decaying window over the last samples to update the empirical estimate of $p(a)$ (not implemented in this notebook).
    4. Use a parametric model for $p_theta(a)$ and then perform some moment-matching or use a gradient-based update rule to adjust the parameters $\theta$ (not implemented in this notebook).
  3. Repeat until convergence (or here for simplicity: for a fixed number of steps)

Additionally, this notebook allows for some burn-in time, where after a certain number of iterations of 1. and 2. (i.e. after the "burn-in") the counters for $\hat{p}(a)$ are reset. This naive scheme seems to work but it is unclear how to choose the batch-size (number of samples from $\hat{p}^*(a|w)$ to take before performing an update step on $\hat{p}(a)$), how to set the burn-in phase, etc.

In the notebook below, you can try different batch-sizes and different burn-in times and you can compare full-forgetting against no forgetting (i.e. no resetting of the counters).


In [ ]:
#only run this once
include("RateDistortionDecisionMaking.jl")

Load the taxonomy example as a testbed


In [ ]:
#load taxonomy example
using RateDistortionDecisionMaking, DataFrames, Gadfly, Distributions

#set up taxonomy example
include("TaxonomyExample.jl")
w_vec, w_strings, a_vec, a_strings, p_w, U = setuptaxonomy()

#pre-compute utilities, find maxima
U_pre, Umax = setuputilityarrays(a_vec,w_vec,U)

#initialize p(a) uniformly
num_acts = length(a_vec)
pa_init = ones(num_acts)/num_acts;

Set up functions for sampling and run on the example from above


In [ ]:
#Performs rejection sampling with a constant (scaled uniform) envelope
#using a softmax acceptance-rejection criterion.
#prop_dist .. proposal distribution (must be an instance of type ::Distribution)
#nsamps ..... desired number of samples (scalar)
#maxsteps ... maximum number of acceptance-rejection steps (scalar, must be ≧ nsamps)
#β .......... softmax parameter
#lh ......... likelihood value (vector of length N)
#maxlh ...... maximum value that the likelihood can take (scalar)
function rej_samp_const(prop_dist::Distribution, nsamps::Integer, maxsteps::Integer, β::Number, lh::Vector, maxlh::Number)    
    #initialize
    samps = zeros(nsamps)
    acc_cnt = 0  #acceptance-counter
    if(maxsteps < nsamps)
        maxsteps = nsamps
    end
    
    k=0 #use this to make sure that k is still available after the loop
    for k in 1:maxsteps
        u=rand(1) #sample from uniform between (0,1)
        index = rand(prop_dist) #sample from proposal
        
        ratio = exp(β*lh[index])/exp(β*maxlh)
        if u[1]<ratio #explicit indexing is needed to get a float, since the >= can not handle arrays
            #if we enter here, accept the sample                       
            acc_cnt = acc_cnt + 1     
            samps[acc_cnt] = index
            
            if(acc_cnt == nsamps)
                #we have enough samples, exit loop
                break
            end
        end
    end
    
    if(k==maxsteps)
        warn("[RejSampConst] Maximum number of steps reached - number of samples is potentially lower than nsamps!\n")
    end
    
    #store all accepted samples (this can be less than nsamps if maxsteps is too low or acceptance-rate is low)
    samples = samps[1:acc_cnt]
    
    #compute acceptance ratio
    acc_ratio = acc_cnt/k
    
    return samples, acc_ratio
end

In [ ]:
#marginal is simply represented by counters (i.e. by frequencies)
function init_marginal_representation_ctrs(pa_init::Vector)
    return pa_init
end

#this updates the marginal over actions p(a) using a counter-representation
#this function counts the number of times each action-index occurs in sampled_indices
#these counts are then added to the current marginal_ctrs. Optionally, the counters are reset
#before adding the new samples (=hard forgetting).
function update_marginal_ctrs(sampled_indices::Vector, marginal_ctrs::Vector; reset_ctrs::Bool=false)    
    #TODO: perhaps replace hard-resetting with an exponential decay?
    
    p_ctrs = marginal_ctrs
    card_p = length(p_ctrs)
    
    #reset counters for marginal? (make sure every entry is non-zero!)
    if reset_ctrs
        p_ctrs = ones(card_p)/card_p
    end

    #update marginal counters using a histogram to do the counting (bin-borders have to be set manually!)
    e,p_counts = hist(sampled_indices,0.5:1:(card_p+0.5))         
    p_ctrs = p_ctrs + p_counts
    
    #normalize to get the updated marginal
    p_sampled = p_ctrs / sum(p_ctrs)
    
    return p_sampled, p_ctrs  #return the probability-vector, but also the representation of the marginal (as counts)
end



#function for BA sampling
#burnin_ratio specifies the ratio of outer iterations that will not count
#towards computation of the final marginal distribution (counters will be blocked)
#reset_marginal_ctrs specifies whether the marginal is computed with the samples of the last
#iteration only (=hard forgetting by resetting counters) or whether the marginal
#is computed with all samples of all iterations (=no forgetting)
function BAsampling(pa_init::Vector, β::Number, U_pre::Matrix, Umax::Vector, pw::Vector, 
                    nsteps_marginalupdate::Integer, nsteps_conditionalupdate::Integer;
                    burnin_ratio::Real=0.7, max_rejsamp_steps::Integer=200,
                    compute_performance::Bool=false, performance_as_dataframe::Bool=false,
                    performance_per_iteration::Bool=false,
                    init_marg_func::Function=init_marginal_representation_ctrs,
                    update_marg_func::Function=update_marginal_ctrs, update_func_args...)
    
    #compute cardinality, check size of U_pre
    card_a = length(pa_init)
    card_w = length(pw)
    if size(U_pre) != (card_a, card_w)
        error("Size mismatch of U_pre and pa_init or pw!")
    end
    
    #check that burnin_ratio is really a ratio
    if (burnin_ratio < 0) || (burnin_ratio > 1)
        error("burnin_ratio must be a number between 0 and 1.")
    end
    
    #if performance measures don't need to be returned, don't compute them per iteration
    if compute_performance==false
        performance_per_iteration = false
    end 
    #preallocate if necessary
    if performance_per_iteration 
        I_i = zeros(maxiter)
        Ha_i = zeros(maxiter)
        Hagw_i = zeros(maxiter)
        EU_i = zeros(maxiter)
        RDobj_i = zeros(maxiter)
    end
    
    #initialize sampling distributions
    pw_dist = Categorical(pw) #proposal distribution    
    pagw_ctrs = ones(card_a, card_w) #counters for conditional distribution   
    pa_sampled = pa_init #marginal distribution
    
    #initialize the marginal representation
    pa_ctrs = init_marg_func(pa_init)

    burnin_triggered=false
    #outer loop - in each iteration the marginal is updated
    iter=0
    for iter in 1:nsteps_marginalupdate        
        a_samples = zeros(nsteps_conditionalupdate)  #this will hold the samples from p(a|w) during inner loop
        
        #inner loop - in each step a sample is drawn from the conditional and stored for
        #for the batch-update of the marginal
        for j in 1:nsteps_conditionalupdate
            #draw a w sample
            w_samp = rand(pw_dist)

            #draw a sample from p(a|w) using the current estimate of p(a) as proposal distribution using rejection sampling
            agw_samp, acc_ratio = rej_samp_const(Categorical(pa_sampled), 1, max_rejsamp_steps, β, U_pre[:,w_samp], Umax[w_samp])
            a_samples[j] = agw_samp[1]
            

            #update conditional counters
            pagw_ctrs[agw_samp, w_samp] += 1
        end
        
        #very simple burn-in: simply reset counters
        if (iter >(nsteps_marginalupdate)*burnin_ratio) && (!burnin_triggered)
            burnin_triggered = true
            pagw_ctrs = ones(card_a, card_w)                                
        end

        #update marginal with samples drawn in inner loop
        pa_sampled, pa_ctrs = update_marg_func(a_samples, pa_ctrs; update_func_args...)       
        
        
        #compute entropic quantities (if requested with additional parameter)
        if performance_per_iteration
            #compute sample-based conditional p(a|w)
            pagw_sampled = zeros(card_a, card_w)
            for i in 1:card_w
                pagw_sampled[:,i] = pagw_ctrs[:,i] / sum(pagw_ctrs[:,i])
            end
            I_i[iter], Ha_i[iter], Hagw_i[iter], EU_i[iter], RDobj_i[iter] = analyzeBAsolution(pw, pa_sampled, pagw_sampled, U_pre, β)
        end
    end

    #compute conditionals using the sample-counts of the previous inner loops
    #the burn-in parameter specifies how many of the inner loops are discarded
    pagw_sampled = zeros(card_a, card_w)
    for i in 1:card_w
        pagw_sampled[:,i] = pagw_ctrs[:,i] / sum(pagw_ctrs[:,i])
    end

    #return results
    if compute_performance == false
        return pagw_sampled, pa_sampled
    else            
        if performance_per_iteration == false
            #compute performance measures for final solution
            I, Ha, Hagw, EU, RDobj = analyzeBAsolution(pw, pa_sampled, pagw_sampled, U_pre, β)
        else
            #"cut" valid results from preallocated vector
            I = I_i[1:iter]
            Ha = Ha_i[1:iter]
            Hagw = Hagw_i[1:iter]
            EU = EU_i[1:iter]
            RDobj = RDobj_i[1:iter]
        end

        #if needed, transform to data frame
        if performance_as_dataframe == false
            return pagw_sampled, pa_sampled, I, Ha, Hagw, EU, RDobj
        else
            performance_df = performancemeasures2DataFrame(I, Ha, Hagw, EU, RDobj)
            return pagw_sampled, pa_sampled, performance_df 
        end
    end
    
end

In [ ]:
#example call and also plot evolution of performance measueres
maxiter = 10000
β = 1.2
nsteps_marg = 500
nsteps_cond = 750
pagw_s,pa_s,perf = BAsampling(pa_init, β, U_pre, Umax, p_w, nsteps_marg, nsteps_cond,
                              burnin_ratio=0.7, max_rejsamp_steps=500, reset_ctrs=false,
                              compute_performance=true, performance_as_dataframe=true, performance_per_iteration=true)

plt_cond = visualizeBAconditional(pagw_s,a_vec,w_vec,a_strings,w_strings)

#instead of using a range of β-values (as for the standard-performance plot), 
#use a vector indicating the iteration
niter = size(perf,1)
plt_perf_entropy, plt_perf_utility, plt_rateutility = plotperformancemeasures(perf,collect(1:niter),
                                                      suppress_vis=true, xlabel_perf="Iteration")

#TODO: somehow the "Iteration" label above doesn't seem to work!

display(vstack(plt_perf_entropy, plt_perf_utility))
display(plt_cond)

[Interact] Change the parameters in the code-cell above to explore the sampling scheme and its solutions

Compare sampling-solutions against analytical results

Below, we will average over several sampling-runs at different temperatures $\beta$ to see a difference between the analytical solutions and the sample-based solutions (with and without forgetting).


In [ ]:
#compute theoretical result for rate-dutility curve
ε = 0.0001 #convergence critetion for BAiterations
maxiter = 10000
β_sweep = collect(0.01:0.05:3)
 = length(β_sweep)

#preallocate
I = zeros()
Ha = zeros()
Hagw = zeros()
EU = zeros()
RDobj = zeros()

#sweep through β values and perfomr Blahut-Arimoto iterations for each value
for i=1:    
    pagw, pa, I[i], Ha[i], Hagw[i], EU[i], RDobj[i] = BAiterations(pa_init, β_sweep[i], U_pre, p_w, ε, maxiter,compute_performance=true)  
end

#show rate-utility curve (shaded region is theoretically infeasible)
perf_res_analytical = performancemeasures2DataFrame(I, Ha, Hagw, EU, RDobj);  
plot_perf_entropy, plot_perf_util, plot_rateutility = plotperformancemeasures(perf_res_analytical, β_sweep, suppress_vis=true);
display(plot_rateutility)

In [ ]:
#run the smapling for different temperatures and repeat each run n-times
#then plot these results against the rate-utility curve (based on the closed-form solutions)
βrange_samp = [0.1, 0.25, 0.5, 0.8, 1.2, 1.4, 1.6, 2]
#βrange_samp = [1.2, 2]

nruns = 10; #number of runs per β point

nsteps_marg = 500
nsteps_cond = 750
burnin_ratio = 0.8
max_rejsamp_steps=500 #maximum number of steps for sampling from the conditional


nconditions = size(βrange_samp,1)*nruns
I_sampled = zeros(2*nconditions)
Ha_sampled = zeros(2*nconditions)
EU_sampled = zeros(2*nconditions)
βval = zeros(2*nconditions)
ResetCtrs = falses(2*nconditions)

#first run with reset_ctrs to false
reset_ctrs = false #if true, the ctrs for the marginal are reset in each iteration (="hard" forgetting)
for b in 1:nconditions
    println("BA Sampling, run $b of $(2*nconditions)")
    β = βrange_samp[round(Int,ceil(b/nruns))]
    
    pagw_s, pa_s, I, Ha, Hagw, EU, RDobj = BAsampling(pa_init, β, U_pre, Umax, p_w, nsteps_marg, nsteps_cond,
                                               reset_ctrs=reset_ctrs, burnin_ratio=burnin_ratio,
                                               max_rejsamp_steps=max_rejsamp_steps, compute_performance=true)
    
    I_sampled[b] = I
    Ha_sampled[b] = Ha
    EU_sampled[b] = EU
    βval[b] = β
    ResetCtrs[b] = reset_ctrs
end

#second run with reset_ctrs to true
reset_ctrs = true #if true, the ctrs for the marginal are reset in each iteration (="hard" forgetting)
for b in 1:nconditions
    println("BA Sampling, run $(nconditions+b) of $(2*nconditions)")
    β = βrange_samp[round(Int,ceil(b/nruns))]
    
    pagw_s, pa_s, I, Ha, Hagw, EU, RDobj = BAsampling(pa_init, β, U_pre, Umax, p_w, nsteps_marg, nsteps_cond,
                                               reset_ctrs=reset_ctrs, burnin_ratio=burnin_ratio,
                                               max_rejsamp_steps=max_rejsamp_steps, compute_performance=true)
    
    I_sampled[nconditions+b] = I
    Ha_sampled[nconditions+b] = Ha
    EU_sampled[nconditions+b] = EU
    βval[nconditions+b] = β
    ResetCtrs[nconditions+b] = reset_ctrs
end

#wrap data in DataFrame for convenient plotting
res_sampled = DataFrame(β=βval, I_aw=I_sampled, H_a=Ha_sampled, E_U=EU_sampled, Forgetting=ResetCtrs)

In [ ]:
#compute theoretical result for same set of temperatures
ε = 0.0001 #convergence critetion for BAiterations
maxiter = 10000
 = length(βrange_samp)

#preallocate
I = zeros()
Ha = zeros()
Hagw = zeros()
EU = zeros()
RDobj = zeros()

#sweep through β values and perfomr Blahut-Arimoto iterations for each value
for i=1:    
    pagw, pa, I[i], Ha[i], Hagw[i], EU[i], RDobj[i] = BAiterations(pa_init, βrange_samp[i], U_pre, p_w, ε, maxiter,compute_performance=true)  
end

#show rate-utility curve (shaded region is theoretically infeasible)
perf_res_analytical_samp = performancemeasures2DataFrame(I, Ha, Hagw, EU, RDobj)
perf_res_analytical_samp[] = βrange_samp;

In [ ]:
#plot the solutions from the sampling runs into the (analytical) rate-utility plot
plot(Guide.ylabel("E[U]"), Guide.xlabel("I(A;W) [bits]"),
     Guide.title("Rate-Utility curve (dots show sampling solutions)"), BAtheme(), BAdiscretecolorscale(2),
     layer(res_sampled,y="E_U",x="I_aw",Geom.point,color="Forgetting"),
     layer(perf_res_analytical_samp,y="E_U",x="I_aw",Geom.point),
     layer(perf_res_analytical,y="E_U",x="I_aw",Geom.line),
     layer(perf_res_analytical,y="E_U",x="I_aw",ymin="E_U",ymax=ones(size(perf_res_analytical,1))*maximum(perf_res_analytical[:E_U]),
     Geom.ribbon,BAtheme(default_color=colorant"green"))
    )

In [ ]:
#compute the mean for each group of points (that were produced with the same beta and same forgetting setting)
res_samp_aggregated = aggregate(res_sampled, [,:Forgetting], mean)

#plot the mean-solutions from the sampling runs into the (analytical) rate-utility plot
plt_samp = plot(Guide.ylabel("E[U]"), Guide.xlabel("I(A;W) [bits]"),
     Guide.title("Rate-Utility curve (dots show mean sampling solutions)"), BAtheme(), BAdiscretecolorscale(2),
     layer(res_samp_aggregated,y="E_U_mean",x="I_aw_mean",Geom.point,Geom.line(preserve_order=true),color="Forgetting"),
     layer(perf_res_analytical_samp,y="E_U",x="I_aw",Geom.point),
     layer(perf_res_analytical,y="E_U",x="I_aw",Geom.line),
     layer(perf_res_analytical,y="E_U",x="I_aw",ymin="E_U",ymax=ones(size(perf_res_analytical,1))*maximum(perf_res_analytical[:E_U]),
     Geom.ribbon,BAtheme(default_color=colorant"green"))
    )

In [ ]:
#store the plots
#draw(SVG("Figures/RateUtilityCurve.svg", 8.5cm, 7cm), plot_rateutility)

#draw(SVG("Figures/SamplingCond.svg", 8.5cm, 9cm), plt_cond)
#draw(SVG("Figures/BASampling.svg", 13cm, 11cm), plt_samp)

#plot_samp = vstack(plt_samp, plt_cond)
#draw(SVG("Figures/SamplingResults.svg", 18cm,16cm),plot_samp)

In [ ]:
#try changing the number of inner-/outer-loop iterations
#try changing the burn-in ratio
#try soft-forgetting (with exponential-decay window)

#forgetting seems to do better than no forgetting (in terms of being closer to the rate-utility curve),
#but it also seems that the points with forgetting tend to have a lower I(A;O) (and also a lower E[U]) - even though
#the temperatures are the same.