sklearn决策树可视化
过去,关于sklearn决策树可视化的教程大部分都是基于Graphviz(一个图形可视化软件)的。Graphviz的安装比较麻烦,并不是通过pip install就能搞定的,因为要安装底层的依赖库。现在,自版本0.21以后,scikit-learn也自带可视化工具了,它就是sklearn.tree.plot_tree()假设决策树模型(clf)已经训练好了,画图的代码如下:def tree1(cl
·
过去,关于sklearn决策树可视化的教程大部分都是基于Graphviz(一个图形可视化软件)的。
Graphviz的安装比较麻烦,并不是通过pip install就能搞定的,因为要安装底层的依赖库。
现在,自版本0.21以后,scikit-learn也自带可视化工具了,它就是sklearn.tree.plot_tree()
假设决策树模型(clf)已经训练好了,画图的代码如下:
def tree1(clf):
fig = plt.figure()
tree.plot_tree(clf)
fig.savefig(os.path.join(fig_dir, "tree1.png"))
没有设置图像的相关参数,画出的树结构看不清树节点的信息。
设置字体大小,把文字调大一点:
def tree2(clf):
fig = plt.figure()
tree.plot_tree(clf, fontsize=8)
fig.savefig(os.path.join(fig_dir, "tree2.png"))
文字是放大了,树节点也随着增大了,但是画面很拥挤。
那把画布调大一点:
def tree3(clf):
fig = plt.figure(figsize=(35, 10))
tree.plot_tree(clf, fontsize=8)
fig.savefig(os.path.join(fig_dir, "tree3.png"))
大功告成!
下面的代码包含数据读取、模型训练和画图,有注释,就不展开了。
关注【小猫AI】公众号,回复tree
可以获取训练模型的数据哦。
# -*- coding: utf-8 -*-
"""
Description : sklearn决策树可视化(scikit-learn==0.24.2)。
Authors : wapping
CreateDate : 2022/2/7
"""
import os
import pandas as pd
from sklearn import tree
from matplotlib import pyplot as plt
def read_data(fp):
"""加载训练数据。"""
data = pd.read_csv(fp, header=None)
x = data[[0, 1]] # 第0,1列为特征
y = data[[2]] # 第2列为标签
return x, y
def tree1(clf):
# 没有设置图像的相关参数,画出的树结构看不清树节点的信息
fig = plt.figure()
tree.plot_tree(clf)
fig.savefig(os.path.join(fig_dir, "tree1.png"))
def tree2(clf):
# 设置字体大小,树节点放大了,但是很拥挤
fig = plt.figure()
tree.plot_tree(clf, fontsize=8)
fig.savefig(os.path.join(fig_dir, "tree2.png"))
def tree3(clf):
# 同时设置字体大小和图像的大小,树结构正常显示
fig = plt.figure(figsize=(35, 10))
tree.plot_tree(clf, fontsize=8)
fig.savefig(os.path.join(fig_dir, "tree3.png"))
if __name__ == '__main__':
fig_dir = "data/plot_tree" # 保存图片的目录
data_path = "data/plot_tree_data.csv" # 训练树模型的数据
os.makedirs(fig_dir, exist_ok=True)
# 读取训练数据
x, y = read_data(data_path)
# 训练决策树分类器
clf = tree.DecisionTreeClassifier(min_samples_leaf=100, random_state=666)
clf = clf.fit(x, y)
# 画树结构并保存图片
tree1(clf)
tree2(clf)
tree3(clf)
更多推荐
已为社区贡献1条内容
所有评论(0)