torch.no_grad是一个类,pytorch官网描述如下:

 一个上下文管理器,disable梯度计算。disable梯度计算对于推理是有用的,当你确认不会调用Tensor.backward()的时候。这可以减少计算所用内存消耗。这个模式下,每个计算结果的requires_grad=False,尽管输入的requires_grad=True。

上下文管理器是thread local的,不会影响其它线程的计算。

x = torch.tensor([1.], requires_grad=True)
with torch.no_grad():
    y = x * 2
y.requires_grad  # False

也可以作为装饰器。 

@torch.no_grad()
def doubler(x):
    return x * 2
z = doubler(x)
z.requires_grad  # False

在我们对模型进行验证的时候,可以使用下面两种格式:

model.eval()
with torch.no_grad():
    pass

或者使用装饰器

@torch.no_grad()
def eval():
    ...

Logo

华为开发者空间,是为全球开发者打造的专属开发空间,汇聚了华为优质开发资源及工具,致力于让每一位开发者拥有一台云主机,基于华为根生态开发、创新。

更多推荐