Group-robust Sample Reweighting for Subpopulation Shifts via Influence Functions

作者: Rui Qiao, Zhaoxuan Wu, Jingtan Wang, Pang Wei Koh, Bryan Kian Hsiang Low

发布时间: 2025-03-11

来源: arxiv

研究方向: 机器学习中的群体鲁棒性

主要内容

该研究提出了一种名为 Group-robust Sample Reweighting (GSR) 的方法,旨在提高机器学习模型对群体偏移的鲁棒性。GSR 通过利用群体标签数据来优化未标记数据的权重,从而提高模型的泛化能力。

主要贡献

1. 提出了一种新的群体鲁棒样本重加权方法 GSR,该方法通过使用群体标签数据来优化未标记数据的权重,从而提高模型的鲁棒性。

2. 设计了基于隐式微分和影响函数的样本重加权策略,通过使用 Hessian 矩阵来准确估计样本权重更新的梯度,无需回溯整个训练轨迹。

3. 在视觉和自然语言处理任务上进行了实验,证明了 GSR 在提高群体鲁棒性方面的有效性。

研究方法

1. Group-robust Sample Reweighting (GSR):一种两阶段方法,首先从群体未标记数据中学习表示,然后使用影响函数迭代地重新训练模型的最后一层。

2. Last-layer Retraining (LLR):一种轻量级方法,通过仅重新训练神经网络的最后一层来减少虚假相关性并提高群体鲁棒性。

3. Influence Function:一种技术,可以估计如果将训练数据点无限小地加权,模型参数将如何变化。

4. Adaptive Aggregation:一种方法,通过根据群体错误率乘积更新聚合权重,以更好地优化样本权重。

实验结果

GSR 在多个数据集上进行了实验,包括 Waterbirds、CelebA、MultiNLI 和 CivilComments。结果表明,GSR 在提高群体鲁棒性方面优于现有的最先进方法,包括需要更多群体标签的方法。

未来工作

未来工作将探索如何提高 GSR 的性能,包括:1) 缓解最坏群体性能和平均性能之间的权衡;2) 直接提高表示学习质量;3) 开发更有效的策略,以同时优化样本权重和深度模型参数。