Transformer 架构变化:RMSNorm 指南
引言
从 2017 年 Transformer 架构被提出以来,到现在 2025 已经 8 年过去了,Transformer 架构已经发生了很多变化。比如,现如今越来越多的大模型采用的是 RMSNorm1 而不是 LayerNorm。今天这篇文章就是对 RMSNorm 的一个简单介绍,在了解 RMSNorm 之前,我们不妨先回顾一下什么是 LayerNorm
LayerNorm 回顾
$$ \mathbf y=\frac{\mathbf x-E[\mathbf x]}{\sqrt{Var(\mathbf x)+\epsilon}}*\gamma+\beta $$
上面是 LayerNorm 的公式,如果我们忽略放缩因子 $\gamma,\beta$ 不看,LayerNorm 做的事情很好理解:将每一个样本的特征向量 $x$ 转变为均值为 0,标准差为 1 的特征向量
为什么 LayerNorm 是有用的呢?之前流行的解释是
- re-centering:输入 $\mathbf x$ 总是会减去均值 $E[\mathbf x]$。好处是如果输入 $\mathbf x$ 发生了整体的偏移(Shift Noise)也没事,输入 $\mathbf x$ 始终会在 0 的附近
- re-scaling:减去均值之后总是会除以 $\sqrt{Var(\mathbf x)+\epsilon}$。好处是如果输入 $\mathbf x$ 被成比例放缩,也没有影响
可以写个简单的 PyTorch 代码验证一下
import torch
def re_centering(x):
return x - x.mean(dim=-1)
def re_scaling(x):
return x / (x.std(dim=-1) + 1e-5)
x = torch.arange(4).float()
print(x, re_centering(x + 10000))
# tensor([0., 1., 2., 3.]) tensor([-1.5000, -0.5000, 0.5000, 1.5000])
print(x, re_scaling(x * 10000))
# tensor([0., 1., 2., 3.]) tensor([0.0000, 0.7746, 1.5492, 2.3238])
RMSNorm
RMSNorm 认为 LayerNorm 的价值在于 re-scaling 特性,跟 re-centering 倒是关系不大1,所以在设计 RMSNorm 的时候作者只考虑如何做 re-scaling。下面是 RMSNorm 的公式
$$ \mathbf y=\frac{\mathbf x}{\sqrt{\frac{1}{n}\sum_ix_i^2+\epsilon}}*\gamma $$
和 LayerNorm 对比,主要的几个差异如下
- 分子不需要减去 $E[\mathbf x]$
- 分母从 $Var(\mathbf x)$ 变成了 $\frac{1}{n}\sum_ix_i^2$
- 只需要维护 $\gamma$ 参数,不需要维护 $\beta$
RMSNorm 的好处
通过上面观察到的几点差异,我们可以看出 RMSNorm 的一个显而易见的好处:
- 需要维护的参数更少了,只有 $\gamma$
- 计算量也减少了,因为不用计算输入 $\mathbf x$ 的均值 $E[\mathbf x]$(注意 $Var(\mathbf x)$ 的计算也需要均值)
当然,最重要的是,RMSNorm 的效果还真就挺好的,跟 LayerNorm 也差不了多少,具体的实验细节和结果可以参考原论文1
PyTorch API
PyTorch 提供的 nn.RMSNorm
实现有如下的几个参数
normalized_shape
:表示用于计算 RMS 基于的输入张量的末尾维度eps
:为了数值稳定,加上的一个很小的值element_affine
:是否要启用可学习参数 $\gamma$?
>>> rms_norm = nn.RMSNorm([2, 3])
>>> input = torch.randn(2, 2, 3)
>>> rms_norm(input)
RMSNorm from Scratch
手写 RMSNorm 的难度不是很大,下面我写的代码可以作为参考
import torch
import torch.nn as nn
import torch.nn.functional as F
class RMSNorm(nn.Module):
def __init__(
self,
normalized_shape: list | tuple,
eps: float = 1e-5,
element_affine: bool = True,
):
super().__init__()
self.eps = eps
self.element_affine = element_affine
if self.element_affine:
self.gamma = nn.Parameter(torch.ones(normalized_shape))
else:
self.register_parameter("gamma", None)
def forward(self, x: torch.Tensor):
x = x * torch.rsqrt(self.eps + x.pow(2).mean(dim=-1, keepdim=True))
return x if self.gamma is None else x * self.gamma