批归一化到底做了什么?DeepMind研究者进行了拆解
作者:Soham De、Samuel L. Smith
机器之心编译
参与:魔王
批归一化有很多作用,其最重要的一项功能是大幅提升残差网络的最大可训练深度。DeepMind 这项研究探寻了其中的原因,并进行了大量验证。
论文链接:https://arxiv./abs/2002.10444
批归一化用处很多。它可以改善损失分布(loss landscape),同时还是效果惊人的正则化项。但是,它最重要的一项功能出现在残差网络中大幅提升网络的最大可训练深度。
DeepMind 近期一项研究找到了这项功能的原因:在初始化阶段,批归一化使用与网络深度的平方根成比例的归一化因子来缩小与跳跃连接相关的残差分支的大小。这可以确保在训练初期,深度归一化残差网络计算的函数由具备表现良好的梯度的浅路径(shallo path)主导。
该研究基于此想法开发了一种简单的初始化机制,可以在不使用归一化的情况下训练非常深的残差网络。研究者还发现,尽管批归一化可以维持模型以较大的学习率进行稳定训练,但这只在批大小较大的并行化训练中才有用。这一结果有助于厘清批归一化在不同架构中的不同功能。
批归一化到底干了什么
跳跃连接和批归一化结合起来可以大幅提升的最大可训练深度。
DeepMind 研究者将残差网络看作多个路径的集成,这些路径共享权重,但是深度各有不同(与 Veit 等人 2016 年的研究类似),进而发现批归一化如何确保非常深的残差网络(数万层)在训练初期被仅包含几十个层的浅路径主导。原因在于,批归一化使用与网络深度的平方根成比例的因子缩小与跳跃连接相关的残差分支的大校这就为深度归一化残差网络在训练初期可得到高效优化提供了直观解释,它们只是把具备表现良好的梯度的浅层网络集成起来罢了。
上述观察表明,要想在不使用归一化或不进行认真初始化的前提下训练深度残差网络,只需要缩小残差分支即可。
为了确认这一点,研究者改动了一行代码,实现不使用归一化的深度残差网络训练(SkipInit)。结合额外的正则化后,SkipInit 网络的性能可与经过批归一化的对应网络不相上下(该网络使用常规的批大小设置)。
为什么深度归一化残差网络是可训练的?
残差分支经过归一化后,假设 f_i 的输出方差为 1。每个残差块的方差增加 1,则第 i 个残差块前的激活的预期方差为 i。因此,对于任意遍历第 i 个残差分支的路径,其方差缩小到 1/i,这说明隐藏层激活缩小到 1/√ i。
如图 3 所示,该缩小因子很强大,可确保具备 10000 个残差块的网络 97% 的方差来自遍历 15 个或者更少残差分支的浅路径。典型残差块的深度与残差块总数 d 成比例,这表明批归一化将残差分支缩小到 1/√ d。
图 3:此图模拟了初始化阶段不同深度的路径对 logits 方差的贡献。
为了验证这一观点,研究者评估两个归一化残差网络的不同通道的方差,以及批统计量(batch statistics),如下图 4 所示。
图 4(a) 中,深度线性 ResNet 的跳跃路径方差几乎等于当前深度 i,而每个残差分支末端的方差约为 1。这是因为批归一化移动方差约等于深度,从而证实归一化将残差分支缩小到原本的 1/√ i。
图 4(b) 中,研究者在 CIFAR-10 数据集上评估使用 ReLU 激活函数的卷积 ResNet。跳跃路径的方差仍与深度成正比,但系数略低于 1。这些关联也导致批归一化移动平均数的平方随着深度的增加而增大。
图4。
这就为「深度归一化残差网络是可训练的」提供了简洁的解释。这一观点可以扩展至其他归一化方法和模型架构。
SkipInit:归一化的替代方案
研究者发现,归一化之所以能够确保深度残差网络的训练,是因为它在初始化阶段按与网络深度平方根成正比的归一化因子缩小残差分支。
为了验证该观点,研究者提出了一个简单的替代方法SkipInit:在每个残差分支末端放置一个标量乘数,并将每个乘数初始化为 α。
图 1:A) 使用批归一化的残差块。B) SkipInit 用一个可学习标量 α 替代了批归一化。
移除归一化之后,只需改动一行代码即可实现 SkipInit。研究者证明,按 (1/ √ d) 或更小的值初始化 α 就可以训练深度残差网络(d 表示残差块数量)。
研究者引入了 Fixup,它也可以确保残差块在初始化时表示 identity。但是,Fixup 包含多个额外组件。在实践中,研究者发现 Fixup 的组件 1 或组件 2 就足以在不使用归一化的前提下训练深度 ResNet-V2 了。
实证研究
下表 1 展示了 n-2 Wide-ResNet 在 CIFAR-10 数据集上训练 200 epoch 后的平均性能,模型深度 n 在 16 到 1000 层之间。
表 1:批归一化使得我们可以训练深度残差网络。然而在残差分支末端添加标量乘数 α 后,不使用归一化也能实现同样的效果。
下表 2 验证了,当 α = 1 时使用 SkipInit 无法训练深度残差网络,因此必须缩小残差分支。研究者还确认了,对于未经归一化的残差网络,只确保激活函数不在前向传播上爆炸还不够(只需在每次残差分支和跳过路径合并时将激活乘以 (1/ √ 2) 即可实现)。
表 2:如果 α = 1,我们无法训练深度残差网络。
批归一化的主要功能是改善损失分布,增加最大稳定学习率。下图 5 提供了 16-4 Wide-ResNet 在 CIFAR-10 数据集上训练 200 epoch 后的平均性能,批大小的范围很大。
图 5:使用批归一化要比不使用获得的测试准确率更高,研究者还能够以非常大的批大小执行高效训练。
为了更好地理解批归一化网络能够以更大批大小进行高效训练的原因,研究者在下图 6 中展示了最优学习率,它可以最大化测试准确率、最小化训练损失。
图 6:使用和不使用批归一化情况下的最优学习率。
研究者在 ImageNet 数据集上对 SkipInit、Fixup 初始化和批归一化进行了实验对比,证明 SkipInit 可扩展至大型高难度数据分布。
下表 3 展示了最优验证准确率。研究者发现卷积层包含偏置可使 SkipInit 的验证准确率出现小幅提升,因此研究者在所有 SkipInit 运行中添加了偏置。SkipInit 的验证性能与批归一化相当,与使用标准批大小 256 的 Fixup 相当。但是,当批大小非常大时,SkipInit 和 Fixup 的性能不如批归一化。
表 3:研究者训练了 90 个 epoch,并执行网格搜索,以找出最优学习率,从而最大化模型在 ImageNet 数据集上的 -1 验证准确率。
本文为机器之心编译,转载请联系本公众号获得授权。
------------------------------------------------