问题

pytorch跑的时候报了这个:

TypeError
default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found object

分析

  • 代码过程
    先继承 torch.utils.data.Dataset 类写一个子类,然后 init 一个 torch.utils.data.DataLoader 对象

结果在调用的时候:

for x, y in trainloader:
    print(x.shape)
    print(y.shape)
    break

TypeError
default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found object

torch.utils.data.Dataset.__getitem__下面写的是返回numpy数组:

torch.utils.data.Dataset.__getitem__ 下的返回部分
number_info_array, label_info 都是 ndarray

if self.mode != 'test':
	label_info  = currrent_Series[17:-1].values 
	return number_info_array, label_info
else:
	return number_info_array

打印 number_info_array, label_info都是 ndarray
但其对应的dtype=object!!!

解决方式

这个数组里边的 dtype 都是 object ,想起来之前在这个数组里边存过 str, 怪不得现在是 object

解决方式:手动改类型
array.astype(np.float32)

Logo

为开发者提供学习成长、分享交流、生态实践、资源工具等服务,帮助开发者快速成长。

更多推荐