Incorporating Domain Knowledge

Aside from tuning the solver parameters (c, k, alpha), MCTS currently offers several means of incorporating domain knowledge. The following solver parameters control the planner's behavior:

  • estimate_value determines how the value is estimated at the leaf nodes (this is usually done using a rollout simulation).
  • init_N and init_Q determine how N(s,a) and Q(s,a) are initialized when a new node is created.
  • next_action determines which new actions are tried in double progressive widening

There are three ways of specifying these parameters: 1) with constant values, 2) with functions, and 3) with custom objects.


In [19]:
using MCTS
using POMDPs
using POMDPModels
using Random
mdp = LegacyGridWorld();

Constant Values

init_N, init_Q, and estimate_value can be set with constant values (though this is a bad idea for estimate_value. next_action cannot be specified in this way. The following code sets all new N to 3 and all new Q to 11.73 for example.


In [20]:
solver = MCTSSolver(n_iterations=3, depth=4,
                    init_N=3,
                    init_Q=11.73)
policy = solve(solver, mdp)
action(policy, GridWorldState(1,1))
println("State-Action Nodes")
tree = policy.tree
for sn in MCTS.state_nodes(tree)
    for san in MCTS.children(sn)
        println("s:$(MCTS.state(sn)), a:$(action(san)) Q:$(MCTS.q(san)) N:$(MCTS.n(san))")
    end
end


State-Action Nodes
s:GridWorldState(1, 1, false), a:up Q:8.7975 N:4
s:GridWorldState(1, 1, false), a:down Q:5.027142857142857 N:7
s:GridWorldState(1, 1, false), a:left Q:5.027142857142857 N:7
s:GridWorldState(1, 1, false), a:right Q:11.73 N:3
s:GridWorldState(2, 1, false), a:up Q:11.73 N:3
s:GridWorldState(2, 1, false), a:down Q:11.73 N:3
s:GridWorldState(2, 1, false), a:left Q:11.73 N:3
s:GridWorldState(2, 1, false), a:right Q:11.73 N:3

Functions

init_N, init_Q, estimate_value, and next_action can also be functions. The following code will

  • initialize Q to 0.0 everywhere except state [1,2] where it will be 11.73
  • initialize N to 0 everywhere except state [1,2] where it will be 3
  • estimate the value to be 10 divided by the manhattan distance to state [9,3]
  • always choose action "up" first in double progressive widening

Note: the ? below is part of the ternary operator.


In [21]:
special_Q(mdp, s, a) = s == GridWorldState(1,2) ? 11.73 : 0.0
special_N(mdp, s, a) = s == GridWorldState(1,2) ? 3 : 0

function manhattan_value(mdp, s, depth) # depth is the solver `depth` parameter less the number of timesteps that have already passed (it can be ignored in many cases)
    m_dist = abs(s.x-9)+abs(s.y-3)
    val = 10.0/m_dist
    println("Set value for $s to $val") # this is not necessary - just shows that it's working later
    return val
end

function up_priority(mdp, s, snode) # snode is the state node of type DPWStateNode
    if haskey(snode.tree.a_lookup, (snode.index, :up)) # "up" is already there
        return GridWorldAction(rand([:left, :down, :right])) # add a random action
    else
        return GridWorldAction(:up)
    end
end;

In [22]:
solver = DPWSolver(n_iterations=8, depth=4,
                   init_N=special_N, init_Q=special_Q,
                   estimate_value=manhattan_value,
                   next_action=up_priority)
policy = solve(solver, mdp)
action(policy, GridWorldState(1,1))
println("State-Action Nodes:")
tree = policy.tree
for i in 1:length(tree.total_n)
    for j in tree.children[i]
        println("s:$(tree.s_labels[i]), a:$(tree.a_labels[j]), Q:$(tree.q[j]) N:$(tree.n[j])")
    end
end


Set value for GridWorldState(2, 1, false) to 1.1111111111111112
Set value for GridWorldState(1, 2, false) to 1.1111111111111112
Set value for GridWorldState(2, 2, false) to 1.25
Set value for GridWorldState(3, 1, false) to 1.25
Set value for GridWorldState(3, 2, false) to 1.4285714285714286
Set value for GridWorldState(1, 3, false) to 1.25
State-Action Nodes:
s:GridWorldState(1, 1, false), a:up, Q:1.0964380704365078 N:5
s:GridWorldState(1, 1, false), a:down, Q:0.0 N:4
s:GridWorldState(1, 1, false), a:right, Q:0.7520833333333334 N:3
s:GridWorldState(1, 1, false), a:left, Q:0.0 N:2
s:GridWorldState(2, 1, false), a:up, Q:1.1875 N:2
s:GridWorldState(2, 1, false), a:left, Q:0.0 N:1
s:GridWorldState(2, 1, false), a:right, Q:1.2892857142857144 N:1
s:GridWorldState(1, 2, false), a:up, Q:7.4898437499999995 N:5
s:GridWorldState(1, 2, false), a:down, Q:11.73 N:3
s:GridWorldState(3, 1, false), a:up, Q:1.3571428571428572 N:1

Objects

There are many cases where functions are not suitable, for example when the solver needs to be serialized. In this case, arbitrary objects may be passed to the solver to encode the behavior. The same object can be passed to multiple solver parameters to govern all of their behavior. See the docstring for the solver for more information on which functions will be called on the object(s). The following code does exactly the same thing as the function-based code above:


In [23]:
mutable struct MyHeuristic
    target_state::GridWorldState
    special_state::GridWorldState
    special_Q::Float64
    special_N::Int
    priority_action::GridWorldAction
    rng::AbstractRNG
end;

In [24]:
MCTS.init_Q(h::MyHeuristic, mdp::LegacyGridWorld, s, a) = s == h.special_state ? h.special_Q : 0.0
MCTS.init_N(h::MyHeuristic, mdp::LegacyGridWorld, s, a) = s == h.special_state ? h.special_N : 0

function MCTS.estimate_value(h::MyHeuristic, mdp::LegacyGridWorld, s, depth::Int)
    targ = h.target_state
    m_dist = abs(s.x-targ.x)+abs(s.y-targ.y)
    val = 10.0/m_dist
    println("Set value for $s to $val") # this is not necessary - just shows that it's working later
    return val
end

function MCTS.next_action(h::MyHeuristic, mdp::LegacyGridWorld, s, snode::DPWStateNode)
    if haskey(snode.tree.a_lookup, (snode.index, h.priority_action))
        return GridWorldAction(rand(h.rng, [:up, :left, :down, :right])) # add a random other action
    else
        return h.priority_action
    end
end;

In [25]:
heur = MyHeuristic(GridWorldState(9,3), GridWorldState(1,2), 11.73, 3, GridWorldAction(:up), Random.GLOBAL_RNG)
solver = DPWSolver(n_iterations=8, depth=4,
                   init_N=heur, init_Q=heur,
                   estimate_value=heur,
                   next_action=heur)
policy = solve(solver, mdp)
action(policy, GridWorldState(1,1))
println("State-Action Nodes:")
tree = policy.tree
for i in 1:length(tree.total_n)
    for j in tree.children[i]
        println("s:$(tree.s_labels[i]), a:$(tree.a_labels[j]), Q:$(tree.q[j]) N:$(tree.n[j])")
    end
end


Set value for GridWorldState(1, 2, false) to 1.1111111111111112
Set value for GridWorldState(1, 3, false) to 1.25
Set value for GridWorldState(2, 1, false) to 1.1111111111111112
Set value for GridWorldState(1, 4, false) to 1.1111111111111112
Set value for GridWorldState(2, 3, false) to 1.4285714285714286
State-Action Nodes:
s:GridWorldState(1, 1, false), a:up, Q:1.0758547371031746 N:5
s:GridWorldState(1, 1, false), a:left, Q:0.0 N:4
s:GridWorldState(1, 1, false), a:down, Q:0.0 N:4
s:GridWorldState(1, 1, false), a:right, Q:0.5277777777777778 N:2
s:GridWorldState(1, 2, false), a:up, Q:5.677326034580498 N:7
s:GridWorldState(1, 2, false), a:left, Q:5.864999999999999 N:6
s:GridWorldState(1, 3, false), a:up, Q:1.2063492063492065 N:2

Rollouts

The most common way to estimate the value of a state node is with rollout simulations. This can be done with an arbitrary policy or solver by passing a RolloutEstimator object as the estimate_value parameter. The following code does this with a policy that moves towards state [9,3].


In [26]:
mutable struct SeekTarget <: Policy
    target::GridWorldState
end

In [27]:
function POMDPs.action(p::SeekTarget, s::GridWorldState, a::GridWorldAction=GridWorldAction(:up))
    if p.target.x > s.x
        return GridWorldAction(:right)
    elseif p.target.x < s.x
        return GridWorldAction(:left)
    elseif p.target.y > s.y
        return GridWorldAction(:up)
    else
        return GridWorldAction(:down)
    end
end

In [28]:
solver = MCTSSolver(n_iterations=5, depth=20,
                    estimate_value=RolloutEstimator(SeekTarget(GridWorldState(9,3))))
policy = solve(solver, mdp)
action(policy, GridWorldState(5,1))
println("State-Action Nodes")
tree = policy.tree
for sn in MCTS.state_nodes(tree)
    for san in MCTS.children(sn)
        println("s:$(MCTS.state(sn)), a:$(action(san)) Q:$(MCTS.q(san)) N:$(MCTS.n(san))")
    end
end


State-Action Nodes
s:GridWorldState(5, 1, false), a:up Q:6.018902594758494 N:2
s:GridWorldState(5, 1, false), a:down Q:0.0 N:20
s:GridWorldState(5, 1, false), a:left Q:5.987369392383786 N:1
s:GridWorldState(5, 1, false), a:right Q:6.634204312890622 N:1
s:GridWorldState(5, 2, false), a:up Q:6.983372960937498 N:1
s:GridWorldState(5, 2, false), a:down Q:0.0 N:0
s:GridWorldState(5, 2, false), a:left Q:0.0 N:0
s:GridWorldState(5, 2, false), a:right Q:0.0 N:0
s:GridWorldState(5, 3, false), a:up Q:0.0 N:0
s:GridWorldState(5, 3, false), a:down Q:0.0 N:0
s:GridWorldState(5, 3, false), a:left Q:0.0 N:0
s:GridWorldState(5, 3, false), a:right Q:0.0 N:0
s:GridWorldState(4, 1, false), a:up Q:0.0 N:0
s:GridWorldState(4, 1, false), a:down Q:0.0 N:0
s:GridWorldState(4, 1, false), a:left Q:0.0 N:0
s:GridWorldState(4, 1, false), a:right Q:0.0 N:0
s:GridWorldState(6, 1, false), a:up Q:0.0 N:0
s:GridWorldState(6, 1, false), a:down Q:0.0 N:0
s:GridWorldState(6, 1, false), a:left Q:0.0 N:0
s:GridWorldState(6, 1, false), a:right Q:0.0 N:0

In [ ]: