torch.bmm 函数

这里只是记录一下,为了以后个人方便查找。

具体参考:pytorch官方文档

torch.bmm(input, mat2, *, deterministic=False, out=None) → Tensor

功能:对 input 和 mat2 矩阵执行批处理矩阵积。

  • input 和 mat2 必须是三维张量,每个张量包含相同数量的矩阵。

输入

  • input tensor 维度:(b×n×m) ;
  • mat2 tensor 维度: (b×m×p) ,

输出

  • out tensor 维度: (b×n×p) .

note:即,先不看第一个维度 b ,然后把后两个维度做矩阵乘法运算 (n * m) . (m * p) -> (n * p)。最终,得out维度 (b * n * p)。

Example

input = torch.randn(10, 3, 4)
mat2 = torch.randn(10, 4, 5)
res = torch.bmm(input, mat2)
print(res.size()) # torch.Size([10, 3, 5])
Logo

华为开发者空间,是为全球开发者打造的专属开发空间,汇聚了华为优质开发资源及工具,致力于让每一位开发者拥有一台云主机,基于华为根生态开发、创新。

更多推荐