PyTorch Lightningで画像分類
PyTorch Lightningで画像分類のサンプルコード解説
PyTorch Lightningを使った画像分類の完全なサンプルコードを解説します。
完全なサンプルコード
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
# 1. モデル定義
class ImageClassifier(pl.LightningModule):
def __init__(self, num_classes=10, learning_rate=1e-3):
super().__init__()
self.save_hyperparameters() # ハイパーパラメータを自動保存
# 簡単なCNNモデル
self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(128 * 4 * 4, 512)
self.fc2 = nn.Linear(512, num_classes)
self.dropout = nn.Dropout(0.5)
def forward(self, x):
# 順伝播の定義
x = self.pool(F.relu(self.conv1(x))) # 32x32 -> 16x16
x = self.pool(F.relu(self.conv2(x))) # 16x16 -> 8x8
x = self.pool(F.relu(self.conv3(x))) # 8x8 -> 4x4
x = x.view(-1, 128 * 4 * 4)
x = self.dropout(F.relu(self.fc1(x)))
x = self.fc2(x)
return x
def training_step(self, batch, batch_idx):
# 訓練時の1ステップ
x, y = batch
logits = self(x)
loss = F.cross_entropy(logits, y)
# 精度計算
preds = torch.argmax(logits, dim=1)
acc = (preds == y).float().mean()
# ログ記録
self.log('train_loss', loss, prog_bar=True)
self.log('train_acc', acc, prog_bar=True)
return loss
def validation_step(self, batch, batch_idx):
# 検証時の1ステップ
x, y = batch
logits = self(x)
loss = F.cross_entropy(logits, y)
preds = torch.argmax(logits, dim=1)
acc = (preds == y).float().mean()
self.log('val_loss', loss, prog_bar=True)
self.log('val_acc', acc, prog_bar=True)
return loss
def test_step(self, batch, batch_idx):
# テスト時の1ステップ
x, y = batch
logits = self(x)
loss = F.cross_entropy(logits, y)
preds = torch.argmax(logits, dim=1)
acc = (preds == y).float().mean()
self.log('test_loss', loss)
self.log('test_acc', acc)
return loss
def configure_optimizers(self):
# オプティマイザとスケジューラの設定
optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode='min', factor=0.5, patience=3
)
return {
'optimizer': optimizer,
'lr_scheduler': {
'scheduler': scheduler,
'monitor': 'val_loss'
}
}
# 2. データモジュール定義
class CIFAR10DataModule(pl.LightningDataModule):
def __init__(self, data_dir='./data', batch_size=32, num_workers=4):
super().__init__()
self.data_dir = data_dir
self.batch_size = batch_size
self.num_workers = num_workers
# データ前処理
self.transform_train = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, padding=4),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
self.transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
def prepare_data(self):
# データのダウンロード(1度だけ実行)
datasets.CIFAR10(self.data_dir, train=True, download=True)
datasets.CIFAR10(self.data_dir, train=False, download=True)
def setup(self, stage=None):
# データセットの設定
if stage == 'fit' or stage is None:
cifar_full = datasets.CIFAR10(
self.data_dir, train=True, transform=self.transform_train
)
# 訓練データと検証データに分割
self.cifar_train, self.cifar_val = random_split(
cifar_full, [45000, 5000]
)
if stage == 'test' or stage is None:
self.cifar_test = datasets.CIFAR10(
self.data_dir, train=False, transform=self.transform_test
)
def train_dataloader(self):
return DataLoader(
self.cifar_train,
batch_size=self.batch_size,
shuffle=True,
num_workers=self.num_workers
)
def val_dataloader(self):
return DataLoader(
self.cifar_val,
batch_size=self.batch_size,
num_workers=self.num_workers
)
def test_dataloader(self):
return DataLoader(
self.cifar_test,
batch_size=self.batch_size,
num_workers=self.num_workers
)
# 3. 訓練実行
def main():
# データモジュール作成
dm = CIFAR10DataModule(batch_size=64, num_workers=4)
# モデル作成
model = ImageClassifier(num_classes=10, learning_rate=1e-3)
# コールバック設定
checkpoint_callback = ModelCheckpoint(
monitor='val_loss',
dirpath='checkpoints/',
filename='cifar10-{epoch:02d}-{val_loss:.2f}',
save_top_k=3,
mode='min'
)
early_stop_callback = EarlyStopping(
monitor='val_loss',
patience=5,
mode='min'
)
# Trainer設定
trainer = pl.Trainer(
max_epochs=20,
accelerator='auto', # 自動でGPU/CPU選択
devices=1,
callbacks=[checkpoint_callback, early_stop_callback],
log_every_n_steps=10
)
# 訓練実行
trainer.fit(model, dm)
# テスト実行
trainer.test(model, dm)
if __name__ == '__main__':
main()
主要な構成要素の解説
1. LightningModule (モデル定義)
__init__: モデルの層を定義forward: 順伝播処理training_step: 訓練時の処理(損失計算など)validation_step: 検証時の処理configure_optimizers: オプティマイザ設定
2. LightningDataModule (データ管理)
prepare_data: データダウンロードsetup: データセット分割train/val/test_dataloader: データローダー提供
3. Trainer (訓練管理)
- エポック数、GPU設定、コールバックなどを統合管理
4. 実行の結果
$ python pytorch-lightning-cnn-classification.py
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
| Name | Type | Params | Mode
----------------------------------------------
0 | conv1 | Conv2d | 896 | train
1 | conv2 | Conv2d | 18.5 K | train
2 | conv3 | Conv2d | 73.9 K | train
3 | pool | MaxPool2d | 0 | train
4 | fc1 | Linear | 1.0 M | train
5 | fc2 | Linear | 5.1 K | train
6 | dropout | Dropout | 0 | train
----------------------------------------------
1.1 M Trainable params
0 Non-trainable params
1.1 M Total params
4.590 Total estimated model params size (MB)
7 Modules in train mode
0 Modules in eval mode
Epoch 19: 100%|███████████████████████████████████| 704/704 [00:03<00:00, 195.75it/s, v_num=0, train_loss=0.0607, train_acc=1.000, val_loss=0.604, val_acc=0.790]`Trainer.fit` stopped: `max_epochs=20` reached.
Epoch 19: 100%|███████████████████████████████████| 704/704 [00:03<00:00, 193.97it/s, v_num=0, train_loss=0.0607, train_acc=1.000, val_loss=0.604, val_acc=0.790]LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Testing DataLoader 0: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 157/157 [00:00<00:00, 413.72it/s]
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric ┃ DataLoader 0 ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ test_acc │ 0.8090999722480774 │
│ test_loss │ 0.566657304763794 │
└───────────────────────────┴───────────────────────────┘
PyTorch Lightningの利点
- コードが整理される: 訓練ループを書く必要なし
- GPU対応が簡単:
accelerator='auto'だけ - 再現性が高い: ハイパーパラメータ自動保存
- ログ管理が楽: TensorBoard等に自動記録
このコードをそのまま実行すれば、CIFAR-10での画像分類が動きます!