PyTorchのVersion1.5.0がリリースされました。 いくつかの変更がされていますが、その中の一つが、PyTorchでXLAの利用が可能となったというものです。 XLAを利用できると、PyTorch実装をTPU上で実行できるようになります。 本記事ではPyTorch1.5.0を使ってGoogle ColabのTPUを利用できるようになるところまでの流れを説明します。

目次

XLA

XLAはGoogleが開発している線形代数に最適化されたコンパイラです。XLAでコンパイルされたプログラムはCloud TPUで実行できるようになります。 主にTensorFlowで利用されていましたが、TPUを利用するためのプログラムに利用することもできます。 PyTorchで書かれたプログラムをTPUで実行するための開発も進められていて、Version1.5.0になって利用可能となりました。

PyTorchではCPUやCUDA上のテンソルと異なり、XLA上のテンソルにおける演算は遅延評価となります。 遅延評価では実際に計算結果が必要となるまで計算をしません。これによってXLAでは計算グラフを最適化します。最適化によって、複数の演算が一つの演算にまとめられることもあります。 つまり、ユーザが実装したforward処理に無駄な演算があった場合、XLAで処理することで高速化されることがあるということです。

Colab上でデバイスをTPUに設定

デバイスをTPUに設定 デバイスをTPUに設定

PyTorchのインストール

以下のようにしてColab上にPyTorch/XLAをインストールします。

!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
% Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                        Dload  Upload   Total   Spent    Left  Speed
100  3727  100  3727    0     0  41411      0 --:--:-- --:--:-- --:--:-- 41411
!python pytorch-xla-env-setup.py --tpu-ip 10.240.0.0 
Updating TPU and VM. This may take around 2 minutes.
Updating TPU runtime to pytorch-dev20200325 …
…

インストール後、以下のように実行されます。

import torch_xla
import torch
print(torch.__version__)
1.5.0a0+d6149a7

PyTorchでXLAを利用する

ここを参考に利用します。

XLAテンソルの作成

XLA上にテンソルを作成するにはdevice=xm.xla_device()としてテンソル作成で使われる関数に与えるだけで良いため、ユーザはそれ以外のプログラムはCPUやCUDAと同じように使えます。

import torch
import torch_xla
import torch_xla.core.xla_model as xm

t = torch.randn(2, 2, device=xm.xla_device())
print(t.device)
print(t)
xla:1
tensor([[-0.3916,  0.4222],
        [ 1.0496, -0.4849]], device='xla:1')

XLAテンソルによる演算

XLA上のテンソルに対する演算もCPUやCUDAと同様に扱えます。

t0 = torch.randn(2, 2, device=xm.xla_device())
t1 = torch.randn(2, 2, device=xm.xla_device())
print(t0 + t1)
tensor([[ 2.5711, -2.3486],
        [-2.4695,  1.3503]], device='xla:1')

torch.nn.Moduleに対する演算もCPUやCUDAと同様に扱えます。

l_in = torch.randn(10, device=xm.xla_device())
linear = torch.nn.Linear(10, 20).to(xm.xla_device())
l_out = linear(l_in)
print(l_out)
tensor([ 0.4504,  1.1653, -1.2658, -0.4964, -0.6305, -0.0203, -0.6684,  0.1233,
         0.0032,  0.5822, -0.7714, -0.0253,  0.0629, -0.3567,  0.0965, -1.6003,
         -0.3350, -0.8112, -0.3519, -0.6657], device='xla:1',
         grad_fn=<AddBackward0>)

XLAデバイスでモデルの処理を実行する

MNISTを使って手書き文字認識モデルを学習します。

モデルは以下のように実装します。

import time

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

data loaderは以下のように実装します。

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=10, shuffle=True)

ここまでで、XLA依存な実装はありません。 次に実際の学習部分を実装します。

device = xm.xla_device()
model = Net().train().to(device)
loss_fn = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.95)

for epoch in range(3):
    loss_train = 0
    start_at = time.time()
    for data, target in train_loader:
        optimizer.zero_grad()
        data = data.to(device)
        target = target.to(device)
        output = model(data)
        loss = loss_fn(output, target)
        loss.backward()
        loss_train += loss.item()
      
        xm.optimizer_step(optimizer, barrier=True)
    elapsed = time.time() - start_at
    print(f"epoch:{epoch} elaplsed:{elapsed:.3f} loss:{loss:.3f}")

ここで初めてXLA依存な部分が出てきます。デバイスをXLAとして指定する部分と、optimizerによるパラメータのアップデート部分です。 XLAは遅延評価と説明しましたが、xm.optimizer_step(optimizer, barrier=True)の箇所で実際に値が計算されます。

epoch:0 elaplsed:93.684 loss:0.053
epoch:1 elaplsed:89.488 loss:0.052
epoch:2 elaplsed:90.413 loss:0.003

複数のXLAデバイスを使って処理を並列化していないためか、学習は速くないのですが、ひとまずXLA上でPyTorch実装を動作させることができました。

おわり

本記事ではPyTorch1.5.0で利用できるようになったXLAを使って、Google Colab上のTPUで手書き文字認識の分類モデルを学習させるところまで紹介しました。 TPUの恩恵に預かるには、データの並列化などの実装も必要そうです。


関連記事






最近の記事