Oscillation-Reduced MXFP4 Training for Vision Transformers
作者: Yuxiang Chen, Haocheng Xi, Jun Zhu, Jianfei Chen
发布时间: 2025-03-03
来源: arxiv
研究方向: 机器学习,低精度训练,视觉Transformer
主要内容
本文提出了TetraJet,一种基于MXFP4格式的视觉Transformer训练方法,旨在提高4位精度训练的准确性,并通过引入EMA Quantizer(Q-EMA)和Adaptive Ramping Optimizer(Q-Ramping)来减少训练过程中的振荡问题。
主要贡献
1. 提出了TetraJet,一种新的基于MXFP4的视觉Transformer训练方法,提高了4位精度训练的准确性。
2. 识别了前向传播中权重振荡是MXFP4训练精度损失的主要原因,并提出了Q-EMA和Q-Ramping来解决振荡问题。
3. 通过实验验证了TetraJet在视觉Transformer上的有效性,并证明了Q-EMA和Q-Ramping可以进一步提高性能。
4. 将MXFP4训练与全精度训练的性能差距缩小了50%以上。
研究方法
1. MXFP4格式量化
2. 双量化策略
3. 无截断缩放方法
4. 随机舍入
5. EMA量化器(Q-EMA)
6. 自适应调整优化器(Q-Ramping)
实验结果
在视觉Transformer预训练实验中,TetraJet在准确性方面优于现有的4位训练方法,Q-EMA和Q-Ramping通过减少振荡提供了额外的提升。与基线相比,准确性损失降低了50%以上,甚至可以达到与全精度训练相媲美的性能。
未来工作
探索更有效的低精度训练方法,提高MXFP4训练的准确性和效率,并进一步研究振荡问题的解决方案。