ニューラルネットワークの学習には、複数の事例 (たとえば単語の系列) に対して並列に損失関数を計算し、得られた勾配に基づいてパラメータを更新するミニバッチ学習が用いられます。自然言語処理において、ミニバッチ学習時は単語の系列を同じ長さにそろえて処理します。これはニューラルネットワーク内での計算において、データが密行列として扱われることが多いためです。 この長さをそろえる処理はパディングといわれています。 当然ながら、ミニバッチ内で系列の長さが不ぞろいなほど、パディングによって追加される疑似的な単語が増えるため、本来不要な計算が増えます。また、ミニバッチを表す密行列が大きいほど、計算にかかる時間が大きくなります。 本記事ではPyTorchにおける実装において、系列の長さが近い事例でミニバッチを作成することで、不要なパディングをできるだけ減らし、ミニバッチを表す密行列の大きさを小さくする方法を紹介します。

目次

本記事で計算しているコードはPyTorch1.4.0を利用しています。

自然言語処理におけるミニバッチ作成時のパディング

二つの単語系列 [1, 2, 3, 4, 5][1, 2] をまとめて一つのミニバッチを作成することを考えます。 このとき、パディングは、ミニバッチ内の各単語系列に対して、最長の単語系列長になるように、単語系列の末尾に疑似的な単語 (ここでは0) を追加してミニバッチ内のすべての単語系列が同じ長さとなるようにします。 最長の単語系列長は 5 なので、 [1, 2] に対して、 0 を3つ末尾に追加します。 PyTorchで実装するとたとえば以下のように記述できます。

import torch

x = [torch.tensor([1, 2, 3, 4, 5]), torch.tensor([1, 2])]
x = torch.nn.utils.rnn.pad_sequence(x, batch_first=True)

得られる結果は以下の通りです。

tensor([[1, 2, 3, 4, 5],
        [1, 2, 0, 0, 0]])

ミニバッチ学習

ニューラルネットワークを学習する際は、学習データから指定した数だけ事例を選択し、選択した事例集合に対してパディングを適用します。 ここで、どのような基準で事例を選択するかを考える必要がありますが、良く用いられるのは、無作為に事例を選択する方法です。 無作為に事例を選択するのは、ミニバッチがデータの順序などの偏りが無くなるようにすることで、ニューラルネットワークが偏った学習をしないように意図したものです。 この方法の問題のひとつは選択された事例の単語系列長のばらつきが大きいと、パディングによって追加される単語が多くなり、計算にかかる時間が多くなるということです。

まず、良く用いられる無作為に事例を選択することでミニバッチを作成する方法と、できるだけパディングによって追加する単語を減らす作成方法を紹介します。

以降では5つの事例に対してバッチサイズ2でミニバッチを作成することを考えます。

batch_size = 2

data = [torch.tensor([1, 2, 3, 4]),
        torch.tensor([1, 2, 3, 4, 5]),
        torch.tensor([1, 2]),
        torch.tensor([1, 2, 3, 4, 5, 6]),
        torch.tensor([1, 2, 3, 4, 5, 6, 7, 8])]

またパディングを適用する関数を以下の様に定義しておきます。

def collate_fn(batch):
    x = torch.nn.utils.rnn.pad_sequence(batch, batch_first=True)
    return x

無作為に事例を選択してミニバッチを作成

無作為に事例を選択してミニバッチを作成する方法は以下の様に実装できます。

    data_loader = torch.utils.data.DataLoader(
            data,
            batch_size=batch_size,
            shuffle=True,
            collate_fn=collate_fn)

    for epoch in range(2):
        print('Epoch:', epoch)

        for batch in data_loader:
            print(batch)
            print('---')

結果は以下の通りです。

Epoch: 0
tensor([[1, 2, 3, 4, 5, 6, 0, 0],
        [1, 2, 3, 4, 5, 6, 7, 8]])
---
tensor([[1, 2, 0, 0, 0],
	[1, 2, 3, 4, 5]])
---
tensor([[1, 2, 3, 4]])
---
Epoch: 1
tensor([[1, 2, 0, 0, 0],
	[1, 2, 3, 4, 5]])
---
tensor([[1, 2, 3, 4, 0, 0, 0, 0],
	[1, 2, 3, 4, 5, 6, 7, 8]])
---
tensor([[1, 2, 3, 4, 5, 6]])
---

系列の長さでソートしてミニバッチを作成

次に単語系列長が近い事例でミニバッチを作成する方法は次の様に実装できます。

def create_batch_sampler(data, batch_size):
    indices = torch.arange(len(data)).tolist()
    sorted_indices = sorted(indices, key=lambda idx: len(data[idx]))

    batch_indices = []
        start = 0
        end = min(start + batch_size, len(data))
        while True:
            batch_indices.append(sorted_indices[start: end])

            if end >= len(data):
                break

            start = end
            end = min(start + batch_size, len(data))

    return batch_indices


def length_sorted_data_loader(data, batch_size):
    batch_sampler = create_batch_sampler(data, batch_size)

    num_token = 0
    num_pad = 0
    for epoch in range(2):
        print('Epoch:', epoch)

        random.shuffle(batch_sampler)
        data_loader = torch.utils.data.DataLoader(
            data,
            batch_sampler=batch_sampler,
            collate_fn=collate_fn)

        for batch in data_loader:
            print(batch)
            print('---')

無作為に事例を選択する方法とは異なり、まず事例を単語系列長でデータのインデックスをソートし、ソートされたインデックスに対して、先頭から順にミニバッチサイズだけ事例を選択し、 batch_samplercreate_batch_sampler から作成します。 batch_sampler は各エポックの最初に無作為に並び替えて torch.utils.data.DataLoader の引数に渡されます。

結果は以下の様になります。

Epoch: 0
tensor([[1, 2, 3, 4, 5, 6, 7, 8]])
---
tensor([[1, 2, 3, 4, 5, 0],
        [1, 2, 3, 4, 5, 6]])
---
tensor([[1, 2, 0, 0],
	[1, 2, 3, 4]])
---
Epoch: 1
tensor([[1, 2, 3, 4, 5, 0],
	[1, 2, 3, 4, 5, 6]])
---
tensor([[1, 2, 0, 0],
	[1, 2, 3, 4]])
---
tensor([[1, 2, 3, 4, 5, 6, 7, 8]])
---

無作為に事例を選択する方法と異なり、ミニバッチ内の事例は常に同じではあるものの、ミニバッチの順序が無作為になるようなミニバッチ作成ができます。 単語系列長が近い事例でミニバッチを作成するため、無作為に事例を選択する方法と比較して、0の数が少なくなる傾向にあることが分かります。

パディングの数を比較

本来ならGPU上でのニューラルネットワークの学習時間を計測すべきですが、簡単のため、パディングの数にどれくらい差が出るか比較してみます。

import random

import numpy
import torch


def collate_fn(batch):
    x = torch.nn.utils.rnn.pad_sequence(batch, batch_first=True)
    return x


def create_batch_sampler(data, batch_size):
    indices = torch.arange(len(data)).tolist()
    sorted_indices = sorted(indices, key=lambda idx: len(data[idx]))

    batch_indices = []
    start = 0
    end = min(start + batch_size, len(data))
    while True:

        batch_indices.append(sorted_indices[start: end])

        if end >= len(data):
            break

        start = end
        end = min(start + batch_size, len(data))

    return batch_indices


def length_sorted_data_loader(data, batch_size):
    batch_sampler = create_batch_sampler(data, batch_size)

    num_token = 0
    num_pad = 0
    for epoch in range(100):
        random.shuffle(batch_sampler)
        data_loader = torch.utils.data.DataLoader(
            data,
            batch_sampler=batch_sampler,
            collate_fn=collate_fn)

        for batch in data_loader:
            num_token += (batch != 0).sum().item()
            num_pad += batch.eq(0).sum().item()

    print('num_token', num_token, 'num_pad', num_pad)


def random_sample_data_loader(data, batch_size):
    data_loader = torch.utils.data.DataLoader(
        data,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=collate_fn)

    num_token = 0
    num_pad = 0
    for epoch in range(100):
        for batch in data_loader:
            num_token += (batch != 0).sum().item()
            num_pad += batch.eq(0).sum().item()

    print('num_token', num_token, 'num_pad', num_pad)


if __name__ == '__main__':
    batch_size = 2

    data = [torch.tensor([1, 2, 3, 4]),
            torch.tensor([1, 2, 3, 4, 5]),
            torch.tensor([1, 2]),
            torch.tensor([1, 2, 3, 4, 5, 6]),
            torch.tensor([1, 2, 3, 4, 5, 6, 7, 8])]

    print('Random sample')
    random_sample_data_loader(data, batch_size)
    print('Length sorted')
    length_sorted_data_loader(data, batch_size)

結果は以下の様になります。

Random sample
num_token 2500 num_pad 580
Length sorted
num_token 2500 num_pad 300

無作為に事例を選択する方法と比較して、単語系列長が近い事例でミニバッチを作成したほうが、パディングによって追加される単語数が少なくなっていることが分かります。 もちろんパディングの数は学習にかかるエポック数や単語系列長のばらつきによりますが、上記の設定だとパディングの数が半分に減っています。

まとめ

本記事ではニューラルネットワークにおける自然言語処理において、単語系列長が近い事例でミニバッチを作成することでパディングによって追加される単語数を減らす実装方法をPyTorchを使って説明しました。 単語系列長が近い事例でミニバッチを作成する方法はFacebookが公開しているニューラルネットワークのsequence-to-sequence実装 fairseq でも実装されており、より学習時間を短縮する方法の一つとして活用できます。


関連記事






最近の記事