点乘 torch.mul(a,b)

点乘是对应位置元素相乘
点乘都是broadcast的,可以用torch.mul(a, b)实现,也可以直接用*实现。

python中的广播机制(broadcasting)
broadcasting可以这样理解:如果你有一个(m,n)的矩阵,让它加减乘除一个(1,n)的矩阵,它会被复制m次,成为一个(m,n)的矩阵,然后再逐元素地进行加减乘除操作。同样地对(m,1)的矩阵成立
在这里插入图片描述
图源:https://www.jianshu.com/p/fadd169cd396

  • 当a, b维度满足广播机制时,会自动填充到相同维度相点乘。
    例如:a的维度为(2,3),b的维度为(1,3);
    或者:a的维度为(2,3),b的维度为(2,1)。
  • 当a, b维度不满足广播机制时,要求a和b的维度必须相等。
    a的维度为(1,2),b的维度为(2,3)就会报错:The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1
    报错的意思是b中维度为3的位置必须和a中维度为2的位置相匹配,因为a中有个维度1,要想满足广播机制就必须是(1,2)和(2,2),否则就需要满足维度必须相等(2,3)和(2,3)

二维矩阵乘 torch.mm(a,b)

torch.mm(mat1, mat2, out=None)

在这里插入图片描述
二维矩阵乘法要求a、b两个参数的维度要满足乘法要求。
该函数一般只用来计算两个二维矩阵的矩阵乘法,并且不支持broadcast操作。

三维矩阵乘 torch.bmm(a,b)

由于神经网络训练一般采用mini-batch,经常输入的时三维带batch的矩阵,所以提供

torch.bmm(bmat1, bmat2, out=None)

在这里插入图片描述
该函数的两个输入必须是三维矩阵并且第一维相同(表示Batch维度), 后两维符合矩阵乘法要求。不支持broadcast操作

高维矩阵乘 torch.matmul(a,b)

torch.matmul(input, other, out=None)

高维矩阵的最后两维满足矩阵乘法要求,前面维数认为是batch_size, 使用广播机制。

主要参考资料:
pytorch之torch中的几种乘法
pytorch 中矩阵乘法总结

Logo

为开发者提供学习成长、分享交流、生态实践、资源工具等服务,帮助开发者快速成长。

更多推荐