This article originated from personal public account: TechFlow, original is not easy, for attention
Today is the 15th article of machine learning. In the previous article, we talked about related optimization of Kmeans and the famous EM algorithm. We are going to look at a data structure that is often used in machine learning — KD-tree.
From line segment tree to KD tree
Before we talk about KD trees, let’s first understand the concept of line segment trees. Segment trees are not common in machine learning. As high-performance data structures, they often appear in algorithm competitions. The essence of a line segment tree is a balanced binary tree that maintains a segment.
For example, here is a classic line segment tree:
It is not hard to see from the figure below that this line segment tree maintains the maximum value of an interval. For example, if the root is 8, the maximum value of the entire interval is maintained, and the value of each intermediate node is the maximum value of all elements in the subtree whose root it is.
With the line segment tree, we can go toTime to calculate a certainThe maximum value of a continuous interval. For example, take a look at the following figure:
So when we’re looking for the maximum value of the bounded interval, we just need toFind the intermediate node that can cover the intervalWill do. We can see that the subtree of the two nodes that are boxed in red covers this interval, so the maximum value of the entire interval is the maximum value of these two elements. In this way, we put a needThe search problem is reduced toNot only that, but we candoUpdates within complexityThat is, we can not only quickly query, but also update the elements in the line segment.
Of course, line trees are widely used, and there are many variations, but we won’t go too far here. Those who are interested can look forward to the algorithms and data structures section on Wednesday, which will be covered in a later article. Here, we just need to get a general impression of what line segment trees actually accomplish.
A line segment tree maintains a line segment, that is, the elements of the interval, that is, a one-dimensional sequence. What if we expanded the dimensions of the data, to more dimensions?
Yes, you guessed right. To some extent, we can think of kD-tree as a line segment Tree extending into multi-dimensional space.
KD – Tree are defined
Let’s take a look at the specific definition of kD-tree, where K refers to k-dimensional space, D naturally refers to dimension, that is, kD-tree means k-dimensional Tree.
When we build a line segment tree, it is actually a recursive tree building process. We divide the current line segment into two parts each time, and then build the left and right subtrees with the data divided into two parts respectively. To get a more intuitive feel for this, we can simply write pseudocode:
class Node:
def __init__(self, value, lchild, rchild):
self.value = value
self.lchild = lchild
self.rchild = rchild
def build(arr):
n = len(arr):
left, right = arr[: n//2], arr[n//2:]
lchild, rchild = self.build(left), self.build(right)
return Node(max(lchild.value, rchild.value), lchild, rchild)
Copy the code
Let’s take a two-dimensional example, a two-dimensional plane with a number of points.
Let’s first pick a dimension and split the data in half, let’s say we pick the X-axis. We sort all the data by the value of the X-axis, pick the midpoint and divide it into two parts.
The points on the left and right sides of this line are divided into two subtrees. For the data of these two parts, we change a dimension, that is, choose the Y-axis for partitioning. Again, we sort, and then we find the middle point, and we split it in half again. We can get:
We repeat the process until the point is indistinguishable, and we plot all the data with (imprecise) coordinates for better visibility.
If we think of space as a generalized interval, it works the same way as a line tree. You end up with a treePerfect binary tree, because we choose the midpoint of the data set to divide each time, we can ensure that the length from root to leaf node will not exceed.
After we substitute in the above coordinates, our final kD-tree looks something like this:
KD – Tree
In the process of building a tree, every time our tree goes down a level, we measure it in another dimension. The reason is very simple, because we want this tree to be able to express the k-dimensional space very well, so that we can quickly query according to different dimensions.
In some implementations, we calculate the variance of each dimension and then select the dimension with the larger variance for shard. This is done naturally because the dimensions with large variances indicate that the data is relatively scattered, and the data can be distinguished more clearly after segmentation. But I personally don’t think it makes a lot of sense to do that, because it’s expensive to calculate variance. So here we went with the simplest method — take turns.
So we start at the roots, we pick the 0th dimension for sorting and partitioning, and then we go to the depth 1 layer, we pick the first dimension, the depth 2 layer, we pick the second dimension, and so on. When the tree depth exceeds K, we model the tree depth.
Now that we know this, we can write the code for the kD-tree, which is very similar to the binary Tree code above, but with more dimension processing.
# External exposed interface
def build_model(self, dataset):
self.root = self._build_model(dataset)
# Ignore it for now and talk about it later
self.set_father(self.root, None)
# Internal implementation interface
def _build_model(self, dataset, depth=0):
if len(dataset) == 0:
return None
Get the current one-dimensional segmentation by modulo K of tree depth
axis = depth % self.K
m = len(dataset) // 2
# Sort by axis
dataset = sorted(dataset, key=lambda x: x[axis])
# Split the data in two
left, mid, right = dataset[:m], dataset[m], dataset[m+1:]
# Recursive build tree
return KDTree.Node(
mid[axis],
mid,
axis,
depth,
len(dataset),
self._build_model(left, depth+1),
self._build_model(right, depth+1)
)
Copy the code
So we have built the tree, but we need to access the parent of the node in the subsequent query, so we need to assign a pointer to the parent to each node. We could have written this in the tree building code, but it would have been a little more complicated, so I broke it out as a separate function to assign to each node. For the root node, because it has no parent node, it is assigned None.
Let’s look at set_father, which is a simple recursive tree traversal:
def set_father(self, node, father):
if node is None:
return
# assignment
node.father = father
# recurse left and right
self.set_father(node.lchild, node)
self.set_father(node.rchild, node)
Copy the code
Quick Batch Query
Kd-tree is used to build a Tree. It can be used to build a TreeSeveral samples closest to the sample are obtained in a single query. In a evenly distributed data set, we can find theTime to complete the query, but for special cases it may take longer, but it is still much faster than our naive approach.
It is easy to find that kD-tree is a widely used scenario to optimize KNN algorithm. As we have mentioned in the previous article introducing KNN algorithm, KNN algorithm needs to traverse the whole data set during prediction, then calculate the distance between each sample in the data set and the current sample, and select the nearest K, which requires a lot of overhead. However, using KD-tree, we can directly find K recent samples in a query, thus greatly improving the efficiency of KNN algorithm.
So how does this query work?
This query is based on the implementation of recursion, so for those who are not familiar with recursion, it may be difficult at first, you can read the previous article on recursion.
First of all, we find the leaf node on kD-tree recursively, that is, find the subspace where the sample is. This lookup should be pretty easy, but essentially we’re constantly comparing the current sample to the split line to see if it’s to the left or right of the split line. The same is true for element lookups in binary search trees:
def iter_down(self, node, data):
If it is a leaf node, return
if node.lchild is None and node.rchild is None:
return node
If the left node is empty, the right node is recursed
if node.lchild is None:
return self.iter_down(node.rchild, data)
# Similarly, recurse the left node
if node.rchild is None:
return self.iter_down(node.rchild, data)
# if neither is empty, use the dividing line to determine whether it is left or right
axis = node.axis
next_node = node.lchild if data[axis] <= node.boundray else node.rchild
return self.iter_down(next_node, data)
Copy the code
We found the leaf node, which actually represents a small space in the sample space.
So let’s actually walk through the process, and let’s say we want to find three points. First, we create a candidate set to store our answers. When we find the leaves, there’s only one point in this region, and we add it to the candidate set.
In the figure above, the purple x represents the sample we are looking for. After we find the leaf node, we add the current point to the candidate set in both cases. The first case is that the candidate set is empty, which means we haven’t reached K yet, where K is the number of queries we have, which is 3. The second case is that the distance between the current point and the sample is less than the largest one in the candidate set, so we need to update the candidate set.
Once the point has been visited, we mark it to indicate that the point has been visited. At this point, we need to determine whether the search in the whole tree is over. If the current node is already the root node, it means that our traversal is over, so we return the candidate set, otherwise, it means that there is no, we need to continue searching. In the figure above, green indicates that the sample has been put into the candidate set, and yellow indicates that it has been visited.
As our search is not over, we need to continue. To continue searching, you need to determine the distance between the sample and the current line to determine if there is a possible answer on the other side of the line. Since the leaf node doesn’t have the other side, let’s move one up to jump to its parent node.
We calculate the distance and look at the candidate set, at which point the candidate set is not full, we join the candidate set and mark it as visited. It has a dividing line, but it also has no nodes on the other side, so it also skips.
We go up to its parent node, we perform the same judgment and find that the selection is available at this point, so we add it to the answer:
However, when we judge the distance of the dividing line, we find that the sample from the dividing line is smaller than the maximum distance in the previous candidate set, so there is a possibility of the answer on the other side of the dividing line:
Here d1 is the distance from the sample to the dividing line, and d2 is the distance from the sample to the farthest point in the candidate set. Since it’s closer to the split line, there’s probably an answer on the other side of the split line, and we need to search the subtree on the other side of the split line all the way to the leaf node.
We find the leaf node, calculate the distance, and find that at this time the selection is full, and its distance is greater than any of the answers in the candidate set, so it cannot constitute a new answer. So we just mark it as visited and don’t add it to the candidate set. Again, we continue up to its parent node:
After the comparison, we found that the distance from data was less than that of the largest candidate set, so we updated the candidate set to remove the answer that was farther than that. We then repeat the process until we reach the root node.
The candidate set has not been updated since there are no closer points, and finally the three points marked green in the figure above are the answer.
The logic in the recursive function is almost as good as the code in Python:
def query(node, data, answers, K):
Check whether node has been accessed
if node.visited:
# tag access
node.visited = True
Calculate the distance between data and the midpoint of a node
dis = cal_dis(data, node.point)
Update the answer if it is less than the maximum value in the answer
if dis < max(answers):
answers.update(node.point)
# Calculate the distance between data and the dividing line
dis = cal_dis(data, node.split)
# If it is less than the maximum distance, there may be an answer on the other side
if dis < max(answers):
Get the sibling of the current node
brother = self.get_brother(node)
if brother is not None:
# Search down to the leaf node, starting from the leaf node
leaf = self.iter_down(brother, data)
if leaf is not None:
return self.query(leaf, data, answers, K)
Exit if you have reached the root node
if node is root:
return answers
# recursive parent node
return self.query(node.father, data, answers, K)
else:
if node is root:
return answers
return self.query(node.father, data, answers, K)
Copy the code
The resulting code is not much different from the above, but we use priority queues where we compare the distance to the maximum distance in the answer. The rest of the place is almost the same, I also posted to give you a feel:
def _query_nearest_k(self, node, path, data, topK, K):
We use set to record the paths visited, instead of marking the nodes directly
if node not in path:
path.add(node)
# Compute Euclidean distance
dis = KDTree.distance(node.value, data)
if (len(topK) < K or dis < topK[- 1] ['distance') :
topK.append({'node': node, 'distance': dis})
Get topK using priority queue
topK = heapq.nsmallest(K, topK, key=lambda x: x['distance'])
axis = node.axis
# Secant lines are straight lines, directly calculate the coordinate difference
dis = abs(data[axis] - node.boundray)
if len(topK) < K or dis <= topK[- 1] ['distance'] :
brother = self.get_brother(node, path)
if brother is not None:
next_node = self.iter_down(brother, data)
if next_node is not None:
return self._query_nearest_k(next_node, path, data, topK, K)
if node == self.root:
return topK
return self._query_nearest_k(node.father, path, data, topK, K)
else:
if node == self.root:
return topK
return self._query_nearest_k(node.father, path, data, topK, K)
Copy the code
We can understand this logic, but there is a question: why don’t we add a visited field in node, but pass in a set to maintain the visited node? This logic is hard to figure out just by looking at the code, you have to experiment to understand. It is also possible to add a field to a node. If we do this, we will have to manually recurse again after the search and set all nodes in the tree to false. Otherwise, the next time the search is performed, some nodes will already be marked as True, obviously affecting the results. Manually restoring these values after the query is expensive, so the transformation approach uses set for access judgment.
The iter_down function is the same as the leaf-finding function we posted above, which is to find the leaf node of the current subtree. If I remember correctly, this is also the first time in our article that we have called another recursion within a recursion. This can be difficult for beginners to understand. I personally recommend that you try it yourself by drawing a KD-Tree on a piece of paper and doing a manual simulation, and then you’ll know what the logic is. It’s also a great way to think and learn.
To optimize the
Once we understand the logic of building and finding the entire KD-tree, let’s consider optimization.
Look down this code can initially find two can be optimized, the first place is when we build the tree. And every time we recurse because we want to split the data in half, we areUsing sort, and each sort isComplexity, which is actually not low. Actually, if you think about it, we don’t have to sort. WeYou just have to pick the first n/2 numbers sorted by some axis. In other words, this is a selection problem, not a sorting problem, so you can imagine that we can optimize using the quick selection method that we talked about earlier. Using quick selection, we canComplete data separation within the time of.
The other place we use is when we’re looking for K neighboring pointsPriority queueMaintain the answers in the candidate set, so that we can update the answers. Similarly, the priority queue gets topKThe complexity of. And you can optimize this, too. A good idea is toUse a heap instead. You can do itIs more efficient than heAPQ’s NSmallest method.
conclusion
Here, we have almost finished talking about the principle of KD-tree, we have the tree building and query function can be used in KNN algorithm optimization. However, our current KD-tree only supports tree building and query. What should we do if we want to insert or delete data in the set? Does every change create a new tree? Obviously not, but what should we do when inserting and removing nodes will cause changes in the structure of the tree and potentially cause the tree to be out of balance?
Let’s keep it in the dark for now. The content will be in the next article, so don’t miss it if you are interested.
Finally, I put the complete code of KD-tree on Ubuntu. Paste, if you want to see the complete source code, please reply to kD-tree on our official account.
Today’s article is all of these, if you feel that there is a harvest, please feel free to point attention or forward, your help is very important to me.