LLM 推理加速 - KV Cache

LLM 用于推理的时候就是不断基于前面的所有 token 生成下一个 token

假设现在已经生成了 $t$ 个 token,用 $x_{1:t}$ 表示。在下一轮,LLM 会生成 $x_{1:t+1}$,注意他们的前 $t$ 个 token 是一样的

$$ x_{1:t+1}=\text{LLM}(x_{1:t}) $$

再下一步也是相似的

$$ x_{1:t+2}=\text{LLM}(x_{1:t+1}) $$

概括来说,每一轮用上一轮的输出当成新的输入让 LLM 预测,一般这个过程会持续到输出达到提前设定的最大长度或者是 LLM 自己生成了特殊的结束 token

信息
LLM 的推理过程很好理解,但是这个简单的实现存在一个问题——存在不少的重复计算导致计算效率不是很高🧐

只需要看 LLM 的连续两次前向传播推理计算就很好理解为什么说存在重复计算了

比如考虑下面这一步

$$ x_{1:t+1}=\text{LLM}(x_{1:t}) $$

LLM 的输入是 $x_{1:t}$,先只看最后一个 token $x_t$,它的 query 向量会和前面的每个 token 以及自己产生的 key 向量计算内积

$$ \mathbf q_{t}^T\mathbf k_{1},\mathbf q_{t}^T\mathbf k_{2},…,\mathbf q_{t}^T\mathbf k_{t} $$

然后看下一步

$$ x_{1:t+2}=\text{LLM}(x_{1:t+1}) $$

LLM 的输入是 $x_{1:t+1}$,看最后一个 token $x_{t+1}$,它的 query 向量也会和前面的每个 token 以及自己产生的 key 向量计算内积

$$ \mathbf q_{t+1}^T\mathbf k_{1},\mathbf q_{t+1}^T\mathbf k_{2},…,\mathbf q_{t+1}^T\mathbf k_{t+1} $$

此时考虑 $x_{t+1}$ 的前一个 token $x_t$,它也要经历类似的步骤

$$ \mathbf q_{t}^T\mathbf k_{1},\mathbf q_{t}^T\mathbf k_{2},…,\mathbf q_{t}^T\mathbf k_{t} $$

可以看到,这个计算完全和上一轮的计算重复了,对于在 $x_t$ 之前的 token 也是这个道理。我们需要重新计算得到 $x_{1:t}$ 的所有 key 向量和 value 向量,而这些向量的值其实是不会变的🧐

Tip

那么我们只需要把之前 token 的 key 向量和 value 向量都缓存起来,那么就没有必要重复计算了,这就是 KV Cache 的核心思想

在运用了 KV Cache 之后,除了第一轮以外的每一轮,我们都只需要关注输入的最后一个 token,用这个 token 计算得到它的 query, key, value 向量,然后拿着这三个新向量和之前所有缓存的 key/value 向量计算自注意力

我画了一个示意图来帮助理解开启 KV Cache 之后发生了什么,考虑 LLM 从无到有开始生成 4 个 token,其中蓝色的部分表示 KV Cache 里面缓存的值;红色则表示没有被缓存

KV Cache 加速推理的原理是:在自注意力层,本来每次要做矩阵乘法 $\mathbf Q\mathbf K^T$,现在因为 KV Cache 的存在,我们不需要整个 $\mathbf Q$ 和 $\mathbf K^T$ 做矩阵乘法,只需要每次输入的最后一个 token 的 query 向量 $\mathbf q$ 和 $\mathbf K$ 做向量 - 矩阵乘法,之后更新 KV Cache 缓存即可

采用了 KV Cache 的话 LLM 的推理过程可以看成 2 个阶段

  1. 第一次迭代的时候,此时 KV Cache 为空,所有的输入的 token 都需要为其计算 key, value, query 向量,其中 key 和 value 会被缓存起来
  2. 后续的每一次迭代,只需要为新的 token 计算 key、value、query,并更新 KV Cache

KV Cache 加速推理的代价是显存占用会变高,所以它是空间换时间的办法,关于开不开 KV Cache 的显存占用峰值的对比可以看 这里。我在这里放一个总结:

  • 用 KV Cache - 2 * hidden_size * num_layers * decoder_length
    • 2:因为要存储 key 和 value 向量
    • hidden_size:key 和 value 向量的长度
    • num_layers每一层都要缓存 key、value 向量
    • decoder_length:序列长度
  • 不用 KV Cache - 2 * hidden_size * 1 * decoder_length
注意
可以看到,开启 KV Cache 之后,显存占用的峰值是原来的 num_layers。这里用“峰值”这个词是因为,KV Cache 开启之后显存是不断累积增加的;关闭的话每次都会重新计算。所以用“峰值”会更为准确一些

Huggingface 的 model.generate API 有一个参数为 use_cache,可以用来控制是否开启 KV Cache,这个选项是默认打开的1