KD Tree

A KD-tree (k-dimensional tree) is a binary tree that partitions k-dimensional space by alternating splitting axes. It enables fast nearest-neighbor and range queries in low dimensions — the workhorse of scipy.spatial.cKDTree and many computer-graphics systems.

Complexity

Average Worst Space
AccessSearchInsertionDeletion AccessSearchInsertionDeletion Worst
Θ(log n) Θ(log n) Θ(log n) Θ(log n) O(n) O(n) O(n) O(n) O(n)

How it works

At depth d, split by axis d mod k. Build: pick the median along that axis as the root, recurse on each half. Nearest-neighbor query: descend to the leaf containing the target; on the way back up, check whether the bounding hyperplane could contain a closer point and recurse into the other side if so. O(log n) average for low k; degrades sharply when k > ~20 (curse of dimensionality).

Python implementation

class Node:
    __slots__ = ("point", "axis", "left", "right")
    def __init__(self, point, axis):
        self.point, self.axis = point, axis
        self.left = self.right = None


def build(points, depth=0):
    if not points: return None
    k = len(points[0])
    axis = depth % k
    points.sort(key=lambda p: p[axis])
    mid = len(points) // 2
    n = Node(points[mid], axis)
    n.left = build(points[:mid], depth + 1)
    n.right = build(points[mid + 1:], depth + 1)
    return n


def _dist2(a, b):
    return sum((x - y) ** 2 for x, y in zip(a, b))


def nearest(node, target, best=None):
    if node is None: return best
    d = _dist2(node.point, target)
    if best is None or d < best[1]:
        best = (node.point, d)
    diff = target[node.axis] - node.point[node.axis]
    near, far = (node.left, node.right) if diff < 0 else (node.right, node.left)
    best = nearest(near, target, best)
    if diff * diff < best[1]:
        best = nearest(far, target, best)
    return best


pts = [(2, 3), (5, 4), (9, 6), (4, 7), (8, 1), (7, 2)]
tree = build(pts)
print(nearest(tree, (9, 2)))   # ((8, 1), 2)

Trade-offs

← Back to Algorithms