Tensorflow Keras: AutoEncoder 実装例

Tensorflow Keras を利用したAutoEncoderの単純な実装例について,常に参照できるように備忘録として示す。本ページでは MNIST を対象とした画像の再構成を扱う。モデル構築の実装例を可能な限りシンプルに示すため,入力データの前処理やハイパーパラメータ調整については扱わない。

開発環境

  • keras: 2.10.0
  • matplotlib: 3.5.3
  • python: 3.10.4
  • tensorflow-gpu: 2.4.0

ソースコード

from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Input
from tensorflow.keras.datasets import mnist
%matplotlib inline
import matplotlib.pyplot as plt

# データ作成
# データの読み込み
(X_train, _), (X_test, _) = mnist.load_data()
# 前処理
height, width = 28, 28
input_shape = height * width
X_train, X_test = X_train.astype('float32') / 255, X_test.astype('float32') / 255    # [0, 1]
X_train, X_test = X_train.reshape(-1, input_shape), X_test.reshape(-1, input_shape)  # 1次元化

# モデル作成
latent_var = 64  # 潜在変数次元数
inputs = Input(shape=(input_shape, ))
# エンコーダー
encoded = Dense(latent_var, activation='relu')(inputs)
# デコーダー
decoded = Dense(input_shape, activation='sigmoid')(encoded)
# AutoEncoder
autoencoder = Model(inputs=inputs, outputs=decoded)
# モデルコンパイル
autoencoder.compile(optimizer='adam', loss='binary_crossentropy')
print(autoencoder.summary())

# 学習
history = autoencoder.fit(X_train, X_train, validation_split=0.2, epochs=10)
# 学習過程の可視化
plt.figure(figsize=(4, 4))
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Model loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc='upper right')
plt.show()

# 予測
pred = autoencoder.predict(X_test)
# 予測の可視化
num_figures = 5  # 表示画像数
for p, t in zip(pred[:num_figures], X_test[:num_figures]):
    fig, ax = plt.subplots(1, 2, figsize=(3, 4))
    ax[0].imshow(p.reshape(28, 28)*255, cmap='gray')
    ax[1].imshow(t.reshape(28, 28)*255, cmap='gray')
    ax[0].set_title('Prediction')
    ax[1].set_title('True')
    plt.show()

パッケージのインポート

「matplotlib」は学習・予測結果の可視化に用いる。AutoEncoderに利用する各種モジュールは tensorflow から keras モジュールを呼び出す。

from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Input
from tensorflow.keras.datasets import mnist
%matplotlib inline
import matplotlib.pyplot as plt

データ作成

keras で用意されているMNIST出たセットを利用する。今回のAutoEncoderでは畳み込みではなく全層結合を利用するので,入力するデータは [0, 1] の範囲に正規化したうえで1次元データに変換しておく。

# データの読み込み
(X_train, _), (X_test, _) = mnist.load_data()

# 前処理
height, width = 28, 28
input_shape = height * width
X_train, X_test = X_train.astype('float32') / 255, X_test.astype('float32') / 255    # [0, 1]
X_train, X_test = X_train.reshape(-1, input_shape), X_test.reshape(-1, input_shape)  # 1次元化

モデル作成

keras の Functional API 形式でモデルを構築。第1層の全層結合で28*28=784次元のデータを64次元まで圧縮し,2層目の全層結合で64次元のデータを784次元まで復元するような構造。

latent_var = 64  # 潜在変数次元数
inputs = Input(shape=(input_shape, ))
# エンコーダー
encoded = Dense(latent_var, activation='relu')(inputs)
# デコーダー
decoded = Dense(input_shape, activation='sigmoid')(encoded)
# AutoEncoder
autoencoder = Model(inputs=inputs, outputs=decoded)
# モデルコンパイル
autoencoder.compile(optimizer='adam', loss='binary_crossentropy')
print(autoencoder.summary())
# コード実行時の出力

Model: "model_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_2 (InputLayer)         [(None, 784)]             0         
_________________________________________________________________
dense_2 (Dense)              (None, 64)                50240     
_________________________________________________________________
dense_3 (Dense)              (None, 784)               50960     
=================================================================
Total params: 101,200
Trainable params: 101,200
Non-trainable params: 0
_________________________________________________________________
None

学習

定義した AutoEncoder について,学習データ(X_train)と正答データ(X_train)を与え,fitメソッドを利用して学習を行う。下記の例では訓練回数は100回(epochs=10),バリデーションデータ割合は20%(validation_split=0.2)に指定している。

history = autoencoder.fit(X_train, X_train, validation_split=0.2, epochs=10)
# コード実行時の出力

Epoch 1/10
188/188 [==============================] - 1s 5ms/step - loss: 0.3713 - val_loss: 0.1768
Epoch 2/10
188/188 [==============================] - 1s 5ms/step - loss: 0.1669 - val_loss: 0.1408
.
.
.
Epoch 10/10
188/188 [==============================] - 1s 5ms/step - loss: 0.0812 - val_loss: 0.0808

学習過程の可視化

history変数内に保存されている学習過程の情報を用いて,学習データ(’loss’)/バリデーションデータ(’val_loss’)に対する損失の変動を可視化する。

plt.figure(figsize=(4, 4))
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Model loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc='upper right')
plt.show()

予測

学習済み AutoEncoder の predictメソッドを用いてテストデータ(X_test)について予測値(画像次元の圧縮-再構成画像)を求める。

pred = autoencoder.predict(X_test)

予測の可視化

テストデータ(X_test)と再構成画像(pred)を並べて可視化することで再構成精度を可視化する。

num_figures = 5  # 表示画像数
for p, t in zip(pred[:num_figures], X_test[:num_figures]):
    fig, ax = plt.subplots(1, 2, figsize=(3, 4))
    ax[0].imshow(p.reshape(28, 28)*255, cmap='gray')
    ax[1].imshow(t.reshape(28, 28)*255, cmap='gray')
    ax[0].set_title('Prediction')
    ax[1].set_title('True')
    plt.show()

64次元にまで圧縮されたにもかかわらず,再構成された画像(画像左)は入力画像(画像右)に近い出力であることが確認できる。

画像ノイズ除去

本ページの目的からは逸れるが AutoEncoder の利用例としてしばしば画像のノイズ除去が挙げられる。簡単にノイズ除去効果の検証結果を示す。なお,以下で利用する AutoEncoder モデルは上記の流れで学習済みであることを前提とする。

ノイズ画像作成

X_test画像にランダムノイズを加えることで,ノイズ付き画像を作成する。

import numpy as np

# データの読み込み
(_, _), (X_test, _) = mnist.load_data()

# ランダムノイズの追加
noise = np.random.randint(0, 64, (28, 28))
X_test = np.where(X_test + noise > 255, 255, X_test + noise)

# 前処理
height, width = 28, 28
input_shape = height * width
X_test = X_test.astype('float32') / 255   # [0, 1]
X_test = X_test.reshape(-1, input_shape)  # 1次元化

ノイズ画像可視化

num_figures = 5  # 表示画像数
for nt in X_test[:num_figures]:
    fig, ax = plt.subplots(1, 1, figsize=(2, 2))
    ax.imshow(nt.reshape(height, width) * 255, cmap='gray')
    ax.set_title('Noized X_test')
    plt.show()

ノイズが追加された画像例。

ノイズ除去

# ノイズ画像を入力し画像を再構成
pred = autoencoder.predict(X_test)

# 予測の可視化
num_figures = 5  # 表示画像数
for p, t in zip(pred[:num_figures], X_test[:num_figures]):
    fig, ax = plt.subplots(1, 2, figsize=(3, 4))
    ax[0].imshow(p.reshape(height, width)*255, cmap='gray')
    ax[1].imshow(t.reshape(height, width)*255, cmap='gray')
    ax[0].set_title('Reconstructed')
    ax[1].set_title('True')
    plt.show()

ノイズ画像(画像右)に対して再構成画像(画像左)ではノイズが除去されている様子が確認できる。

コメント