PyTorch中一般约定是使用.pt或.pth文件扩展名保存模型,通过torch.save保存模型,通过torch.load加载模型。torch.save和torch.load函数的实现在torch/serialization.py文件中。

      这里以LeNet5模型为例进行说明。LeNet5的介绍过程参考:https://blog.csdn.net/fengbingchun/article/details/125462001

      你应该保存模型的参数,而不是模型本身(you should keep the parameters of the model, not the model itself)。保存模型进行推理(inference)时,只需要保存训练模型的学习参数即可。使用torch.save函数保存模型的state_dict将为你以后恢复模型提供最大的灵活性,这就是为什么它是保存模型的推荐方法。

      torch.save函数有两种保存方式:一种是保存整个模型,此时模型的type应该为继承自nn.Module的类,这里则为类LeNet5;另一种是仅保存模型的参数,此时模型的type应该为有序字典即类OrderedDict。

      torch.save函数将序列化的对象保存到磁盘。此函数使用Python的pickle进行序列化。通过pickle可以保存各种对象的模型、张量和字典。

      pickle的介绍参考参考:https://blog.csdn.net/fengbingchun/article/details/125584682

      torch.load函数使用pickle的unpickling将pickle对象文件反序列化到内存中。

      torch.nn.Module的load_state_dict函数:使用反序列化的state_dict加载模型的参数字典。

      torch.nn.Module的state_dict函数:在PyTorch中,torch.nn.Module模型的可学习参数(即weights和biases)包含在模型的参数中(通过model.parameters函数访问)。state_dict只是一个Python字典对象,它将每一层映射到其参数张量(tensor)。注意:只有具有可学习参数的层(卷积层,线性层等)和注册缓冲区(batchnorm’s running_mean)在模型的state_dict中有条目( Note that only layers with learnable parameters (convolutional layers, linear layers, etc.) and registered buffers (batchnorm’s running_mean) have entries in the model’s state_dict)。优化器对象(torch.optim)也有一个state_dict,其中包含有关优化器状态的信息,以及使用的超参数。因为state_dict对象是Python字典,所以它们可以很容易地保存、更新、更改和恢复。

      注意:

      (1).在运行推理之前,你必须调用model.eval函数将dropout和批量标准化层(batch normalization layers)设置为评估模式。不这样做会产生不一致的推理结果。

      (2).load_state_dict函数采用字典对象,而不是保存对象的路径(load_state_dict function takes a dictionary object, NOT a path to a saved object)。这意味着你必须在将保存的state_dict传递给load_state_dict函数之前对其进行反序列化(you must deserialize the saved state_dict before you pass it to the load_state_dict function)。

      (3).如果你只打算保留性能最佳的模型,不要忘记best_model_state = model.state_dict()返回对状态的引用而不是其副本(not its copy)。你必须序列化best_model_state或使用best_model_state = deepcopy(model.state_dict()) 否则你的best_model_state将通过后续训练迭代不断更新。结果,最终的模型状态将是过拟合模型的状态。

      (4).推荐:torch.save(model.state_dict(), PATH)/model.load_state_dict(torch.load(PATH));

      不推荐:torch.save(model, PATH)/model = torch.load(PATH):保存整个模型。

      以下是测试的代码段:

def save_load_model(model):
    '''saving and loading models'''
    model.load_state_dict(torch.load("../../data/Lenet-5.pth")) # 加载模型
    model.eval() # 将网络设置为评估模式

    # state_dict:返回一个字典,保存着module的所有状态,参数和persistent buffers都会包含在字典中,字典的key就是参数和buffer的names
    print("model state dict keys:", model.state_dict().keys())
    print("model type:", type(model)) # model type: <class 'pytorch.lenet5.test_lenet5_mnist.LeNet5'>
    print("model state dict type:", type(model.state_dict())) # model state dict type: <class 'collections.OrderedDict'>

    optimizer = torch.optim.Adam(params=model.parameters(), lr=0.001)
    print_state_dict(model.state_dict(), optimizer.state_dict())

    torch.save(model, "../../data/Lenet-5_all.pth") # 保存整个模型
    torch.save(model.state_dict(), "../../data/Lenet-5_parameters.pth") # 推荐:仅保存训练模型的参数,为以后恢复模型提供最大的灵活性

      保存一般检查点(checkpoint)用于推理或恢复训练时,你保存的不仅仅是模型的state_dict,保存优化器的state_dict也很重要,因为它包含随着模型训练而更新的缓冲区和参数(buffers and parameters)。你可能还想要保存已训练的epoch编号、最新记录的训练损失、以及外部的torch.nn.Embedding层等。这样的checkpoint通常比单独的模型大2至3倍。

      要保存多个组件,需要将它们组织在字典中并使用torch.save序列化字典。一个常见的PyTorch约定是使用.tar文件扩展名保存这些checkpoint

      以下是测试的代码段:

def save_load_checkpoint(model):
    '''saving & loading a general checkpoint for inference and/or resuming training'''
    path = "../../data/Lenet-5_parameters.tar"
    model.load_state_dict(torch.load("../../data/Lenet-5.pth")) # 加载模型
    optimizer = torch.optim.Adam(params=model.parameters(), lr=0.001)
    torch.save({
                'epoch': 5,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict()
                }, path)

    checkpoint = torch.load(path)
    model2 = LeNet5(n_classes=10).to('cpu') # 实例化一个LeNet5网络对象
    model2.load_state_dict(checkpoint['model_state_dict'])
    optimizer2 = torch.optim.SGD(params=model2.parameters(), lr=0.1)
    optimizer2.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    print("epoch:", epoch)
    model.eval() # 将网络设置为评估模式
    #model.train() # 恢复训练,将网络设置为训练模式

    print_state_dict(model2.state_dict(), optimizer2.state_dict())

      保存由多个torch.nn.Modules组成的模型时,例如GAN、sequence-to-sequence model或模型集合,需遵循与保存checkpoint时相同的方法。即保存每个模型的state_dict和相应优化器的dictionary。

      以下是测试的代码段:

def save_load_multiple_models():
    '''saving multiple models in one file'''
    path1 = "../../data/Lenet-5.pth"
    path2 = "../../data/Lenet-5_parameters_mul.tar"
    model1 = LeNet5(n_classes=10).to('cpu') # 实例化一个LeNet5网络对象
    model1.load_state_dict(torch.load(path1)) # 加载模型
    optimizer1 = torch.optim.Adam(params=model1.parameters(), lr=0.001)

    model2 = LeNet5(n_classes=10).to('cpu') # 实例化一个LeNet5网络对象
    model2.load_state_dict(torch.load(path1)) # 加载模型
    optimizer2 = torch.optim.SGD(params=model2.parameters(), lr=0.1)

    torch.save({
            'epoch': 100,
            'model1_state_dict': model1.state_dict(),
            'model2_state_dict': model2.state_dict(),
            'optimizer1_state_dict': optimizer1.state_dict(),
            'optimizer2_state_dict': optimizer2.state_dict(),
            }, path2)

    checkpoint = torch.load(path2)
    modelA = LeNet5(n_classes=10).to('cpu') # 实例化一个LeNet5网络对象
    modelA.load_state_dict(checkpoint['model1_state_dict'])
    optimizerA = torch.optim.SGD(params=modelA.parameters(), lr=0.1)
    optimizerA.load_state_dict(checkpoint['optimizer1_state_dict'])

    modelB = LeNet5(n_classes=10).to('cpu') # 实例化一个LeNet5网络对象
    modelB.load_state_dict(checkpoint['model2_state_dict'])
    optimizerB = torch.optim.Adam(params=modelB.parameters(), lr=0.01)
    optimizerB.load_state_dict(checkpoint['optimizer2_state_dict'])

    epoch = checkpoint['epoch']
    print("epoch:", epoch)
    modelA.eval() # 将网络设置为评估模式
    #modelA.train() # 恢复训练,将网络设置为训练模式

    #modelB.eval() # 将网络设置为评估模式
    modelB.train() # 恢复训练,将网络设置为训练模式

    print_state_dict(modelA.state_dict(), optimizerA.state_dict())
    print_state_dict(modelB.state_dict(), optimizerB.state_dict())

      部分加载模型或加载部分模型(partially loading a model or loading a partial model)是迁移学习或训练新的复杂模型时的常见场景。利用经过训练的参数,即使只有少数可用,也将有助于热启动(warmstart)训练过程,并有望帮助你的模型从头开始训练更快地收敛。

      PyTorch中其它保存模型的方法

      (1).torch.package是一种以独立、稳定的格式打包PyTorch模型的新方法。它包含模型参数、元数据(metadata)及架构。此外,torch.package添加了对创建包含任意PyTorch code的密封包(hermetic package)的支持,这意味着你可能会使用它来打包你想要的任何东西,例如PyTorch DataLoaders、Datasets等。

      (2).使用经过训练的模型进行推理的另一种方法是使用TorchScript,它是PyTorch模型的中间表示,可以在Python以及C++等环境中运行。TorchScript实际上是用于扩展推理和部署的推荐模型格式。注意:使用TorchScript格式,你将能够加载导出的模型并进行推理,而无需定义模型类。

      以上文字描述主要翻译自:https://pytorch.org/tutorials/beginner/saving_loading_models.html

      GitHubhttps://github.com/fengbingchun/PyTorch_Test

Logo

华为开发者空间,是为全球开发者打造的专属开发空间,汇聚了华为优质开发资源及工具,致力于让每一位开发者拥有一台云主机,基于华为根生态开发、创新。

更多推荐