pythonで3-D散布図

matplotlibを利用して3-D散布図を描画する方法についてまとめています。

単純に3-D散布図を描画

import pandas as pd 
%matplotlib notebook
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from sklearn.datasets import load_breast_cancer

# サンプルデータセットを読み込み
cancer = load_breast_cancer()

# データフレーム形式に変更
X = pd.DataFrame(cancer.data, columns=cancer.feature_names) # 特徴量
y = pd.DataFrame(cancer.target, columns=["Target"]) # ターゲット
df = pd.concat([X, y], axis=1)    # 特徴量とターゲットの水平結合

# 3-D散布図の描画描画領域を作成
fig = plt.figure()
ax = Axes3D(fig)

# Targetの属性=0の散布図を描画
ax.scatter(df[df.columns[0]],
           df[df.columns[1]],
           df[df.columns[2]])

# 軸ラベルを表示
ax.set_xlabel(df.columns[0], fontsize=10)    # x軸ラベル
ax.set_ylabel(df.columns[1], fontsize=10)    # y軸ラベル
ax.set_zlabel(df.columns[2], fontsize=10)    # z軸ラベル

plt.show() # グラフを描画

ターゲットの属性ごとにマーカー形状・色を変更して描画

# 3-D散布図の描画描画領域を作成
fig = plt.figure()
ax = Axes3D(fig)

# Targetの属性=0の散布図を描画
ax.scatter(df.query('Target == 0')[df.columns[0]],
           df.query('Target == 0')[df.columns[1]],
           df.query('Target == 0')[df.columns[2]],
           color="b",
           marker="o",
           label="Target == 0")

# Targetの属性=1の散布図を描画
ax.scatter(df.query('Target == 1')[df.columns[0]],
           df.query('Target == 1')[df.columns[1]],
           df.query('Target == 1')[df.columns[2]],
           color="r",
           marker="^",
           label="Target == 1")

# 軸ラベルを表示
ax.set_xlabel(df.columns[0], fontsize=10)    # x軸ラベル
ax.set_ylabel(df.columns[1], fontsize=10)    # y軸ラベル
ax.set_zlabel(df.columns[2], fontsize=10)    # z軸ラベル

plt.legend() # 凡例の追加
plt.show() # グラフを描画

各マーカー上にインデックス番号を描画

# 3-D散布図の描画描画領域を作成
fig = plt.figure()
ax = Axes3D(fig)

# Targetの属性=0の散布図を描画
ax.scatter(df.query('Target == 0')[df.columns[0]],
           df.query('Target == 0')[df.columns[1]],
           df.query('Target == 0')[df.columns[2]],
           color="b",
           marker="o",
           label="Target == 0")

# Targetの属性=1の散布図を描画
ax.scatter(df.query('Target == 1')[df.columns[0]],
           df.query('Target == 1')[df.columns[1]],
           df.query('Target == 1')[df.columns[2]],
           color="r",
           marker="^",
           label="Target == 1")

# インデックス番号を描画
for i in df.index:
    ax.text(df[df.columns[0]][i],df[df.columns[1]][i],df[df.columns[2]][i],str(i+1))

# 軸ラベルを表示
ax.set_xlabel(df.columns[0], fontsize=10)    # x軸ラベル
ax.set_ylabel(df.columns[1], fontsize=10)    # y軸ラベル
ax.set_zlabel(df.columns[2], fontsize=10)    # z軸ラベル

plt.legend() # 凡例の追加
plt.show() # グラフを描画

コメント