【PyTorch】恐怖の「Shape mismatch (B, C, H, W) vs (B, H, W)」が教えるセグメンテーション実装の罠

プログラミング

U-Netなどの画像セグメンテーション(画像の中のピクセル単位での領域分割)モデルを自作し、初めて学習ループを回した際に、おそらく90%以上の確率で直面するエラーがこれです。

RuntimeError: Expected target size [Batch, Channel, Height, Width], got [Batch, Height, Width]

「えっ?入力の画像も、正解のマスク画像(白黒画像)も同じサイズにしたはずなのに、次元の数が合わないってどういうこと?」と、多くの方がテンソルの形状(Shape)の迷宮に迷い込み、小一時間時間を溶かすことになります。
本記事では、このエラーを通してPyTorchにおける「画像の多次元表現(テンソル構造)」と、「CrossEntropyLossの特殊仕様」について深く掘り下げて解説します!

1. そもそも画像テンソルの「チャンネル」とは?

PyTorchにおいて、画像データは基本的に (Batch, Channels, Height, Width)、通称「BCHW」の4次元テンソル(多次元の箱)として扱われます。
例えば、バッチサイズが8まい、フルHDのカラー画像の場合、そのテンソルの形状(Shape)は (8, 3, 1080, 1920) となりますね。(3はRGBの3チャンネルです)

さて、ここからが問題です。
U-Netで「背景(0)」「道路(1)」「車(2)」の3つのクラスを予測するセグメンテーションを行う場合、ネットワークの推論出力(AIが予測した結果)はどんな形状を出力するべきでしょうか?

正解は、(8, 3, 1080, 1920) です。
ただしこの「3」はRGBの3色ではありません。「各ピクセルが背景である確率」「道路である確率」「車である確率」という、クラス数(3クラス)分の厚みを持ったロジット(確率)が出力されるのがセグメンテーションモデルの正解なのです。これをチャンネル方向に重ね合わせたミルフィーユのような状態を想像してください。

2. 真の罠:教師アノテーション(正解マスク)の渡し方

エラーが発生する原因は、AIの出力側ではなく、あなたが用意した「正解のマスク画像(Ground Truth)」の渡し方にあります。
多くの人は、正解マスク画像(グレースケールのPNG画像など)を読み込んだ際、そのままモデルに投げてしまいます。
すると形状は (Batch, 1, Height, Width) になります(グレースケールなのでチャンネルが1つ)。

しかし、PyTorchの nn.CrossEntropyLoss は非常に賢く(かつ独特な)挙動をします。
多クラスの画像セグメンテーションにおいて、この損失関数は以下のような入力を強固に要求するのです。
・モデル出力: (B, C, H, W) の浮動小数点(Float)テンソル
正解データ: (B, H, W) の整数(Long型)テンソル!!

そう、正解データにはチャンネルの次元(C)があってはいけない仕様なのです!

# 【よくある間違った実装】
# 正解マスク画像をモノクロで読み込んでそのままTensor化した場合、
# mask_tensor.shape は (Batch, 1, Height, Width) と極薄のチャンネル次元を持ちます。
loss = criterion(outputs, mask_tensor) 
# => エラー! PyTorchは (B, H, W) を期待しているのに、余分な「1」次元が挟まっている!
# ================================
# 【正しい実装(エラー回避)】
# ================================
# 邪魔なチャンネル次元(C=1)を squeeze()関数で「押し潰して消去」し、
# 同時に Long型(整数)にキャストしてから渡すのが正解です!!
mask_tensor_correct = mask_tensor.squeeze(1).long()
loss = criterion(outputs, mask_tensor_correct)
# => 完璧!正常にロスが計算され、学習が進みます!!

すなわち、正解マスクのテンソルの中身は、「ピクセルのRGB値」ではなく、「その座標のピクセルが属するクラスのID番号(0, 1, 2…等)」がただ羅列された1枚の二次元配列でなければならないのです。

3. PyTorchの「スマートさ」の弊害とNumPy力

KerasやTensorFlowのように「出力も正解も全部ワンホットベクトル (B, H, W, C) で統一して記述する!」といったおせっかいな思想とは異なり、PyTorchは計算効率とメモリ効率を極限まで高めるため、極限まで無駄を省いた次元圧縮(ワンホットを展開せず整数のまま渡す)をユーザーに強要します。

このエラーは、「テンソルの形状をこまめに print(tensor.shape) して確認する」という、データサイエンティストにとって最も重要で泥臭いデバッグ作業の重要性を強烈に教えてくれます。

まとめ:エラーメッセージは次元の道標

「Shape mismatch(形状不一致)」のエラーは、深層学習プログラミングにおける日常茶飯事であり息をするようなものです。
view()reshape()squeeze()(次元を減らす)、unsqueeze()(次元を増やす)といった次元操作の魔法関数を完全にマスターしない限り、一生エラーに苦しめられることになります。

配列計算と多次元の概念をマスターするためには、PyTorchの根底にある基礎ライブラリ「NumPy」の配列操作から復習するのが一番の近道です。ぜひ専門書を使って、配列スライスの達人を目指してください!

コメント

タイトルとURLをコピーしました