PyTorch 实现联邦学习FedAvg (详解)

开始做第二个工作了,又把之前看的FedAvg的代码看了一遍。联邦学习好难啊…

1. 介绍

简单介绍一下FedAvg

FedAvg是一种分布式框架,允许多个用户同时训练一个机器学习模型。在训练过程中并不需要上传任何私有的数据到服务器。本地用户负责训练本地数据得到本地模型,中心服务器负责加权聚合本地模型,得到全局模型,经过多轮迭代后最终得到一个趋近于集中式机器学习结果的模型,有效地降低了传统机器学习源数据聚合带来的许多隐私风险。

(1) 首先用户用户从服务器中下载模型参数,更新本地模型参数,进行本地机器学习训练。

(2) 其次在用户中通过本地随机梯度下降不断更新模型的精度,当达到预定的本地训练次数时,将本地训练后的模型参数上传到服务端中。

(3) 服务端随机抽取用户设备,并接收本地用户上传的模型参数梯度进行聚合。表示服务端模型参数梯度聚合,服务端将抽样的用户模型参数梯度加权平均并与上一轮聚合后模型参数相加,更新全局模型参数。最后将聚合后的模型参数回传给抽样用户设备,然后继续执行步骤1操作。重复执行上述步骤直至通讯次数达到t。

2. 参数配置

parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, description="FedAvg")

parser.add_argument('-g', '--gpu', type=str, default='0', help='gpu id to use(e.g. 0,1,2,3)')

# 客户端的数量
parser.add_argument('-nc', '--num_of_clients', type=int, default=100, help='numer of the clients')

# 随机挑选的客户端的数量
parser.add_argument('-cf', '--cfraction', type=float, default=0.1, help='C fraction, 0 means 1 client, 1 means total clients')

# 训练次数(客户端更新次数)
parser.add_argument('-E', '--epoch', type=int, default=5, help='local train epoch')

# batchsize大小
parser.add_argument('-B', '--batchsize', type=int, default=10, help='local train batch size')

# 模型名称
parser.add_argument('-mn', '--model_name', type=str, default='mnist_cnn', help='the model to train')

# 学习率
parser.add_argument('-lr', "--learning_rate", type=float, default=0.01, help="learning rate, \
                    use value from origin paper as default")

parser.add_argument('-dataset',"--dataset",type=str,default="mnist",help="需要训练的数据集")

# 模型验证频率(通信频率)
parser.add_argument('-vf', "--val_freq", type=int, default=5, help="model validation frequency(of communications)")

parser.add_argument('-sf', '--save_freq', type=int, default=20, help='global model save frequency(of communication)')

#n um_comm 表示通信次数,此处设置为1k
parser.add_argument('-ncomm', '--num_comm', type=int, default=1000, help='number of communications')

parser.add_argument('-sp', '--save_path', type=str, default='./checkpoints', help='the saving path of checkpoints')

parser.add_argument('-iid', '--IID', type=int, default=0, help='the way to allocate data to clients')

简单说明一下,需要配置的参数

参数说明备注
–gpu配置gpu
–num_of_clients设置客户端设备的数量
–cfraction随机挑选客户端的数量(默认值是0.1)当客户端达到一定数量之后,联邦学习对客户端进行抽样聚合。
–epoch客户端本地训练的次数默认值5
–batchsize客户端本地训练的batchsize大小
–model_name本地训练的model名称
–learning_rate本地训练的学习率
–dataset采用的数据集
–val_freq模型验证频率(通信频率)本地训练与服务器之间通讯的频率(默认值5轮)这里与–epoch 设置的值,应相同。
–save_freqglobal model save frequency(of communication)每20轮,服务器端将保存聚合的模型
–num_comm客户端与服务端通讯的轮次
–save_path最终聚合模型,保存的地址
–IID本地数据是否是独立同分布

3. 数据重构

目的是将数据按照, 非独立同分布(non-iid)或者是(iid)的方式重构数据结构, 并返回重构的后的 数据标签 与 数据。

self.train_data 
self.train_label 

3.1 加载数据

Mnist 手写数字数据集是由60000张28*28 的照片组成,一共分为10类(从0-9),每一类有6000张照片,训练集有60000张照片,测试集是10000张照片.

train_images 表示训练集的图片大小是(60000, 28, 28, 1), train_labels 表示训练集标签,与train_images 一一对应 ,图像大小是(60000,10)
test_images 表示测试集,测试集大小为(10000,28,28,1).test_labels 表示测试集标签,大小 (10000, 10)

 # 加载数据集
    data_dir = r'.\data\MNIST'
    # data_dir = r'./data/MNIST'
    # python路径拼接os.path.join() 路径变为.\data\MNIST\train-images-idx3-ubyte.gz
    train_images_path = os.path.join(data_dir, 'train-images-idx3-ubyte.gz')
    train_labels_path = os.path.join(data_dir, 'train-labels-idx1-ubyte.gz')
    test_images_path = os.path.join(data_dir, 't10k-images-idx3-ubyte.gz')
    test_labels_path = os.path.join(data_dir, 't10k-labels-idx1-ubyte.gz')


    train_images = extract_images(train_images_path)
    print(train_images.shape) # (60000, 28, 28, 1) 一共60000 张图片,每一张是28*28*1
   
    train_labels = extract_labels(train_labels_path)
    print(train_labels.shape) # (60000, 10)
   
    test_images = extract_images(test_images_path)
    print(test_images.shape) # (10000, 28, 28, 1)
   
    test_labels = extract_labels(test_labels_path)
    print(test_labels.shape) # (10000, 10) 10000维
  

将数据集压缩成60000 * 784

并对数据集,进行归一化处理.将数组中的每个位置都与1.0 / 255.0 相乘

# 讲图片每一张图片变成28*28 = 784
# reshape(60000,28*28)
train_images = train_images.reshape(train_images.shape[0], train_images.shape[1] * train_images.shape[2])
test_images = test_images.reshape(test_images.shape[0], test_images.shape[1] * test_images.shape[2])

#--------------------归一化处理--------------------#
train_images = train_images.astype(np.float32)
train_images = np.multiply(train_images, 1.0 / 255.0)# 数组对应元素位置相乘

test_images = test_images.astype(np.float32)
test_images = np.multiply(test_images, 1.0 / 255.0)

3.2 分割数据集

思路
一共有60000个样本, 分到100个客户端
IID:
我们首先将数据集打乱,然后为每个Client分配600个样本。
Non-IID:
我们首先根据数据标签将数据集排序(即MNIST中的数字大小),
然后将其划分为200组大小为300的数据切片,然后分给每个Client两个切片。

3.2.1 IID
#一个参数 默认起点0,步长为1 输出:[0 1 2]
# a = np.arange(3)
# 一共60000个
# numpy 中的随机打乱数据方法np.random.shuffle
'''
  num = np.arange(20)
  print(num)
  # [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19]
  np.random.shuffle(num)
  print(num)
  # [ 1  5 19  9 14  2 12  3  6 18  4  8 16  0 10 17 13  7 15 11]
'''
order = np.arange(self.train_data_size)
np.random.shuffle(order)
self.train_data = train_images[order]
self.train_label = train_labels[order]
3.2.2 Non-IID
'''
 numpy.argmax(array, axis) 用于返回一个numpy数组中最大值的索引值。当一组中同时出现几个最大值时,返回第一个最大值的索引值。
 two_dim_array = np.array([[1, 3, 5], [0, 4, 3]])
 max_index_axis0 = np.argmax(two_dim_array, axis = 0) # 找 纵向 最大值的下标 
 max_index_axis1 = np.argmax(two_dim_array, axis = 1) # 找 横向 最大值的下标
 print(max_index_axis0)
 print(max_index_axis1)
 # [0 1 0] 
 # [2 1]

'''
labels = np.argmax(train_labels, axis=1)
# 对数据标签进行排序

order = np.argsort(labels)
print("标签下标排序")
print(train_labels[order[0:10]])
self.train_data = train_images[order]
self.train_label = train_labels[order]

4. 初始化客户端

初始化100个Clients。传入数据mnist,以及数据分布方式(IID)。

重构数据 (请看上以一部分)

myClients = ClientsGroup('mnist', args['IID'], args['num_of_clients'], dev)

加载数据

# 得到已经被重新分配的数据
mnistDataSet = GetDataSet(self.data_set_name, self.is_iid)

test_data = torch.tensor(mnistDataSet.test_data)
test_label = torch.argmax(torch.tensor(mnistDataSet.test_label), dim=1)
# 加载测试数据
self.test_data_loader = DataLoader(TensorDataset( test_data, test_label), batch_size=100, shuffle=False)

train_data = mnistDataSet.train_data
train_label = mnistDataSet.train_label

4.1数据分配到客户端

然后将其划分为200组大小为300的数据切片,然后分给每个Client两个切片

一共200组,没组大小为300

# 60000 /100 = 600/2 = 300
shard_size = mnistDataSet.train_data_size // self.num_of_clients // 2
# print("shard_size:"+str(shard_size))

# np.random.permutation 将序列进行随机排序
# np.random.permutation(60000//300=200)
shards_id = np.random.permutation(mnistDataSet.train_data_size // shard_size)
# 一共200个
# print(shards_id)   

迭代客户端,shard_id1 与 shards_id2 所对应的数据块分发到客户端中(数据数600)。

for i in range(self.num_of_clients):

    ## shards_id1
    ## shards_id2
    ## 是所有被分得的两块数据切片
    # 0 2 4 6...... 偶数
    shards_id1 = shards_id[i * 2]
    # 0+1 = 1 2+1 = 3 .... 奇数
    shards_id2 = shards_id[i * 2 + 1]
    #
    # 例如shard_id1 = 10
    # 10* 300 : 10*300+300
    # 将数据以及的标签分配给该客户端
    data_shards1 = train_data[shards_id1 * shard_size: shards_id1 * shard_size + shard_size]
    data_shards2 = train_data[shards_id2 * shard_size: shards_id2 * shard_size + shard_size]
    label_shards1 = train_label[shards_id1 * shard_size: shards_id1 * shard_size + shard_size]
    label_shards2 = train_label[shards_id2 * shard_size: shards_id2 * shard_size + shard_size]

    #
    # np.vstack 是按照垂直方向堆叠
    # np.hstack: 按水平方向(列顺序)堆叠数组构成一个新的数组
    '''
                In[4]:
                a = np.array([[1,2,3]])
                a.shape
                # (1, 3)

                In [5]:
                b = np.array([[4,5,6]])
                b.shape             
                # (1, 3)

                In [6]:
                c = np.vstack((a,b)) # 将两个(1,3)形状的数组按垂直方向叠加
                print(c)
                c.shape # 输出形状为(2,3)
                [[1 2 3]
                 [4 5 6]]
                # (2, 3)

            '''
# 将两个被分到得数据块进行堆叠
    local_data, local_label = np.vstack((data_shards1, data_shards2)), np.vstack((label_shards1, label_shards2))
    local_label = np.argmax(local_label, axis=1)

    # 创建一个客户端
    someone = client(TensorDataset(torch.tensor(local_data), torch.tensor(local_label)), self.dev)
    # 为每一个clients 设置一个名字
    # client10
    self.clients_set['client{}'.format(i)] = someone

5. Server

前4部分针对联邦学习的准备工作已经完成,在server部分,将对联邦学习进行训练。

设置随机选取客户端

 # 每次随机选取10个Clients
    num_in_comm = int(max(args['num_of_clients'] * args['cfraction'], 1))

得到当前的全局模型

    # 得到全局的参数
    global_parameters = {}
    # net.state_dict()  # 获取模型参数以共享

    # 得到每一层中全连接层中的名称fc1.weight
    # 以及权重weights(tenor)
    # 得到网络每一层上
    for key, var in net.state_dict().items():
        # print("key:"+str(key)+",var:"+str(var))
        print("张量的维度:"+str(var.shape))
        print("张量的Size"+str(var.size()))
        global_parameters[key] = var.clone()

5.1训练

客户端与服务端进行通讯的轮次为1000次

首先得到被挑选的10个客户端的order

    # 得到被挑选的10个客户端
    order = np.random.permutation(args['num_of_clients'])
    print("order:")
    print(len(order))
    print(order)
    # 得到10个客户端
    clients_in_comm = ['client{}'.format(i) for i in order[0:num_in_comm]]
    print("客户端"+str(clients_in_comm))
    print(type(clients_in_comm)) # <class 'list'>

5.2 本地客户端更新模型

迭代被挑选的客户端。

并将当前的全局模型 传入到客户端中,更新客户端的本地模型。

该方法为 localUpdate()

5.2.1参数如下:

param: localEpoch 当前Client的迭代次数
param: localBatchSize 当前Client的batchsize大小
param: Net Server共享的模型
param: LossFun 损失函数
param: opti 优化函数
param: global_parmeters 当前通讯中最全局参数
return: 返回当前Client基于自己的数据训练得到的新的模型参数

5.2.2 流程
  1. 更新本地模型

    Net.load_state_dict(global_parameters, strict=True)
    
  2. 加载本地数据

    self.train_dl = DataLoader(self.train_ds, batch_size=localBatchSize, shuffle=True)
    
  3. 进行本地训练

     for epoch in range(localEpoch):
                for data, label in self.train_dl:
                    # 加载到GPU上
                    data, label = data.to(self.dev), label.to(self.dev)
                    # 模型上传入数据
                    preds = Net(data)
                    # 计算损失函数
                    '''
                        这里应该记录一下模型得损失值 写入到一个txt文件中
                    '''
                    loss = lossFun(preds, label)
                    # 反向传播
                    loss.backward()
                    # 计算梯度,并更新梯度
                    opti.step()
                    # 将梯度归零,初始化梯度
                    opti.zero_grad()
    
  4. 最后返回本地模型

  5. 整体代码

    def localUpdate(self, localEpoch, localBatchSize, Net, lossFun, opti, global_parameters):
          '''
              param: localEpoch 当前Client的迭代次数
              param: localBatchSize 当前Client的batchsize大小
              param: Net Server共享的模型
              param: LossFun 损失函数
              param: opti 优化函数
              param: global_parmeters 当前通讯中最全局参数
              return: 返回当前Client基于自己的数据训练得到的新的模型参数
          '''
          # 加载当前通信中最新全局参数
          # 传入网络模型,并加载global_parameters参数的
          Net.load_state_dict(global_parameters, strict=True)
          # 载入Client自有数据集
          # 加载本地数据
          self.train_dl = DataLoader(self.train_ds, batch_size=localBatchSize, shuffle=True)
          # 设置迭代次数
          for epoch in range(localEpoch):
              for data, label in self.train_dl:
                  # 加载到GPU上
                  data, label = data.to(self.dev), label.to(self.dev)
                  # 模型上传入数据
                  preds = Net(data)
                  # 计算损失函数
                  '''
                      这里应该记录一下模型得损失值 写入到一个txt文件中
                  '''
                  loss = lossFun(preds, label)
                  # 反向传播
                  loss.backward()
                  # 计算梯度,并更新梯度
                  opti.step()
                  # 将梯度归零,初始化梯度
                  opti.zero_grad()
          # 返回当前Client基于自己的数据训练得到的新的模型参数
          return Net.state_dict()
    

5.3 服务端聚合

将客户端上传的模型参数数据,线性求和

  # 对所有的Client返回的参数累加(最后取平均值)           
        if sum_parameters is None:
                sum_parameters = {}
                for key, var in local_parameters.items():
                    sum_parameters[key] = var.clone()
            else:
                for var in sum_parameters:
                    sum_parameters[var] = sum_parameters[var] + local_parameters[var]

FedAvg聚合

求出模型参数的平均值,并更新服务端的模型。

    # 取平均值,得到本次通信中Server得到的更新后的模型参数       
       for var in global_parameters:
            global_parameters[var] = (sum_parameters[var] / num_in_comm)
        net.load_state_dict(global_parameters, strict=True)

6. 测试结果

这是通讯1000次之后的结果

在这里插入图片描述

7. 运行

运行

  1. getData.py 下载数据集

  2. server.py 训练

python server.py -nc 100 -cf 0.1 -E 5 -B 10 -mn mnist_cnn  -ncomm 1000 -iid 0 -lr 0.01 -vf 20 -g 0

代码下载地址

https://download.csdn.net/download/qq_36018871/43064799
或者
链接:https://pan.baidu.com/s/13gOszw2TED2X6uezUxyS7w
提取码:zaxf

Logo

为开发者提供学习成长、分享交流、生态实践、资源工具等服务,帮助开发者快速成长。

更多推荐