One for all: the torch.einsum API
Motivations
In PyTorch, multiple APIs exist for matrix multiplication operations. However, these functions often lead to memorization challenges. Additionally, many of these APIs require explicit dimension manipulation (e.g., permuting, reshaping)
Does there exist a magic API that can cover all the use cases? A potential unified solution is the torch.einsum
API.
What is torch.einsum ?
The syntax of torch.einsum
is
torch.einsum(<einstein_notations>, input1, input2, ...)
Where input1, input2, ...
are all the tensors involved in the computation, and <einstein_notations>
is a string following the Einstein Summation Convention format.
"<dimensions_of_input_1>,<dimensions_of_input_2>-><dimensions_of_output>"
This string representation can be divided into two parts.
- On the left-hand side of
->
is<dimensions_of_intput_*>
, which, as the name suggests, represents the dimension information of each input. - On the right-hand side of
->
is<dimensions_of_output>
, specifying the expected output dimensions.
The dimension information is represented as a string of letters. For example, the dimensions of $\mathbf M\in\mathcal R^{i\times j}$ are written as ij
.
When using the Einstein Summation Convention, the following rules apply:
- If the same index letter appears in two different tensors within a product, their corresponding dimensions are multiplied element-wise and then summed over that index. These repeated indices must not appear in the final output.
- If some letters appear in the input but not in the output, these dimensions will be summed
- A single tensor cannot have repeated indices.
Various Cases
The torch.einsum
API is best taught by examples.
Let’s take matrix multiplication as an example. The inputs are $\mathbf M\in\mathcal R^{i\times j}$ and $\mathbf N\in\mathcal R^{j\times k}$. Then the function call of matrix multiplication will be
torch.einsum("ik,kj->ij", M, N)
To perform summation over the rows of matrix $\mathbf M$ (resulting in a column vector):
torch.einsum("ij->i", M)
Similarly, summation over the columns of $\mathbf M$ (resulting in a row vector) is
torch.einsum("ij->i", M)
The torch.einsum
API can also transpose a matrix.
torch.einsum("ij->ji", M)
It also supports the element-wise product.
torch.einsum("ij,ij->ij", M, M)
Let’s examine the batch matrix-matrix product in more detail as a more complex example. Here, $\mathbf M$ has dimensions $b\times i\times j$, and $\mathbf N$ has dimensions $b\times j\times k$
torch.einsum("bij,bjk->bik", M, N)
When specifying dimensions, you can use ...
to represent dimensions of no interest. Note that each dimension specification can contain ...
at most once. Thus, the batch matrix-matrix product can alternatively be expressed by
torch.einsum("...ij,...jk->...ik", M, N)
Let’s use $\mathbf v$ to represent a vector of length $d$. To compute its inner product, use the following code.
torch.einsum("i,i->", v, v)
Wrap-up
I like the torch.einsum
API for its simplicity. It gives me the experience of declarative programming, that is, you describe the input and output and it will automatically figure out how to do the calculates :)