2021年6月3日

ResNetからSkip Connectionを取り除く研究についてのサーベイ

Deep LearningSkip Connection

LeapMindの徳永です。

ResNet、使ってますか?

ResNetからSkip Connectionを取り除きたくなること、長い人生で、たまにはありますよね。今日はそんな話をしたいと思います。

はじめに

Residual Networkは、画像処理においてよく用いられる、代表的なネットワークの一つです。

2015年に提案されたこのネットワークは、学習のしやすさにおいて、それまでのものとは一線を画していました。

さて、そんなResidual NetworkはSkip Connection、Batch NormalizationKaiming Initializationの3つの組み合わせでできた手法です。これらのどれかが欠けてしまうと、学習が遅くなったり、そもそも性能が出なくなったりしてしまいます。例えば、Skip Connectionがないと、層数が増えた場合に精度が向上しないどころか、逆に低下してしまいます。

このSkip Connectionですが、実用上は以下のような課題があります。

  • 推論時のメモリ消費量が増える
  • 推論時に計算量の割に実際の計算が重たくなりがち(特にDNN専用アクセラレーターにおいてその傾向がありがち)

このため「Skip Connectionなしで精度を出したい」という研究がいくつも試されてきました。今回は、それらの研究の中から、「Residual NetworkからSkip Connectionを取り除く」ことを主眼とした研究を中心に紹介します。DenseNetのように大きくネットワークの形が変わるものについては、今回は触れません。

DiracNets

Zagoruykoらは、Skip Connectionを使う代わりに、似たような挙動をする畳み込みパラメーターをデザインしました。

y=x+σ(Wx)\begin{equation} \boldsymbol{y} = \boldsymbol{x} + \sigma (\boldsymbol{W} \odot \boldsymbol{x}) \end{equation}

これに対して、DiracNetsでは、次式のように、畳み込みの結果に対してx\boldsymbol{x}を足し合わせた後で非線形変換を行います。

y=σ(x+Wx)\begin{equation} \boldsymbol{y} = \sigma\big(\boldsymbol{x} + \boldsymbol{W} \odot \boldsymbol{x}\big) \end{equation}

DNNにおける畳み込みは線形変換なので、これは結局、畳み込みパラメーターに対して同次元のidentity matrixを足し合わせることと等価であることがわかります。

y=σ(x+Wx)=σ((I+W)x)\begin{equation} \boldsymbol{y} = \sigma\big(\boldsymbol{x} + \boldsymbol{W} \odot \boldsymbol{x}\big) = \sigma\big((\boldsymbol{I} + \boldsymbol{W}) \odot \boldsymbol{x}\big) \end{equation}

つまり、畳み込みパラメーターに対してidentity matrixを足し合わせることで、Skip Connectionそのものではありませんが、似たような効果が期待できます。実際の仕組みとしては、identity matrixの部分には学習可能なscaling parameterが乗されており、もう少し複雑ではありますが、DiracNetsのお気持ちはこの式で表現されるものがほぼ全てです。

Plainな34層のネットワークでは18層のネットワークよりもむしろ精度が低下することが知られていますが、DiracNetsを用いることで、34層のネットワークの方が18層のネットワークよりも高精度になります。ただし、Skip Connectionを用いた場合と比べると、精度は幾分低下します。

Avoiding degradation in deep feed-forward networks by phasing out skip-connections

Montiらは、Skip Connectionありのネットワークから学習を開始し、最終的にネットワークからSkip Connectionを取り除く手法を提案しました。

具体的には、Skip Connectionの部分に(1-α\alpha)を係数として乗じました。- α\alphaが1であればPlainCNNに、α\alphaが0であればResNetになります。最終的にはα=1\alpha=1になっていてほしいので、目的関数にラグランジュ変数を導入し、L+λ(α1)L + \lambda (\alpha-1)を最適化します。

残念ながら、実験結果は他の手法と適切に比較できるようになっておらず、どれくらいの効果があるのかをフェアに比較できる状態にはなっていません。

Residual Distillation

Liらは、まずResNetを学習し、それを教師としてPlainCNNを学習する手法を提案しました。通常のDistillationでは最終層の類似度のみを用いますが、Residual Distillationでは、生徒ネットワークの中間層の出力を教師ネットワークの中間に入れてからbackpropして勾配を得ます。以下に特徴を箇条書きで挙げます。

  • LIT: Learned Intermediate Representation Training for Model Compressionと似ているが、こちらが中間層の類似度のノルムをloss functionに加えるのに対し、Residual Distillationでは、生徒ネットワークの中間データを別のネットワークに流し込むという荒業を用いています。手元での再現実験では、理由はよくわかりませんが、Residual Distillationの方が高い効果が得られました。ただ、学習はResidual Distillationの方がかなり遅くなります。
  • 生徒ネットワークの初期化にはDiracNetのやり方を使います。
  • CIFAR-10/100では50層にしても性能が下がらない、程度の効果しかないが、ImageNetでは34層より50層の方が3ptくらい精度が上がります。
  • さらに通常のKDと組み合わせるともっと精度が上がります。

注目すべきは、50層のPlainCNNで76.08%のtop-1 accuracyを達成していることです。ResNet-50は76.11%なので、ほぼ同精度を達成できていると言って良いでしょう。

また、この論文は、単にSkip Connectionを取り除ける、というだけではなく、取り除くことで、最大メモリ消費量がどれくらい減るか、mobile NPUでの実行速度がどれくらい向上するのかを定量的に評価しています。(メモリ消費量は20%弱削減され、実行速度は20%〜30%程度高速になります。)隙がないですね。

Residual Distillationは50層でもResNetとほぼ変わらない性能をPlainCNNで実現できており、かなり有望な手法です。今のところ、この手法が決定版と言ってよいでしょう。

RepVGG

Dingらは、DiracNetsの考え方を発展させ、3x3convと1x1convとDiracNetsを組み合わせたRepVGGというネットワークを提案しました。肝となるアイデアは、これらが最終的に単なる1つの3x3convに置き換え可能であることです。

最終的なRepVGGネットワークは、GPU上では実行速度を基準にして同程度のモデルであれば、ResNetやEfficientNetよりも精度が高いネットワークとなっています。

手法としてはDiracNetsの発展形となっており、これでResNetからSkip Connectionをなくしたらどうなるのか、という点に興味がわきますが、残念ながら論文中にはそのような実験結果は載っていませんでした。コードは公開されているので、いつか自分で実験したみたいところです。

まとめ

Skip Connectionを使わずに深いネットワークで精度を出すための研究をいくつか紹介しました。これらの手法は技術的に面白いだけではなく、ネットワーク構造がシンプルになり、ハードウェアに優しい、という利点があります。数年後には一般的に使われる技術になっているかもしれません。かもしれませんと書きましたが、個人的には、使われるようになっている可能性がかなり高いと予想しています。

ところで、LeapMindでは、Skip Connectionを使わないで高い精度を出すニューラルネットワークに興味がある方もそうでない方も、様々な職種で現在採用活動を実施しております。特に機械学習、深層学習が得意なエンジニアを強く募集しておりますので、興味がある方はぜひ一度LeapMindの採用ページから応募してみてください。いきなり採用面接は怖い、という人はカジュアル面談から始めることもできます(もちろん、カジュアル面談の後に、実際に選考プロセスを進めない、という選択も可能です。そのためのカジュアル面談ですので。)