LLM 推理加速 - KV Cache
背景
LLM 用于推理的时候就是不断基于前面的所有 token 生成下一个 token
假设现在已经生成了 $t$ 个 token,用 $x_{1:t}$ 表示。在下一轮,LLM 会生成 $x_{1:t+1}$,注意他们的前 $t$ 个 token 是一样的
$$ x_{1:t+1}=\text{LLM}(x_{1:t}) $$
再下一步也是相似的
$$ x_{1:t+2}=\text{LLM}(x_{1:t+1}) $$
概括来说,每一轮用上一轮的输出当成新的输入让 LLM 预测,一般这个过程会持续到输出达到提前设定的最大长度或者是 LLM 自己生成了特殊的结束 token
KV Cache 原理
只需要看 LLM 的连续两次前向传播推理计算就很好理解为什么说存在重复计算了
比如考虑下面这一步
$$ x_{1:t+1}=\text{LLM}(x_{1:t}) $$
LLM 的输入是 $x_{1:t}$,先只看最后一个 token $x_t$,它的 query 向量会和前面的每个 token 以及自己产生的 key 向量计算内积
$$ \mathbf q_{t}^T\mathbf k_{1},\mathbf q_{t}^T\mathbf k_{2},…,\mathbf q_{t}^T\mathbf k_{t} $$
然后看下一步
$$ x_{1:t+2}=\text{LLM}(x_{1:t+1}) $$
LLM 的输入是 $x_{1:t+1}$,看最后一个 token $x_{t+1}$,它的 query 向量也会和前面的每个 token 以及自己产生的 key 向量计算内积
$$ \mathbf q_{t+1}^T\mathbf k_{1},\mathbf q_{t+1}^T\mathbf k_{2},…,\mathbf q_{t+1}^T\mathbf k_{t+1} $$
此时考虑 $x_{t+1}$ 的前一个 token $x_t$,它也要经历类似的步骤
$$ \mathbf q_{t}^T\mathbf k_{1},\mathbf q_{t}^T\mathbf k_{2},…,\mathbf q_{t}^T\mathbf k_{t} $$
可以看到,这个计算完全和上一轮的计算重复了,对于在 $x_t$ 之前的 token 也是这个道理。我们需要重新计算得到 $x_{1:t}$ 的所有 key 向量和 value 向量,而这些向量的值其实是不会变的🧐
那么我们只需要把之前 token 的 key 向量和 value 向量都缓存起来,那么就没有必要重复计算了,这就是 KV Cache 的核心思想
在运用了 KV Cache 之后,除了第一轮以外的每一轮,我们都只需要关注输入的最后一个 token,用这个 token 计算得到它的 query, key, value 向量,然后拿着这三个新向量和之前所有缓存的 key/value 向量计算自注意力
我画了一个示意图来帮助理解开启 KV Cache 之后发生了什么,考虑 LLM 从无到有开始生成 4 个 token,其中蓝色的部分表示 KV Cache 里面缓存的值;红色则表示没有被缓存
KV Cache 的效率分析
KV Cache 加速推理的原理是:在自注意力层,本来每次要做矩阵乘法 $\mathbf Q\mathbf K^T$,现在因为 KV Cache 的存在,我们不需要整个 $\mathbf Q$ 和 $\mathbf K^T$ 做矩阵乘法,只需要每次输入的最后一个 token 的 query 向量 $\mathbf q$ 和 $\mathbf K$ 做向量 - 矩阵乘法,之后更新 KV Cache 缓存即可
采用了 KV Cache 的话 LLM 的推理过程可以看成 2 个阶段
- 第一次迭代的时候,此时 KV Cache 为空,所有的输入的 token 都需要为其计算 key, value, query 向量,其中 key 和 value 会被缓存起来
- 后续的每一次迭代,只需要为新的 token 计算 key、value、query,并更新 KV Cache
KV Cache 加速推理的代价是显存占用会变高,所以它是空间换时间的办法,关于开不开 KV Cache 的显存占用峰值的对比可以看 这里。我在这里放一个总结:
- 用 KV Cache -
2 * hidden_size * num_layers * decoder_length
2
:因为要存储 key 和 value 向量hidden_size
:key 和 value 向量的长度num_layers
:每一层都要缓存 key、value 向量decoder_length
:序列长度
- 不用 KV Cache -
2 * hidden_size * 1 * decoder_length
num_layers
倍。这里用“峰值”这个词是因为,KV Cache 开启之后显存是不断累积增加的;关闭的话每次都会重新计算。所以用“峰值”会更为准确一些KV Cache API
Huggingface 的 model.generate
API 有一个参数为 use_cache
,可以用来控制是否开启 KV Cache,这个选项是默认打开的1