在这里插入图片描述
RuntimeError: scatter(): Expected dtype int64 for index

1.报错原因:

scatter要求数据是int64类型,而我在定义tensor时写的是torch.Tensor(x),应该写成torch.LongTensor(x),指定为int64类型。

2.解决方法

找到原数据的定义方式,改!
一般在dtype=np.int64;dtype=np.float32中
(多数定义函数都有dtype属性)
最好int和float的位数要一致

import numpy as np
a = np.random.randint(100, size=(10**6), dtype="int64")
print(a)
print(type(a[0]))

在这里插入图片描述

Logo

为开发者提供学习成长、分享交流、生态实践、资源工具等服务,帮助开发者快速成长。

更多推荐