Contents

Transformer architecture variation: Rotary Position Embedding (RoPE)

In self-attention, the query ($\mathbf q_m$), key ($\mathbf k_n$), and value ($\mathbf v_n$) vectors are computed as follows:

$$ \begin{equation} \begin{aligned} \mathbf q_m&=f_q(\mathbf x_m,m)\\ \mathbf k_n&=f_k(\mathbf x_n,n)\\ \mathbf v_n&=f_v(\mathbf x_n,n) \end{aligned} \end{equation} $$

Here, the $\mathbf x_i$ is the $i$-th token embedding, while $n$ and $m$ denote different positions.

The attention score between position $m$ and $n$ is computed as:

$$ \alpha_{m,n}=\frac{exp(\frac{\mathbf q_m^T\mathbf k_n}{\sqrt d})}{\sum_{j=1}^Nexp(\frac{\mathbf q_m^T\mathbf k_j}{\sqrt d})} $$

Here, $N$ denotes the input sequence of the LLM.

Each position $i$ has a corresponding attention score $\alpha_{m,i}$. Using these attention scores, the updated representation ($\mathbf o_m$) for position $m$ is computed as

$$ \mathbf o_m=\sum_{i=1}^N\alpha_{m,i}\mathbf v_n $$

The core idea behind the self-attention mechanism lies in computing the attention score between a query and a key vector. The authors of RoPE propose a method (denoted as $g$) that ensures the attention score depends solely on the relative distance $n-m$ (in addition to the input $\mathbf q_m, \mathbf k_n$).

That is

$$ \langle f_q(\mathbf x_m,m),f_k(\mathbf x_n,n)\rangle=g(\mathbf x_m,\mathbf x_n,m-n) $$

If vectors $\mathbf x_m, \mathbf x_n$ have length 2, there exists a method that satisfies the requirements: first applying a weight transformation and then multiply by a rotation matrix $\mathbf R$. For example, the equation of $f_q$ is

$$ \begin{split} f_q(\mathbf x_m,m)&=\mathbf{R}_m(\mathbf W_q\mathbf x_m) \\ &=\begin{pmatrix} cos\ m\theta & -sin\ m\theta \\ sin\ m\theta & cos\ m\theta \end{pmatrix} \begin{pmatrix} \mathbf W_q\mathbf x_m \end{pmatrix} =\mathbf q_m \end{split} $$

It can be observed that for a vector $\mathbf x_m$ at position $m$, the rotation angle is always $m\theta$.

Following the same manner, the equation $f_k$ can be calculated by

$$ \begin{split} f_k(\mathbf x_n,n)&=\mathbf{R}_n(\mathbf W_k\mathbf x_n) \\ &=\begin{pmatrix} cos\ n\theta & -sin\ n\theta \\ sin\ n\theta & cos\ n\theta \end{pmatrix} (\mathbf W_k\mathbf x_n) =\mathbf k_n \end{split} $$

Here’s the inner product value between $\mathbf q_m$ and $\mathbf k_n$:

$$ \begin{split} \mathbf q_m^T\mathbf k_n&=(\mathbf R_m(\mathbf W_q\mathbf x_m))^T(\mathbf R_n(\mathbf W_k\mathbf x_n))\\ &=(\mathbf W_q\mathbf x_m)^T\mathbf R_m^T\mathbf R_n(\mathbf W_k\mathbf x_n) \\ &=(\mathbf W_q\mathbf x_m)^T\mathbf R_m^{-1}\mathbf R_n(\mathbf W_k\mathbf x_n) \\ &=(\mathbf W_q\mathbf x_m)^T\mathbf R_{-m}\mathbf R_n(\mathbf W_k\mathbf x_n) \\ &=(\mathbf W_q\mathbf x_m)^T\mathbf R_{-m}\mathbf R_n(\mathbf W_k\mathbf x_n) \\ &=(\mathbf W_q\mathbf x_m)^T\mathbf R_{n-m}(\mathbf W_k\mathbf x_n) \\ \end{split} $$

It can be observed that the attention score only depends on the $n-m$.

Tip

The derivation here uses the following properties of the rotation matrix:

$$ \mathbf R_m\mathbf R_n=\mathbf R_{m+n} $$

and

$$ \mathbf R^{-1}=\mathbf R^T $$

and

$$ \mathbf R_m^{-1}=\mathbf R_{-m} $$

However, the vectors $\mathbf x_m,\mathbf x_n$ has length $d$. What should the rotation matrix look like? The answer is that the input can be divided into $d/2$ pairs.

$$ \mathbf R=\begin{pmatrix} cos\ m\theta_1 & -sin\ m\theta_1 & 0 & 0 & … & 0 & 0 \\ sin\ m\theta_1 & cos\ m\theta_1 & 0 & 0 & … & 0 & 0 \\ 0 & 0 & cos\ m\theta_2 & -sin\ m\theta_2 & … & 0 & 0 \\ 0 & 0 & sin\ m\theta_2 & cos\ m\theta_2 & … & 0 & 0 \\ … & … & … &… &… &… &… \\ 0 & 0 & 0 & 0 & … & cos\ m\theta_{d/2} & -sin\ m\theta_{d/2} \\ 0 & 0 & 0 & 0 & … & sin\ m\theta_{d/2} & cos\ m\theta_{d/2} \\ \end{pmatrix} $$

Tip

The matrix $\mathbf R$ here is also a rotation matrix because it satisfy $\mathbf R\mathbf R^{T}=\mathbf I$ and $\det (\mathbf R)=1$.

The equation of $\theta_i$ is

$$ \theta_i=10000^{-2(i-1)/d} $$

Where $i=1,2,…,d/2$.

Since $\mathbf R$ is a sparse matrix, direct matrix multiplication would be inefficient. Instead, we use the following equivalent form for faster computation:

$$ \mathbf R\mathbf x= \begin{pmatrix} x_1\\x_2\\x_3\\x_4\\…\\x_{d-1}\\x_d \end{pmatrix}\otimes \begin{pmatrix} cos\ m\theta_1 \\ cos\ m\theta_1\\cos\ m\theta_2\\cos\ m\theta_2\\…\\cos\ m\theta_{d/2}\\cos\ m\theta_{d/2} \end{pmatrix}+ \begin{pmatrix} -x_2\\x_1\\-x_4\\x_3\\…\\-x_{d}\\x_{d-1} \end{pmatrix}\otimes \begin{pmatrix} sin\ m\theta_1 \\ sin\ m\theta_1\\sin\ m\theta_2\\sin\ m\theta_2\\…\\sin\ m\theta_{d/2}\\sin\ m\theta_{d/2} \end{pmatrix} $$

There exist two different implementations:

  • A pair can be reprensented by $x_{2i+1}, x_{2i+2}$. This implementation corresponds to the aforementioned explanation. You can refer to the LLaMA’s rotation implementation here.
  • A pair can be reprensented by $x_{i}, x_{i+d/2}$. The code can be found in this amazing blog.

To give you intuition as to why the second form has the same effects, let’s expand $\mathbf R\mathbf x$ here.

$$ \mathbf R\mathbf x= \begin{pmatrix} cos\ m\theta_1 & 0 & 0 & …& -sin\ m\theta_1 & 0 & & 0 & \\ 0 & cos\ m\theta_2 & 0 & … & 0 & -sin\ m\theta_2 & 0 & 0 \\ 0 & 0 & … & … & 0 & 0 & … & 0 \\ … & … & … & cos\ m\theta_{d/2} & … & … & … &-sin\ m\theta_{d/2} \\ sin\ m\theta_1 & 0 & 0 & …& cos\ m\theta_1 & 0 & & 0 & \\ 0 & sin\ m\theta_2 & 0 & … & 0 & cos\ m\theta_2 & 0 & 0 \\ 0 & 0 & … & … & 0 & 0 & … & 0 \\ … & … & … & sin\ m\theta_{d/2} & … & … & … &cos\ m\theta_{d/2} \\ \end{pmatrix} \times \begin{pmatrix} x_1\\x_2\\…\\x_{d/2}\\x_{d/2+1}\\x_{d/2+2}\\…\\x_d \end{pmatrix} $$

We can prove that the $\mathbf R$ here is also a rotation matrix.

  1. $\mathbf R\mathbf R^T=\mathbf I$
  2. $\det (\mathbf R) = 1$. The matrix can be transformed to match the first implementation’s definition via $d/2-1$ row swaps and $d/2-1$ column swaps (totaling $d-2$ swaps, which is even since $d$ is even). As each swap pair (of either rows or columns) multiplies the determinant by -1, the determinant remains 1 after an even number of swaps.

It has the equivalent but more efficient form.

$$ \mathbf R\mathbf x= \begin{pmatrix} x_1\\x_2\\…\\x_{d/2}\\x_{d/2+1}\\…\\x_{d-1}\\x_d \end{pmatrix}\otimes \begin{pmatrix} cos\ m\theta_1 \\ cos\ m\theta_2\\…\\cos\ m\theta_{d/2}\\cos\ m\theta_{1}\\…\\cos\ m\theta_{d/2-1}\\cos\ m\theta_{d/2} \end{pmatrix} + \begin{pmatrix} -x_{d/2+1}\\-x_{d/2+2}\\…\\-x_d\\x_1\\x_2\\…\\x_{d/2} \end{pmatrix}\otimes \begin{pmatrix} sin\ m\theta_1 \\ sin\ m\theta_2\\…\\sin\ m\theta_{d/2}\\sin\ m\theta_{1}\\…\\sin\ m\theta_{d/2-1}\\sin\ m\theta_{d/2} \end{pmatrix} $$