torch.load()的作用:从文件加载用torch.save()保存的对象。

格式:torch.load — PyTorch 1.12 documentation

torch.load(f, map_location=None, pickle_module=pickle, **pickle_load_args)

参数解释:

  • f :类似类文件的对象(必须实现read(),:meth ' readline ',:meth ' tell '和:meth ' seek '),或者是包含文件名的类路径对象的字符串。
  • map_location : 函数、torch.device、字符串或字典指明如何重新映射存储位置。
  • pickle_module :用于解pickle元数据和对象的模块(必须匹配用于序列化文件pickle_module)
  • pickle_load_args : 传递给pickle_module.load()和pickle_module. unpickpickler()的可选关键字参数,例如errors=…(仅适用于Python 3)
  • 注意:常见的使用是 f 和 map_location。

常用使用方式:torch.load — PyTorch 1.12 documentation

# 常用根据设备,加载Tensor
>>> torch.load('modelparameters.pth', map_location = device)

# 默认加载方式,使用cpu加载cpu训练得出的模型或者用gpu调用gpu训练的模型
>>> torch.load('tensors.pt')

# Load all tensors onto the CPU
# ♥♥♥♥♥♥♥♥♥♥♥♥♥♥♥♥♥♥♥♥♥♥♥♥♥♥♥♥♥♥♥♥♥♥将全部Tensor全部加载到cpu上
>>> torch.load('tensors.pt', map_location=torch.device('cpu'))

# Load all tensors onto the CPU, using a function
# 使用一个函数将所有的Tensor加载到CPU上
>>> torch.load('tensors.pt', map_location=lambda storage, loc: storage)
# Load all tensors onto GPU 1

# 加载全部Tensor到GPU 1上
>>> torch.load('tensors.pt', map_location=lambda storage, loc: storage.cuda(1))

# Map tensors from GPU 1 to GPU 0
#将张量从GPU 1映射到GPU 0
>>> torch.load('tensors.pt', map_location={'cuda:1':'cuda:0'})

# Load tensor from io.BytesIO object
# 从io加载张量。BytesIO对象
>>> with open('tensor.pt', 'rb') as f:
...     buffer = io.BytesIO(f.read())
>>> torch.load(buffer)

# Load a module with 'ascii' encoding for unpickling
# 加载一个带有'ascii'编码的模块用于反pickle
>>> torch.load('module.pt', encoding='ascii')

 

 

Logo

为开发者提供学习成长、分享交流、生态实践、资源工具等服务,帮助开发者快速成长。

更多推荐