pytorch训练模型时dataloader报错“default_collate: batch must contain tensors, numpy arrays, numbers, dicts ”
这个数组里边的dtype都是object,想起来之前在这个数组里边存过str,怪不得现在是object。打印number_info_array,label_info都是ndarray。解决方式手动改类型。
·
问题
pytorch跑的时候报了这个:
TypeError
default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found object
分析
- 代码过程
先继承 torch.utils.data.Dataset 类写一个子类,然后 init 一个 torch.utils.data.DataLoader 对象
结果在调用的时候:
for x, y in trainloader:
print(x.shape)
print(y.shape)
break
TypeError
default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found object
torch.utils.data.Dataset.__getitem__
下面写的是返回numpy数组:
torch.utils.data.Dataset.__getitem__ 下的返回部分
number_info_array, label_info 都是 ndarray
if self.mode != 'test':
label_info = currrent_Series[17:-1].values
return number_info_array, label_info
else:
return number_info_array
打印 number_info_array, label_info都是 ndarray
但其对应的dtype=object!!!
解决方式
这个数组里边的 dtype 都是 object ,想起来之前在这个数组里边存过 str, 怪不得现在是 object
解决方式:手动改类型
array.astype(np.float32)
更多推荐
已为社区贡献3条内容
所有评论(0)