【PyTorch】Version1.5でTPUを利用する方法
PyTorchのVersion1.5.0がリリースされました。 いくつかの変更がされていますが、その中の一つが、PyTorchでXLAの利用が可能となったというものです。 XLAを利用できると、PyTorch実装をTPU上で実行できるようになります。 本記事ではPyTorch1.5.0を使ってGoogle ColabのTPUを利用できるようになるところまでの流れを説明します。
目次
XLA
XLAはGoogleが開発している線形代数に最適化されたコンパイラです。XLAでコンパイルされたプログラムはCloud TPUで実行できるようになります。 主にTensorFlowで利用されていましたが、TPUを利用するためのプログラムに利用することもできます。 PyTorchで書かれたプログラムをTPUで実行するための開発も進められていて、Version1.5.0になって利用可能となりました。
PyTorchではCPUやCUDA上のテンソルと異なり、XLA上のテンソルにおける演算は遅延評価となります。 遅延評価では実際に計算結果が必要となるまで計算をしません。これによってXLAでは計算グラフを最適化します。最適化によって、複数の演算が一つの演算にまとめられることもあります。 つまり、ユーザが実装したforward処理に無駄な演算があった場合、XLAで処理することで高速化されることがあるということです。
Colab上でデバイスをTPUに設定
PyTorchのインストール
以下のようにしてColab上にPyTorch/XLAをインストールします。
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
% Total % Received % Xferd Average Speed Time Time Time Current
Dload Upload Total Spent Left Speed
100 3727 100 3727 0 0 41411 0 --:--:-- --:--:-- --:--:-- 41411
!python pytorch-xla-env-setup.py --tpu-ip 10.240.0.0
Updating TPU and VM. This may take around 2 minutes.
Updating TPU runtime to pytorch-dev20200325 …
…
インストール後、以下のように実行されます。
import torch_xla
import torch
print(torch.__version__)
1.5.0a0+d6149a7
PyTorchでXLAを利用する
ここを参考に利用します。
XLAテンソルの作成
XLA上にテンソルを作成するにはdevice=xm.xla_device()
としてテンソル作成で使われる関数に与えるだけで良いため、ユーザはそれ以外のプログラムはCPUやCUDAと同じように使えます。
import torch
import torch_xla
import torch_xla.core.xla_model as xm
t = torch.randn(2, 2, device=xm.xla_device())
print(t.device)
print(t)
xla:1
tensor([[-0.3916, 0.4222],
[ 1.0496, -0.4849]], device='xla:1')
XLAテンソルによる演算
XLA上のテンソルに対する演算もCPUやCUDAと同様に扱えます。
t0 = torch.randn(2, 2, device=xm.xla_device())
t1 = torch.randn(2, 2, device=xm.xla_device())
print(t0 + t1)
tensor([[ 2.5711, -2.3486],
[-2.4695, 1.3503]], device='xla:1')
torch.nn.Module
に対する演算もCPUやCUDAと同様に扱えます。
l_in = torch.randn(10, device=xm.xla_device())
linear = torch.nn.Linear(10, 20).to(xm.xla_device())
l_out = linear(l_in)
print(l_out)
tensor([ 0.4504, 1.1653, -1.2658, -0.4964, -0.6305, -0.0203, -0.6684, 0.1233,
0.0032, 0.5822, -0.7714, -0.0253, 0.0629, -0.3567, 0.0965, -1.6003,
-0.3350, -0.8112, -0.3519, -0.6657], device='xla:1',
grad_fn=<AddBackward0>)
XLAデバイスでモデルの処理を実行する
MNISTを使って手書き文字認識モデルを学習します。
モデルは以下のように実装します。
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout2d(0.25)
self.dropout2 = nn.Dropout2d(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
output = F.log_softmax(x, dim=1)
return output
data loaderは以下のように実装します。
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('./data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=10, shuffle=True)
ここまでで、XLA依存な実装はありません。 次に実際の学習部分を実装します。
device = xm.xla_device()
model = Net().train().to(device)
loss_fn = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.95)
for epoch in range(3):
loss_train = 0
start_at = time.time()
for data, target in train_loader:
optimizer.zero_grad()
data = data.to(device)
target = target.to(device)
output = model(data)
loss = loss_fn(output, target)
loss.backward()
loss_train += loss.item()
xm.optimizer_step(optimizer, barrier=True)
elapsed = time.time() - start_at
print(f"epoch:{epoch} elaplsed:{elapsed:.3f} loss:{loss:.3f}")
ここで初めてXLA依存な部分が出てきます。デバイスをXLAとして指定する部分と、optimizerによるパラメータのアップデート部分です。
XLAは遅延評価と説明しましたが、xm.optimizer_step(optimizer, barrier=True)
の箇所で実際に値が計算されます。
epoch:0 elaplsed:93.684 loss:0.053
epoch:1 elaplsed:89.488 loss:0.052
epoch:2 elaplsed:90.413 loss:0.003
複数のXLAデバイスを使って処理を並列化していないためか、学習は速くないのですが、ひとまずXLA上でPyTorch実装を動作させることができました。
おわり
本記事ではPyTorch1.5.0で利用できるようになったXLAを使って、Google Colab上のTPUで手書き文字認識の分類モデルを学習させるところまで紹介しました。 TPUの恩恵に預かるには、データの並列化などの実装も必要そうです。