[报错]-RuntimeError: Input type (torch.cuda.HalfTensor) and weight type (torch.cuda.FloatTensor) shoul
模型输入的数据类型要与模型参数的数据类型一致。
·
RuntimeError: Input type (torch.cuda.HalfTensor) and weight type (torch.cuda.FloatTensor) should be the same
模型输入的数据类型要与模型参数的数据类型一致。
torch.cuda.HalfTensor:对应
np.array(x, dtype = 'float32')
torch.cuda.FloatTensor:对应
np.array(x, dtype = 'float16')
参考链接:
https://stackoverflow.com/questions/65029217/runtimeerror-input-type-torch-cuda-floattensor-and-weight-type-torch-cuda-ha
更多推荐
所有评论(0)