Transformer

#

I coded this Transformer from scratch for learning. It is based on The Annotated Transformer by Harvard NLP, which uses PyTorch.

I tested this with a toy problem so that data loading, tokenizing, etc. code is not needed.

🧸 The toy problem is to reverse a given sequence whilst replacing every even repetition of a digit with a special token (X). For example,

input = 0 1 5 9 0 3 5 2 5
input after replacing even repetitions: 0 1 5 9 X 3 X 2 5
reversed = 5 2 X 3 X 9 5 1 0

🎫 If someone reading this has any questions or comments please find me on Twitter, @vpj.

import math

import numpy as np
import tensorflow as tf
#

Layer Normalization

#

Calculate the mean and standard deviation

The mean and variance is calculated along the last dimension.

def get_mean_std(x: tf.Tensor):
#
    mean = tf.reduce_mean(x, axis=-1, keepdims=True)
    squared = tf.square(x - mean)
    variance = tf.reduce_mean(squared, axis=-1, keepdims=True)
    std = tf.sqrt(variance)

    return mean, std
#

Layer normalization

def layer_norm(layer: tf.Tensor):
#
    with tf.variable_scope("norm"):
        scale = tf.get_variable("scale", shape=layer.shape[-1], dtype=tf.float32)
        base = tf.get_variable("base", shape=layer.shape[-1], dtype=tf.float32)
        mean, std = get_mean_std(layer)
#

Normalize

        norm = (layer - mean) / (std + 1e-6)
#

Adjust by learned scale and base

        return norm * scale + base
#

👀 Attention

#

Scaled Dot-Product Attention

The inputs query $Q$, key $K$ and value $V$ have form [batches, heads, sequence, features]. $d_k$ is the number of features; i.e. size of the last axis.

def attention(query: tf.Tensor, key: tf.Tensor, value: tf.Tensor, *,
              mask: tf.Tensor,
              keep_prob: float):
#
#

$d_k$ is the number of features

    d_k = query.shape[-1].value
#

Calculate attention scores $\frac{Q K^T}{\sqrt{d_k}}$
We need the dot-product of each query vector along the sequence with each key vector along the sequence. We do a matrix multiplication of the query with the transpose (last 2 axes) of the key. The last two axes of the resultant tensor will be a matrix $S_{i,j} = Q_i \cdot K_j$ where $i$ and $j$ are positions along the sequence.

    scores = tf.matmul(query, tf.transpose(key, perm=[0, 1, 3, 2]))
    scores = scores / tf.constant(math.sqrt(d_k))
#

scores has form [batches, heads, sequence, sequence], where in the last two dimensions [sequence, sequence] each row gives the attention vector mask has from [batches, heads, sequence, sequence]. We update the values of scores to be -1e9 everywhere mask is 0. So that when we calculate the $\mathop{softmax}$ the attention will be zero for those.

    mask_add = ((scores * 0) - 1e9) * (tf.constant(1.) - mask)
    scores = scores * mask + mask_add
#

$(i, j)$ entry of the attention matrix gives the attention from $i^{th}$ position to $j^{th}$ position.

    attn = tf.nn.softmax(scores, axis=-1)
#

Add a dropout layer to improve generalization

    attn = tf.nn.dropout(attn, keep_prob)
#

$\mathop{softmax}\Bigg(\frac{Q K^T}{\sqrt{d_k}}\Bigg)V$

    return tf.matmul(attn, value), attn
#

Prepare for multi-head attention

This prepares query $Q$, key $K$ and value $V$ that have form [batches, sequence, features].

def prepare_for_multi_head_attention(x: tf.Tensor, heads: int, name: str):
#
#

$d_{model}$ is the number of features

    n_batches, seq_len, d_model = x.shape
#

$d_k$ the number of features per head

    assert d_model % heads == 0
    d_k = d_model // heads
#

apply linear transformations

    x = tf.layers.dense(x, units=d_model, name=name)
#

split into multiple heads

    x = tf.reshape(x, shape=[n_batches, seq_len, heads, d_k])
#

transpose from [batches, sequence, heads, features] to [batches, heads, sequence, features]

    x = tf.transpose(x, perm=[0, 2, 1, 3])

    return x
#

Multi-Head Attention

The inputs query $Q$, key $K$ and value $V$ have form [batches, sequence, features].

def multi_head_attention(query: tf.Tensor, key: tf.Tensor, value: tf.Tensor, *,
                         mask: tf.Tensor,
                         heads: int,
                         keep_prob: float):
#
    with tf.variable_scope("multi_head"):
#

$d_{model}$ is the number of features

        n_batches, seq_len, d_model = query.shape
#

Apply linear transformations and split to multiple heads. The resulting tensors have form [batches, heads, sequence, features]

        query = prepare_for_multi_head_attention(query, heads, "query")
        key = prepare_for_multi_head_attention(key, heads, "key")
        value = prepare_for_multi_head_attention(value, heads, "value")
#

mask has form [batches, sequence, sequence] and we expand it to have form [batches, heads, sequence, sequence]

        mask = tf.expand_dims(mask, axis=1)
#

calculate output from attention layer

        out, _ = attention(query, key, value, mask=mask, keep_prob=keep_prob)
#

transform back from [batches, heads, sequence, features] to [batches, sequence, heads, features]

        out = tf.transpose(out, perm=[0, 2, 1, 3])
#

reshape to [batches, sequence, features]

        out = tf.reshape(out, shape=[n_batches, seq_len, d_model])
#

pass through a linear layer

        return tf.layers.dense(out, units=d_model, name="attention")
#

Position-wise Feed-Froward Networks

def feed_forward(x: tf.Tensor,
                 d_model: int, d_ff: int, keep_prob: float):
#
    with tf.variable_scope("feed_forward"):
        hidden = tf.layers.dense(x, units=d_ff, name="hidden")
        hidden = tf.nn.relu(hidden)
        hidden = tf.nn.dropout(hidden, keep_prob=keep_prob)
        return tf.layers.dense(hidden, units=d_model, name="out")
#

Encoder

#

Encoder Layer

This is a single encoder layer. The encoder consists of multiple such layers.

x has the form [batches, sequence, features]

def encoder_layer(x: tf.Tensor, *,
                  mask: tf.Tensor, index: int, heads: int,
                  keep_prob: float, d_ff: int):
#
#

$d_{model}$ is the number of features

    d_model = x.shape[-1]
#

Attention

    with tf.variable_scope(f"attention_{index}"):
        attention_out = multi_head_attention(x, x, x,
                                             mask=mask, heads=heads, keep_prob=keep_prob)
#

add a residual connection

        added = x + tf.nn.dropout(attention_out, keep_prob)
#

normalize

        x = layer_norm(added)
#

Feed-forward

    with tf.variable_scope(f"ff_{index}"):
        ff_out = feed_forward(x, d_model, d_ff, keep_prob)
#

add a residual connection

        added = x + tf.nn.dropout(ff_out, keep_prob)
#

normalize

        return layer_norm(added)
#

Encoder

Encoder consists of n_layers encoder layers.

def encoder(x: tf.Tensor, *,
            mask: tf.Tensor,
            n_layers: int,
            heads: int, keep_prob: float, d_ff: int):
#
    with tf.variable_scope("encoder"):
        for i in range(n_layers):
            x = encoder_layer(x,
                              mask=mask, index=i,
                              heads=heads, keep_prob=keep_prob, d_ff=d_ff)

        return x
#

Decoder

#

Decoder Layer

This is a single decoder layer. The decoder consists of multiple such layers.

encoding is the final output from the encoder. It has the form [batches, sequence, features]. enc_mask is the mask for encoding, of the form [batches, sequence, sequence].

x is the previous output from the decoder. During training we supply the true values for x. It has the form [batches, sequence, features] mask is the mask for x, of the form [batches, sequence, sequence].

def decoder_layer(encoding: tf.Tensor, x: tf.Tensor, *,
                  enc_mask: tf.Tensor, mask: tf.Tensor,
                  index: int, heads: int, keep_prob: float, d_ff: int):
#
#

$d_{model}$ is the number of features

    d_model = encoding.shape[-1]
#

Self-attention to x

    with tf.variable_scope(f"{index}_self_attention"):
        attention_out = multi_head_attention(x, x, x,
                                             mask=mask, heads=heads, keep_prob=keep_prob)
#

add a residual connection

        added = x + tf.nn.dropout(attention_out, keep_prob=keep_prob)
#

normalize

        x = layer_norm(added)
#

Attention to the output from the encoder encoding

    with tf.variable_scope(f"{index}_encoding_attention"):
        attention_out = multi_head_attention(x, encoding, encoding,
                                             mask=enc_mask, heads=heads, keep_prob=keep_prob)
#

add a residual connection

        added = x + tf.nn.dropout(attention_out, keep_prob=keep_prob)
#

normalize

        x = layer_norm(added)
#

Feed-forward

    with tf.variable_scope(f"{index}_ff"):
        ff_out = feed_forward(x, d_model, d_ff, keep_prob)
#

add a residual connection

        added = x + tf.nn.dropout(ff_out, keep_prob)
#

normalize

        return layer_norm(added)
#

Decoder

Decoder consists of n_layers decoder layers.

def decoder(encoding: tf.Tensor, x: tf.Tensor, *,
            enc_mask: tf.Tensor, mask: tf.Tensor,
            n_layers: int,
            heads: int, keep_prob: float, d_ff: int):
#
    with tf.variable_scope("decoder"):
        for i in range(n_layers):
            x = decoder_layer(encoding, x,
                              enc_mask=enc_mask, mask=mask, index=i,
                              heads=heads, keep_prob=keep_prob, d_ff=d_ff)

        return x
#

Embeddings

#

Word Embeddings

We use a table look up to get embeddings. The table is a trainable variable, so the embeddings get learned during training.

def get_embeddings(input_ids: tf.Tensor, output_ids: tf.Tensor,
                   vocab_size: int, d_model: int):
#
#

Embedding table

    word_embeddings = tf.get_variable("word_embeddings",
                                      shape=[vocab_size, d_model],
                                      dtype=tf.float32,
                                      initializer=tf.initializers.random_normal())
#

Embeddings of inputs, for the encoder

    in_emb = tf.nn.embedding_lookup(word_embeddings, input_ids)
#

Embeddings of outputs, for the decoder

    out_emb = tf.nn.embedding_lookup(word_embeddings, output_ids)

    return word_embeddings, in_emb, out_emb
#

∿ Positional Encodings

The positional encoding encodes the position along the sequence into a set of $d_{model}$ features.

where $p$ is the position along the sequence and $i$ is the index of the feature.

def generate_positional_encodings(d_model: int, max_len: int = 5000):
#
#

Empty $PE$ matrix

    encodings = np.zeros((max_len, d_model), dtype=float)
#

$p$

    position = np.arange(0, max_len).reshape((max_len, 1))
#

$2i$

    two_i = np.arange(0, d_model, 2)
#

$10000^{-\frac{2i}{d_{model}}}$

    div_term = np.exp(-math.log(10000.0) * two_i / d_model)
#

${PE}_{p,2i}$

    encodings[:, 0::2] = np.sin(position * div_term)
#

${PE}_{p,2i + 1}$

    encodings[:, 1::2] = np.cos(position * div_term)
#

convert to a TensorFlow tensor from NumPy

    return tf.constant(encodings.reshape((1, max_len, d_model)),
                       dtype=tf.float32, name="positional_encodings")
#

Prepare Embeddings

Add positional encodings, and normalize embeddings before the encode or decode stages.

def prepare_embeddings(x: tf.Tensor, *,
                       positional_encodings: tf.Tensor,
                       keep_prob: float, is_input: bool):
#
    name = "prepare_input" if is_input else "prepare_output"
    with tf.variable_scope(name):
        _, seq_len, _ = x.shape
#

add positional encodings

        x = x + positional_encodings[:, :seq_len, :]
#

drop out

        x = tf.nn.dropout(x, keep_prob)
#

normalize

        return layer_norm(x)
#

Generator

Get the final outputs by sending the output of the decoder through linear layer and softmax activation.

def generator(x: tf.Tensor, *, vocab_size: int):
#
    res = tf.layers.dense(x, units=vocab_size, name="generator")
    return tf.nn.log_softmax(res, axis=-1)
#

Label smoothing loss

This prevents the model from becoming over confident on certain results. Another alternative could be to add a small entropy loss.

Here, instead of making the probabilities for expected 1 and 0 for others, we set the log-probabilities of expected to be 1 - smoothing and others to smoothing / (vocab_size - 1).

def label_smoothing_loss(results: tf.Tensor, expected: tf.Tensor, *,
                         vocab_size: int, smoothing: float):
#
    results = tf.reshape(results, shape=(-1, vocab_size))
    expected = tf.reshape(expected, shape=[-1])

    confidence = 1 - smoothing
    smoothing = smoothing / (vocab_size - 1)
#

set the log-probabilities

    expected = tf.one_hot(expected, depth=vocab_size) * (confidence - smoothing)
    expected += smoothing
#

KL-Divergence

    results = tf.distributions.Categorical(logits=results)
    expected = tf.distributions.Categorical(logits=expected)
    return tf.reduce_mean(tf.distributions.kl_divergence(results, expected))
#

Generate Data

This generates training data for our toy problem

We use vocab_size - 2 digits, vocab_size - 2 is used as the special token to replace the even repetitions and vocab_size - 1 is used as the special token to indicate start of sequence for the decoder.

def generate_data(batch_size: int, seq_len: int, vocab_size: int):
#
    start_token = vocab_size - 1
    repeat_token = vocab_size - 2
    vocab_size -= 2

    inputs = np.random.randint(0, vocab_size, size=(batch_size, seq_len))
#

reverse

    outputs = np.zeros((batch_size, seq_len + 1), dtype=int)
    outputs[:, 1:] = np.flip(inputs, 1)
#

initial output supplied to decoder,

    outputs[:, 0] = start_token

    for i in range(batch_size):
        v = np.zeros(vocab_size, dtype=bool)
        for j in range(seq_len):
            word = inputs[i, j]
#

replace with repeat_token if repeated

            if v[word]:
                v[word] = False
                outputs[i][seq_len - j] = repeat_token
            else:
                v[word] = True

    return inputs, outputs
#

Learning rate

The learning rate varies during training. Learning rate is increased linearly up to warm_up steps, and then slowly decreased.

def noam_learning_rate(step: int, warm_up: float, d_model: int):
#
    return (d_model ** -.5) * min(step ** -.5, step * warm_up ** -1.5)
#

⬕ Mask out subsequent positions in the output

Otherwise the model gets access to true outputs during training.

def output_subsequent_mask(seq_len: int):
#
    mask = np.zeros((seq_len, seq_len), dtype=float)
#

Set mask[i, j] = 0 for all j > i.

    for i in range(seq_len):
        for j in range(i + 1):
            mask[i, j] = 1.

    return mask
#

🏋 ️Train

def train():
#
#

Constants

    seq_length = 10
#

digits 0 to 9, 10 is the special token to replace repetitions and 11 is the special token to indicate start of sequence for decoder

    vocab_size = 10 + 1 + 1
    vocab_str = [f"{i}" for i in range(10)]
    vocab_str += ['X', 'S']
#
    batch_size = 32  # 12000
    d_model = 128  # 512
    heads = 8
    keep_prob = 0.9
    n_layers = 2  # 6
    d_ff = 256  # 2048
#

Positional Encodings

    positional_encodings = generate_positional_encodings(d_model)
#

Placeholders

    inputs = tf.placeholder(dtype=tf.int32,
                            shape=(batch_size, seq_length), name="input")
    outputs = tf.placeholder(dtype=tf.int32,
                             shape=(batch_size, seq_length), name="output")
    expected = tf.placeholder(dtype=tf.int32,
                              shape=(batch_size, seq_length), name="expected")
    inputs_mask = tf.placeholder(dtype=tf.float32,
                                 shape=(1, 1, seq_length),
                                 name="input_mask")
    output_mask = tf.placeholder(dtype=tf.float32,
                                 shape=(1, seq_length, seq_length),
                                 name="output_mask")
#

Learning rate

    learning_rate = tf.placeholder(dtype=tf.float32, name="learning_rate")
#

Create TensorFlow graph

    w_embed, input_embeddings, output_embeddings = get_embeddings(inputs, outputs, vocab_size,
                                                                  d_model)
    input_embeddings = prepare_embeddings(input_embeddings,
                                          positional_encodings=positional_encodings,
                                          keep_prob=keep_prob,
                                          is_input=True)
    output_embeddings = prepare_embeddings(output_embeddings,
                                           positional_encodings=positional_encodings,
                                           keep_prob=keep_prob,
                                           is_input=False)

    encoding = encoder(input_embeddings, mask=inputs_mask, n_layers=n_layers, heads=heads,
                       keep_prob=keep_prob, d_ff=d_ff)
    decoding = decoder(encoding, output_embeddings,
                       enc_mask=inputs_mask, mask=output_mask,
                       n_layers=n_layers, heads=heads, keep_prob=keep_prob, d_ff=d_ff)
    log_results = generator(decoding, vocab_size=vocab_size)
    results = tf.exp(log_results)
#

Loss

    loss = label_smoothing_loss(log_results, expected, vocab_size=vocab_size, smoothing=0.0)
#

Optimizer

    adam = tf.train.AdamOptimizer(learning_rate=learning_rate, epsilon=1e-5)
    params = tf.trainable_variables()
    grads, _ = tf.clip_by_global_norm(tf.gradients(loss, params), 5.)
    grads_and_vars = list(zip(grads, params))
    train_op = adam.apply_gradients(grads_and_vars, name="apply_gradients")
#

Training

    warm_up = 400
    batch_in_mask = np.ones((1, 1, seq_length), dtype=float)
    batch_out_mask = output_subsequent_mask(seq_length)
    batch_out_mask = batch_out_mask.reshape(1, seq_length, seq_length)
#
    def __print_seq(seq):
        return ' '.join([vocab_str[i] for i in seq])

    with tf.Session() as session:
        session.run(tf.global_variables_initializer())
#

Training loop

        for i in range(100_000):
            lr = noam_learning_rate(i + 1, warm_up, d_model)
#

generate new data

            batch_in, batch_out = generate_data(batch_size, seq_length, vocab_size)
#

Train

            _, batch_loss, batch_res = session.run([train_op, loss, results],
                                                   feed_dict={
                                                       learning_rate: lr,
                                                       inputs: batch_in,
                                                       outputs: batch_out[:, :-1],
                                                       expected: batch_out[:, 1:],
                                                       inputs_mask: batch_in_mask,
                                                       output_mask: batch_out_mask
                                                   })
#

Log

            if i % 100 == 0:
                print(f"step={i}\tloss={batch_loss: .6f}")
                print(f"inp=  {__print_seq(batch_in[0])}")
                print(f"exp={__print_seq(batch_out[0])}")
                print(f"res=  {__print_seq(np.argmax(batch_res[0], -1))}")


if __name__ == '__main__':
    train()