PytorchでGAN

AIによる画像生成技術の原点ともいえるGAN(Generative Adversarial Network:敵対的生成ネットワーク)

論文や数式を見ると難解に感じますが、Python(PyTorch)のコードに落とし込んでみると、その仕組みは驚くほどシンプルです。

本記事では、MNIST(手書き数字)データセットを使って、ゼロから画像を生成する最も基本的なGANを実装・解説します。

GANの仕組みを超ざっくり解説

GANは、以下の2つのAI(ネットワーク)を競わせながら学習させます。

  • Generator(生成器):ランダムなノイズから、本物そっくりの偽物画像を作る「贋作師」
  • Discriminator(識別器):画像を見て、それが本物か偽物かを見破る「鑑定士」

最初はノイズしか作れないGeneratorですが、Discriminatorに見破られるたびに賢くなり、最終的には人間が見ても本物と見分けがつかない画像を生成できるようになります。

モジュールのインポートとグローバル設定

使用するモジュールと画像サイズ、潜在変数次元数などの指定をします。

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

# ハイパーパラメータの設定
img_rows = 28
img_cols = 28
channels = 1
img_shape = (channels, img_rows, img_cols) # PyTorchは (C, H, W) 形式
z_dim = 100 # Latent Space(潜在空間)の次元数

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

モデルの定義(Generator と Discriminator)

Generator (生成器) の実装

潜在空間(Latent Space)と呼ばれるランダムなノイズベクトル(ここでは100次元)を受け取り、28×28ピクセルの画像を出力します。最後にTanh関数を通すことで、出力される画像の明るさを-1から1の範囲に収めています。

class Generator(nn.Module):
    def __init__(self, z_dim, img_shape):
        super(Generator, self).__init__()
        self.img_shape = img_shape

        self.model = nn.Sequential(
            nn.Linear(z_dim, 128),
            nn.LeakyReLU(0.01, inplace=True),
            nn.Linear(128, int(np.prod(img_shape))),
            # Tanh Activation(出力を[-1, 1]の範囲に収める)
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        # 1次元の出力を画像サイズ (Channels, Height, Width) に変形
        return img.view(img.size(0), *self.img_shape)

Discriminator(識別器)の実装

次に、画像を見破る Discriminator です。 入力は 28x28 の画像、出力は「その画像が本物である確率(0%〜100%)」です。画像分類モデルを構築したことがあるかたならおなじみの構造で、最後にSigmoid関数を通すことで、出力を0(完全に偽物)から1(完全に本物)という確率に変換します。

class Discriminator(nn.Module):
    def __init__(self, img_shape):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Flatten(), # 画像を1次元ベクトルに変換
            nn.Linear(int(np.prod(img_shape)), 128),
            nn.LeakyReLU(0.01, inplace=True),
            nn.Linear(128, 1),
            # Sigmoid Activation([0, 1]の確率値として出力)
            nn.Sigmoid()
        )

    def forward(self, img):
        return self.model(img)

モデルの構築と最適化アルゴリズムの設定

生成器と識別器を個別に学習するための最適化アルゴリズム設定を行います。ここでは損失関数として真贋判定のためのBinary Cross Entropyを指定しています。

# インスタンス化
generator = Generator(z_dim, img_shape).to(device)
discriminator = Discriminator(img_shape).to(device)

# 損失関数: Binary Cross Entropy Loss
criterion = nn.BCELoss()

# Optimizer(最適化アルゴリズム): Adam
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

学習ループ

コードが長く見えますが、やっていることは単純です。GeneratorはDiscriminatorによる真贋判定をだませるように学習する一方で、DiscriminatorはGeneratorが作成した画像の真贋をうまく見破れるよう学習を行っています。

学習過程の詳細が気になる方はこちらの記事を参照してください:なぜBCE(Binary Cross Entropy)でGANが学習できるのか

実行すると、最初はノイズだった画像が徐々に手書き数字に変化していく様子がプロットされます(しかし、モデルの表現力がシンプルさ重視で高くはないため、ノイズ性が強い画像の出力となります)。

def train(iterations, batch_size, sample_interval):
    # 保存用ディレクトリの作成(なければ作る)
    os.makedirs("checkpoints", exist_ok=True)

    # MNISTデータの読み込みと前処理 ([-1, 1]に正規化)
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])
    ])
    dataloader = DataLoader(
        datasets.MNIST('.', train=True, download=True, transform=transform),
        batch_size=batch_size, shuffle=True, drop_last=True
    )

    losses = []
    iteration_checkpoints = []

    # イテレータの作成
    data_iter = iter(dataloader)

    for iteration in range(iterations+1):
        
        # --- データの準備 ---
        try:
            real_imgs, _ = next(data_iter)
        except StopIteration:
            data_iter = iter(dataloader)
            real_imgs, _ = next(data_iter)

        real_imgs = real_imgs.to(device)
        # 教師ラベル: 本物は1, 偽物は0
        valid = torch.ones(batch_size, 1).to(device)
        fake = torch.zeros(batch_size, 1).to(device)

        # -------------------------
        #  Train Discriminator
        # -------------------------
        optimizer_D.zero_grad()

        # 本物の画像の判定
        d_loss_real = criterion(discriminator(real_imgs), valid)
        
        # 偽物の画像の生成と判定
        z = torch.randn(batch_size, z_dim).to(device)
        gen_imgs = generator(z)
        d_loss_fake = criterion(discriminator(gen_imgs.detach()), fake)

        # 誤差逆伝播と更新
        d_loss = 0.5 * (d_loss_real + d_loss_fake)
        d_loss.backward()
        optimizer_D.step()

        # ---------------------
        #  Train Generator
        # ---------------------
        optimizer_G.zero_grad()

        # 偽物を「本物(1)」と判定させることが目標
        g_loss = criterion(discriminator(gen_imgs), valid)

        g_loss.backward()
        optimizer_G.step()

        # ---------------------
        #  Progress Logging & Saving
        # ---------------------
        if (iteration) % sample_interval == 0:
            print(f"{iteration} [D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}]")
            losses.append((d_loss.item(), g_loss.item()))
            iteration_checkpoints.append(iteration + 1)
            
            # 画像の保存・表示
            sample_images(generator)
            
            # === モデルの保存処理 (ここを追加) ===
            # GeneratorとDiscriminatorの重みをそれぞれ保存
            torch.save(generator.state_dict(), f"checkpoints/generator_{iteration}.pth")
            torch.save(discriminator.state_dict(), f"checkpoints/discriminator_{iteration}.pth")
            print(f"Saved model checkpoints at iteration {iteration}")
            
    return np.array(losses), iteration_checkpoints

# sample_images関数などはそのまま変更なしでOK

def sample_images(generator, rows=4, cols=4):
    generator.eval()
    z = torch.randn(rows * cols, z_dim).to(device)
    gen_imgs = generator(z).detach().cpu()
    gen_imgs = 0.5 * gen_imgs + 0.5 # [-1, 1] -> [0, 1]

    fig, axs = plt.subplots(rows, cols, figsize=(4, 4))
    cnt = 0
    for i in range(rows):
        for j in range(cols):
            axs[i, j].imshow(gen_imgs[cnt, 0, :, :], cmap='gray')
            axs[i, j].axis('off')
            cnt += 1
    plt.show()

# 実行
losses, checkpoints = train(iterations=10000, batch_size=128, sample_interval=1000)

出力画像例

sample_intervalで指定したエポックごとにランダムな潜在変数から生成された画像が出力されると同時にモデルのチェックポイントが保存されます。※今回は初学者向けに全結合層(Linear)のみの軽量なモデルとしているためノイズが残ります。ここから畳み込み層(Conv2d)などを導入していくと、さらにくっきりとした数字が生成できるようになるでしょう。

学習済み生成器を用いた画像生成

下記のコードを利用すれば学習済みの生成器を用いて画像を生成することが可能です。

# --- 設定 ---
# 読み込みたいモデルのパス
checkpoint_path = "checkpoints/generator_10000.pth" 

# --- 1. モデルの枠組みを用意して、重みをロード ---
# まずモデルの構造(箱)を作る
generator = Generator(z_dim, img_shape).to(device)

# 指定したパスから重み(中身)を読み込みむ
# map_location=device をつけることで、GPUで保存したモデルをCPUでも開けるようしている
generator.load_state_dict(torch.load(checkpoint_path, map_location=device))

# --- 2. Generatorで画像を生成 ---
generator.eval() # 推論モードへ(BatchNormなどを固定)

# ランダムなノイズを1つ作成 (バッチサイズ=1, 次元=100)
z = torch.randn(1, z_dim).to(device)

with torch.no_grad(): # 勾配計算をしない(メモリ節約)
    gen_img = generator(z)

# --- 3. 表示用に整形 ---
# (1, 1, 28, 28) -> (28, 28) に不要な次元を削除してNumpy化
image = gen_img.squeeze().cpu().numpy()

# 画素値を [-1, 1] から [0, 1] に戻す
image = 0.5 * image + 0.5

# --- 4. 表示 ---
plt.imshow(image, cmap='gray')
plt.axis('off')
plt.show()

おわりに

PyTorchで書くGANの実装、いかがだったでしょうか? 今回は分かりやすさを最優先し、全結合層(Linear)のみの非常にシンプルな構造にしました。そのため生成画像にはまだノイズが残っていますが、**「GeneratorとDiscriminatorを競わせて学習させる」**というGANの心臓部は全く同じです。

この基本形をマスターすれば、層を畳み込み(CNN)に変えた高画質な「DCGAN」や、画像の条件付けができる「Conditional GAN」へもスムーズにステップアップできます。

コメント