【bug解决】TypeError: forward() missing 1 required positional argument: ‘x‘
【bug解决】TypeError: forward() missing 1 required positional argument: 'x' 。类的实例化
·
项目场景:
进行ReLU类的实例运算时,出现了问题
问题描述
import numpy as np
class ReLU():
def __init__(self):
self.mask = None
def forward(self, x):
self.mask = (x <= 0)
out = x.copy()
out[self.mask] = 0
return out
def backward(self, dout):
dout[self.mask] = 0
dx = dout
return dx
a = np.array([[1.0, -0.5], [-2.0, 3.0]])
relu = ReLU()
out = ReLU.forward(a)
print(out)
运行程序时会出现报错:
runfile('C:/Users/Administrator/.spyder-py3/temp.py', wdir='C:/Users/Administrator/.spyder-py3')
Traceback (most recent call last):
File "C:\Users\Administrator\.spyder-py3\temp.py", line 22, in <module>
out = ReLU.forward(a)
TypeError: forward() missing 1 required positional argument: 'x'
原因分析:
看22行,前面已经将类赋给了relu,在下一行调用内部函数的时候,要用relu,不应该用原来类的名称ReLU.
解决方案:
import numpy as np
class ReLU():
def __init__(self):
self.mask = None
def forward(self, x):
self.mask = (x <= 0)
out = x.copy()
out[self.mask] = 0
return out
def backward(self, dout):
dout[self.mask] = 0
dx = dout
return dx
a = np.array([[1.0, -0.5], [-2.0, 3.0]])
relu = ReLU()
out = relu.forward(a)
print(out)
结果:
[[1. 0.]
[0. 3.]]
在运用类的时候,要先实例化类,然后用实例化之后的名称进一步调用类内部的函数。
更多推荐
已为社区贡献2条内容
所有评论(0)