Transformer architecture variation: RMSNorm
Intro
It’s been 8 years since the famous transformer architecture was first proposed. You might have noticed that some modifications to the original design - for instance, most large language models (LLMs) now use RMSNorm1 instead of LayerNorm. Today I will briefly introduce RMSNorm, but first, let’s recap LayerNorm.
LayerNorm Recap
$$ \mathbf y=\frac{\mathbf x-E[\mathbf x]}{\sqrt{Var(\mathbf x)+\epsilon}}*\gamma+\beta $$
The equation above shows how LayerNorm works. If we ignore the scaling factors ($\gamma, \beta$), LayerNorm’s behavior becomes intuitive: it transforms each input $\mathbf x$ into a feature vector with zero mean and unit standard deviation .
Why does LayerNorm work? The most common explanations are
- re-centering: the input vector $\mathbf x$ will always substract it’s mean $E[\mathbf x]$. The benefit is: it’s fine if $\mathbf x$ has shift noise, the values of $\mathbf x$ will remain centered around zero.
- re-scaling: after mean substracttion, $\mathbf x$ is divided by $\sqrt{Var(\mathbf x)+\epsilon}$. The benefit is: it ensures the output representation remains intact when the input is randomly scaled.
The two benefits can be proved by the following PyTorch code.
import torch
def re_centering(x):
return x - x.mean(dim=-1)
def re_scaling(x):
return x / (x.std(dim=-1) + 1e-5)
x = torch.arange(4).float()
print(x, re_centering(x + 10000))
# tensor([0., 1., 2., 3.]) tensor([-1.5000, -0.5000, 0.5000, 1.5000])
print(x, re_scaling(x * 10000))
# tensor([0., 1., 2., 3.]) tensor([0.0000, 0.7746, 1.5492, 2.3238])
RMSNorm
The RMSNorm authors argue that re-scaling - not re-centering is LayerNorm’s key benefit1. Based on this insight, they proposed RMSNorm in the following form:
$$ \mathbf y=\frac{\mathbf x}{\sqrt{\frac{1}{n}\sum_ix_i^2+\epsilon}}*\gamma $$
Compared to LayerNorm, RMSNorm has three key differences:
- The numerator skips the mean-centering step.
- The denominator use $\frac{1}{n}\sum_ix_i^2$ instead of $Var(\mathbf x)$.
- The $\beta$ is removed, leaving only $\gamma$ as the learnable scale parameter.
Why RMSNorm
From these key differences, we can identify the main advantages of RMSNorm:
- Less parameter to maintain, only $\gamma$ needs to be learned.
- Less computation, for it does not need to do the mean calculation.
What’s more, the performance of RMSNorm and LayerNorm is comparable. More details can be found in this paper1.
PyTorch API
PyTorch’s nn.RMSNorm
implementation includes the following key parameters:
normalized_shape
: a list or tuple represents the trailing dimensions over which root mean square (RMS) is computed.eps
: the value of $\epsilon$.element_affine
: should we enable learnable parameter $\gamma$?
>>> rms_norm = nn.RMSNorm([2, 3])
>>> input = torch.randn(2, 2, 3)
>>> rms_norm(input)
RMSNorm from Scratch
Writing RMSNorm from scratch serves as excellent PyTorch practice. Below is my implementation.
import torch
import torch.nn as nn
import torch.nn.functional as F
class RMSNorm(nn.Module):
def __init__(
self,
normalized_shape: list | tuple,
eps: float = 1e-5,
element_affine: bool = True,
):
super().__init__()
self.eps = eps
self.element_affine = element_affine
if self.element_affine:
self.gamma = nn.Parameter(torch.ones(normalized_shape))
else:
self.register_parameter("gamma", None)
def forward(self, x: torch.Tensor):
x = x * torch.rsqrt(self.eps + x.pow(2).mean(dim=-1, keepdim=True))
return x if self.gamma is None else x * self.gamma