摘要
1) 一句话总结
锐度感知最小化(SAM)是一种通过在损失曲面上寻找平坦极小值(而非仅仅最小化损失值)来有效提升过参数化深度学习模型泛化能力的优化算法。
2) 关键要点
- 背景与双重下降:现代深度学习模型通常是过参数化的,其泛化能力随模型规模增加会呈现“双重下降”曲线,因此选择合适的优化器对于跨越插值阈值并提升泛化能力至关重要。
- 锐度与泛化能力的关系:损失曲面上的“平坦极小值”(低锐度)与模型良好的泛化能力正相关,而“尖锐极小值”容易导致过拟合。
- SAM 的核心机制:SAM 不直接最小化损失函数 ,而是寻找给定半径 邻域内的最大损失(对抗性扰动),并最小化这个最大损失,从而迫使模型收敛于平坦区域。
- 算法执行步骤:单次 SAM 迭代包含两次前向传播和两次反向传播。第一步计算梯度以寻找对抗性扰动;第二步在扰动点计算梯度,并使用基础优化器更新原始权重。
- 代码实现要求:在 PyTorch 中实现 SAM 需要指定一个基础优化器(如 SGD 或 Adam)以及限制扰动大小的超参数 。
- BatchNorm 处理关键决策:对于包含批量归一化(BatchNorm)层的模型,必须在第二次(基于扰动权重的)前向传播期间显式禁用运行统计数据(running statistics)的更新,以防止污染原始模型的统计信息。
- 实验基准与设置:在 Fashion-MNIST 数据集上使用 PreAct ResNet-18(约 1120 万参数)进行图像分类测试。为保证公平对比,引入了“标准化轮次”(SAM 训练 1 轮对比非 SAM 训练 2 轮)。
- 实验数据结果:在 150 个标准化轮次后,SAM+SGD 的测试准确率略高于纯 SGD(92.5% vs 92.0%),但其泛化差距(训练与测试准确率之差)显著更低(2.3% vs 6.8%),证明其有效缓解了过拟合。
- 最佳应用场景:SAM 在小数据集上微调预训练模型时表现尤为出色。
3) 风险与缺口
- 计算成本翻倍与收敛变慢:由于每次权重更新需要两次完整的前向和反向传播,SAM 的计算开销更大,且在实验中达到近乎完美的训练准确率需要比标准 SGD 更多的训练轮次。
- BatchNorm 统计污染风险:如果在代码实现中遗漏了对 BatchNorm 状态的切换控制,中间步骤的扰动权重会错误地更新模型的全局统计数据,导致模型行为异常。
- 理论验证缺口:虽然实证表明 SAM 提升了泛化能力,但要严格证明模型确实收敛于更平坦的极小值,需要对训练后模型的 Hessian 谱进行复杂的分析,这在常规训练指标中无法直接体现。
正文
使用 SAM 优化深度学习模型
深入探讨锐度感知最小化(Sharpness-Aware-Minimization, SAM)算法及其如何提升现代深度学习模型的泛化能力。
引言:过参数化、泛化能力与 SAM
现代深度学习——特别是在计算机视觉和自然语言处理领域——取得的巨大成功,建立在“过参数化(overparameterized)”模型的基础之上:这些模型的参数量远超完美记忆训练数据所需的数量。从功能上看,当一个模型在特定任务上能够轻松达到近乎完美的训练准确率(接近 100%)且训练损失接近于零时,就可以被诊断为过参数化。
然而,这类模型的实用性取决于它在保留的测试数据上的表现,这些测试数据与训练集来自同一分布,但在训练期间是不可见的。这种属性被称为“泛化能力(generalizability)”——即模型在全新样本上保持性能的能力——这对于任何深度学习模型在实际应用中发挥作用都是至关重要的。
经典的机器学习理论告诉我们,过参数化模型应该会发生灾难性的过拟合,因此泛化能力会很差。然而,过去十年中最令人惊讶的发现之一是,这类模型通常具有非常出色的泛化能力。
这一极具反直觉的现象在一系列论文中得到了研究,始于 Belkin 等人(2018)和 Nakkiran 等人(2019)的开创性工作。他们证明了泛化能力存在一条“双重下降(double descent)”曲线:随着模型规模的增加,泛化能力首先会恶化(正如经典理论所预测的那样),然后在超过一个临界阈值后再次改善——前提是模型使用了合适的优化方法进行训练。
图 1 展示了双重下降曲线的示意图。y 轴表示测试误差(test error)——这是衡量泛化能力的一个指标,误差越低表示泛化能力越好——而 x 轴表示模型参数的数量。正如预期的那样,随着模型规模的增加,训练误差(蓝色虚线)迅速趋近于零。
测试误差(蓝色实线)表现出更有趣的行为:它最初随着模型规模的增加而下降——即第一次下降,由左侧的红圈标出——然后上升至垂直虚线标记的插值阈值(interpolation threshold)处的峰值,此时模型的泛化能力最差。然而,越过这个阈值进入过参数化区域后,测试误差再次下降——即第二次下降,由右侧的红圈标出——并随着参数的增加继续下降。这就是现代深度学习模型所关注的区域。
在机器学习中,人们通过最小化训练数据集上的损失函数来寻找模型参数。但是,仅仅在训练数据集上最小化我们常用的损失函数(如交叉熵),就能保证过参数化模型具有令人满意的泛化特性吗?一般来说,答案是否定的!无论你是想微调一个预训练模型,还是从头开始训练一个模型,优化你的训练算法以确保模型具有足够的泛化能力都是非常重要的。这使得优化器的选择成为一个关键的设计决策。
锐度感知最小化(Sharpness-Aware-Minimization, SAM)——由 Foret 等人(2019)在一篇论文中提出——是一种旨在提高过参数化模型泛化能力的优化器。在本文中,我将对 SAM 进行教学式的回顾,包括:
-
直观理解 SAM 的工作原理及其提高泛化能力的原因。
-
深入探讨该算法,解释其中涉及的关键数学步骤。
-
在训练循环中实现该优化器类的 PyTorch 代码,包括针对带有 BatchNorm 层的模型的一个重要注意事项。
-
快速演示该优化器在使用 ResNet-18 模型进行图像分类任务时提升泛化能力的有效性。
本文使用的完整代码可以在此 Github 仓库中找到——欢迎随意尝试!
锐度(Sharpness)的概念
首先,让我们试着直观地理解为什么仅仅最小化损失函数可能不足以获得最佳的泛化能力。
脑海中浮现出损失曲面(loss landscape)的画面会很有帮助。对于一个大型的过参数化模型,损失曲面具有多个局部和全局极小值。这些极小值周围的局部几何形状在曲面上可能有很大差异。例如,两个极小值可能具有几乎相同的损失值,但它们的局部几何形状却截然不同:一个可能是尖锐的(狭窄的山谷),而另一个可能是平坦的(宽阔的山谷)。
比较这些局部几何形状的一个正式指标是“锐度(sharpness)”。在损失函数为 L(w) 的损失曲面上的任意给定点 w 处,锐度 S(w) 定义为:
让我来拆解一下这个定义。想象你处于损失曲面上的点 w,你对参数进行扰动,使得新的参数始终位于以 w 为中心、半径为 ρ 的球体内。锐度就被定义为在这个扰动族内损失函数的最大变化量。在文献中,由于显而易见的原因,它也被称为最坏方向锐度(worst-direction sharpness)。
人们很容易看出,对于一个尖锐的极小值——一个陡峭、狭窄的山谷——损失函数的值会随着某些方向上的微小扰动而发生剧烈变化,从而导致较高的锐度值。另一方面,对于一个平坦的极小值——一个宽阔的山谷——损失函数的值会随着微小扰动而相对缓慢地变化,从而导致较低的锐度值。因此,锐度提供了一个衡量损失曲面中给定极小值平坦程度的指标。
极小值的局部几何形状(尤其是锐度指标)与所得模型的泛化特性之间存在着深刻的联系。在过去十年中,大量的理论和实证研究致力于阐明这种联系。例如——正如 Keskar 等人(2016)的论文所指出的——具有相似损失函数值的全局极小值,根据其锐度指标的不同,可能会具有截然不同的泛化特性。
从这些研究中得出的基本经验似乎是:更平坦(锐度更低)的极小值与模型更好的泛化能力呈正相关。特别是,如果模型想要具有良好的泛化能力,就应该在训练过程中避免陷入尖锐的极小值。因此,为了训练出一个具有良好泛化能力的模型,需要确保优化过程不仅最小化损失函数,而且还要寻求最大化极小值的平坦度(或者等效地,最小化锐度)。
这正是 SAM 优化器旨在解决的问题,也是我们在下一节中要讨论的内容。
顺便提一下:请注意,上述图景从概念上解释了为什么过参数化模型有可能避免过拟合问题。这是因为大型模型具有丰富的损失曲面,提供了大量具有优异泛化特性的平坦全局极小值。
锐度感知最小化(SAM)算法
让我们回顾一下模型的标准优化过程。它涉及寻找能够最小化在小批量(mini-batch)B 上计算的给定损失函数的模型参数。在每个时间步,计算损失相对于参数的梯度,并根据以下规则更新参数:
与 SGD 或 Adam 不同,SAM 并不直接最小化 L。相反,在损失曲面上的给定点,它首先扫描给定大小 ρ 的邻域,并找到使损失函数最大化的扰动。在第二步中,它最小化这个最大损失函数。这使得优化器能够找到位于具有均匀低损失值的邻域中的参数,从而产生更小的锐度值和更平坦的极小值。
让我们更详细地讨论一下这个过程。SAM 优化器的损失函数为:
其中 ρ 表示扰动大小的上限。使函数 L 最大化的扰动(通常称为对抗性扰动,因为它最大化了常规损失)可以通过注意以下几点来找到:
其中第二个等式是通过在第一步中对扰动函数进行泰勒展开获得的近似值,最后一个等式是因为上一步方括号中的第一项与 ϵ 无关。最后一个等式可以按如下方式求解对抗性扰动:
将其代回 SAM 损失的方程中,可以计算出 SAM 损失在 ϵ 导数的领头阶(leading order)上的梯度:
这是优化过程中最关键的方程。在 ϵ 导数的领头阶上,SAM 损失函数的梯度可以近似为在对抗性扰动点处评估的常规损失函数的梯度。使用上述梯度公式,现在可以执行标准优化器步骤:
这就完成了一次完整的 SAM 迭代。接下来,让我们将该算法从英语翻译成 PyTorch 代码。
训练循环中的 PyTorch 实现
代码块 sam_training_loop.py 中给出了一个带有 SAM 优化器的训练循环的说明性示例。为了具体起见,我们选择了一个通用的图像分类问题,但相同的结构广泛适用于各种计算机视觉和 NLP 任务。SAM 优化器类显示在代码块 sam_optimizer_class.py 中。
请注意,定义 SAM 优化器需要指定两项数据:
-
一个基础优化器(如 SGD 或 Adam),因为 SAM 最终包含一个标准优化器步骤。
-
一个超参数 ρ,它为允许的扰动大小设定了上限。
优化器的单次迭代涉及两次前向传播和两次反向传播。让我们梳理一下 sam_training_loop.py 中代码的关键步骤:
-
第 5 行计算当前小批量 B 的损失函数 L(w, B)——第一次前向传播。
-
第 6 行计算损失函数 L(w, B) 的梯度——第一次反向传播。
-
第 7 行调用 SAM 优化器类(见下文)中的
sam_optimizer.first_step函数,该函数使用上述公式计算对抗性扰动,并如前所述对模型的权重进行扰动。 -
第 10 行计算扰动后模型的损失函数——第二次前向传播。
-
第 11 行计算扰动后模型的损失函数的梯度——第二次反向传播。
-
第 12 行调用优化器类(见下文)中的
sam_optimizer.second_step函数,该函数将权重恢复为 w_t,然后使用基础优化器利用在扰动点计算出的梯度来更新权重 w_t。
注意事项:SAM 与 BatchNorm
如果在训练循环中部署 SAM,且模型包含任何带有批量归一化(batch-normalization, BatchNorm)层的模块,则需要牢记一个重要的一点。在训练期间,BatchNorm 使用当前批次的统计数据执行归一化,并在每次前向传播时更新运行统计数据(running statistics)。在评估期间,它使用运行统计数据。
现在,正如我们上面看到的,SAM 每次迭代涉及两次前向传播。在第一次传播中,BatchNorm 以标准方式工作。然而,在第二次传播期间,我们使用扰动后的权重来计算损失,而代码块 sam_training_loop.py 中朴素的训练函数将允许 BatchNorm 层在第二次传播期间也更新运行统计数据。这是不可取的,因为运行统计数据应该只反映原始模型的行为,而不是扰动后模型的行为,后者只是计算梯度的中间步骤。因此,必须在第二次传播期间显式禁用运行统计数据的更新,并在下一次迭代之前启用它。
为此,我们将在训练循环中使用两个显式函数 disable_bn_stats 和 enable_bn_stats——代码块 running_stat.py 中展示了此类函数的简单示例——它们切换 PyTorch 中 BatchNorm 函数的 track_running_stats 参数(第 4 行和第 9 行)。修改后的训练循环在代码块 mod_train.py 中给出。
演示:使用 ResNet-18 进行图像分类
最后,让我们通过一个具体的例子来演示 SAM 优化如何提高模型的泛化能力。我们将考虑一个使用 Fashion-MNIST 数据集(MIT 许可证)的图像分类问题:它包含 60,000 张训练图像和 10,000 张测试图像,分为 10 个互斥的类别,每张图像都是 28*28 像素的灰度图。
作为分类器模型,我们将选择一个没有任何预训练的 PreAct ResNet-18。虽然讨论精确的 ResNet-18 架构与我们的目的关系不大,但让我们回顾一下,该模型由一系列构建块组成,每个构建块由卷积层、BatchNorm 层、带有跳跃连接(skipped connections)的 ReLU 激活函数组成。PreAct(预激活)表示在每个块中,激活函数(ReLU)位于卷积层之前。对于标准的 ResNet-18,情况正好相反。我建议读者参考论文——He 等人(2015)——以获取有关该架构的更多详细信息。
然而,需要注意的是,该模型拥有约 1120 万个参数,因此从经典机器学习的角度来看,它是一个过参数化模型,参数与样本的比例约为 186:1。此外,由于该模型包含 BatchNorm 层,在使用 SAM 时,我们必须小心地在第二次传播时禁用运行统计数据。
我们现在准备进行以下实验。我们首先使用标准的 SGD 优化器在 Fashion-MNIST 数据集上训练模型,然后使用以相同的 SGD 作为基础优化器的 SAM 优化器进行训练。我们将考虑一个简单的设置,固定学习率 lr=0.05,动量(momentum)和权重衰减(weight-decay)均设置为零。SAM 中的超参数 ρ 设置为 0.05。所有运行均在单张 A100 GPU 上执行。
由于每次 SAM 权重更新都需要两个反向传播步骤——一个用于计算扰动,另一个用于计算最终梯度——为了公平比较,每次非 SAM 训练运行执行的轮数(epochs)必须是每次 SAM 训练运行的两倍。因此,我们将不得不把 SAM 训练运行一轮的指标与非 SAM 训练运行两轮的指标进行比较。我们将此称为“标准化轮次(standardized epoch)”,在标准化轮次记录的指标将被标记为 metric_st。我们将实验限制在 150 个标准化轮次,这意味着 SAM 训练运行 150 轮,而非 SAM 训练运行 300 轮。我们将对经过 SAM 优化的模型额外训练 50 轮,以了解模型在更长时间训练下的表现。
为了检查哪种优化器能提供更好的泛化能力,我们将在每个标准化训练轮次后比较以下两个指标:
-
测试准确率(Test accuracy):模型在测试数据集上的表现。
-
泛化差距(Generalizability gap):训练准确率与测试准确率之间的差值。
测试准确率是衡量模型在一定数量的训练轮次后泛化能力好坏的绝对指标。另一方面,泛化差距是一种诊断指标,它告诉你模型在特定训练阶段的过拟合程度。
让我们首先比较 training_loss_st 和 training_accuracy_st 图表,如图 3 所示。正如对过参数化模型的预期,使用 SGD 的模型在 150 轮内达到了接近零的损失和接近 99% 的训练准确率。很明显,与 SGD 相比,SAM 的训练速度较慢,并且需要更多的标准化轮次才能达到近乎完美的训练准确率。这一点显而易见,因为当对 SAM 优化模型进行超过规定的 150 轮的训练时,训练损失和训练准确率都在继续改善。
测试准确率。图 4 中的图表比较了两种情况在每个标准化轮次后的测试准确率。
SGD 优化模型在第 50 轮左右达到 92% 的测试准确率,并在接下来的 100 轮中稳定在该值附近。SAM 优化模型在训练的初始阶段(直到大约 80 轮)泛化能力较差,这从该阶段与 SGD 图表相比更低的测试准确率中可以明显看出。然而,在大约第 80 轮时,它赶上了 SGD 图表,并最终以微弱优势超越了它。
对于这次特定的运行,在 150 轮结束时,SAM 的测试准确率为 test_SAM = 92.5%,而 SGD 的测试准确率为 test_SGD = 92.0%。请注意,尽管此时 SAM 训练的模型具有低得多的训练准确率和训练损失,但仍取得了这样的结果。如果将 SAM 模型再训练 50 轮,测试准确率会略微提高到 92.7%。
泛化差距。在训练过程中,每个标准化轮次后泛化差距的演变如图 5 所示。
SGD 模型的差距随着训练稳步增长,在 150 轮后达到 gap_SGD=6.8%,而 SAM 的差距增长要慢得多,达到 gap_SAM=2.3%。在进一步训练 50 轮后,SAM 的差距攀升至 3% 左右,但与 SGD 的值相比仍然低得多。
虽然在 Fashion-MNIST 数据集上,两种优化器之间的测试准确率差异很小,但在泛化差距上存在着不可忽视的差异,这证明了使用 SAM 进行优化可以带来更好的泛化能力。
结语
在本文中,我以教学的方式回顾了 SAM 作为一种显著提高过参数化深度学习模型泛化能力的优化器。我们讨论了 SAM 背后的动机和直觉,逐步分解了该算法,并研究了一个简单的例子,证明了它与标准 SGD 优化器相比的有效性。
SAM 还有几个有趣的方面我在这里没有机会涉及。让我简要提及其中两个。首先,作为一种实用工具,SAM 在小数据集上微调预训练模型时特别有用——Foret 等人(2019)针对 CNN 类型的架构详细探讨了这一点,随后的许多工作也针对更通用的架构进行了探讨。其次,既然我们以损失曲面中平坦极小值与泛化能力之间的联系开始了讨论,那么很自然地会问:一个经过 SAM 训练的模型(已被证明能提高泛化能力)是否真的收敛到了一个更平坦的极小值?这是一个不简单的问题,需要仔细分析训练后模型的 Hessian 谱,并与其经过 SGD 训练的对应模型进行比较。但那是另一个话题了!
感谢阅读!如果你喜欢这篇文章,并且有兴趣阅读更多关于深度学习的教学文章,请在 Medium 和 LinkedIn 上关注我。除非另有说明,本文中使用的所有图像和图表均由作者生成。
作者
分享本文
-
在 Facebook 上分享
-
在 LinkedIn 上分享
-
在 X 上分享
Towards Data Science 是一个社区出版物。提交您的见解以触达我们的全球受众,并通过 TDS 作者付款计划赚取收益。