<?xml version="1.0" encoding="utf-8" standalone="yes"?><rss version="2.0" xmlns:atom="http://www.w3.org/2005/Atom"><channel><title>Posts on 機械学習の備備忘</title><link>https://ml.askbox.net/posts/</link><description>Recent content in Posts on 機械学習の備備忘</description><generator>Hugo</generator><language>ja</language><lastBuildDate>Thu, 19 Feb 2026 18:10:52 +0900</lastBuildDate><atom:link href="https://ml.askbox.net/posts/index.xml" rel="self" type="application/rss+xml"/><item><title>CNN Variational Autoencoder</title><link>https://ml.askbox.net/posts/cnn-variational-autoencoder/</link><pubDate>Thu, 19 Feb 2026 18:10:52 +0900</pubDate><guid>https://ml.askbox.net/posts/cnn-variational-autoencoder/</guid><description>&lt;h2 id="pytorch-cnn-vaeのサンプルコード解説"&gt;PyTorch CNN VAEのサンプルコード解説&lt;/h2&gt;
&lt;p&gt;MNISTデータセットを使ったCNN版Variational Autoencoder（VAE）のコードを解説します。&lt;/p&gt;
&lt;h2 id="完全なサンプルコード"&gt;完全なサンプルコード&lt;/h2&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;-webkit-text-size-adjust:none;"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#f92672"&gt;import&lt;/span&gt; torch
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#f92672"&gt;import&lt;/span&gt; torch.nn &lt;span style="color:#66d9ef"&gt;as&lt;/span&gt; nn
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#f92672"&gt;import&lt;/span&gt; torch.nn.functional &lt;span style="color:#66d9ef"&gt;as&lt;/span&gt; F
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#f92672"&gt;from&lt;/span&gt; torch.utils.data &lt;span style="color:#f92672"&gt;import&lt;/span&gt; DataLoader
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#f92672"&gt;from&lt;/span&gt; torchvision &lt;span style="color:#f92672"&gt;import&lt;/span&gt; datasets, transforms
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#f92672"&gt;import&lt;/span&gt; matplotlib.pyplot &lt;span style="color:#66d9ef"&gt;as&lt;/span&gt; plt
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#75715e"&gt;# ハイパーパラメータ&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;BATCH_SIZE &lt;span style="color:#f92672"&gt;=&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;128&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;EPOCHS &lt;span style="color:#f92672"&gt;=&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;10&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;LEARNING_RATE &lt;span style="color:#f92672"&gt;=&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;1e-3&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;LATENT_DIM &lt;span style="color:#f92672"&gt;=&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;20&lt;/span&gt; &lt;span style="color:#75715e"&gt;# 潜在変数の次元数&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;device &lt;span style="color:#f92672"&gt;=&lt;/span&gt; torch&lt;span style="color:#f92672"&gt;.&lt;/span&gt;device(&lt;span style="color:#e6db74"&gt;&amp;#34;cuda&amp;#34;&lt;/span&gt; &lt;span style="color:#66d9ef"&gt;if&lt;/span&gt; torch&lt;span style="color:#f92672"&gt;.&lt;/span&gt;cuda&lt;span style="color:#f92672"&gt;.&lt;/span&gt;is_available() &lt;span style="color:#66d9ef"&gt;else&lt;/span&gt; &lt;span style="color:#e6db74"&gt;&amp;#34;cpu&amp;#34;&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#75715e"&gt;# データセットの準備&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;transform &lt;span style="color:#f92672"&gt;=&lt;/span&gt; transforms&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Compose([
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; transforms&lt;span style="color:#f92672"&gt;.&lt;/span&gt;ToTensor(),
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;])
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;train_dataset &lt;span style="color:#f92672"&gt;=&lt;/span&gt; datasets&lt;span style="color:#f92672"&gt;.&lt;/span&gt;MNIST(&lt;span style="color:#e6db74"&gt;&amp;#39;~/.pytorch/data&amp;#39;&lt;/span&gt;, train&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#66d9ef"&gt;True&lt;/span&gt;, download&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#66d9ef"&gt;True&lt;/span&gt;, transform&lt;span style="color:#f92672"&gt;=&lt;/span&gt;transform)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;train_loader &lt;span style="color:#f92672"&gt;=&lt;/span&gt; DataLoader(train_dataset, batch_size&lt;span style="color:#f92672"&gt;=&lt;/span&gt;BATCH_SIZE, shuffle&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#66d9ef"&gt;True&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#75715e"&gt;# VAEモデルの定義&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#66d9ef"&gt;class&lt;/span&gt; &lt;span style="color:#a6e22e"&gt;VAE&lt;/span&gt;(nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Module):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;def&lt;/span&gt; &lt;span style="color:#a6e22e"&gt;__init__&lt;/span&gt;(self, latent_dim&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;20&lt;/span&gt;):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; super(VAE, self)&lt;span style="color:#f92672"&gt;.&lt;/span&gt;&lt;span style="color:#a6e22e"&gt;__init__&lt;/span&gt;()
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# エンコーダー（画像 → 潜在変数）&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;encoder &lt;span style="color:#f92672"&gt;=&lt;/span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Sequential(
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Conv2d(&lt;span style="color:#ae81ff"&gt;1&lt;/span&gt;, &lt;span style="color:#ae81ff"&gt;32&lt;/span&gt;, kernel_size&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;3&lt;/span&gt;, stride&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;2&lt;/span&gt;, padding&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;1&lt;/span&gt;), &lt;span style="color:#75715e"&gt;# 28x28 -&amp;gt; 14x14&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;ReLU(),
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Conv2d(&lt;span style="color:#ae81ff"&gt;32&lt;/span&gt;, &lt;span style="color:#ae81ff"&gt;64&lt;/span&gt;, kernel_size&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;3&lt;/span&gt;, stride&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;2&lt;/span&gt;, padding&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;1&lt;/span&gt;), &lt;span style="color:#75715e"&gt;# 14x14 -&amp;gt; 7x7&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;ReLU(),
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Flatten(), &lt;span style="color:#75715e"&gt;# 64*7*7 = 3136&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; )
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# 潜在変数の平均と分散を出力&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;fc_mu &lt;span style="color:#f92672"&gt;=&lt;/span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Linear(&lt;span style="color:#ae81ff"&gt;64&lt;/span&gt;&lt;span style="color:#f92672"&gt;*&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;7&lt;/span&gt;&lt;span style="color:#f92672"&gt;*&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;7&lt;/span&gt;, latent_dim)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;fc_logvar &lt;span style="color:#f92672"&gt;=&lt;/span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Linear(&lt;span style="color:#ae81ff"&gt;64&lt;/span&gt;&lt;span style="color:#f92672"&gt;*&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;7&lt;/span&gt;&lt;span style="color:#f92672"&gt;*&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;7&lt;/span&gt;, latent_dim)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# デコーダー（潜在変数 → 画像）&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;decoder_input &lt;span style="color:#f92672"&gt;=&lt;/span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Linear(latent_dim, &lt;span style="color:#ae81ff"&gt;64&lt;/span&gt;&lt;span style="color:#f92672"&gt;*&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;7&lt;/span&gt;&lt;span style="color:#f92672"&gt;*&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;7&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;decoder &lt;span style="color:#f92672"&gt;=&lt;/span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Sequential(
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Unflatten(&lt;span style="color:#ae81ff"&gt;1&lt;/span&gt;, (&lt;span style="color:#ae81ff"&gt;64&lt;/span&gt;, &lt;span style="color:#ae81ff"&gt;7&lt;/span&gt;, &lt;span style="color:#ae81ff"&gt;7&lt;/span&gt;)), &lt;span style="color:#75715e"&gt;# 3136 -&amp;gt; 64x7x7&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;ConvTranspose2d(&lt;span style="color:#ae81ff"&gt;64&lt;/span&gt;, &lt;span style="color:#ae81ff"&gt;32&lt;/span&gt;, kernel_size&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;3&lt;/span&gt;, stride&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;2&lt;/span&gt;, padding&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;1&lt;/span&gt;, output_padding&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;1&lt;/span&gt;), &lt;span style="color:#75715e"&gt;# 7x7 -&amp;gt; 14x14&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;ReLU(),
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;ConvTranspose2d(&lt;span style="color:#ae81ff"&gt;32&lt;/span&gt;, &lt;span style="color:#ae81ff"&gt;1&lt;/span&gt;, kernel_size&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;3&lt;/span&gt;, stride&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;2&lt;/span&gt;, padding&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;1&lt;/span&gt;, output_padding&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;1&lt;/span&gt;), &lt;span style="color:#75715e"&gt;# 14x14 -&amp;gt; 28x28&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Sigmoid(), &lt;span style="color:#75715e"&gt;# [0, 1]の範囲に正規化&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; )
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;def&lt;/span&gt; &lt;span style="color:#a6e22e"&gt;encode&lt;/span&gt;(self, x):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#e6db74"&gt;&amp;#34;&amp;#34;&amp;#34;エンコーダー部分&amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; h &lt;span style="color:#f92672"&gt;=&lt;/span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;encoder(x)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; mu &lt;span style="color:#f92672"&gt;=&lt;/span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;fc_mu(h)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; logvar &lt;span style="color:#f92672"&gt;=&lt;/span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;fc_logvar(h)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;return&lt;/span&gt; mu, logvar
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;def&lt;/span&gt; &lt;span style="color:#a6e22e"&gt;reparameterize&lt;/span&gt;(self, mu, logvar):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#e6db74"&gt;&amp;#34;&amp;#34;&amp;#34;再パラメータ化トリック: z = μ + σ * ε&amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; std &lt;span style="color:#f92672"&gt;=&lt;/span&gt; torch&lt;span style="color:#f92672"&gt;.&lt;/span&gt;exp(&lt;span style="color:#ae81ff"&gt;0.5&lt;/span&gt; &lt;span style="color:#f92672"&gt;*&lt;/span&gt; logvar) &lt;span style="color:#75715e"&gt;# 標準偏差&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; eps &lt;span style="color:#f92672"&gt;=&lt;/span&gt; torch&lt;span style="color:#f92672"&gt;.&lt;/span&gt;randn_like(std) &lt;span style="color:#75715e"&gt;# 標準正規分布からサンプリング&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; z &lt;span style="color:#f92672"&gt;=&lt;/span&gt; mu &lt;span style="color:#f92672"&gt;+&lt;/span&gt; eps &lt;span style="color:#f92672"&gt;*&lt;/span&gt; std
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;return&lt;/span&gt; z
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;def&lt;/span&gt; &lt;span style="color:#a6e22e"&gt;decode&lt;/span&gt;(self, z):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#e6db74"&gt;&amp;#34;&amp;#34;&amp;#34;デコーダー部分&amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; h &lt;span style="color:#f92672"&gt;=&lt;/span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;decoder_input(z)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; reconstruction &lt;span style="color:#f92672"&gt;=&lt;/span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;decoder(h)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;return&lt;/span&gt; reconstruction
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;def&lt;/span&gt; &lt;span style="color:#a6e22e"&gt;forward&lt;/span&gt;(self, x):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#e6db74"&gt;&amp;#34;&amp;#34;&amp;#34;順伝播&amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; mu, logvar &lt;span style="color:#f92672"&gt;=&lt;/span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;encode(x)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; z &lt;span style="color:#f92672"&gt;=&lt;/span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;reparameterize(mu, logvar)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; reconstruction &lt;span style="color:#f92672"&gt;=&lt;/span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;decode(z)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;return&lt;/span&gt; reconstruction, mu, logvar
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#75715e"&gt;# 損失関数の定義&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#66d9ef"&gt;def&lt;/span&gt; &lt;span style="color:#a6e22e"&gt;vae_loss&lt;/span&gt;(recon_x, x, mu, logvar):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#e6db74"&gt;&amp;#34;&amp;#34;&amp;#34;
&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#e6db74"&gt; VAEの損失関数 = 再構成誤差 + KLダイバージェンス
&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#e6db74"&gt; &amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# 再構成誤差（Binary Cross Entropy）&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; recon_loss &lt;span style="color:#f92672"&gt;=&lt;/span&gt; F&lt;span style="color:#f92672"&gt;.&lt;/span&gt;binary_cross_entropy(recon_x, x, reduction&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#e6db74"&gt;&amp;#39;sum&amp;#39;&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# KLダイバージェンス&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# KL(N(μ,σ²) || N(0,1)) = -0.5 * Σ(1 + log(σ²) - μ² - σ²)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; kl_divergence &lt;span style="color:#f92672"&gt;=&lt;/span&gt; &lt;span style="color:#f92672"&gt;-&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;0.5&lt;/span&gt; &lt;span style="color:#f92672"&gt;*&lt;/span&gt; torch&lt;span style="color:#f92672"&gt;.&lt;/span&gt;sum(&lt;span style="color:#ae81ff"&gt;1&lt;/span&gt; &lt;span style="color:#f92672"&gt;+&lt;/span&gt; logvar &lt;span style="color:#f92672"&gt;-&lt;/span&gt; mu&lt;span style="color:#f92672"&gt;.&lt;/span&gt;pow(&lt;span style="color:#ae81ff"&gt;2&lt;/span&gt;) &lt;span style="color:#f92672"&gt;-&lt;/span&gt; logvar&lt;span style="color:#f92672"&gt;.&lt;/span&gt;exp())
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;return&lt;/span&gt; recon_loss &lt;span style="color:#f92672"&gt;+&lt;/span&gt; kl_divergence
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#75715e"&gt;# モデルの初期化&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;model &lt;span style="color:#f92672"&gt;=&lt;/span&gt; VAE(latent_dim&lt;span style="color:#f92672"&gt;=&lt;/span&gt;LATENT_DIM)&lt;span style="color:#f92672"&gt;.&lt;/span&gt;to(device)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;optimizer &lt;span style="color:#f92672"&gt;=&lt;/span&gt; torch&lt;span style="color:#f92672"&gt;.&lt;/span&gt;optim&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Adam(model&lt;span style="color:#f92672"&gt;.&lt;/span&gt;parameters(), lr&lt;span style="color:#f92672"&gt;=&lt;/span&gt;LEARNING_RATE)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#75715e"&gt;# 学習ループ&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#66d9ef"&gt;def&lt;/span&gt; &lt;span style="color:#a6e22e"&gt;train&lt;/span&gt;(epoch):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; model&lt;span style="color:#f92672"&gt;.&lt;/span&gt;train()
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; train_loss &lt;span style="color:#f92672"&gt;=&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;0&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;for&lt;/span&gt; batch_idx, (data, _) &lt;span style="color:#f92672"&gt;in&lt;/span&gt; enumerate(train_loader):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; data &lt;span style="color:#f92672"&gt;=&lt;/span&gt; data&lt;span style="color:#f92672"&gt;.&lt;/span&gt;to(device)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; optimizer&lt;span style="color:#f92672"&gt;.&lt;/span&gt;zero_grad()
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# 順伝播&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; recon_batch, mu, logvar &lt;span style="color:#f92672"&gt;=&lt;/span&gt; model(data)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# 損失計算&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; loss &lt;span style="color:#f92672"&gt;=&lt;/span&gt; vae_loss(recon_batch, data, mu, logvar)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# 逆伝播&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; loss&lt;span style="color:#f92672"&gt;.&lt;/span&gt;backward()
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; train_loss &lt;span style="color:#f92672"&gt;+=&lt;/span&gt; loss&lt;span style="color:#f92672"&gt;.&lt;/span&gt;item()
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; optimizer&lt;span style="color:#f92672"&gt;.&lt;/span&gt;step()
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;if&lt;/span&gt; batch_idx &lt;span style="color:#f92672"&gt;%&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;100&lt;/span&gt; &lt;span style="color:#f92672"&gt;==&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;0&lt;/span&gt;:
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; print(&lt;span style="color:#e6db74"&gt;f&lt;/span&gt;&lt;span style="color:#e6db74"&gt;&amp;#39;Epoch &lt;/span&gt;&lt;span style="color:#e6db74"&gt;{&lt;/span&gt;epoch&lt;span style="color:#e6db74"&gt;}&lt;/span&gt;&lt;span style="color:#e6db74"&gt; [&lt;/span&gt;&lt;span style="color:#e6db74"&gt;{&lt;/span&gt;batch_idx &lt;span style="color:#f92672"&gt;*&lt;/span&gt; len(data)&lt;span style="color:#e6db74"&gt;}&lt;/span&gt;&lt;span style="color:#e6db74"&gt;/&lt;/span&gt;&lt;span style="color:#e6db74"&gt;{&lt;/span&gt;len(train_loader&lt;span style="color:#f92672"&gt;.&lt;/span&gt;dataset)&lt;span style="color:#e6db74"&gt;}&lt;/span&gt;&lt;span style="color:#e6db74"&gt;] &amp;#39;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#e6db74"&gt;f&lt;/span&gt;&lt;span style="color:#e6db74"&gt;&amp;#39;Loss: &lt;/span&gt;&lt;span style="color:#e6db74"&gt;{&lt;/span&gt;loss&lt;span style="color:#f92672"&gt;.&lt;/span&gt;item() &lt;span style="color:#f92672"&gt;/&lt;/span&gt; len(data)&lt;span style="color:#e6db74"&gt;:&lt;/span&gt;&lt;span style="color:#e6db74"&gt;.4f&lt;/span&gt;&lt;span style="color:#e6db74"&gt;}&lt;/span&gt;&lt;span style="color:#e6db74"&gt;&amp;#39;&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; avg_loss &lt;span style="color:#f92672"&gt;=&lt;/span&gt; train_loss &lt;span style="color:#f92672"&gt;/&lt;/span&gt; len(train_loader&lt;span style="color:#f92672"&gt;.&lt;/span&gt;dataset)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; print(&lt;span style="color:#e6db74"&gt;f&lt;/span&gt;&lt;span style="color:#e6db74"&gt;&amp;#39;====&amp;gt; Epoch: &lt;/span&gt;&lt;span style="color:#e6db74"&gt;{&lt;/span&gt;epoch&lt;span style="color:#e6db74"&gt;}&lt;/span&gt;&lt;span style="color:#e6db74"&gt; Average loss: &lt;/span&gt;&lt;span style="color:#e6db74"&gt;{&lt;/span&gt;avg_loss&lt;span style="color:#e6db74"&gt;:&lt;/span&gt;&lt;span style="color:#e6db74"&gt;.4f&lt;/span&gt;&lt;span style="color:#e6db74"&gt;}&lt;/span&gt;&lt;span style="color:#e6db74"&gt;&amp;#39;&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#75715e"&gt;# 学習実行&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#66d9ef"&gt;for&lt;/span&gt; epoch &lt;span style="color:#f92672"&gt;in&lt;/span&gt; range(&lt;span style="color:#ae81ff"&gt;1&lt;/span&gt;, EPOCHS &lt;span style="color:#f92672"&gt;+&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;1&lt;/span&gt;):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; train(epoch)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#75715e"&gt;# 生成画像の可視化&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;model&lt;span style="color:#f92672"&gt;.&lt;/span&gt;eval()
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#66d9ef"&gt;with&lt;/span&gt; torch&lt;span style="color:#f92672"&gt;.&lt;/span&gt;no_grad():
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# ランダムサンプリングから生成&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; z &lt;span style="color:#f92672"&gt;=&lt;/span&gt; torch&lt;span style="color:#f92672"&gt;.&lt;/span&gt;randn(&lt;span style="color:#ae81ff"&gt;64&lt;/span&gt;, LATENT_DIM)&lt;span style="color:#f92672"&gt;.&lt;/span&gt;to(device)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; sample &lt;span style="color:#f92672"&gt;=&lt;/span&gt; model&lt;span style="color:#f92672"&gt;.&lt;/span&gt;decode(z)&lt;span style="color:#f92672"&gt;.&lt;/span&gt;cpu()
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# 画像表示&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; fig, axes &lt;span style="color:#f92672"&gt;=&lt;/span&gt; plt&lt;span style="color:#f92672"&gt;.&lt;/span&gt;subplots(&lt;span style="color:#ae81ff"&gt;8&lt;/span&gt;, &lt;span style="color:#ae81ff"&gt;8&lt;/span&gt;, figsize&lt;span style="color:#f92672"&gt;=&lt;/span&gt;(&lt;span style="color:#ae81ff"&gt;10&lt;/span&gt;, &lt;span style="color:#ae81ff"&gt;10&lt;/span&gt;))
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;for&lt;/span&gt; i, ax &lt;span style="color:#f92672"&gt;in&lt;/span&gt; enumerate(axes&lt;span style="color:#f92672"&gt;.&lt;/span&gt;flat):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; ax&lt;span style="color:#f92672"&gt;.&lt;/span&gt;imshow(sample[i]&lt;span style="color:#f92672"&gt;.&lt;/span&gt;squeeze(), cmap&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#e6db74"&gt;&amp;#39;gray&amp;#39;&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; ax&lt;span style="color:#f92672"&gt;.&lt;/span&gt;axis(&lt;span style="color:#e6db74"&gt;&amp;#39;off&amp;#39;&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; plt&lt;span style="color:#f92672"&gt;.&lt;/span&gt;tight_layout()
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; plt&lt;span style="color:#f92672"&gt;.&lt;/span&gt;savefig(&lt;span style="color:#e6db74"&gt;&amp;#39;vae_generated_samples.png&amp;#39;&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; plt&lt;span style="color:#f92672"&gt;.&lt;/span&gt;show()
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#75715e"&gt;# 再構成画像の可視化&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#66d9ef"&gt;with&lt;/span&gt; torch&lt;span style="color:#f92672"&gt;.&lt;/span&gt;no_grad():
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; data, _ &lt;span style="color:#f92672"&gt;=&lt;/span&gt; next(iter(train_loader))
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; data &lt;span style="color:#f92672"&gt;=&lt;/span&gt; data[:&lt;span style="color:#ae81ff"&gt;8&lt;/span&gt;]&lt;span style="color:#f92672"&gt;.&lt;/span&gt;to(device)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; recon, _, _ &lt;span style="color:#f92672"&gt;=&lt;/span&gt; model(data)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# 元画像と再構成画像を比較&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; fig, axes &lt;span style="color:#f92672"&gt;=&lt;/span&gt; plt&lt;span style="color:#f92672"&gt;.&lt;/span&gt;subplots(&lt;span style="color:#ae81ff"&gt;2&lt;/span&gt;, &lt;span style="color:#ae81ff"&gt;8&lt;/span&gt;, figsize&lt;span style="color:#f92672"&gt;=&lt;/span&gt;(&lt;span style="color:#ae81ff"&gt;15&lt;/span&gt;, &lt;span style="color:#ae81ff"&gt;4&lt;/span&gt;))
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;for&lt;/span&gt; i &lt;span style="color:#f92672"&gt;in&lt;/span&gt; range(&lt;span style="color:#ae81ff"&gt;8&lt;/span&gt;):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# 元画像&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; axes[&lt;span style="color:#ae81ff"&gt;0&lt;/span&gt;, i]&lt;span style="color:#f92672"&gt;.&lt;/span&gt;imshow(data[i]&lt;span style="color:#f92672"&gt;.&lt;/span&gt;cpu()&lt;span style="color:#f92672"&gt;.&lt;/span&gt;squeeze(), cmap&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#e6db74"&gt;&amp;#39;gray&amp;#39;&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; axes[&lt;span style="color:#ae81ff"&gt;0&lt;/span&gt;, i]&lt;span style="color:#f92672"&gt;.&lt;/span&gt;axis(&lt;span style="color:#e6db74"&gt;&amp;#39;off&amp;#39;&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# 再構成画像&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; axes[&lt;span style="color:#ae81ff"&gt;1&lt;/span&gt;, i]&lt;span style="color:#f92672"&gt;.&lt;/span&gt;imshow(recon[i]&lt;span style="color:#f92672"&gt;.&lt;/span&gt;cpu()&lt;span style="color:#f92672"&gt;.&lt;/span&gt;squeeze(), cmap&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#e6db74"&gt;&amp;#39;gray&amp;#39;&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; axes[&lt;span style="color:#ae81ff"&gt;1&lt;/span&gt;, i]&lt;span style="color:#f92672"&gt;.&lt;/span&gt;axis(&lt;span style="color:#e6db74"&gt;&amp;#39;off&amp;#39;&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; axes[&lt;span style="color:#ae81ff"&gt;0&lt;/span&gt;, &lt;span style="color:#ae81ff"&gt;0&lt;/span&gt;]&lt;span style="color:#f92672"&gt;.&lt;/span&gt;set_ylabel(&lt;span style="color:#e6db74"&gt;&amp;#39;Original&amp;#39;&lt;/span&gt;, size&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;20&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; axes[&lt;span style="color:#ae81ff"&gt;1&lt;/span&gt;, &lt;span style="color:#ae81ff"&gt;0&lt;/span&gt;]&lt;span style="color:#f92672"&gt;.&lt;/span&gt;set_ylabel(&lt;span style="color:#e6db74"&gt;&amp;#39;Reconstructed&amp;#39;&lt;/span&gt;, size&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;20&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; plt&lt;span style="color:#f92672"&gt;.&lt;/span&gt;tight_layout()
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; plt&lt;span style="color:#f92672"&gt;.&lt;/span&gt;savefig(&lt;span style="color:#e6db74"&gt;&amp;#39;vae_reconstruction.png&amp;#39;&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; plt&lt;span style="color:#f92672"&gt;.&lt;/span&gt;show()
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;h2 id="主要な構成要素の解説"&gt;主要な構成要素の解説&lt;/h2&gt;
&lt;h3 id="1-エンコーダーencoder"&gt;1. &lt;strong&gt;エンコーダー（Encoder）&lt;/strong&gt;&lt;/h3&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;-webkit-text-size-adjust:none;"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;encoder &lt;span style="color:#f92672"&gt;=&lt;/span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Sequential(
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Conv2d(&lt;span style="color:#ae81ff"&gt;1&lt;/span&gt;, &lt;span style="color:#ae81ff"&gt;32&lt;/span&gt;, kernel_size&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;3&lt;/span&gt;, stride&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;2&lt;/span&gt;, padding&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;1&lt;/span&gt;),
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;ReLU(),
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Conv2d(&lt;span style="color:#ae81ff"&gt;32&lt;/span&gt;, &lt;span style="color:#ae81ff"&gt;64&lt;/span&gt;, kernel_size&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;3&lt;/span&gt;, stride&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;2&lt;/span&gt;, padding&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;1&lt;/span&gt;),
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;ReLU(),
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Flatten(),
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;ul&gt;
&lt;li&gt;入力画像を畳み込み層で特徴抽出&lt;/li&gt;
&lt;li&gt;28×28 → 14×14 → 7×7 と縮小&lt;/li&gt;
&lt;li&gt;平均（μ）と分散（logvar）を出力&lt;/li&gt;
&lt;/ul&gt;
&lt;h3 id="2-再パラメータ化トリック"&gt;2. &lt;strong&gt;再パラメータ化トリック&lt;/strong&gt;&lt;/h3&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;-webkit-text-size-adjust:none;"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#66d9ef"&gt;def&lt;/span&gt; &lt;span style="color:#a6e22e"&gt;reparameterize&lt;/span&gt;(self, mu, logvar):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; std &lt;span style="color:#f92672"&gt;=&lt;/span&gt; torch&lt;span style="color:#f92672"&gt;.&lt;/span&gt;exp(&lt;span style="color:#ae81ff"&gt;0.5&lt;/span&gt; &lt;span style="color:#f92672"&gt;*&lt;/span&gt; logvar)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; eps &lt;span style="color:#f92672"&gt;=&lt;/span&gt; torch&lt;span style="color:#f92672"&gt;.&lt;/span&gt;randn_like(std)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; z &lt;span style="color:#f92672"&gt;=&lt;/span&gt; mu &lt;span style="color:#f92672"&gt;+&lt;/span&gt; eps &lt;span style="color:#f92672"&gt;*&lt;/span&gt; std
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;return&lt;/span&gt; z
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;目的&lt;/strong&gt;: 確率的なサンプリングでも勾配が伝播できるようにする&lt;/li&gt;
&lt;li&gt;ε ~ N(0,1) を使って z = μ + σε と変換&lt;/li&gt;
&lt;/ul&gt;
&lt;h3 id="3-デコーダーdecoder"&gt;3. &lt;strong&gt;デコーダー（Decoder）&lt;/strong&gt;&lt;/h3&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;-webkit-text-size-adjust:none;"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;decoder &lt;span style="color:#f92672"&gt;=&lt;/span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Sequential(
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;ConvTranspose2d(&lt;span style="color:#ae81ff"&gt;64&lt;/span&gt;, &lt;span style="color:#ae81ff"&gt;32&lt;/span&gt;, &lt;span style="color:#f92672"&gt;...&lt;/span&gt;),
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;ConvTranspose2d(&lt;span style="color:#ae81ff"&gt;32&lt;/span&gt;, &lt;span style="color:#ae81ff"&gt;1&lt;/span&gt;, &lt;span style="color:#f92672"&gt;...&lt;/span&gt;),
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Sigmoid(),
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;ul&gt;
&lt;li&gt;潜在変数から画像を再構成&lt;/li&gt;
&lt;li&gt;ConvTranspose2d（転置畳み込み）で画像を拡大&lt;/li&gt;
&lt;/ul&gt;
&lt;h3 id="4-vae損失関数"&gt;4. &lt;strong&gt;VAE損失関数&lt;/strong&gt;&lt;/h3&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;-webkit-text-size-adjust:none;"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;loss &lt;span style="color:#f92672"&gt;=&lt;/span&gt; 再構成誤差 &lt;span style="color:#f92672"&gt;+&lt;/span&gt; KLダイバージェンス
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;再構成誤差&lt;/strong&gt;: 元画像と再構成画像の差&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;KLダイバージェンス&lt;/strong&gt;: 潜在変数分布をN(0,1)に近づける正則化項&lt;/li&gt;
&lt;/ul&gt;
&lt;h2 id="実行結果"&gt;実行結果&lt;/h2&gt;
&lt;p&gt;このコードを実行すると：&lt;/p&gt;</description></item><item><title>Pytorchで敵対生成ネットワーク(GAN)</title><link>https://ml.askbox.net/posts/pytorch-gan-mnist/</link><pubDate>Tue, 10 Feb 2026 20:34:54 +0900</pubDate><guid>https://ml.askbox.net/posts/pytorch-gan-mnist/</guid><description>&lt;h2 id="pytorch--gan--mnistサンプルコードの詳細解説"&gt;PyTorch + GAN + MNISTサンプルコードの詳細解説&lt;/h2&gt;
&lt;h2 id="完全なサンプルコード"&gt;完全なサンプルコード&lt;/h2&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;-webkit-text-size-adjust:none;"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#f92672"&gt;import&lt;/span&gt; torch
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#f92672"&gt;import&lt;/span&gt; torch.nn &lt;span style="color:#66d9ef"&gt;as&lt;/span&gt; nn
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#f92672"&gt;import&lt;/span&gt; torch.optim &lt;span style="color:#66d9ef"&gt;as&lt;/span&gt; optim
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#f92672"&gt;from&lt;/span&gt; torchvision &lt;span style="color:#f92672"&gt;import&lt;/span&gt; datasets, transforms
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#f92672"&gt;from&lt;/span&gt; torch.utils.data &lt;span style="color:#f92672"&gt;import&lt;/span&gt; DataLoader
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#f92672"&gt;import&lt;/span&gt; matplotlib.pyplot &lt;span style="color:#66d9ef"&gt;as&lt;/span&gt; plt
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#f92672"&gt;import&lt;/span&gt; numpy &lt;span style="color:#66d9ef"&gt;as&lt;/span&gt; np
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#75715e"&gt;# デバイスの設定&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;device &lt;span style="color:#f92672"&gt;=&lt;/span&gt; torch&lt;span style="color:#f92672"&gt;.&lt;/span&gt;device(&lt;span style="color:#e6db74"&gt;&amp;#39;cuda&amp;#39;&lt;/span&gt; &lt;span style="color:#66d9ef"&gt;if&lt;/span&gt; torch&lt;span style="color:#f92672"&gt;.&lt;/span&gt;cuda&lt;span style="color:#f92672"&gt;.&lt;/span&gt;is_available() &lt;span style="color:#66d9ef"&gt;else&lt;/span&gt; &lt;span style="color:#e6db74"&gt;&amp;#39;cpu&amp;#39;&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;print(&lt;span style="color:#e6db74"&gt;f&lt;/span&gt;&lt;span style="color:#e6db74"&gt;&amp;#39;使用デバイス: &lt;/span&gt;&lt;span style="color:#e6db74"&gt;{&lt;/span&gt;device&lt;span style="color:#e6db74"&gt;}&lt;/span&gt;&lt;span style="color:#e6db74"&gt;&amp;#39;&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#75715e"&gt;# ハイパーパラメータ&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;latent_dim &lt;span style="color:#f92672"&gt;=&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;100&lt;/span&gt; &lt;span style="color:#75715e"&gt;# 潜在空間の次元数&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;img_size &lt;span style="color:#f92672"&gt;=&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;28&lt;/span&gt; &lt;span style="color:#75715e"&gt;# 画像サイズ&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;channels &lt;span style="color:#f92672"&gt;=&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;1&lt;/span&gt; &lt;span style="color:#75715e"&gt;# チャンネル数（グレースケール）&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;batch_size &lt;span style="color:#f92672"&gt;=&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;128&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;learning_rate &lt;span style="color:#f92672"&gt;=&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;0.0002&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;num_epochs &lt;span style="color:#f92672"&gt;=&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;50&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;beta1 &lt;span style="color:#f92672"&gt;=&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;0.5&lt;/span&gt; &lt;span style="color:#75715e"&gt;# Adam最適化のβ1パラメータ&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#75715e"&gt;# データの準備&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;transform &lt;span style="color:#f92672"&gt;=&lt;/span&gt; transforms&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Compose([
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; transforms&lt;span style="color:#f92672"&gt;.&lt;/span&gt;ToTensor(),
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; transforms&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Normalize([&lt;span style="color:#ae81ff"&gt;0.5&lt;/span&gt;], [&lt;span style="color:#ae81ff"&gt;0.5&lt;/span&gt;]) &lt;span style="color:#75715e"&gt;# [-1, 1]に正規化&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;])
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;train_dataset &lt;span style="color:#f92672"&gt;=&lt;/span&gt; datasets&lt;span style="color:#f92672"&gt;.&lt;/span&gt;MNIST(
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; root&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#e6db74"&gt;&amp;#39;./data&amp;#39;&lt;/span&gt;,
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; train&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#66d9ef"&gt;True&lt;/span&gt;,
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; download&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#66d9ef"&gt;True&lt;/span&gt;,
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; transform&lt;span style="color:#f92672"&gt;=&lt;/span&gt;transform
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;dataloader &lt;span style="color:#f92672"&gt;=&lt;/span&gt; DataLoader(
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; train_dataset,
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; batch_size&lt;span style="color:#f92672"&gt;=&lt;/span&gt;batch_size,
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; shuffle&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#66d9ef"&gt;True&lt;/span&gt;,
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; num_workers&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;2&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#75715e"&gt;# Generatorの定義&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#66d9ef"&gt;class&lt;/span&gt; &lt;span style="color:#a6e22e"&gt;Generator&lt;/span&gt;(nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Module):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;def&lt;/span&gt; &lt;span style="color:#a6e22e"&gt;__init__&lt;/span&gt;(self):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; super(Generator, self)&lt;span style="color:#f92672"&gt;.&lt;/span&gt;&lt;span style="color:#a6e22e"&gt;__init__&lt;/span&gt;()
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# ノイズから画像を生成するネットワーク&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;model &lt;span style="color:#f92672"&gt;=&lt;/span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Sequential(
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# 入力: latent_dim次元のノイズベクトル&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Linear(latent_dim, &lt;span style="color:#ae81ff"&gt;256&lt;/span&gt;),
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;LeakyReLU(&lt;span style="color:#ae81ff"&gt;0.2&lt;/span&gt;, inplace&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#66d9ef"&gt;True&lt;/span&gt;),
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Linear(&lt;span style="color:#ae81ff"&gt;256&lt;/span&gt;, &lt;span style="color:#ae81ff"&gt;512&lt;/span&gt;),
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;BatchNorm1d(&lt;span style="color:#ae81ff"&gt;512&lt;/span&gt;),
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;LeakyReLU(&lt;span style="color:#ae81ff"&gt;0.2&lt;/span&gt;, inplace&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#66d9ef"&gt;True&lt;/span&gt;),
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Linear(&lt;span style="color:#ae81ff"&gt;512&lt;/span&gt;, &lt;span style="color:#ae81ff"&gt;1024&lt;/span&gt;),
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;BatchNorm1d(&lt;span style="color:#ae81ff"&gt;1024&lt;/span&gt;),
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;LeakyReLU(&lt;span style="color:#ae81ff"&gt;0.2&lt;/span&gt;, inplace&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#66d9ef"&gt;True&lt;/span&gt;),
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# 出力: 28*28 = 784次元&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Linear(&lt;span style="color:#ae81ff"&gt;1024&lt;/span&gt;, img_size &lt;span style="color:#f92672"&gt;*&lt;/span&gt; img_size &lt;span style="color:#f92672"&gt;*&lt;/span&gt; channels),
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Tanh() &lt;span style="color:#75715e"&gt;# [-1, 1]の範囲に出力&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; )
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;def&lt;/span&gt; &lt;span style="color:#a6e22e"&gt;forward&lt;/span&gt;(self, z):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; img &lt;span style="color:#f92672"&gt;=&lt;/span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;model(z)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; img &lt;span style="color:#f92672"&gt;=&lt;/span&gt; img&lt;span style="color:#f92672"&gt;.&lt;/span&gt;view(img&lt;span style="color:#f92672"&gt;.&lt;/span&gt;size(&lt;span style="color:#ae81ff"&gt;0&lt;/span&gt;), channels, img_size, img_size)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;return&lt;/span&gt; img
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#75715e"&gt;# Discriminatorの定義&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#66d9ef"&gt;class&lt;/span&gt; &lt;span style="color:#a6e22e"&gt;Discriminator&lt;/span&gt;(nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Module):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;def&lt;/span&gt; &lt;span style="color:#a6e22e"&gt;__init__&lt;/span&gt;(self):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; super(Discriminator, self)&lt;span style="color:#f92672"&gt;.&lt;/span&gt;&lt;span style="color:#a6e22e"&gt;__init__&lt;/span&gt;()
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# 画像が本物か偽物かを判定するネットワーク&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;model &lt;span style="color:#f92672"&gt;=&lt;/span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Sequential(
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# 入力: 28*28 = 784次元の画像&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Linear(img_size &lt;span style="color:#f92672"&gt;*&lt;/span&gt; img_size &lt;span style="color:#f92672"&gt;*&lt;/span&gt; channels, &lt;span style="color:#ae81ff"&gt;512&lt;/span&gt;),
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;LeakyReLU(&lt;span style="color:#ae81ff"&gt;0.2&lt;/span&gt;, inplace&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#66d9ef"&gt;True&lt;/span&gt;),
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Linear(&lt;span style="color:#ae81ff"&gt;512&lt;/span&gt;, &lt;span style="color:#ae81ff"&gt;256&lt;/span&gt;),
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;LeakyReLU(&lt;span style="color:#ae81ff"&gt;0.2&lt;/span&gt;, inplace&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#66d9ef"&gt;True&lt;/span&gt;),
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# 出力: 本物である確率&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Linear(&lt;span style="color:#ae81ff"&gt;256&lt;/span&gt;, &lt;span style="color:#ae81ff"&gt;1&lt;/span&gt;),
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Sigmoid()
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; )
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;def&lt;/span&gt; &lt;span style="color:#a6e22e"&gt;forward&lt;/span&gt;(self, img):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; img_flat &lt;span style="color:#f92672"&gt;=&lt;/span&gt; img&lt;span style="color:#f92672"&gt;.&lt;/span&gt;view(img&lt;span style="color:#f92672"&gt;.&lt;/span&gt;size(&lt;span style="color:#ae81ff"&gt;0&lt;/span&gt;), &lt;span style="color:#f92672"&gt;-&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;1&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; validity &lt;span style="color:#f92672"&gt;=&lt;/span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;model(img_flat)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;return&lt;/span&gt; validity
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#75715e"&gt;# モデルのインスタンス化&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;generator &lt;span style="color:#f92672"&gt;=&lt;/span&gt; Generator()&lt;span style="color:#f92672"&gt;.&lt;/span&gt;to(device)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;discriminator &lt;span style="color:#f92672"&gt;=&lt;/span&gt; Discriminator()&lt;span style="color:#f92672"&gt;.&lt;/span&gt;to(device)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#75715e"&gt;# 損失関数&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;adversarial_loss &lt;span style="color:#f92672"&gt;=&lt;/span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;BCELoss()
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#75715e"&gt;# オプティマイザー&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;optimizer_G &lt;span style="color:#f92672"&gt;=&lt;/span&gt; optim&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Adam(generator&lt;span style="color:#f92672"&gt;.&lt;/span&gt;parameters(), lr&lt;span style="color:#f92672"&gt;=&lt;/span&gt;learning_rate, betas&lt;span style="color:#f92672"&gt;=&lt;/span&gt;(beta1, &lt;span style="color:#ae81ff"&gt;0.999&lt;/span&gt;))
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;optimizer_D &lt;span style="color:#f92672"&gt;=&lt;/span&gt; optim&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Adam(discriminator&lt;span style="color:#f92672"&gt;.&lt;/span&gt;parameters(), lr&lt;span style="color:#f92672"&gt;=&lt;/span&gt;learning_rate, betas&lt;span style="color:#f92672"&gt;=&lt;/span&gt;(beta1, &lt;span style="color:#ae81ff"&gt;0.999&lt;/span&gt;))
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#75715e"&gt;# 学習ループ&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;print(&lt;span style="color:#e6db74"&gt;&amp;#34;学習開始...&amp;#34;&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#66d9ef"&gt;for&lt;/span&gt; epoch &lt;span style="color:#f92672"&gt;in&lt;/span&gt; range(num_epochs):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;for&lt;/span&gt; i, (real_imgs, _) &lt;span style="color:#f92672"&gt;in&lt;/span&gt; enumerate(dataloader):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# ラベルの準備&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; batch_size_current &lt;span style="color:#f92672"&gt;=&lt;/span&gt; real_imgs&lt;span style="color:#f92672"&gt;.&lt;/span&gt;size(&lt;span style="color:#ae81ff"&gt;0&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; real_imgs &lt;span style="color:#f92672"&gt;=&lt;/span&gt; real_imgs&lt;span style="color:#f92672"&gt;.&lt;/span&gt;to(device)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# 本物と偽物のラベル&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; real_labels &lt;span style="color:#f92672"&gt;=&lt;/span&gt; torch&lt;span style="color:#f92672"&gt;.&lt;/span&gt;ones(batch_size_current, &lt;span style="color:#ae81ff"&gt;1&lt;/span&gt;)&lt;span style="color:#f92672"&gt;.&lt;/span&gt;to(device)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; fake_labels &lt;span style="color:#f92672"&gt;=&lt;/span&gt; torch&lt;span style="color:#f92672"&gt;.&lt;/span&gt;zeros(batch_size_current, &lt;span style="color:#ae81ff"&gt;1&lt;/span&gt;)&lt;span style="color:#f92672"&gt;.&lt;/span&gt;to(device)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# ---------------------&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# Discriminatorの学習&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# ---------------------&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; optimizer_D&lt;span style="color:#f92672"&gt;.&lt;/span&gt;zero_grad()
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# 本物の画像に対する損失&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; real_output &lt;span style="color:#f92672"&gt;=&lt;/span&gt; discriminator(real_imgs)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; d_loss_real &lt;span style="color:#f92672"&gt;=&lt;/span&gt; adversarial_loss(real_output, real_labels)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# 偽物の画像を生成&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; z &lt;span style="color:#f92672"&gt;=&lt;/span&gt; torch&lt;span style="color:#f92672"&gt;.&lt;/span&gt;randn(batch_size_current, latent_dim)&lt;span style="color:#f92672"&gt;.&lt;/span&gt;to(device)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; fake_imgs &lt;span style="color:#f92672"&gt;=&lt;/span&gt; generator(z)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# 偽物の画像に対する損失&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; fake_output &lt;span style="color:#f92672"&gt;=&lt;/span&gt; discriminator(fake_imgs&lt;span style="color:#f92672"&gt;.&lt;/span&gt;detach())
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; d_loss_fake &lt;span style="color:#f92672"&gt;=&lt;/span&gt; adversarial_loss(fake_output, fake_labels)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# 合計損失&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; d_loss &lt;span style="color:#f92672"&gt;=&lt;/span&gt; d_loss_real &lt;span style="color:#f92672"&gt;+&lt;/span&gt; d_loss_fake
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; d_loss&lt;span style="color:#f92672"&gt;.&lt;/span&gt;backward()
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; optimizer_D&lt;span style="color:#f92672"&gt;.&lt;/span&gt;step()
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# -----------------&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# Generatorの学習&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# -----------------&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; optimizer_G&lt;span style="color:#f92672"&gt;.&lt;/span&gt;zero_grad()
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# Generatorは偽物をDiscriminatorに本物と判定させたい&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; fake_output &lt;span style="color:#f92672"&gt;=&lt;/span&gt; discriminator(fake_imgs)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; g_loss &lt;span style="color:#f92672"&gt;=&lt;/span&gt; adversarial_loss(fake_output, real_labels)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; g_loss&lt;span style="color:#f92672"&gt;.&lt;/span&gt;backward()
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; optimizer_G&lt;span style="color:#f92672"&gt;.&lt;/span&gt;step()
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# 進捗表示&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;if&lt;/span&gt; i &lt;span style="color:#f92672"&gt;%&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;100&lt;/span&gt; &lt;span style="color:#f92672"&gt;==&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;0&lt;/span&gt;:
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; print(&lt;span style="color:#e6db74"&gt;f&lt;/span&gt;&lt;span style="color:#e6db74"&gt;&amp;#34;[Epoch &lt;/span&gt;&lt;span style="color:#e6db74"&gt;{&lt;/span&gt;epoch&lt;span style="color:#e6db74"&gt;}&lt;/span&gt;&lt;span style="color:#e6db74"&gt;/&lt;/span&gt;&lt;span style="color:#e6db74"&gt;{&lt;/span&gt;num_epochs&lt;span style="color:#e6db74"&gt;}&lt;/span&gt;&lt;span style="color:#e6db74"&gt;] [Batch &lt;/span&gt;&lt;span style="color:#e6db74"&gt;{&lt;/span&gt;i&lt;span style="color:#e6db74"&gt;}&lt;/span&gt;&lt;span style="color:#e6db74"&gt;/&lt;/span&gt;&lt;span style="color:#e6db74"&gt;{&lt;/span&gt;len(dataloader)&lt;span style="color:#e6db74"&gt;}&lt;/span&gt;&lt;span style="color:#e6db74"&gt;] &amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#e6db74"&gt;f&lt;/span&gt;&lt;span style="color:#e6db74"&gt;&amp;#34;[D loss: &lt;/span&gt;&lt;span style="color:#e6db74"&gt;{&lt;/span&gt;d_loss&lt;span style="color:#f92672"&gt;.&lt;/span&gt;item()&lt;span style="color:#e6db74"&gt;:&lt;/span&gt;&lt;span style="color:#e6db74"&gt;.4f&lt;/span&gt;&lt;span style="color:#e6db74"&gt;}&lt;/span&gt;&lt;span style="color:#e6db74"&gt;] [G loss: &lt;/span&gt;&lt;span style="color:#e6db74"&gt;{&lt;/span&gt;g_loss&lt;span style="color:#f92672"&gt;.&lt;/span&gt;item()&lt;span style="color:#e6db74"&gt;:&lt;/span&gt;&lt;span style="color:#e6db74"&gt;.4f&lt;/span&gt;&lt;span style="color:#e6db74"&gt;}&lt;/span&gt;&lt;span style="color:#e6db74"&gt;]&amp;#34;&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# エポック終了ごとに生成画像を保存&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;if&lt;/span&gt; epoch &lt;span style="color:#f92672"&gt;%&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;5&lt;/span&gt; &lt;span style="color:#f92672"&gt;==&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;0&lt;/span&gt;:
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;with&lt;/span&gt; torch&lt;span style="color:#f92672"&gt;.&lt;/span&gt;no_grad():
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; z &lt;span style="color:#f92672"&gt;=&lt;/span&gt; torch&lt;span style="color:#f92672"&gt;.&lt;/span&gt;randn(&lt;span style="color:#ae81ff"&gt;16&lt;/span&gt;, latent_dim)&lt;span style="color:#f92672"&gt;.&lt;/span&gt;to(device)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; generated_imgs &lt;span style="color:#f92672"&gt;=&lt;/span&gt; generator(z)&lt;span style="color:#f92672"&gt;.&lt;/span&gt;cpu()
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; fig, axes &lt;span style="color:#f92672"&gt;=&lt;/span&gt; plt&lt;span style="color:#f92672"&gt;.&lt;/span&gt;subplots(&lt;span style="color:#ae81ff"&gt;4&lt;/span&gt;, &lt;span style="color:#ae81ff"&gt;4&lt;/span&gt;, figsize&lt;span style="color:#f92672"&gt;=&lt;/span&gt;(&lt;span style="color:#ae81ff"&gt;8&lt;/span&gt;, &lt;span style="color:#ae81ff"&gt;8&lt;/span&gt;))
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;for&lt;/span&gt; idx, ax &lt;span style="color:#f92672"&gt;in&lt;/span&gt; enumerate(axes&lt;span style="color:#f92672"&gt;.&lt;/span&gt;flat):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; img &lt;span style="color:#f92672"&gt;=&lt;/span&gt; generated_imgs[idx]&lt;span style="color:#f92672"&gt;.&lt;/span&gt;squeeze()&lt;span style="color:#f92672"&gt;.&lt;/span&gt;numpy()
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; img &lt;span style="color:#f92672"&gt;=&lt;/span&gt; (img &lt;span style="color:#f92672"&gt;+&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;1&lt;/span&gt;) &lt;span style="color:#f92672"&gt;/&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;2&lt;/span&gt; &lt;span style="color:#75715e"&gt;# [-1, 1] -&amp;gt; [0, 1]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; ax&lt;span style="color:#f92672"&gt;.&lt;/span&gt;imshow(img, cmap&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#e6db74"&gt;&amp;#39;gray&amp;#39;&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; ax&lt;span style="color:#f92672"&gt;.&lt;/span&gt;axis(&lt;span style="color:#e6db74"&gt;&amp;#39;off&amp;#39;&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; plt&lt;span style="color:#f92672"&gt;.&lt;/span&gt;suptitle(&lt;span style="color:#e6db74"&gt;f&lt;/span&gt;&lt;span style="color:#e6db74"&gt;&amp;#39;Epoch &lt;/span&gt;&lt;span style="color:#e6db74"&gt;{&lt;/span&gt;epoch&lt;span style="color:#e6db74"&gt;}&lt;/span&gt;&lt;span style="color:#e6db74"&gt;&amp;#39;&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; plt&lt;span style="color:#f92672"&gt;.&lt;/span&gt;tight_layout()
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; plt&lt;span style="color:#f92672"&gt;.&lt;/span&gt;savefig(&lt;span style="color:#e6db74"&gt;f&lt;/span&gt;&lt;span style="color:#e6db74"&gt;&amp;#39;generated_epoch_&lt;/span&gt;&lt;span style="color:#e6db74"&gt;{&lt;/span&gt;epoch&lt;span style="color:#e6db74"&gt;}&lt;/span&gt;&lt;span style="color:#e6db74"&gt;.png&amp;#39;&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; plt&lt;span style="color:#f92672"&gt;.&lt;/span&gt;close()
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;print(&lt;span style="color:#e6db74"&gt;&amp;#34;学習完了!&amp;#34;&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#75715e"&gt;# 最終的な生成画像の表示&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#66d9ef"&gt;with&lt;/span&gt; torch&lt;span style="color:#f92672"&gt;.&lt;/span&gt;no_grad():
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; z &lt;span style="color:#f92672"&gt;=&lt;/span&gt; torch&lt;span style="color:#f92672"&gt;.&lt;/span&gt;randn(&lt;span style="color:#ae81ff"&gt;25&lt;/span&gt;, latent_dim)&lt;span style="color:#f92672"&gt;.&lt;/span&gt;to(device)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; generated_imgs &lt;span style="color:#f92672"&gt;=&lt;/span&gt; generator(z)&lt;span style="color:#f92672"&gt;.&lt;/span&gt;cpu()
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; fig, axes &lt;span style="color:#f92672"&gt;=&lt;/span&gt; plt&lt;span style="color:#f92672"&gt;.&lt;/span&gt;subplots(&lt;span style="color:#ae81ff"&gt;5&lt;/span&gt;, &lt;span style="color:#ae81ff"&gt;5&lt;/span&gt;, figsize&lt;span style="color:#f92672"&gt;=&lt;/span&gt;(&lt;span style="color:#ae81ff"&gt;10&lt;/span&gt;, &lt;span style="color:#ae81ff"&gt;10&lt;/span&gt;))
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;for&lt;/span&gt; idx, ax &lt;span style="color:#f92672"&gt;in&lt;/span&gt; enumerate(axes&lt;span style="color:#f92672"&gt;.&lt;/span&gt;flat):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; img &lt;span style="color:#f92672"&gt;=&lt;/span&gt; generated_imgs[idx]&lt;span style="color:#f92672"&gt;.&lt;/span&gt;squeeze()&lt;span style="color:#f92672"&gt;.&lt;/span&gt;numpy()
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; img &lt;span style="color:#f92672"&gt;=&lt;/span&gt; (img &lt;span style="color:#f92672"&gt;+&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;1&lt;/span&gt;) &lt;span style="color:#f92672"&gt;/&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;2&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; ax&lt;span style="color:#f92672"&gt;.&lt;/span&gt;imshow(img, cmap&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#e6db74"&gt;&amp;#39;gray&amp;#39;&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; ax&lt;span style="color:#f92672"&gt;.&lt;/span&gt;axis(&lt;span style="color:#e6db74"&gt;&amp;#39;off&amp;#39;&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; plt&lt;span style="color:#f92672"&gt;.&lt;/span&gt;suptitle(&lt;span style="color:#e6db74"&gt;&amp;#39;最終生成画像&amp;#39;&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; plt&lt;span style="color:#f92672"&gt;.&lt;/span&gt;tight_layout()
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; plt&lt;span style="color:#f92672"&gt;.&lt;/span&gt;savefig(&lt;span style="color:#e6db74"&gt;&amp;#39;final_generated_images.png&amp;#39;&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; plt&lt;span style="color:#f92672"&gt;.&lt;/span&gt;show()
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;h2 id="詳細解説"&gt;詳細解説&lt;/h2&gt;
&lt;h3 id="1-ganの基本概念"&gt;1. &lt;strong&gt;GANの基本概念&lt;/strong&gt;&lt;/h3&gt;
&lt;pre tabindex="0"&gt;&lt;code&gt;Generator (生成器) Discriminator (識別器)
 ↓ ↓
 偽の画像を生成 本物/偽物を判定
 ↓ ↓
 互いに競争しながら学習
&lt;/code&gt;&lt;/pre&gt;&lt;h3 id="2-データの正規化"&gt;2. &lt;strong&gt;データの正規化&lt;/strong&gt;&lt;/h3&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;-webkit-text-size-adjust:none;"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;transforms&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Normalize([&lt;span style="color:#ae81ff"&gt;0.5&lt;/span&gt;], [&lt;span style="color:#ae81ff"&gt;0.5&lt;/span&gt;])
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;ul&gt;
&lt;li&gt;MNIST画像を&lt;code&gt;[-1, 1]&lt;/code&gt;の範囲に正規化&lt;/li&gt;
&lt;li&gt;計算式: &lt;code&gt;(x - 0.5) / 0.5&lt;/code&gt;&lt;/li&gt;
&lt;/ul&gt;
&lt;h3 id="3-generator生成器の役割"&gt;3. &lt;strong&gt;Generator（生成器）の役割&lt;/strong&gt;&lt;/h3&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;-webkit-text-size-adjust:none;"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;入力: ランダムノイズ (&lt;span style="color:#ae81ff"&gt;100&lt;/span&gt;次元)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#960050;background-color:#1e0010"&gt;↓&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 全結合層 &lt;span style="color:#f92672"&gt;+&lt;/span&gt; 活性化
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#960050;background-color:#1e0010"&gt;↓&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 徐々に次元を拡大
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#960050;background-color:#1e0010"&gt;↓&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;出力: &lt;span style="color:#ae81ff"&gt;28&lt;/span&gt;&lt;span style="color:#960050;background-color:#1e0010"&gt;×&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;28&lt;/span&gt;の画像 (&lt;span style="color:#ae81ff"&gt;784&lt;/span&gt;次元)
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;&lt;strong&gt;重要なポイント:&lt;/strong&gt;&lt;/p&gt;</description></item><item><title>Vision Transformer(ViT)画像分類</title><link>https://ml.askbox.net/posts/vision-transformer-classification/</link><pubDate>Mon, 09 Feb 2026 16:01:11 +0900</pubDate><guid>https://ml.askbox.net/posts/vision-transformer-classification/</guid><description>&lt;h2 id="pytorch-vision-transformer画像分類サンプルコード解説"&gt;PyTorch Vision Transformer画像分類サンプルコード解説&lt;/h2&gt;
&lt;p&gt;Vision Transformerの実装と解説をします。&lt;/p&gt;
&lt;h2 id="1-基本的な実装"&gt;1. 基本的な実装&lt;/h2&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;-webkit-text-size-adjust:none;"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#f92672"&gt;import&lt;/span&gt; torch
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#f92672"&gt;import&lt;/span&gt; torch.nn &lt;span style="color:#66d9ef"&gt;as&lt;/span&gt; nn
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#f92672"&gt;import&lt;/span&gt; torch.nn.functional &lt;span style="color:#66d9ef"&gt;as&lt;/span&gt; F
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#f92672"&gt;from&lt;/span&gt; torchvision &lt;span style="color:#f92672"&gt;import&lt;/span&gt; datasets, transforms
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#f92672"&gt;from&lt;/span&gt; torch.utils.data &lt;span style="color:#f92672"&gt;import&lt;/span&gt; DataLoader
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#75715e"&gt;# ===== パッチ埋め込み層 =====&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#66d9ef"&gt;class&lt;/span&gt; &lt;span style="color:#a6e22e"&gt;PatchEmbedding&lt;/span&gt;(nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Module):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#e6db74"&gt;&amp;#34;&amp;#34;&amp;#34;画像をパッチに分割し、埋め込みベクトルに変換&amp;#34;&amp;#34;&amp;#34;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;def&lt;/span&gt; &lt;span style="color:#a6e22e"&gt;__init__&lt;/span&gt;(self, img_size&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;224&lt;/span&gt;, patch_size&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;16&lt;/span&gt;, in_channels&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;3&lt;/span&gt;, embed_dim&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;768&lt;/span&gt;):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; super()&lt;span style="color:#f92672"&gt;.&lt;/span&gt;&lt;span style="color:#a6e22e"&gt;__init__&lt;/span&gt;()
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;img_size &lt;span style="color:#f92672"&gt;=&lt;/span&gt; img_size
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;patch_size &lt;span style="color:#f92672"&gt;=&lt;/span&gt; patch_size
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;n_patches &lt;span style="color:#f92672"&gt;=&lt;/span&gt; (img_size &lt;span style="color:#f92672"&gt;//&lt;/span&gt; patch_size) &lt;span style="color:#f92672"&gt;**&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;2&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# 畳み込みでパッチ埋め込みを実現&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;proj &lt;span style="color:#f92672"&gt;=&lt;/span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Conv2d(
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; in_channels, 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; embed_dim, 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; kernel_size&lt;span style="color:#f92672"&gt;=&lt;/span&gt;patch_size, 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; stride&lt;span style="color:#f92672"&gt;=&lt;/span&gt;patch_size
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; )
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;def&lt;/span&gt; &lt;span style="color:#a6e22e"&gt;forward&lt;/span&gt;(self, x):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# x: (B, C, H, W) → (B, embed_dim, n_patches**0.5, n_patches**0.5)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; x &lt;span style="color:#f92672"&gt;=&lt;/span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;proj(x) 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# (B, embed_dim, H&amp;#39;, W&amp;#39;) → (B, embed_dim, n_patches)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; x &lt;span style="color:#f92672"&gt;=&lt;/span&gt; x&lt;span style="color:#f92672"&gt;.&lt;/span&gt;flatten(&lt;span style="color:#ae81ff"&gt;2&lt;/span&gt;) 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# (B, embed_dim, n_patches) → (B, n_patches, embed_dim)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; x &lt;span style="color:#f92672"&gt;=&lt;/span&gt; x&lt;span style="color:#f92672"&gt;.&lt;/span&gt;transpose(&lt;span style="color:#ae81ff"&gt;1&lt;/span&gt;, &lt;span style="color:#ae81ff"&gt;2&lt;/span&gt;) 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;return&lt;/span&gt; x
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#75715e"&gt;# ===== Multi-Head Attention =====&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#66d9ef"&gt;class&lt;/span&gt; &lt;span style="color:#a6e22e"&gt;MultiHeadAttention&lt;/span&gt;(nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Module):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;def&lt;/span&gt; &lt;span style="color:#a6e22e"&gt;__init__&lt;/span&gt;(self, embed_dim&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;768&lt;/span&gt;, num_heads&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;12&lt;/span&gt;, dropout&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;0.1&lt;/span&gt;):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; super()&lt;span style="color:#f92672"&gt;.&lt;/span&gt;&lt;span style="color:#a6e22e"&gt;__init__&lt;/span&gt;()
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;embed_dim &lt;span style="color:#f92672"&gt;=&lt;/span&gt; embed_dim
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;num_heads &lt;span style="color:#f92672"&gt;=&lt;/span&gt; num_heads
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;head_dim &lt;span style="color:#f92672"&gt;=&lt;/span&gt; embed_dim &lt;span style="color:#f92672"&gt;//&lt;/span&gt; num_heads
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;scale &lt;span style="color:#f92672"&gt;=&lt;/span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;head_dim &lt;span style="color:#f92672"&gt;**&lt;/span&gt; &lt;span style="color:#f92672"&gt;-&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;0.5&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# Query, Key, Value の線形変換&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;qkv &lt;span style="color:#f92672"&gt;=&lt;/span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Linear(embed_dim, embed_dim &lt;span style="color:#f92672"&gt;*&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;3&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;proj &lt;span style="color:#f92672"&gt;=&lt;/span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Linear(embed_dim, embed_dim)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;dropout &lt;span style="color:#f92672"&gt;=&lt;/span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Dropout(dropout)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;def&lt;/span&gt; &lt;span style="color:#a6e22e"&gt;forward&lt;/span&gt;(self, x):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; B, N, C &lt;span style="color:#f92672"&gt;=&lt;/span&gt; x&lt;span style="color:#f92672"&gt;.&lt;/span&gt;shape
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# QKV計算&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; qkv &lt;span style="color:#f92672"&gt;=&lt;/span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;qkv(x)&lt;span style="color:#f92672"&gt;.&lt;/span&gt;reshape(B, N, &lt;span style="color:#ae81ff"&gt;3&lt;/span&gt;, self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;num_heads, self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;head_dim)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; qkv &lt;span style="color:#f92672"&gt;=&lt;/span&gt; qkv&lt;span style="color:#f92672"&gt;.&lt;/span&gt;permute(&lt;span style="color:#ae81ff"&gt;2&lt;/span&gt;, &lt;span style="color:#ae81ff"&gt;0&lt;/span&gt;, &lt;span style="color:#ae81ff"&gt;3&lt;/span&gt;, &lt;span style="color:#ae81ff"&gt;1&lt;/span&gt;, &lt;span style="color:#ae81ff"&gt;4&lt;/span&gt;) &lt;span style="color:#75715e"&gt;# (3, B, num_heads, N, head_dim)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; q, k, v &lt;span style="color:#f92672"&gt;=&lt;/span&gt; qkv[&lt;span style="color:#ae81ff"&gt;0&lt;/span&gt;], qkv[&lt;span style="color:#ae81ff"&gt;1&lt;/span&gt;], qkv[&lt;span style="color:#ae81ff"&gt;2&lt;/span&gt;]
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# Attention計算&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; attn &lt;span style="color:#f92672"&gt;=&lt;/span&gt; (q &lt;span style="color:#f92672"&gt;@&lt;/span&gt; k&lt;span style="color:#f92672"&gt;.&lt;/span&gt;transpose(&lt;span style="color:#f92672"&gt;-&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;2&lt;/span&gt;, &lt;span style="color:#f92672"&gt;-&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;1&lt;/span&gt;)) &lt;span style="color:#f92672"&gt;*&lt;/span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;scale &lt;span style="color:#75715e"&gt;# (B, num_heads, N, N)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; attn &lt;span style="color:#f92672"&gt;=&lt;/span&gt; attn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;softmax(dim&lt;span style="color:#f92672"&gt;=-&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;1&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; attn &lt;span style="color:#f92672"&gt;=&lt;/span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;dropout(attn)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# 値と結合&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; x &lt;span style="color:#f92672"&gt;=&lt;/span&gt; (attn &lt;span style="color:#f92672"&gt;@&lt;/span&gt; v)&lt;span style="color:#f92672"&gt;.&lt;/span&gt;transpose(&lt;span style="color:#ae81ff"&gt;1&lt;/span&gt;, &lt;span style="color:#ae81ff"&gt;2&lt;/span&gt;)&lt;span style="color:#f92672"&gt;.&lt;/span&gt;reshape(B, N, C)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; x &lt;span style="color:#f92672"&gt;=&lt;/span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;proj(x)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; x &lt;span style="color:#f92672"&gt;=&lt;/span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;dropout(x)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;return&lt;/span&gt; x
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#75715e"&gt;# ===== MLP (Feed Forward Network) =====&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#66d9ef"&gt;class&lt;/span&gt; &lt;span style="color:#a6e22e"&gt;MLP&lt;/span&gt;(nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Module):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;def&lt;/span&gt; &lt;span style="color:#a6e22e"&gt;__init__&lt;/span&gt;(self, embed_dim&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;768&lt;/span&gt;, mlp_ratio&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;4.0&lt;/span&gt;, dropout&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;0.1&lt;/span&gt;):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; super()&lt;span style="color:#f92672"&gt;.&lt;/span&gt;&lt;span style="color:#a6e22e"&gt;__init__&lt;/span&gt;()
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; hidden_dim &lt;span style="color:#f92672"&gt;=&lt;/span&gt; int(embed_dim &lt;span style="color:#f92672"&gt;*&lt;/span&gt; mlp_ratio)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;fc1 &lt;span style="color:#f92672"&gt;=&lt;/span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Linear(embed_dim, hidden_dim)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;fc2 &lt;span style="color:#f92672"&gt;=&lt;/span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Linear(hidden_dim, embed_dim)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;dropout &lt;span style="color:#f92672"&gt;=&lt;/span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Dropout(dropout)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;def&lt;/span&gt; &lt;span style="color:#a6e22e"&gt;forward&lt;/span&gt;(self, x):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; x &lt;span style="color:#f92672"&gt;=&lt;/span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;fc1(x)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; x &lt;span style="color:#f92672"&gt;=&lt;/span&gt; F&lt;span style="color:#f92672"&gt;.&lt;/span&gt;gelu(x)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; x &lt;span style="color:#f92672"&gt;=&lt;/span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;dropout(x)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; x &lt;span style="color:#f92672"&gt;=&lt;/span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;fc2(x)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; x &lt;span style="color:#f92672"&gt;=&lt;/span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;dropout(x)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;return&lt;/span&gt; x
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#75715e"&gt;# ===== Transformer Block =====&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#66d9ef"&gt;class&lt;/span&gt; &lt;span style="color:#a6e22e"&gt;TransformerBlock&lt;/span&gt;(nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Module):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;def&lt;/span&gt; &lt;span style="color:#a6e22e"&gt;__init__&lt;/span&gt;(self, embed_dim&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;768&lt;/span&gt;, num_heads&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;12&lt;/span&gt;, mlp_ratio&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;4.0&lt;/span&gt;, dropout&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;0.1&lt;/span&gt;):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; super()&lt;span style="color:#f92672"&gt;.&lt;/span&gt;&lt;span style="color:#a6e22e"&gt;__init__&lt;/span&gt;()
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;norm1 &lt;span style="color:#f92672"&gt;=&lt;/span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;LayerNorm(embed_dim)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;attn &lt;span style="color:#f92672"&gt;=&lt;/span&gt; MultiHeadAttention(embed_dim, num_heads, dropout)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;norm2 &lt;span style="color:#f92672"&gt;=&lt;/span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;LayerNorm(embed_dim)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;mlp &lt;span style="color:#f92672"&gt;=&lt;/span&gt; MLP(embed_dim, mlp_ratio, dropout)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;def&lt;/span&gt; &lt;span style="color:#a6e22e"&gt;forward&lt;/span&gt;(self, x):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# Pre-Norm構造&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; x &lt;span style="color:#f92672"&gt;=&lt;/span&gt; x &lt;span style="color:#f92672"&gt;+&lt;/span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;attn(self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;norm1(x))
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; x &lt;span style="color:#f92672"&gt;=&lt;/span&gt; x &lt;span style="color:#f92672"&gt;+&lt;/span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;mlp(self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;norm2(x))
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;return&lt;/span&gt; x
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#75715e"&gt;# ===== Vision Transformer =====&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#66d9ef"&gt;class&lt;/span&gt; &lt;span style="color:#a6e22e"&gt;VisionTransformer&lt;/span&gt;(nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Module):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;def&lt;/span&gt; &lt;span style="color:#a6e22e"&gt;__init__&lt;/span&gt;(
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; self, 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; img_size&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;224&lt;/span&gt;, 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; patch_size&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;16&lt;/span&gt;, 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; in_channels&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;3&lt;/span&gt;, 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; num_classes&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;10&lt;/span&gt;,
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; embed_dim&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;768&lt;/span&gt;, 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; depth&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;12&lt;/span&gt;, 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; num_heads&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;12&lt;/span&gt;, 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; mlp_ratio&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;4.0&lt;/span&gt;, 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; dropout&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;0.1&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; ):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; super()&lt;span style="color:#f92672"&gt;.&lt;/span&gt;&lt;span style="color:#a6e22e"&gt;__init__&lt;/span&gt;()
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# パッチ埋め込み&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;patch_embed &lt;span style="color:#f92672"&gt;=&lt;/span&gt; PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; num_patches &lt;span style="color:#f92672"&gt;=&lt;/span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;patch_embed&lt;span style="color:#f92672"&gt;.&lt;/span&gt;n_patches
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# CLSトークン (分類用の特別なトークン)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;cls_token &lt;span style="color:#f92672"&gt;=&lt;/span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Parameter(torch&lt;span style="color:#f92672"&gt;.&lt;/span&gt;zeros(&lt;span style="color:#ae81ff"&gt;1&lt;/span&gt;, &lt;span style="color:#ae81ff"&gt;1&lt;/span&gt;, embed_dim))
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# 位置埋め込み&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;pos_embed &lt;span style="color:#f92672"&gt;=&lt;/span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Parameter(torch&lt;span style="color:#f92672"&gt;.&lt;/span&gt;zeros(&lt;span style="color:#ae81ff"&gt;1&lt;/span&gt;, num_patches &lt;span style="color:#f92672"&gt;+&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;1&lt;/span&gt;, embed_dim))
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;pos_drop &lt;span style="color:#f92672"&gt;=&lt;/span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Dropout(dropout)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# Transformer Blocks&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;blocks &lt;span style="color:#f92672"&gt;=&lt;/span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;ModuleList([
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;for&lt;/span&gt; _ &lt;span style="color:#f92672"&gt;in&lt;/span&gt; range(depth)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; ])
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# 分類ヘッド&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;norm &lt;span style="color:#f92672"&gt;=&lt;/span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;LayerNorm(embed_dim)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;head &lt;span style="color:#f92672"&gt;=&lt;/span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Linear(embed_dim, num_classes)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# 重み初期化&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;init&lt;span style="color:#f92672"&gt;.&lt;/span&gt;trunc_normal_(self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;pos_embed, std&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;0.02&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;init&lt;span style="color:#f92672"&gt;.&lt;/span&gt;trunc_normal_(self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;cls_token, std&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;0.02&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;def&lt;/span&gt; &lt;span style="color:#a6e22e"&gt;forward&lt;/span&gt;(self, x):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; B &lt;span style="color:#f92672"&gt;=&lt;/span&gt; x&lt;span style="color:#f92672"&gt;.&lt;/span&gt;shape[&lt;span style="color:#ae81ff"&gt;0&lt;/span&gt;]
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# パッチ埋め込み&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; x &lt;span style="color:#f92672"&gt;=&lt;/span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;patch_embed(x) &lt;span style="color:#75715e"&gt;# (B, n_patches, embed_dim)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# CLSトークンを追加&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; cls_tokens &lt;span style="color:#f92672"&gt;=&lt;/span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;cls_token&lt;span style="color:#f92672"&gt;.&lt;/span&gt;expand(B, &lt;span style="color:#f92672"&gt;-&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;1&lt;/span&gt;, &lt;span style="color:#f92672"&gt;-&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;1&lt;/span&gt;) &lt;span style="color:#75715e"&gt;# (B, 1, embed_dim)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; x &lt;span style="color:#f92672"&gt;=&lt;/span&gt; torch&lt;span style="color:#f92672"&gt;.&lt;/span&gt;cat([cls_tokens, x], dim&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;1&lt;/span&gt;) &lt;span style="color:#75715e"&gt;# (B, n_patches+1, embed_dim)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# 位置埋め込みを追加&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; x &lt;span style="color:#f92672"&gt;=&lt;/span&gt; x &lt;span style="color:#f92672"&gt;+&lt;/span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;pos_embed
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; x &lt;span style="color:#f92672"&gt;=&lt;/span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;pos_drop(x)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# Transformer Blocksを通過&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;for&lt;/span&gt; block &lt;span style="color:#f92672"&gt;in&lt;/span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;blocks:
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; x &lt;span style="color:#f92672"&gt;=&lt;/span&gt; block(x)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# 正規化&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; x &lt;span style="color:#f92672"&gt;=&lt;/span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;norm(x)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# CLSトークンのみを使用して分類&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; cls_token_final &lt;span style="color:#f92672"&gt;=&lt;/span&gt; x[:, &lt;span style="color:#ae81ff"&gt;0&lt;/span&gt;]
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; logits &lt;span style="color:#f92672"&gt;=&lt;/span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;head(cls_token_final)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;return&lt;/span&gt; logits
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;h2 id="2-訓練コード"&gt;2. 訓練コード&lt;/h2&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;-webkit-text-size-adjust:none;"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#75715e"&gt;# ===== データ準備 =====&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#66d9ef"&gt;def&lt;/span&gt; &lt;span style="color:#a6e22e"&gt;get_dataloaders&lt;/span&gt;(batch_size&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;32&lt;/span&gt;):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; transform &lt;span style="color:#f92672"&gt;=&lt;/span&gt; transforms&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Compose([
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; transforms&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Resize((&lt;span style="color:#ae81ff"&gt;224&lt;/span&gt;, &lt;span style="color:#ae81ff"&gt;224&lt;/span&gt;)),
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; transforms&lt;span style="color:#f92672"&gt;.&lt;/span&gt;ToTensor(),
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; transforms&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Normalize(mean&lt;span style="color:#f92672"&gt;=&lt;/span&gt;[&lt;span style="color:#ae81ff"&gt;0.485&lt;/span&gt;, &lt;span style="color:#ae81ff"&gt;0.456&lt;/span&gt;, &lt;span style="color:#ae81ff"&gt;0.406&lt;/span&gt;], 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; std&lt;span style="color:#f92672"&gt;=&lt;/span&gt;[&lt;span style="color:#ae81ff"&gt;0.229&lt;/span&gt;, &lt;span style="color:#ae81ff"&gt;0.224&lt;/span&gt;, &lt;span style="color:#ae81ff"&gt;0.225&lt;/span&gt;])
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; ])
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; train_dataset &lt;span style="color:#f92672"&gt;=&lt;/span&gt; datasets&lt;span style="color:#f92672"&gt;.&lt;/span&gt;CIFAR10(
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; root&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#e6db74"&gt;&amp;#39;./data&amp;#39;&lt;/span&gt;, 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; train&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#66d9ef"&gt;True&lt;/span&gt;, 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; download&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#66d9ef"&gt;True&lt;/span&gt;, 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; transform&lt;span style="color:#f92672"&gt;=&lt;/span&gt;transform
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; )
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; test_dataset &lt;span style="color:#f92672"&gt;=&lt;/span&gt; datasets&lt;span style="color:#f92672"&gt;.&lt;/span&gt;CIFAR10(
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; root&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#e6db74"&gt;&amp;#39;./data&amp;#39;&lt;/span&gt;, 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; train&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#66d9ef"&gt;False&lt;/span&gt;, 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; download&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#66d9ef"&gt;True&lt;/span&gt;, 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; transform&lt;span style="color:#f92672"&gt;=&lt;/span&gt;transform
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; )
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; train_loader &lt;span style="color:#f92672"&gt;=&lt;/span&gt; DataLoader(train_dataset, batch_size&lt;span style="color:#f92672"&gt;=&lt;/span&gt;batch_size, shuffle&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#66d9ef"&gt;True&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; test_loader &lt;span style="color:#f92672"&gt;=&lt;/span&gt; DataLoader(test_dataset, batch_size&lt;span style="color:#f92672"&gt;=&lt;/span&gt;batch_size, shuffle&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#66d9ef"&gt;False&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;return&lt;/span&gt; train_loader, test_loader
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#75715e"&gt;# ===== 訓練関数 =====&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#66d9ef"&gt;def&lt;/span&gt; &lt;span style="color:#a6e22e"&gt;train_one_epoch&lt;/span&gt;(model, dataloader, criterion, optimizer, device):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; model&lt;span style="color:#f92672"&gt;.&lt;/span&gt;train()
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; running_loss &lt;span style="color:#f92672"&gt;=&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;0.0&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; correct &lt;span style="color:#f92672"&gt;=&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;0&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; total &lt;span style="color:#f92672"&gt;=&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;0&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;for&lt;/span&gt; images, labels &lt;span style="color:#f92672"&gt;in&lt;/span&gt; dataloader:
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; images, labels &lt;span style="color:#f92672"&gt;=&lt;/span&gt; images&lt;span style="color:#f92672"&gt;.&lt;/span&gt;to(device), labels&lt;span style="color:#f92672"&gt;.&lt;/span&gt;to(device)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; optimizer&lt;span style="color:#f92672"&gt;.&lt;/span&gt;zero_grad()
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; outputs &lt;span style="color:#f92672"&gt;=&lt;/span&gt; model(images)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; loss &lt;span style="color:#f92672"&gt;=&lt;/span&gt; criterion(outputs, labels)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; loss&lt;span style="color:#f92672"&gt;.&lt;/span&gt;backward()
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; optimizer&lt;span style="color:#f92672"&gt;.&lt;/span&gt;step()
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; running_loss &lt;span style="color:#f92672"&gt;+=&lt;/span&gt; loss&lt;span style="color:#f92672"&gt;.&lt;/span&gt;item()
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; _, predicted &lt;span style="color:#f92672"&gt;=&lt;/span&gt; outputs&lt;span style="color:#f92672"&gt;.&lt;/span&gt;max(&lt;span style="color:#ae81ff"&gt;1&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; total &lt;span style="color:#f92672"&gt;+=&lt;/span&gt; labels&lt;span style="color:#f92672"&gt;.&lt;/span&gt;size(&lt;span style="color:#ae81ff"&gt;0&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; correct &lt;span style="color:#f92672"&gt;+=&lt;/span&gt; predicted&lt;span style="color:#f92672"&gt;.&lt;/span&gt;eq(labels)&lt;span style="color:#f92672"&gt;.&lt;/span&gt;sum()&lt;span style="color:#f92672"&gt;.&lt;/span&gt;item()
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; epoch_loss &lt;span style="color:#f92672"&gt;=&lt;/span&gt; running_loss &lt;span style="color:#f92672"&gt;/&lt;/span&gt; len(dataloader)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; epoch_acc &lt;span style="color:#f92672"&gt;=&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;100.&lt;/span&gt; &lt;span style="color:#f92672"&gt;*&lt;/span&gt; correct &lt;span style="color:#f92672"&gt;/&lt;/span&gt; total
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;return&lt;/span&gt; epoch_loss, epoch_acc
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#75715e"&gt;# ===== 評価関数 =====&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#66d9ef"&gt;def&lt;/span&gt; &lt;span style="color:#a6e22e"&gt;evaluate&lt;/span&gt;(model, dataloader, criterion, device):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; model&lt;span style="color:#f92672"&gt;.&lt;/span&gt;eval()
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; running_loss &lt;span style="color:#f92672"&gt;=&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;0.0&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; correct &lt;span style="color:#f92672"&gt;=&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;0&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; total &lt;span style="color:#f92672"&gt;=&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;0&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;with&lt;/span&gt; torch&lt;span style="color:#f92672"&gt;.&lt;/span&gt;no_grad():
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;for&lt;/span&gt; images, labels &lt;span style="color:#f92672"&gt;in&lt;/span&gt; dataloader:
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; images, labels &lt;span style="color:#f92672"&gt;=&lt;/span&gt; images&lt;span style="color:#f92672"&gt;.&lt;/span&gt;to(device), labels&lt;span style="color:#f92672"&gt;.&lt;/span&gt;to(device)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; outputs &lt;span style="color:#f92672"&gt;=&lt;/span&gt; model(images)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; loss &lt;span style="color:#f92672"&gt;=&lt;/span&gt; criterion(outputs, labels)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; running_loss &lt;span style="color:#f92672"&gt;+=&lt;/span&gt; loss&lt;span style="color:#f92672"&gt;.&lt;/span&gt;item()
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; _, predicted &lt;span style="color:#f92672"&gt;=&lt;/span&gt; outputs&lt;span style="color:#f92672"&gt;.&lt;/span&gt;max(&lt;span style="color:#ae81ff"&gt;1&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; total &lt;span style="color:#f92672"&gt;+=&lt;/span&gt; labels&lt;span style="color:#f92672"&gt;.&lt;/span&gt;size(&lt;span style="color:#ae81ff"&gt;0&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; correct &lt;span style="color:#f92672"&gt;+=&lt;/span&gt; predicted&lt;span style="color:#f92672"&gt;.&lt;/span&gt;eq(labels)&lt;span style="color:#f92672"&gt;.&lt;/span&gt;sum()&lt;span style="color:#f92672"&gt;.&lt;/span&gt;item()
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; epoch_loss &lt;span style="color:#f92672"&gt;=&lt;/span&gt; running_loss &lt;span style="color:#f92672"&gt;/&lt;/span&gt; len(dataloader)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; epoch_acc &lt;span style="color:#f92672"&gt;=&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;100.&lt;/span&gt; &lt;span style="color:#f92672"&gt;*&lt;/span&gt; correct &lt;span style="color:#f92672"&gt;/&lt;/span&gt; total
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;return&lt;/span&gt; epoch_loss, epoch_acc
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#75715e"&gt;# ===== メイン実行 =====&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#66d9ef"&gt;def&lt;/span&gt; &lt;span style="color:#a6e22e"&gt;main&lt;/span&gt;():
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# ハイパーパラメータ&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; device &lt;span style="color:#f92672"&gt;=&lt;/span&gt; torch&lt;span style="color:#f92672"&gt;.&lt;/span&gt;device(&lt;span style="color:#e6db74"&gt;&amp;#39;cuda&amp;#39;&lt;/span&gt; &lt;span style="color:#66d9ef"&gt;if&lt;/span&gt; torch&lt;span style="color:#f92672"&gt;.&lt;/span&gt;cuda&lt;span style="color:#f92672"&gt;.&lt;/span&gt;is_available() &lt;span style="color:#66d9ef"&gt;else&lt;/span&gt; &lt;span style="color:#e6db74"&gt;&amp;#39;cpu&amp;#39;&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; num_epochs &lt;span style="color:#f92672"&gt;=&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;50&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; batch_size &lt;span style="color:#f92672"&gt;=&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;64&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; learning_rate &lt;span style="color:#f92672"&gt;=&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;3e-4&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# モデル作成（小型版）&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; model &lt;span style="color:#f92672"&gt;=&lt;/span&gt; VisionTransformer(
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; img_size&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;224&lt;/span&gt;,
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; patch_size&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;16&lt;/span&gt;,
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; in_channels&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;3&lt;/span&gt;,
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; num_classes&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;10&lt;/span&gt;,
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; embed_dim&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;384&lt;/span&gt;, &lt;span style="color:#75715e"&gt;# 小さめ&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; depth&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;6&lt;/span&gt;, &lt;span style="color:#75715e"&gt;# 浅め&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; num_heads&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;6&lt;/span&gt;,
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; mlp_ratio&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;4.0&lt;/span&gt;,
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; dropout&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;0.1&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; )&lt;span style="color:#f92672"&gt;.&lt;/span&gt;to(device)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; print(&lt;span style="color:#e6db74"&gt;f&lt;/span&gt;&lt;span style="color:#e6db74"&gt;&amp;#34;モデルパラメータ数: &lt;/span&gt;&lt;span style="color:#e6db74"&gt;{&lt;/span&gt;sum(p&lt;span style="color:#f92672"&gt;.&lt;/span&gt;numel() &lt;span style="color:#66d9ef"&gt;for&lt;/span&gt; p &lt;span style="color:#f92672"&gt;in&lt;/span&gt; model&lt;span style="color:#f92672"&gt;.&lt;/span&gt;parameters())&lt;span style="color:#e6db74"&gt;:&lt;/span&gt;&lt;span style="color:#e6db74"&gt;,&lt;/span&gt;&lt;span style="color:#e6db74"&gt;}&lt;/span&gt;&lt;span style="color:#e6db74"&gt;&amp;#34;&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# データローダー&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; train_loader, test_loader &lt;span style="color:#f92672"&gt;=&lt;/span&gt; get_dataloaders(batch_size)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# 損失関数とオプティマイザ&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; criterion &lt;span style="color:#f92672"&gt;=&lt;/span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;CrossEntropyLoss()
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; optimizer &lt;span style="color:#f92672"&gt;=&lt;/span&gt; torch&lt;span style="color:#f92672"&gt;.&lt;/span&gt;optim&lt;span style="color:#f92672"&gt;.&lt;/span&gt;AdamW(model&lt;span style="color:#f92672"&gt;.&lt;/span&gt;parameters(), lr&lt;span style="color:#f92672"&gt;=&lt;/span&gt;learning_rate, weight_decay&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;0.05&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; scheduler &lt;span style="color:#f92672"&gt;=&lt;/span&gt; torch&lt;span style="color:#f92672"&gt;.&lt;/span&gt;optim&lt;span style="color:#f92672"&gt;.&lt;/span&gt;lr_scheduler&lt;span style="color:#f92672"&gt;.&lt;/span&gt;CosineAnnealingLR(optimizer, T_max&lt;span style="color:#f92672"&gt;=&lt;/span&gt;num_epochs)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# 訓練ループ&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;for&lt;/span&gt; epoch &lt;span style="color:#f92672"&gt;in&lt;/span&gt; range(num_epochs):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; train_loss, train_acc &lt;span style="color:#f92672"&gt;=&lt;/span&gt; train_one_epoch(model, train_loader, criterion, optimizer, device)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; test_loss, test_acc &lt;span style="color:#f92672"&gt;=&lt;/span&gt; evaluate(model, test_loader, criterion, device)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; scheduler&lt;span style="color:#f92672"&gt;.&lt;/span&gt;step()
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; print(&lt;span style="color:#e6db74"&gt;f&lt;/span&gt;&lt;span style="color:#e6db74"&gt;&amp;#34;Epoch [&lt;/span&gt;&lt;span style="color:#e6db74"&gt;{&lt;/span&gt;epoch&lt;span style="color:#f92672"&gt;+&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;1&lt;/span&gt;&lt;span style="color:#e6db74"&gt;}&lt;/span&gt;&lt;span style="color:#e6db74"&gt;/&lt;/span&gt;&lt;span style="color:#e6db74"&gt;{&lt;/span&gt;num_epochs&lt;span style="color:#e6db74"&gt;}&lt;/span&gt;&lt;span style="color:#e6db74"&gt;]&amp;#34;&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; print(&lt;span style="color:#e6db74"&gt;f&lt;/span&gt;&lt;span style="color:#e6db74"&gt;&amp;#34; Train Loss: &lt;/span&gt;&lt;span style="color:#e6db74"&gt;{&lt;/span&gt;train_loss&lt;span style="color:#e6db74"&gt;:&lt;/span&gt;&lt;span style="color:#e6db74"&gt;.4f&lt;/span&gt;&lt;span style="color:#e6db74"&gt;}&lt;/span&gt;&lt;span style="color:#e6db74"&gt;, Train Acc: &lt;/span&gt;&lt;span style="color:#e6db74"&gt;{&lt;/span&gt;train_acc&lt;span style="color:#e6db74"&gt;:&lt;/span&gt;&lt;span style="color:#e6db74"&gt;.2f&lt;/span&gt;&lt;span style="color:#e6db74"&gt;}&lt;/span&gt;&lt;span style="color:#e6db74"&gt;%&amp;#34;&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; print(&lt;span style="color:#e6db74"&gt;f&lt;/span&gt;&lt;span style="color:#e6db74"&gt;&amp;#34; Test Loss: &lt;/span&gt;&lt;span style="color:#e6db74"&gt;{&lt;/span&gt;test_loss&lt;span style="color:#e6db74"&gt;:&lt;/span&gt;&lt;span style="color:#e6db74"&gt;.4f&lt;/span&gt;&lt;span style="color:#e6db74"&gt;}&lt;/span&gt;&lt;span style="color:#e6db74"&gt;, Test Acc: &lt;/span&gt;&lt;span style="color:#e6db74"&gt;{&lt;/span&gt;test_acc&lt;span style="color:#e6db74"&gt;:&lt;/span&gt;&lt;span style="color:#e6db74"&gt;.2f&lt;/span&gt;&lt;span style="color:#e6db74"&gt;}&lt;/span&gt;&lt;span style="color:#e6db74"&gt;%&amp;#34;&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# モデル保存&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; torch&lt;span style="color:#f92672"&gt;.&lt;/span&gt;save(model&lt;span style="color:#f92672"&gt;.&lt;/span&gt;state_dict(), &lt;span style="color:#e6db74"&gt;&amp;#39;vit_model.pth&amp;#39;&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; print(&lt;span style="color:#e6db74"&gt;&amp;#34;訓練完了!&amp;#34;&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#66d9ef"&gt;if&lt;/span&gt; __name__ &lt;span style="color:#f92672"&gt;==&lt;/span&gt; &lt;span style="color:#e6db74"&gt;&amp;#34;__main__&amp;#34;&lt;/span&gt;:
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; main()
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;h2 id="3-事前学習済みモデルの使用"&gt;3. 事前学習済みモデルの使用&lt;/h2&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;-webkit-text-size-adjust:none;"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#75715e"&gt;# torchvisionの事前学習済みViTを使う簡単な方法&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#f92672"&gt;from&lt;/span&gt; torchvision.models &lt;span style="color:#f92672"&gt;import&lt;/span&gt; vit_b_16, ViT_B_16_Weights
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#66d9ef"&gt;def&lt;/span&gt; &lt;span style="color:#a6e22e"&gt;use_pretrained_vit&lt;/span&gt;():
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# 事前学習済みモデルをロード&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; weights &lt;span style="color:#f92672"&gt;=&lt;/span&gt; ViT_B_16_Weights&lt;span style="color:#f92672"&gt;.&lt;/span&gt;IMAGENET1K_V1
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; model &lt;span style="color:#f92672"&gt;=&lt;/span&gt; vit_b_16(weights&lt;span style="color:#f92672"&gt;=&lt;/span&gt;weights)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# ファインチューニング用にヘッドを置き換え&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; num_classes &lt;span style="color:#f92672"&gt;=&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;10&lt;/span&gt; &lt;span style="color:#75715e"&gt;# CIFAR-10&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; model&lt;span style="color:#f92672"&gt;.&lt;/span&gt;heads &lt;span style="color:#f92672"&gt;=&lt;/span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Linear(model&lt;span style="color:#f92672"&gt;.&lt;/span&gt;hidden_dim, num_classes)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# 前処理も取得&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; preprocess &lt;span style="color:#f92672"&gt;=&lt;/span&gt; weights&lt;span style="color:#f92672"&gt;.&lt;/span&gt;transforms()
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;return&lt;/span&gt; model, preprocess
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#75715e"&gt;# 使用例&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;model, preprocess &lt;span style="color:#f92672"&gt;=&lt;/span&gt; use_pretrained_vit()
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;h2 id="4-実行の結果"&gt;4. 実行の結果&lt;/h2&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;-webkit-text-size-adjust:none;"&gt;&lt;code class="language-bash" data-lang="bash"&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;$ python vision-transformer-classification.py 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;モデルパラメータ数: 11,022,730
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;Epoch &lt;span style="color:#f92672"&gt;[&lt;/span&gt;1/50&lt;span style="color:#f92672"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; Train Loss: 1.7205, Train Acc: 36.19%
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; Test Loss: 1.5259, Test Acc: 44.82%
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;Epoch &lt;span style="color:#f92672"&gt;[&lt;/span&gt;2/50&lt;span style="color:#f92672"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; Train Loss: 1.4613, Train Acc: 46.57%
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; Test Loss: 1.3900, Test Acc: 49.07%
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;...
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;Epoch &lt;span style="color:#f92672"&gt;[&lt;/span&gt;46/50&lt;span style="color:#f92672"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; Train Loss: 0.0055, Train Acc: 99.86%
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; Test Loss: 1.6752, Test Acc: 74.55%
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;Epoch &lt;span style="color:#f92672"&gt;[&lt;/span&gt;47/50&lt;span style="color:#f92672"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; Train Loss: 0.0057, Train Acc: 99.82%
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; Test Loss: 1.6772, Test Acc: 74.43%
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;Epoch &lt;span style="color:#f92672"&gt;[&lt;/span&gt;48/50&lt;span style="color:#f92672"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; Train Loss: 0.0044, Train Acc: 99.89%
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; Test Loss: 1.6697, Test Acc: 74.76%
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;Epoch &lt;span style="color:#f92672"&gt;[&lt;/span&gt;49/50&lt;span style="color:#f92672"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; Train Loss: 0.0041, Train Acc: 99.90%
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; Test Loss: 1.6655, Test Acc: 74.78%
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;Epoch &lt;span style="color:#f92672"&gt;[&lt;/span&gt;50/50&lt;span style="color:#f92672"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; Train Loss: 0.0043, Train Acc: 99.89%
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; Test Loss: 1.6649, Test Acc: 74.76%
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;訓練完了!
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;h2 id="主要な構成要素の解説"&gt;主要な構成要素の解説&lt;/h2&gt;
&lt;ol&gt;
&lt;li&gt;&lt;strong&gt;パッチ埋め込み&lt;/strong&gt;: 画像を16×16などのパッチに分割し、ベクトル化&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;CLSトークン&lt;/strong&gt;: 分類に使用する特別なトークン&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;位置埋め込み&lt;/strong&gt;: パッチの位置情報を学習&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;Transformer&lt;/strong&gt;: Self-Attentionで画像全体の関係性を捉える&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;分類ヘッド&lt;/strong&gt;: CLSトークンから最終的な予測を出力&lt;/li&gt;
&lt;/ol&gt;
&lt;p&gt;このコードでCIFAR-10での画像分類が実行できます！&lt;/p&gt;</description></item><item><title>PyTorch Lightningで画像分類</title><link>https://ml.askbox.net/posts/pytorch-lightning-cnn-classification/</link><pubDate>Mon, 09 Feb 2026 14:44:37 +0900</pubDate><guid>https://ml.askbox.net/posts/pytorch-lightning-cnn-classification/</guid><description>&lt;h2 id="pytorch-lightningで画像分類のサンプルコード解説"&gt;PyTorch Lightningで画像分類のサンプルコード解説&lt;/h2&gt;
&lt;p&gt;PyTorch Lightningを使った画像分類の完全なサンプルコードを解説します。&lt;/p&gt;
&lt;h2 id="完全なサンプルコード"&gt;完全なサンプルコード&lt;/h2&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;-webkit-text-size-adjust:none;"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#f92672"&gt;import&lt;/span&gt; torch
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#f92672"&gt;import&lt;/span&gt; torch.nn &lt;span style="color:#66d9ef"&gt;as&lt;/span&gt; nn
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#f92672"&gt;import&lt;/span&gt; torch.nn.functional &lt;span style="color:#66d9ef"&gt;as&lt;/span&gt; F
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#f92672"&gt;from&lt;/span&gt; torch.utils.data &lt;span style="color:#f92672"&gt;import&lt;/span&gt; DataLoader, random_split
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#f92672"&gt;from&lt;/span&gt; torchvision &lt;span style="color:#f92672"&gt;import&lt;/span&gt; datasets, transforms
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#f92672"&gt;import&lt;/span&gt; pytorch_lightning &lt;span style="color:#66d9ef"&gt;as&lt;/span&gt; pl
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#f92672"&gt;from&lt;/span&gt; pytorch_lightning.callbacks &lt;span style="color:#f92672"&gt;import&lt;/span&gt; ModelCheckpoint, EarlyStopping
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#75715e"&gt;# 1. モデル定義&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#66d9ef"&gt;class&lt;/span&gt; &lt;span style="color:#a6e22e"&gt;ImageClassifier&lt;/span&gt;(pl&lt;span style="color:#f92672"&gt;.&lt;/span&gt;LightningModule):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;def&lt;/span&gt; &lt;span style="color:#a6e22e"&gt;__init__&lt;/span&gt;(self, num_classes&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;10&lt;/span&gt;, learning_rate&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;1e-3&lt;/span&gt;):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; super()&lt;span style="color:#f92672"&gt;.&lt;/span&gt;&lt;span style="color:#a6e22e"&gt;__init__&lt;/span&gt;()
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;save_hyperparameters() &lt;span style="color:#75715e"&gt;# ハイパーパラメータを自動保存&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# 簡単なCNNモデル&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;conv1 &lt;span style="color:#f92672"&gt;=&lt;/span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Conv2d(&lt;span style="color:#ae81ff"&gt;3&lt;/span&gt;, &lt;span style="color:#ae81ff"&gt;32&lt;/span&gt;, &lt;span style="color:#ae81ff"&gt;3&lt;/span&gt;, padding&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;1&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;conv2 &lt;span style="color:#f92672"&gt;=&lt;/span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Conv2d(&lt;span style="color:#ae81ff"&gt;32&lt;/span&gt;, &lt;span style="color:#ae81ff"&gt;64&lt;/span&gt;, &lt;span style="color:#ae81ff"&gt;3&lt;/span&gt;, padding&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;1&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;conv3 &lt;span style="color:#f92672"&gt;=&lt;/span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Conv2d(&lt;span style="color:#ae81ff"&gt;64&lt;/span&gt;, &lt;span style="color:#ae81ff"&gt;128&lt;/span&gt;, &lt;span style="color:#ae81ff"&gt;3&lt;/span&gt;, padding&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;1&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;pool &lt;span style="color:#f92672"&gt;=&lt;/span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;MaxPool2d(&lt;span style="color:#ae81ff"&gt;2&lt;/span&gt;, &lt;span style="color:#ae81ff"&gt;2&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;fc1 &lt;span style="color:#f92672"&gt;=&lt;/span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Linear(&lt;span style="color:#ae81ff"&gt;128&lt;/span&gt; &lt;span style="color:#f92672"&gt;*&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;4&lt;/span&gt; &lt;span style="color:#f92672"&gt;*&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;4&lt;/span&gt;, &lt;span style="color:#ae81ff"&gt;512&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;fc2 &lt;span style="color:#f92672"&gt;=&lt;/span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Linear(&lt;span style="color:#ae81ff"&gt;512&lt;/span&gt;, num_classes)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;dropout &lt;span style="color:#f92672"&gt;=&lt;/span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Dropout(&lt;span style="color:#ae81ff"&gt;0.5&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;def&lt;/span&gt; &lt;span style="color:#a6e22e"&gt;forward&lt;/span&gt;(self, x):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# 順伝播の定義&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; x &lt;span style="color:#f92672"&gt;=&lt;/span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;pool(F&lt;span style="color:#f92672"&gt;.&lt;/span&gt;relu(self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;conv1(x))) &lt;span style="color:#75715e"&gt;# 32x32 -&amp;gt; 16x16&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; x &lt;span style="color:#f92672"&gt;=&lt;/span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;pool(F&lt;span style="color:#f92672"&gt;.&lt;/span&gt;relu(self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;conv2(x))) &lt;span style="color:#75715e"&gt;# 16x16 -&amp;gt; 8x8&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; x &lt;span style="color:#f92672"&gt;=&lt;/span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;pool(F&lt;span style="color:#f92672"&gt;.&lt;/span&gt;relu(self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;conv3(x))) &lt;span style="color:#75715e"&gt;# 8x8 -&amp;gt; 4x4&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; x &lt;span style="color:#f92672"&gt;=&lt;/span&gt; x&lt;span style="color:#f92672"&gt;.&lt;/span&gt;view(&lt;span style="color:#f92672"&gt;-&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;1&lt;/span&gt;, &lt;span style="color:#ae81ff"&gt;128&lt;/span&gt; &lt;span style="color:#f92672"&gt;*&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;4&lt;/span&gt; &lt;span style="color:#f92672"&gt;*&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;4&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; x &lt;span style="color:#f92672"&gt;=&lt;/span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;dropout(F&lt;span style="color:#f92672"&gt;.&lt;/span&gt;relu(self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;fc1(x)))
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; x &lt;span style="color:#f92672"&gt;=&lt;/span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;fc2(x)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;return&lt;/span&gt; x
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;def&lt;/span&gt; &lt;span style="color:#a6e22e"&gt;training_step&lt;/span&gt;(self, batch, batch_idx):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# 訓練時の1ステップ&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; x, y &lt;span style="color:#f92672"&gt;=&lt;/span&gt; batch
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; logits &lt;span style="color:#f92672"&gt;=&lt;/span&gt; self(x)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; loss &lt;span style="color:#f92672"&gt;=&lt;/span&gt; F&lt;span style="color:#f92672"&gt;.&lt;/span&gt;cross_entropy(logits, y)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# 精度計算&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; preds &lt;span style="color:#f92672"&gt;=&lt;/span&gt; torch&lt;span style="color:#f92672"&gt;.&lt;/span&gt;argmax(logits, dim&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;1&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; acc &lt;span style="color:#f92672"&gt;=&lt;/span&gt; (preds &lt;span style="color:#f92672"&gt;==&lt;/span&gt; y)&lt;span style="color:#f92672"&gt;.&lt;/span&gt;float()&lt;span style="color:#f92672"&gt;.&lt;/span&gt;mean()
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# ログ記録&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;log(&lt;span style="color:#e6db74"&gt;&amp;#39;train_loss&amp;#39;&lt;/span&gt;, loss, prog_bar&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#66d9ef"&gt;True&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;log(&lt;span style="color:#e6db74"&gt;&amp;#39;train_acc&amp;#39;&lt;/span&gt;, acc, prog_bar&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#66d9ef"&gt;True&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;return&lt;/span&gt; loss
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;def&lt;/span&gt; &lt;span style="color:#a6e22e"&gt;validation_step&lt;/span&gt;(self, batch, batch_idx):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# 検証時の1ステップ&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; x, y &lt;span style="color:#f92672"&gt;=&lt;/span&gt; batch
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; logits &lt;span style="color:#f92672"&gt;=&lt;/span&gt; self(x)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; loss &lt;span style="color:#f92672"&gt;=&lt;/span&gt; F&lt;span style="color:#f92672"&gt;.&lt;/span&gt;cross_entropy(logits, y)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; preds &lt;span style="color:#f92672"&gt;=&lt;/span&gt; torch&lt;span style="color:#f92672"&gt;.&lt;/span&gt;argmax(logits, dim&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;1&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; acc &lt;span style="color:#f92672"&gt;=&lt;/span&gt; (preds &lt;span style="color:#f92672"&gt;==&lt;/span&gt; y)&lt;span style="color:#f92672"&gt;.&lt;/span&gt;float()&lt;span style="color:#f92672"&gt;.&lt;/span&gt;mean()
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;log(&lt;span style="color:#e6db74"&gt;&amp;#39;val_loss&amp;#39;&lt;/span&gt;, loss, prog_bar&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#66d9ef"&gt;True&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;log(&lt;span style="color:#e6db74"&gt;&amp;#39;val_acc&amp;#39;&lt;/span&gt;, acc, prog_bar&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#66d9ef"&gt;True&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;return&lt;/span&gt; loss
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;def&lt;/span&gt; &lt;span style="color:#a6e22e"&gt;test_step&lt;/span&gt;(self, batch, batch_idx):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# テスト時の1ステップ&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; x, y &lt;span style="color:#f92672"&gt;=&lt;/span&gt; batch
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; logits &lt;span style="color:#f92672"&gt;=&lt;/span&gt; self(x)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; loss &lt;span style="color:#f92672"&gt;=&lt;/span&gt; F&lt;span style="color:#f92672"&gt;.&lt;/span&gt;cross_entropy(logits, y)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; preds &lt;span style="color:#f92672"&gt;=&lt;/span&gt; torch&lt;span style="color:#f92672"&gt;.&lt;/span&gt;argmax(logits, dim&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;1&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; acc &lt;span style="color:#f92672"&gt;=&lt;/span&gt; (preds &lt;span style="color:#f92672"&gt;==&lt;/span&gt; y)&lt;span style="color:#f92672"&gt;.&lt;/span&gt;float()&lt;span style="color:#f92672"&gt;.&lt;/span&gt;mean()
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;log(&lt;span style="color:#e6db74"&gt;&amp;#39;test_loss&amp;#39;&lt;/span&gt;, loss)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;log(&lt;span style="color:#e6db74"&gt;&amp;#39;test_acc&amp;#39;&lt;/span&gt;, acc)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;return&lt;/span&gt; loss
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;def&lt;/span&gt; &lt;span style="color:#a6e22e"&gt;configure_optimizers&lt;/span&gt;(self):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# オプティマイザとスケジューラの設定&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; optimizer &lt;span style="color:#f92672"&gt;=&lt;/span&gt; torch&lt;span style="color:#f92672"&gt;.&lt;/span&gt;optim&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Adam(self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;parameters(), lr&lt;span style="color:#f92672"&gt;=&lt;/span&gt;self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;hparams&lt;span style="color:#f92672"&gt;.&lt;/span&gt;learning_rate)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; scheduler &lt;span style="color:#f92672"&gt;=&lt;/span&gt; torch&lt;span style="color:#f92672"&gt;.&lt;/span&gt;optim&lt;span style="color:#f92672"&gt;.&lt;/span&gt;lr_scheduler&lt;span style="color:#f92672"&gt;.&lt;/span&gt;ReduceLROnPlateau(
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; optimizer, mode&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#e6db74"&gt;&amp;#39;min&amp;#39;&lt;/span&gt;, factor&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;0.5&lt;/span&gt;, patience&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;3&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; )
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;return&lt;/span&gt; {
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#e6db74"&gt;&amp;#39;optimizer&amp;#39;&lt;/span&gt;: optimizer,
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#e6db74"&gt;&amp;#39;lr_scheduler&amp;#39;&lt;/span&gt;: {
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#e6db74"&gt;&amp;#39;scheduler&amp;#39;&lt;/span&gt;: scheduler,
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#e6db74"&gt;&amp;#39;monitor&amp;#39;&lt;/span&gt;: &lt;span style="color:#e6db74"&gt;&amp;#39;val_loss&amp;#39;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; }
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; }
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#75715e"&gt;# 2. データモジュール定義&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#66d9ef"&gt;class&lt;/span&gt; &lt;span style="color:#a6e22e"&gt;CIFAR10DataModule&lt;/span&gt;(pl&lt;span style="color:#f92672"&gt;.&lt;/span&gt;LightningDataModule):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;def&lt;/span&gt; &lt;span style="color:#a6e22e"&gt;__init__&lt;/span&gt;(self, data_dir&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#e6db74"&gt;&amp;#39;./data&amp;#39;&lt;/span&gt;, batch_size&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;32&lt;/span&gt;, num_workers&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;4&lt;/span&gt;):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; super()&lt;span style="color:#f92672"&gt;.&lt;/span&gt;&lt;span style="color:#a6e22e"&gt;__init__&lt;/span&gt;()
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;data_dir &lt;span style="color:#f92672"&gt;=&lt;/span&gt; data_dir
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;batch_size &lt;span style="color:#f92672"&gt;=&lt;/span&gt; batch_size
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;num_workers &lt;span style="color:#f92672"&gt;=&lt;/span&gt; num_workers
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# データ前処理&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;transform_train &lt;span style="color:#f92672"&gt;=&lt;/span&gt; transforms&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Compose([
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; transforms&lt;span style="color:#f92672"&gt;.&lt;/span&gt;RandomHorizontalFlip(),
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; transforms&lt;span style="color:#f92672"&gt;.&lt;/span&gt;RandomCrop(&lt;span style="color:#ae81ff"&gt;32&lt;/span&gt;, padding&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;4&lt;/span&gt;),
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; transforms&lt;span style="color:#f92672"&gt;.&lt;/span&gt;ToTensor(),
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; transforms&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Normalize((&lt;span style="color:#ae81ff"&gt;0.5&lt;/span&gt;, &lt;span style="color:#ae81ff"&gt;0.5&lt;/span&gt;, &lt;span style="color:#ae81ff"&gt;0.5&lt;/span&gt;), (&lt;span style="color:#ae81ff"&gt;0.5&lt;/span&gt;, &lt;span style="color:#ae81ff"&gt;0.5&lt;/span&gt;, &lt;span style="color:#ae81ff"&gt;0.5&lt;/span&gt;))
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; ])
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;transform_test &lt;span style="color:#f92672"&gt;=&lt;/span&gt; transforms&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Compose([
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; transforms&lt;span style="color:#f92672"&gt;.&lt;/span&gt;ToTensor(),
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; transforms&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Normalize((&lt;span style="color:#ae81ff"&gt;0.5&lt;/span&gt;, &lt;span style="color:#ae81ff"&gt;0.5&lt;/span&gt;, &lt;span style="color:#ae81ff"&gt;0.5&lt;/span&gt;), (&lt;span style="color:#ae81ff"&gt;0.5&lt;/span&gt;, &lt;span style="color:#ae81ff"&gt;0.5&lt;/span&gt;, &lt;span style="color:#ae81ff"&gt;0.5&lt;/span&gt;))
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; ])
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;def&lt;/span&gt; &lt;span style="color:#a6e22e"&gt;prepare_data&lt;/span&gt;(self):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# データのダウンロード(1度だけ実行)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; datasets&lt;span style="color:#f92672"&gt;.&lt;/span&gt;CIFAR10(self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;data_dir, train&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#66d9ef"&gt;True&lt;/span&gt;, download&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#66d9ef"&gt;True&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; datasets&lt;span style="color:#f92672"&gt;.&lt;/span&gt;CIFAR10(self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;data_dir, train&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#66d9ef"&gt;False&lt;/span&gt;, download&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#66d9ef"&gt;True&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;def&lt;/span&gt; &lt;span style="color:#a6e22e"&gt;setup&lt;/span&gt;(self, stage&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#66d9ef"&gt;None&lt;/span&gt;):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# データセットの設定&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;if&lt;/span&gt; stage &lt;span style="color:#f92672"&gt;==&lt;/span&gt; &lt;span style="color:#e6db74"&gt;&amp;#39;fit&amp;#39;&lt;/span&gt; &lt;span style="color:#f92672"&gt;or&lt;/span&gt; stage &lt;span style="color:#f92672"&gt;is&lt;/span&gt; &lt;span style="color:#66d9ef"&gt;None&lt;/span&gt;:
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; cifar_full &lt;span style="color:#f92672"&gt;=&lt;/span&gt; datasets&lt;span style="color:#f92672"&gt;.&lt;/span&gt;CIFAR10(
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;data_dir, train&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#66d9ef"&gt;True&lt;/span&gt;, transform&lt;span style="color:#f92672"&gt;=&lt;/span&gt;self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;transform_train
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; )
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# 訓練データと検証データに分割&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;cifar_train, self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;cifar_val &lt;span style="color:#f92672"&gt;=&lt;/span&gt; random_split(
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; cifar_full, [&lt;span style="color:#ae81ff"&gt;45000&lt;/span&gt;, &lt;span style="color:#ae81ff"&gt;5000&lt;/span&gt;]
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; )
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;if&lt;/span&gt; stage &lt;span style="color:#f92672"&gt;==&lt;/span&gt; &lt;span style="color:#e6db74"&gt;&amp;#39;test&amp;#39;&lt;/span&gt; &lt;span style="color:#f92672"&gt;or&lt;/span&gt; stage &lt;span style="color:#f92672"&gt;is&lt;/span&gt; &lt;span style="color:#66d9ef"&gt;None&lt;/span&gt;:
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;cifar_test &lt;span style="color:#f92672"&gt;=&lt;/span&gt; datasets&lt;span style="color:#f92672"&gt;.&lt;/span&gt;CIFAR10(
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;data_dir, train&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#66d9ef"&gt;False&lt;/span&gt;, transform&lt;span style="color:#f92672"&gt;=&lt;/span&gt;self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;transform_test
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; )
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;def&lt;/span&gt; &lt;span style="color:#a6e22e"&gt;train_dataloader&lt;/span&gt;(self):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;return&lt;/span&gt; DataLoader(
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;cifar_train,
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; batch_size&lt;span style="color:#f92672"&gt;=&lt;/span&gt;self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;batch_size,
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; shuffle&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#66d9ef"&gt;True&lt;/span&gt;,
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; num_workers&lt;span style="color:#f92672"&gt;=&lt;/span&gt;self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;num_workers
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; )
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;def&lt;/span&gt; &lt;span style="color:#a6e22e"&gt;val_dataloader&lt;/span&gt;(self):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;return&lt;/span&gt; DataLoader(
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;cifar_val,
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; batch_size&lt;span style="color:#f92672"&gt;=&lt;/span&gt;self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;batch_size,
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; num_workers&lt;span style="color:#f92672"&gt;=&lt;/span&gt;self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;num_workers
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; )
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;def&lt;/span&gt; &lt;span style="color:#a6e22e"&gt;test_dataloader&lt;/span&gt;(self):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;return&lt;/span&gt; DataLoader(
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;cifar_test,
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; batch_size&lt;span style="color:#f92672"&gt;=&lt;/span&gt;self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;batch_size,
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; num_workers&lt;span style="color:#f92672"&gt;=&lt;/span&gt;self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;num_workers
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; )
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#75715e"&gt;# 3. 訓練実行&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#66d9ef"&gt;def&lt;/span&gt; &lt;span style="color:#a6e22e"&gt;main&lt;/span&gt;():
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# データモジュール作成&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; dm &lt;span style="color:#f92672"&gt;=&lt;/span&gt; CIFAR10DataModule(batch_size&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;64&lt;/span&gt;, num_workers&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;4&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# モデル作成&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; model &lt;span style="color:#f92672"&gt;=&lt;/span&gt; ImageClassifier(num_classes&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;10&lt;/span&gt;, learning_rate&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;1e-3&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# コールバック設定&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; checkpoint_callback &lt;span style="color:#f92672"&gt;=&lt;/span&gt; ModelCheckpoint(
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; monitor&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#e6db74"&gt;&amp;#39;val_loss&amp;#39;&lt;/span&gt;,
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; dirpath&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#e6db74"&gt;&amp;#39;checkpoints/&amp;#39;&lt;/span&gt;,
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; filename&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#e6db74"&gt;&amp;#39;cifar10-&lt;/span&gt;&lt;span style="color:#e6db74"&gt;{epoch:02d}&lt;/span&gt;&lt;span style="color:#e6db74"&gt;-&lt;/span&gt;&lt;span style="color:#e6db74"&gt;{val_loss:.2f}&lt;/span&gt;&lt;span style="color:#e6db74"&gt;&amp;#39;&lt;/span&gt;,
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; save_top_k&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;3&lt;/span&gt;,
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; mode&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#e6db74"&gt;&amp;#39;min&amp;#39;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; )
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; early_stop_callback &lt;span style="color:#f92672"&gt;=&lt;/span&gt; EarlyStopping(
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; monitor&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#e6db74"&gt;&amp;#39;val_loss&amp;#39;&lt;/span&gt;,
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; patience&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;5&lt;/span&gt;,
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; mode&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#e6db74"&gt;&amp;#39;min&amp;#39;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; )
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# Trainer設定&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; trainer &lt;span style="color:#f92672"&gt;=&lt;/span&gt; pl&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Trainer(
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; max_epochs&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;20&lt;/span&gt;,
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; accelerator&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#e6db74"&gt;&amp;#39;auto&amp;#39;&lt;/span&gt;, &lt;span style="color:#75715e"&gt;# 自動でGPU/CPU選択&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; devices&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;1&lt;/span&gt;,
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; callbacks&lt;span style="color:#f92672"&gt;=&lt;/span&gt;[checkpoint_callback, early_stop_callback],
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; log_every_n_steps&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;10&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; )
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# 訓練実行&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; trainer&lt;span style="color:#f92672"&gt;.&lt;/span&gt;fit(model, dm)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# テスト実行&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; trainer&lt;span style="color:#f92672"&gt;.&lt;/span&gt;test(model, dm)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#66d9ef"&gt;if&lt;/span&gt; __name__ &lt;span style="color:#f92672"&gt;==&lt;/span&gt; &lt;span style="color:#e6db74"&gt;&amp;#39;__main__&amp;#39;&lt;/span&gt;:
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; main()
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;h2 id="主要な構成要素の解説"&gt;主要な構成要素の解説&lt;/h2&gt;
&lt;h3 id="1-lightningmodule-モデル定義"&gt;1. &lt;strong&gt;LightningModule (モデル定義)&lt;/strong&gt;&lt;/h3&gt;
&lt;ul&gt;
&lt;li&gt;&lt;code&gt;__init__&lt;/code&gt;: モデルの層を定義&lt;/li&gt;
&lt;li&gt;&lt;code&gt;forward&lt;/code&gt;: 順伝播処理&lt;/li&gt;
&lt;li&gt;&lt;code&gt;training_step&lt;/code&gt;: 訓練時の処理（損失計算など）&lt;/li&gt;
&lt;li&gt;&lt;code&gt;validation_step&lt;/code&gt;: 検証時の処理&lt;/li&gt;
&lt;li&gt;&lt;code&gt;configure_optimizers&lt;/code&gt;: オプティマイザ設定&lt;/li&gt;
&lt;/ul&gt;
&lt;h3 id="2-lightningdatamodule-データ管理"&gt;2. &lt;strong&gt;LightningDataModule (データ管理)&lt;/strong&gt;&lt;/h3&gt;
&lt;ul&gt;
&lt;li&gt;&lt;code&gt;prepare_data&lt;/code&gt;: データダウンロード&lt;/li&gt;
&lt;li&gt;&lt;code&gt;setup&lt;/code&gt;: データセット分割&lt;/li&gt;
&lt;li&gt;&lt;code&gt;train/val/test_dataloader&lt;/code&gt;: データローダー提供&lt;/li&gt;
&lt;/ul&gt;
&lt;h3 id="3-trainer-訓練管理"&gt;3. &lt;strong&gt;Trainer (訓練管理)&lt;/strong&gt;&lt;/h3&gt;
&lt;ul&gt;
&lt;li&gt;エポック数、GPU設定、コールバックなどを統合管理&lt;/li&gt;
&lt;/ul&gt;
&lt;h3 id="4-実行の結果"&gt;4. &lt;strong&gt;実行の結果&lt;/strong&gt;&lt;/h3&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;-webkit-text-size-adjust:none;"&gt;&lt;code class="language-bash" data-lang="bash"&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;$ python pytorch-lightning-cnn-classification.py 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;GPU available: True &lt;span style="color:#f92672"&gt;(&lt;/span&gt;cuda&lt;span style="color:#f92672"&gt;)&lt;/span&gt;, used: True
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;TPU available: False, using: &lt;span style="color:#ae81ff"&gt;0&lt;/span&gt; TPU cores
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;HPU available: False, using: &lt;span style="color:#ae81ff"&gt;0&lt;/span&gt; HPUs
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;LOCAL_RANK: &lt;span style="color:#ae81ff"&gt;0&lt;/span&gt; - CUDA_VISIBLE_DEVICES: &lt;span style="color:#f92672"&gt;[&lt;/span&gt;0&lt;span style="color:#f92672"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; | Name | Type | Params | Mode 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;----------------------------------------------
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#ae81ff"&gt;0&lt;/span&gt; | conv1 | Conv2d | &lt;span style="color:#ae81ff"&gt;896&lt;/span&gt; | train
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#ae81ff"&gt;1&lt;/span&gt; | conv2 | Conv2d | 18.5 K | train
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#ae81ff"&gt;2&lt;/span&gt; | conv3 | Conv2d | 73.9 K | train
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#ae81ff"&gt;3&lt;/span&gt; | pool | MaxPool2d | &lt;span style="color:#ae81ff"&gt;0&lt;/span&gt; | train
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#ae81ff"&gt;4&lt;/span&gt; | fc1 | Linear | 1.0 M | train
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#ae81ff"&gt;5&lt;/span&gt; | fc2 | Linear | 5.1 K | train
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#ae81ff"&gt;6&lt;/span&gt; | dropout | Dropout | &lt;span style="color:#ae81ff"&gt;0&lt;/span&gt; | train
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;----------------------------------------------
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;1.1 M Trainable params
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#ae81ff"&gt;0&lt;/span&gt; Non-trainable params
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;1.1 M Total params
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;4.590 Total estimated model params size &lt;span style="color:#f92672"&gt;(&lt;/span&gt;MB&lt;span style="color:#f92672"&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#ae81ff"&gt;7&lt;/span&gt; Modules in train mode
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#ae81ff"&gt;0&lt;/span&gt; Modules in eval mode
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;Epoch 19: 100%|███████████████████████████████████| 704/704 &lt;span style="color:#f92672"&gt;[&lt;/span&gt;00:03&amp;lt;00:00, 195.75it/s, v_num&lt;span style="color:#f92672"&gt;=&lt;/span&gt;0, train_loss&lt;span style="color:#f92672"&gt;=&lt;/span&gt;0.0607, train_acc&lt;span style="color:#f92672"&gt;=&lt;/span&gt;1.000, val_loss&lt;span style="color:#f92672"&gt;=&lt;/span&gt;0.604, val_acc&lt;span style="color:#f92672"&gt;=&lt;/span&gt;0.790&lt;span style="color:#f92672"&gt;]&lt;/span&gt;&lt;span style="color:#e6db74"&gt;`&lt;/span&gt;Trainer.fit&lt;span style="color:#e6db74"&gt;`&lt;/span&gt; stopped: &lt;span style="color:#e6db74"&gt;`&lt;/span&gt;max_epochs&lt;span style="color:#f92672"&gt;=&lt;/span&gt;20&lt;span style="color:#e6db74"&gt;`&lt;/span&gt; reached. 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;Epoch 19: 100%|███████████████████████████████████| 704/704 &lt;span style="color:#f92672"&gt;[&lt;/span&gt;00:03&amp;lt;00:00, 193.97it/s, v_num&lt;span style="color:#f92672"&gt;=&lt;/span&gt;0, train_loss&lt;span style="color:#f92672"&gt;=&lt;/span&gt;0.0607, train_acc&lt;span style="color:#f92672"&gt;=&lt;/span&gt;1.000, val_loss&lt;span style="color:#f92672"&gt;=&lt;/span&gt;0.604, val_acc&lt;span style="color:#f92672"&gt;=&lt;/span&gt;0.790&lt;span style="color:#f92672"&gt;]&lt;/span&gt;LOCAL_RANK: &lt;span style="color:#ae81ff"&gt;0&lt;/span&gt; - CUDA_VISIBLE_DEVICES: &lt;span style="color:#f92672"&gt;[&lt;/span&gt;0&lt;span style="color:#f92672"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;Testing DataLoader 0: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 157/157 &lt;span style="color:#f92672"&gt;[&lt;/span&gt;00:00&amp;lt;00:00, 413.72it/s&lt;span style="color:#f92672"&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;┃ Test metric ┃ DataLoader &lt;span style="color:#ae81ff"&gt;0&lt;/span&gt; ┃
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;│ test_acc │ 0.8090999722480774 │
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;│ test_loss │ 0.566657304763794 │
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;└───────────────────────────┴───────────────────────────┘
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;h2 id="pytorch-lightningの利点"&gt;PyTorch Lightningの利点&lt;/h2&gt;
&lt;ol&gt;
&lt;li&gt;&lt;strong&gt;コードが整理される&lt;/strong&gt;: 訓練ループを書く必要なし&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;GPU対応が簡単&lt;/strong&gt;: &lt;code&gt;accelerator='auto'&lt;/code&gt;だけ&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;再現性が高い&lt;/strong&gt;: ハイパーパラメータ自動保存&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;ログ管理が楽&lt;/strong&gt;: TensorBoard等に自動記録&lt;/li&gt;
&lt;/ol&gt;
&lt;p&gt;このコードをそのまま実行すれば、CIFAR-10での画像分類が動きます!&lt;/p&gt;</description></item><item><title>PyTorch + CNNでMNIST数字分類</title><link>https://ml.askbox.net/posts/mnist+cnn+classification/</link><pubDate>Sun, 08 Feb 2026 22:06:53 +0900</pubDate><guid>https://ml.askbox.net/posts/mnist+cnn+classification/</guid><description>&lt;h2 id="pytorch--cnnでmnist数字分類の解説"&gt;PyTorch + CNNでMNIST数字分類の解説&lt;/h2&gt;
&lt;p&gt;MNISTの手書き数字を分類するCNNの実装を段階的に解説します。&lt;/p&gt;
&lt;h2 id="完全なコード"&gt;完全なコード&lt;/h2&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;-webkit-text-size-adjust:none;"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#f92672"&gt;import&lt;/span&gt; torch
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#f92672"&gt;import&lt;/span&gt; torch.nn &lt;span style="color:#66d9ef"&gt;as&lt;/span&gt; nn
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#f92672"&gt;import&lt;/span&gt; torch.optim &lt;span style="color:#66d9ef"&gt;as&lt;/span&gt; optim
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#f92672"&gt;from&lt;/span&gt; torch.utils.data &lt;span style="color:#f92672"&gt;import&lt;/span&gt; DataLoader
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#f92672"&gt;from&lt;/span&gt; torchvision &lt;span style="color:#f92672"&gt;import&lt;/span&gt; datasets, transforms
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#75715e"&gt;# 1. CNNモデルの定義&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#66d9ef"&gt;class&lt;/span&gt; &lt;span style="color:#a6e22e"&gt;CNN&lt;/span&gt;(nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Module):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;def&lt;/span&gt; &lt;span style="color:#a6e22e"&gt;__init__&lt;/span&gt;(self):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; super(CNN, self)&lt;span style="color:#f92672"&gt;.&lt;/span&gt;&lt;span style="color:#a6e22e"&gt;__init__&lt;/span&gt;()
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# 畳み込み層&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;conv1 &lt;span style="color:#f92672"&gt;=&lt;/span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Conv2d(&lt;span style="color:#ae81ff"&gt;1&lt;/span&gt;, &lt;span style="color:#ae81ff"&gt;32&lt;/span&gt;, kernel_size&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;3&lt;/span&gt;, padding&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;1&lt;/span&gt;) &lt;span style="color:#75715e"&gt;# 28x28x1 → 28x28x32&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;conv2 &lt;span style="color:#f92672"&gt;=&lt;/span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Conv2d(&lt;span style="color:#ae81ff"&gt;32&lt;/span&gt;, &lt;span style="color:#ae81ff"&gt;64&lt;/span&gt;, kernel_size&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;3&lt;/span&gt;, padding&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;1&lt;/span&gt;) &lt;span style="color:#75715e"&gt;# 14x14x32 → 14x14x64&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# プーリング層&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;pool &lt;span style="color:#f92672"&gt;=&lt;/span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;MaxPool2d(&lt;span style="color:#ae81ff"&gt;2&lt;/span&gt;, &lt;span style="color:#ae81ff"&gt;2&lt;/span&gt;) &lt;span style="color:#75715e"&gt;# サイズを半分に&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# 全結合層&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;fc1 &lt;span style="color:#f92672"&gt;=&lt;/span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Linear(&lt;span style="color:#ae81ff"&gt;64&lt;/span&gt; &lt;span style="color:#f92672"&gt;*&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;7&lt;/span&gt; &lt;span style="color:#f92672"&gt;*&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;7&lt;/span&gt;, &lt;span style="color:#ae81ff"&gt;128&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;fc2 &lt;span style="color:#f92672"&gt;=&lt;/span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Linear(&lt;span style="color:#ae81ff"&gt;128&lt;/span&gt;, &lt;span style="color:#ae81ff"&gt;10&lt;/span&gt;) &lt;span style="color:#75715e"&gt;# 10クラス分類&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# ドロップアウト&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;dropout &lt;span style="color:#f92672"&gt;=&lt;/span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Dropout(&lt;span style="color:#ae81ff"&gt;0.5&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;def&lt;/span&gt; &lt;span style="color:#a6e22e"&gt;forward&lt;/span&gt;(self, x):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# 畳み込み + ReLU + プーリング&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; x &lt;span style="color:#f92672"&gt;=&lt;/span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;pool(torch&lt;span style="color:#f92672"&gt;.&lt;/span&gt;relu(self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;conv1(x))) &lt;span style="color:#75715e"&gt;# 28x28x32 → 14x14x32&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; x &lt;span style="color:#f92672"&gt;=&lt;/span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;pool(torch&lt;span style="color:#f92672"&gt;.&lt;/span&gt;relu(self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;conv2(x))) &lt;span style="color:#75715e"&gt;# 14x14x64 → 7x7x64&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# 平坦化&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; x &lt;span style="color:#f92672"&gt;=&lt;/span&gt; x&lt;span style="color:#f92672"&gt;.&lt;/span&gt;view(&lt;span style="color:#f92672"&gt;-&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;1&lt;/span&gt;, &lt;span style="color:#ae81ff"&gt;64&lt;/span&gt; &lt;span style="color:#f92672"&gt;*&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;7&lt;/span&gt; &lt;span style="color:#f92672"&gt;*&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;7&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# 全結合層&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; x &lt;span style="color:#f92672"&gt;=&lt;/span&gt; torch&lt;span style="color:#f92672"&gt;.&lt;/span&gt;relu(self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;fc1(x))
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; x &lt;span style="color:#f92672"&gt;=&lt;/span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;dropout(x)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; x &lt;span style="color:#f92672"&gt;=&lt;/span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;fc2(x)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;return&lt;/span&gt; x
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#75715e"&gt;# 2. データの準備&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;transform &lt;span style="color:#f92672"&gt;=&lt;/span&gt; transforms&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Compose([
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; transforms&lt;span style="color:#f92672"&gt;.&lt;/span&gt;ToTensor(), &lt;span style="color:#75715e"&gt;# PIL画像をTensorに変換&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; transforms&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Normalize((&lt;span style="color:#ae81ff"&gt;0.5&lt;/span&gt;,), (&lt;span style="color:#ae81ff"&gt;0.5&lt;/span&gt;,)) &lt;span style="color:#75715e"&gt;# 正規化 (平均, 標準偏差)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;])
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#75715e"&gt;# データセットのダウンロードと読み込み&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;train_dataset &lt;span style="color:#f92672"&gt;=&lt;/span&gt; datasets&lt;span style="color:#f92672"&gt;.&lt;/span&gt;MNIST(root&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#e6db74"&gt;&amp;#39;~/.pytorch/data&amp;#39;&lt;/span&gt;, train&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#66d9ef"&gt;True&lt;/span&gt;, 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; download&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#66d9ef"&gt;True&lt;/span&gt;, transform&lt;span style="color:#f92672"&gt;=&lt;/span&gt;transform)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;test_dataset &lt;span style="color:#f92672"&gt;=&lt;/span&gt; datasets&lt;span style="color:#f92672"&gt;.&lt;/span&gt;MNIST(root&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#e6db74"&gt;&amp;#39;~/.pytorch/data&amp;#39;&lt;/span&gt;, train&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#66d9ef"&gt;False&lt;/span&gt;, 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; download&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#66d9ef"&gt;True&lt;/span&gt;, transform&lt;span style="color:#f92672"&gt;=&lt;/span&gt;transform)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#75715e"&gt;# DataLoader作成&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;train_loader &lt;span style="color:#f92672"&gt;=&lt;/span&gt; DataLoader(train_dataset, batch_size&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;64&lt;/span&gt;, shuffle&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#66d9ef"&gt;True&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;test_loader &lt;span style="color:#f92672"&gt;=&lt;/span&gt; DataLoader(test_dataset, batch_size&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;1000&lt;/span&gt;, shuffle&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#66d9ef"&gt;False&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#75715e"&gt;# 3. モデル、損失関数、最適化手法の設定&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;device &lt;span style="color:#f92672"&gt;=&lt;/span&gt; torch&lt;span style="color:#f92672"&gt;.&lt;/span&gt;device(&lt;span style="color:#e6db74"&gt;&amp;#39;cuda&amp;#39;&lt;/span&gt; &lt;span style="color:#66d9ef"&gt;if&lt;/span&gt; torch&lt;span style="color:#f92672"&gt;.&lt;/span&gt;cuda&lt;span style="color:#f92672"&gt;.&lt;/span&gt;is_available() &lt;span style="color:#66d9ef"&gt;else&lt;/span&gt; &lt;span style="color:#e6db74"&gt;&amp;#39;cpu&amp;#39;&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#75715e"&gt;# MPSが利用可能かチェック&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#66d9ef"&gt;if&lt;/span&gt; torch&lt;span style="color:#f92672"&gt;.&lt;/span&gt;backends&lt;span style="color:#f92672"&gt;.&lt;/span&gt;mps&lt;span style="color:#f92672"&gt;.&lt;/span&gt;is_available():
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; device &lt;span style="color:#f92672"&gt;=&lt;/span&gt; torch&lt;span style="color:#f92672"&gt;.&lt;/span&gt;device(&lt;span style="color:#e6db74"&gt;&amp;#34;mps&amp;#34;&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;model &lt;span style="color:#f92672"&gt;=&lt;/span&gt; CNN()&lt;span style="color:#f92672"&gt;.&lt;/span&gt;to(device)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;criterion &lt;span style="color:#f92672"&gt;=&lt;/span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;CrossEntropyLoss() &lt;span style="color:#75715e"&gt;# 多クラス分類用の損失関数&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;optimizer &lt;span style="color:#f92672"&gt;=&lt;/span&gt; optim&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Adam(model&lt;span style="color:#f92672"&gt;.&lt;/span&gt;parameters(), lr&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;0.001&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#75715e"&gt;# 4. 訓練関数&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#66d9ef"&gt;def&lt;/span&gt; &lt;span style="color:#a6e22e"&gt;train&lt;/span&gt;(model, device, train_loader, optimizer, criterion, epoch):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; model&lt;span style="color:#f92672"&gt;.&lt;/span&gt;train() &lt;span style="color:#75715e"&gt;# 訓練モードに設定&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; total_loss &lt;span style="color:#f92672"&gt;=&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;0&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;for&lt;/span&gt; batch_idx, (data, target) &lt;span style="color:#f92672"&gt;in&lt;/span&gt; enumerate(train_loader):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; data, target &lt;span style="color:#f92672"&gt;=&lt;/span&gt; data&lt;span style="color:#f92672"&gt;.&lt;/span&gt;to(device), target&lt;span style="color:#f92672"&gt;.&lt;/span&gt;to(device)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# 勾配をゼロに&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; optimizer&lt;span style="color:#f92672"&gt;.&lt;/span&gt;zero_grad()
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# 順伝播&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; output &lt;span style="color:#f92672"&gt;=&lt;/span&gt; model(data)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# 損失計算&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; loss &lt;span style="color:#f92672"&gt;=&lt;/span&gt; criterion(output, target)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# 逆伝播&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; loss&lt;span style="color:#f92672"&gt;.&lt;/span&gt;backward()
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# パラメータ更新&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; optimizer&lt;span style="color:#f92672"&gt;.&lt;/span&gt;step()
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; total_loss &lt;span style="color:#f92672"&gt;+=&lt;/span&gt; loss&lt;span style="color:#f92672"&gt;.&lt;/span&gt;item()
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;if&lt;/span&gt; batch_idx &lt;span style="color:#f92672"&gt;%&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;100&lt;/span&gt; &lt;span style="color:#f92672"&gt;==&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;0&lt;/span&gt;:
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; print(&lt;span style="color:#e6db74"&gt;f&lt;/span&gt;&lt;span style="color:#e6db74"&gt;&amp;#39;Epoch: &lt;/span&gt;&lt;span style="color:#e6db74"&gt;{&lt;/span&gt;epoch&lt;span style="color:#e6db74"&gt;}&lt;/span&gt;&lt;span style="color:#e6db74"&gt;, Batch: &lt;/span&gt;&lt;span style="color:#e6db74"&gt;{&lt;/span&gt;batch_idx&lt;span style="color:#e6db74"&gt;}&lt;/span&gt;&lt;span style="color:#e6db74"&gt;, Loss: &lt;/span&gt;&lt;span style="color:#e6db74"&gt;{&lt;/span&gt;loss&lt;span style="color:#f92672"&gt;.&lt;/span&gt;item()&lt;span style="color:#e6db74"&gt;:&lt;/span&gt;&lt;span style="color:#e6db74"&gt;.4f&lt;/span&gt;&lt;span style="color:#e6db74"&gt;}&lt;/span&gt;&lt;span style="color:#e6db74"&gt;&amp;#39;&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;return&lt;/span&gt; total_loss &lt;span style="color:#f92672"&gt;/&lt;/span&gt; len(train_loader)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#75715e"&gt;# 5. テスト関数&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#66d9ef"&gt;def&lt;/span&gt; &lt;span style="color:#a6e22e"&gt;test&lt;/span&gt;(model, device, test_loader):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; model&lt;span style="color:#f92672"&gt;.&lt;/span&gt;eval() &lt;span style="color:#75715e"&gt;# 評価モードに設定&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; correct &lt;span style="color:#f92672"&gt;=&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;0&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; total &lt;span style="color:#f92672"&gt;=&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;0&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;with&lt;/span&gt; torch&lt;span style="color:#f92672"&gt;.&lt;/span&gt;no_grad(): &lt;span style="color:#75715e"&gt;# 勾配計算を無効化&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;for&lt;/span&gt; data, target &lt;span style="color:#f92672"&gt;in&lt;/span&gt; test_loader:
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; data, target &lt;span style="color:#f92672"&gt;=&lt;/span&gt; data&lt;span style="color:#f92672"&gt;.&lt;/span&gt;to(device), target&lt;span style="color:#f92672"&gt;.&lt;/span&gt;to(device)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; output &lt;span style="color:#f92672"&gt;=&lt;/span&gt; model(data)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# 最も確率の高いクラスを予測&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; _, predicted &lt;span style="color:#f92672"&gt;=&lt;/span&gt; torch&lt;span style="color:#f92672"&gt;.&lt;/span&gt;max(output&lt;span style="color:#f92672"&gt;.&lt;/span&gt;data, &lt;span style="color:#ae81ff"&gt;1&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; total &lt;span style="color:#f92672"&gt;+=&lt;/span&gt; target&lt;span style="color:#f92672"&gt;.&lt;/span&gt;size(&lt;span style="color:#ae81ff"&gt;0&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; correct &lt;span style="color:#f92672"&gt;+=&lt;/span&gt; (predicted &lt;span style="color:#f92672"&gt;==&lt;/span&gt; target)&lt;span style="color:#f92672"&gt;.&lt;/span&gt;sum()&lt;span style="color:#f92672"&gt;.&lt;/span&gt;item()
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; accuracy &lt;span style="color:#f92672"&gt;=&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;100&lt;/span&gt; &lt;span style="color:#f92672"&gt;*&lt;/span&gt; correct &lt;span style="color:#f92672"&gt;/&lt;/span&gt; total
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; print(&lt;span style="color:#e6db74"&gt;f&lt;/span&gt;&lt;span style="color:#e6db74"&gt;&amp;#39;Test Accuracy: &lt;/span&gt;&lt;span style="color:#e6db74"&gt;{&lt;/span&gt;accuracy&lt;span style="color:#e6db74"&gt;:&lt;/span&gt;&lt;span style="color:#e6db74"&gt;.2f&lt;/span&gt;&lt;span style="color:#e6db74"&gt;}&lt;/span&gt;&lt;span style="color:#e6db74"&gt;%&amp;#39;&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;return&lt;/span&gt; accuracy
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#75715e"&gt;# 6. 訓練実行&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;epochs &lt;span style="color:#f92672"&gt;=&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;5&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#66d9ef"&gt;for&lt;/span&gt; epoch &lt;span style="color:#f92672"&gt;in&lt;/span&gt; range(&lt;span style="color:#ae81ff"&gt;1&lt;/span&gt;, epochs &lt;span style="color:#f92672"&gt;+&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;1&lt;/span&gt;):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; train_loss &lt;span style="color:#f92672"&gt;=&lt;/span&gt; train(model, device, train_loader, optimizer, criterion, epoch)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; print(&lt;span style="color:#e6db74"&gt;f&lt;/span&gt;&lt;span style="color:#e6db74"&gt;&amp;#39;Average Loss: &lt;/span&gt;&lt;span style="color:#e6db74"&gt;{&lt;/span&gt;train_loss&lt;span style="color:#e6db74"&gt;:&lt;/span&gt;&lt;span style="color:#e6db74"&gt;.4f&lt;/span&gt;&lt;span style="color:#e6db74"&gt;}&lt;/span&gt;&lt;span style="color:#e6db74"&gt;&amp;#39;&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; test(model, device, test_loader)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; print(&lt;span style="color:#e6db74"&gt;&amp;#39;-&amp;#39;&lt;/span&gt; &lt;span style="color:#f92672"&gt;*&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;60&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#75715e"&gt;# モデルの保存&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;torch&lt;span style="color:#f92672"&gt;.&lt;/span&gt;save(model&lt;span style="color:#f92672"&gt;.&lt;/span&gt;state_dict(), &lt;span style="color:#e6db74"&gt;&amp;#39;mnist_cnn.pth&amp;#39;&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;h2 id="主要な構成要素の解説"&gt;主要な構成要素の解説&lt;/h2&gt;
&lt;h3 id="1-cnnモデル構造"&gt;1. &lt;strong&gt;CNNモデル構造&lt;/strong&gt;&lt;/h3&gt;
&lt;pre tabindex="0"&gt;&lt;code&gt;入力 (1x28x28)
 ↓
Conv2d (32フィルター) → ReLU → MaxPool → (32x14x14)
 ↓
Conv2d (64フィルター) → ReLU → MaxPool → (64x7x7)
 ↓
Flatten → (3136)
 ↓
FC (128) → ReLU → Dropout
 ↓
FC (10) → 出力
&lt;/code&gt;&lt;/pre&gt;&lt;h3 id="2-重要なパラメータ"&gt;2. &lt;strong&gt;重要なパラメータ&lt;/strong&gt;&lt;/h3&gt;
&lt;ul&gt;
&lt;li&gt;&lt;code&gt;kernel_size=3&lt;/code&gt;: 3×3の畳み込みフィルター&lt;/li&gt;
&lt;li&gt;&lt;code&gt;padding=1&lt;/code&gt;: 画像サイズを維持&lt;/li&gt;
&lt;li&gt;&lt;code&gt;MaxPool2d(2,2)&lt;/code&gt;: 2×2領域の最大値を取得&lt;/li&gt;
&lt;li&gt;&lt;code&gt;Dropout(0.5)&lt;/code&gt;: 過学習防止&lt;/li&gt;
&lt;/ul&gt;
&lt;h3 id="3-実行結果"&gt;3. &lt;strong&gt;実行結果&lt;/strong&gt;&lt;/h3&gt;
&lt;pre tabindex="0"&gt;&lt;code&gt;$ python3 mnist+cnn+classification.py
Epoch: 1, Batch: 0, Loss: 2.3096
Epoch: 1, Batch: 100, Loss: 0.3755
Epoch: 1, Batch: 200, Loss: 0.4064
Epoch: 1, Batch: 300, Loss: 0.2006
Epoch: 1, Batch: 400, Loss: 0.1787
Epoch: 1, Batch: 500, Loss: 0.0984
Epoch: 1, Batch: 600, Loss: 0.1776
Epoch: 1, Batch: 700, Loss: 0.2356
Epoch: 1, Batch: 800, Loss: 0.1497
Epoch: 1, Batch: 900, Loss: 0.1300
Average Loss: 0.2377
Test Accuracy: 98.61%
------------------------------------------------------------
Epoch: 2, Batch: 0, Loss: 0.1457
Epoch: 2, Batch: 100, Loss: 0.1019
Epoch: 2, Batch: 200, Loss: 0.0407
Epoch: 2, Batch: 300, Loss: 0.0687
Epoch: 2, Batch: 400, Loss: 0.0562
Epoch: 2, Batch: 500, Loss: 0.0583
Epoch: 2, Batch: 600, Loss: 0.0361
Epoch: 2, Batch: 700, Loss: 0.0554
Epoch: 2, Batch: 800, Loss: 0.0757
Epoch: 2, Batch: 900, Loss: 0.2820
Average Loss: 0.0859
Test Accuracy: 98.98%
------------------------------------------------------------
Epoch: 3, Batch: 0, Loss: 0.0496
Epoch: 3, Batch: 100, Loss: 0.1323
Epoch: 3, Batch: 200, Loss: 0.0146
Epoch: 3, Batch: 300, Loss: 0.0297
Epoch: 3, Batch: 400, Loss: 0.0217
Epoch: 3, Batch: 500, Loss: 0.0470
Epoch: 3, Batch: 600, Loss: 0.0499
Epoch: 3, Batch: 700, Loss: 0.0439
Epoch: 3, Batch: 800, Loss: 0.0967
Epoch: 3, Batch: 900, Loss: 0.0390
Average Loss: 0.0642
Test Accuracy: 99.00%
------------------------------------------------------------
Epoch: 4, Batch: 0, Loss: 0.0106
Epoch: 4, Batch: 100, Loss: 0.0114
Epoch: 4, Batch: 200, Loss: 0.0156
Epoch: 4, Batch: 300, Loss: 0.0550
Epoch: 4, Batch: 400, Loss: 0.0288
Epoch: 4, Batch: 500, Loss: 0.0282
Epoch: 4, Batch: 600, Loss: 0.1245
Epoch: 4, Batch: 700, Loss: 0.0610
Epoch: 4, Batch: 800, Loss: 0.0127
Epoch: 4, Batch: 900, Loss: 0.0365
Average Loss: 0.0516
Test Accuracy: 99.12%
------------------------------------------------------------
Epoch: 5, Batch: 0, Loss: 0.0130
Epoch: 5, Batch: 100, Loss: 0.0412
Epoch: 5, Batch: 200, Loss: 0.0173
Epoch: 5, Batch: 300, Loss: 0.0064
Epoch: 5, Batch: 400, Loss: 0.0599
Epoch: 5, Batch: 500, Loss: 0.0848
Epoch: 5, Batch: 600, Loss: 0.0542
Epoch: 5, Batch: 700, Loss: 0.0545
Epoch: 5, Batch: 800, Loss: 0.0207
Epoch: 5, Batch: 900, Loss: 0.0759
Average Loss: 0.0424
Test Accuracy: 99.20%
------------------------------------------------------------
&lt;/code&gt;&lt;/pre&gt;&lt;p&gt;このコードで99.20%の精度が達成できます！&lt;/p&gt;</description></item><item><title>CNNによる二値画像AutoEncoder</title><link>https://ml.askbox.net/posts/binary-image-cnn-autoender/</link><pubDate>Sun, 08 Feb 2026 15:53:08 +0900</pubDate><guid>https://ml.askbox.net/posts/binary-image-cnn-autoender/</guid><description>&lt;h2 id="pytorch-cnnによる二値画像autoencoderの解説"&gt;PyTorch CNNによる二値画像AutoEncoderの解説&lt;/h2&gt;
&lt;p&gt;二値画像（白黒画像）を扱うAutoEncoderのサンプルコードを解説します。&lt;/p&gt;
&lt;h2 id="完全なサンプルコード"&gt;完全なサンプルコード&lt;/h2&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;-webkit-text-size-adjust:none;"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#f92672"&gt;import&lt;/span&gt; torch
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#f92672"&gt;import&lt;/span&gt; torch.nn &lt;span style="color:#66d9ef"&gt;as&lt;/span&gt; nn
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#f92672"&gt;import&lt;/span&gt; torch.optim &lt;span style="color:#66d9ef"&gt;as&lt;/span&gt; optim
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#f92672"&gt;from&lt;/span&gt; torch.utils.data &lt;span style="color:#f92672"&gt;import&lt;/span&gt; DataLoader
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#f92672"&gt;from&lt;/span&gt; torchvision &lt;span style="color:#f92672"&gt;import&lt;/span&gt; datasets, transforms
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#f92672"&gt;import&lt;/span&gt; matplotlib.pyplot &lt;span style="color:#66d9ef"&gt;as&lt;/span&gt; plt
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#75715e"&gt;# デバイスの設定&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;device &lt;span style="color:#f92672"&gt;=&lt;/span&gt; torch&lt;span style="color:#f92672"&gt;.&lt;/span&gt;device(&lt;span style="color:#e6db74"&gt;&amp;#39;cuda&amp;#39;&lt;/span&gt; &lt;span style="color:#66d9ef"&gt;if&lt;/span&gt; torch&lt;span style="color:#f92672"&gt;.&lt;/span&gt;cuda&lt;span style="color:#f92672"&gt;.&lt;/span&gt;is_available() &lt;span style="color:#66d9ef"&gt;else&lt;/span&gt; &lt;span style="color:#e6db74"&gt;&amp;#39;cpu&amp;#39;&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#75715e"&gt;# MPSが利用可能かチェック&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#66d9ef"&gt;if&lt;/span&gt; torch&lt;span style="color:#f92672"&gt;.&lt;/span&gt;backends&lt;span style="color:#f92672"&gt;.&lt;/span&gt;mps&lt;span style="color:#f92672"&gt;.&lt;/span&gt;is_available():
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; device &lt;span style="color:#f92672"&gt;=&lt;/span&gt; torch&lt;span style="color:#f92672"&gt;.&lt;/span&gt;device(&lt;span style="color:#e6db74"&gt;&amp;#34;mps&amp;#34;&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#75715e"&gt;# ハイパーパラメータ&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;BATCH_SIZE &lt;span style="color:#f92672"&gt;=&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;128&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;LEARNING_RATE &lt;span style="color:#f92672"&gt;=&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;0.001&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;EPOCHS &lt;span style="color:#f92672"&gt;=&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;3&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;LATENT_DIM &lt;span style="color:#f92672"&gt;=&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;32&lt;/span&gt; &lt;span style="color:#75715e"&gt;# 潜在空間の次元数&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#75715e"&gt;# データの準備（MNISTを例に使用）&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;transform &lt;span style="color:#f92672"&gt;=&lt;/span&gt; transforms&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Compose([
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; transforms&lt;span style="color:#f92672"&gt;.&lt;/span&gt;ToTensor(),
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;#transforms.Normalize((0.5,), (0.5,)) # -1～1の範囲に正規化&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; transforms&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Lambda(&lt;span style="color:#66d9ef"&gt;lambda&lt;/span&gt; x: (x &lt;span style="color:#f92672"&gt;&amp;gt;&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;0.5&lt;/span&gt;)&lt;span style="color:#f92672"&gt;.&lt;/span&gt;float()) &lt;span style="color:#75715e"&gt;# 二値化&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;])
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;train_dataset &lt;span style="color:#f92672"&gt;=&lt;/span&gt; datasets&lt;span style="color:#f92672"&gt;.&lt;/span&gt;MNIST(
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; root&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#e6db74"&gt;&amp;#39;~/.pytorch/data&amp;#39;&lt;/span&gt;,
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; train&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#66d9ef"&gt;True&lt;/span&gt;,
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; download&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#66d9ef"&gt;True&lt;/span&gt;,
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; transform&lt;span style="color:#f92672"&gt;=&lt;/span&gt;transform
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;train_loader &lt;span style="color:#f92672"&gt;=&lt;/span&gt; DataLoader(
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; train_dataset,
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; batch_size&lt;span style="color:#f92672"&gt;=&lt;/span&gt;BATCH_SIZE,
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; shuffle&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#66d9ef"&gt;True&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#75715e"&gt;# Encoderの定義&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#66d9ef"&gt;class&lt;/span&gt; &lt;span style="color:#a6e22e"&gt;Encoder&lt;/span&gt;(nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Module):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;def&lt;/span&gt; &lt;span style="color:#a6e22e"&gt;__init__&lt;/span&gt;(self, latent_dim):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; super(Encoder, self)&lt;span style="color:#f92672"&gt;.&lt;/span&gt;&lt;span style="color:#a6e22e"&gt;__init__&lt;/span&gt;()
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# 畳み込み層&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;conv_layers &lt;span style="color:#f92672"&gt;=&lt;/span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Sequential(
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# 入力: 1x28x28&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Conv2d(&lt;span style="color:#ae81ff"&gt;1&lt;/span&gt;, &lt;span style="color:#ae81ff"&gt;32&lt;/span&gt;, kernel_size&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;3&lt;/span&gt;, stride&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;2&lt;/span&gt;, padding&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;1&lt;/span&gt;), &lt;span style="color:#75715e"&gt;# 32x14x14&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;BatchNorm2d(&lt;span style="color:#ae81ff"&gt;32&lt;/span&gt;),
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;ReLU(),
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Conv2d(&lt;span style="color:#ae81ff"&gt;32&lt;/span&gt;, &lt;span style="color:#ae81ff"&gt;64&lt;/span&gt;, kernel_size&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;3&lt;/span&gt;, stride&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;2&lt;/span&gt;, padding&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;1&lt;/span&gt;), &lt;span style="color:#75715e"&gt;# 64x7x7&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;BatchNorm2d(&lt;span style="color:#ae81ff"&gt;64&lt;/span&gt;),
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;ReLU(),
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Conv2d(&lt;span style="color:#ae81ff"&gt;64&lt;/span&gt;, &lt;span style="color:#ae81ff"&gt;128&lt;/span&gt;, kernel_size&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;3&lt;/span&gt;, stride&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;2&lt;/span&gt;, padding&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;1&lt;/span&gt;), &lt;span style="color:#75715e"&gt;# 128x4x4&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;BatchNorm2d(&lt;span style="color:#ae81ff"&gt;128&lt;/span&gt;),
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;ReLU(),
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; )
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# 全結合層で潜在空間へ&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;fc &lt;span style="color:#f92672"&gt;=&lt;/span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Linear(&lt;span style="color:#ae81ff"&gt;128&lt;/span&gt; &lt;span style="color:#f92672"&gt;*&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;4&lt;/span&gt; &lt;span style="color:#f92672"&gt;*&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;4&lt;/span&gt;, latent_dim)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;def&lt;/span&gt; &lt;span style="color:#a6e22e"&gt;forward&lt;/span&gt;(self, x):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; x &lt;span style="color:#f92672"&gt;=&lt;/span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;conv_layers(x)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; x &lt;span style="color:#f92672"&gt;=&lt;/span&gt; x&lt;span style="color:#f92672"&gt;.&lt;/span&gt;view(x&lt;span style="color:#f92672"&gt;.&lt;/span&gt;size(&lt;span style="color:#ae81ff"&gt;0&lt;/span&gt;), &lt;span style="color:#f92672"&gt;-&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;1&lt;/span&gt;) &lt;span style="color:#75715e"&gt;# Flatten&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; x &lt;span style="color:#f92672"&gt;=&lt;/span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;fc(x)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;return&lt;/span&gt; x
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#75715e"&gt;# Decoderの定義&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#66d9ef"&gt;class&lt;/span&gt; &lt;span style="color:#a6e22e"&gt;Decoder&lt;/span&gt;(nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Module):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;def&lt;/span&gt; &lt;span style="color:#a6e22e"&gt;__init__&lt;/span&gt;(self, latent_dim):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; super(Decoder, self)&lt;span style="color:#f92672"&gt;.&lt;/span&gt;&lt;span style="color:#a6e22e"&gt;__init__&lt;/span&gt;()
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# 潜在空間から特徴マップへ&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;fc &lt;span style="color:#f92672"&gt;=&lt;/span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Linear(latent_dim, &lt;span style="color:#ae81ff"&gt;128&lt;/span&gt; &lt;span style="color:#f92672"&gt;*&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;4&lt;/span&gt; &lt;span style="color:#f92672"&gt;*&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;4&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# 転置畳み込み層（逆畳み込み）&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;deconv_layers &lt;span style="color:#f92672"&gt;=&lt;/span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Sequential(
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# 入力: 128x4x4&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;ConvTranspose2d(&lt;span style="color:#ae81ff"&gt;128&lt;/span&gt;, &lt;span style="color:#ae81ff"&gt;64&lt;/span&gt;, kernel_size&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;4&lt;/span&gt;, stride&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;2&lt;/span&gt;, padding&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;1&lt;/span&gt;), &lt;span style="color:#75715e"&gt;# 64x8x8&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;BatchNorm2d(&lt;span style="color:#ae81ff"&gt;64&lt;/span&gt;),
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;ReLU(),
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;ConvTranspose2d(&lt;span style="color:#ae81ff"&gt;64&lt;/span&gt;, &lt;span style="color:#ae81ff"&gt;32&lt;/span&gt;, kernel_size&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;4&lt;/span&gt;, stride&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;2&lt;/span&gt;, padding&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;1&lt;/span&gt;), &lt;span style="color:#75715e"&gt;# 32x16x16&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;BatchNorm2d(&lt;span style="color:#ae81ff"&gt;32&lt;/span&gt;),
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;ReLU(),
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;ConvTranspose2d(&lt;span style="color:#ae81ff"&gt;32&lt;/span&gt;, &lt;span style="color:#ae81ff"&gt;1&lt;/span&gt;, kernel_size&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;4&lt;/span&gt;, stride&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;2&lt;/span&gt;, padding&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;3&lt;/span&gt;), &lt;span style="color:#75715e"&gt;# 1x28x28&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;#nn.Tanh() # -1～1の範囲に出力&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Sigmoid() &lt;span style="color:#75715e"&gt;# 0-1の範囲に出力を制限&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; )
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;def&lt;/span&gt; &lt;span style="color:#a6e22e"&gt;forward&lt;/span&gt;(self, x):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; x &lt;span style="color:#f92672"&gt;=&lt;/span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;fc(x)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; x &lt;span style="color:#f92672"&gt;=&lt;/span&gt; x&lt;span style="color:#f92672"&gt;.&lt;/span&gt;view(x&lt;span style="color:#f92672"&gt;.&lt;/span&gt;size(&lt;span style="color:#ae81ff"&gt;0&lt;/span&gt;), &lt;span style="color:#ae81ff"&gt;128&lt;/span&gt;, &lt;span style="color:#ae81ff"&gt;4&lt;/span&gt;, &lt;span style="color:#ae81ff"&gt;4&lt;/span&gt;) &lt;span style="color:#75715e"&gt;# Reshape&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; x &lt;span style="color:#f92672"&gt;=&lt;/span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;deconv_layers(x)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;return&lt;/span&gt; x
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#75715e"&gt;# AutoEncoderの定義&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#66d9ef"&gt;class&lt;/span&gt; &lt;span style="color:#a6e22e"&gt;AutoEncoder&lt;/span&gt;(nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Module):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;def&lt;/span&gt; &lt;span style="color:#a6e22e"&gt;__init__&lt;/span&gt;(self, latent_dim):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; super(AutoEncoder, self)&lt;span style="color:#f92672"&gt;.&lt;/span&gt;&lt;span style="color:#a6e22e"&gt;__init__&lt;/span&gt;()
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;encoder &lt;span style="color:#f92672"&gt;=&lt;/span&gt; Encoder(latent_dim)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;decoder &lt;span style="color:#f92672"&gt;=&lt;/span&gt; Decoder(latent_dim)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;def&lt;/span&gt; &lt;span style="color:#a6e22e"&gt;forward&lt;/span&gt;(self, x):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; latent &lt;span style="color:#f92672"&gt;=&lt;/span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;encoder(x)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; reconstructed &lt;span style="color:#f92672"&gt;=&lt;/span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;decoder(latent)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;return&lt;/span&gt; reconstructed
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#75715e"&gt;# モデルの初期化&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;model &lt;span style="color:#f92672"&gt;=&lt;/span&gt; AutoEncoder(LATENT_DIM)&lt;span style="color:#f92672"&gt;.&lt;/span&gt;to(device)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#75715e"&gt;# 損失関数と最適化手法&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#75715e"&gt;# criterion = nn.MSELoss() # 平均二乗誤差&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#75715e"&gt;# 損失関数（二値交差エントロピー）&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;criterion &lt;span style="color:#f92672"&gt;=&lt;/span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;BCELoss()
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;optimizer &lt;span style="color:#f92672"&gt;=&lt;/span&gt; optim&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Adam(model&lt;span style="color:#f92672"&gt;.&lt;/span&gt;parameters(), lr&lt;span style="color:#f92672"&gt;=&lt;/span&gt;LEARNING_RATE)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#75715e"&gt;# 訓練ループ&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#66d9ef"&gt;def&lt;/span&gt; &lt;span style="color:#a6e22e"&gt;train&lt;/span&gt;(model, train_loader, criterion, optimizer, epochs):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; model&lt;span style="color:#f92672"&gt;.&lt;/span&gt;train()
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; losses &lt;span style="color:#f92672"&gt;=&lt;/span&gt; []
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;for&lt;/span&gt; epoch &lt;span style="color:#f92672"&gt;in&lt;/span&gt; range(epochs):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; epoch_loss &lt;span style="color:#f92672"&gt;=&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;0&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;for&lt;/span&gt; batch_idx, (data, _) &lt;span style="color:#f92672"&gt;in&lt;/span&gt; enumerate(train_loader):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; data &lt;span style="color:#f92672"&gt;=&lt;/span&gt; data&lt;span style="color:#f92672"&gt;.&lt;/span&gt;to(device)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# 勾配の初期化&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; optimizer&lt;span style="color:#f92672"&gt;.&lt;/span&gt;zero_grad()
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# 順伝播&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; reconstructed &lt;span style="color:#f92672"&gt;=&lt;/span&gt; model(data)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# 損失計算&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; loss &lt;span style="color:#f92672"&gt;=&lt;/span&gt; criterion(reconstructed, data)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# 逆伝播&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; loss&lt;span style="color:#f92672"&gt;.&lt;/span&gt;backward()
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# パラメータ更新&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; optimizer&lt;span style="color:#f92672"&gt;.&lt;/span&gt;step()
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; epoch_loss &lt;span style="color:#f92672"&gt;+=&lt;/span&gt; loss&lt;span style="color:#f92672"&gt;.&lt;/span&gt;item()
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;if&lt;/span&gt; batch_idx &lt;span style="color:#f92672"&gt;%&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;100&lt;/span&gt; &lt;span style="color:#f92672"&gt;==&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;0&lt;/span&gt;:
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; print(&lt;span style="color:#e6db74"&gt;f&lt;/span&gt;&lt;span style="color:#e6db74"&gt;&amp;#39;Epoch [&lt;/span&gt;&lt;span style="color:#e6db74"&gt;{&lt;/span&gt;epoch&lt;span style="color:#f92672"&gt;+&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;1&lt;/span&gt;&lt;span style="color:#e6db74"&gt;}&lt;/span&gt;&lt;span style="color:#e6db74"&gt;/&lt;/span&gt;&lt;span style="color:#e6db74"&gt;{&lt;/span&gt;epochs&lt;span style="color:#e6db74"&gt;}&lt;/span&gt;&lt;span style="color:#e6db74"&gt;], Step [&lt;/span&gt;&lt;span style="color:#e6db74"&gt;{&lt;/span&gt;batch_idx&lt;span style="color:#e6db74"&gt;}&lt;/span&gt;&lt;span style="color:#e6db74"&gt;/&lt;/span&gt;&lt;span style="color:#e6db74"&gt;{&lt;/span&gt;len(train_loader)&lt;span style="color:#e6db74"&gt;}&lt;/span&gt;&lt;span style="color:#e6db74"&gt;], Loss: &lt;/span&gt;&lt;span style="color:#e6db74"&gt;{&lt;/span&gt;loss&lt;span style="color:#f92672"&gt;.&lt;/span&gt;item()&lt;span style="color:#e6db74"&gt;:&lt;/span&gt;&lt;span style="color:#e6db74"&gt;.4f&lt;/span&gt;&lt;span style="color:#e6db74"&gt;}&lt;/span&gt;&lt;span style="color:#e6db74"&gt;&amp;#39;&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; avg_loss &lt;span style="color:#f92672"&gt;=&lt;/span&gt; epoch_loss &lt;span style="color:#f92672"&gt;/&lt;/span&gt; len(train_loader)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; losses&lt;span style="color:#f92672"&gt;.&lt;/span&gt;append(avg_loss)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; print(&lt;span style="color:#e6db74"&gt;f&lt;/span&gt;&lt;span style="color:#e6db74"&gt;&amp;#39;Epoch [&lt;/span&gt;&lt;span style="color:#e6db74"&gt;{&lt;/span&gt;epoch&lt;span style="color:#f92672"&gt;+&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;1&lt;/span&gt;&lt;span style="color:#e6db74"&gt;}&lt;/span&gt;&lt;span style="color:#e6db74"&gt;/&lt;/span&gt;&lt;span style="color:#e6db74"&gt;{&lt;/span&gt;epochs&lt;span style="color:#e6db74"&gt;}&lt;/span&gt;&lt;span style="color:#e6db74"&gt;], Average Loss: &lt;/span&gt;&lt;span style="color:#e6db74"&gt;{&lt;/span&gt;avg_loss&lt;span style="color:#e6db74"&gt;:&lt;/span&gt;&lt;span style="color:#e6db74"&gt;.4f&lt;/span&gt;&lt;span style="color:#e6db74"&gt;}&lt;/span&gt;&lt;span style="color:#e6db74"&gt;&amp;#39;&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;return&lt;/span&gt; losses
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#75715e"&gt;# 訓練実行&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;losses &lt;span style="color:#f92672"&gt;=&lt;/span&gt; train(model, train_loader, criterion, optimizer, EPOCHS)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#75715e"&gt;# 結果の可視化&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#66d9ef"&gt;def&lt;/span&gt; &lt;span style="color:#a6e22e"&gt;visualize_results&lt;/span&gt;(model, test_loader, num_images&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;10&lt;/span&gt;):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; model&lt;span style="color:#f92672"&gt;.&lt;/span&gt;eval()
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;with&lt;/span&gt; torch&lt;span style="color:#f92672"&gt;.&lt;/span&gt;no_grad():
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; data, _ &lt;span style="color:#f92672"&gt;=&lt;/span&gt; next(iter(test_loader))
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; data &lt;span style="color:#f92672"&gt;=&lt;/span&gt; data[:num_images]&lt;span style="color:#f92672"&gt;.&lt;/span&gt;to(device)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; reconstructed &lt;span style="color:#f92672"&gt;=&lt;/span&gt; model(data)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# CPU に移動して表示&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; data &lt;span style="color:#f92672"&gt;=&lt;/span&gt; data&lt;span style="color:#f92672"&gt;.&lt;/span&gt;cpu()
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; reconstructed &lt;span style="color:#f92672"&gt;=&lt;/span&gt; reconstructed&lt;span style="color:#f92672"&gt;.&lt;/span&gt;cpu()
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; fig, axes &lt;span style="color:#f92672"&gt;=&lt;/span&gt; plt&lt;span style="color:#f92672"&gt;.&lt;/span&gt;subplots(&lt;span style="color:#ae81ff"&gt;2&lt;/span&gt;, num_images, figsize&lt;span style="color:#f92672"&gt;=&lt;/span&gt;(&lt;span style="color:#ae81ff"&gt;15&lt;/span&gt;, &lt;span style="color:#ae81ff"&gt;3&lt;/span&gt;))
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;for&lt;/span&gt; i &lt;span style="color:#f92672"&gt;in&lt;/span&gt; range(num_images):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# 元画像&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; axes[&lt;span style="color:#ae81ff"&gt;0&lt;/span&gt;, i]&lt;span style="color:#f92672"&gt;.&lt;/span&gt;imshow(data[i]&lt;span style="color:#f92672"&gt;.&lt;/span&gt;squeeze(), cmap&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#e6db74"&gt;&amp;#39;gray&amp;#39;&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; axes[&lt;span style="color:#ae81ff"&gt;0&lt;/span&gt;, i]&lt;span style="color:#f92672"&gt;.&lt;/span&gt;axis(&lt;span style="color:#e6db74"&gt;&amp;#39;off&amp;#39;&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;if&lt;/span&gt; i &lt;span style="color:#f92672"&gt;==&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;0&lt;/span&gt;:
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; axes[&lt;span style="color:#ae81ff"&gt;0&lt;/span&gt;, i]&lt;span style="color:#f92672"&gt;.&lt;/span&gt;set_title(&lt;span style="color:#e6db74"&gt;&amp;#39;Original&amp;#39;&lt;/span&gt;, fontsize&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;10&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# 再構成画像&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; axes[&lt;span style="color:#ae81ff"&gt;1&lt;/span&gt;, i]&lt;span style="color:#f92672"&gt;.&lt;/span&gt;imshow(reconstructed[i]&lt;span style="color:#f92672"&gt;.&lt;/span&gt;squeeze(), cmap&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#e6db74"&gt;&amp;#39;gray&amp;#39;&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; axes[&lt;span style="color:#ae81ff"&gt;1&lt;/span&gt;, i]&lt;span style="color:#f92672"&gt;.&lt;/span&gt;axis(&lt;span style="color:#e6db74"&gt;&amp;#39;off&amp;#39;&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;if&lt;/span&gt; i &lt;span style="color:#f92672"&gt;==&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;0&lt;/span&gt;:
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; axes[&lt;span style="color:#ae81ff"&gt;1&lt;/span&gt;, i]&lt;span style="color:#f92672"&gt;.&lt;/span&gt;set_title(&lt;span style="color:#e6db74"&gt;&amp;#39;Reconstructed&amp;#39;&lt;/span&gt;, fontsize&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;10&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; plt&lt;span style="color:#f92672"&gt;.&lt;/span&gt;tight_layout()
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; plt&lt;span style="color:#f92672"&gt;.&lt;/span&gt;savefig(&lt;span style="color:#e6db74"&gt;&amp;#39;autoencoder_results.png&amp;#39;&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; plt&lt;span style="color:#f92672"&gt;.&lt;/span&gt;show()
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#75715e"&gt;# テストデータで可視化&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;test_dataset &lt;span style="color:#f92672"&gt;=&lt;/span&gt; datasets&lt;span style="color:#f92672"&gt;.&lt;/span&gt;MNIST(
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; root&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#e6db74"&gt;&amp;#39;~/.pytorch/data&amp;#39;&lt;/span&gt;,
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; train&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#66d9ef"&gt;False&lt;/span&gt;,
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; download&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#66d9ef"&gt;True&lt;/span&gt;,
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; transform&lt;span style="color:#f92672"&gt;=&lt;/span&gt;transform
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;test_loader &lt;span style="color:#f92672"&gt;=&lt;/span&gt; DataLoader(test_dataset, batch_size&lt;span style="color:#f92672"&gt;=&lt;/span&gt;BATCH_SIZE, shuffle&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#66d9ef"&gt;False&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;visualize_results(model, test_loader)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#75715e"&gt;# 損失の推移をプロット&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;plt&lt;span style="color:#f92672"&gt;.&lt;/span&gt;figure(figsize&lt;span style="color:#f92672"&gt;=&lt;/span&gt;(&lt;span style="color:#ae81ff"&gt;10&lt;/span&gt;, &lt;span style="color:#ae81ff"&gt;5&lt;/span&gt;))
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;plt&lt;span style="color:#f92672"&gt;.&lt;/span&gt;plot(losses)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;plt&lt;span style="color:#f92672"&gt;.&lt;/span&gt;xlabel(&lt;span style="color:#e6db74"&gt;&amp;#39;Epoch&amp;#39;&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;plt&lt;span style="color:#f92672"&gt;.&lt;/span&gt;ylabel(&lt;span style="color:#e6db74"&gt;&amp;#39;Loss&amp;#39;&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;plt&lt;span style="color:#f92672"&gt;.&lt;/span&gt;title(&lt;span style="color:#e6db74"&gt;&amp;#39;Training Loss&amp;#39;&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;plt&lt;span style="color:#f92672"&gt;.&lt;/span&gt;grid(&lt;span style="color:#66d9ef"&gt;True&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;plt&lt;span style="color:#f92672"&gt;.&lt;/span&gt;savefig(&lt;span style="color:#e6db74"&gt;&amp;#39;autoencoder_loss.png&amp;#39;&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;plt&lt;span style="color:#f92672"&gt;.&lt;/span&gt;show()
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;h2 id="主要部分の解説"&gt;主要部分の解説&lt;/h2&gt;
&lt;h3 id="1-encoder符号化器"&gt;1. &lt;strong&gt;Encoder（符号化器）&lt;/strong&gt;&lt;/h3&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;-webkit-text-size-adjust:none;"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Conv2d(&lt;span style="color:#ae81ff"&gt;1&lt;/span&gt;, &lt;span style="color:#ae81ff"&gt;32&lt;/span&gt;, kernel_size&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;3&lt;/span&gt;, stride&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;2&lt;/span&gt;, padding&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;1&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;ul&gt;
&lt;li&gt;畳み込み層で画像の特徴を抽出&lt;/li&gt;
&lt;li&gt;&lt;code&gt;stride=2&lt;/code&gt;で画像サイズを半分に縮小&lt;/li&gt;
&lt;li&gt;徐々にチャネル数を増やして特徴を豊かに&lt;/li&gt;
&lt;/ul&gt;
&lt;h3 id="2-decoder復号化器"&gt;2. &lt;strong&gt;Decoder（復号化器）&lt;/strong&gt;&lt;/h3&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;-webkit-text-size-adjust:none;"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;ConvTranspose2d(&lt;span style="color:#ae81ff"&gt;128&lt;/span&gt;, &lt;span style="color:#ae81ff"&gt;64&lt;/span&gt;, kernel_size&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;4&lt;/span&gt;, stride&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;2&lt;/span&gt;, padding&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;1&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;ul&gt;
&lt;li&gt;転置畳み込みで画像を拡大&lt;/li&gt;
&lt;li&gt;Encoderの逆操作を実行&lt;/li&gt;
&lt;li&gt;最終的に元の画像サイズに復元&lt;/li&gt;
&lt;/ul&gt;
&lt;h3 id="3-損失関数"&gt;3. &lt;strong&gt;損失関数&lt;/strong&gt;&lt;/h3&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;-webkit-text-size-adjust:none;"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;criterion &lt;span style="color:#f92672"&gt;=&lt;/span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;MSELoss()
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;ul&gt;
&lt;li&gt;元画像と再構成画像の差を最小化&lt;/li&gt;
&lt;li&gt;二値画像には&lt;code&gt;BCELoss&lt;/code&gt;も使用可能&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;このコードを実行すると、画像の圧縮・復元が学習され、ノイズ除去や特徴抽出にも応用できます。&lt;/p&gt;</description></item><item><title>PyTorchによる二値画像AutoEncoder</title><link>https://ml.askbox.net/posts/binary-image-autoender/</link><pubDate>Sun, 08 Feb 2026 15:06:16 +0900</pubDate><guid>https://ml.askbox.net/posts/binary-image-autoender/</guid><description>&lt;h2 id="pytorchによる二値画像autoencoderのサンプルコード解説"&gt;PyTorchによる二値画像AutoEncoderのサンプルコード解説&lt;/h2&gt;
&lt;p&gt;二値画像（白黒画像）を扱うAutoEncoderの実装例を解説します。&lt;/p&gt;
&lt;h2 id="完全なサンプルコード"&gt;完全なサンプルコード&lt;/h2&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;-webkit-text-size-adjust:none;"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#f92672"&gt;import&lt;/span&gt; torch
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#f92672"&gt;import&lt;/span&gt; torch.nn &lt;span style="color:#66d9ef"&gt;as&lt;/span&gt; nn
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#f92672"&gt;import&lt;/span&gt; torch.optim &lt;span style="color:#66d9ef"&gt;as&lt;/span&gt; optim
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#f92672"&gt;from&lt;/span&gt; torch.utils.data &lt;span style="color:#f92672"&gt;import&lt;/span&gt; DataLoader
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#f92672"&gt;from&lt;/span&gt; torchvision &lt;span style="color:#f92672"&gt;import&lt;/span&gt; datasets, transforms
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#f92672"&gt;import&lt;/span&gt; matplotlib.pyplot &lt;span style="color:#66d9ef"&gt;as&lt;/span&gt; plt
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#75715e"&gt;# 1. AutoEncoderモデルの定義&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#66d9ef"&gt;class&lt;/span&gt; &lt;span style="color:#a6e22e"&gt;BinaryAutoEncoder&lt;/span&gt;(nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Module):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;def&lt;/span&gt; &lt;span style="color:#a6e22e"&gt;__init__&lt;/span&gt;(self, input_dim&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;784&lt;/span&gt;, hidden_dim&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;128&lt;/span&gt;, latent_dim&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;32&lt;/span&gt;):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; super(BinaryAutoEncoder, self)&lt;span style="color:#f92672"&gt;.&lt;/span&gt;&lt;span style="color:#a6e22e"&gt;__init__&lt;/span&gt;()
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# エンコーダ部分&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;encoder &lt;span style="color:#f92672"&gt;=&lt;/span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Sequential(
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Linear(input_dim, hidden_dim),
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;ReLU(),
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Linear(hidden_dim, latent_dim),
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;ReLU()
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; )
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# デコーダ部分&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;decoder &lt;span style="color:#f92672"&gt;=&lt;/span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Sequential(
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Linear(latent_dim, hidden_dim),
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;ReLU(),
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Linear(hidden_dim, input_dim),
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Sigmoid() &lt;span style="color:#75715e"&gt;# 0-1の範囲に出力を制限&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; )
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;def&lt;/span&gt; &lt;span style="color:#a6e22e"&gt;forward&lt;/span&gt;(self, x):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# エンコード&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; encoded &lt;span style="color:#f92672"&gt;=&lt;/span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;encoder(x)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# デコード&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; decoded &lt;span style="color:#f92672"&gt;=&lt;/span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;decoder(encoded)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;return&lt;/span&gt; decoded
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;def&lt;/span&gt; &lt;span style="color:#a6e22e"&gt;encode&lt;/span&gt;(self, x):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;return&lt;/span&gt; self&lt;span style="color:#f92672"&gt;.&lt;/span&gt;encoder(x)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#75715e"&gt;# 2. データの準備&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#66d9ef"&gt;def&lt;/span&gt; &lt;span style="color:#a6e22e"&gt;prepare_data&lt;/span&gt;(batch_size&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;128&lt;/span&gt;):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# MNISTデータセットを使用（二値化）&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; transform &lt;span style="color:#f92672"&gt;=&lt;/span&gt; transforms&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Compose([
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; transforms&lt;span style="color:#f92672"&gt;.&lt;/span&gt;ToTensor(),
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; transforms&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Lambda(&lt;span style="color:#66d9ef"&gt;lambda&lt;/span&gt; x: (x &lt;span style="color:#f92672"&gt;&amp;gt;&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;0.5&lt;/span&gt;)&lt;span style="color:#f92672"&gt;.&lt;/span&gt;float()) &lt;span style="color:#75715e"&gt;# 二値化&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; ])
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; train_dataset &lt;span style="color:#f92672"&gt;=&lt;/span&gt; datasets&lt;span style="color:#f92672"&gt;.&lt;/span&gt;MNIST(
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; root&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#e6db74"&gt;&amp;#39;~/.pytorch/data&amp;#39;&lt;/span&gt;, 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; train&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#66d9ef"&gt;True&lt;/span&gt;, 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; download&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#66d9ef"&gt;True&lt;/span&gt;, 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; transform&lt;span style="color:#f92672"&gt;=&lt;/span&gt;transform
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; )
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; test_dataset &lt;span style="color:#f92672"&gt;=&lt;/span&gt; datasets&lt;span style="color:#f92672"&gt;.&lt;/span&gt;MNIST(
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; root&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#e6db74"&gt;&amp;#39;~/.pytorch/data&amp;#39;&lt;/span&gt;, 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; train&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#66d9ef"&gt;False&lt;/span&gt;, 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; download&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#66d9ef"&gt;True&lt;/span&gt;, 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; transform&lt;span style="color:#f92672"&gt;=&lt;/span&gt;transform
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; )
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; train_loader &lt;span style="color:#f92672"&gt;=&lt;/span&gt; DataLoader(train_dataset, batch_size&lt;span style="color:#f92672"&gt;=&lt;/span&gt;batch_size, shuffle&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#66d9ef"&gt;True&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; test_loader &lt;span style="color:#f92672"&gt;=&lt;/span&gt; DataLoader(test_dataset, batch_size&lt;span style="color:#f92672"&gt;=&lt;/span&gt;batch_size, shuffle&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#66d9ef"&gt;False&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;return&lt;/span&gt; train_loader, test_loader
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#75715e"&gt;# 3. 訓練関数&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#66d9ef"&gt;def&lt;/span&gt; &lt;span style="color:#a6e22e"&gt;train&lt;/span&gt;(model, train_loader, optimizer, criterion, device):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; model&lt;span style="color:#f92672"&gt;.&lt;/span&gt;train()
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; total_loss &lt;span style="color:#f92672"&gt;=&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;0&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;for&lt;/span&gt; batch_idx, (data, _) &lt;span style="color:#f92672"&gt;in&lt;/span&gt; enumerate(train_loader):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# データを平坦化 (batch_size, 1, 28, 28) -&amp;gt; (batch_size, 784)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; data &lt;span style="color:#f92672"&gt;=&lt;/span&gt; data&lt;span style="color:#f92672"&gt;.&lt;/span&gt;view(data&lt;span style="color:#f92672"&gt;.&lt;/span&gt;size(&lt;span style="color:#ae81ff"&gt;0&lt;/span&gt;), &lt;span style="color:#f92672"&gt;-&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;1&lt;/span&gt;)&lt;span style="color:#f92672"&gt;.&lt;/span&gt;to(device)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# 勾配をゼロに&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; optimizer&lt;span style="color:#f92672"&gt;.&lt;/span&gt;zero_grad()
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# 順伝播&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; output &lt;span style="color:#f92672"&gt;=&lt;/span&gt; model(data)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# 損失計算&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; loss &lt;span style="color:#f92672"&gt;=&lt;/span&gt; criterion(output, data)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# 逆伝播&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; loss&lt;span style="color:#f92672"&gt;.&lt;/span&gt;backward()
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# パラメータ更新&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; optimizer&lt;span style="color:#f92672"&gt;.&lt;/span&gt;step()
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; total_loss &lt;span style="color:#f92672"&gt;+=&lt;/span&gt; loss&lt;span style="color:#f92672"&gt;.&lt;/span&gt;item()
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;return&lt;/span&gt; total_loss &lt;span style="color:#f92672"&gt;/&lt;/span&gt; len(train_loader)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#75715e"&gt;# 4. 評価関数&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#66d9ef"&gt;def&lt;/span&gt; &lt;span style="color:#a6e22e"&gt;evaluate&lt;/span&gt;(model, test_loader, criterion, device):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; model&lt;span style="color:#f92672"&gt;.&lt;/span&gt;eval()
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; total_loss &lt;span style="color:#f92672"&gt;=&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;0&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;with&lt;/span&gt; torch&lt;span style="color:#f92672"&gt;.&lt;/span&gt;no_grad():
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;for&lt;/span&gt; data, _ &lt;span style="color:#f92672"&gt;in&lt;/span&gt; test_loader:
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; data &lt;span style="color:#f92672"&gt;=&lt;/span&gt; data&lt;span style="color:#f92672"&gt;.&lt;/span&gt;view(data&lt;span style="color:#f92672"&gt;.&lt;/span&gt;size(&lt;span style="color:#ae81ff"&gt;0&lt;/span&gt;), &lt;span style="color:#f92672"&gt;-&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;1&lt;/span&gt;)&lt;span style="color:#f92672"&gt;.&lt;/span&gt;to(device)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; output &lt;span style="color:#f92672"&gt;=&lt;/span&gt; model(data)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; loss &lt;span style="color:#f92672"&gt;=&lt;/span&gt; criterion(output, data)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; total_loss &lt;span style="color:#f92672"&gt;+=&lt;/span&gt; loss&lt;span style="color:#f92672"&gt;.&lt;/span&gt;item()
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;return&lt;/span&gt; total_loss &lt;span style="color:#f92672"&gt;/&lt;/span&gt; len(test_loader)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#75715e"&gt;# 5. 結果の可視化&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#66d9ef"&gt;def&lt;/span&gt; &lt;span style="color:#a6e22e"&gt;visualize_results&lt;/span&gt;(model, test_loader, device, num_images&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;10&lt;/span&gt;):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; model&lt;span style="color:#f92672"&gt;.&lt;/span&gt;eval()
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;with&lt;/span&gt; torch&lt;span style="color:#f92672"&gt;.&lt;/span&gt;no_grad():
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; data, _ &lt;span style="color:#f92672"&gt;=&lt;/span&gt; next(iter(test_loader))
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; data &lt;span style="color:#f92672"&gt;=&lt;/span&gt; data[:num_images]&lt;span style="color:#f92672"&gt;.&lt;/span&gt;to(device)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; data_flat &lt;span style="color:#f92672"&gt;=&lt;/span&gt; data&lt;span style="color:#f92672"&gt;.&lt;/span&gt;view(data&lt;span style="color:#f92672"&gt;.&lt;/span&gt;size(&lt;span style="color:#ae81ff"&gt;0&lt;/span&gt;), &lt;span style="color:#f92672"&gt;-&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;1&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# 再構成&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; reconstructed &lt;span style="color:#f92672"&gt;=&lt;/span&gt; model(data_flat)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# 画像を元の形状に戻す&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; data &lt;span style="color:#f92672"&gt;=&lt;/span&gt; data&lt;span style="color:#f92672"&gt;.&lt;/span&gt;cpu()&lt;span style="color:#f92672"&gt;.&lt;/span&gt;view(&lt;span style="color:#f92672"&gt;-&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;1&lt;/span&gt;, &lt;span style="color:#ae81ff"&gt;28&lt;/span&gt;, &lt;span style="color:#ae81ff"&gt;28&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; reconstructed &lt;span style="color:#f92672"&gt;=&lt;/span&gt; reconstructed&lt;span style="color:#f92672"&gt;.&lt;/span&gt;cpu()&lt;span style="color:#f92672"&gt;.&lt;/span&gt;view(&lt;span style="color:#f92672"&gt;-&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;1&lt;/span&gt;, &lt;span style="color:#ae81ff"&gt;28&lt;/span&gt;, &lt;span style="color:#ae81ff"&gt;28&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# プロット&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; fig, axes &lt;span style="color:#f92672"&gt;=&lt;/span&gt; plt&lt;span style="color:#f92672"&gt;.&lt;/span&gt;subplots(&lt;span style="color:#ae81ff"&gt;2&lt;/span&gt;, num_images, figsize&lt;span style="color:#f92672"&gt;=&lt;/span&gt;(&lt;span style="color:#ae81ff"&gt;15&lt;/span&gt;, &lt;span style="color:#ae81ff"&gt;3&lt;/span&gt;))
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;for&lt;/span&gt; i &lt;span style="color:#f92672"&gt;in&lt;/span&gt; range(num_images):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# 元画像&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; axes[&lt;span style="color:#ae81ff"&gt;0&lt;/span&gt;, i]&lt;span style="color:#f92672"&gt;.&lt;/span&gt;imshow(data[i], cmap&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#e6db74"&gt;&amp;#39;gray&amp;#39;&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; axes[&lt;span style="color:#ae81ff"&gt;0&lt;/span&gt;, i]&lt;span style="color:#f92672"&gt;.&lt;/span&gt;axis(&lt;span style="color:#e6db74"&gt;&amp;#39;off&amp;#39;&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;if&lt;/span&gt; i &lt;span style="color:#f92672"&gt;==&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;0&lt;/span&gt;:
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; axes[&lt;span style="color:#ae81ff"&gt;0&lt;/span&gt;, i]&lt;span style="color:#f92672"&gt;.&lt;/span&gt;set_title(&lt;span style="color:#e6db74"&gt;&amp;#39;Original&amp;#39;&lt;/span&gt;, fontsize&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;10&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# 再構成画像&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; axes[&lt;span style="color:#ae81ff"&gt;1&lt;/span&gt;, i]&lt;span style="color:#f92672"&gt;.&lt;/span&gt;imshow(reconstructed[i], cmap&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#e6db74"&gt;&amp;#39;gray&amp;#39;&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; axes[&lt;span style="color:#ae81ff"&gt;1&lt;/span&gt;, i]&lt;span style="color:#f92672"&gt;.&lt;/span&gt;axis(&lt;span style="color:#e6db74"&gt;&amp;#39;off&amp;#39;&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;if&lt;/span&gt; i &lt;span style="color:#f92672"&gt;==&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;0&lt;/span&gt;:
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; axes[&lt;span style="color:#ae81ff"&gt;1&lt;/span&gt;, i]&lt;span style="color:#f92672"&gt;.&lt;/span&gt;set_title(&lt;span style="color:#e6db74"&gt;&amp;#39;Reconstructed&amp;#39;&lt;/span&gt;, fontsize&lt;span style="color:#f92672"&gt;=&lt;/span&gt;&lt;span style="color:#ae81ff"&gt;10&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; plt&lt;span style="color:#f92672"&gt;.&lt;/span&gt;tight_layout()
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; plt&lt;span style="color:#f92672"&gt;.&lt;/span&gt;savefig(&lt;span style="color:#e6db74"&gt;&amp;#39;autoencoder_results.png&amp;#39;&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; plt&lt;span style="color:#f92672"&gt;.&lt;/span&gt;show()
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#75715e"&gt;# 6. メイン実行部分&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#66d9ef"&gt;def&lt;/span&gt; &lt;span style="color:#a6e22e"&gt;main&lt;/span&gt;():
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# ハイパーパラメータ&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; input_dim &lt;span style="color:#f92672"&gt;=&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;784&lt;/span&gt; &lt;span style="color:#75715e"&gt;# 28x28&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; hidden_dim &lt;span style="color:#f92672"&gt;=&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;128&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; latent_dim &lt;span style="color:#f92672"&gt;=&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;32&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; epochs &lt;span style="color:#f92672"&gt;=&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;10&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; learning_rate &lt;span style="color:#f92672"&gt;=&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;0.001&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; batch_size &lt;span style="color:#f92672"&gt;=&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;128&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# デバイス設定&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; device &lt;span style="color:#f92672"&gt;=&lt;/span&gt; torch&lt;span style="color:#f92672"&gt;.&lt;/span&gt;device(&lt;span style="color:#e6db74"&gt;&amp;#39;cuda&amp;#39;&lt;/span&gt; &lt;span style="color:#66d9ef"&gt;if&lt;/span&gt; torch&lt;span style="color:#f92672"&gt;.&lt;/span&gt;cuda&lt;span style="color:#f92672"&gt;.&lt;/span&gt;is_available() &lt;span style="color:#66d9ef"&gt;else&lt;/span&gt; &lt;span style="color:#e6db74"&gt;&amp;#39;cpu&amp;#39;&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; print(&lt;span style="color:#e6db74"&gt;f&lt;/span&gt;&lt;span style="color:#e6db74"&gt;&amp;#34;Using device: &lt;/span&gt;&lt;span style="color:#e6db74"&gt;{&lt;/span&gt;device&lt;span style="color:#e6db74"&gt;}&lt;/span&gt;&lt;span style="color:#e6db74"&gt;&amp;#34;&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# データ準備&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; train_loader, test_loader &lt;span style="color:#f92672"&gt;=&lt;/span&gt; prepare_data(batch_size)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# モデル初期化&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; model &lt;span style="color:#f92672"&gt;=&lt;/span&gt; BinaryAutoEncoder(input_dim, hidden_dim, latent_dim)&lt;span style="color:#f92672"&gt;.&lt;/span&gt;to(device)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# 損失関数（二値交差エントロピー）&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; criterion &lt;span style="color:#f92672"&gt;=&lt;/span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;BCELoss()
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# 最適化手法&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; optimizer &lt;span style="color:#f92672"&gt;=&lt;/span&gt; optim&lt;span style="color:#f92672"&gt;.&lt;/span&gt;Adam(model&lt;span style="color:#f92672"&gt;.&lt;/span&gt;parameters(), lr&lt;span style="color:#f92672"&gt;=&lt;/span&gt;learning_rate)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# 訓練ループ&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; print(&lt;span style="color:#e6db74"&gt;&amp;#34;Training started...&amp;#34;&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#66d9ef"&gt;for&lt;/span&gt; epoch &lt;span style="color:#f92672"&gt;in&lt;/span&gt; range(&lt;span style="color:#ae81ff"&gt;1&lt;/span&gt;, epochs &lt;span style="color:#f92672"&gt;+&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;1&lt;/span&gt;):
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; train_loss &lt;span style="color:#f92672"&gt;=&lt;/span&gt; train(model, train_loader, optimizer, criterion, device)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; test_loss &lt;span style="color:#f92672"&gt;=&lt;/span&gt; evaluate(model, test_loader, criterion, device)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; print(&lt;span style="color:#e6db74"&gt;f&lt;/span&gt;&lt;span style="color:#e6db74"&gt;&amp;#39;Epoch [&lt;/span&gt;&lt;span style="color:#e6db74"&gt;{&lt;/span&gt;epoch&lt;span style="color:#e6db74"&gt;}&lt;/span&gt;&lt;span style="color:#e6db74"&gt;/&lt;/span&gt;&lt;span style="color:#e6db74"&gt;{&lt;/span&gt;epochs&lt;span style="color:#e6db74"&gt;}&lt;/span&gt;&lt;span style="color:#e6db74"&gt;], &amp;#39;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#e6db74"&gt;f&lt;/span&gt;&lt;span style="color:#e6db74"&gt;&amp;#39;Train Loss: &lt;/span&gt;&lt;span style="color:#e6db74"&gt;{&lt;/span&gt;train_loss&lt;span style="color:#e6db74"&gt;:&lt;/span&gt;&lt;span style="color:#e6db74"&gt;.4f&lt;/span&gt;&lt;span style="color:#e6db74"&gt;}&lt;/span&gt;&lt;span style="color:#e6db74"&gt;, &amp;#39;&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#e6db74"&gt;f&lt;/span&gt;&lt;span style="color:#e6db74"&gt;&amp;#39;Test Loss: &lt;/span&gt;&lt;span style="color:#e6db74"&gt;{&lt;/span&gt;test_loss&lt;span style="color:#e6db74"&gt;:&lt;/span&gt;&lt;span style="color:#e6db74"&gt;.4f&lt;/span&gt;&lt;span style="color:#e6db74"&gt;}&lt;/span&gt;&lt;span style="color:#e6db74"&gt;&amp;#39;&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# 結果の可視化&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; visualize_results(model, test_loader, device)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; 
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; &lt;span style="color:#75715e"&gt;# モデルの保存&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; torch&lt;span style="color:#f92672"&gt;.&lt;/span&gt;save(model&lt;span style="color:#f92672"&gt;.&lt;/span&gt;state_dict(), &lt;span style="color:#e6db74"&gt;&amp;#39;binary_autoencoder.pth&amp;#39;&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; print(&lt;span style="color:#e6db74"&gt;&amp;#34;Model saved!&amp;#34;&lt;/span&gt;)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#66d9ef"&gt;if&lt;/span&gt; __name__ &lt;span style="color:#f92672"&gt;==&lt;/span&gt; &lt;span style="color:#e6db74"&gt;&amp;#39;__main__&amp;#39;&lt;/span&gt;:
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt; main()
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;h2 id="主要部分の解説"&gt;主要部分の解説&lt;/h2&gt;
&lt;h3 id="1-モデル構造"&gt;1. &lt;strong&gt;モデル構造&lt;/strong&gt;&lt;/h3&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;-webkit-text-size-adjust:none;"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;Encoder: &lt;span style="color:#ae81ff"&gt;784&lt;/span&gt; &lt;span style="color:#960050;background-color:#1e0010"&gt;→&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;128&lt;/span&gt; &lt;span style="color:#960050;background-color:#1e0010"&gt;→&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;32&lt;/span&gt; (次元削減)
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;Decoder: &lt;span style="color:#ae81ff"&gt;32&lt;/span&gt; &lt;span style="color:#960050;background-color:#1e0010"&gt;→&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;128&lt;/span&gt; &lt;span style="color:#960050;background-color:#1e0010"&gt;→&lt;/span&gt; &lt;span style="color:#ae81ff"&gt;784&lt;/span&gt; (次元復元)
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;h3 id="2-重要なポイント"&gt;2. &lt;strong&gt;重要なポイント&lt;/strong&gt;&lt;/h3&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;Sigmoidの使用&lt;/strong&gt;: 出力を0-1の範囲に制限&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;BCELoss&lt;/strong&gt;: 二値データに適した損失関数&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;二値化&lt;/strong&gt;: &lt;code&gt;(x &amp;gt; 0.5).float()&lt;/code&gt;で画像を白黒に変換&lt;/li&gt;
&lt;/ul&gt;
&lt;h3 id="3-損失関数の選択"&gt;3. &lt;strong&gt;損失関数の選択&lt;/strong&gt;&lt;/h3&gt;
&lt;div class="highlight"&gt;&lt;pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;-webkit-text-size-adjust:none;"&gt;&lt;code class="language-python" data-lang="python"&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;&lt;span style="color:#75715e"&gt;# 二値交差エントロピー損失&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span style="display:flex;"&gt;&lt;span&gt;criterion &lt;span style="color:#f92672"&gt;=&lt;/span&gt; nn&lt;span style="color:#f92672"&gt;.&lt;/span&gt;BCELoss()
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;このコードをそのまま実行すれば、MNISTータセットで二値画像のAutoEncoderを訓練できます！&lt;/p&gt;</description></item></channel></rss>