论文阅读: In-Context Retrieval-Augmented Language Models
The idea
In-Context RALM1 是用于 Autoregressive LM 上的 RAG 技术。RAG 说白了就是在模型推理的时候有个 Retriever 检索相关的文档,检索到的文档会和本来的输入拼接在一起
在 In-Context Learning 里面,会把一些例子放在用户输入的前面,再给 LLM。因此不难想象 In-Context RALM 也类似:In-Context RALM 就是将检索到的最相关的文档直接拼在模型输入的前面,优势是不需要再训练 LLM,我用 mermaid 画了一个图,如下所示
flowchart subgraph Input direction TB document query end query --> Retriever(Retriever) --> document document ---|Concat| query Input --> LLM --> Output
那么带有 RAG 的 LLM 做文本生成的公式表示是
$$ P_\theta(x_1,x_2,…,x_L)=\prod_{i=1}^L P_\theta\big(x_i|[R_{\mathcal C}(x_{<i});x_{<i}]\big) $$
其中
- $R_{\mathcal C}(x_{<i})$ 表示用输入 ${x_1,x_2,…,x_{i-1}}$ 让 Retriever 检索到的文档,$x_{<i}$ 对应 ${x_1,x_2,…,x_{i-1}}$
- $[;]$ 表示 Concat 操作,$[x;y]$ 就是将 $x$ 和 $y$ 拼接起来
上面一段话已经概括了 In-Context RALM 的设计(就是这么简单!)。但其实还是有几个细节值得深究,作者着重研究了下面几个
- 应该用什么 Retriever?
- Retriever 检索的频率应该如何设置?
- Retriever 的输入是什么?即它是根据什么做检索的?
- 如果使用更多 Retriever 检索到的文档而不止是最相关的,效果会更好吗?
下面我们一一来看上面问题的研究结论 :)
Retriever Choice
作者研究了 2 种 Retriever
- Sparse Retriever
- BM25
- Dense Retriever
- 冻结参数的 BERT-base
- Contriever
- Spider
Retrieval Stride
REALM 我之前也发过一篇论文阅读笔记,感兴趣的可以查看该链接
Retrieval Stride 要解决的是“Retriever 检索的频率应该如何设置”这个问题,像 REALM 一样在一开始的时候做检索(而且只做一次)当然可以,但作者提出:应该固定步长就做一次检索,这里的步长(用 $s$ 表示)用 token 数量来度量
作者发现,$s$ 越小,模型生成的文本越好,但开销也会越大,这里存在 trade-off。每隔 $s$ 个 token 的情况做一次检索的情况下,带有 RAG 的 LLM 的文本生成可以用公式表示为
$$ P_\theta(x_1,x_2,…,x_L)=\prod_{j=0}^{n_s-1}\prod_{i=1}^{s} P_\theta\big(x_{s\cdot j +i}|[R_{\mathcal C}(x_{\le s\cdot j});x_{<(s\cdot j + i)}]\big) $$
这里说的开销包括两个部分
- Retriever 检索的开销
- 检索到的文档需要重新让 LLM 计算 Embedding
Retrieval Query Length
那么 Retriever 的输入应该是什么呢?一个朴素的想法是,将用户的输入和和截止目前为止 LLM 生成的 token(这两者可以用 $x_{\le x_{s\cdot j}}$ 表示) 一起作为 Retriever 的输入,但作者认为最后几个 token 对 LLM 的生成更重要
因此作者想到,每隔 $s$ 个 token 要让 Retriever 去检索的时候只需要给最后 $l$ 个 token 就可以,并不需要给当前所有的 token(即 $x_{\le s\cdot j}$)。用公式表示就是
$$ P_\theta(x_1,x_2,…,x_L)=\prod_{j=0}^{n_s-1}\prod_{i=1}^{s} P_\theta\big(x_{s\cdot j +i}|[R_{\mathcal C}(x_{s\cdot j-l+1:{s\cdot j}});x_{<(s\cdot j + i)}]\big) $$
其中 $s\cdot j-l+1:{s\cdot j}$ 表示 $[s\cdot j-l+1,{s\cdot j}]$,也就是当前位置的前 $l$ 个 token
$l$ 应该等于 $s$ 吗?作者发现 $s=l$ 的话,效果会不好
Reranking
Retriever 会对语料库中的文档都计算相关性,一般只使用最相关的文档,但直接根据相关性计算出来的就一定最好吗?是否可以考虑返回 Top-k 个然后重新做排序操作?
作者研究了 2 种排序方法
第一种,直接用 LLM 判断,假设要对 k 个文档排序,那么对于每一个文档 $d_i$:取 $x_{s\cdot j}$ 的最后 $s’$ 个 token 作为预测目标,并从 LLM 输入中去掉这 $s’$ 个 token。文档 $d_i$ 会被添加到 LLM 输入的前面,然后再让 LLM 预测。最好的文档就是会让这 $s’$ 个 token 出现频率最高的
$$ \texttt{argmax}\ p(x_{s\cdot j-s’+1:s\cdot j}|[d_i;x_{\le s\cdot j-s’})) $$
第一种方法里面用于排序的 LLM 并没有要求要和负责生成文本的 LLM 一样。因此出于性能的考虑,负责排序的 LLM 完全可以用更小的模型
第二种,直接训练一个 reranker,作者称其为 Predictive Reranking
- Reranker 输入
- $x_{\le s\cdot j}$
- $d_i$
- Reranker 返回:一个标量,表示 $d_i$ 和 $x_{\le s\cdot j}$ 的相关性
那么每个文档 $d_i$ 都会有个相关性的得分,经过 $\texttt{softmax}$ 之后 k 个文档的得分就形成了一个概率分布,概率最大的文档就是相关性最强的
那么训练集的样本的 label 从哪里来呢?
- 让 LLM 模型用 $x_{\le s\cdot j}$ 生成 $s$ 个 token(用 $y$ 表示),注意这里的 $s$ 是步长
- 让 Retriever 用最后 $l$ 个 token 检索到 k 个文档
- 让 LLM 根据 $[d_i; x_{\le s\cdot j}]$ 预测 $y$,计算概率,就得到了 label
总结
我将本文的研究结论汇总成一张表格,如下所示
Conclusion | Reference |
---|---|
Predictive Reranking 的效果会更好一些 | Figure 1 |
BM25 虽然只考虑了语法信息,但是作为 Retriever 比基于神经网络的效果还要好(推荐设置 $l=32, s=4$) | Figure 3 |
不管 LLM 的模型大小,用了 In-Context RALM 的 RAG 技术之后,LLM 生成的文本的困惑度都下降了 | Figure 4 |
检索的频率越高(即 $s$ 越小),效果越好(极限情况下每个 token 生成的时候都检索一次) | Figure 5 |
给 Retriever 的输入只需要 LLM 输入的最后 $l$ 个 token 效果就很好了 | Figure 6 |
对 Retriever 返回的文档进行重新排序是有帮助的,提升效果明显 | Figure 7 |
只需要使用 Retriever 返回的最相关的文档就可以了,更多文档的提升并不明显 | Figure 8 |