分类算法之决策树:例子讲解+实战案例(附源码)
????点击关注|设为星标|干货速递????目录介绍1. 从一个例子(贷款前,评估个人能否偿还)出发,怎样决策。2. 分析算法原理、思路形成的过程。3. 扩展决策树衍生的高级版本进行简要的介绍。原理...
👆点击关注|设为星标|干货速递👆
目录介绍
1. 从一个例子(贷款前,评估个人能否偿还)出发,怎样决策。
2. 分析算法原理、思路形成的过程。
3. 扩展决策树衍生的高级版本进行简要的介绍。
原理
在已知的条件中,选取一个条件作为树根,然后再看是否还需要其他判断条件。
如果需要的话,继续构建一个分支来判断第二个条件,以此类推。
最终形成的这颗树上,所有的叶子节点都是要输出的类别信息,所有的非叶子节点都是特征信息。
理想情况:
决策树上的每一个叶子节点都是一个纯粹的分类。
实际:
决策树实现的时候采用贪心算法,来寻找一个最近的最优解。
几个版本的决策树的比较:
例子
例子:贷款前,评估个人能否偿还???
(图片来源于网络)
比如要参考的有:1. 年收入情况、2. 是否房产、3. 婚姻情况,根据这三个决策条件进行决定。
优缺点
优点:
1. 非常直观,可解释性极强。
2. 预测速度比较快(由条件判断即可)。
3. 既可以处理离散值也可以出来连续值,还可以处理缺失值。
缺点:
1. 容易过拟合。
2. 需要处理样本不均衡的问题。
3. 样本的变化会引发树结构巨变。
关于剪枝:
剪枝方法:1. 预剪枝和2. 后剪枝。
目的:去掉不必要的节点路径,防止过拟合。
1. 预剪枝:在决策树构建之初就设定一个阈值,当分裂节点的熵阈值小于设定值的时候就不再进行分裂了。
2. 后剪枝:在决策树已经构建完成后,再根据设定的条件来判断是否要合并一些中间节点,使用叶子节点来代替。
备注:通常都是采用后剪枝。
实战代码
咱们以鸢尾花数据集进行实战,首先进行导入sklearn库,以及加载数据集
# sklearn 数据集
from sklearn import datasets
# 引入决策树算法包
from sklearn.tree import DecisionTreeClassifier
# 矩阵运算库numpy
import numpy as np
# 设置随机种子,可以保证每次产生的随机数是一样的
np.random.seed(0)
# 获取鸢尾花数据集
iris = datasets.load_iris()
iris_x = iris.data # 数据部分
iris_y = iris.target # 类别部分
设置数的最大深度为4
# 从150条数据中选择140条作为训练集,10条作为测试集。
# permutation接收一个数作为参数(这里为数据集长度150)。产生一个0-149乱序一维数组
randomarr = np.random.permutation(len(iris_x))
iris_x_train = iris_x[randomarr[:-10]] # 训练集数据
iris_y_train = iris_y[randomarr[:-10]] # 训练集标签
iris_x_test = iris_x[randomarr[-10:]] # 测试集数据
iris_y_test = iris_y[randomarr[-10:]] # 训练集标签
clf = DecisionTreeClassifier(max_depth=4)
# 调用该对象的训练方法,主要接收两个参数:训练数据集及其类别标签
clf.fit(iris_x_train,iris_y_train)
这样决策树算法会生成一个树形的判定模型。
展示决策树算法生成的模型
# 引入画图相关的包
from IPython.display import Image
from sklearn import tree
# dot是一个程序化生成流程图的简单语言
import pydotplus
dot_data = tree.export_graphviz(
clf,
out_file=None,
feature_names = iris.feature_names,
class_names = iris.target_names,
filled = True,
rounded = True,
special_characters=True)
graph = pydotplus.graph_from_dot_data(dot_data)
Image(graph.create_png())
决策树模型图
经过运行上面的画图代码,生成了该图。
可以看到每一次的判定条件以及基尼系数,还有能够落入此决策的样本数量和分类的类别。
使用模型对测试数据进行测试:
# 调用预测方法,主要接收一个参数:测试数据集
iris_y_predict = clf.predict(iris_x_test)
# 计算各预测样本预测的概率值
probility = clf.predict_proba(iris_x_test)
# 计算出准确率
score = clf.score(iris_x_test,iris_y_test,sample_weight=None)
# 输出预测结果
print("iris_y_predict=",iris_y_predict)
# 输出原始结果
print("iris_y_test=",iris_y_test)
# 输出准确率
print("Accuracy:",score)
可以看到第二个测试样本预测错误了,其他的都预测准确了,准确率在90%。
扩展内容
随机森林:使用bagging方案构建了多棵决策树,然后对所有的决策树结果进行平均计算以获得最终结果。
GBDT:GBDT构建的构建的多棵树之间是有联系的,每个分类器在上一轮分类器的残差基础上进行训练。
XGBoost:优化了GBDl里面的求解过程,并加入了很多工程上的优化项目。
END
创作不易,路过的小伙伴右下角点赞和再看,鼓励一下
觉得不错的话,可以分享给其他小伙伴
干货文章推荐:
2. 参考ggplot2,Seaborn将迎来超大版本更新!
分享
收藏
点赞
在看
更多推荐
所有评论(0)