张量感知量化杂谈

news/2024/10/2 3:23:19
感知量化训练 QAT
传统的训练后量化将模型从 FP32 量化到 INT8 精度时会产生较大的数值精度损失。感知量化训练(Aware Quantization Training)通过在训练期间模拟量化操作,可以最大限度地减少量化带来的精度损失。
QAT 的流程,如图7-9所示,首先基于预训练好的模型获取计算图,对计算图插入伪量化算子。准备好训练数据进行训练或者微调,在训练过程中最小化量化误差,最终得到 QAT 之后对神经网络模型。QAT 模型需要转换去掉伪量化算子,为推理部署做准备。
图7-9 QAT的流程方法
QAT 时会往模型中插入伪量化节点 FakeQuant 来模拟量化引入的误差。端测推理的时候折叠伪量化节点中的属性到张量中,在端侧推理的过程中直接使用张量中带有的量化属性参数。
1. 伪量化节点
在 QAT 过程中,所有权重和偏差都以 FP32 格式存储,反向传播照常进行。然而,在正向传播中,通过 FakeQuant 节点模拟量化。之所以称之为“fake”量化,是因为它们对数据进行量化并立即反量化,添加了类似于在量化推理过程中可能遇到的量化噪声,以模拟训练期间量化的效果。最终损失 loss 值因此包含了预期内的量化误差,使得将模型量化为 INT8 不会显著影响精度。
伪量化节点通常插入在模型的以下关键部分:
1)卷积层(Conv2D)前后:这可以帮助卷积操作在量化后适应低精度计算。
2)全连接层(Fully Connected Layer)前后:这对于处理密集矩阵运算的量化误差非常重要。
3)激活函数(如 ReLU)前后:这有助于在非线性变换中保持量化精度。
这些插入位置可以确保模型在训练期间模拟量化引入的噪声,从而在推理阶段更好地适应量化环境。
下面是一个计算图,同时对输入和权重插入伪量化算子,如图7-10所示。
图7-10 同时对输入和权重插入伪量化算子方法
伪量化节点的作用:
1)找到输入数据的分布,即找到 MIN 和 MAX 值;
2)模拟量化到低比特操作的时候的精度损失,把该损失作用到网络模型中,传递给损失函数,让优化器去在训练过程中对该损失值进行优化。
2. 正向传播
在正向传播中,FakeQuant 节点将输入数据量化为低精度(如 INT8),进行计算后再反量化为浮点数。这样,模型在训练期间就能体验到量化引入的误差,从而进行相应的调整。为了求得网络模型张量数据精确的 Min 和 Max 值,因此在模型训练的时候插入伪量化节点来模拟引入的误差,得到数据的分布。对于每一个算子,量化参数通过下面的方式得到:
𝑄=𝑅𝑆+𝑍𝑆=𝑅𝑚𝑎𝑥−𝑅𝑚𝑖𝑛𝑄𝑚𝑎𝑥−𝑄𝑚𝑖𝑛𝑍=𝑄𝑚𝑎𝑥−𝑅𝑚𝑎𝑥𝑆
FakeQuant 量化和反量化的过程:
𝑄(𝑥)=𝐹𝑎𝑘𝑒𝑄𝑢𝑎𝑛𝑡(𝑥)=𝐷𝑒𝑄𝑢𝑎𝑛𝑡(𝑄𝑢𝑎𝑛𝑡(𝑥))=𝑠∗(𝐶𝑙𝑎𝑚𝑝(𝑟𝑜𝑢𝑛𝑑(𝑥/𝑠)−𝑧)+𝑧)
原始权重为 W,伪量化之后得到浮点值 Q(W),同理得到激活的伪量化值 Q(X)。这些伪量化得到的浮点值虽然表示为浮点数,但仅能取离散的量化级别。
3. 反向传播
在反向传播过程中,模型需要计算损失函数相对于每个权重和输入的梯度。梯度通过 FakeQuant 节点进行传递,这些节点将量化误差反映到梯度计算中。模型参数的更新因此包含了量化误差的影响,使模型更适应量化后的部署环境。按照正向传播的公式,因为量化后的权重是离散的,反向传播的时候对 𝑊 求导数为 0:
𝜕𝑄(𝑊)𝜕𝑊=0
因为梯度为 0,所以网络学习不到任何内容,权重 𝑊 也不会更新:
𝑔𝑊=𝜕𝐿𝜕𝑊=𝜕𝐿𝜕𝑄(𝑊)⋅𝜕𝑄(𝑊)𝜕𝑊=0
如图7-11所示,使用直通估计器(Straight-Through Estimator,简称 STE),简单地将梯度通过量化传递,近似来计算梯度。这使得模型能够在前向传播中进行量化模拟,但在反向传播中仍然更新高精度的浮点数参数。STE 近似假设量化操作的梯度为 1,从而允许梯度直接通过量化节点:
𝑔𝑊=𝜕𝐿𝜕𝑊=𝜕𝐿𝜕𝑄(𝑊)
图7-11 前向传播与反向传播计算梯度
如果被量化的值在 [𝑥𝑚𝑖𝑛,𝑥𝑚𝑎𝑥] 范围内,STE 近似的结果为 1,否则为 0。这种方法使模型能够在训练期间适应量化噪声,从而在实际部署时能够更好地处理量化误差。
4. 量化感知训练的技巧
1)从已校准的表现最佳的 PTQ 模型开始
与其从未训练或随机初始化的模型开始量化感知训练,不如从已校准的 PTQ 模型开始,这样能为 QAT 提供更好的起点。特别是在低比特宽量化情况下,从头开始训练可能会非常困难,而使用表现良好的 PTQ 模型可以帮助确保更快的收敛和更好的整体性能。
2)微调时间为原始训练计划的 10%
量化感知训练不需要像原始训练那样耗时,因为模型已经相对较好地训练过,只需要调整到较低的精度。一般来说,微调时间为原始训练计划的 10% 是一个不错的经验法则。
3)使用从初始训练学习率 1% 开始的余弦退火学习率计划
①为了让 STE 近似效果更好,最好使用小学习率。大学习率更有可能增加 STE 近似引入的方差,从而破坏已训练的网络。
②使用余弦退火学习率计划可以帮助改善收敛,确保模型在微调过程中继续学
习。从较低的学习率(如初始训练学习率的 1%)开始有助于模型更平稳地适应较低
的精度,从而提高稳定性。直到达到初始微调学习率的 1%(相当于初始训练学习率
的 0.01%)。在 QAT 的早期阶段使用学习率预热和余弦退火可以进一步提高训练
的稳定性。
4)使用带动量的 SGD 优化器而不是 ADAM 或 RMSProp
尽管 ADAM 和 RMSProp 是深度学习中常用的优化算法,但它们可能不太适合量化感知微调。这些方法会按参数重新缩放梯度,可能会扰乱量化感知训练的敏感性。使用带动量的 SGD 优化器可以确保微调过程更加稳定,使模型能够更有控制地适应较低的精度。
通过 QAT,深度学习模型能够在保持高效推理的同时,尽量减少量化带来的精度损失,是模型压缩和部署的重要技术之一。在大多数情况下,一旦应用量化感知训练,量化推理精度几乎与浮点精度完全相同。然而,在 QAT 中重新训练模型的计算成本可能是数百个 epoch。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.ryyt.cn/news/55541.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈,一经查实,立即删除!

相关文章

张量量化方法杂谈

量化方法对比 QAT 可以达到的精度较高,但是往往需要较多的量化训练时间,量化成本比较大。PTQ 的量化过程比较迅速,只需要少量数据集来校准,但是量化后精度往往损失较多,见表7-1。 表7-1 量化方法参数对比量化方法功能经典适用场景使用条件易用性精度损失预期收益量化训练 …

深度学习模型优化杂谈

深度学习模型优化概述 模型压缩跟轻量化网络模型不同,压缩主要是对轻量化或者非轻量化模型执行剪枝、蒸馏、量化等压缩算法和手段,使得模型更加小、更加轻便、更加利于执行。 基本介绍 随着神经网络模型的复杂性和规模不断增加,模型对存储空间和计算资源的需求越来越多,使得…

【新型进程注入】一种绕过EDR和XDR的新型进程注入技术

1.技术研究背景 2.进程注入 3.Windows API监控 4.认识下新的注入技术 5.查找有漏洞的DLL 6.注射方法 7.自注入和EDR脱钩 8.远程进程注入 9.安全影响 10.检测方法 11.结论1.技术研究背景 进程注入是指将恶意代码注入到某个目标进程内存空间, 使攻击者能够隐藏恶意代码并逃避检测…

【漏洞挖掘技巧】编码绕过

demo说明 初始请求: url/?f=etc/passwd在这里,我尝试通过URL参数f访问/etc/passwd文件,但服务器返回了403状态码,表示访问被禁止。这是因为Web应用防火墙(WAF)检测到请求试图访问敏感文件,因此拦截了该请求 Base64编码:为了绕过WAF的检测,我们对字符串/etc/passwd进行…

给 Minecraft 的 MTR 网页地图换个字体

介绍 呃。。。其实就是感觉 MTR 自带的那个地图有点不太合自己口味,就简单拆包改一下 CSS! 修改 找到 Jetty 目录 右键用压缩包打开 .jar,依次打开 assets/mtr/website 即网页。我这里就简单的改了下 index.css 和 utilities.js。还需要单独改一个 JS 文件是因为 MTR 的地图…

VR虚拟驾驶未来发展以及汽车自动驾驶的特点

在自动驾驶汽车的基础上,VR虚拟现实技术的应用也让自动驾驶汽车更加智能化,能够实现更高级的驾驶体验,今天这篇文章就和大家一起探讨一下 VR虚拟驾驶未来发展的趋势,以及虚拟现实自动驾驶汽车所带来的几个改变。在自动驾驶汽车的基础上,VR虚拟现实技术的应用也让自动驾驶汽…

对 Minecraft 的 Dynmap 做一些小美化

介绍 Dynmap 是 Minecraft 中以网页 Web 形式呈现地图的模组,和 BlueMap 等类似。我自己倒是 Dynmap 用多了感觉更习惯一些就一直用下去了,虽然如今 BlueMap 之类的确实更先进。 LiveAtlas LiveAtlas 是 Dynmap 的第三方皮肤扩展,下载好后直接导入 dynmap/web 文件夹覆盖即可…

docker 环境下的 iptables 复杂配置

最近在项目中,遇到了一个比较辣手的 iptables 规则配置问题。记录一下简化一下问题: 本文中使用到 docker-compose 服务启动示例如下,虚拟机 IP 为 192.168.111.138 services:db:image: mariadb:ltscontainer_name: mysql_zdprestart: alwaysports:- "3306:3306"en…