einsum方法详解(爱因斯坦求和)
einsum方法详解(爱因斯坦求和)einsum是pytorch、numpy中一个十分优雅的方法,如果利用得当,可完全代替所有其他的矩阵计算方法,不过这需要一定的学习成本。本文旨在详细解读einsum方法的原理,并给出一些基本示例。一、爱因斯坦求和爱因斯坦求和是一种对求和公式简洁高效的记法,其原则是当变量下标重复出现时,即可省略繁琐的求和符号。比如求和公式:∑i=1naibi=a1b1+a2b2+
einsum方法详解(爱因斯坦求和)
einsum是pytorch、numpy中一个十分优雅的方法,如果利用得当,可完全代替所有其他的矩阵计算方法,不过这需要一定的学习成本。本文旨在详细解读einsum方法的原理,并给出一些基本示例。
一、爱因斯坦求和
爱因斯坦求和是一种对求和公式简洁高效的记法,其原则是当变量下标重复出现时,即可省略繁琐的求和符号。
比如求和公式:
∑
i
=
1
n
a
i
b
i
=
a
1
b
1
+
a
2
b
2
+
.
.
.
+
a
n
b
n
\begin{aligned} \sum_{i=1}^na_ib_i=a_1b_1+a_2b_2+...+a_nb_n \end{aligned}
i=1∑naibi=a1b1+a2b2+...+anbn
其中变量
a
a
a与变量
b
b
b的下标是相同的(即重复出现),则可将其记为:
a
i
b
i
=
∑
i
=
1
n
a
i
b
i
\begin{aligned} a_ib_i=\sum_{i=1}^na_ib_i \end{aligned}
aibi=i=1∑naibi
二、einsum方法原理
einsum方法正是利用了爱因斯坦求和简介高效的表示方法,从而可以驾驭任何复杂的矩阵计算操作。基本的框架如下:
C = einsum('ij,jk->ik', A, B)
上述操作表示矩阵A与矩阵B的点积。输入的参数分为两部分,前面表示计算操作的字符串,后面是以逗号隔开的操作对象(数量需与前面对应)。其中在计算操作表示中,"->“左边是以逗号隔开的下标索引,重复出现的索引即是需要爱因斯坦求和的;”->"右边的是最后输出的结果形式。
以上式为例,其计算公式为:
C
i
k
=
∑
j
A
i
j
B
j
k
C_{ik} = \sum_jA_{ij}B_{jk}
Cik=∑jAijBjk,其等价于矩阵A与B的点积。
这里有几条原则需要注意,之后也会和结合示例进行详解:
- "->"左边的是对应维度,以逗号隔开
- "->"右边的是最终output的形式
- 如果符号"->"被省略则代表输出为整体求和
- "…"表示省略之前或之后的所有维度
- einsum中涉及到的计算操作有很多,包括但不限于点积、对应元素相乘、求和、转置等
三、一些示例
einsum方法在numpy和pytorch中均有内置,这里以pytorch为例,首先定义一些需要用到的变量:
import torch
from torch import einsum
a = torch.ones(3,4)
b = torch.ones(4,5)
c = torch.ones(6,7,8)
d = torch.ones(3,4)
x, y = torch.randn(5), torch.randn(5)
- 计算矩阵所有元素之和
einsum('i,j', a) # 等价于einsum('i,j->', a)
einsum('i,j,k', c)
- 计算矩阵的迹
einsum('ii', a)
- 获取矩阵对角线元素组成的向量
einsum('ii->i', a)
- 向量相乘得到矩阵
einsum('i,j->ij', x, y)
- 矩阵点积
einsum('ij,jk->ik', a, b)
- 矩阵对应元素相乘
einsum('ij,ij->ij', a, d)
- 矩阵的转置
einsum('ijk->ikj', c)
einsum('...jk->...kj', c) # 两种形式等价
- 双线性运算
A = torch.randn(3,5,4)
l = torch.randn(2,5)
r = torch.randn(2,4)
torch.einsum('bn,anm,bm->ba', l, A, r)
- 最后来一个复杂的
a = torch.randn(3,4,5)
b = torch.randn(6,5)
c = torch.randn(6,3)
target = einsum('fti,di,df->dt', a,b,c)
# 把上面的求和过程写成循环的形式方便理解
# 对f和i进行求和
dt = torch.zeros(6,4)
for d in range(6):
for t in range(4):
tmp = 0
for f in range(3):
for i in range(5):
tmp += a[f,t,i].item() * b[d,i].item() * c[d,f].item()
dt[d,t] = tmp
print(target)
print('\n', dt)
# 结果:
tensor([[ 0.5712, -0.9473, 0.1972, 0.6474],
[ 0.6273, -1.3820, 2.0779, -0.5729],
[ 4.0083, -2.9979, 0.9493, 0.8467],
[15.9985, -7.1897, 1.1957, 2.7394],
[ 0.4298, -0.6509, 0.5128, 1.7027],
[-4.6843, -4.7948, 1.1389, -6.9389]])
tensor([[ 0.5712, -0.9473, 0.1972, 0.6474],
[ 0.6273, -1.3820, 2.0779, -0.5729],
[ 4.0083, -2.9979, 0.9493, 0.8467],
[15.9985, -7.1897, 1.1957, 2.7394],
[ 0.4298, -0.6509, 0.5128, 1.7027],
[-4.6843, -4.7948, 1.1389, -6.9389]])
更多推荐










所有评论(0)