shuffle=True用于打乱数据集,每次都会以不同的顺序返回。

from torch.utils.data import Dataset, DataLoader


class DataSet(Dataset):
    def __init__(self, n):
        self.n = n
        self.data = [i for i in range(n)]

    def __len__(self):
        return self.n

    def __getitem__(self, idx):
        return self.data[idx]


data_set = DataSet(32)
data_loader = DataLoader(data_set, batch_size=8, shuffle=True)
for num in data_loader:
    print(num, end='\t')
print()
for num in data_loader:
    print(num, end='\t')
print()
for num in data_loader:
    print(num, end='\t')

实验结果:

tensor([ 7, 23, 19,  4, 22, 26, 21, 31])	tensor([27, 14, 13, 12,  2, 24,  5, 10])	tensor([ 0, 15, 16, 30, 25,  8,  6, 29])	tensor([11,  3,  1, 18,  9, 20, 17, 28])	
tensor([10, 11, 22, 28, 24, 19, 31,  8])	tensor([15,  2,  9,  1, 20, 14, 23, 16])	tensor([ 5, 25,  4,  6, 21, 30, 18, 27])	tensor([ 3, 13, 17, 12, 29,  0,  7, 26])	
tensor([ 1,  2, 13, 20, 11, 19,  9, 22])	tensor([ 0, 14, 25, 24, 27, 31, 12, 28])	tensor([10,  4, 23, 15, 21, 30,  6, 16])	tensor([ 7, 17,  8, 26,  3, 29, 18,  5])

如果是shuffle=False的话,实验结果

tensor([0, 1, 2, 3, 4, 5, 6, 7])	tensor([ 8,  9, 10, 11, 12, 13, 14, 15])	tensor([16, 17, 18, 19, 20, 21, 22, 23])	tensor([24, 25, 26, 27, 28, 29, 30, 31])	
tensor([0, 1, 2, 3, 4, 5, 6, 7])	tensor([ 8,  9, 10, 11, 12, 13, 14, 15])	tensor([16, 17, 18, 19, 20, 21, 22, 23])	tensor([24, 25, 26, 27, 28, 29, 30, 31])	
tensor([0, 1, 2, 3, 4, 5, 6, 7])	tensor([ 8,  9, 10, 11, 12, 13, 14, 15])	tensor([16, 17, 18, 19, 20, 21, 22, 23])	tensor([24, 25, 26, 27, 28, 29, 30, 31])
Logo

为开发者提供学习成长、分享交流、生态实践、资源工具等服务,帮助开发者快速成长。

更多推荐