项目场景:

进行ReLU类的实例运算时,出现了问题


问题描述

import numpy as np

class ReLU():
    def __init__(self):
        self.mask = None
    
    def forward(self, x):
        self.mask = (x <= 0)
        out = x.copy()
        out[self.mask] = 0
        
        return out
    
    def backward(self, dout):
        dout[self.mask] = 0
        dx = dout
        
        return dx
    
a = np.array([[1.0, -0.5], [-2.0, 3.0]])
relu = ReLU()
out = ReLU.forward(a)
print(out)

  运行程序时会出现报错:

runfile('C:/Users/Administrator/.spyder-py3/temp.py', wdir='C:/Users/Administrator/.spyder-py3')
Traceback (most recent call last):

  File "C:\Users\Administrator\.spyder-py3\temp.py", line 22, in <module>
    out = ReLU.forward(a)

TypeError: forward() missing 1 required positional argument: 'x'

 


原因分析:

看22行,前面已经将类赋给了relu,在下一行调用内部函数的时候,要用relu,不应该用原来类的名称ReLU.


解决方案:

import numpy as np

class ReLU():
    def __init__(self):
        self.mask = None
    
    def forward(self, x):
        self.mask = (x <= 0)
        out = x.copy()
        out[self.mask] = 0
        
        return out
    
    def backward(self, dout):
        dout[self.mask] = 0
        dx = dout
        
        return dx
    
a = np.array([[1.0, -0.5], [-2.0, 3.0]])
relu = ReLU()
out = relu.forward(a)
print(out)

    

结果:

[[1. 0.]
 [0. 3.]]

在运用类的时候,要先实例化类,然后用实例化之后的名称进一步调用类内部的函数。

Logo

华为开发者空间,是为全球开发者打造的专属开发空间,汇聚了华为优质开发资源及工具,致力于让每一位开发者拥有一台云主机,基于华为根生态开发、创新。

更多推荐