pytorch中trace时出现forward() takes 2 positional arguments but 3 were given问题
错误源代码import torchclass MyCell(torch.nn.Module):def __init__(self):super(MyCell, self).__init__()self.linear = torch.nn.Linear(4, 4)def forward(self, input):x = input[0]h = input[1]new_h = torch.tan
·
错误源代码
import torch
class MyCell(torch.nn.Module):
def __init__(self):
super(MyCell, self).__init__()
self.linear = torch.nn.Linear(4, 4)
def forward(self, input):
x = input[0]
h = input[1]
new_h = torch.tanh(self.linear(x) + h)
return new_h, new_h
my_cell = MyCell()
x, h = torch.rand(3, 4), torch.rand(3, 4)
input = (x,h)
traced_cell = torch.jit.trace(my_cell, input)
print(traced_cell)
traced_cell(x, h)
正确源代码
import torch
class MyCell(torch.nn.Module):
def __init__(self):
super(MyCell, self).__init__()
self.linear = torch.nn.Linear(4, 4)
def forward(self, x,h):
new_h = torch.tanh(self.linear(x) + h)
return new_h, new_h
my_cell = MyCell()
x, h = torch.rand(3, 4), torch.rand(3, 4)
input = (x, h)
traced_cell = torch.jit.trace(my_cell, input)
print(traced_cell)
traced_cell(x, h)
更多推荐
已为社区贡献2条内容
所有评论(0)