论文标题
免费的敏锐感知培训
Sharpness-Aware Training for Free
论文作者
论文摘要
现代深层神经网络(DNNS)已经实现了最先进的表现,但通常被过度参数化。在没有其他定制培训策略的情况下,过度参数可能会导致不必要的概括误差。最近,以清晰度最小化的名称(SAM)的一系列研究表明,最小化的清晰度度量反映了损失景观的几何形状,可以显着减少概括误差。但是,类似SAM的方法会产生给定碱基优化器(例如SGD)的两倍计算开销,以近似清晰度。在本文中,我们提出了免费或SAF的清晰度感知培训,该培训在基础优化器上以几乎为零的额外计算成本减轻锋利的景观。凭直觉,SAF通过避免在重量更新的整个轨迹中避免急剧局部最小值的损失突然下降来实现这一目标。具体而言,我们建议基于当前权重和过去重量的DNN的输出之间的KL差异,以替代SAM的清晰度度量。这种损失捕获了沿模型更新轨迹的训练损失的变化率。通过将其最小化,SAF确保收敛到平坦的最小值,并提高概括能力。广泛的经验结果表明,SAF以与SAM相同的方式最小化清晰度,从而在ImageNet数据集上获得更好的结果,其计算成本与基础优化器基本相同。
Modern deep neural networks (DNNs) have achieved state-of-the-art performances but are typically over-parameterized. The over-parameterization may result in undesirably large generalization error in the absence of other customized training strategies. Recently, a line of research under the name of Sharpness-Aware Minimization (SAM) has shown that minimizing a sharpness measure, which reflects the geometry of the loss landscape, can significantly reduce the generalization error. However, SAM-like methods incur a two-fold computational overhead of the given base optimizer (e.g. SGD) for approximating the sharpness measure. In this paper, we propose Sharpness-Aware Training for Free, or SAF, which mitigates the sharp landscape at almost zero additional computational cost over the base optimizer. Intuitively, SAF achieves this by avoiding sudden drops in the loss in the sharp local minima throughout the trajectory of the updates of the weights. Specifically, we suggest a novel trajectory loss, based on the KL-divergence between the outputs of DNNs with the current weights and past weights, as a replacement of the SAM's sharpness measure. This loss captures the rate of change of the training loss along the model's update trajectory. By minimizing it, SAF ensures the convergence to a flat minimum with improved generalization capabilities. Extensive empirical results show that SAF minimizes the sharpness in the same way that SAM does, yielding better results on the ImageNet dataset with essentially the same computational cost as the base optimizer.