matplotlibを利用して3-D散布図を描画する方法についてまとめています。
単純に3-D散布図を描画
import pandas as pd
%matplotlib notebook
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from sklearn.datasets import load_breast_cancer
# サンプルデータセットを読み込み
cancer = load_breast_cancer()
# データフレーム形式に変更
X = pd.DataFrame(cancer.data, columns=cancer.feature_names) # 特徴量
y = pd.DataFrame(cancer.target, columns=["Target"]) # ターゲット
df = pd.concat([X, y], axis=1) # 特徴量とターゲットの水平結合
# 3-D散布図の描画描画領域を作成
fig = plt.figure()
ax = Axes3D(fig)
# Targetの属性=0の散布図を描画
ax.scatter(df[df.columns[0]],
df[df.columns[1]],
df[df.columns[2]])
# 軸ラベルを表示
ax.set_xlabel(df.columns[0], fontsize=10) # x軸ラベル
ax.set_ylabel(df.columns[1], fontsize=10) # y軸ラベル
ax.set_zlabel(df.columns[2], fontsize=10) # z軸ラベル
plt.show() # グラフを描画
ターゲットの属性ごとにマーカー形状・色を変更して描画
# 3-D散布図の描画描画領域を作成
fig = plt.figure()
ax = Axes3D(fig)
# Targetの属性=0の散布図を描画
ax.scatter(df.query('Target == 0')[df.columns[0]],
df.query('Target == 0')[df.columns[1]],
df.query('Target == 0')[df.columns[2]],
color="b",
marker="o",
label="Target == 0")
# Targetの属性=1の散布図を描画
ax.scatter(df.query('Target == 1')[df.columns[0]],
df.query('Target == 1')[df.columns[1]],
df.query('Target == 1')[df.columns[2]],
color="r",
marker="^",
label="Target == 1")
# 軸ラベルを表示
ax.set_xlabel(df.columns[0], fontsize=10) # x軸ラベル
ax.set_ylabel(df.columns[1], fontsize=10) # y軸ラベル
ax.set_zlabel(df.columns[2], fontsize=10) # z軸ラベル
plt.legend() # 凡例の追加
plt.show() # グラフを描画
各マーカー上にインデックス番号を描画
# 3-D散布図の描画描画領域を作成
fig = plt.figure()
ax = Axes3D(fig)
# Targetの属性=0の散布図を描画
ax.scatter(df.query('Target == 0')[df.columns[0]],
df.query('Target == 0')[df.columns[1]],
df.query('Target == 0')[df.columns[2]],
color="b",
marker="o",
label="Target == 0")
# Targetの属性=1の散布図を描画
ax.scatter(df.query('Target == 1')[df.columns[0]],
df.query('Target == 1')[df.columns[1]],
df.query('Target == 1')[df.columns[2]],
color="r",
marker="^",
label="Target == 1")
# インデックス番号を描画
for i in df.index:
ax.text(df[df.columns[0]][i],df[df.columns[1]][i],df[df.columns[2]][i],str(i+1))
# 軸ラベルを表示
ax.set_xlabel(df.columns[0], fontsize=10) # x軸ラベル
ax.set_ylabel(df.columns[1], fontsize=10) # y軸ラベル
ax.set_zlabel(df.columns[2], fontsize=10) # z軸ラベル
plt.legend() # 凡例の追加
plt.show() # グラフを描画
コメント