One for all: the torch.einsum API

In PyTorch, multiple APIs exist for matrix multiplication operations. However, these functions often lead to memorization challenges. Additionally, many of these APIs require explicit dimension manipulation (e.g., permuting, reshaping)

Does there exist a magic API that can cover all the use cases? A potential unified solution is the torch.einsum API.

The syntax of torch.einsum is

python

torch.einsum(<einstein_notations>, input1, input2, ...)

Where input1, input2, ... are all the tensors involved in the computation, and <einstein_notations> is a string following the Einstein Summation Convention format.

text

"<dimensions_of_input_1>,<dimensions_of_input_2>-><dimensions_of_output>"

This string representation can be divided into two parts.

  • On the left-hand side of -> is <dimensions_of_intput_*>, which, as the name suggests, represents the dimension information of each input.
  • On the right-hand side of -> is <dimensions_of_output>, specifying the expected output dimensions.

The dimension information is represented as a string of letters. For example, the dimensions of $\mathbf M\in\mathcal R^{i\times j}$ are written as ij.

When using the Einstein Summation Convention, the following rules apply:

  • If the same index letter appears in two different tensors within a product, their corresponding dimensions are multiplied element-wise and then summed over that index. These repeated indices must not appear in the final output.
  • If some letters appear in the input but not in the output, these dimensions will be summed
  • A single tensor cannot have repeated indices.
Info

The torch.einsum API is best taught by examples.

Let’s take matrix multiplication as an example. The inputs are $\mathbf M\in\mathcal R^{i\times j}$ and $\mathbf N\in\mathcal R^{j\times k}$. Then the function call of matrix multiplication will be

python

torch.einsum("ik,kj->ij", M, N)

To perform summation over the rows of matrix $\mathbf M$ (resulting in a column vector):

python

torch.einsum("ij->i", M)

Similarly, summation over the columns of $\mathbf M$ (resulting in a row vector) is

python

torch.einsum("ij->i", M)

The torch.einsum API can also transpose a matrix.

python

torch.einsum("ij->ji", M)

It also supports the element-wise product.

python

torch.einsum("ij,ij->ij", M, M)

Let’s examine the batch matrix-matrix product in more detail as a more complex example. Here, $\mathbf M$ has dimensions $b\times i\times j$, and $\mathbf N$ has dimensions $b\times j\times k$

python

torch.einsum("bij,bjk->bik", M, N)
Tip

When specifying dimensions, you can use ... to represent dimensions of no interest. Note that each dimension specification can contain ... at most once. Thus, the batch matrix-matrix product can alternatively be expressed by

python

torch.einsum("...ij,...jk->...ik", M, N)

Let’s use $\mathbf v$ to represent a vector of length $d$. To compute its inner product, use the following code.

python

torch.einsum("i,i->", v, v)

I like the torch.einsum API for its simplicity. It gives me the experience of declarative programming, that is, you describe the input and output and it will automatically figure out how to do the calculates :)