这两天在machinelearningmastery.com上学习Python实现CART(Classify And Regression Tree),把分类树从头到尾学习实现了一遍,虽然不是什么难事,还是想记录一下,就当增强增强记忆也好。
分类树逻辑上即为一些连环判断的组合,以Binary Tree的结构承载这个流程,以存在于非叶节点的数据的属性+值为判断条件,以存于各叶节点的值为判断结果。下图即为一个简单的决策树逻辑(图片来源:machinelearningmastery.com)。
算法实现分一下几个部分: 1、Gini函数 2、树内各节点的分割 3、树的建立 4、预测结果
Gini函数: Gini指数作为loss function用来衡量分组后数据“纯净性”(原文用的purity)的指标,判断数据正确分类的程度:
树内各节点的分割: 要将数据分类,首先要知道根据什么指标进行分类,即对于每一步的判断条件,应当找出最适合分类的属性及该属性下最适合的值——get_split()将来自要分割节点的所有数据的所有属性和所有值进行遍历分割,分别计算各拟分组的Gini指数,取能获得最小Gini指数的分割方式对该节点进行分割;test_split()即为每次遍历是根据给定的属性和值对数据集进行分割。
#Split dataset into groups by specific attribute and valuedef test_split(dataset,index,value): left, right = [], [] for row in dataset: if row[index] < value: left.append(row) else: right.append(row) return left, right#Split dataset into groups for every splitable node def get_split(dataset): dimen = len(dataset[0])-1 b_index, b_value, b_gini, b_group = 999, 999, 999, None class_values = list(set([row[-1] for row in dataset])) for index in range(dimen): for row in dataset: group = test_split(dataset,index,row[index]) gini = gini_index(group,class_values) if gini < b_gini: b_index, b_value, b_gini, b_group = index, row[index], gini, group return {'index':b_index, 'value':b_value, 'gini':b_gini, 'groups':b_group}树的建立: 在知道如何对每个节点进行合适分割之后,就要开始用递归的方式调用split()函数不断分割节点来建立整棵树。
考虑递归中的基本情况和需要递归的情况:
1、基本情况(节点分割结束,变为叶节点):分割后的节点没有左子节点或右子节点,;当本次分割后树的深度超出最大深度(max_depth,给定);当本次分割后子节点的数据量小于最小分类后数据量(min_size,给定)或子节点已经被完全正确分类(节点内的所有数据为同一类)。
2、需要递归的情况(子节点继续作为下一个父节点调用分割函数)。
#Make a node a terminal def to_terminal(group): result = [row[-1] for row in group] return max(set(result),key=result.count)#Split the whole tree by iterationdef split(node,max_depth,min_size,depth): left, right = node['groups'] del(node['groups']) #check for no left or right if not left or not right: node['left'] = node['right'] = to_terminal(left + right) return #check for max depth if depth >= max_depth: node['left'], node['right'] = to_terminal(left), to_terminal(right) return #process the left if len(left) <= min_size or len(set(row[-1] for row in left)) <= 1: #check for min size and already splited correctly node['left'] = to_terminal(left) else: node['left'] = get_split(left) split(node['left'],max_depth,min_size,depth+1) #process the right if len(right) <= min_size or len(set(row[-1] for row in right)) <= 1: #check for min size and already splited correctly node['right'] = to_terminal(right) else: node['right'] = get_split(right) split(node['right'],max_depth,min_size,depth+1)#Build a whole decision treedef build_tree(dataset,max_depth,min_size): root = get_split(dataset) split(root,max_depth,min_size,1) return root其中的to_terminal()函数实现将该节点变为叶节点,逻辑为以该节点数据中最大比例的该类作为叶节点的值。build_tree()为封装的建树函数,返回树的根节点。
预测结果: 在训练数据建好决策树之后,对测试数据利用决策树进行预测分类,逻辑即为利用存储在各非叶结点中的一系列判断条件进行从根节点到叶节点的预测:
#Predict the results of a set of data by trained decision treedef predict(node,row): if row[node['index']] < node['value']: if isinstance(node['left'],dict): return predict(node['left'],row) else: return node['left'] else: if isinstance(node['right'],dict): return predict(node['right'],row) else: return node['right']最后将包括训练和预测的函数全部封装到一个decision_tree()函数中,实现算法。
#The (Classify) Decision Tree Algorithmdef decision_tree(train_data,test_data,max_depth,min_size): tree_root = build_tree(train_data,max_depth,min_size) predicted = [] for row in test_data: predicted.append(predict(tree_root,row)) return predicted学习与代码参考:machinelearningmastery.com
新闻热点
疑难解答