问题分析

在这里插入图片描述

  • 就像是字面意思那样,这个错误是因为模型中的 weights 没有被转移到 cuda 上,而模型的数据转移到了 cuda 上而造成的
  • 但是造成这个问题的原因却没有那么简单。
  • 绝大多数时候,造成这个的原因是因为你定义好模型之后,没有对模型进行 to(device) 而造成的,但是,也有可能,是因为你的模型在定义的时候,没有定义好,导致模型的一部分在加载的时候没有办法转移到 cuda上。

细节举例

  • 比如我现在定义了一个模型 A,B,它们的结构如下:
# @Time : 2022/1/19 17:57 
# @Author : PeinuanQin
# @File : test.py
import torch.nn as nn
import torch
import torch.utils.data as Data
from tqdm import tqdm
from torchvision import transforms,datasets
import numpy as np
import torchvision
from torch.optim import lr_scheduler


class A(nn.Module):
    def __init__(self):
        super(A,self).__init__()
        self.conv = nn.Conv2d(in_channels=3
                              ,out_channels=8
                              ,kernel_size=3)
        self.relu = nn.ReLU(inplace=True)

    def forward(self,x):
        out = self.conv(x)
        out = self.relu(out)
        B_model = B()
        out = B_model(out)
        return out

class B(nn.Module):
    def __init__(self):
        super(B,self).__init__()
        self.conv = nn.Conv2d(in_channels=8
                              ,out_channels=16
                              ,kernel_size=3)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        out = self.conv(x)
        out = self.relu(out)
        return out



在这里插入图片描述

  • 这个时候就会报错,而报错的原因,就是因为 torch 的流程是这样的:
    • 首先将所有的模型加载,先从 A 开始,进入 A 的 init 中把所有的内容加载,然后,通过 main 函数中的 to(device) 操作,就把加载的所有内容和网络定义都放到 cuda 上了,但是注意!!!
    • 第二步开始训练,训练的过程中,都是通过 forward 函数来调用的,但是这个时候程序发现,当进入 A 的 forward 中运行的时候,出现了几个 B 的网络层,但是注意:这些 B 中定义的网络层,在网络加载的过程中可是没有出现在 A 的 __init__里面,也就理所当然地没有加载到 cuda上,因此在 A 的 forward 中出现的时候,B 的这几个网络层的 weight 依然在 cpu 上,这就导致了错误。在这里插入图片描述

改错思路

  • 将所有的内容都放到 cpu 上运行,即:
    在这里插入图片描述
  • 但显然这是个治标不治本的方法,我们就没有办法使用 gpu 训练了,因此我们选择把所有的网络层(只要有参数需要训练的网络层)都放到 init 里面去定义,只在 forward 中写运行时的逻辑,即:
class A(nn.Module):
    def __init__(self):
        super(A,self).__init__()
        self.conv = nn.Conv2d(in_channels=3
                              ,out_channels=8
                              ,kernel_size=3)
        self.relu = nn.ReLU(inplace=True)
        self.b_module = B()

    def forward(self,x):
        out = self.conv(x)
        out = self.relu(out)
        out = self.b_module(out)
        return out

class B(nn.Module):
    def __init__(self):
        super(B,self).__init__()
        self.conv = nn.Conv2d(in_channels=8
                              ,out_channels=16
                              ,kernel_size=3)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        out = self.conv(x)
        out = self.relu(out)
        return out

Logo

为开发者提供学习成长、分享交流、生态实践、资源工具等服务,帮助开发者快速成长。

更多推荐