Reading Notes: In-Context Retrieval-Augmented Language Models
The idea
In-Context RALM1 is the RAG technology for Autoregressive LM. In summary, the RAG technology involves using a retriever during model inference to fetch relevant documents, which are then concatenated with the origin input.
In the In-Context Learning setting, some examples are placed before the user’s input, and then they are fed to LLM. Similarly, the In-Context RALM works in a similar way: it directly concatenates the most relevant retrieved document in front of the model’s input. The advantage is that there’s no need to retrain the LLM. A diagram created with Mermaid is shown below.
flowchart subgraph Input direction TB document query end query --> Retriever(Retriever) --> document document ---|Concat| query Input --> LLM --> Output
The procedure of text generation using LLM with RAG technology can be formulated with the following equation.
$$ P_\theta(x_1,x_2,…,x_L)=\prod_{i=1}^L P_\theta\big(x_i|[R_{\mathcal C}(x_{<i});x_{<i}]\big) $$
Where
- The $R_{\mathcal C}(x_{<i})$ is the retrieved document returned by the retriever. The $x_{<i}$ is equal to ${x_1,x_2,…,x_{i-1}}$.
- The $[;]$ is equal to concatenation. So $[x;y]$ means concatenating $x$ and $y$ together.
The equation summarizes the design of In-Context RALM. However, there are a few design choices worth delving into. The authors specifically focus on the following topics.
- What kind of retriever should be used?
- How often should the retrieval be triggered?
- What is the input of the retriever?
- If we use more than one document returned by the retriever, will the model performance become better?
Now let’s talk about the answers as suggested in this paper.
Retriever Choice
The authors investigate two kinds of retrievers.
- Sparse Retriever
- BM25
- Dense Retriever
- a frozen BERT-base model
- Contriever
- Spider
Retrieval Stride
Previously, I posted reading notes for the REALM paper. If you are interested, feel free to click this link to check it out.
For the second question, that is, how often should the retrieval be triggered? A naive approach would be similar to the REALM’s method: Do it once at the very beginning. However, the author of the In-Context RALM proposes an alternative: retrieval should be performed at fixed intervals. Here, the retrieval stride (denoted as $s$) is measured in terms of tokens.
The authors find that: the smaller the $s$ is, the better the text generated by the model, but the computational cost also increases. It can be formulated by the following equation.
$$ P_\theta(x_1,x_2,…,x_L)=\prod_{j=0}^{n_s-1}\prod_{i=1}^{s} P_\theta\big(x_{s\cdot j +i}|[R_{\mathcal C}(x_{\le s\cdot j});x_{<(s\cdot j + i)}]\big) $$
The computational cost can be broken into two parts.
- The cost of retrieving
- The calculation of the embeddings of the retrieved document
Retrieval Query Length
So, what should the retriever’s input be? A simple approach would be to combine the user’s input with all tokens generated by the LLM so far (denoted as $x_{\le x_{s\cdot j}}$ ) as the input of the retriever. However, the authors suggest that the most recent $l$ tokens play a more critical role in guiding the LLM’s generation.
Therefore, the authors propose that when the retriever performs a retrieval of every $s$ token, it only needs to consider the last $l$ tokens, rather than a $x_{\le s\cdot j}$. Formally, this can be expressed as:
$$ P_\theta(x_1,x_2,…,x_L)=\prod_{j=0}^{n_s-1}\prod_{i=1}^{s} P_\theta\big(x_{s\cdot j +i}|[R_{\mathcal C}(x_{s\cdot j-l+1:{s\cdot j}});x_{<(s\cdot j + i)}]\big) $$
Where $s\cdot j-l+1:{s\cdot j}$ stands for $[s\cdot j-l+1,{s\cdot j}]$, that is, the last $l$ tokens
You may be wondering that, should $l$ equal to $s$? The answer is NO!
Reranking
Usually, we would only use the most relevant document after retrieving it. But is the one selected purely based on relevance scores always the best choice? Could we consider retrieving the top-k documents and then performing an additional ranking operation?
That’s precisely what the authors aim to explore.
The first reranking method is to use LLM directly as a document reranker. Let’s say we have k documents to rerank, for each document $d_i$: use the last $s’$ tokens as the target to predict, and remove them from the LLM’s input. The document $d_i$ will be prepended to the LLM’s input. The best document will maximize the probability of the text for generation.
$$ \texttt{argmax}\ p(x_{s\cdot j-s’+1:s\cdot j}|[d_i;x_{\le s\cdot j-s’})) $$
In the first reranking method, the reranker LLM does not necessarily be the same as the LLM for generation. For computational cost considerations, we can use a smaller LLM as a reranker.
The second reranking method is training a reranker, which is called Predictive Reranking
- The input of reranker
- $x_{\le s\cdot j}$
- $d_i$
- The output of reranker: a scalar which means the relevance between $d_i$ and $x_{\le s\cdot j}$
Each document $d_i$ will get its relevance score. By using the $\texttt{softmax}$ function we get a probability distribution among the k documents. The document that has the highest probability is the one we want.
How to get the label of examples and construct a training set?
- Let the LLM model use $x_{\le s\cdot j}$ as input to generate $s$ tokens (denoted as $y$). Notice that the $s$ is the retrieval stride.
- Let the retriever use the last $l$ 个 tokens as input and get k relevant documents.
- The label is the probability of generating $y$ using $[d_i; x_{\le s\cdot j}]$ as the input of the LLM model
Wrap-up
I have summarized the research findings of this paper into a table, as shown below:
Conclusion | Reference |
---|---|
The performance of Predictive Reranking is better | Figure 1 |
The BM25 only consider the lexical information, but performs better than neural retriever model (Recommended setting: $l=32, s=4$) | Figure 3 |
Whatever the size of LLM is, the LLM has better performance when using In-Context RALM | Figure 4 |
The smaller the $s$ is, the better the text generated by the LLM model | Figure 5 |
The last $l$ tokens are enough for the retriever to find a relevant document | Figure 6 |
Rerank is quite helpful | Figure 7 |
More relevant documents ($>1$) only make the LLM’s generation slightly better | Figure 8 |