torch.bmm 函数

这里只是记录一下,为了以后个人方便查找。

具体参考:pytorch官方文档

torch.bmm(input, mat2, *, deterministic=False, out=None) → Tensor

功能:对 input 和 mat2 矩阵执行批处理矩阵积。

  • input 和 mat2 必须是三维张量,每个张量包含相同数量的矩阵。

输入

  • input tensor 维度:(b×n×m) ;
  • mat2 tensor 维度: (b×m×p) ,

输出

  • out tensor 维度: (b×n×p) .

note:即,先不看第一个维度 b ,然后把后两个维度做矩阵乘法运算 (n * m) . (m * p) -> (n * p)。最终,得out维度 (b * n * p)。

Example

input = torch.randn(10, 3, 4)
mat2 = torch.randn(10, 4, 5)
res = torch.bmm(input, mat2)
print(res.size()) # torch.Size([10, 3, 5])
Logo

为开发者提供学习成长、分享交流、生态实践、资源工具等服务,帮助开发者快速成长。

更多推荐