Vector Database from Scratch

pytorch
Implementing Approximate Nearest Neighbours Oh Yeah (ANNOY)
Author

Sachin Abeywardana

Published

March 22, 2022

Introduction

Now that embeddings are becoming a vital part of search algorithms, the next question is how do we do that at scale. There are a lot of vendor Vector Databases popping up, and here we will explore one of those algorithms, ANNOY. We will be implement Approximate Nearest Neighbours Oh Yeah from scratch. We will use a synthetic dataset of a million points (N) and 768 dimensions (D). If we have K query points the run time of brute force search is \(O(KND)\). The ANNOY algorithm aims to bring that down to \(O(K \log N D)\).

Disclaimer: More often than not you will find that brute force search is fast enough. Especially if the number of vectors you have is <1M. If you have a GPU you can stretch this even further due to the embarassingly parallel nature of matrix multiplication.

A vector database as generated by Stable Diffusion > A vector database as generated by Stable Diffusion

ANNOY algorithm

The premise of the algorithm lies in recursively partitioning the space of data points until min_leaf number of data points are left in that sub-space. The following high level steps shows how to construct a tree. 1. Initialise the list of indices to include all data points. 2. Choose a subset of x based on indices. 3. If the number of data points in subset is less that min_leaf stop. 4. Choose two data points, complete randomly. Store these two. 5. Choose data points close to first data point, and set indices to that. Go to 2. 6. Choose data points close to second data point, and set indices to that. Go to 2.

Code
class AnnoyTree:
    def __init__(self, max_depth, min_leaf, dim):
        self.max_depth = max_depth
        self.min_leaf = min_leaf
        self.labels = None
        self.max_level = 2 ** max_depth
        # self.centers = torch.zeros(self.max_level + 1, dim)
        self.centers = {}
        self.leaf = {}
        
    def fit(self, x, idx=None, current_label=0):
        if self.labels is None:
            self.labels = np.zeros(len(x), dtype=np.int32)
            idx = self.labels == 0
            
        next_label = 2 * current_label + 1
        
        x_subset = x[idx]
        if len(x_subset) <= self.min_leaf or current_label >= self.max_level:
            self.leaf[next_label] = x_subset
            return
        
        # choose 2 points at random
        center_idx = np.random.choice(len(x_subset), 2, replace=False)
        x_centers = x_subset[center_idx]
        self.centers[next_label] = x_centers # save centers
        self.labels[idx] = next_label + (x_subset @ x_centers.T).argmax(dim=-1) # trick of 2n + 1, 2n + 2
        
        # assign left to cluster l
        self.fit(x, self.labels == next_label, next_label)
        self.fit(x, self.labels == next_label + 1, next_label + 1)
        
    def predict(self, x: torch.FloatTensor) -> torch.FloatTensor:
        vecs, similarities = zip(*[self.get_closest(row, 0) for row in x])
        return torch.stack(vecs), similarities
    
    def get_closest(self, x: torch.FloatTensor, idx: int) -> torch.FloatTensor:
        current_index = 2 * idx + 1
        if current_index in self.leaf:
            val, idx = (x @ self.leaf[current_index].T).topk(1)
            return self.leaf[current_index][idx.item()], val.item()
        # closest_index = 2 * idx + 1 + (x @ self.centers[[2 * idx + 1, 2 * idx + 2]].T).argmax(dim=-1)
        closest_index = current_index + (x @ self.centers[current_index].T).argmax(dim=-1)
        return self.get_closest(x, closest_index.item())

fit method

Let’s go through the fit method in the above class.

x_subset = x[idx]
if len(x_subset) <= self.min_leaf or current_label >= self.max_level:
        self.leaf[2 * current_label + 1] = x_subset
        return

The above code saves the subset of datapoints if the conditions for a leaf are met. This is done so that we can compare a query datapoints against our data at a leaf level. Note that this does mean we have \(O(ND)\) storage costs.

Note that we can use the \(2n+1, 2n+2\) trick to make sure that we don’t overlap labels. This also ensures that if we need a parent label we can simply do current_label // 2 to get to the parent label. This avoids us needing to have left and right nodes.

center_idx = np.random.choice(len(x_subset), 2, replace=False)
x_centers = x_subset[center_idx]
self.centers[next_label] = x_centers # save centers
self.labels[idx] = next_label + (x_subset @ x_centers.T).argmax(dim=-1)

This code block chooses 2 datapoints randomly and stores them. It also uses this line to assign what the level ought to be from 2n+1, 2n+2, next_label + (x_subset @ x_centers.T).argmax(dim=-1). This is done since argmax will return 0 or 1 and we simply add that to 2n+1 to get the child label.

The final two lines simply recursively calls the fit method until a stop condition is met.

predict method

The predict method takes the tree constructed in above step and compares them against the stored branches until a leaf node is reached. Once at a leaf node it does a brute force search to get the closest element in that block.

Code
tree = AnnoyTree(15, 1000, x.shape[1])
tree.fit(x)

Results

As can be seen below, to predict closest vector on 1000 vectors takes 344ms while a full brute force search takes 17 seconds. That’s a 50x scale up in speed.

%%time
vecs, similarities = tree.predict(x_new)
CPU times: user 342 ms, sys: 3.23 ms, total: 345 ms
Wall time: 344 ms
%%time
max_similarity, max_idx = (x_new @ x.T).topk(1, dim=-1)
CPU times: user 33.1 s, sys: 1.39 s, total: 34.5 s
Wall time: 16.5 s

Given the actual maximum similarity below, we can see that approximate method captures a close vector, but not the best. If you are wondering why the numbers are relatively small (~0.18) keep in mind that two random vectors are highly likely to be orthogonal and very close to zero the higher the number of dimensions

max_similarity.squeeze()[:10]
tensor([0.1821, 0.1769, 0.1651, 0.1692, 0.1798, 0.1711, 0.1764, 0.1796, 0.1751,
        0.1716])
Code
similarities[:10]
(0.14211197197437286,
 0.10861558467149734,
 0.12875324487686157,
 0.13691002130508423,
 0.12784186005592346,
 0.13487347960472107,
 0.11535458266735077,
 0.11262157559394836,
 0.11334729194641113,
 0.11169558018445969)

Random Forest Approach

In similar spirit to random forests, we can easily extend a single tree into multiple trees. Below we construct 10 trees. Unlike random forest where we average results across trees, here we take the data point with maximum similarity across the best datapoints chosen by each tree.

N_TREE = 10
trees = [AnnoyTree(15, 1000, x.shape[1]) for _ in range(N_TREE)]
for tree in tqdm(trees):
    tree.fit(x)
vecs, similarities = zip(*[tree.predict(x_new) for tree in tqdm(trees)])

The similarity scores shown below are better than the original numbers obtained.

torch.stack([torch.Tensor(similarity) for similarity in similarities]).amax(dim=0)[:10]
tensor([0.1430, 0.1509, 0.1453, 0.1444, 0.1554, 0.1476, 0.1374, 0.1420, 0.1582,
        0.1419])