transformersのAutoModelで独自クラスを使う
本記事ではhuggingfaceのtransformersのAutoModelを使って独自クラスを利用する方法を紹介します。
transformersはAutoModelによって、事前学習済みモデルがどのモデルの実装なのかを意識せずに利用できます。
たとえばモデルのアーキテクチャや事前学習済みのパラメータを変えて実験をするプログラムははモデル名をmodel_name_or_path
とした場合 model = AutoModel.from_pretrained(model_name_or_path)
とだけ記述すれば事前学習済みパラメータで初期化されたモデルを読み込めます。
このAutoModel
で独自クラスを利用できればtransformersで実装されている他のモデルと同様に利用が容易になります。
目次
本記事で利用した環境
python = "3.10"
transformers = "4.27.4"
torch = "2.0.0"
プログラムの構成
├── poetry.lock
├── pyproject.toml
├── scripts
│ ├── load.py
│ └── save.py
└── transformer_ext
├── __init__.py
└── models
├── __init__.py
├── sample_config.py
└── sample_model.py
PretrainedConfigクラス
まず独自クラスで利用するハイパーパラメータを記載したPretrainedConfig
クラスを作成します。
ポイントはAutoConfig
に作成したクラスを登録することです。
ここではデコレータとして実装します。
本記事ではsample_model.py
として以下の内容を作成します。
from typing import Type
from transformers import AutoConfig, PretrainedConfig
def register_to_hf_auto_config(
config_class: Type[PretrainedConfig],
) -> Type[PretrainedConfig]:
AutoConfig.register(config_class.model_type, config_class)
return config_class
@register_to_hf_auto_config
class SampleConfig(PretrainedConfig):
model_type: str = "sample_model"
def __init__(self, vocab_size: int = 1000, **kwargs):
self.vocab_size: int = vocab_size
super().__init__(**kwargs)
今回はサンプルなので独自のハイパーパラメータはありません。必要に応じてここに利用するハイパーパラメータを記載します。
model_type
がすでにtransformersで利用されている文字列だとエラーが出ますので注意が必要です。例えばmodel_type="bert"
にすると以下のエラーになります。
ValueError: 'bert' is already used by a Transformers config, pick another name.
PreTrainedModelクラス
次に独自クラスを作成します。独自クラスはPreTrainedModel
を継承します。
サンプルなので中身は多層パーセプトロンです。
ここではAutoModel
に独自クラスを登録します。
本記事ではsample_model.py
に以下の内容を作成します。
import torch.nn as nn
from torch import FloatTensor, LongTensor
from transformers import AutoModel, PretrainedConfig, PreTrainedModel
from typing import Type, Optional
from .sample_config import SampleConfig
def register_to_hf_auto_model(
model_class: Type[PreTrainedModel],
) -> Type[PreTrainedModel]:
config_class: Type[PretrainedConfig] = model_class.config_class
AutoModel.register(config_class, model_class)
return model_class
@register_to_hf_auto_model
class SampleModel(PreTrainedModel):
config_class: PretrainedConfig = SampleConfig
def __init__(self, config):
super().__init__(config)
self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
self.fc1: nn.Module = nn.Linear(config.hidden_size, 2 * config.hidden_size)
self.fc2: nn.Module = nn.Linear(2 * config.hidden_size, config.hidden_size)
def forward(
self,
input_ids: Optional[LongTensor] = None,
attention_mask: Optional[LongTensor] = None,
position_ids: Optional[LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
):
embeddings: FloatTensor = self.embeddings(input_ids)
return self.fc2(self.fc1(embeddings))
AutoModelから独自クラスを利用する
ここまでで独自クラスを実装し、transformersのAutoModelに登録する準備ができました。
次からは作成した実装をAutoModelを介して利用する方法を紹介します。
まず、独自クラスを実装したライブラリをインストールします。
そのライブラリをtransformer_ext
とすると以下のコマンドを実行します。
cd transformer_ext
poetry install
モデルを保存できていることを確認
以下のプログラムを実行することで作成したモデルと保存したモデルの出力結果が一致することを確認できます。
import torch
from transformers import AutoModel, PreTrainedModel
from transformer_ext.models import SampleConfig, SampleModel
model_dir: str = "model"
config: SampleConfig = SampleConfig(
vocab_size=1000,
hidden_size=128,
)
print(config)
model: SampleModel = SampleModel(config)
# 実際はこの部分で学習
model.save_pretrained("model")
m: PreTrainedModel = AutoModel.from_pretrained("model")
input_ids: torch.LongTensor = torch.LongTensor([[0, 1, 2, 3]])
assert torch.allclose(model(input_ids), m(input_ids))
AutoModelから保存したモデルを読み込む
モデルを利用するときはAutoModel.from_pretrained
を使って簡単に読み込めます。
AutoModelに独自クラスを登録する必要があるためimport transformer_ext
を記載する必要がある点にご注意ください。
from transformers import AutoModel, PreTrainedModel
# 独自クラスをAutoModelに登録するためにimportが必要
import transformer_ext # type: ignore
m: PreTrainedModel = AutoModel.from_pretrained("model")
print(m)
おわり
本記事では独自クラスをAutoModelから利用するための手順を紹介しました。 学習したモデルとともにその実装をHubで公開すればAutoModelからの利用は比較的容易ですが必ずしも実装や学習済みモデルを公開可能とは限りません。
本記事で利用したプログラムをtransformer_extで公開しています。 記事の作成にあたりSharing custom modelsを参考にしました。