项目场景:

进行ReLU类的实例运算时,出现了问题


问题描述

import numpy as np

class ReLU():
    def __init__(self):
        self.mask = None
    
    def forward(self, x):
        self.mask = (x <= 0)
        out = x.copy()
        out[self.mask] = 0
        
        return out
    
    def backward(self, dout):
        dout[self.mask] = 0
        dx = dout
        
        return dx
    
a = np.array([[1.0, -0.5], [-2.0, 3.0]])
relu = ReLU()
out = ReLU.forward(a)
print(out)

  运行程序时会出现报错:

runfile('C:/Users/Administrator/.spyder-py3/temp.py', wdir='C:/Users/Administrator/.spyder-py3')
Traceback (most recent call last):

  File "C:\Users\Administrator\.spyder-py3\temp.py", line 22, in <module>
    out = ReLU.forward(a)

TypeError: forward() missing 1 required positional argument: 'x'

 


原因分析:

看22行,前面已经将类赋给了relu,在下一行调用内部函数的时候,要用relu,不应该用原来类的名称ReLU.


解决方案:

import numpy as np

class ReLU():
    def __init__(self):
        self.mask = None
    
    def forward(self, x):
        self.mask = (x <= 0)
        out = x.copy()
        out[self.mask] = 0
        
        return out
    
    def backward(self, dout):
        dout[self.mask] = 0
        dx = dout
        
        return dx
    
a = np.array([[1.0, -0.5], [-2.0, 3.0]])
relu = ReLU()
out = relu.forward(a)
print(out)

    

结果:

[[1. 0.]
 [0. 3.]]

在运用类的时候,要先实例化类,然后用实例化之后的名称进一步调用类内部的函数。

Logo

为开发者提供学习成长、分享交流、生态实践、资源工具等服务,帮助开发者快速成长。

更多推荐