参考https://zhuanlan.zhihu.com/p/30934236https://blog.csdn.net/jzwong/article/details/108867297?ops_request_misc=%25257B%252522request%25255Fid%252522%25253A%252522161179811616780269879401%252522%25252C%252522scm%252522%25253A%25252220140713.130102334.pc%25255Fblog.%252522%25257D&request_id=161179811616780269879401&biz_id=0&utm_medium=distribute.pc_search_result.none-task-blog-2~blog~first_rank_v1~rank_blog_v1-1-108867297.pc_v1_rank_blog_v1&utm_term=dataloader

Pytorch的数据读取主要包括3个类:

1.Dataset

2.DataLoader

3.DataLoaderIter

这三者的大致是依次封装的关系,1被装进2,2被装进3

一.torch.utils.data.Dataset是一个抽象类,自定义的Dataset需要集成它并实现两个成员方法:

1.__getitem__()

2.__len__()

第一个最为重要,即每次怎么读数据。以图片为例:

def __getitem__(self,index):
    img_path,label=self.data[index].img_path,self.data[index].label
    img=Image.open(img_path)
    return img,label

因为Dataset被封进Dataloader,从这里基本上就知道dataloader返回哪些东西了。如果需要看dataloader的东西,大致可以采用以下几行代码简单看一下dataloader里面的东西:

for inputs,label in dataloader:
    print(inputs,labels)

也就是说,默认情况下遍历dataloader其实就是输出一个batch内的图像和对应的label

Logo

为开发者提供学习成长、分享交流、生态实践、资源工具等服务,帮助开发者快速成长。

更多推荐