「RとStanで始めるベイズ統計モデリングによるデータ分析入門」「第5部第2章 ローカルレベルモデル」にて紹介されている「plotSSM関数」のpython実装verになります。
以下のスクリプトを「plotSSM.py」として保存し,起動している.ipynbファイルや.pyファイルと同一のディレクトリに配置してください。その状態で以下のようにplotSSM.pyからplotSSM関数を呼び出すことで使用可能な状態となります。
from plotSSM import plotSSM
plotSSM関数
# -*- coding: utf-8 -*-
"""
@author: data-anal-ojisan
"""
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
plt.style.use("ggplot")
plt.rcParams["font.family"] = "Meiryo"
def plotSSM(mcmc_sample, time_vec, state_name,
graph_title, y_label, axes, obs_vec=None):
# 状態空間モデルを図示する関数
#
# Args:
# mcmc_sample : MCMCサンプル
# time_vec : 時間軸(datetime)のベクトル
# obs_vec : (必要なら)観測値のベクトル
# state_name : 図示する状態の変数名
# graph_title : グラフタイトル
# y_label : y軸のラベル
# axes : サブプロットの描画領域
#
# Returns:
# 生成されたグラフ
# すべての時点の状態の、95%区間と中央値
result_df = pd.DataFrame(np.zeros([mcmc_sample[state_name].shape[1], 3]))
for i in range(mcmc_sample[state_name].shape[1]):
result_df.iloc[i,:] = np.percentile(mcmc_sample[state_name][:,i], q=[2.5, 50, 97.5])
# 列名の変更
result_df.columns = ["lwr", "fit", "upr"]
# 時間軸の追加
result_df['time'] = time_vec
# 観測値の追加
if obs_vec is not None:
if obs_vec.isnull().all(axis=0) == False:
result_df['obs'] = obs_vec
# 図示
axes.plot(result_df['time'],
result_df['fit'],
color='black')
axes.fill_between(x=result_df['time'],
y1=result_df['upr'],
y2=result_df['lwr'],
color='gray',
alpha=0.5)
axes.set_ylabel(y_label)
axes.set_title(graph_title)
# 観測値をグラフに追加
if obs_vec is not None:
if obs_vec.isnull().all(axis=0) == False:
axes.plot(result_df['time'],
result_df['obs'],
marker='.',
linewidth=0,
color='black')
# グラフを返す
return axes
コメント