PyTorch + CNNでMNIST数字分類
PyTorch + CNNでMNIST数字分類の解説
MNISTの手書き数字を分類するCNNの実装を段階的に解説します。
完全なコード
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
# 1. CNNモデルの定義
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
# 畳み込み層
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1) # 28x28x1 → 28x28x32
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) # 14x14x32 → 14x14x64
# プーリング層
self.pool = nn.MaxPool2d(2, 2) # サイズを半分に
# 全結合層
self.fc1 = nn.Linear(64 * 7 * 7, 128)
self.fc2 = nn.Linear(128, 10) # 10クラス分類
# ドロップアウト
self.dropout = nn.Dropout(0.5)
def forward(self, x):
# 畳み込み + ReLU + プーリング
x = self.pool(torch.relu(self.conv1(x))) # 28x28x32 → 14x14x32
x = self.pool(torch.relu(self.conv2(x))) # 14x14x64 → 7x7x64
# 平坦化
x = x.view(-1, 64 * 7 * 7)
# 全結合層
x = torch.relu(self.fc1(x))
x = self.dropout(x)
x = self.fc2(x)
return x
# 2. データの準備
transform = transforms.Compose([
transforms.ToTensor(), # PIL画像をTensorに変換
transforms.Normalize((0.5,), (0.5,)) # 正規化 (平均, 標準偏差)
])
# データセットのダウンロードと読み込み
train_dataset = datasets.MNIST(root='~/.pytorch/data', train=True,
download=True, transform=transform)
test_dataset = datasets.MNIST(root='~/.pytorch/data', train=False,
download=True, transform=transform)
# DataLoader作成
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)
# 3. モデル、損失関数、最適化手法の設定
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# MPSが利用可能かチェック
if torch.backends.mps.is_available():
device = torch.device("mps")
model = CNN().to(device)
criterion = nn.CrossEntropyLoss() # 多クラス分類用の損失関数
optimizer = optim.Adam(model.parameters(), lr=0.001)
# 4. 訓練関数
def train(model, device, train_loader, optimizer, criterion, epoch):
model.train() # 訓練モードに設定
total_loss = 0
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
# 勾配をゼロに
optimizer.zero_grad()
# 順伝播
output = model(data)
# 損失計算
loss = criterion(output, target)
# 逆伝播
loss.backward()
# パラメータ更新
optimizer.step()
total_loss += loss.item()
if batch_idx % 100 == 0:
print(f'Epoch: {epoch}, Batch: {batch_idx}, Loss: {loss.item():.4f}')
return total_loss / len(train_loader)
# 5. テスト関数
def test(model, device, test_loader):
model.eval() # 評価モードに設定
correct = 0
total = 0
with torch.no_grad(): # 勾配計算を無効化
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
# 最も確率の高いクラスを予測
_, predicted = torch.max(output.data, 1)
total += target.size(0)
correct += (predicted == target).sum().item()
accuracy = 100 * correct / total
print(f'Test Accuracy: {accuracy:.2f}%')
return accuracy
# 6. 訓練実行
epochs = 5
for epoch in range(1, epochs + 1):
train_loss = train(model, device, train_loader, optimizer, criterion, epoch)
print(f'Average Loss: {train_loss:.4f}')
test(model, device, test_loader)
print('-' * 60)
# モデルの保存
torch.save(model.state_dict(), 'mnist_cnn.pth')
主要な構成要素の解説
1. CNNモデル構造
入力 (1x28x28)
↓
Conv2d (32フィルター) → ReLU → MaxPool → (32x14x14)
↓
Conv2d (64フィルター) → ReLU → MaxPool → (64x7x7)
↓
Flatten → (3136)
↓
FC (128) → ReLU → Dropout
↓
FC (10) → 出力
2. 重要なパラメータ
kernel_size=3: 3×3の畳み込みフィルターpadding=1: 画像サイズを維持MaxPool2d(2,2): 2×2領域の最大値を取得Dropout(0.5): 過学習防止
3. 実行結果
$ python3 mnist+cnn+classification.py
Epoch: 1, Batch: 0, Loss: 2.3096
Epoch: 1, Batch: 100, Loss: 0.3755
Epoch: 1, Batch: 200, Loss: 0.4064
Epoch: 1, Batch: 300, Loss: 0.2006
Epoch: 1, Batch: 400, Loss: 0.1787
Epoch: 1, Batch: 500, Loss: 0.0984
Epoch: 1, Batch: 600, Loss: 0.1776
Epoch: 1, Batch: 700, Loss: 0.2356
Epoch: 1, Batch: 800, Loss: 0.1497
Epoch: 1, Batch: 900, Loss: 0.1300
Average Loss: 0.2377
Test Accuracy: 98.61%
------------------------------------------------------------
Epoch: 2, Batch: 0, Loss: 0.1457
Epoch: 2, Batch: 100, Loss: 0.1019
Epoch: 2, Batch: 200, Loss: 0.0407
Epoch: 2, Batch: 300, Loss: 0.0687
Epoch: 2, Batch: 400, Loss: 0.0562
Epoch: 2, Batch: 500, Loss: 0.0583
Epoch: 2, Batch: 600, Loss: 0.0361
Epoch: 2, Batch: 700, Loss: 0.0554
Epoch: 2, Batch: 800, Loss: 0.0757
Epoch: 2, Batch: 900, Loss: 0.2820
Average Loss: 0.0859
Test Accuracy: 98.98%
------------------------------------------------------------
Epoch: 3, Batch: 0, Loss: 0.0496
Epoch: 3, Batch: 100, Loss: 0.1323
Epoch: 3, Batch: 200, Loss: 0.0146
Epoch: 3, Batch: 300, Loss: 0.0297
Epoch: 3, Batch: 400, Loss: 0.0217
Epoch: 3, Batch: 500, Loss: 0.0470
Epoch: 3, Batch: 600, Loss: 0.0499
Epoch: 3, Batch: 700, Loss: 0.0439
Epoch: 3, Batch: 800, Loss: 0.0967
Epoch: 3, Batch: 900, Loss: 0.0390
Average Loss: 0.0642
Test Accuracy: 99.00%
------------------------------------------------------------
Epoch: 4, Batch: 0, Loss: 0.0106
Epoch: 4, Batch: 100, Loss: 0.0114
Epoch: 4, Batch: 200, Loss: 0.0156
Epoch: 4, Batch: 300, Loss: 0.0550
Epoch: 4, Batch: 400, Loss: 0.0288
Epoch: 4, Batch: 500, Loss: 0.0282
Epoch: 4, Batch: 600, Loss: 0.1245
Epoch: 4, Batch: 700, Loss: 0.0610
Epoch: 4, Batch: 800, Loss: 0.0127
Epoch: 4, Batch: 900, Loss: 0.0365
Average Loss: 0.0516
Test Accuracy: 99.12%
------------------------------------------------------------
Epoch: 5, Batch: 0, Loss: 0.0130
Epoch: 5, Batch: 100, Loss: 0.0412
Epoch: 5, Batch: 200, Loss: 0.0173
Epoch: 5, Batch: 300, Loss: 0.0064
Epoch: 5, Batch: 400, Loss: 0.0599
Epoch: 5, Batch: 500, Loss: 0.0848
Epoch: 5, Batch: 600, Loss: 0.0542
Epoch: 5, Batch: 700, Loss: 0.0545
Epoch: 5, Batch: 800, Loss: 0.0207
Epoch: 5, Batch: 900, Loss: 0.0759
Average Loss: 0.0424
Test Accuracy: 99.20%
------------------------------------------------------------
このコードで99.20%の精度が達成できます!