We have:

$$ w_{ij} \leftarrow w_{ij} + \alpha \delta \frac{\partial Q_w}{\partial w} $$

where:

$$ \delta = r_t + \gamma Q_w(s_{t+1}, a_{t+1}) - Q_w(s_t, a_t) $$

Imagine we have calculated $\delta$, and pytorch gives us a calculator for $\frac{\partial Q_w}{\partial w}$, using autograd. So can do simply:

Q = q_estimator(state)
Q.backward(-delta)

(where we make delta negative, since the optimizer will assume the backpropped gradient is the gradient of the loss respect to that parameter, that we want to minimize the loss, and the optimizer will therefore move in the opposite direction to the gradient)

If we wanted to use a loss, we'd want that:

$$ \frac{\partial loss}{\partial w} = - \delta \frac{\partial Q_w}{\partial w} \tag{1} $$

(with a minus sign, since we will try to decrease the loss, but we want $w$ to move in the direction of $\delta \frac{\partial Q_w}{\partial w}$)

But we have:

$$ loss = crit(Q_w(input)) $$

So, we can chain rule this:

$$ \frac{\partial loss}{\partial w} = \frac{\partial crit}{\partial Q_w}\frac{\partial Q_w}{\partial w} $$

Comparing this to equation (1), we see that they both have $\frac{\partial Q_w}{\partial w}$, and so we require that:

$$ \frac{\partial crit}{\partial Q_w} = - \delta $$

If we integrate this we get:

$$ crit = - \int \delta \,dQ_w \tag{2} $$

But $\delta$ is actually a function of $Q_w$, ie:

$$ \delta = r_t + \gamma Q_w(s_{t+1}, a_{t+1}) - Q_w(s_t, a_t) $$

I think we're going to hold $Q_w(s_{t+1}, a_{t+1})$ as a constant, and update the weights so taht $Q_w(s_t, a_t)$ approaches the target $r_t + \gamma Q_w(s_{t+1}, a_{t+1})$. Let's write $Q_w(s_{t+1}, a_{t+1})$ as $Q_{w'}(s_{t+1}, s_{t+1})$ to emphasize taht we are treating it as a constant. And substituting it back in equation (2), we get:

$$ loss = \int Q_w(s_t, a_t) - r_t - \gamma Q_{w'}(s_{t+1}, a_{t+1}) dQ_w $$
$$ = 1/2 Q_w(s_t, a_t)^2 - r_t Q_w - \gamma Q_{w'}(s_{t+1}, a_{t+1}) + C $$

(where $C$ is an arbitrary constant of integration)

$$ = 1/2 \left(Q_w(s_t, a_t) - (r_t + \gamma Q_{w'}(s_{t+1}, a_{t+1}) \right)^2 + C $$

... which is half the MSE Loss:

$$ = \frac{1}{2} \text{MSELoss}(Q_w(s_t, a_t), r_t + \gamma Q_{w'}(s_{t+1}, a_{t+1}) + C $$

Let's try this with some pytorch, to test empirically for one example:


In [26]:
import torch
from torch import autograd, optim, nn
from torch.autograd import Variable
import torch.nn.functional as F


def calc_q(s):
    w = Variable(torch.FloatTensor([3.0]), requires_grad=True)
    q = w * Variable(s)
    return w, q


s = torch.FloatTensor([1.0])
target_q = torch.FloatTensor([1.3])

w, q = calc_q(s)
crit = nn.MSELoss()
loss = crit(q, Variable(target_q)) / 2
loss.backward()
print('mse(q, target_q)/2 w.grad', w.grad.data[0])

w, q = calc_q(s)
delta = Variable(target_q) - q
q.backward(-delta)
print('q.backward(-delta) w.grad', w.grad.data[0])


mse(q, target_q)/2 w.grad 1.7000000476837158
q.backward(-delta) w.grad 1.7000000476837158

=> they match :)