Transformer architecture variation: RMSNorm

It’s been 8 years since the famous transformer architecture was first proposed. You might have noticed that some modifications to the original design - for instance, most large language models (LLMs) now use RMSNorm1 instead of LayerNorm. Today I will briefly introduce RMSNorm, but first, let’s recap LayerNorm.

$$ \mathbf y=\frac{\mathbf x-E[\mathbf x]}{\sqrt{Var(\mathbf x)+\epsilon}}*\gamma+\beta $$

The equation above shows how LayerNorm works. If we ignore the scaling factors ($\gamma, \beta$), LayerNorm’s behavior becomes intuitive: it transforms each input $\mathbf x$ into a feature vector with zero mean and unit standard deviation .

Why does LayerNorm work? The most common explanations are

  • re-centering: the input vector $\mathbf x$ will always substract it’s mean $E[\mathbf x]$. The benefit is: it’s fine if $\mathbf x$ has shift noise, the values of $\mathbf x$ will remain centered around zero.
  • re-scaling: after mean substracttion, $\mathbf x$ is divided by $\sqrt{Var(\mathbf x)+\epsilon}$. The benefit is: it ensures the output representation remains intact when the input is randomly scaled.

The two benefits can be proved by the following PyTorch code.

python

import torch


def re_centering(x):
    return x - x.mean(dim=-1)


def re_scaling(x):
    return x / (x.std(dim=-1) + 1e-5)


x = torch.arange(4).float()
print(x, re_centering(x + 10000))
# tensor([0., 1., 2., 3.]) tensor([-1.5000, -0.5000,  0.5000,  1.5000])
print(x, re_scaling(x * 10000))
# tensor([0., 1., 2., 3.]) tensor([0.0000, 0.7746, 1.5492, 2.3238])

The RMSNorm authors argue that re-scaling - not re-centering is LayerNorm’s key benefit1. Based on this insight, they proposed RMSNorm in the following form:

$$ \mathbf y=\frac{\mathbf x}{\sqrt{\frac{1}{n}\sum_ix_i^2+\epsilon}}*\gamma $$

Compared to LayerNorm, RMSNorm has three key differences:

  • The numerator skips the mean-centering step.
  • The denominator use $\frac{1}{n}\sum_ix_i^2$ instead of $Var(\mathbf x)$.
  • The $\beta$ is removed, leaving only $\gamma$ as the learnable scale parameter.

From these key differences, we can identify the main advantages of RMSNorm:

  • Less parameter to maintain, only $\gamma$ needs to be learned.
  • Less computation, for it does not need to do the mean calculation.

What’s more, the performance of RMSNorm and LayerNorm is comparable. More details can be found in this paper1.

PyTorch’s nn.RMSNorm implementation includes the following key parameters:

  • normalized_shape: a list or tuple represents the trailing dimensions over which root mean square (RMS) is computed.
  • eps: the value of $\epsilon$.
  • element_affine: should we enable learnable parameter $\gamma$?

python

>>> rms_norm = nn.RMSNorm([2, 3])
>>> input = torch.randn(2, 2, 3)
>>> rms_norm(input)

Writing RMSNorm from scratch serves as excellent PyTorch practice. Below is my implementation.

python

import torch
import torch.nn as nn
import torch.nn.functional as F

class RMSNorm(nn.Module):
    def __init__(
        self,
        normalized_shape: list | tuple,
        eps: float = 1e-5,
        element_affine: bool = True,
    ):
        super().__init__()
        self.eps = eps
        self.element_affine = element_affine
        if self.element_affine:
            self.gamma = nn.Parameter(torch.ones(normalized_shape))
        else:
            self.register_parameter("gamma", None)

    def forward(self, x: torch.Tensor):
        x = x * torch.rsqrt(self.eps + x.pow(2).mean(dim=-1, keepdim=True))

        return x if self.gamma is None else x * self.gamma