Weight Tying in Language Models: A Technique to Parameter efficiency

Quote

In our model, we share the same weight matrix between the two embedding layers and the pre-softmax linear transformation - Attention is All You Need, Section 3.4. Embeddings and Softmax1

The sentence above is an excerpt from the Attention is all you need paper. The technique is actually called weight tying, which was introduced in the paper Using the Output Embedding to Improve Language Models2. In this short post, I will explain its rationale and implementation.

flowchart LR
    wte(input embedding)
    lm_head(output embedding)
    wte --> ... --> lm_head

In the design of language models, there are typically two matrices.

  • the input embedding (denoted as $\mathbf U$): transform a token ID to a token embedding, which is usually implemented as nn.Embedding in PyTorch.
  • the output embedding (denoted as $\mathbf V$): transform a token embedding to a probability distribution over vocab, which is usually implemented as nn.Linear in PyTorch.

The authors argue that we have similar expectations for these two matrices1.

  • For $\mathbf U$, we hope that semantically similar tokens share similar token embeddings
  • For $\mathbf V$, we hope that the score of tokens that are interchangeable to be similar.

In addition, the $\mathbf U$ and $\mathbf V$ have the same size. This raises an intuitive question: can we share the same weight matrix for the input embedding and output embedding?

The answer is yes, and you can find details of the authors’ experiments here2

In PyTorch, the $\mathbf U$ is implemented as the nn.Embedding layer, and the $\mathbf V$ is implemented as the nn.Linear layer:

python

in_features, hidden_dim = 3, 4

U = nn.Embedding(in_features, hidden_dim)
V = nn.Linear(hidden_dim, in_features, bias=False)

It’s trivial to share the matrix between $\mathbf U$ and $\mathbf V$ like this:

python

U.weight = V.weight

The benefits of sharing the matrix between input embedding $\mathbf U$ and the output embedding $\mathbf V$ are clear: it reduces the number of parameters while maintaining or even improving performance, including lower perplexity2