神奇的 torch.einsum API
Motivations
在 PyTorch 里面存在着很多跟矩阵乘法、矩阵向量乘法等操作相关的 API,这对记忆来说是一种负担。并且,在使用这些 API 的过程中经常需要对矩阵进行 reshape 等操作,确保维度信息对得上
有没有一个神奇的 API 可以覆盖所有可能的情况呢?答案是有的,那就是 torch.einsum
torch.einsum 是什么?
torch.einsum
的用法是
torch.einsum(<einstein_notations>, input1, input2, ...)
其中 input1, input2, ...
是所有参与运算的 tensor, <einstein_notations>
指的是爱因斯坦求和约定,是 1 个字符串,遵循如下的格式
"<dimensions_of_input_1>,<dimensions_of_input_2>-><dimensions_of_output>"
可以将上面的式子拆成两个部分
- 在
->
左侧是<dimensions_of_intput_*>
,即每个输入的维度信息 - 在
->
右侧是<dimensions_of_output>
,即预期输出的维度信息
这里说的维度信息是用英文字母表示的字符串,比如 $\mathbf M\in\mathcal R^{i\times j}$ 就写成 ij
写爱因斯坦求和约定的时候记住如下几点
- 不同输入的维度信息如果用到了相同的字母(比如 $\mathbf M_{ik}\mathbf N_{kj}$ 里面的 $k$),那么对应的维度上会先按元素相乘然后求和,并且这些字母不会在输出的维度信息里面出现
- 如果输入的维度信息用到了某个字母但是这些字母没有在输出的维度信息上出现,那么这些字母上会做求和
- 同一个项(输入或者输出)用到的字母不能有重复
Various Cases
通过例子学习是最快上手 torch.einsum
的方法
以矩阵乘法为例,假设输入是 $\mathbf M\in\mathcal R^{i\times j}$ 和 $\mathbf N\in\mathcal R^{j\times k}$,那么矩阵乘法可以写为
torch.einsum("ik,kj->ij", M, N)
如果想要对矩阵 $\mathbf M$ 进行行求和,可以用
torch.einsum("ij->i", M)
同理,列求和是
torch.einsum("ij->i", M)
也可以不用求和的功能,比如只是想要转置一下矩阵 $\mathbf M$
torch.einsum("ij->ji", M)
element-wise 相乘也是支持的
torch.einsum("ij,ij->ij", M, M)
在深度学习中,经常会遇到的是 batch matrix-matrix product。现在假设矩阵 $\mathbf M$ 的维度变成了 $b\times i\times j$,矩阵 $\mathbf N$ 的维度变成了 $b\times j\times k$,不难写出如下的代码
torch.einsum("bij,bjk->bik", M, N)
在书写每个项(输入或者输出)维度信息的时候,都可以用 ...
表示不关心的维度,但每个项最多只能用一次。,根据这个技巧,batch-matrix-matrix product 的代码也可以写为
torch.einsum("...ij,...jk->...ik", M, N)
现在假设有向量 $\mathbf v\in\mathcal R^{d}$,想要计算它的 内积 可以这么写
torch.einsum("i,i->", v, v)
总结
就我个人而言,我很喜欢 torch.einsum
这个 API,因为写爱因斯坦求和记号约定的时候其实是在做声明式编程,就像你写 SQL 一样。此刻的你只会关心输入、输出是什么,而不关心具体要用什么操作完成,即关心的是 what,不关心的是 how :)