[PyTorch] Datasetの読み込みにかかるメモリ消費量を節約する
ニューラルネットワークを用いた自然言語処理では、大量のラベルなしテキストを利用した事前学習によって、目的のタスクの予測モデルの精度を改善することが報告されています。 事前学習に用いるテキストの量が多いと、データを計算機上のメモリに一度に載りきらない場合があります。 この記事ではPyTorchでニューラルネットワークの学習を記述する際に、テキストをファイルに分割して、ファイル単位でテキストを読み込むことで、計算機上で利用するメモリの使用量を節約する方法を紹介します。
目次
前置き
本記事ではPyTorch1.4.0を利用しています。
データをファイル単位で読み込むために、ある程度の大きさでデータをファイルに分割し、同じディレクトリ以下に配置しているものとします。
data
├── 0.txt
└── 1.txt
$cat data/0.txt
This is a sentence A in 0.txt
This is a sentence B in 0.txt
This is a sentence C in 0.txt
$cat data/1.txt
This is a sentence A in 1.txt
This is a sentence B in 1.txt
This is a sentence C in 1.txt
This is a sentence D in 1.txt
簡単な例
このようなデータをファイル単位で読み込むために、 torch.utils.Dataset
を継承した以下のクラスを実装します。 torch.utils.Dataset
を継承するために、 __len__
と __getitem__
が実装されています。 以下のクラスでは、 __len__
は dirname
以下に配置されているファイルの数、 __getitem__
は指定されたインデックスに対応するファイルを読み込み、そこに含まれるデータを返します。
import os
import torch
class FileDataset(torch.utils.data.Dataset):
def __init__(self, dirname):
self.filenames = [os.path.join(dirname, n) for n in os.listdir(dirname)]
def __len__(self):
return len(self.filenames)
def __getitem__(self, idx):
texts = []
with open(self.filenames[idx], 'r') as f:
for line in f:
texts.append(line.strip())
return texts
(実際にニューラルネットワークを学習する際には、 __getitem__
でテキストを単語やサブワードに分割し、それらを対応する整数に変換する処理も必要となりますが、ここでは簡単ため割愛します。)
実際にこのクラスを使ってデータを読み込む際は以下の様に利用します。
dataset = FileDataset('./data')
file_loader = torch.utils.data.DataLoader(dataset)
for epoch in range(2):
print('epoch', epoch)
for texts in file_loader:
batch_loader = torch.utils.data.DataLoader(texts, batch_size=2, shuffle=True)
for batch in batch_loader:
print(batch)
print('---')
結果は以下の様になります。
epoch 0
[('This is a sentence D in 1.txt', 'This is a sentence A in 1.txt')]
[('This is a sentence C in 1.txt', 'This is a sentence B in 1.txt')]
---
[('This is a sentence C in 0.txt', 'This is a sentence B in 0.txt')]
[('This is a sentence A in 0.txt',)]
---
epoch 1
[('This is a sentence B in 1.txt', 'This is a sentence D in 1.txt')]
[('This is a sentence A in 1.txt', 'This is a sentence C in 1.txt')]
---
[('This is a sentence A in 0.txt', 'This is a sentence C in 0.txt')]
[('This is a sentence B in 0.txt',)]
---
ミニバッチは同じファイル内のデータから作成されるので、その点は認識したうえで利用する必要があります。 ここまでで、ファイル単位でデータを読み込む方法を紹介しました。
ファイル単位でデータを読み込むことでメモリの消費量を節約することができました。ただし、この実装では毎回ファイルを開いてテキストを読み込むため、モデルの学習以外にかかる時間も大きくなってきます。 計算機のメモリ容量が十分にあれば、毎回ファイルからデータを読み込むようなことはしたくないし、そうでなければメモリ使用量を節約してデータを読み込みたい場合、次のような実装も考えられます。
テキストを読み込んだ結果をキャッシュする
この方法は、計算機ごとにメモリ容量が異なり、メモリ容量が大きな計算機では初回のみファイルから読み込んで、以降はキャッシュした結果を利用して、メモリ容量が小さな計算機では毎回ファイルから結果を読み込む、というように計算機の環境に応じてオプションで切り分けたい場合に利用する方法です。
class FileDatasetWithCache(torch.utils.data.Dataset):
def __init__(self, dirname, use_cache=True):
self.filenames = [os.path.join(dirname, n) for n in os.listdir(dirname)]
self.cache = dict((idx, None) for idx in range(len(self)))
self.use_cache = use_cache
def __len__(self):
return len(self.filenames)
def __getitem__(self, idx):
texts = []
if self.use_cache and self.cache[idx] is not None:
print('return cached text')
return self.cache[idx]
else:
with open(self.filenames[idx], 'r') as f:
for line in f:
texts.append(line.strip())
if self.use_cache and self.cache[idx] is None:
print('cache text', idx)
self.cache[idx] = texts
return texts
FileDataset
を FileDatasetWithCache
に置き換えることで、初回のみファイルからデータを読み込み、以降はキャッシュした結果を返すだけとすることができます。
epoch 0
cache text 0
[('This is a sentence B in 1.txt', 'This is a sentence C in 1.txt')]
[('This is a sentence D in 1.txt', 'This is a sentence A in 1.txt')]
---
cache text 1
[('This is a sentence B in 0.txt', 'This is a sentence C in 0.txt')]
[('This is a sentence A in 0.txt',)]
---
epoch 1
return cached text
[('This is a sentence D in 1.txt', 'This is a sentence C in 1.txt')]
[('This is a sentence A in 1.txt', 'This is a sentence B in 1.txt')]
---
return cached text
[('This is a sentence C in 0.txt', 'This is a sentence A in 0.txt')]
[('This is a sentence B in 0.txt',)]
---
おわり
本記事では、PyTorchでニューラルネットワークを学習する際に、メモリ使用量を節約してデータを読み込む方法を紹介しました。 データの読み込みにかかる使用量を節約するために、データをファイル単位で分割し、ファイル単位でデータを読み込むような実装を紹介しました。