论文标题

冻结然后训练:在虚假相关和特征噪音下进行可证明的代表学习

Freeze then Train: Towards Provable Representation Learning under Spurious Correlations and Feature Noise

论文作者

Ye, Haotian, Zou, James, Zhang, Linjun

论文摘要

在训练环境中的图像背景等虚假相关性的存在可能会使经验风险最小化(ERM)在测试环境中表现不佳。为了解决这个问题,Kirichenko等。 (2022)从经验上发现,即使存在虚假相关性,与结果相关的核心特征仍然可以很好地学习。这为首先训练功能学习者而不是分类器开设了有希望的策略,然后在测试环境中执行线性探测(最后一层再培训)。但是,对这种方法何时以及为什么缺乏这种方法的理论理解。在本文中,我们发现,只有当其相关的不可通知的噪声小于虚假特征时,核心特征才能很好地学习,这在实践中不一定是正确的。我们既提供理论和实验,以支持这一发现并说明不可交流噪声的重要性。此外,我们提出了一种称为冻结然后训练(FTT)的算法,该算法首先冻结某些显着特征,然后使用ERM训练其余功能。从理论上讲,我们表明FTT保留了对测试时间探测更有益的功能。在两个常用的虚假相关数据集中,FTT的表现优于ERM,IRM,JTT和CVAR-DRO,当特征噪声较大时,准确性(4.5%)的准确性大大提高。 FTT在一般分配偏移基准测试方面的性能也更好。

The existence of spurious correlations such as image backgrounds in the training environment can make empirical risk minimization (ERM) perform badly in the test environment. To address this problem, Kirichenko et al. (2022) empirically found that the core features that are related to the outcome can still be learned well even with the presence of spurious correlations. This opens a promising strategy to first train a feature learner rather than a classifier, and then perform linear probing (last layer retraining) in the test environment. However, a theoretical understanding of when and why this approach works is lacking. In this paper, we find that core features are only learned well when their associated non-realizable noise is smaller than that of spurious features, which is not necessarily true in practice. We provide both theories and experiments to support this finding and to illustrate the importance of non-realizable noise. Moreover, we propose an algorithm called Freeze then Train (FTT), that first freezes certain salient features and then trains the rest of the features using ERM. We theoretically show that FTT preserves features that are more beneficial to test time probing. Across two commonly used spurious correlation datasets, FTT outperforms ERM, IRM, JTT and CVaR-DRO, with substantial improvement in accuracy (by 4.5%) when the feature noise is large. FTT also performs better on general distribution shift benchmarks.

扫码加入交流群

加入微信交流群

微信交流群二维码

扫码加入学术交流群,获取更多资源