前言

  在项目开发过程中,遇到了一个PyTorch版本更新带来的问题,搜索了相关博客之后,都是降低PyTorch的版本,看似解决了问题,实则不然,治标不治本而已。本篇博客主要介绍一下如何从根源上解决这些问题。

1. 问题描述

# 此处粘贴了问题出错的部分代码
......
with torch.onnx.set_training(model_clone, False):
	device = distiller.model_device(model_clone)
	dummy_input = distiller.convert_tensors_recursively_to(dummy_input, device=device)
	self.dummy_input = dummy_input
	......

  运行项目的过程中出现错误:AttributeError: module 'torch.onnx' has no attribute 'set_training'

2. 问题原因

  查询了相关博客说是PyTorch的版本过高需要降低版本即可解决,没错,该项目所附带的requirements.txt要求的PyTorch版本为1.3.1。由于我所开发的项目实是在新版本上(PyTorch 1.8.1)开发的,所以降低PyTorch的版本解决此问题不是一个好的 idea 。

3. 解决过程

  要想解决这个问题,肯定是要知道torch.onnx.set_training()这行代码是干什么的,看了一下PyTorch 1.3.1官方文档,官方是这样说的:

def set_training(model, mode):
    r"""
    A context manager to temporarily set the training mode of 'model'
    to 'mode', resetting it when we exit the with-block.  A no-op if
    mode is None.
    """

    from torch.onnx import utils
    return utils.set_training(model, mode)

  大致意思就是说,在和一个with上下文管理器一起用是,其作用就是临时将model设置为mode模式,退出with再将modelmode重置。
  然后我看了一下官方的源码,发现也是很简单,就一行。我尝试在新版本中导出一下from torch.onnx.utils import set_training,发现没有此函数,ennnnn,合情合理,但有了另一个函数from torch.onnx.utils import select_model_mode_for_export,从函数名字上感觉很像,而且参数一毛一样,看了一下PyTorch 1.8.1版本的官方文档

def select_model_mode_for_export(model, mode):
    r"""
    A context manager to temporarily set the training mode of 'model'
    to 'mode', resetting it when we exit the with-block.  A no-op if
    mode is None.

    In version 1.6 changed to this from set_training
    """

    from torch.onnx import utils
    return utils.select_model_mode_for_export(model, mode)

  很明确了,在PyTorch 1.6版本中set_training变成了select_model_mode_for_export,改一下就可以了。

with torch.onnx.set_training(model_clone, False)
# 改为
with torch.onnx.select_model_mode_for_export(model_clone, torch.onnx.TrainingMode.EVAL)

在这里插入图片描述

Logo

华为开发者空间,是为全球开发者打造的专属开发空间,汇聚了华为优质开发资源及工具,致力于让每一位开发者拥有一台云主机,基于华为根生态开发、创新。

更多推荐