TensorDataset 可以用来对 tensor 进行打包,就好像 python 中的 zip 功能。该类通过每一个 tensor 的第一个维度进行索引。因此,该类中的 tensor 第一维度必须相等。


from torch.utils.data import TensorDataset
import torch
from torch.utils.data import DataLoader

a = torch.tensor([[11, 22, 33], [44, 55, 66], [77, 88, 99], [11, 22, 33], [44, 55, 66], [77, 88, 99], [11, 22, 33], [44, 55, 66], [77, 88, 99], [11, 22, 33], [44, 55, 66], [77, 88, 99]])
b = torch.tensor([0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2])
train_ids = TensorDataset(a, b) 
# 切片输出
print(train_ids[0:2])
print('#' * 30)
# 循环取数据
for x_train, y_label in train_ids:
    print(x_train, y_label)
# DataLoader进行数据封装
print('#' * 30)
train_loader = DataLoader(dataset=train_ids, batch_size=4, shuffle=True)
for i, data in enumerate(train_loader, 1):  # 注意enumerate返回值有两个,一个是序号,一个是数据(包含训练数据和标签)
    x_data, label = data
    print(' batch:{0} x_data:{1}  label: {2}'.format(i, x_data, label))   # y data (torch tensor)

运行结果:

注意:TensorDataset 中的参数必须是 tensor 

pytorch中使用torch.utils.data.TensorDataset时报错TypeError: 'int' object is not callable,同时在代码中并没有与TensorDataset重名的函数的解决办法。

使用TensorDataset函数的代码为:

train_dataset = Data.TensorDataset(x_train,y_train)

执行之后发现报错:

TypeError: 'int' object is not callable。但是检查代码发现并没有与TensorDataset重名的函数。

经过研究TensorDataset函数的源码发现,这个函数传入的参数必须是tensor类型的,所以把x_train与y_train转换为tensor类型在执行这个函数就不报错了,更改后的代码为:

train_dataset = Data.TensorDataset(pt.tensor(x_train),pt.tensor(y_train))

 

Logo

为开发者提供学习成长、分享交流、生态实践、资源工具等服务,帮助开发者快速成长。

更多推荐