论文阅读: Generalization through Memorization: Nearest Neighbor Language Models

语言模型解决 2 种问题

  1. 用一个特征向量表示句子前缀
  2. 使用该特征向量预测下一个 token

本文提出的 $k\texttt{NN-LM}$ 基于这么一个假设:学习特征向量表示比预测下一个 token,因此本文的方法主要基于该假设进行设计

下面这张图概括了 $k\texttt{NN-LM}$ 的思想1

使用 $k\texttt{NN-LM}$ 需要先对语料库的文本进行处理,分成下面几个步骤,下面我会结合一个详细的例子来进行说明,比如语料库里面有这么一个句子

text

Today is a good [day]

这句子可以被分为 2 个部分

  • Today is a good day 是上下文 $c_i$(Context),把这个上下文扔给 LLM,就可以得到一个向量表示 $f(c_i)$
  • day 是下一个 token $w_i$(Target)

这样就得到了一个 $(f(c_i),w_i)$ 的 pair,可以存储在 KV 数据库里面,在论文里叫做 Datastore

Info

KNN 算法我之前写过一篇博客介绍,在这里

那么在模型推理的时候如何使用这个 KV 数据库呢?假设用 $x$ 表示 LLM 的输入,那么步骤如下

  1. 用 LLM 生成它的向量表示 $f(x)$
  2. 用 $f(x)$ 作为 Query,在 KV 数据库里面用 KNN 算法找到最近的 $k$ 个邻居(用 $\mathcal N$ 表示)。换句话说,找和 LLM 输入 $x$ 最相似 $k$ 个上下文,注意每一个上下文都有下一个 token
  3. 假设邻居 $i$ 到 $f(x)$ 的距离是 $d_i$,$k$ 个邻居的距离就得到了一个距离向量数组 $[d_1, d_2,…,d_k]$,将距离取负然后做 Softmax 就得到了概率分布,这是关于距离的概率分布,但也可以转化为关于每个邻居背后的 token $w_i$ 的概率 $[p_1,p_2, p_3,…,p_k]$。注意这里可能不同邻居的下一个 token 是一样的,那么对应的概率要相加(可以看前面的图,有 2 个 Hawaii)
  4. 此时再把 Vocab $\mathcal V$ 的其他 token 的概率设置为 0,就得到了在 Vocab $\mathcal V$ 上的概率分布 $p_\texttt{kNN}(y|x)$,一个长度为 $|\mathcal V|$ 的向量

而 LLM 也会为输入 $x$ 输出下一个 token 的概率分布 $p_\texttt{LM}(y|x)$

引入一个超参数 $\lambda$ 综合这两个概率分布

$$ p(y|x)=\lambda p_\texttt{kNN}(y|x)+(1-\lambda)p_\texttt{LM}(y|x) $$

用这下一个 token 的概率分布预测下一个 token 即可

FAISS 快速检索 KV 数据库中的邻居

距离的计算则采用欧几里得距离而不是向量内积,因为作者发现这个效果更好

根据实验结果,有如下几个发现(知识库指的就是 KV 数据库)

发现 索引
将 LLM 的训练数据作为知识库能降低 LLM 输出的困惑度 Table 1
在小数据集上微调 LLM,并外挂一个大知识库,比直接在这个大知识库上做训练效果更好 Table 3
最好的上下文表示是:LLM 的最后一个 transformer 层其中的 FFN 层的输入 Figure 3
检索的邻居越多,效果越好 Figure 4
如果知识库是同个领域的,那么就把 $\lambda$ 设置得小一点,反之就设置得大一点
Figure 5

在了解了 $k\texttt{NN-LM}$ 的原理之后,我们可以尝试理解一下它的优缺点

优点是 1

  • 能够捕捉到文本中一些罕见模式,对 out-of-domain 的文本生成效果也比较好,这体现在 $k\texttt{NN}$ 和 $\texttt{LM}$ 的协同
    • $k\texttt{NN}$ 可以捕捉到:上下文相似,下一个 token 一样,但 token 的含义不同
    • $\texttt{LM}$ 可以捕捉到:上下文不同,下一个 token 一样

不足点是 1

  1. 空间开销大,因为给定语料库我们是为每个 token 构造 KV pair,所以 KV 数据库的大小跟跟 token 数量挂钩。所以在 Scalability 上 $k-\texttt{NN-LM}$ 存在比较大的问题
  2. 在模型输入和检索结果之间没有注意力的计算,这会降低模型的表现