Monte-Carlo Tree Search

We assume you've read the documentation on Serial Value Iteration. Otherwise, go back there and understand it before coming back.

The Monte-Carlo tree search (MCTS) algorithm relies on the same problem definition framework as the value iteration algorithms.

Like value iteration, MCTS works by keeping an internal approximation of the value function and chooses the action using it.

Unlike value iteration, however, MCTS is an online algorithm. This means that the MCTS policy may start off poor, but it gets better the more it interacts with the MDP simulator/environment.

The main advantage to MCTS is its ability to give a good approximation of the state-action utility function despite not needing an expensive value iteration-type computation. We recommend using this for problems with large state and/or action spaces.

Note however that a key assumption is that both the action space and the state space are finite.

Solver definition

The syntax for using a serial MCTS solver is similar to that of the serial value iteration solver. We still need to discretize continuous variables since our solver implements the finite MCTS. Otherwise, the only difference is having to initialize a different type of solver.


In [1]:
push!(LOAD_PATH, "../src")
using PLite

# constants
const MinX = 0
const MaxX = 100
const StepX = 20

# mdp definition
mdp = MDP()

statevariable!(mdp, "x", MinX, MaxX)  # continuous
statevariable!(mdp, "goal", ["no", "yes"])  # discrete

actionvariable!(mdp, "move", ["W", "E", "stop"])  # discrete

function isgoal(x::Float64)
  if abs(x - MaxX / 2) < StepX
    return "yes"
  else
    return "no"
  end
end

function mytransition(x::Float64, goal::AbstractString, move::AbstractString)
  if isgoal(x) == "yes" && goal == "yes"
    return [([x, isgoal(x)], 1.0)]
  end

  if move == "E"
    if x >= MaxX
      return [
        ([x, isgoal(x)], 0.9),
        ([x - StepX, isgoal(x - StepX)], 0.1)]
    elseif x <= MinX
      return [
        ([x, isgoal(x)], 0.2),
        ([x + StepX, isgoal(x + StepX)], 0.8)]
    else
      return [
        ([x, isgoal(x)], 0.1),
        ([x - StepX, isgoal(x - StepX)], 0.1),
        ([x + StepX, isgoal(x + StepX)], 0.8)]
    end
  elseif move == "W"
    if x >= MaxX
      return [
        ([x, isgoal(x)], 0.1),
        ([x - StepX, isgoal(x - StepX)], 0.9)]
    elseif x <= MinX
      return [
      ([x, isgoal(x)], 0.9),
      ([x + StepX, isgoal(x + StepX)], 0.1)]
    else
      return [
        ([x, isgoal(x)], 0.1),
        ([x - StepX, isgoal(x - StepX)], 0.8),
        ([x + StepX, isgoal(x + StepX)], 0.1)]
    end
  elseif move == "stop"
    return [([x, isgoal(x)], 1.0)]
  end
end

function myreward(x::Float64, goal::AbstractString, move::AbstractString)
  if goal == "yes" && move == "stop"
    return 1
  else
    return 0
  end
end

transition!(mdp, ["x", "goal", "move"], mytransition)
reward!(mdp, ["x", "goal", "move"], myreward)


Out[1]:
PLite.LazyFunc(false,ASCIIString["x","goal","move"],myreward)

We define the solver as follows, and then generate the policy using the same syntax as in the value iteration algorithms.


In [2]:
# solver options
solver = SerialMCTS()
discretize_statevariable!(solver, "x", StepX)

# generate results
solution = solve(mdp, solver)
policy = getpolicy(mdp, solution)


INFO: mdp and monte-carlo tree search solver passed basic checks
Out[2]:
policy (generic function with 1 method)

Online solution

As mentioned, the policy generally improves as it receives more queries. MCTS grows an internal tree that keeps track of the approximate value function for the states it has seen. For example, after the query


In [3]:
stateq = (20.0, "no")
actionq = policy(stateq...)


Out[3]:
1-element Array{Any,1}:
 "E"

We see that the tree has grown, and the resulting state-action value function approximation agrees with intuition (higher value for better actions at a given state).


In [4]:
actions = ["W", "E", "stop"]
for entry in solution.tree
    println("state: ", entry[1])
    println("value: ")
    for iaction in 1:length(actions)
        println("\taction: ", actions[iaction], ", value: ", entry[2].qval[iaction]) 
    end
    println()
end


state: Any[20.0,"no"]
value: 
	action: W, value: 3.6718786612417675
	action: E, value: 15.849808314832028
	action: stop, value: 0.0

state: Any[40.0,"yes"]
value: 
	action: W, value: 0.0
	action: E, value: 0.0
	action: stop, value: 9.009182612454815

state: Any[0.0,"no"]
value: 
	action: W, value: 0.0
	action: E, value: 3.671252470380541
	action: stop, value: 0.0