[CVPR2024]DeiT-LT Distillation Strikes Back for Vision Transformer Training on Long-Tailed Datasets

news/2024/9/22 20:50:15

在长尾数据集上,本文引入强增强(文中也称为OOD)实现对DeiT的知识蒸馏的改进,实现尾部类分类性能的提升。

动机

  1. ViT相较于CNN缺少归纳偏置,如局部性(一个像素与周围的区域关系更紧密)、平移不变性(图像的主体在图像的任意位置都应该一样重要)。因此需要大型数据集进行预训练。
  2. 长尾数据学习的工作有很多,例如Re-weighting、Re-sampling,这些方法可以帮助尾部类学习,但会损害头部类的表现。一些工作提出了“多专家”去专攻不同的类别,最后汇总预测结果得到最终输出以改善性能。但是这些方法都是基于CNN的,而本文将“多专家”的思想代入transformer架构(ViT)。

对于第1点,之前的工作——DeiT通过对预训练的CNN架构模型知识蒸馏,改进了ViT的效率。而在本文中,作者希望将这一知识蒸馏的思想应用到长尾数据集上,并提高尾部类的分类性能。

还有一些工作虽然使得 ViT 在长尾识别任务上的性能有所提高,但它们通常需要在大规模数据集预训练。在这项工作中,目标是从头开始研究和改进视觉变换器的训练,而不需要对不同的长尾数据集(图像大小和分辨率各不相同)进行大规模预训练。

方法

回顾下DeiT[1],相较于ViT,多了DIS(distillation) token,它是教师模型对x的预测结果,作为\(\mathcal{L}_\text{teacher}\)的标签输入。

本文中的DeiT-LT,是在DeiT架构基础上:

  • 对于输入的样本使用强增强(文中管强增强后的样本为OOD样本)。
  • 增加了一个分类器用来表示尾部类专家,使用DRW(Deferred Re-weighting)loss优化。
  • 通过蒸馏,从扁平的教师模型学习低秩特征。

引入OOD样本的蒸馏

表中比较了教师、学生模型,是否使用强图像增强,和是否使用mixup(用X、√表示)的精度表现。

\[\begin{array}{c|cc|ccc|c}\hline\text{Tch} & \text{Stu} & \text{Tch} & \text{Tch} & \text{Stu} & \text{Train} \\\textbf{Model} & \text{Augs.} & \text{Augs.} & \text{Acc.} & \text{Acc.} & \text{Time} \\\hline\text{RegNetY16GF} & \text{Strong}\left(\checkmark\right) & \text{Strong}\left(\checkmark\right) & \text{79.1} & \text{70.2} & \text{33.3} \\\hline\text{ResNet-32} & \text{Strong}\left(X\right) & \text{Weak}\left(X\right) & \text{97.2} & \text{54.2} & \text{17.8} \\& \text{Strong}\left(X\right) & \text{Strong}\left(X\right) & \text{71.9} & \text{69.6} & \text{17.8} \\& \text{Strong}\left(\checkmark\right) & \text{Strong}\left(\checkmark\right) & \text{56.6} & \text{79.4} & \text{19.0} \end{array} \]

可看到,使用强增强训练的教师模型虽然精度下降了,但是学生模型的表现提升了。作者认为,这是学生模仿了老师对OOD的错误预测,从而学习到了老师的归纳偏置(例如下图的局部性),即\(f^d(X)\approx g(X),X\sim A(x)\)

由于OOD样本对教师的影响,使得教师模型的预测\(y_t\)与ground-truth \(y\)不同。下图展示了,随着epoch增加,头部类专家和尾部类专家的余弦距离(1-余弦相似度)的变化。两条主要的线分别表示在强增强/非强增强(OOD/ID)训练得到的教师模型。

这也揭示了一个现象,CLS token与预测的DIS token不那么一致,也能进行有效蒸馏。

作者引入了DRW:

\[\mathcal{L}=\frac12\mathcal{L}_{CE}(f^c(x),y)+\frac12\mathcal{L}_{DRW}(f^d(x),y_t),\ \mathrm{where~}\mathcal{L}_{DRW}=-w_{y_t} log(f^d(x)_{y_t}) \]

其中,\(w_y=1/\{1+(e_y-1)\mathbb{1}_{\mathrm{epoch\geq K}}\}\)\(e_y=\frac{1-\beta^{N_y}}{1-\beta}\),在上图也可以看到,DRW进一步增加了CLS token与DIS token间的多样性。

证明引入强增强(OOD)对知识蒸馏的有效性,可通过尾部类特征来判断。下图表示,尾部类特征的平均attention distance与transformer头的关系。可以看到没有OOD蒸馏的ViT和DeiT,过拟合了虚特征,使得尾部类泛化较差。

通过SAM 教师模型得到低秩特征

Sharpness Aware Minimization (SAM) 相当于在计算损失时,对模型参数增加扰动提高模型的泛化性。

对于低秩矩阵的计算,令\(\mathcal{X}_{all},\mathcal{X}_{min}\subset\mathcal{X}\),其中\(\mathcal{X}_{all}\)表示所有样本,\(\mathcal{X}_{min}\)表示尾部类样本。对应的特征矩阵为\(F_{n_h,d}^{all},\ F_{n_t,d}^{min}\),n表示样本数,d表示特征维度。对前者进行奇异值分解\(U,S,V^T=\mathsf{SVD}(F_{n_h,d}^{all})\),并使用右奇异值矩阵对\(F_{n_t,d}^{min}\)进行投影降维。对角阵k的取值满足

\[\frac{\left\|F_{n_t,d}^{min}-F_{recon}^{min}(k)\right\|^2}{\left\|F_{n_t,d}^{min}\right\|^2}\leq0.01 \]

其中\(F_{recon}^{min}(k)=F_{proj}^{min}(k)*{V_k}^T.\)

对比CLS token和DIS token在不同block中输出特征的秩。可以看到DIST token从多数类中学到判别性特征,充分保证了尾部类的学习。

参考文献

  1. Touvron, Hugo, et al. "Training data-efficient image transformers & distillation through attention." International conference on machine learning. PMLR, 2021.
  2. Rangwani, Harsh, et al. "DeiT-LT: Distillation Strikes Back for Vision Transformer Training on Long-Tailed Datasets." Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2024.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.ryyt.cn/news/63471.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈,一经查实,立即删除!

相关文章

MobaXterm24.2 分析

MobaXterm 目录MobaXterm0、启动窗口 TForm11、TForm1_FormCreatedecrypt_9FDA481)xxBase64Decode_9FD80C2)DecryptBytes_9FD9DC2、许可结构1) Type2) version_info_3A83) user_limit4) Version5) unuse6)NoGames7)NoPlugins解析函数parse_9FEB5Cothersub_A03F80TFormAbout…

ABC372 F 题解

ABC372 F 题解F - Teleporting Takahashi 2 先把问题转化一下:把环断开成链,复制 \((K + 1)\) 层,每走一步就相当于前进一层:可以想到一个简单的 dp:设 \(f(i, j)\) 表示走到第 \(i\) 层第 \(j\) 个位置的方案数。初始化:\(f(0, 1) = 1\),其它均为 \(0\),表示 Takahash…

【做题笔记】收集邮票 做题笔记

水。P4550 收集邮票展开目录 目录P4550 收集邮票ReadingStep 1Step 2Code彩蛋Reading \(k\ge 1\) 时,可以通过支付 \(k\) 元钱获得一张 \(n\) 种邮票中的某种邮票。这 \(n\) 种邮票等概率出现,求买到全部 \(n\) 种邮票的花费期望。 Step 1 \(k\) 次 \(k\) 元太难搞了,干脆直…

单机版 ClickHouse 部署和 SpringBoot 程序访问

ClickHouse 是俄罗斯的 Yandex 于 2016 年开源的列式存储数据库(DBMS),使用C++语言编写,主要用于在线分析处理查询(OLAP),能够使用SQL查询实时生成分析数据报告。 OLAP 为联机分析处理,专注于统计查询;OLTP 为联机事务处理,专注于增删改。 ClickHouse 的优势在于单表…

用户验收测试指南8实施测试

8 实施测试 到目前为止,我们已经规划了我们的 UAT 演习,并制定了测试的总体战略,然后设计了所有测试并编写了测试脚本。现在,我们已准备好实施计划和进行测试。 在本章中,我们将介绍如何安排所有测试,以实现我们的测试策略,并根据验收标准评估系统。为此,我们需要记录所…

Linux中删除文本中所有的重复的字符保持唯一

001、[root@PC1 test]# ls a.txt [root@PC1 test]# cat a.txt ## 测试文本 abk akkkkccc 8777 ,,, aaaf 333444 --- uukk22 [root@PC1 test]# cat a.txt | tr -s [:alnum:] ## 删除连续的重复字符 abk akc 87 ,,, af 34 --- uk2 [root@PC1 test]# …

ConcurrentLinkedQueue详解(图文并茂)

前言 ConcurrentLinkedQueue是基于链接节点的无界线程安全队列。此队列按照FIFO(先进先出)原则对元素进行排序。队列的头部是队列中存在时间最长的元素,而队列的尾部则是最近添加的元素。新的元素总是被插入到队列的尾部,而队列的获取操作(例如poll或peek)则是从队列头部…