Aside from tuning the solver parameters (c, k, alpha), MCTS currently offers several means of incorporating domain knowledge. The following solver parameters control the planner's behavior:
estimate_value
determines how the value is estimated at the leaf nodes (this is usually done using a rollout simulation).init_N
and init_Q
determine how N(s,a) and Q(s,a) are initialized when a new node is created.next_action
determines which new actions are tried in double progressive wideningThere are three ways of specifying these parameters: 1) with constant values, 2) with functions, and 3) with custom objects.
In [19]:
using MCTS
using POMDPs
using POMDPModels
using Random
mdp = LegacyGridWorld();
In [20]:
solver = MCTSSolver(n_iterations=3, depth=4,
init_N=3,
init_Q=11.73)
policy = solve(solver, mdp)
action(policy, GridWorldState(1,1))
println("State-Action Nodes")
tree = policy.tree
for sn in MCTS.state_nodes(tree)
for san in MCTS.children(sn)
println("s:$(MCTS.state(sn)), a:$(action(san)) Q:$(MCTS.q(san)) N:$(MCTS.n(san))")
end
end
init_N
, init_Q
, estimate_value
, and next_action
can also be functions. The following code will
Note: the ?
below is part of the ternary operator.
In [21]:
special_Q(mdp, s, a) = s == GridWorldState(1,2) ? 11.73 : 0.0
special_N(mdp, s, a) = s == GridWorldState(1,2) ? 3 : 0
function manhattan_value(mdp, s, depth) # depth is the solver `depth` parameter less the number of timesteps that have already passed (it can be ignored in many cases)
m_dist = abs(s.x-9)+abs(s.y-3)
val = 10.0/m_dist
println("Set value for $s to $val") # this is not necessary - just shows that it's working later
return val
end
function up_priority(mdp, s, snode) # snode is the state node of type DPWStateNode
if haskey(snode.tree.a_lookup, (snode.index, :up)) # "up" is already there
return GridWorldAction(rand([:left, :down, :right])) # add a random action
else
return GridWorldAction(:up)
end
end;
In [22]:
solver = DPWSolver(n_iterations=8, depth=4,
init_N=special_N, init_Q=special_Q,
estimate_value=manhattan_value,
next_action=up_priority)
policy = solve(solver, mdp)
action(policy, GridWorldState(1,1))
println("State-Action Nodes:")
tree = policy.tree
for i in 1:length(tree.total_n)
for j in tree.children[i]
println("s:$(tree.s_labels[i]), a:$(tree.a_labels[j]), Q:$(tree.q[j]) N:$(tree.n[j])")
end
end
There are many cases where functions are not suitable, for example when the solver needs to be serialized. In this case, arbitrary objects may be passed to the solver to encode the behavior. The same object can be passed to multiple solver parameters to govern all of their behavior. See the docstring for the solver for more information on which functions will be called on the object(s). The following code does exactly the same thing as the function-based code above:
In [23]:
mutable struct MyHeuristic
target_state::GridWorldState
special_state::GridWorldState
special_Q::Float64
special_N::Int
priority_action::GridWorldAction
rng::AbstractRNG
end;
In [24]:
MCTS.init_Q(h::MyHeuristic, mdp::LegacyGridWorld, s, a) = s == h.special_state ? h.special_Q : 0.0
MCTS.init_N(h::MyHeuristic, mdp::LegacyGridWorld, s, a) = s == h.special_state ? h.special_N : 0
function MCTS.estimate_value(h::MyHeuristic, mdp::LegacyGridWorld, s, depth::Int)
targ = h.target_state
m_dist = abs(s.x-targ.x)+abs(s.y-targ.y)
val = 10.0/m_dist
println("Set value for $s to $val") # this is not necessary - just shows that it's working later
return val
end
function MCTS.next_action(h::MyHeuristic, mdp::LegacyGridWorld, s, snode::DPWStateNode)
if haskey(snode.tree.a_lookup, (snode.index, h.priority_action))
return GridWorldAction(rand(h.rng, [:up, :left, :down, :right])) # add a random other action
else
return h.priority_action
end
end;
In [25]:
heur = MyHeuristic(GridWorldState(9,3), GridWorldState(1,2), 11.73, 3, GridWorldAction(:up), Random.GLOBAL_RNG)
solver = DPWSolver(n_iterations=8, depth=4,
init_N=heur, init_Q=heur,
estimate_value=heur,
next_action=heur)
policy = solve(solver, mdp)
action(policy, GridWorldState(1,1))
println("State-Action Nodes:")
tree = policy.tree
for i in 1:length(tree.total_n)
for j in tree.children[i]
println("s:$(tree.s_labels[i]), a:$(tree.a_labels[j]), Q:$(tree.q[j]) N:$(tree.n[j])")
end
end
In [26]:
mutable struct SeekTarget <: Policy
target::GridWorldState
end
In [27]:
function POMDPs.action(p::SeekTarget, s::GridWorldState, a::GridWorldAction=GridWorldAction(:up))
if p.target.x > s.x
return GridWorldAction(:right)
elseif p.target.x < s.x
return GridWorldAction(:left)
elseif p.target.y > s.y
return GridWorldAction(:up)
else
return GridWorldAction(:down)
end
end
In [28]:
solver = MCTSSolver(n_iterations=5, depth=20,
estimate_value=RolloutEstimator(SeekTarget(GridWorldState(9,3))))
policy = solve(solver, mdp)
action(policy, GridWorldState(5,1))
println("State-Action Nodes")
tree = policy.tree
for sn in MCTS.state_nodes(tree)
for san in MCTS.children(sn)
println("s:$(MCTS.state(sn)), a:$(action(san)) Q:$(MCTS.q(san)) N:$(MCTS.n(san))")
end
end
In [ ]: