论文阅读: In-Context Retrieval-Augmented Language Models

In-Context RALM1 是用于 Autoregressive LM 上的 RAG 技术。RAG 说白了就是在模型推理的时候有个 Retriever 检索相关的文档,检索到的文档会和本来的输入拼接在一起

在 In-Context Learning 里面,会把一些例子放在用户输入的前面,再给 LLM。因此不难想象 In-Context RALM 也类似:In-Context RALM 就是将检索到的最相关的文档直接拼在模型输入的前面,优势是不需要再训练 LLM,我用 mermaid 画了一个图,如下所示

flowchart
    subgraph Input
        direction TB
        document
        query
    end
    query --> Retriever(Retriever) --> document
    document ---|Concat| query

    Input --> LLM --> Output

那么带有 RAG 的 LLM 做文本生成的公式表示是

$$ P_\theta(x_1,x_2,…,x_L)=\prod_{i=1}^L P_\theta\big(x_i|[R_{\mathcal C}(x_{<i});x_{<i}]\big) $$

其中

  • $R_{\mathcal C}(x_{<i})$ 表示用输入 ${x_1,x_2,…,x_{i-1}}$ 让 Retriever 检索到的文档,$x_{<i}$ 对应 ${x_1,x_2,…,x_{i-1}}$
  • $[;]$ 表示 Concat 操作,$[x;y]$ 就是将 $x$ 和 $y$ 拼接起来

上面一段话已经概括了 In-Context RALM 的设计(就是这么简单!)。但其实还是有几个细节值得深究,作者着重研究了下面几个

  • 应该用什么 Retriever
  • Retriever 检索的频率应该如何设置
  • Retriever 的输入是什么?即它是根据什么做检索的
  • 如果使用更多 Retriever 检索到的文档而不止是最相关的,效果会更好吗

下面我们一一来看上面问题的研究结论 :)

作者研究了 2 种 Retriever

  • Sparse Retriever
    • BM25
  • Dense Retriever
    • 冻结参数的 BERT-base
    • Contriever
    • Spider
Info

REALM 我之前也发过一篇论文阅读笔记,感兴趣的可以查看该链接

Retrieval Stride 要解决的是“Retriever 检索的频率应该如何设置”这个问题,像 REALM 一样在一开始的时候做检索(而且只做一次)当然可以,但作者提出:应该固定步长就做一次检索,这里的步长(用 $s$ 表示)用 token 数量来度量

作者发现,$s$ 越小,模型生成的文本越好,但开销也会越大,这里存在 trade-off。每隔 $s$ 个 token 的情况做一次检索的情况下,带有 RAG 的 LLM 的文本生成可以用公式表示为

$$ P_\theta(x_1,x_2,…,x_L)=\prod_{j=0}^{n_s-1}\prod_{i=1}^{s} P_\theta\big(x_{s\cdot j +i}|[R_{\mathcal C}(x_{\le s\cdot j});x_{<(s\cdot j + i)}]\big) $$

这里说的开销包括两个部分

  • Retriever 检索的开销
  • 检索到的文档需要重新让 LLM 计算 Embedding

那么 Retriever 的输入应该是什么呢?一个朴素的想法是,将用户的输入和和截止目前为止 LLM 生成的 token(这两者可以用 $x_{\le x_{s\cdot j}}$ 表示) 一起作为 Retriever 的输入,但作者认为最后几个 token 对 LLM 的生成更重要

因此作者想到,每隔 $s$ 个 token 要让 Retriever 去检索的时候只需要给最后 $l$ 个 token 就可以,并不需要给当前所有的 token(即 $x_{\le s\cdot j}$)。用公式表示就是

$$ P_\theta(x_1,x_2,…,x_L)=\prod_{j=0}^{n_s-1}\prod_{i=1}^{s} P_\theta\big(x_{s\cdot j +i}|[R_{\mathcal C}(x_{s\cdot j-l+1:{s\cdot j}});x_{<(s\cdot j + i)}]\big) $$

其中 $s\cdot j-l+1:{s\cdot j}$ 表示 $[s\cdot j-l+1,{s\cdot j}]$,也就是当前位置的前 $l$ 个 token

Faq

$l$ 应该等于 $s$ 吗?作者发现 $s=l$ 的话,效果会不好

Retriever 会对语料库中的文档都计算相关性,一般只使用最相关的文档,但直接根据相关性计算出来的就一定最好吗?是否可以考虑返回 Top-k 个然后重新做排序操作

作者研究了 2 种排序方法

第一种,直接用 LLM 判断,假设要对 k 个文档排序,那么对于每一个文档 $d_i$:取 $x_{s\cdot j}$ 的最后 $s’$ 个 token 作为预测目标,并从 LLM 输入中去掉这 $s’$ 个 token。文档 $d_i$ 会被添加到 LLM 输入的前面,然后再让 LLM 预测。最好的文档就是会让这 $s’$ 个 token 出现频率最高的

$$ \texttt{argmax}\ p(x_{s\cdot j-s’+1:s\cdot j}|[d_i;x_{\le s\cdot j-s’})) $$

Tip

第一种方法里面用于排序的 LLM 并没有要求要和负责生成文本的 LLM 一样。因此出于性能的考虑,负责排序的 LLM 完全可以用更小的模型

第二种,直接训练一个 reranker,作者称其为 Predictive Reranking

  • Reranker 输入
    • $x_{\le s\cdot j}$
    • $d_i$
  • Reranker 返回:一个标量,表示 $d_i$ 和 $x_{\le s\cdot j}$ 的相关性

那么每个文档 $d_i$ 都会有个相关性的得分,经过 $\texttt{softmax}$ 之后 k 个文档的得分就形成了一个概率分布,概率最大的文档就是相关性最强的

Tip

那么训练集的样本的 label 从哪里来呢?

  1. 让 LLM 模型用 $x_{\le s\cdot j}$ 生成 $s$ 个 token(用 $y$ 表示),注意这里的 $s$ 是步长
  2. 让 Retriever 用最后 $l$ 个 token 检索到 k 个文档
  3. 让 LLM 根据 $[d_i; x_{\le s\cdot j}]$ 预测 $y$,计算概率,就得到了 label

我将本文的研究结论汇总成一张表格,如下所示

Conclusion Reference
Predictive Reranking 的效果会更好一些 Figure 1
BM25 虽然只考虑了语法信息,但是作为 Retriever 比基于神经网络的效果还要好(推荐设置 $l=32, s=4$) Figure 3
不管 LLM 的模型大小,用了 In-Context RALM 的 RAG 技术之后,LLM 生成的文本的困惑度都下降了 Figure 4
检索的频率越高(即 $s$ 越小),效果越好(极限情况下每个 token 生成的时候都检索一次) Figure 5
给 Retriever 的输入只需要 LLM 输入的最后 $l$ 个 token 效果就很好了 Figure 6
对 Retriever 返回的文档进行重新排序是有帮助的,提升效果明显 Figure 7
只需要使用 Retriever 返回的最相关的文档就可以了,更多文档的提升并不明显 Figure 8

  1. Ram, Ori, et al. “In-context retrieval-augmented language models.” Transactions of the Association for Computational Linguistics 11 (2023): 1316-1331. src ↩︎