I recently had an idea for a new query pattern for vector databases. I donโt know if makes any sense to embed into a DB, but I need to sketch out an idea to determine if itโs computationally viable. Letโs start with a basic pattern for querying vector databases:
Where a query vector is used to search a database of vectors and the database returns the vector that minimizes . My idea is very simple, what if instead of querying for a single vector we queried with a of vectors () and minimized ?
A naive solution may involve calculating the sums of distances between each vector in and , maintaining a capped priority queue of the top vectors, and producing the results after scanning the full DB. Even if we add an approximate nearest neighbors index to reduce the total data scanned, this operation is still linear with respect to the size of . Not Good!
This is a hefty claim. Intuitively I this is true with distance. I am less sure it holds for other metrics. This property most likely hold when we consider that weโll have an index built from arbitrary partitions of . Regardless, weโre deep into territory.
We can do much better by querying the DB for vectors in the neighborhood of some theoretical vector which minimizes the sum of distances to all vectors in . The vectors nearest to should be the same as the vectors that minimize .
The vector is dependent on and the distance metric weโre minimizing. Thus, if we can calculate quickly, we can implement a function which performs a single regardless of the size of . To enable this, we must be able to find efficiently for several common distance metrics.
When distance is the query metric we find quite easily. Letโs represent as a matrix of the query vectors and as a candidate vector in dimensions. The sum of distances is given by:
After setting , we have the optimal value for each as the of the values of the dimension from the vectors in .
Cosine distance is a bit more challenging because the distance between some query vector in depends on the magnitude of .
If we normalize all vectors in our database so that , the distance calculation is simplified dramatically. Now we just need to calculate the dot product of and . Using unit vectors also leads us to a useful relationship between the and cosine distance. When normalized, they share the same minimum. The method used to calculate for can be replicated for cosine distance!
To verify these results, Iโll demonstrate that SGD doesnโt improve the loss relative to these (much cheaper) methods. It would be very bad if we actually had to do SGD to find .
Example 1. Using SGD and optimizing for cosine distance, SGD doesnโt improve loss compared to the simple method (normalize + avg. val of dimension).
I fiddled with torch a good bit, changing learning rate etc. I do not think this matters in this problem, but perhaps thereโs an edge case in higher dimensions Iโm not seeing.
def torch_cosine_f(A, B):
return (1 - torch.mv(A, B) / (A.norm(dim=1) * B.norm())).sum()
A = torch.normal(0, 1, size=(16, 4))
A = A / A.norm(dim=1, p=2, keepdim=True)
A_mean = torch.mean(A, axis=0)
A_mean = A_mean / A_mean.norm()
B = torch.nn.Parameter(A_mean.clone(), requires_grad=True)
sgd = optim.SGD([B], lr=0.05)
for _ in range(2**6):
sgd.zero_grad()
loss = torch_cosine_f(A, B)
loss.backward()
sgd.step()
# initial (loss): [-0.2979, -0.5264, -0.2663, 0.7505], (10.76275)
# soln' SGD (loss): [-0.2979, -0.5264, -0.2663, 0.7505], (10.76275)Example 2. Using SGD and optimizing for squared distance, SGD reduces absolute loss!? Not quite. Notice that the resulting vector is just a scaled version of the vector produced by the simple method (normalize + avg. val of dimension)
def torch_l2_f(A, B):
return torch.cdist(A, B, p=2.0).pow(2).sum()
B = torch.nn.Parameter(A_mean.clone().unsqueeze(0), requires_grad=True)
sgd = optim.SGD([B], lr=0.05)
for _ in range(2**6):
sgd.zero_grad()
loss = torch_l2_f(A, B)
loss.backward()
sgd.step()
# initial (loss): [-0.2979, -0.5264, -0.2663, 0.7505], (21.52550)
# sol'n SGD (loss): [-0.0975, -0.1723, -0.0872, 0.2457], (14.28570)So where to go from here? This method is computationally viable, but Iโm not sure that it makes sense in practice. Is this useful for RAG? No clue. Is this useful in any context? Again, no clue.