画像分類のためのオリジナルCNN(畳み込みニューラルネットワーク)をご自身でコーディングしたとき。Conv2d(畳み込み層)とMaxPoolingで順調に画像を小さくしていき、いざ「最後の分類器(全結合層:Linear層)」へとデータを流し込んだ瞬間に、おそらく世界で一番遭遇率の高いであろうこのエラーが炸裂します。
RuntimeError: mat1 and mat2 shapes cannot be multiplied (64x2048 and 1024x10)
「mat1 と mat2 の掛け算ができないって? 行列(マトリックス)のサイズが違います、だと?
入力画像は 224×224 だし、最後の出力は 10クラス にしたいだけなんだけど、2048とか1024って一体どこから湧いてきた数字なんだ!?」
本記事では、初心者が必ず直面する「CNNの立体的なデータ(3次元テンソル)」から「全結合層の平べったいデータ(1次元テンソル)」への橋渡し(Flatten)における、次元計算のカラクリを徹底解剖します!
1. 何と何の「掛け算」に失敗したのか?
エラーメッセージにある mat1 と mat2 とは、全結合層(nn.Linear)の内部で行われている「入力データ」と「重み行列」のことです。
エラー文の 64x2048 という数字は、「バッチサイズが64で、1枚の画像のデータ容量(要素数)が2048個になっている入力データ(mat1)」を意味しています。
そして 1024x10 という数字は、「あなたが定義した nn.Linear(1024, 10) という重み行列(mat2)」を意味しています。
高校数学で行列の掛け算を習った方ならご存知の通り、行列の掛け算は「左の列数」と「右の行数」が完全に一致していないと計算できません。
つまり、2048個のデータが流れてきているのに、受け取る側の扉は1024個分しか用意されていなかった(2048 != 1024)ため、「掛け算ができないよ!」と大爆発を起こしたのです。
2. プーリング(Pooling)が引き起こす次元のブラックボックス
では、なぜ流れてきたデータが「2048個」になっていたのでしょうか?
画像データは、畳み込み層(Conv2d)を通過するたびにチャンネル数(厚み)が増え、プーリング層(MaxPool2dなど)を通過するたびに縦横のサイズが半分(1/4)になっていきます。
例えば、224×224(厚み3)のカラー画像が、数回の層をくぐり抜けるうちに、最終的に 7×7 の大きさに縮みつつ、厚みが 512 チャンネルに増えたとします。
このデータを全結合層(Linear)に渡すためには、これをペチャンコに潰して(Flatten)一列の長いベクトルにする必要があります。
潰したあとの長さは、7(縦) × 7(横) × 512(厚み) = 25,088個 となります。
そう、あなたが書くべきだった全結合層は nn.Linear(1024, 10) ではなく、nn.Linear(25088, 10) が正解だったのです!
# 【エラーを生むネットワークの例】
import torch.nn as nn
import torch.nn.functional as F
class BadCNN(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 32, kernel_size=3)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
# 適当に入力を1024にしてしまっている!ここが諸悪の根源!
self.fc1 = nn.Linear(1024, 10)
def forward(self, x):
x = F.max_pool2d(F.relu(self.conv1(x)), 2)
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
# データをペチャンコに潰す(バッチ次元は残す)
x = x.view(x.size(0), -1)
# さあ、ここで x のサイズと fc1の入口のサイズが合わずに大爆発!
x = self.fc1(x)
return x3. ダサい計算からの卒業:適応的プーリング (Adaptive Pooling)
「えっ、じゃあ層を追加したり入力画像のサイズを変えるたびに、手計算で (W-K+2P)/S+1 みたいな畳み込みサイズの数式を解いて、最後のかけ算(7x7x512…)を自分で求めないといけないの?頭がおかしくなりそう…」
安心してください、現代のPyTorchには、そんな無駄な手計算を葬り去ってくれる最強の魔法機能「AdaptiveAvgPool2d(適応的平均プーリング)」が用意されています。
この層は、「今流れてきている画像がどんな縦横サイズだろうが関係なく、強制的に指定されたサイズ(例えば 1×1 のピクセルサイズ)に平均化して縮めてね」という強引な指示を出してくれます。
# 【手計算から解放される最強のモダンCNN実装】
class SmartCNN(nn.Module):
def __init__(self):
super().__init__()
self.conv_layers = nn.Sequential(
nn.Conv2d(3, 32, 3),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(32, 512, 3),
nn.ReLU()
)
# 【最大のポイント!】
# どんな縦横サイズで流れてきても、強制的に 1x1 に縮小して512の厚みだけを残す!
self.adaptive_pool = nn.AdaptiveAvgPool2d((1, 1))
# だから fc 層の入口は、計算などせず常に「512」で固定できる!!入力画像が何サイズになろうが関係なし!
self.fc = nn.Linear(512, 10)
def forward(self, x):
x = self.conv_layers(x)
x = self.adaptive_pool(x)
x = torch.flatten(x, 1) # 1x1x512 を 512次元のベクトルに潰す
x = self.fc(x)
return xまとめ:エラーを楽しむためのテンソル力
行列の形状(Shape)不一致エラーは、あなたがモデルのアーキテクチャを手探りで組み上げている証拠です。ResNetやEfficientNetなどの近代的な画像分類モデルは、その全ての最深部に今回紹介した AdaptiveAvgPool2d (または GlobalAveragePooling)を仕込んでおり、そのためどんな画像サイズを突っ込んでもエラーがなく柔軟に動くようになっています。
このようなPyTorchの実践的な設計思想(デザインパターン)については、体系化された専門書を読むことで劇的に理解が深まります。エラーメッセージを「ヒント」として読み解く数学的センスを、この機会にぜひ鍛えてみてください。


コメント