# QAT follows the same steps as PTQ, with the exception of the training loop before you actually convert the model to its quantized version
# QAT遵循与PTQ相同的步骤,除了在实际将模型转换为量化版本之前进行训练循环
''''''
'''量化感知训练步骤:
step1.搭建模型
step2.融合(可选步骤)
step3.插入stubs(1和3可合在一起)
step4.准备(主要是选择架构)
step5.训练
step6.模型转换
'''
import torch
from torch import nn

backend = "fbgemm"  # running on a x86 CPU. Use "qnnpack" if running on ARM.

'''step1.搭建模型build model'''
m = nn.Sequential(
     nn.Conv2d(2,64,8),
     nn.ReLU(),
     nn.Conv2d(64, 128, 8),
     nn.ReLU(),
)

"""step2.融合Fuse(可选步骤)"""
torch.quantization.fuse_modules(m, ['0','1'], inplace=True) # fuse first Conv-ReLU pair
torch.quantization.fuse_modules(m, ['2','3'], inplace=True) # fuse second Conv-ReLU pair

"""step3.插入stubs于模型,Insert stubs"""
m = nn.Sequential(torch.quantization.QuantStub(),
                  *m,
                  torch.quantization.DeQuantStub())

"""step4.准备Prepare"""
m.train()
m.qconfig = torch.quantization.get_default_qconfig(backend)
torch.quantization.prepare_qat(m, inplace=True)

"""step5.训练Training Loop"""
n_epochs = 10
opt = torch.optim.SGD(m.parameters(), lr=0.1)
loss_fn = lambda out, tgt: torch.pow(tgt-out, 2).mean()
for epoch in range(n_epochs):
  x = torch.rand(10,2,24,24)
  out = m(x)
  loss = loss_fn(out, torch.rand_like(out))
  opt.zero_grad()
  loss.backward()
  opt.step()
  print(loss)

"""step6.模型转换Convert"""
m.eval()
torch.quantization.convert(m, inplace=True)
Logo

为开发者提供学习成长、分享交流、生态实践、资源工具等服务,帮助开发者快速成长。

更多推荐