问题描述:

运行torch.sum(torch.mul(users, pos_items), axis=1)时报错:

TypeError: sum() received an invalid combination of arguments - got (Tensor, axis=int), but expected one of:
 * (Tensor input)
 * (Tensor input, torch.dtype dtype)
      didn't match because some of the keywords were incorrect: axis
 * (Tensor input, tuple of ints dim, torch.dtype dtype, Tensor out)
 * (Tensor input, tuple of ints dim, bool keepdim, torch.dtype dtype, Tensor out)
 * (Tensor input, tuple of ints dim, bool keepdim, Tensor out)

其中,torch.mul函数的功能是两个维度相等的矩阵的对应位相乘,其中users和pos_items的大小都是:torch.Size([1024, 256])。

另外,torch.matmul是tensor的乘法,当输入是二维时和tensor.mm函数用法相同做普通的矩阵乘法,也能用作高维矩阵乘法。

解决方法

按照提示,axis关键字错误,经查,torch中用dim,或者直接把axis关键字去掉,即改成:

torch.sum(torch.mul(users, pos_items), dim=1)

或者

torch.sum(torch.mul(users, pos_items), 1)

Logo

华为开发者空间,是为全球开发者打造的专属开发空间,汇聚了华为优质开发资源及工具,致力于让每一位开发者拥有一台云主机,基于华为根生态开发、创新。

更多推荐