Tkinterでアヤメの種類を判定する機械学習アプリを作成する

ここのところ,pythonの標準GUIライブラリ「Tkinter」を用いたアプリ開発の練習をしています。今回は機械学習アルゴリズム(ランダムフォレスト)を利用したクラス分類アプリ開発を想定し,アヤメ(iris)の花弁(petal)やガク(sepal)の長さ・幅からその種類を判定する簡単なアプリの開発事例について紹介します。

読者のみなさんのアプリ開発の参考になれば幸いです。

(本ページにて紹介しているコードは github にて公開しています。)

開発環境

今回の開発で使用した環境です。scikit-learn モジュールは iris データセットの取得及び,ランダムフォレストアルゴリズムのために利用しています。

  • python: 3.7.9
  • scikit-learn: 0.23.2

機械学習アルゴリズムの事前学習

今回開発するアプリでは,指定された花弁やガクの長さ・幅を入力データとし,ランダムフォレストアルゴリズムによる分類を行うことで,アヤメの種類を判定するという処理を実行します。

分類処理のコアとなる機械学習アルゴリズムはそのままでは利用できず,アヤメの種類と花弁・ガクの情報の関係を事前に学習しておく必要があります。今回は機械学習アルゴリズムとして,scikit-learn モジュールに用意されている RandomForestClassifier を採用し1,同じく scikit-learn モジュールに用意されているアヤメ(iris)データセットを対象とした教師あり分類学習を行います2

学習後のアルゴリズムのオブジェクトは,オブジェクトの状態を保存・復元するためのpickleモジュールを利用して保存しておきます。ここで用意したアルゴリズムを開発するアプリで復元することで,機械学習アルゴリズムを活用するアプリの開発を行います。

scikit-learn モジュールの iris データセットや,ランダムフォレストアルゴリズムを対象とした学習はよく知られており,説明を行うサイトも複数存在するため,詳しい説明はここでは割愛します。

import pickle
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier

# irisデータセットを読み込む
iris = load_iris()                   # インスタンスを生成
feature = iris.data                  # 特徴量(花弁・ガクの長さ・幅)を取得
target = iris.target.reshape(-1, 1)  # クラス(アヤメの種類)を取得

# RandomForestClassifierのインスタンスを生成
RandomForestIris = RandomForestClassifier()

# 学習を実行
RandomForestIris.fit(feature, target)

# モデルを保存
with open('model/RandomForest_Iris.pickle', mode='wb') as fp:
    pickle.dump(RandomForestIris, fp)

アプリ開発スクリプト全容

早速ですが,開発したアプリのスクリプトの全容です。IrisClassifier クラスを作成し,メソッドを利用することで処理を実現する,オブジェクト指向型プログラミングでアプリを開発しています。

各メソッドの細かな説明は後述しますが,ウィジェットの生成や分類処理は個別のメソッドで記述し,launch() メソッドでそれらの処理を実行することでアプリの起動を行っています。

import pickle
import tkinter as tk
from tkinter import ttk


class IrisClassifier:

    def __init__(self):

        self.root = tk.Tk()                                    # トップレベルウィンドウ
        self.classes = ['Setosa', 'Versicolour', 'Virginica']  # アヤメ種別のリスト

        # 事前学習済みモデルを読み込む
        with open('model/RandomForest_Iris.pickle', mode='rb') as fp:
            self.model = pickle.load(fp)

        # 分類結果を表示するためのtk.StringVar()
        self.predicted_class = tk.StringVar()

        # 入力フィールドの値をまとめたdictionary
        self.feature = {'sepal_length': tk.DoubleVar(value=3.0),
                        'sepal_width': tk.DoubleVar(value=3.0),
                        'petal_length': tk.DoubleVar(value=3.0),
                        'petal_width': tk.DoubleVar(value=3.0)}

    def launch(self):
        """
        アプリ起動用メソッド
        """

        self.call_window()                 # ウィンドウ設定を行う
        self.call_input_fields()           # 特徴量を指定する入力フィールドウィジェットを呼び出す
        self.call_classification_button()  # 分類処理を実行するボタンウィジェットを呼び出す
        self.call_result_label()           # 分類結果を表示するラベルウィジェットを呼び出す
        self.root.mainloop()               # アプリの起動状態を維持する

    def call_window(self):
        """
        ウィンドウの設定を行う
        """

        self.root.title('Iris Classification App')      # タイトルを変更する
        self.root.geometry('275x300')                   # ウィンドウサイズを指定する
        self.root.resizable(height=False, width=False)  # ウィンドウサイズ変更を不可にする

    def call_input_fields(self):
        """
        アヤメの特徴量を指定するための入力フィールドを呼び出す
        """

        # ウィジェット配置のためのLabelFrameを作成
        lf = ttk.LabelFrame(self.root, text='Features', padding=(10, 10))
        lf.pack(fill=tk.X, padx=5, pady=5)

        # 特徴量指定のための入力フィールドを作成
        for i, key in enumerate(self.feature.keys(), 0):
            tk.Label(lf, text=key, anchor='e', width=15).grid(row=i, column=0)
            tk.Label(lf, text=' : ').grid(row=i, column=1)
            tk.Entry(lf, textvariable=self.feature[key], justify='right', width=10).grid(row=i, column=2)

    def call_classification_button(self):
        """
        指定の特徴量をもとにアヤメ種別の分類処理を開始するボタンを呼び出す
        """

        tk.Button(self.root, text='Classify', command=self.classification).pack(fill=tk.X, padx=5, pady=5)

    def call_result_label(self):
        """
        分類結果を表示するラベルを呼び出す
        """

        # ウィジェット配置のためのFrameを作成
        f = tk.Frame(self.root, relief='solid', bd=1)
        f.pack(fill=tk.BOTH, expand=True, padx=5, pady=5)

        # ラベルを作成
        tk.Label(f, text='Predicted class').pack(anchor='center')
        tk.Label(f,
                 textvariable=self.predicted_class,
                 bg='white',
                 font=('', 30),
                 foreground='#ff0000').pack(anchor='center', expand=True, fill=tk.BOTH)

    def classification(self):
        """
        アヤメの分類処理を実行し,分類結果を示すtk.StringVar()を更新するメソッド
        """

        # 入力データを作成する
        input_data = [[self.feature['sepal_length'].get(),
                       self.feature['sepal_width'].get(),
                       self.feature['petal_length'].get(),
                       self.feature['petal_width'].get()]]

        # 分類を行う
        predict = self.model.predict(input_data)

        # tk.StringVar()の更新
        self.predicted_class.set(self.classes[int(predict)])


if __name__ == '__main__':

    app = IrisClassifier()  # IrisClassifierのインスタンスを生成する
    app.launch()            # launch()メソッドでアプリを起動する

__init__コンストラクタ

クラスのインスタンス生成時に実行されるコンストラクタです。アプリのウィンドウやアヤメ種類のリスト,分類処理に利用する事前学習済みモデル,特徴量の入力値や分類結果の表示に対応する tk.DoubleVar() / tk.StringVar() がここで宣言されています。

tk.StringVar() は,分類結果を表示するラベル3のテキスト4に対応しています。self.feature の辞書にまとめられた tk.DolubleVar() はアヤメの特徴量を示す入力フィールド5の値に対応しています6

    def __init__(self):

        self.root = tk.Tk()                                    # トップレベルウィンドウ
        self.classes = ['Setosa', 'Versicolour', 'Virginica']  # アヤメ種類のリスト

        # 事前学習済みモデルを読み込む
        with open('model/RandomForest_Iris.pickle', mode='rb') as fp:
            self.model = pickle.load(fp)

        # 分類結果を表示するためのtk.StringVar()
        self.predicted_class = tk.StringVar()

        # 入力フィールドの値をまとめたdictionary
        self.feature = {'sepal_length': tk.DoubleVar(value=3.0),
                        'sepal_width': tk.DoubleVar(value=3.0),
                        'petal_length': tk.DoubleVar(value=3.0),
                        'petal_width': tk.DoubleVar(value=3.0)}

launch()メソッド

IrisClassifier クラスのメソッドを実行することで,ウィジェットを呼び出し,アプリを起動するためのメソッドです。

各メソッドの概要はコメントの通りです。

    def launch(self):
        """
        アプリ起動用メソッド
        """

        self.call_window()                 # ウィンドウ設定を行う
        self.call_input_fields()           # 特徴量を指定する入力フィールドウィジェットを呼び出す
        self.call_classification_button()  # 分類処理を実行するボタンウィジェットを呼び出す
        self.call_result_label()           # 分類結果を表示するラベルウィジェットを呼び出す
        self.root.mainloop()               # アプリの起動状態を維持する

call_window()メソッド

__init__コンストラクタで呼び出したトップレベルウィンドウ7の設定を行うメソッドです。設定内容はスクリプト内コメントの通りです。

    def call_window(self):
        """
        ウィンドウの設定を行う
        """

        self.root.title('Iris Classification App')      # タイトルを変更する
        self.root.geometry('275x300')                   # ウィンドウサイズを指定する
        self.root.resizable(height=False, width=False)  # ウィンドウサイズ変更を不可にする

call_input_fields()メソッド

アヤメの特徴量を入力するためのフィールドを呼び出すためのメソッドです。ウィジェットの表示領域を視覚的に明確化するために ttk.LabelFrame() を呼び出し,入れ子にする形でその内部に入力フィールドや対応するラベルを配置しています。

入力フィールドや対応を示すラベルは,self.features の keys8 を対象とした for ループを活用して配置を行っています。tk.Entry() の textvariable オプションを tk.DoubleVar() と指定することで,入力フィールドの値を受け取れるように設計しています。

    def call_input_fields(self):
        """
        アヤメの特徴量を指定するための入力フィールドを呼び出す
        """

        # ウィジェット配置のためのLabelFrameを作成
        lf = ttk.LabelFrame(self.root, text='Features', padding=(10, 10))
        lf.pack(fill=tk.X, padx=5, pady=5)

        # 特徴量指定のための入力フィールドを作成
        for i, key in enumerate(self.feature.keys(), 0):
            tk.Label(lf, text=key, anchor='e', width=15).grid(row=i, column=0)
            tk.Label(lf, text=' : ').grid(row=i, column=1)
            tk.Entry(lf, textvariable=self.feature[key], justify='right', width=10).grid(row=i, column=2)

call_classification_button()メソッド

入力フィールドの値を入力データとして,ランダムフォレストアルゴリズムでアヤメの種類を判定する処理を実行するボタンを呼び出すためのメソッドです。

command=self.classification とある通り,ボタンがクリックされると classification() メソッドが実行されます。

def call_classification_button(self):
        """
        指定の特徴量をもとにアヤメ種別の分類処理を開始するボタンを呼び出す
        """

        tk.Button(self.root, text='Classify', command=self.classification).pack(fill=tk.X, padx=5, pady=5)

call_result_label()メソッド

分類結果を表示するためのラベルを呼び出すメソッドです。

ここでは,ウィジェットを配置するためのフレーム9を作成し,その内部に分類結果を表示するラベルを配置しています。

特に,2つ目のラベルでは textvariable オプションに tk.StringVar()10 を与えることで,StringVar() の変更だけで分類結果が更新されるようになっています。

    def call_result_label(self):
        """
        分類結果を表示するラベルを呼び出す
        """

        # ウィジェット配置のためのFrameを作成
        f = tk.Frame(self.root, relief='solid', bd=1)
        f.pack(fill=tk.BOTH, expand=True, padx=5, pady=5)

        # ラベルを作成
        tk.Label(f, text='Predicted class').pack(anchor='center')
        tk.Label(f,
                 textvariable=self.predicted_class,
                 bg='white',
                 font=('', 30),
                 foreground='#ff0000').pack(anchor='center', expand=True, fill=tk.BOTH)

classification()メソッド

花弁やガクの長さ・幅に関する指定の特徴量を対象にアヤメの種類を判定し,その結果を tk.StringVar() に反映するメソッドです。

scikit-learn の RandomForestClassifier() で作成した分類モデルは,入力データのサイズが [4, 1] となっるため,このメソッド内でも同様のサイズになるように入力データを作成しています。

作成された入力データを predict() に与えることでアヤメの種類の予測が行われます。ここの出力は [0, 1, 2] いずれかの値となり,この予測結果を示す値とアヤメの種類をまとめたリスト11 のインデックス番号と対応しています。

そして最後に,tk.StringVar() に予測された種類に対応する文字列を与えることで,自動的に call_result_label() メソッド内のラベルが更新されるようになっています。例えば,分類の予測値が [1] の場合は,self.classes[1] ,つまり’Versicolour’ が種類の判定結果として表示されるといった具合です。

    def classification(self):
        """
        アヤメの分類処理を実行し,分類結果を示すtk.StringVar()を更新するメソッド
        """

        # 入力データを作成する
        input_data = [[self.feature['sepal_length'].get(),
                       self.feature['sepal_width'].get(),
                       self.feature['petal_length'].get(),
                       self.feature['petal_width'].get()]]

        # 分類を行う
        predict = self.model.predict(input_data)

        # tk.StringVar()の更新
        self.predicted_class.set(self.classes[int(predict)])

アプリの起動

スクリプト全容や launch() メソッドの説明でも述べましたが,IrisClassifier クラスのインスタンスを生成し,launch() メソッドを実行することで,ウィンドウやウィジェットの呼び出し等を行い,アプリを起動しています。

動作GIF
起動後のアプリ動作。[Classify]ボタンクリックにより,入力データからアヤメの種類を判定し,その結果を表示している。
if __name__ == '__main__':

    app = IrisClassifier()  # IrisClassifierのインスタンスを生成する
    app.launch()            # launch()メソッドでアプリを起動する

所感

あくまで練習目的のアプリ開発に過ぎないので,実ビジネスに適用可能な程度には到達していないですが,機械学習アルゴリズムを搭載したアプリ開発は達成できたように思います。ビジネス実践においては,欠陥画像の分類や時系列データ予測など,より複雑な処理が要求されますが,今回の開発経験のエッセンスを取り入れて開発に勤しんでいければと思います。

このページが皆さんのアプリ開発の参考になれば幸いです。

  1. 今回の開発例ではランダムフォレストアルゴリズムを採用したが,ロジスティック回帰やニューラルネットワークなど,採用する機械学習アルゴリズムは何でもよい。
  2. 今回の目的は機械学習アルゴリズムを活用するアプリ開発を行うことにあるため,分類精度の良し悪しは考慮せずデフォルト設定で学習を行います。
  3. tk.Label()
  4. textvariable オプション
  5. tk.Entry()
  6. 初期値を3.0に設定しているため,アプリ起動時にはあらかじめ入力フィールドに3.0の値が入力されている。
  7. self.root = tk.Tk()
  8. dictionary 型におけるindex
  9. tk.Frame()
  10. self.predicted_class変数
  11. self.classes = [‘Setosa’, ‘Versicolour’, ‘Virginica’]

コメント