A simple bandit algorithm

Initialize, for a = 1 to k:

$$Q(a) \leftarrow 0$$ $$N(a) \leftarrow 0$$

Repeat forever:

$$A \leftarrow \begin{cases} argmax_a Q(a), & \text{with probability } 1 - \epsilon \text{ (breaking ties randomly)}\\ \text{a random action}, & \text{with probability } \epsilon \\ \end{cases}\\ R \leftarrow bandit(A)\\ N(A) \leftarrow N(A) + 1\\ Q(A) \leftarrow Q(A) + \cfrac{1}{N(A)}[R - Q(A)]$$

In [1]:
type Action = Int
type Reward = Int
type NthTry = Int
type Estimate = Double
type Epsilon = Double

trait SimpleBanditAlgorithm extends Serializable {

  def epoch = (1 to 2000)
  def eps: Epsilon = 0.1
  def nActions = 10

  import scala.util.Random
  def bandit(a: Action): Reward = a match {
    case 0 => 5 //Random.nextInt(20)
    case 5 => 98 //Random.nextInt(maxReward)
    case 7 => 14
    case 9 => 70 //Random.nextInt(30)  
    case _ => 60 //Random.nextInt(maxReward)
  }

  def randomAction: Action = Random.nextInt(nActions)
  def distribution(eps: Epsilon): Boolean = Random.nextDouble() < eps

  def argmax(values: Map[Action, Estimate]): Action = if (values.nonEmpty){
    values.iterator.maxBy(_._2)._1
  } else randomAction
    
  final case class State(n: NthTry = 0, accR: List[Reward] = List.empty, q: Map[Action, Estimate] = Map.empty)

  def step(state: State): State = {
    val action = 
      if (distribution(eps)) randomAction else argmax(state.q)
    val reward = bandit(action)
    val nUpdate = state.n + 1
    val previousEstimate = state.q.getOrElse(action, 0.0)
    val updatedEstimate = 
      previousEstimate + (reward - previousEstimate) / nUpdate
    val qUpdate = state.q.updated(action, updatedEstimate)
    State(nUpdate, reward :: state.accR, qUpdate)
  }

  val results = epoch
    .scanLeft(State())((state, _) => step(state))

}

def algByEps(epsilon: Epsilon) = new SimpleBanditAlgorithm {
  override def eps = epsilon
}


Out[1]:
defined type Action
defined type Reward
defined type NthTry
defined type Estimate
defined type Epsilon
defined trait SimpleBanditAlgorithm
defined function algByEps

The average reward for different $\epsilon$:


In [2]:
//----------------------Vegas---------------------------------

import $ivy.`org.vegas-viz::vegas:0.3.6`

import vegas._
import vegas.render.HTMLRenderer._

implicit val displayer: String => Unit = publish.html(_)


val vegasPlotData = for {
  eps <- Seq(0.1, 0.01, 0)
  s <- algByEps(eps).results
} yield Map(
  "eps" -> eps.toString, 
  "step" -> s.n, 
  "average reward" -> s.accR.sum / (s.accR.size + 1)
) 

def drawPlotWithVegas(): Unit = Vegas("Plot")
  .withData(vegasPlotData)
  .encodeX("step", Quant)
  .encodeY("average reward", Quant)
  .encodeDetailFields(Field(field = "eps", dataType = Nominal))
  .encodeColor(
       field = "eps",
       dataType = Nominal,
       legend = vegas.Legend(orient = "left", title = "epsilon"))
  .mark(vegas.Line)
  .show


Out[2]:
import $ivy.$                           


import vegas._

import vegas.render.HTMLRenderer._


displayer: String => Unit = <function1>
vegasPlotData: Seq[Map[String, Any]] = List(
  Map("eps" -> 0.1, "step" -> 0, "average reward" -> 0),
  Map("eps" -> 0.1, "step" -> 1, "average reward" -> 7),
  Map("eps" -> 0.1, "step" -> 2, "average reward" -> 9),
  Map("eps" -> 0.1, "step" -> 3, "average reward" -> 10),
  Map("eps" -> 0.1, "step" -> 4, "average reward" -> 11),
  Map("eps" -> 0.1, "step" -> 5, "average reward" -> 11),
  Map("eps" -> 0.1, "step" -> 6, "average reward" -> 12),
  Map("eps" -> 0.1, "step" -> 7, "average reward" -> 12),
  Map("eps" -> 0.1, "step" -> 8, "average reward" -> 12),
  Map("eps" -> 0.1, "step" -> 9, "average reward" -> 12),
  Map("eps" -> 0.1, "step" -> 10, "average reward" -> 12),
...
defined function drawPlotWithVegas

In [3]:
drawPlotWithVegas()


Same Plot rendered with Plotly library:


In [4]:
//------------------------Plotly---------------------------

import $ivy.`org.plotly-scala::plotly-jupyter-scala:0.3.0`

import plotly._
import plotly.element._
import plotly.layout._
import plotly.JupyterScala._

plotly.JupyterScala.init()

val plot = Seq(0.1, 0.01, 0) map { eps =>
  val (x, y) = algByEps(eps).results.map(s => s.n -> s.accR.sum / (s.accR.size + 1)).unzip
  Scatter(x, y, name = eps.toString)  
}

def drawPlot(): Unit = plot.plot()


Out[4]:
import $ivy.$                                             


import plotly._

import plotly.element._

import plotly.layout._

import plotly.JupyterScala._


plot: Seq[plotly.Scatter] = List(
  Scatter(
    Some(
      Doubles(
        Vector(
          0.0,
          1.0,
          2.0,
          3.0,
          4.0,
          5.0,
          6.0,
...
defined function drawPlot

In [5]:
drawPlot()


Spark example:


In [6]:
import $ivy.`org.slf4j:slf4j-nop:1.7.12` // for cleaner logs
import $ivy.`org.apache.spark::spark-sql:2.0.2` // adjust spark version - spark >= 1.6 should be fine, possibly >= 1.3 too
import $ivy.`org.jupyter-scala::spark:0.4.0-RC3` // JupyterSparkContext-s (SparkContext aware of the jupyter-scala kernel)

import org.apache.spark._
import org.apache.spark.sql._
import jupyter.spark._

@transient val sparkConf = new SparkConf().
  setAppName("SBTB").
  setMaster("local")

@transient val sc = new JupyterSparkContext(sparkConf)

val alg2 = algByEps(0.01) 
import alg2._

val metrics = sc.collectionAccumulator[State]
def withMetric(m: State) = { metrics.add(m); m}

def epochAggregation(state: State, stepNum: Int) = withMetric(step(state))

type Strategy = Map[Action, Estimate]
def aggregateStrategies(st1: Strategy, st2: Strategy) = (st1 ++ st2).toIterator.map{//should be monoid 
  case (action, estimate) => action -> scala.math.max(st1.getOrElse(action, 0.0), st2.getOrElse(action, 0.0))
}.toMap

def overallAggregation(s1: State, s2: State) = State(s1.n + s2.n, s1.accR ++ s2.accR, aggregateStrategies(s1.q, s2.q))

sc.makeRDD(1 to 10000).aggregate(State())(epochAggregation, overallAggregation)


SLF4J: Class path contains multiple SLF4J bindings.
SLF4J: Found binding in [jar:file:/Users/dmytrokondratiuk/.coursier/cache/v1/https/repo1.maven.org/maven2/org/slf4j/slf4j-nop/1.7.12/slf4j-nop-1.7.12.jar!/org/slf4j/impl/StaticLoggerBinder.class]
SLF4J: Found binding in [jar:file:/Users/dmytrokondratiuk/.coursier/cache/v1/https/repo1.maven.org/maven2/org/slf4j/slf4j-log4j12/1.7.16/slf4j-log4j12-1.7.16.jar!/org/slf4j/impl/StaticLoggerBinder.class]
SLF4J: See http://www.slf4j.org/codes.html#multiple_bindings for an explanation.
SLF4J: Actual binding is of type [org.slf4j.helpers.NOPLoggerFactory]
log4j:WARN No appenders could be found for logger (io.netty.util.internal.logging.InternalLoggerFactory).
log4j:WARN Please initialize the log4j system properly.
log4j:WARN See http://logging.apache.org/log4j/1.2/faq.html#noconfig for more info.
Out[6]:
import $ivy.$                            // for cleaner logs

import $ivy.$                                   // adjust spark version - spark >= 1.6 should be fine, possibly >= 1.3 too

import $ivy.$                                    // JupyterSparkContext-s (SparkContext aware of the jupyter-scala kernel)


import org.apache.spark._

import org.apache.spark.sql._

import jupyter.spark._


sparkConf: org.apache.spark.SparkConf = org.apache.spark.SparkConf@60584f36
sc: jupyter.spark.JupyterSparkContext = jupyter.spark.JupyterSparkContext@53cd366f
alg2: AnyRef with SimpleBanditAlgorithm = $sess.cmd0Wrapper$Helper$$anon$1@66e955ea
import alg2._


metrics: org.apache.spark.util.CollectionAccumulator[alg2.State] = CollectionAccumulator(id: 0, name: None, value: [State(1,List(14),Map(7 -> 14.0)), State(2,List(14, 14),Map(7 -> 14.0)), State(3,List(14, 14, 14),Map(7 -> 14.0)), State(4,List(14, 14, 14, 14),Map(7 -> 14.0)), State(5,List(14, 14, 14, 14, 14),Map(7 -> 14.0)), State(6,List(14, 14, 14, 14, 14, 14),Map(7 -> 14.0)), State(7,List(14, 14, 14, 14, 14, 14, 14),Map(7 -> 14.0)), State(8,List(14, 14, 14, 14, 14, 14, 14, 14),Map(7 -> 14.0)), State(9,List(14, 14, 14, 14, 14, 14, 14, 14, 14),Map(7 -> 14.0)), State(10,List(14, 14, 14, 14, 14, 14, 14, 14, 14, 14),Map(7 -> 14.0)), State(11,List(14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14),Map(7 -> 14.0)), State(12,List(14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14),Map(7 -> 14.0)), State(13,List(14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14),Map(7 -> 14.0)), State(14,List(14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14),Map(7 -> 14.0)), State(15,List(14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, ...
defined function withMetric
defined function epochAggregation
defined type Strategy
defined function aggregateStrategies
defined function overallAggregation
res5_16: alg2.State = State(
  10000,
  List(
    14,
    14,
    14,
    14,
    14,
    14,
    14,
    14,
    14,
...