pytorch 中 混合精度训练(真香)
一、什么是混合精度训练在pytorch的tensor中,默认的类型是float32,神经网络训练过程中,网络权重以及其他参数,默认都是float32,即单精度,为了节省内存,部分操作使用float16,即半精度,训练过程既有float32,又有float16,因此叫混合精度训练。二、如何进行混合精度训练pytorch中是自动混合精度训练,使用 torch.cuda.amp.autocast 和 t
一、什么是混合精度训练
在pytorch的tensor中,默认的类型是float32,神经网络训练过程中,网络权重以及其他参数,默认都是float32,即单精度,为了节省内存,部分操作使用float16,即半精度,训练过程既有float32,又有float16,因此叫混合精度训练。
二、如何进行混合精度训练
pytorch中是自动混合精度训练,使用 torch.cuda.amp.autocast 和 torch.cuda.amp.GradScaler 这两个模块。
torch.cuda.amp.autocast:在选择的区域中自动进行数据精度之间的转换,即提高了运算效率,又保证了网络的性能。
torch.cuda.amp.GradScaler:来解决数据溢出问题,即数据溢出问题:Overflow / Underflow
# Creates model and optimizer in default precision
model = Net().cuda()
optimizer = optim.SGD(model.parameters(), ...)
# Creates a GradScaler once at the beginning of training.
scaler = GradScaler()
for epoch in epochs:
for input, target in data:
optimizer.zero_grad()
# Runs the forward pass with autocasting.
with autocast():
output = model(input)
loss = loss_fn(output, target)
# Scales loss. Calls backward() on scaled loss to create scaled gradients.
# Backward passes under autocast are not recommended.
# Backward ops run in the same dtype autocast chose for corresponding forward ops.
scaler.scale(loss).backward()
# scaler.step() first unscales the gradients of the optimizer's assigned params.
# If these gradients do not contain infs or NaNs, optimizer.step() is then called,
# otherwise, optimizer.step() is skipped.
scaler.step(optimizer)
# Updates the scale for next iteration.
scaler.update()
三、哪些运算操作可以自动转换,哪些不可以
首先:只有 CUDA 操作有资格进行自动转换
下面这些操作可以自动转换为float16:
matmul, addbmm, addmm, addmv, addr, baddbmm, bmm, chain_matmul, multi_dot, conv1d, conv2d, conv3d, conv_transpose1d, conv_transpose2d, conv_transpose3d, GRUCell, linear, LSTMCell, matmul, mm, mv, prelu, RNNCell
下面这些操作可以自动转换为float32:
pow, rdiv, rpow, rtruediv, acos, asin, binary_cross_entropy_with_logits, cosh, cosine_embedding_loss, cdist, cosine_similarity, cross_entropy, cumprod, cumsum, dist, erfinv, exp, expm1, group_norm, hinge_embedding_loss, kl_div, l1_loss, layer_norm, log, log_softmax, log10, log1p, log2, margin_ranking_loss, mse_loss, multilabel_margin_loss, multi_margin_loss, nll_loss, norm, normalize, pdist, poisson_nll_loss, pow, prod, reciprocal, rsqrt, sinh, smooth_l1_loss, soft_margin_loss, softmax, softmin, softplus, sum, renorm, tan, triplet_margin_loss
有些操作并没有指定是float16还是float32,但是需要输入的数据类型一致,如果所有的输入都是float16,操作就是在float16中进行,如果输入中的任何一个是float32,操作就是在float32中进行。
四、遇到不能自动转换的操作怎么办
例如下面这句代码,在自动转换的区域,where操作中,tensor phi是float16,但是cosine是float32,where操作没有自动转换的能力,因此就会出现数据类型匹配,报错!
注意:下面的代码在非混合精度训练中,没有问题,因为所有生成的Tensor数据都是float32类型
phi = torch.where(cosine > self.th, phi, cosine - self.mm)
把上面一句代码改为下面的代码就可以了
phi = torch.where(cosine.to(dtype=phi.dtype) > self.th, phi, cosine.to(dtype=phi.dtype) - self.mm)
更多推荐










所有评论(0)