BERTの学習で用いるoptimizerでbiasやlayer normalizationのパラメータだけがweight decayの対象外となっていることについて疑問は持ったことはあるでしょうか。たとえばhuggingfaceのtransformersのissueでもそのような質問がありますが、「Googleの公開しているBERTがそうしているから再現性のために合わせた」と回答されています。ではなぜGoogleのBERT実装はそのような設定にしたのでしょうか。これらのOSSを利用されている方にも天下り的に設定している方もいらっしゃると思います。本記事ではBERTなどの学習で用いられるoptimizerのweight decayで、biasやlayer normalizationのパラメータが対象外となっている理由について解説します。

目次

BERTの学習で用いられるoptimizer

GoogleのTensorFlow実装で利用されているoptimizerは下記のように記述されています。 引数exclude_from_weight_decayで指定されたlayer_normbiasに関するパラメータがweight decayの対象から除外されています。

  optimizer = AdamWeightDecayOptimizer(
      learning_rate=learning_rate,
      weight_decay_rate=0.01,
      beta_1=0.9,
      beta_2=0.999,
      epsilon=1e-6,
      exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"])

どういった理由でこれらのパラメータをweight decayの対象から除外しているのかを以降で解説します。

weight decay

weight decayはL2正則化のことで、モデルの過学習を抑えるために用いられます。モデルのパラメータ$\theta$に対して、ある損失関数$L(\theta)$に対してweight decayを足し合わせた関数を$E(\theta)$とします。

$$ E(\theta) = L(\theta) + \frac{C}{2} ||\theta||^2 \tag{1} $$

ハイパーパラメータ$C$はweight decayに対する重みです。学習時は$E(\theta)$を最小化するパラメータを探索します。 まず誤差逆伝播を用いてパラメータ$\theta$の更新に利用する勾配を以下のように求めます。

$$ \begin{eqnarray} \frac{\nabla E(\theta)}{\nabla \theta} &=& \frac{\nabla L(\theta)}{\nabla \theta} + C \theta \tag{2} \end{eqnarray} $$

次に得られた勾配と学習率$\lambda$に基づいて$\theta$を更新します。

$$ \begin{eqnarray} \theta &\leftarrow& \theta - \lambda \frac{\nabla E(\theta)}{\nabla \theta} \\\
&=& \theta - \lambda \frac{\nabla L(\theta)}{\nabla \theta} - \lambda C \theta \\\
&=& - \lambda \frac{\nabla L(\theta)}{\nabla \theta} + \theta (1 - \lambda C) \tag{3} \end{eqnarray} $$

次に実装について見ていきいます。 PyTorchのAdamWの実装はこちら です。weight decayに関する部分を抜粋します。

    for i, param in enumerate(params):
        # 省略

        # Perform stepweight decay
        param.mul_(1 - lr * weight_decay)

        # 省略。以降で勾配の項を使ってパラメータを更新する

これを数式で表すと以下のようになり式 (3)の第二項と一致します。ここでparamは$\theta$、lrは$\lambda$、weight_decayは$C$と対応します。

$$ \begin{eqnarray} \mathbf{\theta} &=& \mathbf{\theta} (1 - \lambda C) \end{eqnarray} $$

この式から、weight decayを適用することにより、パラメータが0に近づくことがわかります。またパラメータが0に近づくことで入力の一部の特徴量に引っ張られた予測を抑制したり、あまり重要でない特徴量を無視できるようになることで過学習を抑制することがわかります。

weight decayの対象外となるパラメータ

これまでにweight decayの仕組みについて解説しました。ここからは、なぜBERTなどのモデルの学習ではbiasやlayer normalizationのパラメータがweight decayの対象外となるかを見ていきます。

bias

biasは線形変換などで出てくるパラメータです。入力$\mathbf{x} \in \mathbb{R}^D$に対して、パラメータ$\mathbf{W} \in \mathbb{R}^{H \times D}$と$\mathbf{b} \in \mathbb{R}^H$を持つ線形変換は以下の計算をして出力$\mathbf{y} \in \mathbb{R}^H$を得ます。

$$ \mathbf{y} = \mathbf{W} \mathbf{x} + \mathbf{b} \tag{4} $$

$\mathbf{W}$は入力$\mathbf{x}$との積をとって計算するものです。例えば$i$番目の次元の出力は$y_i = \sum_{j=1}^H W_{ij} x_j$のように掛け算の総和によって得られます。この式から$W_{ij} (1 \leq j \leq H)$の中に大きな値のものがあると、出力に大きく影響することがわかります。このことから$\mathbf{W}$の値の大きさそのものよりも他の次元との大小関係が大事なので、入力の特定の値に対して計算結果が大きく変動しないよう、値そのものは小さいほうが高い汎化性能が期待されます。

$\mathbf{b}$は入力$\mathbf{x}$に依存せず、値の大きさそのものを調整する役割を担います。たとえば$\mathbf{y}$が大きな値である必要があることが期待されるとき、$\mathbf{b}$も大きな値になる必要があります。ReLUのように値が0以上かそうでないかによって出力が変わる際には値の大きさが重要になります。weight decayのたびに小さな値になるとbiasの役割を果たせません。

layer normalization

layer normalizationはデータの分布を正規化することで学習時間を減らしつつ精度改善も期待できる手法です。直前の層の出力$\mathbf{y} \in \mathbb{R}^H$に対する平均値$\mu$と標準偏差$\sigma$で正規化した結果$\mathbf{h} \in \mathbb{R}^H$を出力します。$\mathbf{g} \in \mathbb{R}^H$と$\mathbf{b} \in \mathbb{R}^H$はlayer normalizationにおける学習対象のパラメータであり、値をスケールする際に用います。

$$ \begin{eqnarray} \mu &=& \frac{1}{H} \sum_{i=1}^{H} y_i \tag{5} \\\
\sigma &=& \sqrt{\frac{1}{H} \sum_{i=1}^{H} (y_i - \mu)^2} \tag{6} \\\
\mathbf{h} &=& f(\frac{\mathbf{g}}{\sigma} \odot (\mathbf{y} - \mu) + \mathbf{b}) \tag{7} \end{eqnarray} $$

layer normalizationにおける$\mathbf{g}$および$\mathbf{b}$も線形変換でのbiasと同様に値の大きさそのものを調整するために用いています。 そのためweight decayの対象から除外します。 layer normalizationの出力が活性化関数fへの入力となります。

おわりに

本記事ではbiasやlayer normalizationのパラメータがweight decayの対象から外す理由を解説しました。 これらのパラメータは値の大きさを調整するために用いられるものなので、正則化によって0に近づけると本来の役割を実現できなくなってしまいます。 最後に本記事を執筆するにあたり参考にした記事を参照します。

Is weight decay applied to the bias term?


関連記事






最近の記事