einsum

两个基本概念

自由索引(Free indices)和求和索引(Summation indices):

三条基本规则

特殊规则

特殊规则有两条:

a = torch.randn(2,3,5,7,9)
# i = 7, j = 9
b = torch.einsum('...ij->...ji', [a])

二次变换(bilinear transformation)

np_a = a.numpy()
np_b = b.numpy()
np_c = c.numpy()
np_out = np.empty((2, 5), dtype=np.float32)

np_out = torch.einsum('ik,jkl,il->ij', [a, b, c]).numpy()
# ik broadcast成ikl
# il broadcast成 ikl
# 'ik,jkl,il->ij'可以理解成'ikl,jkl,ikl->ij'

for i in range(0, 2):
    for j in range(0, 5):
        # 求和索引内循环  这里是 k 和 l
        sum_result = 0
        for k in range(0, 3):
            for l in range(0, 7):
                sum_result += np_a[i, k] * np_b[j, k, l] * np_c[i, l]
        np_out[i, j] = sum_result

总结

a = torch.rand(2,3)
b = torch.rand(3,4)
c = torch.einsum("ik,kj->ij", [a, b])
# 等价操作 torch.mm(a, b)

equation 中的字符也可以理解为索引,就是输出张量的某个位置的值,是怎么从输入张量中得到的,比如上面矩阵乘法的输出 c 的某个点 c[i, j] 的值是通过 a[i, k] 和 b[k, j] 沿着 k 这个维度做内积得到的。

 

 


Revision #1
Created 11 January 2025 09:46:28 by Colin
Updated 3 March 2025 09:52:02 by Colin