分布式训练时,torch.utils.data.distributed.DistributedSampler做了什么?

试验用到的code
import os
import sys

import torch
import torch.nn as nn
import torch.distributed as dist
import torchvision

from torch.utils.data import Dataset, DataLoader

import numpy as np

class InnerDS(Dataset):
    def __init__(self, n=8):
        self.n = n

    def __len__(self):
        return self.n

    def __getitem__(self, item):
        np_img = np.random.rand(3,224,224)
        image = torch.from_numpy(np_img).float()
        label = np.random.randint(0,9)
        return image, label, item


local_rank = int(os.environ['LOCAL_RANK'])
world_size = int(os.environ['WORLD_SIZE'])
rank = int(os.environ['RANK'])

dist.init_process_group('nccl',world_size=world_size, rank=rank)


torch.cuda.set_device(local_rank)


# case 1
# ds = InnerDS(8)
# sampler = torch.utils.data.distributed.DistributedSampler(ds)
# dataloader = DataLoader(ds, batch_size=4, drop_last=True)

# case 2
# ds = InnerDS(8)
# sampler = torch.utils.data.distributed.DistributedSampler(ds)
# dataloader = DataLoader(ds, batch_size=4, sampler=sampler, drop_last=True)

# case 3
# ds = InnerDS(8)
# sampler = torch.utils.data.distributed.DistributedSampler(ds)
# dataloader = DataLoader(ds, batch_size=4, sampler=sampler, drop_last=True)

# case 4
# ds = InnerDS(6)
# sampler = torch.utils.data.distributed.DistributedSampler(ds)
# dataloader = DataLoader(ds, batch_size=4, sampler=sampler, drop_last=False)


# case 5
# ds = InnerDS(5)
# sampler = torch.utils.data.distributed.DistributedSampler(ds)
# dataloader = DataLoader(ds, batch_size=4, sampler=sampler, drop_last=False)

# case 6
# ds = InnerDS(10)
# sampler = torch.utils.data.distributed.DistributedSampler(ds)
# dataloader = DataLoader(ds, batch_size=4, sampler=sampler, drop_last=False)

# case 7
ds = InnerDS(10)
sampler = torch.utils.data.distributed.DistributedSampler(ds)
dataloader = DataLoader(ds, batch_size=4, sampler=sampler, drop_last=True)


for epoch in range(2):
    # case 3+
    # sampler.set_epoch(epoch)
    for index,(_,labels, items) in enumerate(dataloader):
        print(items.cuda())
        # print('epoch:\t', epoch)
        dist.barrier()
试验过程

执行code: case 1, 不使用torch.utils.data.distributed.DistributedSampler, 结果显示,每块卡上(每个进程)每个epoch中都迭代了所有的数据。

CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 --nnodes=1 --node_rank=0 --master_addr='10.100.37.21' --master_port='29500' exp_ds_dist_sampler.py
*****************************************
Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
*****************************************
tensor([0, 1, 2, 3], device='cuda:0')
tensor([0, 1, 2, 3], device='cuda:1')
tensor([4, 5, 6, 7], device='cuda:1')
tensor([4, 5, 6, 7], device='cuda:0')
tensor([0, 1, 2, 3], device='cuda:1')
tensor([0, 1, 2, 3], device='cuda:0')
tensor([4, 5, 6, 7], device='cuda:0')
tensor([4, 5, 6, 7], device='cuda:1')

执行code: case 2, 使用torch.utils.data.distributed.DistributedSampler, 结果显示,数据被平分到两块卡上,每个epoch被分配到每块卡上的数据都一样。

CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 --nnodes=1 --node_rank=0 --master_addr='10.100.37.21' --master_port='29500' exp_ds_dist_sampler.py
*****************************************
Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
*****************************************
tensor([4, 7, 2, 1], device='cuda:0')
tensor([0, 3, 5, 6], device='cuda:1')
tensor([0, 3, 5, 6], device='cuda:1')
tensor([4, 7, 2, 1], device='cuda:0')

为了解决case 2中每块卡上分配的数据相同的问题,执行code: case 3, 在每个epoch中加入sampler.set_epoch(epoch)

 CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 --nnodes=1 --node_rank=0 --master_addr='10.100.37.21' --master_port='29500' exp_ds_dist_sampler.py
*****************************************
Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
*****************************************
tensor([0, 3, 5, 6], device='cuda:1')
tensor([4, 7, 2, 1], device='cuda:0')
tensor([5, 2, 7, 1], device='cuda:0')
tensor([4, 6, 3, 0], device='cuda:1')

执行code: case 4, 数据集里有6例数据,在两张卡,batch_size=4, drop_last=False时,每张卡上平均分配了3例数据;当drop_last=True时,不足4例数据的会被丢掉,在数据集只有6例数据时,每张卡上分配的3例数据都会被丢掉;

 CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 --nnodes=1 --node_rank=0 --master_addr='10.100.37.21' --master_port='29500' exp_ds_dist_sampler.py
*****************************************
Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
*****************************************
tensor([2, 3, 1], device='cuda:0')
tensor([5, 0, 4], device='cuda:1')
tensor([5, 0, 4], device='cuda:1')
tensor([2, 3, 1], device='cuda:0')

执行code: case 5, 数据集里有5例数据,两张卡,batch_size=4, drop_last=False时,每张卡上平均分配了2.5例数据, 会向上补齐到6例数据,每张卡上三张,补齐的标准是把数据集的第一例数据(在本例1中index=4)用来补齐;如果将sampler.set_epoch(epoch)加入其中,补齐标准不变,
在本例2中,第一个epoch补齐的是index=4,第二个epoch补齐的是index=0

CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 --nnodes=1 --node_rank=0 --master_addr='10.100.37.21' --master_port='29500' exp_ds_dist_sampler.py
*****************************************
Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
*****************************************
tensor([4, 1, 2], device='cuda:0')
tensor([0, 3, 4], device='cuda:1')
tensor([0, 3, 4], device='cuda:1')
tensor([4, 1, 2], device='cuda:0')
CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 --nnodes=1 --node_rank=0 --master_addr='10.100.37.21' --master_port='29500' exp_ds_dist_sampler.py
*****************************************
Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
*****************************************
tensor([4, 1, 2], device='cuda:0')
tensor([0, 3, 4], device='cuda:1')
tensor([4, 3, 0], device='cuda:1')
tensor([0, 2, 1], device='cuda:0')

当多进程同时工作时,执行case 6时,有的迭代中,会出现batch_size=1的情况,如果模型中存在BatchNormalize这样的模块时,运行可能报错。

 CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 --nnodes=1 --node_rank=0 --master_addr='10.100.37.21' --master_port='29500' exp_ds_dist_sampler.py
*****************************************
Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
*****************************************
tensor([4, 7, 3, 0], device='cuda:0')
tensor([1, 5, 9, 8], device='cuda:1')
tensor([6], device='cuda:0')
tensor([2], device='cuda:1')
tensor([1, 5, 9, 8], device='cuda:1')
tensor([4, 7, 3, 0], device='cuda:0')
tensor([6], device='cuda:0')
tensor([2], device='cuda:1')

为了避免case 6这种情况,可以引入BatchSampler这样的模块,运行case 7, 将drop_last=True

CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 --nnodes=1 --node_rank=0 --master_addr='10.100.37.21' --master_port='29500' exp_ds_dist_sampler.py
*****************************************
Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
*****************************************
tensor([4, 7, 3, 0], device='cuda:0')
tensor([1, 5, 9, 8], device='cuda:1')
tensor([1, 5, 9, 8], device='cuda:1')
tensor([4, 7, 3, 0], device='cuda:0')
Logo

为开发者提供学习成长、分享交流、生态实践、资源工具等服务,帮助开发者快速成长。

更多推荐