torch.bmm 函数
torch.bmm 函数这里只是记录一下,为了以后个人方便查找。具体参考:pytorch官方文档torch.bmm(input, mat2, *, deterministic=False, out=None) → Tensor功能:对 input 和 mat2 矩阵执行批处理矩阵积。input 和 mat2 必须是三维张量,每个张量包含相同数量的矩阵。输入:input tensor 维度:(b×n
·
torch.bmm 函数
这里只是记录一下,为了以后个人方便查找。
具体参考:pytorch官方文档
torch.bmm(input, mat2, *, deterministic=False, out=None) → Tensor
功能:对 input 和 mat2 矩阵执行批处理矩阵积。
- input 和 mat2 必须是三维张量,每个张量包含相同数量的矩阵。
输入:
- input tensor 维度:(b×n×m) ;
- mat2 tensor 维度: (b×m×p) ,
输出:
- out tensor 维度: (b×n×p) .
note:即,先不看第一个维度 b ,然后把后两个维度做矩阵乘法运算 (n * m) . (m * p) -> (n * p)。最终,得out维度 (b * n * p)。
Example:
input = torch.randn(10, 3, 4)
mat2 = torch.randn(10, 4, 5)
res = torch.bmm(input, mat2)
print(res.size()) # torch.Size([10, 3, 5])
更多推荐
已为社区贡献2条内容
所有评论(0)