Vision Transformer(ViT)画像分類
PyTorch Vision Transformer画像分類サンプルコード解説
Vision Transformerの実装と解説をします。
1. 基本的な実装
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# ===== パッチ埋め込み層 =====
class PatchEmbedding(nn.Module):
"""画像をパッチに分割し、埋め込みベクトルに変換"""
def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.n_patches = (img_size // patch_size) ** 2
# 畳み込みでパッチ埋め込みを実現
self.proj = nn.Conv2d(
in_channels,
embed_dim,
kernel_size=patch_size,
stride=patch_size
)
def forward(self, x):
# x: (B, C, H, W) → (B, embed_dim, n_patches**0.5, n_patches**0.5)
x = self.proj(x)
# (B, embed_dim, H', W') → (B, embed_dim, n_patches)
x = x.flatten(2)
# (B, embed_dim, n_patches) → (B, n_patches, embed_dim)
x = x.transpose(1, 2)
return x
# ===== Multi-Head Attention =====
class MultiHeadAttention(nn.Module):
def __init__(self, embed_dim=768, num_heads=12, dropout=0.1):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.scale = self.head_dim ** -0.5
# Query, Key, Value の線形変換
self.qkv = nn.Linear(embed_dim, embed_dim * 3)
self.proj = nn.Linear(embed_dim, embed_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
B, N, C = x.shape
# QKV計算
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
qkv = qkv.permute(2, 0, 3, 1, 4) # (3, B, num_heads, N, head_dim)
q, k, v = qkv[0], qkv[1], qkv[2]
# Attention計算
attn = (q @ k.transpose(-2, -1)) * self.scale # (B, num_heads, N, N)
attn = attn.softmax(dim=-1)
attn = self.dropout(attn)
# 値と結合
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.dropout(x)
return x
# ===== MLP (Feed Forward Network) =====
class MLP(nn.Module):
def __init__(self, embed_dim=768, mlp_ratio=4.0, dropout=0.1):
super().__init__()
hidden_dim = int(embed_dim * mlp_ratio)
self.fc1 = nn.Linear(embed_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, embed_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
x = self.fc1(x)
x = F.gelu(x)
x = self.dropout(x)
x = self.fc2(x)
x = self.dropout(x)
return x
# ===== Transformer Block =====
class TransformerBlock(nn.Module):
def __init__(self, embed_dim=768, num_heads=12, mlp_ratio=4.0, dropout=0.1):
super().__init__()
self.norm1 = nn.LayerNorm(embed_dim)
self.attn = MultiHeadAttention(embed_dim, num_heads, dropout)
self.norm2 = nn.LayerNorm(embed_dim)
self.mlp = MLP(embed_dim, mlp_ratio, dropout)
def forward(self, x):
# Pre-Norm構造
x = x + self.attn(self.norm1(x))
x = x + self.mlp(self.norm2(x))
return x
# ===== Vision Transformer =====
class VisionTransformer(nn.Module):
def __init__(
self,
img_size=224,
patch_size=16,
in_channels=3,
num_classes=10,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.0,
dropout=0.1
):
super().__init__()
# パッチ埋め込み
self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
num_patches = self.patch_embed.n_patches
# CLSトークン (分類用の特別なトークン)
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
# 位置埋め込み
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
self.pos_drop = nn.Dropout(dropout)
# Transformer Blocks
self.blocks = nn.ModuleList([
TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout)
for _ in range(depth)
])
# 分類ヘッド
self.norm = nn.LayerNorm(embed_dim)
self.head = nn.Linear(embed_dim, num_classes)
# 重み初期化
nn.init.trunc_normal_(self.pos_embed, std=0.02)
nn.init.trunc_normal_(self.cls_token, std=0.02)
def forward(self, x):
B = x.shape[0]
# パッチ埋め込み
x = self.patch_embed(x) # (B, n_patches, embed_dim)
# CLSトークンを追加
cls_tokens = self.cls_token.expand(B, -1, -1) # (B, 1, embed_dim)
x = torch.cat([cls_tokens, x], dim=1) # (B, n_patches+1, embed_dim)
# 位置埋め込みを追加
x = x + self.pos_embed
x = self.pos_drop(x)
# Transformer Blocksを通過
for block in self.blocks:
x = block(x)
# 正規化
x = self.norm(x)
# CLSトークンのみを使用して分類
cls_token_final = x[:, 0]
logits = self.head(cls_token_final)
return logits
2. 訓練コード
# ===== データ準備 =====
def get_dataloaders(batch_size=32):
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
train_dataset = datasets.CIFAR10(
root='./data',
train=True,
download=True,
transform=transform
)
test_dataset = datasets.CIFAR10(
root='./data',
train=False,
download=True,
transform=transform
)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
return train_loader, test_loader
# ===== 訓練関数 =====
def train_one_epoch(model, dataloader, criterion, optimizer, device):
model.train()
running_loss = 0.0
correct = 0
total = 0
for images, labels in dataloader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
epoch_loss = running_loss / len(dataloader)
epoch_acc = 100. * correct / total
return epoch_loss, epoch_acc
# ===== 評価関数 =====
def evaluate(model, dataloader, criterion, device):
model.eval()
running_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
for images, labels in dataloader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
running_loss += loss.item()
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
epoch_loss = running_loss / len(dataloader)
epoch_acc = 100. * correct / total
return epoch_loss, epoch_acc
# ===== メイン実行 =====
def main():
# ハイパーパラメータ
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_epochs = 50
batch_size = 64
learning_rate = 3e-4
# モデル作成(小型版)
model = VisionTransformer(
img_size=224,
patch_size=16,
in_channels=3,
num_classes=10,
embed_dim=384, # 小さめ
depth=6, # 浅め
num_heads=6,
mlp_ratio=4.0,
dropout=0.1
).to(device)
print(f"モデルパラメータ数: {sum(p.numel() for p in model.parameters()):,}")
# データローダー
train_loader, test_loader = get_dataloaders(batch_size)
# 損失関数とオプティマイザ
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.05)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
# 訓練ループ
for epoch in range(num_epochs):
train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, device)
test_loss, test_acc = evaluate(model, test_loader, criterion, device)
scheduler.step()
print(f"Epoch [{epoch+1}/{num_epochs}]")
print(f" Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
print(f" Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%")
# モデル保存
torch.save(model.state_dict(), 'vit_model.pth')
print("訓練完了!")
if __name__ == "__main__":
main()
3. 事前学習済みモデルの使用
# torchvisionの事前学習済みViTを使う簡単な方法
from torchvision.models import vit_b_16, ViT_B_16_Weights
def use_pretrained_vit():
# 事前学習済みモデルをロード
weights = ViT_B_16_Weights.IMAGENET1K_V1
model = vit_b_16(weights=weights)
# ファインチューニング用にヘッドを置き換え
num_classes = 10 # CIFAR-10
model.heads = nn.Linear(model.hidden_dim, num_classes)
# 前処理も取得
preprocess = weights.transforms()
return model, preprocess
# 使用例
model, preprocess = use_pretrained_vit()
4. 実行の結果
$ python vision-transformer-classification.py
モデルパラメータ数: 11,022,730
Epoch [1/50]
Train Loss: 1.7205, Train Acc: 36.19%
Test Loss: 1.5259, Test Acc: 44.82%
Epoch [2/50]
Train Loss: 1.4613, Train Acc: 46.57%
Test Loss: 1.3900, Test Acc: 49.07%
...
Epoch [46/50]
Train Loss: 0.0055, Train Acc: 99.86%
Test Loss: 1.6752, Test Acc: 74.55%
Epoch [47/50]
Train Loss: 0.0057, Train Acc: 99.82%
Test Loss: 1.6772, Test Acc: 74.43%
Epoch [48/50]
Train Loss: 0.0044, Train Acc: 99.89%
Test Loss: 1.6697, Test Acc: 74.76%
Epoch [49/50]
Train Loss: 0.0041, Train Acc: 99.90%
Test Loss: 1.6655, Test Acc: 74.78%
Epoch [50/50]
Train Loss: 0.0043, Train Acc: 99.89%
Test Loss: 1.6649, Test Acc: 74.76%
訓練完了!
主要な構成要素の解説
- パッチ埋め込み: 画像を16×16などのパッチに分割し、ベクトル化
- CLSトークン: 分類に使用する特別なトークン
- 位置埋め込み: パッチの位置情報を学習
- Transformer: Self-Attentionで画像全体の関係性を捉える
- 分類ヘッド: CLSトークンから最終的な予測を出力
このコードでCIFAR-10での画像分類が実行できます!