torch.nn.identity()方法详解
先看代码m = nn.Identity(54,unused_argument1=0.1,unused_argument2=False)input = torch.randn(128, 20)output = m(input)>>> print(output.size())torch.Size([128, 20])这是官方文档中给出的代码,很明显,没有什么变化,输入的是torch,
·
先看代码
m = nn.Identity(
54,
unused_argument1=0.1,
unused_argument2=False
)
input = torch.randn(128, 20)
output = m(input)
>>> print(output.size())
torch.Size([128, 20])
这是官方文档中给出的代码,很明显,没有什么变化,输入的是torch,输出也是,并且给定的参数似乎并没有起到变化的效果。
看源码
class Identity(Module):
r"""A placeholder identity operator that is argument-insensitive.
Args:
args: any argument (unused)
kwargs: any keyword argument (unused)
Examples::
>>> m = nn.Identity(54, unused_argument1=0.1, unused_argument2=False)
>>> input = torch.randn(128, 20)
>>> output = m(input)
>>> print(output.size())
torch.Size([128, 20])
"""
def __init__(self, *args, **kwargs):
super(Identity, self).__init__()
def forward(self, input: Tensor) -> Tensor:
return input
这相当的简洁明了啊,输入是啥,直接给输出,不做任何的改变。再看文档中的一句话:A placeholder identity operator that is argument-insensitive.
翻译一下就是:不区分参数的占位符标识运算符。百度翻译,其实意思就是这个网络层的设计是用于占位的,即不干活,只是有这么一个层,放到残差网络里就是在跳过连接的地方用这个层,显得没有那么空虚!
更多推荐
已为社区贡献2条内容
所有评论(0)