sklearn - 决策树的基本使用
1 使用sklearn内置的数据集
sklearn
提供了sklearn.datasets
模块,用于加载和获取流行的参考数据集。
以接下来要使用的葡萄酒数据集(一个分类数据集)为例,简单介绍sklearn.datasets
模块的使用:
# API参考:sklearn.datasets.load_wine(*, return_X_y=False, as_frame=False)
wine= load_wine() # 加载并返回葡萄酒数据集
wine
在jupyter中直接输出wine会得到以下结果

看起来有点乱,这是因为load_wine()默认返回的是字典

如果想获取特征数据、类标签,可以用以下方法
wine.data # 获取data数据
wine.target # 获取target(即类标签)数据
要了解数据的数据列名称和target类的名称,可以使用feature_names
和 target_names
wine.feature_names # 获取数据列名称
wine.target_names # 获取target类名称

如果觉得这样看起来不够直观,可以改变load_wine()
的参数值,获取DataFrame类型的数据
wine=load_wine(as_frame=True) # as_frame=True是返回的数据是DataFrame
df=pd.concat([wine.data,wine.target],axis=1) # 将data(特征数据)和target(类标签)拼接在一起
df
2 决策树
这里不具体介绍决策树的算法原理,只讲在sklearn中是决策树模型的用法。
决策树(DTs)是一种用于[分类](http://scikit-learn.org.cn/view/89.html#1.10.1 分类)和[回归](http://scikit-learn.org.cn/view/89.html#1.10.2 回归)的非参数有监督学习方法。其目标是创建一个模型,通过学习从数据特性中推断出的简单决策规则来预测目标变量的值
2.1 导入相关的库、加载数据集
这里使用的是葡萄酒数据集
from sklearn import tree # sklearn中的树模型
from sklearn.datasets import load_wine # 葡萄酒数据集
from sklearn.model_selection import train_test_split # 用于划分训练集、测试集
# 1、加载葡萄酒数据集
wine=load_wine()
# 2、划分训练集、测试集
# test_size是测试集的比例,test_size=0.3 代表 训练集、测试集按照7:3的比例划分
Xtrain,Xtest,Ytrain,Ytest=train_test_split(wine.data,wine.target,test_size=0.3)
2.2 构建模型
模型构建
# 3、构建模型
# sklearn中几乎所有模型的使用都是3步骤:创建模型对象、fit()、score()
# 创建模型对象
clf=tree.DecisionTreeClassifier(criterion="entropy"
,random_state=30
,splitter='random'
,max_depth=3
,min_samples_leaf=10
,min_samples_split=25
)
# 给模型填充数据:将训练集填充到模型中
clf=clf.fit(Xtrain,Ytrain)
# 在评估模型在测试集上的表现
# score()的返回值是给定测试数据和标签上的平均准确度。
score=clf.score(Xtest,Ytest)
score
tree.DecisionTreeClassifier
的参数非常多,但都有默认值,如果什么也不懂,使用默认值就好😥
class sklearn.tree.DecisionTreeClassifier(*, criterion='gini', splitter='best', max_depth=None, min_samples_split=2, min_samples_leaf=1, min_weight_fraction_leaf=0.0, max_features=None, random_state=None, max_leaf_nodes=None, min_impurity_decrease=0.0, min_impurity_split=None, class_weight=None, presort='deprecated', ccp_alpha=0.0)
下面简单介绍一下比较重要的几个参数(搬运:sklearn.tree.DecisionTreeClassifier-scikit-learn中文社区)
参数 | 说明 |
---|---|
criterion(n.标准;(评判或作决定的)准则;原则) | {“gini”, “entropy”}, default=”gini” 这个参数是用来选择使用何种方法度量树的切分质量的。当criterion取值为“gini”时采用 基尼不纯度(Gini impurity)算法构造决策树,当criterion取值为 “entropy” 时采用信息增益( information gain)算法构造决策树. |
splitter | {“best”, “random”}, default=”best” 此参数决定了在每个节点上拆分策略的选择。支持的策略是“best” 选择“最佳拆分策略”, “random” 选择“最佳随机拆分策略”。 |
random_state(没有特别的含义,随便填一个整数都行) | int, RandomState instance, default=None 此参数用来控制估计器的随机性。即使分割器设置为“最佳”,这些特征也总是在每个分割中随机排列。当 max_features <n_features 时,该算法将在每个拆分中随机选择max_features ,然后再在其中找到最佳拆分。但是,即使max_features = n_features ,找到的最佳分割也可能因不同的运行而有所不同。 就是这种情况,如果标准的改进对于几个拆分而言是相同的,并且必须随机选择一个拆分。 为了在拟合过程中获得确定性的行为,random_state 必须固定为整数。 |
max_depth | int, default=None 树的最大深度。如果取值为None,则将所有节点展开,直到所有的叶子都是纯净的或者直到所有叶子都包含少于min_samples_split个样本。 |
min_samples_split | int or float, default=2 拆分内部节点所需的最少样本数, 如果取值 int , 则将 min_samples_split 视为最小值。 如果为float,则min_samples_split 是一个分数,而ceil(min_samples_split * n_samples) 是每个拆分的最小样本数。 |
min_samples_leaf | int or float, default=1 在叶节点处所需的最小样本数。 仅在任何深度的分裂点在左分支和右分支中的每个分支上至少留有 min_samples_leaf 个训练样本时,才考虑。 这可能具有平滑模型的效果,尤其是在回归中。 如果为int,则将min_samples_leaf 视为最小值。 如果为float,则min_samples_leaf 是一个分数,而ceil(min_samples_leaf * n_samples) 是每个节点的最小样本数。 |
max_features | int, float or {“auto”, “sqrt”, “log2”}, default=None 寻找最佳分割时要考虑的特征数量: |
min_impurity_decrease | float, default=0.0 如果节点分裂会导致不纯度的减少大于或等于该值,则该节点将被分裂。 |
如果想看看模型用到了那些特征,可以这样做
[*zip(wine.feature_names,clf.feature_importances_)]
可以看出,有些特征是没有用到的(值为0)

确定最佳参数
既然决策树的参数这么多,那如何确定最佳的参数呢?
答案就是利用循环一个个去试😏
import matplotlib.pyplot as plt
test=[]
# 这里循环改变的是max_depth参数的值,确定最佳的max_depth值
for i in range(10):
clf=tree.DecisionTreeClassifier(max_depth=i+1
,criterion="entropy"
,random_state=30
,splitter="random"
)
clf=clf.fit(Xtrain,Ytrain)
score=clf.score(Xtest,Ytest)
test.append(score)
# 画图
# 横坐标是max_depth,纵坐标是模型的分数,
plt.plot(range(1,11),test,color="blue",label="max_depth")
plt.legend()
plt.show()

2.3 可视化
决策树的一个特点是可视化,让人可以直观地看到树是什么样子的。
这里使用graphviz
模块画决策树(注意:graphviz是一个软件,不是简单地pip install graphviz
就能安装的,需要去官网下载,然后设置环境变量:具体操作参考 解决failed to execute [‘dot’, ‘-Tsvg’], make sure the Graphviz executables are on your systems)
import graphviz # 用于画决策树的模块
# filled 代表是否给树的结点填充颜色
# rounded 代表结点是否为圆角矩形
dot_data=tree.export_graphviz(clf
,feature_names=wine.feature_names
,class_names=["红酒","白酒","啤酒"]
,filled=True
,rounded=True
)
graph=graphviz.Source(dot_data)
graph
小技巧:在逗号写在参数名前面,例如
,filled=True
,如果不想要这个参数,可以直接把这一行注释掉
一颗决策树就这样画好了
参考: