Pytorchで敵対生成ネットワーク(GAN)
PyTorch + GAN + MNISTサンプルコードの詳細解説
完全なサンプルコード
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
# デバイスの設定
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'使用デバイス: {device}')
# ハイパーパラメータ
latent_dim = 100 # 潜在空間の次元数
img_size = 28 # 画像サイズ
channels = 1 # チャンネル数(グレースケール)
batch_size = 128
learning_rate = 0.0002
num_epochs = 50
beta1 = 0.5 # Adam最適化のβ1パラメータ
# データの準備
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]) # [-1, 1]に正規化
])
train_dataset = datasets.MNIST(
root='./data',
train=True,
download=True,
transform=transform
)
dataloader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=2
)
# Generatorの定義
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
# ノイズから画像を生成するネットワーク
self.model = nn.Sequential(
# 入力: latent_dim次元のノイズベクトル
nn.Linear(latent_dim, 256),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(256, 512),
nn.BatchNorm1d(512),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 1024),
nn.BatchNorm1d(1024),
nn.LeakyReLU(0.2, inplace=True),
# 出力: 28*28 = 784次元
nn.Linear(1024, img_size * img_size * channels),
nn.Tanh() # [-1, 1]の範囲に出力
)
def forward(self, z):
img = self.model(z)
img = img.view(img.size(0), channels, img_size, img_size)
return img
# Discriminatorの定義
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
# 画像が本物か偽物かを判定するネットワーク
self.model = nn.Sequential(
# 入力: 28*28 = 784次元の画像
nn.Linear(img_size * img_size * channels, 512),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 256),
nn.LeakyReLU(0.2, inplace=True),
# 出力: 本物である確率
nn.Linear(256, 1),
nn.Sigmoid()
)
def forward(self, img):
img_flat = img.view(img.size(0), -1)
validity = self.model(img_flat)
return validity
# モデルのインスタンス化
generator = Generator().to(device)
discriminator = Discriminator().to(device)
# 損失関数
adversarial_loss = nn.BCELoss()
# オプティマイザー
optimizer_G = optim.Adam(generator.parameters(), lr=learning_rate, betas=(beta1, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=learning_rate, betas=(beta1, 0.999))
# 学習ループ
print("学習開始...")
for epoch in range(num_epochs):
for i, (real_imgs, _) in enumerate(dataloader):
# ラベルの準備
batch_size_current = real_imgs.size(0)
real_imgs = real_imgs.to(device)
# 本物と偽物のラベル
real_labels = torch.ones(batch_size_current, 1).to(device)
fake_labels = torch.zeros(batch_size_current, 1).to(device)
# ---------------------
# Discriminatorの学習
# ---------------------
optimizer_D.zero_grad()
# 本物の画像に対する損失
real_output = discriminator(real_imgs)
d_loss_real = adversarial_loss(real_output, real_labels)
# 偽物の画像を生成
z = torch.randn(batch_size_current, latent_dim).to(device)
fake_imgs = generator(z)
# 偽物の画像に対する損失
fake_output = discriminator(fake_imgs.detach())
d_loss_fake = adversarial_loss(fake_output, fake_labels)
# 合計損失
d_loss = d_loss_real + d_loss_fake
d_loss.backward()
optimizer_D.step()
# -----------------
# Generatorの学習
# -----------------
optimizer_G.zero_grad()
# Generatorは偽物をDiscriminatorに本物と判定させたい
fake_output = discriminator(fake_imgs)
g_loss = adversarial_loss(fake_output, real_labels)
g_loss.backward()
optimizer_G.step()
# 進捗表示
if i % 100 == 0:
print(f"[Epoch {epoch}/{num_epochs}] [Batch {i}/{len(dataloader)}] "
f"[D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}]")
# エポック終了ごとに生成画像を保存
if epoch % 5 == 0:
with torch.no_grad():
z = torch.randn(16, latent_dim).to(device)
generated_imgs = generator(z).cpu()
fig, axes = plt.subplots(4, 4, figsize=(8, 8))
for idx, ax in enumerate(axes.flat):
img = generated_imgs[idx].squeeze().numpy()
img = (img + 1) / 2 # [-1, 1] -> [0, 1]
ax.imshow(img, cmap='gray')
ax.axis('off')
plt.suptitle(f'Epoch {epoch}')
plt.tight_layout()
plt.savefig(f'generated_epoch_{epoch}.png')
plt.close()
print("学習完了!")
# 最終的な生成画像の表示
with torch.no_grad():
z = torch.randn(25, latent_dim).to(device)
generated_imgs = generator(z).cpu()
fig, axes = plt.subplots(5, 5, figsize=(10, 10))
for idx, ax in enumerate(axes.flat):
img = generated_imgs[idx].squeeze().numpy()
img = (img + 1) / 2
ax.imshow(img, cmap='gray')
ax.axis('off')
plt.suptitle('最終生成画像')
plt.tight_layout()
plt.savefig('final_generated_images.png')
plt.show()
詳細解説
1. GANの基本概念
Generator (生成器) Discriminator (識別器)
↓ ↓
偽の画像を生成 本物/偽物を判定
↓ ↓
互いに競争しながら学習
2. データの正規化
transforms.Normalize([0.5], [0.5])
- MNIST画像を
[-1, 1]の範囲に正規化 - 計算式:
(x - 0.5) / 0.5
3. Generator(生成器)の役割
入力: ランダムノイズ (100次元)
↓
全結合層 + 活性化
↓
徐々に次元を拡大
↓
出力: 28×28の画像 (784次元)
重要なポイント:
LeakyReLU: 負の値も少し通す(勾配消失を防ぐ)BatchNorm: 学習の安定化Tanh: 出力を[-1, 1]に制限
4. Discriminator(識別器)の役割
入力: 28×28の画像
↓
次元を徐々に削減
↓
出力: 0~1のスコア(本物らしさ)
5. 学習プロセス
Discriminatorの学習
# 1. 本物の画像を「本物(1)」と判定するように学習
d_loss_real = loss(D(real_imgs), 1)
# 2. 偽物の画像を「偽物(0)」と判定するように学習
d_loss_fake = loss(D(G(z)), 0)
# 3. 合計損失で更新
d_loss = d_loss_real + d_loss_fake
Generatorの学習
# Discriminatorを騙すように学習
# 生成画像を「本物(1)」と判定させたい
g_loss = loss(D(G(z)), 1)
6. 重要な技術的ポイント
detach()の使用
fake_output = discriminator(fake_imgs.detach())
- Generatorへの勾配伝播を遮断
- Discriminator学習時にGeneratorが更新されないようにする
交互学習
1回のイテレーションで:
1. Discriminatorを更新
2. Generatorを更新
7. 損失関数: BCELoss
BCE = -[y*log(ŷ) + (1-y)*log(1-ŷ)]
y=1(本物)の時:-log(ŷ)→ ŷ=1に近づけたいy=0(偽物)の時:-log(1-ŷ)→ ŷ=0に近づけたい
8. 改良版: DCGAN風のConvolutional GAN
class ConvGenerator(nn.Module):
def __init__(self):
super(ConvGenerator, self).__init__()
self.init_size = img_size // 4 # 7
self.fc = nn.Linear(latent_dim, 128 * self.init_size ** 2)
self.conv_blocks = nn.Sequential(
nn.BatchNorm2d(128),
nn.Upsample(scale_factor=2),
nn.Conv2d(128, 128, 3, stride=1, padding=1),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace=True),
nn.Upsample(scale_factor=2),
nn.Conv2d(128, 64, 3, stride=1, padding=1),
nn.BatchNorm2d(64),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, channels, 3, stride=1, padding=1),
nn.Tanh(),
)
def forward(self, z):
out = self.fc(z)
out = out.view(out.size(0), 128, self.init_size, self.init_size)
img = self.conv_blocks(out)
return img
class ConvDiscriminator(nn.Module):
def __init__(self):
super(ConvDiscriminator, self).__init__()
self.model = nn.Sequential(
nn.Conv2d(channels, 32, 3, 2, 1),
nn.LeakyReLU(0.2, inplace=True),
nn.Dropout2d(0.25),
nn.Conv2d(32, 64, 3, 2, 1),
nn.BatchNorm2d(64),
nn.LeakyReLU(0.2, inplace=True),
nn.Dropout2d(0.25),
nn.Conv2d(64, 128, 3, 2, 1),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace=True),
nn.Dropout2d(0.25),
)
self.adv_layer = nn.Sequential(
nn.Linear(128 * 4 * 4, 1),
nn.Sigmoid()
)
def forward(self, img):
out = self.model(img)
out = out.view(out.size(0), -1)
validity = self.adv_layer(out)
return validity
9. 学習のTips
- 学習率の調整: 0.0002が標準的
- Beta1の調整: 0.5がGANでは効果的
- Label Smoothing:
real_labels = 0.9などにすると安定 - Noisy Labels: ラベルに少しノイズを加える
10. よくある問題と対処法
| 問題 | 対処法 |
|---|---|
| Mode Collapse | Minibatch Discrimination追加 |
| 学習が不安定 | Spectral Normalization使用 |
| 生成画像が低品質 | より深いネットワーク、WGAN使用 |
このコードを実行すると、数字を生成するGANが学習されます!