手把手从0实现一颗决策树

手把手从0实现一颗决策树

这是我2019-09-24在知乎上面写的文章,原文跳转点这里,今天面试java不知道为什么面试官感兴趣被提起,提起于是搬运到博客上,纯python实现,除numpy外无第三方库

前段时间数学建模我和我的小伙伴们划了个水,几个没有建模经验的小伙伴临时抱佛脚最后通了个宵,在此感谢我的队友们。抱佛脚的时候笔者对决策树起了兴趣(可惜没有用上),之前看过一点点理论相关的知识,但是没有自己实现过,这几天空下来决定自己从0撸一颗简单的ID3决策树。

由于我也是个新手,因此本篇文章深度有限,大神请绕道^_^,我将着重于代码的实现和整个流程的说明,通过一个案例(案例取自https://blog.csdn.net/mn_kw/article/details/79913786,本文就是对这篇博客中的介绍进行代码实现),带你一步步看懂数据在决策树中的行为,减少理论部分的晦涩知识。快来动手跟我一起实现一颗决策树吧!

一.什么是决策树

现想象一个女孩的母亲要给这个女孩介绍男朋友,于是有了下面的对话:
女儿:多大年纪了?
母亲:26。
女儿:长的帅不帅?
母亲:挺帅的。
女儿:收入高不?
母亲:不算很高,中等情况。
女儿:是公务员不?
母亲:是,在税务局上班呢。
女儿:那好,我去见见。
这个女孩的决策过程就是典型的分类树决策。
(声明:此决策树纯属为了写文章而YY的产物,没有任何根据,也不代表任何女孩的择偶倾向,请各位女同胞莫质问我^_^)
例子取自https://blog.csdn.net/mn_kw/article/details/79913786

假设现在有这么一张表,统计了一些人的择偶标准,称为训练样本

img

我们将要对其构建类似如下的判断分支结构,这就叫决策树。之后再有一组样本,我们就可以根据这棵树来判定嫁不嫁啦!

img

看到这里你肯定会有个疑问:为什么先判断高不高?难道真是传说中的一高遮百丑吗?

img

当然不是,我们可是有依据的!这个依据就叫做信息增益(Gain),公式如下,就是两个函数相减而已,看不懂也没关系,下面有例子说明。

img)H(x)叫香农熵,H(D|x)叫条件熵<br>就两个函数相减而已

其中H(x)和H(Y|X)长下面这样

img)p(x)是概率,比如矮占身高的

img里面的H

如本例中:
一开始嫁的个数为6个,占1/2,那么信息熵H(X)为-1/2log1/2-1/2log1/2 = -log1/2=1
现在假如我知道了一个男生的身高信息。
身高有三个可能的取值{矮,中,高}
矮包括{1,2,3,5,6,11,12},嫁的个数为1个,不嫁的个数为6个
中包括{8,9} ,嫁的个数为2个,不嫁的个数为0个
高包括{4,7,10},嫁的个数为3个,不嫁的个数为0个
先回忆一下条件熵的公式如下:
我们先求出公式对应的:
H(Y|X = 矮) = -1/7log1/7-6/7log6/7=0.592
H(Y|X=中) = -1log1-0 = 0
H(Y|X=高) = -1log1-0=0
p(X = 矮) = 7/12,p(X =中) = 2/12,p(X=高) = 3/12
则可以得出条件熵为:
7/120.592+2/120+3/12*0 = 0.345
那么我们知道信息熵与条件熵相减就是我们的信息增益,为
1-0.345=0.655
所以我们可以得出我们在知道了身高这个信息之后,信息增益是0.655

结论
我们可以知道,本来如果我对一个男生什么都不知道的话,作为他的女朋友决定是否嫁给他的不确定性有1.0这么大。
当我们知道男朋友的身高信息后,不确定度减少了0.655.也就是说,身高这个特征对于我们广大女生同学来说,决定嫁不嫁给自己的男朋友是很重要的。
至少我们知道了身高特征后,我们原来没有底的心里(1.0)已经明朗一半多了,减少0.655了(大于原来的一半了)。
那么这就类似于非诚勿扰节目里面的桥段了,请问女嘉宾,你只能知道男生的一个特征。请问你想知道哪个特征。
假如其它特征我也全算了,信息增益是身高这个特征最大。那么我就可以说,孟非哥哥,我想知道男嘉宾的一个特征是身高特征。因为它在这些特征中,对于我挑夫君是最重要的,信息增益是最大的,知道了这个特征,嫁与不嫁的不确定度减少的是最多的。
取自https://blog.csdn.net/mn_kw/article/details/79913786


二.代码实现

1.咱们先来构建我们的数据,我用的是python3(day1)

img

表头

1
datas_header = ['是否帅', '脾气是否好', '是否高', '是否上进', '结果']

数据

1
2
3
4
5
6
7
8
9
10
11
12
13
import numpy as np
datas = np.array([['帅', '不好', '矮', '不上进'],
['不帅', '好', '矮', '上进'],
['帅', '好', '矮', '上进'],
['不帅', '爆好', '高', '上进'],
['帅', '不好', '矮', '上进'],
['帅', '不好', '矮', '上进'],
['帅', '好', '高', '不上进'],
['不帅', '好', '中', '上进'],
['帅', '爆好', '中', '上进'],
['不帅', '不好', '高', '上进'],
['帅', '好', '矮', '不上进'],
['帅', '好', '矮', '不上进']])

标签

1
labels = np.array(['不嫁', '不嫁', '嫁', '嫁', '不嫁', '不嫁', '嫁', '嫁', '嫁', '嫁', '不嫁', '不嫁'])

2.明确代码流程,自顶向下设计,自底向上实现

img

3.实现香农熵

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
def get_shannon_entropy(labels: np.ndarray) -> float:
"""
计算香农熵
即H(X)=−∑(i=1,n)pi * log2pi
:param labels: shape = (n, ),n个标签
:return: float香农熵值
"""
num_labels = len(labels)
labels_dic = {} # 考虑到labels不一定是数字,我们用字典来装
shannon_entropy = 0
for label in labels: # 统计各label个数,本例中最后labels_dic格式为{"嫁":6, "不嫁":6}
if label not in labels_dic.keys():
labels_dic[label] = 0
labels_dic[label] += 1

for i in labels_dic.keys():
prob = labels_dic[i] / num_labels
shannon_entropy += - prob * np.log2(prob)

return shannon_entropy

img

4.实现条件熵(day2)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
def get_conditional_entropy(datas: np.ndarray, labels: np.ndarray) -> float:
"""
接下来的条件熵
即H(Y|X)=∑(i=1,n)piH(Y|X=xi) H就是上面的香农熵呀! Y就是嫁不嫁呀! X就是当前的特征呀!
本例中其中一个特征举例:高中矮,则表现为
H(Y|X) = p矮 * H(嫁|X = 矮) + p中 * H(嫁|X = 中) + p高 * H(嫁|X = 高)
:param datas:某个特征下所有人的取值,例如['矮', '矮', '矮', '高', '矮', '矮', '高', '中', '中', '高', '矮', '矮']
:param labels:该特征下嫁或不嫁的结果,例如['不嫁', '不嫁', '嫁', '嫁', '不嫁', '不嫁', '嫁', '嫁', '嫁', '嫁', '不嫁', '不嫁']
:return:
"""
num_data = len(datas)
conditional_entropy = 0
num_class = {} # 先得知道身高特征分成了几类(高中矮3类),每类占多少
for class_val in datas: # for循环完后结果为num_class = {"矮": 7, "中": 2, "高": 3}
if class_val not in num_class.keys():
num_class[class_val] = 0
num_class[class_val] += 1

for i in num_class.keys():
prob = num_class[i] / num_data
index = np.argwhere(datas == i).squeeze() # 以下两句获得datas中'矮'的下标并找出每个'矮'对应的labels值
# print(index.shape)
if index.shape != (): # 这里如果不这么处理会出问题,就是index只有一个数的情况,index = 2,type显示ndarray没有问题,shape却显示()而不是(1,)
i_labels = labels[index] # 造成这句的结果居然是i_labels = 'no' 而不是['no'],于是在后面遍历的时候变成了'n'和'o',这不是我们想要的
else:
i_labels = list([]) # 通过先限定i_labels的类型,使用append的方式添加,这时候i_labels就会是想要的['no'],后面遍历的时候就是遍历数组中的这一个元素
i_labels.append(labels[index])
conditional_entropy += prob * get_shannon_entropy(i_labels)
return conditional_entropy

img

5.获得最佳信息增益gain值和对应特征

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
def get_best_gain(datas: np.ndarray, labels: np.ndarray) -> (float, float):
"""
计算信息增益,并找出最佳的信息增益值,作为当前最佳的划分依据
即g(D,X)=H(D)−H(D|X) H就是香农熵呀! H(D|X)就是信息熵呀!
best_gain = max(g(D,X))
:param datas: shape(n, m) n条数据,m种特征
:param labels: shape = (n, ),n个标签
:return:返回最佳特征和对应的gain值
"""
best_gain = 0
best_feature = -1
(num_data, num_feature) = datas.shape
for feature in range(num_feature): #遍历每个特征,找出到底谁的信息增益值最大
current_data = datas[:, feature]
gain = get_shannon_entropy(labels) - get_conditional_entropy(current_data, labels) # 公式
if gain > best_gain:
best_gain = gain
best_feature = feature
return best_feature, best_gain

img

6.建树(day3)

由于要建树,树型结构就需要建立起来,本来也打算存在数据结构中,但是就在这时候看到了csdn上一篇博客里的树是这样用字典来保存

img博客地址https://blog.csdn.net/hongbin_xu/article/details/78516114

我的天!!!惊到我了,学!用字典建树

输入数据变化过程

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
def create_tree(datas_header: list, datas: np.ndarray, labels: np.ndarray) -> dict:
"""
get_best_gain()是给一组datas,labels计算一次最佳划分特征
要构建一棵决策树,那必须要在划分完剩下的数据集继续划分(递归过程),直到以下情况出现:
1.剩下全部结果都是相同的,那么直接作为结果。
如在本例中,假如最佳分组为帅或者不帅,帅对应的labels全部为嫁,则不用继续讨论后面的分组。同理,如果不帅中既有嫁和不嫁,那么需要继续递归
2.遍历完了所有特征,但是还是无法得到唯一标签,则少数服从多数。
假如在本例中,遍历到最后一组特征:是否上进,但上进的组里还是有嫁或者不嫁两种标签,且较多为嫁,那就让她嫁
:param datas_header:
:param datas:
:param labels:
:return:
"""
# 结束条件1
if list(labels).count(labels[0]) == len(labels):
return labels[0]
# 结束条件2
if len(datas) == 0 and len(set(labels)) > 1:
result_num = {}
for result in labels:
if result not in result_num.keys():
result_num[result] = 0
result_num[result] += 1
more = -1
decide = ''
for result, num in result_num.items():
if num > more:
more = num
decide = result
return decide

cur_best_feature_num, _ = get_best_gain(datas, labels)
cur_best_feature_name = datas_header[cur_best_feature_num]

# 首先知道该特征下有什么值 本例中class_val = {'帅', '不帅'}
class_val = set([data[cur_best_feature_num] for data in datas])
trees = {cur_best_feature_name: {}}
for val in class_val: # 逐一找出每个特征值的数据 本例中表现为含'帅'/'不帅'的数据
new_datas = [datas[i] for i in range(len(datas)) if datas[i, cur_best_feature_num] == val] # 用列表生成式,读作:遍历datas每行,找到每行的'是否帅'特征下值为'帅'的行,返回该行
new_labels = [labels[i] for i in range(len(datas)) if datas[i, cur_best_feature_num] == val]

new_datas = np.delete(new_datas, cur_best_feature_num, axis=1) # 删除最佳列,准备进入下一个划分依据,即删掉
new_datas_header = np.delete(datas_header, cur_best_feature_num)
# 递归:去掉该行该列再丢进
trees[cur_best_feature_name][val] = create_tree(list(new_datas_header), new_datas, np.array(new_labels))
return trees

img

7.预测,树建完了,预测时我们就要递归遍历字典(day4)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
def predict_result(trees_model: dict, input_data: np, datas_header: list) -> str:
"""
1.找字典中的第一个划分特征 本例中是'是否高'
2.在datas_header中找到高是第几个特征 本例中是第2个特征
3.在input_data中找到这个特征对应的值 比如要预测的数据中是否高(第二个特征)取值为矮
4.找到字典中的'矮'的值,如果为str(嫁不嫁),则直接返回结果,如果是字典,则进行下一个节点预测(递归)
:param trees_model:
:param input_data:
:param datas_header:
:return:
"""
cur_judge = list(trees_model.keys())[0] # '是否高'
num_feature = datas_header.index(cur_judge) # 本例中是第2个特征
cur_val = input_data[num_feature] # 比如要预测的数据中是否高(第二个特征)取值为矮
cur_tree = trees_model[cur_judge]
# print(type(cur_tree[cur_val]))
if type(cur_tree[cur_val]) == np.str_:
return cur_tree[cur_val]
return predict_result(cur_tree[cur_val], input_data, datas_header)

img

8.保存和读取模型文件

1
2
3
4
5
6
7
8
9
10
11
12
# 保存和读取函数
def store_tree(input_tree, filename):
import pickle
with open(filename, 'wb') as f:
pickle.dump(input_tree, f)
f.close()


def restore_tree(filename):
import pickle
with open(filename, 'rb') as f:
return pickle.load(f)

通过这棵树,我们就知道,帅与否并不是影响你勾搭小姐姐主要的因素,这给了作者足够的自信,好了,我去勾搭小姐姐啦!

img

代码已上传github,地址为https://github.com/wjwABCDEFG/wjw-DecisionTree

百度网盘地址为https://pan.baidu.com/s/18O1OpU1hPfHdoAh_qVF1-Q提取码7qdv

零零散散耗时4天,写文章又需要琢磨,感觉说的还不是很清楚,有不懂的可以交流,打算以后有空就实现一些简单的机器学习模型,代码菜鸟,风格不太好,多多包涵,如有错误,欢迎指正。

#
Your browser is out-of-date!

Update your browser to view this website correctly. Update my browser now

×