A simple bandit algorithm
Initialize, for a = 1 to k:
$$Q(a) \leftarrow 0$$ $$N(a) \leftarrow 0$$
Repeat forever:
$$A \leftarrow \begin{cases} argmax_a Q(a), & \text{with probability } 1 - \epsilon \text{ (breaking ties randomly)}\\ \text{a random action}, & \text{with probability } \epsilon \\ \end{cases}\\ R \leftarrow bandit(A)\\ N(A) \leftarrow N(A) + 1\\ Q(A) \leftarrow Q(A) + \cfrac{1}{N(A)}[R - Q(A)]$$
In [1]:
type Action = Int
type Reward = Int
type NthTry = Int
type Estimate = Double
type Epsilon = Double
trait SimpleBanditAlgorithm extends Serializable {
def epoch = (1 to 2000)
def eps: Epsilon = 0.1
def nActions = 10
import scala.util.Random
def bandit(a: Action): Reward = a match {
case 0 => 5 //Random.nextInt(20)
case 5 => 98 //Random.nextInt(maxReward)
case 7 => 14
case 9 => 70 //Random.nextInt(30)
case _ => 60 //Random.nextInt(maxReward)
}
def randomAction: Action = Random.nextInt(nActions)
def distribution(eps: Epsilon): Boolean = Random.nextDouble() < eps
def argmax(values: Map[Action, Estimate]): Action = if (values.nonEmpty){
values.iterator.maxBy(_._2)._1
} else randomAction
final case class State(n: NthTry = 0, accR: List[Reward] = List.empty, q: Map[Action, Estimate] = Map.empty)
def step(state: State): State = {
val action =
if (distribution(eps)) randomAction else argmax(state.q)
val reward = bandit(action)
val nUpdate = state.n + 1
val previousEstimate = state.q.getOrElse(action, 0.0)
val updatedEstimate =
previousEstimate + (reward - previousEstimate) / nUpdate
val qUpdate = state.q.updated(action, updatedEstimate)
State(nUpdate, reward :: state.accR, qUpdate)
}
val results = epoch
.scanLeft(State())((state, _) => step(state))
}
def algByEps(epsilon: Epsilon) = new SimpleBanditAlgorithm {
override def eps = epsilon
}
Out[1]:
The average reward for different $\epsilon$:
In [2]:
//----------------------Vegas---------------------------------
import $ivy.`org.vegas-viz::vegas:0.3.6`
import vegas._
import vegas.render.HTMLRenderer._
implicit val displayer: String => Unit = publish.html(_)
val vegasPlotData = for {
eps <- Seq(0.1, 0.01, 0)
s <- algByEps(eps).results
} yield Map(
"eps" -> eps.toString,
"step" -> s.n,
"average reward" -> s.accR.sum / (s.accR.size + 1)
)
def drawPlotWithVegas(): Unit = Vegas("Plot")
.withData(vegasPlotData)
.encodeX("step", Quant)
.encodeY("average reward", Quant)
.encodeDetailFields(Field(field = "eps", dataType = Nominal))
.encodeColor(
field = "eps",
dataType = Nominal,
legend = vegas.Legend(orient = "left", title = "epsilon"))
.mark(vegas.Line)
.show
Out[2]:
In [3]:
drawPlotWithVegas()
Same Plot rendered with Plotly library:
In [4]:
//------------------------Plotly---------------------------
import $ivy.`org.plotly-scala::plotly-jupyter-scala:0.3.0`
import plotly._
import plotly.element._
import plotly.layout._
import plotly.JupyterScala._
plotly.JupyterScala.init()
val plot = Seq(0.1, 0.01, 0) map { eps =>
val (x, y) = algByEps(eps).results.map(s => s.n -> s.accR.sum / (s.accR.size + 1)).unzip
Scatter(x, y, name = eps.toString)
}
def drawPlot(): Unit = plot.plot()
Out[4]:
In [5]:
drawPlot()
Spark example:
In [6]:
import $ivy.`org.slf4j:slf4j-nop:1.7.12` // for cleaner logs
import $ivy.`org.apache.spark::spark-sql:2.0.2` // adjust spark version - spark >= 1.6 should be fine, possibly >= 1.3 too
import $ivy.`org.jupyter-scala::spark:0.4.0-RC3` // JupyterSparkContext-s (SparkContext aware of the jupyter-scala kernel)
import org.apache.spark._
import org.apache.spark.sql._
import jupyter.spark._
@transient val sparkConf = new SparkConf().
setAppName("SBTB").
setMaster("local")
@transient val sc = new JupyterSparkContext(sparkConf)
val alg2 = algByEps(0.01)
import alg2._
val metrics = sc.collectionAccumulator[State]
def withMetric(m: State) = { metrics.add(m); m}
def epochAggregation(state: State, stepNum: Int) = withMetric(step(state))
type Strategy = Map[Action, Estimate]
def aggregateStrategies(st1: Strategy, st2: Strategy) = (st1 ++ st2).toIterator.map{//should be monoid
case (action, estimate) => action -> scala.math.max(st1.getOrElse(action, 0.0), st2.getOrElse(action, 0.0))
}.toMap
def overallAggregation(s1: State, s2: State) = State(s1.n + s2.n, s1.accR ++ s2.accR, aggregateStrategies(s1.q, s2.q))
sc.makeRDD(1 to 10000).aggregate(State())(epochAggregation, overallAggregation)
Out[6]: