TD(λ)

Bounded random walk with TD(λ) from Learning to Predict by the Methods of Temporal Dierences

Random walk generation


In [10]:
def generate_random_walk():
    """
    Generates a single walk
    :return: array of states (strings)
    """

    terminal_states = {'A', 'G'}
    choices = {
        'B': ['A', 'C'],
        'C': ['B', 'D'],
        'D': ['C', 'E'],
        'E': ['D', 'F'],
        'F': ['E', 'G']
    }

    # Start at state D
    sequence = ['D']

    while sequence[-1] not in terminal_states:
        next_state = random.choice(choices[sequence[-1]])
        sequence.append(next_state)

    return sequence

In [15]:
def test_random_walk():
    random.seed()
    walks = [generate_random_walk() for _ in range(1000)]
    all_first = all((walk[0] == 'D' for walk in walks))
    all_last = all((walk[-1] in {'A', 'G'} for walk in walks))
    all_len = all((len(walk) >= 4 for walk in walks))
    all_middle = all((all([w not in {'A', 'G'} for w in walk[1:-1]]) for walk in walks))
    actual = all([all_first, all_last, all_len, all_middle])
    expected = True
    assert expected == actual
    print('✓ All tests successful')

In [16]:
test_random_walk()


✓ All tests successful

Equation 4


In [19]:
def process_walk(walk, weights, alpha, lambda_, terminal_states):
    """
    Calculate changes in weights (dw) for a single walk
    :param walk: sequence of states
    :param weights: current value of weights
    :param alpha: learning rate
    :param lambda_: 
    :param dict of (terminal state, reward)'s
    :return: dw (dict)
    """
    
    def _clip(x, max_=1, min_=0):
        return max(min_, min(x, max_))
    
    def _p(w, state):
        # P_t
        return _clip(w[state])

    def _p_next(w, state):
        # P_t+1
        if state in terminal_states:
            return terminal_states[state]
        else:
            return _clip(_p(w, state))
        
    def _add_dictionaries(a, b):
        """
        Adds two dictionaries and returns a new dict 
        :param a: first dict
        :param b: second dict
        :return: new dictionary with common keys added and non-common keys from both dictionaries present
        """
        return {k: a.get(k, 0) + b.get(k, 0) for k in set(a) | set(b)}
    
    def _get_delta_w(walk, t_index, error, alpha, lambda_):
        """
        Calculate delta_w vector
        :param walk: array of states/strings
        :param t_index: t
        :param error: P_t+1 - P_t
        :param alpha: learning rate
        :param lambda_: of TD(lambda)
        :return: delta_w (dict)
        """
        # Vector of delta_w's (by state)
        states_vector = defaultdict(float)

        # This is the sum over k=1 to t_index of lambda ** (t_index -k) in equation 4
        for k in range(t_index + 1):
            state_at_k = walk[k]
            lambda_to_the_power = lambda_ ** (t_index - k)
            if lambda_to_the_power > 0:
                states_vector[state_at_k] += alpha * error * lambda_to_the_power

        # Equation 4
        result = {state: states_vector[state] for state in states_vector}
        return result
    
    # Δw
    dw = defaultdict(float)

    # Go all the way up to (not including) the terminal state
    for t_index in range(len(walk) - 1):
        state_current = walk[t_index]
        state_next = walk[t_index + 1]
        p_t = _p(weights, state_current)
        p_t_plus1 = _p_next(weights, state_next)
        error = p_t_plus1 - p_t
        dw_t = _get_delta_w(walk, t_index, error, alpha, lambda_)

        # Add the current state t's Δw_t to the overall Δw
        dw = _add_dictionaries(dw, dw_t)

    return dw

Setup


In [24]:
weight_initial = 0.5
terminal_states = {'A': 0, 'G': 1}
weights_initial = {state: weight_initial for state in 'BCDEF'}
weights_initial


Out[24]:
{'B': 0.5, 'C': 0.5, 'D': 0.5, 'E': 0.5, 'F': 0.5}

Weight updates


In [25]:
# Sample walks
walk_right = ['D', 'E', 'F', 'G']
walk_left  = ['D', 'C', 'B', 'A']

In [51]:
process_walk_partial = partial(process_walk,
                               weights=weights_initial,
                               terminal_states=terminal_states)

In [34]:
# With λ=0 we're only updating the step
process_walk_partial(walk=walk_right, alpha=0.01, lambda_=0)


Out[34]:
{'D': 0.0, 'E': 0.0, 'F': 0.005}

In [35]:
# With λ=0.5 the weight update is propagated
process_walk_partial(walk=walk_right, alpha=0.01, lambda_=0.5)


Out[35]:
{'D': 0.00125, 'E': 0.0025, 'F': 0.005}

In [36]:
# With λ=1 the weight update is propagated equally
process_walk_partial(walk=walk_right, alpha=0.01, lambda_=1)


Out[36]:
{'D': 0.005, 'E': 0.005, 'F': 0.005}

In [52]:
# Alpha controls the learning rate or how much we're updating the weights
process_walk_partial(walk=walk_right, alpha=0.1, lambda_=0)


Out[52]:
{'D': 0.0, 'E': 0.0, 'F': 0.05}

In [41]:
# Walking to the left lowers the weights since reward of 
# landing in state A is 0
process_walk_partial(walk=walk_left, alpha=0.01, lambda_=0)


Out[41]:
{'B': -0.005, 'C': 0.0, 'D': 0.0}

In [42]:
# λ=0.5 propagates the weight change back
process_walk_partial(walk=walk_left, alpha=0.1, lambda_=0.5)


Out[42]:
{'B': -0.05, 'C': -0.025, 'D': -0.0125}

In [53]:
# λ=1 propagates the weight change back with equal weight for all steps
process_walk_partial(walk=walk_left, alpha=0.1, lambda_=1)


Out[53]:
{'B': -0.05, 'C': -0.05, 'D': -0.05}

In [48]:
# With λ>0 repeated states become problematic, specially with a large alpha
process_walk_partial(
    walk=['D', 'E', 'D', 'E', 'D', 'E', 'D', 'E', 'D', 'E', 'F', 'G'],
    alpha=0.1,
    lambda_=1
)


Out[48]:
{'D': 0.25, 'E': 0.25, 'F': 0.05}

In [49]:
# D and E are now larger than 1 which doesn't make sense
process_walk_partial(
    walk=['D', 'E', 'D', 'E', 'D', 'E', 'D', 'E', 'D', 'E', 'F', 'G'],
    alpha=0.5,
    lambda_=1
)


Out[49]:
{'D': 1.25, 'E': 1.25, 'F': 0.25}