A decision tree (ID3) was introduced in a previous article. A brief review: ID3 selects the best feature to segment the data every time, and the judgment principle of the best feature is realized through information gain. After the data is sharded according to a certain feature, the feature will not be used in the future data set sharding, so there is the problem of too fast sharding. ID3 algorithm is not yet able to deal with continuity features. Here are some other algorithms:
CART classification regression tree
CART, short for Classification And Regerssion Trees, can handle both Classification And regression tasks.
CART tree construction algorithm is similar to the construction method of ID3 decision tree, which directly presents the construction process of CART tree. First, similar to ID3, the data structure of dictionary tree is adopted, which contains the following 4 elements:
- Characteristics to be sliced
- Eigenvalues to be sliced
- The right subtree. It can also be a single value when no longer needed to be sharded
- Left subtree, like right subtree.
The process is as follows:
- Search for the most suitable segmentation feature
- If the data set cannot be split, the data set is used as a leaf node.
- The data set is divided into two parts
- Repeat steps 1, 2,3 for the split dataset 1 to create the right subtree.
- Repeat steps 1, 2, 3 for the split dataset 2 to create the left subtree.
Obvious recursive algorithm.
Split the data set by data filtering and return two subsets.
def splitDatas(rows, value, column):
# splitDatas by value, column
# return 2 part (list1, list2)
list1 = []
list2 = []
if isinstance(value, int) or isinstance(value, float) :for row in rows:
if row[column] >= value:
list1.append(row)
else:
list2.append(row)
else:
for row in rows:
if row[column] == value:
list1.append(row)
else:
list2.append(row)
return list1, list2
Copy the code
Dividing data points
Creating a binary decision tree is essentially a recursive partitioning of the input space.
The code is as follows:
# gini()
def gini(rows):
Calculate gini (Calculate gini)Length = len(rows) Results = calculateDiffCount(rows) IMP = 0.0for i in results:
imp += results[i] / length * results[i] / length
return 1 - imp
Copy the code
Build a tree
def buildDecisionTree(rows, evaluationFunction=gini):
# Recursively build the decision tree, stop regression when gain=0
# build decision tree bu recursive function
# stop recursive function when gain = 0
# return treeCurrentGain = evaluationFunction(rows) column_lenght = len(rows[0]) rows_length = len(rows) best_gain = 0.0 best_value = None best_set = None# choose the best gain
for col in range(column_lenght - 1):
col_value_set = set([x[col] for x in rows])
for value in col_value_set:
list1, list2 = splitDatas(rows, value, col)
p = len(list1) / rows_length
gain = currentGain - p * evaluationFunction(list1) - (1 - p) * evaluationFunction(list2)
if gain > best_gain:
best_gain = gain
best_value = (col, value)
best_set = (list1, list2)
dcY = {'impurity': '%.3f' % currentGain, 'sample': '%d' % rows_length}
#
# stop or not stop
if best_gain > 0:
trueBranch = buildDecisionTree(best_set[0], evaluationFunction)
falseBranch = buildDecisionTree(best_set[1], evaluationFunction)
return Tree(col=best_value[0], value = best_value[1], trueBranch = trueBranch, falseBranch=falseBranch, summary=dcY)
else:
return Tree(results=calculateDiffCount(rows), summary=dcY, data=rows)
Copy the code
The function of the above code is to first find the best location for data set segmentation and split the data set. Then recursively build the entire tree of the image above.
pruning
In the learning of decision tree, sometimes there are too many branches in the decision tree, so it is necessary to remove some branches to reduce over-fitting. The process of avoiding overfitting through the complexity of the decision tree is called pruning. Post pruning requires the generation of a complete decision tree from the training set and then the bottom-up examination of non-leaf nodes. The test set is used to determine whether to replace the subtree corresponding to the node with a leaf node. The code is as follows:
def prune(tree, miniGain, evaluationFunction=gini):
When a gain < mini gain, merge the trueBranch and falseBranch
if tree.trueBranch.results == None:
prune(tree.trueBranch, miniGain, evaluationFunction)
if tree.falseBranch.results == None:
prune(tree.falseBranch, miniGain, evaluationFunction)
iftree.trueBranch.results ! = None and tree.falseBranch.results ! = None: len1 = len(tree.trueBranch.data) len2 = len(tree.falseBranch.data) len3 = len(tree.trueBranch.data + tree.falseBranch.data) p =float(len1) / (len1 + len2)
gain = evaluationFunction(tree.trueBranch.data + tree.falseBranch.data) - p * evaluationFunction(tree.trueBranch.data) - (1 - p) * evaluationFunction(tree.falseBranch.data)
if gain < miniGain:
tree.data = tree.trueBranch.data + tree.falseBranch.data
tree.results = calculateDiffCount(tree.data)
tree.trueBranch = None
tree.falseBranch = None
Copy the code
When the gain of the node is less than the given mini gain, the two nodes are combined.
Finally, the code to build the tree:
if __name__ == '__main__': dataSet = loadCSV() decisionTree = buildDecisionTree(dataSet, evaluationFunction=gini) prune(decisionTree, R = false (test_data, decisionTree)print(r)
Copy the code
We can print decisionTree to build a decisionTree like the one shown above. And then I’m going to test a bunch of data and see if I can get the classification right.
See github: CART for the complete code and data set
Conclusion:
- CART decision tree
- Split data set
- Recursively create a tree
A python source implementation of the CART Decision Tree