Pytorchで実装する深層畳み込みGAN(DCGAN):画像生成の品質向上へ

前回の記事「PyTorchでGAN」では、全結合層(Fully Connected Layer)のみを用いた非常に基本的なGANの実装を紹介しました。

しかし、全結合層だけのモデルでは、画像の特徴である「形状」や「空間的なつながり」を捉えるのが苦手で、「生成画像にノイズが多い」「形が崩れる」といった課題がありました。

そこで今回は、画像の扱いに長けた畳み込み層(Convolutional Layer)を取り入れ、より高品質な画像を生成できるモデルへと改良していきます。一般的に DCGAN (Deep Convolutional GAN)1 と呼ばれるモデル構造をベースに、PyTorchでの実装方法を解説します。

本記事では、MNIST(手書き数字)データセットを使って、ゼロから画像を生成する最も基本的なGANをコード実装と共に解説します。

全結合から畳み込みへ:何が変わるのか?

大きな変更点は、GeneratorとDiscriminatorの内部構造です。

  • Generator(生成器): 以前は全結合層で一気に画素数分の値を出力していましたが、今回はTransposed Convolution(転置畳み込み)を使用します。これは「逆畳み込み」とも呼ばれ、小さな特徴マップを引き伸ばしながら(アップサンプリング)、徐々に画像サイズへと拡大していく手法です。尤も、転地畳み込みは格子状のアーティファクト発生を発生させる懸念があるので、近年では代わりに線形補完による拡大と畳み込み層による処理が使用されることも多いです。
  • Discriminator(識別器): 画像認識でおなじみのConvolution(畳み込み)を使用し、画像の特徴を抽出しながら徐々にサイズを圧縮(ダウンサンプリング)して真贋を判定します。

それでは、具体的な実装を見ていきましょう。

モジュールのインポートとグローバル設定

使用するモジュールと画像サイズ、潜在変数次元数などの指定をします。

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

# ハイパーパラメータの設定
img_rows = 28
img_cols = 28
channels = 1
img_shape = (channels, img_rows, img_cols) # PyTorchは (C, H, W) 形式
z_dim = 100 # Latent Space(潜在空間)の次元数

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

モデルの定義(Generator と Discriminator)

Generator(生成器)の実装

Generatorは、潜在空間(Latent Space)のランダムベクトル z を入力とし、それを徐々に拡大して 28×28の画像にします。

今回の実装は以下の通りです。

全結合層: 入力ベクトル z を受け取り、それを7×7のサイズで256チャンネル持つデータに変形できるよう、ニューロン数を調整して拡張します。

Reshape: 全結合層の出力を (256, 7, 7) のテンソルに変形します。これが画像の「種」となります。

Transposed Convolutionブロック: 畳み込みの逆操作を行い、特徴マップのサイズを7×7→14×14→28×28と倍々に拡大していきます。

出力層: 最後に Tanh 関数を通し、画素値を ー1~1 の範囲に正規化します

class Generator(nn.Module):
    def __init__(self, z_dim):
        super(Generator, self).__init__()

        # 1. 入力ベクトルを 7x7x256 のテンソルに変換するための全結合層
        # PyTorchでは (Batch, Channels, Height, Width) の順になるため (256, 7, 7) を目指します
        self.fc = nn.Linear(z_dim, 256 * 7 * 7)

        self.deconv_blocks = nn.Sequential(
            # 2. Transposed Convolution: 7x7x256 -> 14x14x128
            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.01, inplace=True),

            # 3. Transposed Convolution: 14x14x128 -> 14x14x64
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.01, inplace=True),

            # 4. Transposed Convolution: 14x14x64 -> 28x28x1
            nn.ConvTranspose2d(64, 1, kernel_size=3, stride=2, padding=1, output_padding=1),
            # 出力層の Tanh Activation
            nn.Tanh()
        )

    def forward(self, z):
        # 全結合層を通してフラットなベクトルを出力
        x = self.fc(z)
        # テンソルを (Batch, Channels, Height, Width) にリシェイプ
        x = x.view(-1, 256, 7, 7)
        # 転置畳み込み層のブロックを通過
        img = self.deconv_blocks(x)
        return img

Discriminator(識別器)の実装

Discriminatorは、Generatorとは逆に、入力された画像(28×28)を畳み込み層によって圧縮していきます。

Convolutionブロック: カーネルサイズ3、ストライド2の畳み込み層を重ねることで、画像サイズを半分ずつに圧縮します(28×28→14×14→7×7→4×4)。同時に、チャンネル数(特徴の数)を増やし、画像の抽象的な特徴を抽出します。

Flatten: 最終的な特徴マップ(4×4×128)を1次元のベクトルに平坦化します。

出力層: 全結合層を経て、最後に Sigmoid 関数を通すことで、その画像が本物である確率(0.0 〜 1.0)を出力します。

class Discriminator(nn.Module):
    def __init__(self, img_shape):
        super(Discriminator, self).__init__()
        
        # img_shape は通常 (Channels, Height, Width) = (1, 28, 28)
        channels = img_shape[0]

        self.conv_blocks = nn.Sequential(
            # 1. Convolution: 28x28x1 -> 14x14x32
            nn.Conv2d(channels, 32, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(0.01, inplace=True),

            # 2. Convolution: 14x14x32 -> 7x7x64
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.01, inplace=True),

            # 3. Convolution: 7x7x64 -> 4x4x128 (※Kerasの計算に合わせるため端数は調整されます)
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.01, inplace=True),

            # 4. Flatten & Output Layer
            nn.Flatten(),
            # 畳み込み後の特徴マップサイズに合わせて全結合層の入力次元を決定
            # 28 -> 14 -> 7 -> 4 とサイズが変化するため、4*4*128 となります
            nn.Linear(128 * 4 * 4, 1),
            # Sigmoid Activation
            nn.Sigmoid()
        )

    def forward(self, img):
        validity = self.conv_blocks(img)
        return validity

モデルの構築と最適化アルゴリズムの設定

生成器と識別器を個別に学習するための最適化アルゴリズム設定を行います。ここでは損失関数として真贋判定のためのBinary Cross Entropyを指定しています。詳細が気になる方はこちらの記事を参照してください:なぜBCE(Binary Cross Entropy)でGANが学習できるのか

# インスタンス化
generator = Generator(z_dim, img_shape).to(device)
discriminator = Discriminator(img_shape).to(device)

# 損失関数: Binary Cross Entropy Loss
criterion = nn.BCELoss()

# Optimizer(最適化アルゴリズム): Adam
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

学習ループ

モデルの構造は複雑になりましたが、学習の仕組み自体(損失関数やオプティマイザの設定、学習ループ)は前回のシンプルGANと全く同じです。 PyTorchでは nn.Module を継承してモデルを定義しているため、中身が全結合であろうと畳み込みであろうと、同じ学習コードを利用できるのが大きな利点です。

実行すると、最初はノイズだった画像が徐々に手書き数字に変化していく様子がプロットされます(しかし、モデルの表現力がシンプルさ重視で高くはないため、ノイズ性が強い画像の出力となります)。

def train(iterations, batch_size, sample_interval):
    # 保存用ディレクトリの作成(なければ作る)
    os.makedirs("checkpoints", exist_ok=True)

    # MNISTデータの読み込みと前処理 ([-1, 1]に正規化)
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])
    ])
    dataloader = DataLoader(
        datasets.MNIST('.', train=True, download=True, transform=transform),
        batch_size=batch_size, shuffle=True, drop_last=True
    )

    losses = []
    iteration_checkpoints = []

    # イテレータの作成
    data_iter = iter(dataloader)

    for iteration in range(iterations+1):
        
        # --- データの準備 ---
        try:
            real_imgs, _ = next(data_iter)
        except StopIteration:
            data_iter = iter(dataloader)
            real_imgs, _ = next(data_iter)

        real_imgs = real_imgs.to(device)
        # 教師ラベル: 本物は1, 偽物は0
        valid = torch.ones(batch_size, 1).to(device)
        fake = torch.zeros(batch_size, 1).to(device)

        # -------------------------
        #  Train Discriminator
        # -------------------------
        optimizer_D.zero_grad()

        # 本物の画像の判定
        d_loss_real = criterion(discriminator(real_imgs), valid)
        
        # 偽物の画像の生成と判定
        z = torch.randn(batch_size, z_dim).to(device)
        gen_imgs = generator(z)
        d_loss_fake = criterion(discriminator(gen_imgs.detach()), fake)

        # 誤差逆伝播と更新
        d_loss = 0.5 * (d_loss_real + d_loss_fake)
        d_loss.backward()
        optimizer_D.step()

        # ---------------------
        #  Train Generator
        # ---------------------
        optimizer_G.zero_grad()

        # 偽物を「本物(1)」と判定させることが目標
        g_loss = criterion(discriminator(gen_imgs), valid)

        g_loss.backward()
        optimizer_G.step()

        # ---------------------
        #  Progress Logging & Saving
        # ---------------------
        if (iteration) % sample_interval == 0:
            print(f"{iteration} [D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}]")
            losses.append((d_loss.item(), g_loss.item()))
            iteration_checkpoints.append(iteration + 1)
            
            # 画像の保存・表示
            sample_images(generator)
            
            # === モデルの保存処理 (ここを追加) ===
            # GeneratorとDiscriminatorの重みをそれぞれ保存
            torch.save(generator.state_dict(), f"checkpoints/generator_{iteration}.pth")
            torch.save(discriminator.state_dict(), f"checkpoints/discriminator_{iteration}.pth")
            print(f"Saved model checkpoints at iteration {iteration}")
            
    return np.array(losses), iteration_checkpoints

# sample_images関数などはそのまま変更なしでOK

def sample_images(generator, rows=4, cols=4):
    generator.eval()
    z = torch.randn(rows * cols, z_dim).to(device)
    gen_imgs = generator(z).detach().cpu()
    gen_imgs = 0.5 * gen_imgs + 0.5 # [-1, 1] -> [0, 1]

    fig, axs = plt.subplots(rows, cols, figsize=(4, 4))
    cnt = 0
    for i in range(rows):
        for j in range(cols):
            axs[i, j].imshow(gen_imgs[cnt, 0, :, :], cmap='gray')
            axs[i, j].axis('off')
            cnt += 1
    plt.show()

# 実行
losses, checkpoints = train(iterations=10000, batch_size=128, sample_interval=1000)

出力画像例

sample_intervalで指定したエポックごとにランダムな潜在変数から生成された画像が出力されると同時にモデルのチェックポイントが保存されます。

前回の記事「PyTorchでGAN」の出力と比較すると、ノイズの少なさや数字らしさが格段に向上していることがわかります。

学習済み生成器を用いた画像生成

下記のコードを利用すれば学習済みの生成器を用いて画像を生成することが可能です。

# --- 設定 ---
# 読み込みたいモデルのパス
checkpoint_path = "checkpoints/generator_10000.pth" 

# --- 1. モデルの枠組みを用意して、重みをロード ---
# まずモデルの構造(箱)を作る
generator = Generator(z_dim, img_shape).to(device)

# 指定したパスから重み(中身)を読み込みむ
# map_location=device をつけることで、GPUで保存したモデルをCPUでも開けるようしている
generator.load_state_dict(torch.load(checkpoint_path, map_location=device))

# --- 2. Generatorで画像を生成 ---
generator.eval() # 推論モードへ(BatchNormなどを固定)

# ランダムなノイズを1つ作成 (バッチサイズ=1, 次元=100)
z = torch.randn(1, z_dim).to(device)

with torch.no_grad(): # 勾配計算をしない(メモリ節約)
    gen_img = generator(z)

# --- 3. 表示用に整形 ---
# (1, 1, 28, 28) -> (28, 28) に不要な次元を削除してNumpy化
image = gen_img.squeeze().cpu().numpy()

# 画素値を [-1, 1] から [0, 1] に戻す
image = 0.5 * image + 0.5

# --- 4. 表示 ---
plt.imshow(image, cmap='gray')
plt.axis('off')
plt.show()

おわりに

今回は、GANの内部構造を畳み込みニューラルネットワーク(CNN)ベースに変更しました。

全結合層だけのモデルと比較して、以下の点が改善されているはずです。

  • 空間的特徴の学習: 画素の並びを意識して学習するため、数字の形状がよりくっきりとする。
  • パラメータの効率化: 全結合層ですべてを接続するよりも、畳み込み層の方が画像処理において効率よくパラメータを使える場合が多い。

この構造は、より高解像度なカラー画像の生成など、本格的な画像生成AIへの第一歩となります。ぜひ手元の環境で動かして、生成される数字のクオリティの違いを体感してみてください

  1. Radford, A., Metz, L., & Chintala, S. (2015). Unsupervised representation learning with deep convolutional generative adversarial networks. arXiv preprint arXiv:1511.06434. ↩︎

コメント