This page attempts to explain why LSTM was first proposed, and what are the core features together with some examples.
This is based on the paper Hochreiter and Schmidhuber. 1997. Long Short-Term Memory
The core feature of an LSTM unit as first proposed is the constant-error carrousel (CEC) which solves the vanishing gradient problem with standard RNN.
A CEC is neural network unit which consists of a single neuron with self-loop with weight fixed to 1.0 to ensure constant error flow when doing backpropagation.
Now let's see an example of CEC at work. We will use CEC to do very simple task: recognizing whether current character is inside a bracketed expression, with the opening bracket considered to be inside, and the closing bracket considered to be outside, for simplicity. This is solvable only using network that can store memory, since to recognize whether a character is inside a bracketed expression, we need to have the knowledge that there is an opening bracket to the left of current character which does not have the corresponding closing bracket.
The input alphabets are coming from the set: $\{`a`, `b`, `(`, `)`\}$ with the following 2-dimensional embedding:
$$ \begin{eqnarray} emb(`a`) &=& (1, 1) \nonumber\\ emb(`b`) &=& (-1, -1) \nonumber\\ emb(`(`) &=& (1, 0) \nonumber\\ emb(`)`) &=& (0, -1) \nonumber \end{eqnarray} $$For this task, we define a very simple network with two input units, one CEC unit, and one output unit with sigmoid activation ($\sigma(x) = \frac{1}{1 + e^{-x}}$), as follows:
For this task, we define the loss function as the cross-entropy (CE) between the predicted and the true one:
$$ \begin{eqnarray} \mathrm{CE}(x, y) = - (x\log(y) + (1-x)\log(1-y)) \nonumber\\ \mathrm{Loss}(\hat{o}_t, o_t) = \mathrm{CE}(\hat{o}_t, o_t) - \mathrm{CE}(\hat{o}_t, \hat{o}_t) \end{eqnarray} $$with $\hat{o}_t$ and $o_t$ represent the target value (gold standard) and output value (network prediction), respectively, at time step $t$. The first term is the cross-entropy between the target value and the output value, and the second term is the entropy of the target value itself. Note that the second term is a constant, and serves just to make the minimum achievable loss to be 0 (perfect output).
More specifically, we have:
$$ \begin{equation} o_t = \sigma(w_3*s_t) \end{equation} $$where $s_t$ is the output of the CEC unit (a.k.a. the memory), which depends on the previous value of the memory $s_{t-1}$, and the input $x_{t,1}$ and $x_{t,2}$ (representing the first and second dimension of the input at time step $t$):
$$ \begin{equation} s_t = \underbrace{w_s * s_{t-1}}_\text{previous value} + \underbrace{w_1 * x_{t,1} + w_2 * x_{t,2}}_\text{input} \end{equation} $$where $w_s$ is the weight of the self-loop, which is 1.0. But for clarity of why this should be 1.0, the calculation later does not assume $w_s=1.0$.
In [14]:
import math
from IPython.display import Markdown, display
def printmd(string):
display(Markdown(string))
# Embedding
embedding = {}
embedding['a'] = (1.0, 1)
embedding['b'] = (-1, -1)
embedding['('] = (1, 0)
embedding[')'] = (0, 1)
# embedding['a'] = (-1, 0)
# embedding['b'] = (-0.5, 0)
# embedding['('] = (1, 1)
# embedding[')'] = (1, -1)
# Weights
w1=1.0
w2=1.0
w3=1.0
ws=1.0
memory_history = [0]
output_history = [0]
def sigmoid(x):
return 1.0/(1+math.exp(-x))
def gold(seq):
result = [0]
bracket_count = 0
for char in seq:
if char == '(':
bracket_count += 1
if char == ')':
bracket_count -= 1
result.append(sigmoid(bracket_count))
return result
def activate_memory(x1, x2):
prev_memory = memory_history[-1]
memory_history.append(ws*prev_memory + w1*x1 + w2*x2)
return memory_history[-1]
def activate_output(h):
output_history.append(sigmoid(w3*h))
return output_history[-1]
def predict(seq):
for char in seq:
activate_output(activate_memory(*embedding[char]))
result = output_history[:]
return result
def reset():
global memory_history, output_history
memory_history = [0]
output_history = [0]
def loss(gold_seq, pred_seq):
result = 0.0
per_position_loss = []
for idx, (corr, pred) in enumerate(zip(gold_seq, pred_seq)):
cur_loss = -(corr*math.log(pred) + (1-corr)*math.log(1-pred))
cur_loss -= -(corr*math.log(corr) + (1-corr)*math.log(1-corr))
result += cur_loss
per_position_loss.append(cur_loss)
return result, per_position_loss
def print_list(lst):
'''A convenience method to print a list of real numbers'''
as_str = ['{:+.3f}'.format(num) for num in lst]
print('[{}]'.format(', '.join(as_str)))
In [15]:
# See typical values of sigmoid
for i in range(5):
print('sigmoid({}) = {}'.format(i, sigmoid(i)))
Now let's check the function calculating the target value. Basically we want it to output $\sigma(0)$ or $\sigma(1)$ when the output is outside or inside a bracketed expression, respectively.
In [16]:
gold('a(a)a')[1:] # The first element is dummy
Out[16]:
Which is $\sigma(0), \sigma(1), \sigma(1), \sigma(0), \sigma(0)$, which is what we expect. So far so good.
In [17]:
test_seq = 'ab(ab)ab'
reset()
w1 = 1.0
w2 = 1.0
w3 = 1.0
result = predict(test_seq)
correct = gold(test_seq)
print('Output: ', end='')
print_list(result[1:])
print('Target: ', end='')
print_list(correct[1:])
print('Loss : {:.3f}'.format(loss(correct[1:], result[1:])[0]))
We see that the loss is still non-zero, and we see that some values are incorrectly predicted.
Next we will see the gradient calculation in progress, so that we can update the weight to reduce the loss.
To do the weight update, we need to calculate the partial derivative of the loss function with respect to the each weight. We have three weight parameters $w_1, w_2$, and $w_3$, so we need to compute three different partial derivatives.
For ease of notation, we denote $\mathrm{Loss}_t = \mathrm{Loss}(\hat{o}_t, o_t)$ as the loss at time step $t$ and $\mathrm{Loss} = \sum_t \mathrm{Loss}_t$ as the total loss over one sequence.
Remember that our objective is to reduce the total loss.
$$ \begin{eqnarray} \frac{\partial\mathrm{Loss}}{\partial w_i} & = & \sum_t\frac{\partial \mathrm{Loss}_t}{\partial w_i} \\ & = & \sum_t\frac{\partial \mathrm{Loss}_t}{\partial o_t} \cdot \frac{\partial o_t}{\partial w_i} \qquad \text{(by chain rule)} \\ \end{eqnarray} $$for $w_3$, we can already compute the gradient here, which is:
$$ \require{cancel} \begin{eqnarray} \frac{\partial\mathrm{Loss}}{\partial w_3} & = & \sum_t\frac{\partial \mathrm{Loss}_t}{\partial o_t} \cdot \frac{\partial o_t}{\partial w_i} \\ & = & \sum_t\underbrace{\frac{o_t - \hat{o}_t}{\cancel{o_t(1-o_t)}}}_{=\frac{\partial \mathrm{Loss}_t}{\partial o_t}} \cdot \underbrace{s_t \cdot \cancel{o_t(1-o_t)}}_{=\frac{\partial o_t}{\partial w_i}} \\ & = & \sum_t(o_t-\hat{o}_t)s_t \end{eqnarray} $$for $w_1$ and $w_2$, we have:
$$ \begin{eqnarray} \frac{\partial\mathrm{Loss}}{\partial w_i} & = & \sum_t\frac{\partial \mathrm{Loss}_t}{\partial o_t} \cdot \frac{\partial o_t}{\partial w_i} \\ & = & \sum_t \frac{o_t - \hat{o}_t}{o_t(1-o_t)} \cdot \frac{\partial o_t}{\partial s_t} \cdot \frac{\partial s_t}{\partial w_i} \\ & = & \sum_t \frac{o_t - \hat{o}_t}{\cancel{o_t(1-o_t)}} \cdot w_3\cdot \cancel{o_t(1-o_t)} \cdot \frac{\partial s_t}{\partial w_i} \\ & = & \sum_t (o_t - \hat{o}_t)w_3 \cdot \frac{\partial s_t}{\partial w_i} \\ & = & \sum_t (o_t - \hat{o}_t)w_3 \cdot \left(w_s\cdot\frac{\partial s_{t-1}}{\partial w_i} + x_{t,i}\right) \\ & = & \sum_t (o_t - \hat{o}_t)w_3 \cdot \left({w_s}^2\cdot\frac{\partial s_{t-2}}{\partial w_i} + w_s\cdot x_{t-1,i} + x_{t,i}\right) \\ & & \ldots \\ & = & \sum_t (o_t - \hat{o}_t)w_3 \cdot \left(\sum_{t'\leq t} {w_s}^{t-t'}x_{t',i}\right) \\ \end{eqnarray} $$We see that the gradient with respect to $w_1$ and $w_2$ contains the factor ${w_s}^{t-t'}$, where $t-t'$ can be as large as the input sequence length. So if $w_s \neq 1.0$, then either the gradient will vanish or blow up as the input sequence gets longer.
In [18]:
def dLdw1(test_seq, gold_seq, pred_seq, state_seq, info):
result = 0.0
grad_str = '<div style="font-family:monaco; font-size:12px">dL/dw1 = '
for time_step in range(1, len(gold_seq)):
cur_dell = (pred_seq[time_step] - gold_seq[time_step]) * w3
cur_dell *= sum(ws**(step-1)*embedding[test_seq[step-1]][0] for step in range(1, time_step+1))
if cur_dell < 0:
color = 'red'
else:
color = 'blue'
grad_str += '{}<span style="color:{}">{:+.3f}</span>'.format(' + ' if time_step > 1 else '', color, cur_dell)
result += cur_dell
grad_str += ' = <span style="color:{}; text-decoration:underline">{:+.3f}</span></div>'.format(
'red' if result < 0 else 'blue', result)
# printmd(grad_str)
info[0] += grad_str
return result
def dLdw2(test_seq, gold_seq, pred_seq, state_seq, info):
result = 0.0
grad_str = '<div style="font-family:monaco; font-size:12px">dL/dw2 = '
for time_step in range(1, len(gold_seq)):
cur_dell = (pred_seq[time_step] - gold_seq[time_step]) * w3
cur_dell *= sum(ws**(step-1)*embedding[test_seq[step-1]][1] for step in range(1, time_step+1))
if cur_dell < 0:
color = 'red'
else:
color = 'blue'
grad_str += '{}<span style="color:{}">{:+.3f}</span>'.format(' + ' if time_step > 1 else '', color, cur_dell)
result += cur_dell
grad_str += ' = <span style="color:{}; text-decoration:underline">{:+.3f}</span></div>'.format(
'red' if result < 0 else 'blue', result)
# printmd(grad_str)
info[0] += grad_str
return result
def dLdw3(test_seq, gold_seq, pred_seq, state_seq, info):
result = 0.0
grad_str = '<div style="font-family:monaco; font-size:12px">dL/dw3 = '
for time_step in range(1, len(gold_seq)):
cur_dell = (pred_seq[time_step] - gold_seq[time_step]) * state_seq[time_step]
if cur_dell < 0:
color = 'red'
else:
color = 'blue'
grad_str += '{}<span style="color:{}">{:+.3f}</span>'.format(' + ' if time_step > 1 else '', color, cur_dell)
result += cur_dell
grad_str += ' = <span style="color:{}; text-decoration:underline">{:+.3f}</span></div>'.format(
'red' if result < 0 else 'blue', result)
# printmd(grad_str)
info[0] += grad_str
return result
Now we define an experiment which takes in initial values of all the weights, learning rate, and maximum number of iterations. We also want to experiment with fixing the weight $w_3$ (i.e., it is not learned).
The code below will print the total loss, the loss at each time step, the output, target, and memory at each time step, and also the gradient for each learned parameter at each time step.
In [19]:
def experiment(test_seq, _w1=1.0, _w2=1.0, _w3=1.0, alpha=1e-1, max_iter=250, fixed_w3=True):
global w1, w2, w3
reset()
w1 = _w1
w2 = _w2
w3 = _w3
correct = gold(test_seq)
print('w1={:+.3f}, w2={:+.3f}, w3={:+.3f}'.format(w1, w2, w3))
for iter_num in range(max_iter):
result = predict(test_seq)
if iter_num < 15 or (iter_num % 50 == 49):
printmd('<div style="font-weight:bold">Iteration {}</div>'.format(iter_num))
print('Output: ', end='')
print_list(result[1:])
print('Target: ', end='')
print_list(correct[1:])
print('Memory: ', end='')
print_list(memory_history[1:])
total_loss, per_position_loss = loss(correct[1:], result[1:])
info = ['', iter_num]
info[0] = ('<div>Loss: <span style="font-weight:bold">{:.5f}</span>' +
'= <span style="font-family:monaco; font-size:12px">').format(total_loss)
for idx, per_pos_loss in enumerate(per_position_loss):
info[0] += '{}{:.3f}'.format(' + ' if idx > 0 else '', per_pos_loss)
info[0] += '</span></div>'
# printmd(loss_str)
w1 -= alpha * dLdw1(test_seq, correct, result, memory_history, info)
w2 -= alpha * dLdw2(test_seq, correct, result, memory_history, info)
if not fixed_w3:
w3 -= alpha * dLdw3(test_seq, correct, result, memory_history, info)
if iter_num < 15 or (iter_num % 50 == 49):
printmd(info[0])
print('w1={:+.3f}, w2={:+.3f}, w3={:+.3f}'.format(w1, w2, w3))
print()
reset()
return w1, w2, w3
In [20]:
embedding['a'] = (1.0, 1)
embedding['b'] = (-1, -1)
embedding['('] = (1, 0)
embedding[')'] = (0, 1)
w1, w2, w3 = experiment('ab(ab)bb', _w1=1.0, _w2=1.0, max_iter=250, alpha=1e-1, fixed_w3=True)
printmd('## Test on longer sequence')
experiment('aabba(aba)bab', _w1=w1, _w2=w2, _w3=w3, alpha=1e-2, max_iter=100)
Out[20]:
We saw in the experiment before that there is conflicting update (at one point of the sequence the gradient is positive, while at another point it is negative), which the original paper explains that it is caused by the weight into the memory cell needs to update the memory at one point (when we see brackets in this case) and retain information at another point (when we see any other characters).
Another core feature of LSTM that was designed to resolve this issue is that it adds gates: input gate and output gate, to control the flow of information through the memory cells.
In the following, we try adding an input gate, which the network should learn to activate (value = 1) only when it sees an opening bracket or closing bracket. So basically the input gate is telling the network which inputs are relevant and which are not.
Note: Below we have two versions for the input gate: linear with sigmoid, and bilinear with bias. The $w_4$ and $w_5$ have different interpretation depending on the input gate chosen. The bilinear gate was added because the input doesn't allow the linear gate to be useful.
In [21]:
w4 = 1.0
w5 = 1.0
input_history = [0]
gate_history = [0]
def reset_gated():
global memory_history, output_history, input_history, gate_history
memory_history = [0]
output_history = [0]
input_history = [0]
gate_history = [0]
def activate_input(x1, x2):
result = (w1*x1+w2*x2)
input_history.append(result)
return result
def activate_gate(x1, x2, bilinear_gate=True):
if bilinear_gate:
result = w4 + w5*x1*x2 # Bilinear gate
else:
result = sigmoid(w4*x1+w5*x2) # The true linear gate
gate_history.append(result)
return result
def dLdw1_gated(test_seq, gold_seq, pred_seq, state_seq, input_seq, gate_seq, info, bilinear_gate=True):
result = 0.0
grad_str = '<div style="font-family:monaco; font-size:12px">dL/dw1 = '
for time_step in range(1, len(gold_seq)):
cur_dell = (pred_seq[time_step] - gold_seq[time_step]) * w3
cur_dell *= sum(embedding[test_seq[step-1]][0]*gate_seq[step] for step in range(1, time_step+1))
if cur_dell < 0:
color = 'red'
else:
color = 'blue'
grad_str += '{}<span style="color:{}">{:+.3f}</span>'.format(' + ' if time_step > 1 else '', color, cur_dell)
result += cur_dell
grad_str += ' = <span style="color:{}; text-decoration:underline">{:+.3f}</span></div>'.format(
'red' if result < 0 else 'blue', result)
# printmd(grad_str)
info[0] += grad_str
return result
def dLdw2_gated(test_seq, gold_seq, pred_seq, state_seq, input_seq, gate_seq, info, bilinear_gate=True):
result = 0.0
grad_str = '<div style="font-family:monaco; font-size:12px">dL/dw2 = '
for time_step in range(1, len(gold_seq)):
cur_dell = (pred_seq[time_step] - gold_seq[time_step]) * w3
cur_dell *= sum(embedding[test_seq[step-1]][1]*gate_seq[step] for step in range(1, time_step+1))
if cur_dell < 0:
color = 'red'
else:
color = 'blue'
grad_str += '{}<span style="color:{}">{:+.3f}</span>'.format(' + ' if time_step > 1 else '', color, cur_dell)
result += cur_dell
grad_str += ' = <span style="color:{}; text-decoration:underline">{:+.3f}</span></div>'.format(
'red' if result < 0 else 'blue', result)
# printmd(grad_str)
info[0] += grad_str
return result
def dLdw4_gated(test_seq, gold_seq, pred_seq, state_seq, input_seq, gate_seq, info, bilinear_gate=True):
result = 0.0
grad_str = '<div style="font-family:monaco; font-size:12px">dL/dw4 = '
for time_step in range(1, len(gold_seq)):
cur_dell = (pred_seq[time_step] - gold_seq[time_step]) * w3
if bilinear_gate:
cur_dell *= sum(input_seq[step] for step in range(1, time_step+1))
else:
cur_dell *= sum(embedding[test_seq[step-1]][0]*gate_seq[step]*input_seq[step]*(1-gate_seq[step])
for step in range(1,time_step+1))
if cur_dell < 0:
color = 'red'
else:
color = 'blue'
grad_str += '{}<span style="color:{}">{:+.3f}</span>'.format(' + ' if time_step > 1 else '', color, cur_dell)
result += cur_dell
grad_str += ' = <span style="color:{}; text-decoration:underline">{:+.3f}</span></div>'.format(
'red' if result < 0 else 'blue', result)
# printmd(grad_str)
info[0] += grad_str
return result
def dLdw5_gated(test_seq, gold_seq, pred_seq, state_seq, input_seq, gate_seq, info, bilinear_gate=True):
result = 0.0
grad_str = '<div style="font-family:monaco; font-size:12px">dL/dw5 = '
for time_step in range(1, len(gold_seq)):
cur_dell = (pred_seq[time_step] - gold_seq[time_step]) * w3
if bilinear_gate:
cur_dell *= sum(embedding[test_seq[step-1]][0]*embedding[test_seq[step-1]][1]*input_seq[step]
for step in range(1, time_step+1))
else:
cur_dell *= sum(embedding[test_seq[step-1]][1]*gate_seq[step]*input_seq[step]*(1-gate_seq[step])
for step in range(1,time_step+1))
if cur_dell < 0:
color = 'red'
else:
color = 'blue'
grad_str += '{}<span style="color:{}">{:+.3f}</span>'.format(' + ' if time_step > 1 else '', color, cur_dell)
result += cur_dell
grad_str += ' = <span style="color:{}; text-decoration:underline">{:+.3f}</span></div>'.format(
'red' if result < 0 else 'blue', result)
# printmd(grad_str)
info[0] += grad_str
return result
def activate_memory_gated():
memory_history.append(ws*memory_history[-1] + input_history[-1]*gate_history[-1])
return memory_history[-1]
def predict_gated(seq):
for char in seq:
activate_input(*embedding[char])
activate_gate(*embedding[char])
activate_output(activate_memory_gated())
result = output_history[:]
return result
def experiment_gated(test_seq, _w1=1.0, _w2=1.0, _w3=1.0, _w4=1.0, _w5=1.0, alpha=1e-1, max_iter=750,
bilinear_gate=True, fixed_w3=True, fixed_w4=False, fixed_w5=False):
global w1, w2, w3, w4, w5
reset_gated()
w1 = _w1
w2 = _w2
w3 = _w3
w4 = _w4
w5 = _w5
correct = gold(test_seq)
print('w1={:+.3f}, w2={:+.3f}, w3={:+.3f}, w4={:+.3f}, w5={:+.3f}'.format(w1, w2, w3, w4, w5))
for iter_num in range(max_iter):
result = predict_gated(test_seq)
if iter_num < 15 or (iter_num % 50 == 49):
printmd('<div style="font-weight:bold">Iteration {}</div>'.format(iter_num))
print('Output: ', end='')
print_list(result[1:])
print('Target: ', end='')
print_list(correct[1:])
print('Memory: ', end='')
print_list(memory_history[1:])
print('Input : ', end='')
print_list(input_history[1:])
print('Gate : ', end='')
print_list(gate_history[1:])
total_loss, per_position_loss = loss(correct[1:], result[1:])
info = ['', iter_num]
info[0] = ('<div>Loss: <span style="font-weight:bold">{:.5f}</span>' +
'= <span style="font-family:monaco">').format(total_loss)
for idx, per_pos_loss in enumerate(per_position_loss):
info[0] += '{}{:.3f}'.format(' + ' if idx > 0 else '', per_pos_loss)
info[0] += '</span></div>'
# printmd(loss_str)
w1 -= alpha * dLdw1_gated(test_seq, correct, result, memory_history, input_history, gate_history,
info, bilinear_gate)
w2 -= alpha * dLdw2_gated(test_seq, correct, result, memory_history, input_history, gate_history,
info, bilinear_gate)
if not fixed_w3:
w3 -= alpha * dLdw3(test_seq, correct, result, memory_history, info, bilinear_gate)
if not fixed_w4:
w4 -= alpha * dLdw4_gated(test_seq, correct, result, memory_history, input_history, gate_history,
info, bilinear_gate)
if not fixed_w5:
w5 -= alpha * dLdw5_gated(test_seq, correct, result, memory_history, input_history, gate_history,
info, bilinear_gate)
if iter_num < 15 or (iter_num % 50 == 49):
printmd(info[0])
print('w1={:+.3f}, w2={:+.3f}, w3={:+.3f}, w4={:+.3f}, w5={:+.3f}'.format(w1, w2, w3, w4, w5))
print()
reset_gated()
return w1, w2, w3, w4, w5
In [22]:
embedding['a'] = (1.0, 1)
embedding['b'] = (-1, -1)
embedding['('] = (1, 0)
embedding[')'] = (0, 1)
experiment_gated('ab(ab)bb', _w1=1.0, _w2=1.0, _w4=1.0, _w5=1.0, alpha=1e-1, max_iter=250, fixed_w3=True)
Out[22]:
We see that after adding input gate (assuming it is possible for the input gate to exhibit the same properties as the true input gate, manifested by using bilinear gate here), can reach the optimal (loss = 0.0) faster (after iteration 199) compared to the one without input gate (only after iteration 249), although there are more parameters to learn with the input gate (two more: $w_4$ and $w_5$) and that the initial loss is higher with input gate (due to the incorrect gate value initially).
Also we see that the gate learned is actually not the true gate that we want. This is because the input is already separable even without input gate.
In previous experiment, the input gate learned is not the true gate that we want, but that's because the input embedding is ideal, i.e., it allows separation even without input gate.
Now let's experiment with noisy embedding, in which the true function cannot be obtained without input gate.
In [23]:
import random
a_1 = 1.0 + 0.2*(random.random()-0.5)
a_2 = 1.0/a_1
b_1 = -1.0 + 0.2*(random.random()-0.5)
b_2 = 1.0/b_1
embedding['a'] = (a_1, a_2)
embedding['b'] = (b_1, b_2)
embedding['('] = (1, 0)
embedding[')'] = (0, 1)
from pprint import pprint
pprint(embedding)
Here we make the input embedding such that the other characters have noise which should be ignored.
Let's see how the two models perform in this case.
In [24]:
embedding['a'] = (a_1, a_2)
embedding['b'] = (b_1, b_2)
embedding['('] = (1, 0)
embedding[')'] = (0, 1)
experiment('ab(ab)bb', _w1=1.0, _w2=1.0, alpha=1e-1, max_iter=250, fixed_w3=True)
Out[24]:
In [25]:
embedding['a'] = (a_1, a_2)
embedding['b'] = (b_1, b_2)
embedding['('] = (1, 0)
embedding[')'] = (0, 1)
experiment_gated('ab(ab)bb', _w1=1.0, _w2=1.0, _w4=1.0, _w5=1.0, alpha=1e-1, max_iter=250, fixed_w3=True)
Out[25]:
Now we see that the input gate is closer to the true gate: it tries to ignore irrelevant input by setting the weights of those input closer to 0. Although in this case it is still far from the true gate (the irrelevant input still gets positive score), we see that it has good impact on the loss, reaching an order of magnitude lower. And actually if we run more iterations, we see later that the gate will be learned correctly ($w_4 = 1.0, w_5=-1.0$).
Notice that in the network without input gate, at the end the overall gradient is zero, but actually the gradient at each position in the sequence is not zero, and in fact the magnitude is not quite small, meaning the network ends up at a non-optimal position, while in the gated version, we see the gradient approaches zero in all position.
In [28]:
# Trying nested brackets
experiment_gated('ab(aaa(bab)b)')
Out[28]: