PyTorch torch.no_grad()
torch.no_grad() 一般用于神经网络的推理阶段, 表示张量的计算过程中无需计算梯度torch.no_grad 是一个类, 实现了 __enter__ 和 __exit__ 方法, 在进入环境管理器时记录梯度使能状态以及禁止梯度计算, 退出环境管理器时还原, 它还继承了 _DecoratorContextManager, 拥有装饰器的能力(依然是使用 with 语句)# 摘自源码clas
·
torch.no_grad()
一般用于神经网络的推理阶段, 表示张量的计算过程中无需计算梯度
torch.no_grad
是一个类, 实现了 __enter__ 和 __exit__ 方法, 在进入环境管理器时记录梯度使能状态以及禁止梯度计算, 退出环境管理器时还原, 它还继承了 _DecoratorContextManager
, 拥有装饰器的能力(依然是使用 with 语句)
# 摘自源码
class no_grad(_DecoratorContextManager):
def __init__(self):
self.prev = False
def __enter__(self):
self.prev = torch.is_grad_enabled()
torch.set_grad_enabled(False)
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
torch.set_grad_enabled(self.prev)
class _DecoratorContextManager:
"""Allow a context manager to be used as a decorator"""
def __call__(self, func: F) -> F:
@functools.wraps(func)
def decorate_context(*args, **kwargs):
with self.__class__():
return func(*args, **kwargs)
return cast(F, decorate_context)
def __enter__(self) -> None:
raise NotImplementedError
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
raise NotImplementedError
另外, torch.no_grad
用于代替旧版本的 volatile=True
import torch
x = torch.tensor([1.0], requires_grad=True)
y_1: torch.Tensor = x * x
y_1.backward()
print("y_1:", y_1.requires_grad, x.grad)
with torch.no_grad():
y_2 = x * x
print("y_2:", y_2.requires_grad)
@torch.no_grad()
def demo(x):
y_3 = x * x
print("y_3:", y_3.requires_grad)
demo(x)
打印
y_1: True tensor([2.])
y_2: False
y_3: False
y_1 是通常情况, y_1依赖于x, 而x需要求导, 所以y_1也需要求导, y_2 和 y_3 明确无需求导
除了 torch.no_grad()
还有 torch.enable_grad()
明确需要求导以及 torch.set_grad_enabled(mode)
, 它们均支持环境管理器和装饰器
# 单独使用 torch.set_grad_enabled
torch.set_grad_enabled(False)
y_4 = x * x
print("y_4:", y_4.requires_grad)
torch.set_grad_enabled(True)
y_5 = x * x
print("y_5:", y_5.requires_grad)
结果
y_4: False
y_5: True
底层实现位于 “aten/src/ATen/core/grad_mode.cpp”
thread_local bool GradMode_enabled = true;
bool GradMode::is_enabled() {
return GradMode_enabled;
}
void GradMode::set_enabled(bool enabled) {
GradMode_enabled = enabled;
}
更多推荐
已为社区贡献5条内容
所有评论(0)