In [0]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
|
See this notebook for a walk-through of full transformer implementation.
The transformer architecture uses stacked attention layers in place of CNNs or RNNs. This makes it easy to learn long-range dependencise but it contains no built in information about the relative positions of items in a sequence.
To give the model access to this information the transformer architecture uses adda a position encoding to the input.
This endocing is a vector of sines and cosines at each position, where each sine-cosine pair rotates at a different frequency.
Nearby locations will have similar position-encoding vectors.
In [0]:
import numpy as np
import matplotlib.pyplot as plt
The angle rates range from 1 [rads/step]
to min_rate [rads/step]
over the vector depth.
Formula for angle rate:
$$angle\_rate_d = (min\_rate)^{d / d_{max}} $$
In [0]:
num_positions = 50
depth = 512
min_rate = 1/10000
assert depth%2 == 0, "Depth must be even."
angle_rate_exponents = np.linspace(0,1,depth//2)
angle_rates = min_rate**(angle_rate_exponents)
The resulting exponent goes from 0
to 1
, causing the angle_rates
to drop exponentially from 1
to min_rate
.
In [0]:
plt.semilogy(angle_rates)
plt.xlabel('Depth')
plt.ylabel('Angle rate [rads/step]')
Broadcasting a multiply over angle rates and positions gives a map of the position encoding angles as a function of depth.
In [0]:
positions = np.arange(num_positions)
angle_rads = (positions[:, np.newaxis])*angle_rates[np.newaxis, :]
In [0]:
plt.figure(figsize = (14,8))
plt.pcolormesh(
# Convert to degrees, and wrap around at 360
angle_rads*180/(2*np.pi) % 360,
# Use a cyclical colormap so that color(0) == color(360)
cmap='hsv', vmin=0, vmax=360)
plt.xlim([0,len(angle_rates)])
plt.ylabel('Position')
plt.xlabel('Depth')
bar = plt.colorbar(label='Angle [deg]')
bar.set_ticks(np.linspace(0,360,6+1))
Raw angles are not a good model input (they're either unbounded, or discontinuous). So take the sine and cosine:
In [0]:
sines = np.sin(angle_rads)
cosines = np.cos(angle_rads)
pos_encoding = np.concatenate([sines, cosines], axis=-1)
In [0]:
plt.figure(figsize=(14,8))
plt.pcolormesh(pos_encoding,
# Use a diverging colormap so it's clear where zero is.
cmap='RdBu', vmin=-1, vmax=1)
plt.xlim([0,depth])
plt.ylabel('Position')
plt.xlabel('Depth')
plt.colorbar()
Nearby locations will have similar position-encoding vectors.
To demonstrate compare one position's encoding (here position 20) with each of the others:
In [0]:
pos_encoding_at_20 = pos_encoding[20]
dots = np.dot(pos_encoding,pos_encoding_at_20)
SSE = np.sum((pos_encoding - pos_encoding_at_20)**2, axis=1)
Regardless of how you compare the vecors, they are most similar 20, and clearly diverge as you move away:
In [0]:
plt.figure(figsize=(10,8))
plt.subplot(2,1,1)
plt.plot(dots)
plt.ylabel('Dot product')
plt.subplot(2,1,2)
plt.plot(SSE)
plt.ylabel('SSE')
plt.xlabel('Position')
The paper explains, at the end of section 3.5, that any relative position encoding can be written as a linear function of the current position.
To demonstrate, this section builds a matrix that calculates these relative position encodings.
In [0]:
def transition_matrix(position_delta, angle_rates = angle_rates):
# Implement as a matrix multiply:
# sin(a+b) = sin(a)*cos(b)+cos(a)*sin(b)
# cos(a+b) = cos(a)*cos(b)-sin(a)*sin(b)
# b
angle_delta = position_delta*angle_rates
# sin(b), cos(b)
sin_delta = np.sin(angle_delta)
cos_delta = np.cos(angle_delta)
I = np.eye(len(angle_rates))
# sin(a+b) = sin(a)*cos(b)+cos(a)*sin(b)
update_sin = np.concatenate([I*cos_delta, I*sin_delta], axis=0)
# cos(a+b) = cos(a)*cos(b)-sin(a)*sin(b)
update_cos = np.concatenate([-I*sin_delta, I*cos_delta], axis=0)
return np.concatenate([update_sin, update_cos], axis=-1)
For example, create the matrix that calculates the position encoding 10 steps back, from the current position encoding:
In [0]:
position_delta = -10
update = transition_matrix(position_delta)
Applying this matrix to each position encoding vector gives position encoding vector from -10 steps away, resulting in a shifted position-encoding map:
In [0]:
plt.figure(figsize=(14,8))
plt.pcolormesh(np.dot(pos_encoding,update), cmap='RdBu', vmin=-1, vmax=1)
plt.xlim([0,depth])
plt.ylabel('Position')
plt.xlabel('Depth')
This is accurate to numerical precision.
In [0]:
errors = np.dot(pos_encoding,update)[10:] - pos_encoding[:-10]
abs(errors).max()