「RとStanで始めるベイズ統計モデリングによるデータ分析入門」「実践編第5部第3章 状態空間モデルによる予測と補完」を対象に,公開されているR,Stanのコードをpython,pystanのコードへと書き直した一例です。Stanの代わりにpystanを利用しています。
本ページでは公開されていない書籍の内容については一切触れません。理論や詳しい説明は書籍を参照してください。
なお,こちらで紹介しているコードには誤りが含まれる可能性があります。内容やコードについてお気づきの点等ございましたら,ご指摘いただけると幸いです。
(本ページにて紹介しているコードはgithubにて公開しています。)
DataAnalOji.hatena.sample/python_samples/stan/5-3-状態空間モデルによる予測と補完.ipynb at master · Data-Anal-Ojisan/DataAnalOji.hatena.sample
samples for my own blog. Contribute to Data-Anal-Ojisan/DataAnalOji.hatena.sample development by creating an account on GitHub.
分析の準備
パッケージの読み込み
plotSSM関数についてはこちらをご参照ください。
import pystan
import datetime
import numpy as np
import pandas as pd
%matplotlib inline
import matplotlib.pyplot as plt
plt.style.use('ggplot')
plt.rcParams['font.family'] = 'Meiryo'
# 自作のplotSSM関数を読み込み
from plotSSM import plotSSM
データの読み込み
sales_df_all = pd.read_csv('5-2-1-sales-ts-1.csv')
sales_df_all['date'] = pd.to_datetime(sales_df_all['date'])
sales_df_all.head(n=3)
ローカルレベルモデルによる予測の実行
データの準備
data_list_pred = dict(T=len(sales_df_all),
y=sales_df_all['sales'],
pred_term=20)
モデルの推定
# stanコードの記述(5-3-1-local-level-pred.stan)
stan_code = '''
data {
int T; // データ取得期間の長さ
vector[T] y; // 観測値
int pred_term; // 予測期間の長さ
}
parameters {
vector[T] mu; // 状態の推定値(水準成分)
real<lower=0> s_w; // 過程誤差の標準偏差
real<lower=0> s_v; // 観測誤差の標準偏差
}
model {
// 状態方程式に従い、状態が遷移する
for(i in 2:T) {
mu[i] ~ normal(mu[i-1], s_w);
}
// 観測方程式に従い、観測値が得られる
for(i in 1:T) {
y[i] ~ normal(mu[i], s_v);
}
}
generated quantities{
vector[T + pred_term] mu_pred; // 予測値も含めた状態の推定値
// データ取得期間においては、状態推定値muと同じ
mu_pred[1:T] = mu;
// データ取得期間を超えた部分を予測
for(i in 1:pred_term){
mu_pred[T + i] = normal_rng(mu_pred[T + i - 1], s_w);
}
}
'''
# モデルのコンパイル
stan_model_llp = pystan.StanModel(model_code=stan_code)
# サンプリング
local_level_pred = stan_model_llp.sampling(data=data_list_pred,
seed=1,
n_jobs=1)
参考:収束の確認
# 収束確認用のRhatのプロット関数
def mcmc_rhat(dataframe, column='Rhat', figsize=(5, 10)):
plt.figure(figsize=figsize)
plt.hlines(y=dataframe[column].sort_values().index,
xmin=1,
xmax=dataframe[column].sort_values(),
color='b',
alpha=0.5)
plt.vlines(x=1.05, ymin=0, ymax=len(dataframe[column]), linestyles='--')
plt.plot(dataframe[column].sort_values().values,
dataframe[column].sort_values().index,
marker='.',
linestyle='None',
color='b',
alpha=0.5)
plt.yticks(color='None')
plt.tick_params(length=0)
plt.xlabel(column)
plt.show()
# 各推定結果のデータフレームを作成
summary = pd.DataFrame(local_level_pred.summary()['summary'],
columns=local_level_pred.summary()['summary_colnames'],
index=local_level_pred.summary()['summary_rownames'])
# プロット
mcmc_rhat(summary)
参考:結果の表示
print(local_level_pred.stansummary(pars=["s_w", "s_v", "lp__"],
probs=[0.025, 0.5, 0.975]))
Inference for Stan model: anon_model_034be78dc0fa71b8d5e9e6b8edac86e4.
4 chains, each with iter=2000; warmup=1000; thin=1;
post-warmup draws per chain=1000, total post-warmup draws=4000.
mean se_mean sd 2.5% 50% 97.5% n_eff Rhat
s_w 1.3 0.02 0.32 0.79 1.26 2.05 230 1.02
s_v 2.87 6.3e-3 0.26 2.39 2.86 3.42 1740 1.0
lp__ -225.5 1.35 19.48 -264.7 -225.3 -188.2 207 1.03
Samples were drawn using NUTS at Thu Sep 10 22:36:29 2020.
For each parameter, n_eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor on split chains (at
convergence, Rhat=1).
図示
予測対象期間も含めた日付を用意
date_plot = pd.date_range('2010-1-1', periods=120, freq='D')
生成された乱数を格納
mcmc_sample = local_level_pred.extract()
予測結果の図示
# グラフ描画領域の作成
fig = plt.figure(figsize=(10, 5))
ax = fig.add_subplot(1,1,1)
# plotSSM関数によるグラフ作成
plotSSM(mcmc_sample,
time_vec=date_plot,
state_name='mu_pred',
graph_title='予測の結果',
y_label='sales',
axes=ax)
# グラフの描画
plt.show()
欠損があるデータ
データの読み込み
sales_df_NA = pd.read_csv('5-3-1-sales-ts-1-NA.csv')
日付をdatetime型にする
sales_df_NA['date'] = pd.to_datetime(sales_df_NA['date'])
売り上げデータに一部欠損がある
sales_df_NA.head(n=3)
欠損データの取り扱い
NaN値がある行を削除
sales_df_omit_NA = sales_df_NA.dropna(axis=0)
sales_df_omit_NA.head(n=3)
データを取得した期間
len(sales_df_NA)
100
NaN値がどこにあるのかを判別
[(n, not i) for n, i in enumerate(sales_df_NA['sales'].isnull())]
[(0, True),
(1, True),
(2, False),
(3, True),
…省略…
(98, True),
(99, True)]
データがある行番号の取得
sales_df_omit_NA.index
Int64Index([ 0, 1, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
18, 19, 20, 21, 22, 23, 24, 25, 48, 49, 50, 51, 52, 53, 54, 55, 56,
57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73,
74, 75, 76, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92,
93, 95, 96, 97, 98, 99],
dtype='int64')
ローカルレベルモデルによる補間の実行
データの準備
data_list_interpolation = dict(T=len(sales_df_NA),
len_obs=len(sales_df_omit_NA),
y=sales_df_omit_NA['sales'],
obs_no=sales_df_omit_NA.index+1) # pythonはインデックスが0始まりなので1を加えておく
モデルの推定
# stanコードの記述(5-3-2-local-level-interpolation.stan)
stan_code = '''
data {
int T; // データ取得期間の長さ
int len_obs; // 観測値が得られた個数
vector[len_obs] y; // 観測値
int obs_no[len_obs]; // 観測値が得られた時点
}
parameters {
vector[T] mu; // 状態の推定値(水準成分)
real<lower=0> s_w; // 過程誤差の標準偏差
real<lower=0> s_v; // 観測誤差の標準偏差
}
model {
// 状態方程式に従い、状態が遷移する
for(i in 2:T) {
mu[i] ~ normal(mu[i-1], s_w);
}
// 観測方程式に従い、観測値が得られる
// ただし、「観測値が得られた時点」でのみ実行する
for(i in 1:len_obs) {
y[i] ~ normal(mu[obs_no[i]], s_v);
}
}
'''
# モデルのコンパイル
stan_model_lli = pystan.StanModel(model_code=stan_code)
# サンプリング
local_level_interpolation = stan_model_lli.sampling(data=data_list_interpolation,
seed=1,
iter=4000,
n_jobs=1)
参考:収束の確認
# 各推定結果のデータフレームを作成
summary_lli = pd.DataFrame(
local_level_interpolation.summary()['summary'],
columns=local_level_interpolation.summary()['summary_colnames'],
index=local_level_interpolation.summary()['summary_rownames'])
# プロット
mcmc_rhat(summary_lli)
参考:結果の表示
print(local_level_interpolation.stansummary(pars=["s_w", "s_v", "lp__"],
probs=[0.025, 0.5, 0.975]))
Inference for Stan model: anon_model_28d0b3c36d27d089631f3604cc43a67b.
4 chains, each with iter=4000; warmup=2000; thin=1;
post-warmup draws per chain=2000, total post-warmup draws=8000.
mean se_mean sd 2.5% 50% 97.5% n_eff Rhat
s_w 1.3 0.01 0.31 0.83 1.26 2.02 582 1.01
s_v 2.65 4.5e-3 0.29 2.14 2.64 3.26 4023 1.0
lp__ -179.2 0.86 19.47 -219.5 -178.6 -143.1 513 1.01
Samples were drawn using NUTS at Thu Sep 10 22:37:39 2020.
For each parameter, n_eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor on split chains (at
convergence, Rhat=1).
図示
生成された乱数を格納
mcmc_sample_interpolation = local_level_interpolation.extract()
図示
# グラフ描画領域の作成
fig = plt.figure(figsize=(10, 5))
ax = fig.add_subplot(1,1,1)
# plotSSM関数によるグラフ作成
plotSSM(mcmc_sample_interpolation,
time_vec=sales_df_all['date'],
obs_vec=sales_df_all['sales'],
state_name='mu',
graph_title='補間の結果',
y_label='sales',
axes=ax)
# グラフの描画
plt.show()
参考:予測区間
モデルの推定
# stanコードの記述(5-3-3-local-level-interpolation-prediction-interval.stan)
stan_code = '''
data {
int T; // データ取得期間の長さ
int len_obs; // 観測値が得られた個数
vector[len_obs] y; // 観測値
int obs_no[len_obs]; // 観測値が得られた時点
}
parameters {
vector[T] mu; // 状態の推定値(水準成分)
real<lower=0> s_w; // 過程誤差の標準偏差
real<lower=0> s_v; // 観測誤差の標準偏差
}
model {
// 状態方程式に従い、状態が遷移する
for(i in 2:T) {
mu[i] ~ normal(mu[i-1], s_w);
}
// 観測方程式に従い、観測値が得られる
// ただし、「観測値が得られた時点」でのみ実行する
for(i in 1:len_obs) {
y[i] ~ normal(mu[obs_no[i]], s_v);
}
}
generated quantities {
vector[T] y_pred; // 観測値の予測値
for (i in 1:T) {
y_pred[i] = normal_rng(mu[i], s_v);
}
}
'''
# モデルのコンパイル
stan_model_llpi = pystan.StanModel(model_code=stan_code)
# サンプリング
local_level_prediction_interval = stan_model_llpi.sampling(data=data_list_interpolation,
seed=1,
iter=4000,
n_jobs=1)
参考:収束の確認
# 各推定結果のデータフレームを作成
summary_llpi = pd.DataFrame(
local_level_prediction_interval.summary()['summary'],
columns=local_level_prediction_interval.summary()['summary_colnames'],
index=local_level_prediction_interval.summary()['summary_rownames'])
# プロット
mcmc_rhat(summary_llpi)
参考:結果の表示
print(local_level_prediction_interval.stansummary(pars=["s_w", "s_v", "lp__"],
probs=[0.025, 0.5, 0.975]))
Inference for Stan model: anon_model_d708c6aac7351d717bde7cc6d469ec51.
4 chains, each with iter=4000; warmup=2000; thin=1;
post-warmup draws per chain=2000, total post-warmup draws=8000.
mean se_mean sd 2.5% 50% 97.5% n_eff Rhat
s_w 1.32 0.01 0.3 0.84 1.28 2.02 445 1.0
s_v 2.65 4.3e-3 0.29 2.13 2.64 3.25 4452 1.0
lp__ -181.0 0.95 18.79 -218.1 -180.9 -144.9 388 1.0
Samples were drawn using NUTS at Thu Sep 10 22:38:39 2020.
For each parameter, n_eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor on split chains (at
convergence, Rhat=1).
図示
生成された乱数を格納
mcmc_sample_prediction_interval = local_level_prediction_interval.extract()
# グラフ描画領域の作成
fig = plt.figure(figsize=(10, 5))
ax = fig.add_subplot(1,1,1)
# plotSSM関数によるグラフ作成
plotSSM(mcmc_sample_prediction_interval,
time_vec=sales_df_all['date'],
obs_vec=sales_df_all['sales'],
state_name='y_pred',
graph_title='補間の結果:予測分布',
y_label='sales',
axes=ax)
# グラフの描画
plt.show()
pystanの詳細については公式ページを参照してください。
PyStan — pystan 3.10.0 documentation
コメント