【torch.no_grad()】
torch.no_grad()的两种写法,with写法和装饰器写法。
·
torch.no_grad是一个类,pytorch官网描述如下:
一个上下文管理器,disable梯度计算。disable梯度计算对于推理是有用的,当你确认不会调用Tensor.backward()的时候。这可以减少计算所用内存消耗。这个模式下,每个计算结果的requires_grad=False,尽管输入的requires_grad=True。
上下文管理器是thread local的,不会影响其它线程的计算。
x = torch.tensor([1.], requires_grad=True)
with torch.no_grad():
y = x * 2
y.requires_grad # False
也可以作为装饰器。
@torch.no_grad()
def doubler(x):
return x * 2
z = doubler(x)
z.requires_grad # False
在我们对模型进行验证的时候,可以使用下面两种格式:
model.eval()
with torch.no_grad():
pass
或者使用装饰器
@torch.no_grad()
def eval():
...
更多推荐
已为社区贡献1条内容
所有评论(0)