CNN Variational Autoencoder
PyTorch CNN VAEのサンプルコード解説
MNISTデータセットを使ったCNN版Variational Autoencoder(VAE)のコードを解説します。
完全なサンプルコード
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
# ハイパーパラメータ
BATCH_SIZE = 128
EPOCHS = 10
LEARNING_RATE = 1e-3
LATENT_DIM = 20 # 潜在変数の次元数
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# データセットの準備
transform = transforms.Compose([
transforms.ToTensor(),
])
train_dataset = datasets.MNIST('~/.pytorch/data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
# VAEモデルの定義
class VAE(nn.Module):
def __init__(self, latent_dim=20):
super(VAE, self).__init__()
# エンコーダー(画像 → 潜在変数)
self.encoder = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1), # 28x28 -> 14x14
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1), # 14x14 -> 7x7
nn.ReLU(),
nn.Flatten(), # 64*7*7 = 3136
)
# 潜在変数の平均と分散を出力
self.fc_mu = nn.Linear(64*7*7, latent_dim)
self.fc_logvar = nn.Linear(64*7*7, latent_dim)
# デコーダー(潜在変数 → 画像)
self.decoder_input = nn.Linear(latent_dim, 64*7*7)
self.decoder = nn.Sequential(
nn.Unflatten(1, (64, 7, 7)), # 3136 -> 64x7x7
nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1), # 7x7 -> 14x14
nn.ReLU(),
nn.ConvTranspose2d(32, 1, kernel_size=3, stride=2, padding=1, output_padding=1), # 14x14 -> 28x28
nn.Sigmoid(), # [0, 1]の範囲に正規化
)
def encode(self, x):
"""エンコーダー部分"""
h = self.encoder(x)
mu = self.fc_mu(h)
logvar = self.fc_logvar(h)
return mu, logvar
def reparameterize(self, mu, logvar):
"""再パラメータ化トリック: z = μ + σ * ε"""
std = torch.exp(0.5 * logvar) # 標準偏差
eps = torch.randn_like(std) # 標準正規分布からサンプリング
z = mu + eps * std
return z
def decode(self, z):
"""デコーダー部分"""
h = self.decoder_input(z)
reconstruction = self.decoder(h)
return reconstruction
def forward(self, x):
"""順伝播"""
mu, logvar = self.encode(x)
z = self.reparameterize(mu, logvar)
reconstruction = self.decode(z)
return reconstruction, mu, logvar
# 損失関数の定義
def vae_loss(recon_x, x, mu, logvar):
"""
VAEの損失関数 = 再構成誤差 + KLダイバージェンス
"""
# 再構成誤差(Binary Cross Entropy)
recon_loss = F.binary_cross_entropy(recon_x, x, reduction='sum')
# KLダイバージェンス
# KL(N(μ,σ²) || N(0,1)) = -0.5 * Σ(1 + log(σ²) - μ² - σ²)
kl_divergence = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return recon_loss + kl_divergence
# モデルの初期化
model = VAE(latent_dim=LATENT_DIM).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
# 学習ループ
def train(epoch):
model.train()
train_loss = 0
for batch_idx, (data, _) in enumerate(train_loader):
data = data.to(device)
optimizer.zero_grad()
# 順伝播
recon_batch, mu, logvar = model(data)
# 損失計算
loss = vae_loss(recon_batch, data, mu, logvar)
# 逆伝播
loss.backward()
train_loss += loss.item()
optimizer.step()
if batch_idx % 100 == 0:
print(f'Epoch {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)}] '
f'Loss: {loss.item() / len(data):.4f}')
avg_loss = train_loss / len(train_loader.dataset)
print(f'====> Epoch: {epoch} Average loss: {avg_loss:.4f}')
# 学習実行
for epoch in range(1, EPOCHS + 1):
train(epoch)
# 生成画像の可視化
model.eval()
with torch.no_grad():
# ランダムサンプリングから生成
z = torch.randn(64, LATENT_DIM).to(device)
sample = model.decode(z).cpu()
# 画像表示
fig, axes = plt.subplots(8, 8, figsize=(10, 10))
for i, ax in enumerate(axes.flat):
ax.imshow(sample[i].squeeze(), cmap='gray')
ax.axis('off')
plt.tight_layout()
plt.savefig('vae_generated_samples.png')
plt.show()
# 再構成画像の可視化
with torch.no_grad():
data, _ = next(iter(train_loader))
data = data[:8].to(device)
recon, _, _ = model(data)
# 元画像と再構成画像を比較
fig, axes = plt.subplots(2, 8, figsize=(15, 4))
for i in range(8):
# 元画像
axes[0, i].imshow(data[i].cpu().squeeze(), cmap='gray')
axes[0, i].axis('off')
# 再構成画像
axes[1, i].imshow(recon[i].cpu().squeeze(), cmap='gray')
axes[1, i].axis('off')
axes[0, 0].set_ylabel('Original', size=20)
axes[1, 0].set_ylabel('Reconstructed', size=20)
plt.tight_layout()
plt.savefig('vae_reconstruction.png')
plt.show()
主要な構成要素の解説
1. エンコーダー(Encoder)
self.encoder = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.Flatten(),
)
- 入力画像を畳み込み層で特徴抽出
- 28×28 → 14×14 → 7×7 と縮小
- 平均(μ)と分散(logvar)を出力
2. 再パラメータ化トリック
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
z = mu + eps * std
return z
- 目的: 確率的なサンプリングでも勾配が伝播できるようにする
- ε ~ N(0,1) を使って z = μ + σε と変換
3. デコーダー(Decoder)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(64, 32, ...),
nn.ConvTranspose2d(32, 1, ...),
nn.Sigmoid(),
)
- 潜在変数から画像を再構成
- ConvTranspose2d(転置畳み込み)で画像を拡大
4. VAE損失関数
loss = 再構成誤差 + KLダイバージェンス
- 再構成誤差: 元画像と再構成画像の差
- KLダイバージェンス: 潜在変数分布をN(0,1)に近づける正則化項
実行結果
このコードを実行すると:
- MNISTデータセットで学習
$ python cnn-variational-autoencoder.py
Epoch 1 [0/60000] Loss: 666.3545
Epoch 1 [12800/60000] Loss: 177.2791
Epoch 1 [25600/60000] Loss: 147.5417
Epoch 1 [38400/60000] Loss: 132.0728
Epoch 1 [51200/60000] Loss: 127.7308
====> Epoch: 1 Average loss: 172.7988
Epoch 2 [0/60000] Loss: 122.6681
Epoch 2 [12800/60000] Loss: 120.1804
Epoch 2 [25600/60000] Loss: 119.0149
Epoch 2 [38400/60000] Loss: 118.5517
Epoch 2 [51200/60000] Loss: 118.3539
====> Epoch: 2 Average loss: 121.4129
Epoch 3 [0/60000] Loss: 121.5787
Epoch 3 [12800/60000] Loss: 117.8588
Epoch 3 [25600/60000] Loss: 113.7935
Epoch 3 [38400/60000] Loss: 117.5502
Epoch 3 [51200/60000] Loss: 111.4696
====> Epoch: 3 Average loss: 115.4343
Epoch 4 [0/60000] Loss: 110.1155
Epoch 4 [12800/60000] Loss: 111.0482
Epoch 4 [25600/60000] Loss: 114.4668
Epoch 4 [38400/60000] Loss: 113.0274
Epoch 4 [51200/60000] Loss: 113.2623
...
====> Epoch: 7 Average loss: 106.9868
Epoch 8 [0/60000] Loss: 102.2845
Epoch 8 [12800/60000] Loss: 106.8972
Epoch 8 [25600/60000] Loss: 104.2490
Epoch 8 [38400/60000] Loss: 106.9311
Epoch 8 [51200/60000] Loss: 103.4750
====> Epoch: 8 Average loss: 106.2541
Epoch 9 [0/60000] Loss: 105.2472
Epoch 9 [12800/60000] Loss: 105.3693
Epoch 9 [25600/60000] Loss: 109.0917
Epoch 9 [38400/60000] Loss: 105.7611
Epoch 9 [51200/60000] Loss: 108.3643
====> Epoch: 9 Average loss: 105.6413
Epoch 10 [0/60000] Loss: 107.4038
Epoch 10 [12800/60000] Loss: 102.9440
Epoch 10 [25600/60000] Loss: 106.6951
Epoch 10 [38400/60000] Loss: 106.7149
Epoch 10 [51200/60000] Loss: 107.3842
====> Epoch: 10 Average loss: 105.2085
vae_generated_samples.png: ランダム生成された数字画像
vae_reconstruction.png: 元画像と再構成画像の比較
VAEは生成モデルなので、学習後に新しい手書き数字を生成できます!