Oscillation-Reduced MXFP4 Training for Vision Transformers

作者: Yuxiang Chen, Haocheng Xi, Jun Zhu, Jianfei Chen

发布时间: 2025-03-03

来源: arxiv

研究方向: 机器学习,低精度训练,视觉Transformer

主要内容

本文提出了TetraJet,一种基于MXFP4格式的视觉Transformer训练方法,旨在提高4位精度训练的准确性,并通过引入EMA Quantizer(Q-EMA)和Adaptive Ramping Optimizer(Q-Ramping)来减少训练过程中的振荡问题。

主要贡献

1. 提出了TetraJet,一种新的基于MXFP4的视觉Transformer训练方法,提高了4位精度训练的准确性。

2. 识别了前向传播中权重振荡是MXFP4训练精度损失的主要原因,并提出了Q-EMA和Q-Ramping来解决振荡问题。

3. 通过实验验证了TetraJet在视觉Transformer上的有效性,并证明了Q-EMA和Q-Ramping可以进一步提高性能。

4. 将MXFP4训练与全精度训练的性能差距缩小了50%以上。

研究方法

1. MXFP4格式量化

2. 双量化策略

3. 无截断缩放方法

4. 随机舍入

5. EMA量化器(Q-EMA)

6. 自适应调整优化器(Q-Ramping)

实验结果

在视觉Transformer预训练实验中,TetraJet在准确性方面优于现有的4位训练方法,Q-EMA和Q-Ramping通过减少振荡提供了额外的提升。与基线相比,准确性损失降低了50%以上,甚至可以达到与全精度训练相媲美的性能。

未来工作

探索更有效的低精度训练方法,提高MXFP4训练的准确性和效率,并进一步研究振荡问题的解决方案。