【PyTorch Lightning】「KeyError: ‘val_loss’」から学ぶコールバックとフレームワークの黒魔術

PyTorchの面倒な学習ループ(for-loop地獄)を排除し、コードを劇的に綺麗にしてくれる神フレームワーク「PyTorch Lightning」。
「これで学習の進捗保存(ModelCheckpoint)も Early Stopping(過学習の早期停止)も自動でやってくれるぞ!」と意気込みコンポーネントをセットアップし、trainer.fit(model) を走らせた瞬間。

第1エポックの終わりに差し掛かったところで、突然こんなエラーが降ってきます。

pytorch_lightning.utilities.exceptions.MisconfigurationException: Early stopping conditioned on metric `val_loss` which is not available. Pass in or modify your `EarlyStopping` callback to use any of the following: ``

(※シンプルな環境では KeyError: 'val_loss' と出ます。)

「えっ!? Validation用の val_loss(検証データに対する誤差)を監視させておいたのに、なぜ『その値は見つかりません』なんて言われるの?」
実はこれ、Lightningというフレームワーク特有の「内部通信(ロギング)」の仕組みである「黒魔術(ブラックボックス)」を正しく理解していないために起こる、設計ミスなのです。

1. そもそも誰が「val_loss」という名前を決めたのか?

PyTorch Lightningでは、ModelCheckpoint(最高精度のモデルを保存する機能)や EarlyStopping(精度が上がらなくなったら学習を途中で止める機能)といった「コールバック機能」を使います。
これらの機能は、設定画面で monitor='val_loss' などと文字通り「監視対象の名前」だけを指定します。

from pytorch_lightning.callbacks import EarlyStopping
# コールバックの定義:「val_loss が下がらなくなったら3エポックで止めてね」
early_stop = EarlyStopping(
    monitor="val_loss",  # ここで監視対象の名前を指定している
    patience=3, 
    mode="min"
)

ここで重要な質問です。
Lightning側は、どのようにしてあなたのモデルが計算した「誤差(Loss)」のうち、どれが「val_loss」だと知ることができるのでしょうか?

AIは勝手に「あ、これが検証ロスだな」と忖度してくれません。
エラーの原因は、あなたが「検証ステップ(validation_step)」の中で、自分から「この値が val_loss という名前ですよ!」とLightningのシステムに対して放送(ログ記録)し忘れていたからなのです。

2. 真犯人:self.log() の記述漏れ

Lightningモジュール内で定義する validation_step の中で、ただLossを計算して返り値として return しているだけでは、それはEarly Stoppingの監視対象にはなりません。
必ず self.log("キーの名前", 記録したい値) という黒魔術メソッドを唱えて、システム全体に周知(ブロードキャスト)してやらなければならないのです。

import pytorch_lightning as pl
class MyModel(pl.LightningModule):
    # 【悪い例(エラーが起きる)】
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.model(x)
        loss = self.criterion(y_hat, y)
        
        # ただ loss を return するだけでは、Lightning側はそれが 何なのか 分からない!
        return loss 
    # ==========================
    # 【正しい例(防弾仕様)】
    # ==========================
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.model(x)
        loss = self.criterion(y_hat, y)
        
        # 【最重要】システムに対して「これが 'val_loss' だよ!」とログ記録してあげる
        # 内部で各バッチの平均値なども自動計算してくれます
        self.log("val_loss", loss, prog_bar=True)
        
        return loss

3. もう一つの罠:スペルミスと大文字小文字

「私はちゃんと self.log('Val_Loss', loss) って書いたのにエラーが出たぞ!」
はい、もうお気づきですね。
monitor="val_loss" と指定した側と、self.log で保存した側の「文字列(文字列キー)」が1文字でも違っていたり、大文字・小文字が異なっていれば、それは完全な別物として扱われ、エラーになってしまいます。
これはPythonの辞書(Dict)のキーと同じシステムで通信しているためです。
このような人的ミスを防ぐため、実務では文字列をベタ打ちするのではなく、定数(Constant)として定義して使い回すのがクリーンなコードの鉄則になります。

まとめ:フレームワークは「規約の塊」である

PyTorch LightningやKerasのような「ラッパー層(便利にしてくれるカバー)」の深いフレームワークを使うと、数行で膨大な機能を実現できる一方で、「裏でどんなシステムが、どういう文字列キーを使って連携しているのか」という「設計の規約」を無視した途端に意味不明なエラーに殺されます。

エラーが出たときは怒るのではなく、「なるほど、Lightningは self.log というメソッドを通じて、プログレスバーの描画からコールバックの監視までを裏でイベント駆動(Pub/Subのような仕組み)で回しているんだな」というシステムアーキテクチャの視点を持ってみましょう。実践的な運用書に目を通すことで、こうしたフレームワークの奥深い真の実力を120%引き出せるプロになれます。

コメント

タイトルとURLをコピーしました