【PyTorch】限られたメモリにおける大きなバッチサイズでの学習
ニューラルネットワークの学習ではミニバッチ学習という複数の学習事例に対して得られる損失の総和を最小化するようにパラメータを更新します。 バッチサイズは計算機のメモリ容量に応じて人が決める値ですが、 BERTはバッチサイズを大きくしたほうが学習が安定しやすいという報告があります。 しかし、デバイスのメモリに載りきらないサイズでは学習中にメモリーエラーを起こしてしまいます。 本記事ではPyTorchコードを使って、メモリ容量が限られた環境でも大きなバッチサイズでミニバッチ学習する方法を紹介します。
目次
ミニバッチ単位でパラメータ更新
よくある実装は、1つのミニバッチに対して損失を計算し、損失に基づく勾配を計算し、パラメータを更新するというものです。
for input, target in dataset:
pred = net(input)
loss = crit(pred, target)
opt.zero_grad()
loss.backward()
opt.step()
net
はニューラルネットワーク、crit
は損失を計算する関数とします。
loss.backward()
によって、その損失を計算するために利用したパラメータに対する勾配を計算し、opt.step()
によって勾配に基づいてパラメータを更新します。
この実装でバッチサイズを大きくすると、バッチサイズに依存して消費するメモリが大きくなります。
損失を計算するためにニューラルネットワークの中間層の計算結果をすべて保持するためです。
複数のミニバッチ単位でパラメータ更新
次に、複数のミニバッチに対して勾配を計算したあとでパラメータを更新する実装を紹介します。
ここではミニバッチを10
処理したあとにパラメータを更新します。
この方法にすると、バッチサイズ * 10
処理したあとにパラメータを更新することになります。
そのため上記の方法と比較して10倍のバッチサイズでの処理に相当します。
ただし、for
文で各バッチの損失を計算し、loss.backward()
で勾配を計算しているため、並列化による処理の高速化は期待できない点に注意が必要です。
optimizerによるパラメータ更新処理が上記実装方法に比べて少なくなるので、その分の処理の高速化はできると思いますが、微々たるものだと思います。
opt.zero_grad()
for i, (input, target) in enumerate(dataset):
pred = net(input)
loss = crit(pred, target)
loss.backward()
if (i+1)%10 == 0:
opt.step()
opt.zero_grad()
loss.backward()
を実行すると、損失を計算するために利用されたパラメータ (param
) に対する勾配を計算します。
その後、計算した勾配は、以下のようにparam.grad
に加算されます。
param.grad += grad
opt.step()
で以下のように、複数のバッチにおける勾配を加算した結果に基づいてパラメータが更新されます。
param.data += lr * param.grad
ここでlr
はoptimizerの学習率とします。
パラメータの更新が終わったら、opt.zero_grad()
によって、param.grad
が0になります。
まとめ
本記事ではPyTorchを使う際の、限られたメモリ環境における大きなバッチサイズでのニューラルネットワークの学習について紹介しました。
メモリ容量が限られた環境においては、複数のバッチにおける勾配をfor
文で計算した後にoptimizerでパラメータを更新することを紹介しました。
本記事を書くにあたり参考にした記事を以下に記載します。
Why do we need to set the gradients manually to zero in pytorch?