Weight Tying in Language Models: A Technique to Parameter efficiency
Intro
In our model, we share the same weight matrix between the two embedding layers and the pre-softmax linear transformation - Attention is All You Need, Section 3.4. Embeddings and Softmax1
The sentence above is an excerpt from the Attention is all you need paper. The technique is actually called weight tying, which was introduced in the paper Using the Output Embedding to Improve Language Models2. In this short post, I will explain its rationale and implementation.
What is weight tying
flowchart LR wte(input embedding) lm_head(output embedding) wte --> ... --> lm_head
In the design of language models, there are typically two matrices.
- the input embedding (denoted as $\mathbf U$): transform a token ID to a token embedding, which is usually implemented as
nn.Embedding
in PyTorch. - the output embedding (denoted as $\mathbf V$): transform a token embedding to a probability distribution over vocab, which is usually implemented as
nn.Linear
in PyTorch.
The authors argue that we have similar expectations for these two matrices1.
- For $\mathbf U$, we hope that semantically similar tokens share similar token embeddings
- For $\mathbf V$, we hope that the score of tokens that are interchangeable to be similar.
In addition, the $\mathbf U$ and $\mathbf V$ have the same size. This raises an intuitive question: can we share the same weight matrix for the input embedding and output embedding?
The answer is yes, and you can find details of the authors’ experiments here2
The implementation of weight tying
In PyTorch, the $\mathbf U$ is implemented as the nn.Embedding
layer, and the $\mathbf V$ is implemented as the nn.Linear
layer:
in_features, hidden_dim = 3, 4
U = nn.Embedding(in_features, hidden_dim)
V = nn.Linear(hidden_dim, in_features, bias=False)
It’s trivial to share the matrix between $\mathbf U$ and $\mathbf V$ like this:
U.weight = V.weight
Wrap-up
The benefits of sharing the matrix between input embedding $\mathbf U$ and the output embedding $\mathbf V$ are clear: it reduces the number of parameters while maintaining or even improving performance, including lower perplexity2