<- Back to main page12/12/2024

Accelerate GPT Output Embedding computations with a Vector Index

Motivation

The input and output embeddings of transformer-based LLMs are massive. For example, GPT-2's embeddings make up 38.5M (50257 x 768) out of 124M total parameters (~30%).

To generate text with these models, these embeddings are used to calculate the probabilities from what we sample the next token. A popular sampling method is top-k sampling, which samples from the k most probable tokens, preventing the selection of highly unlikely tokens.

For this sampling strategy we're only interested in the 50 (huggingface default for k) most likely tokens out of 50,257 possible tokens - less than 0.1%. Is it possible to get only these 50 tokens without calculating logits for all 50k tokens?

That's similar to the problem faced by vector databases, where they use vector indices to accelerate the search process.

Vector Indices

Vector indices come in two main groups:

  • Clustering based: IVF (Inverted File Index)
  • Graph based: HNSW (Hierarchical Navigable Small World)

IVF struggles with the high dimensionality of the embeddings and can fail when the most probable tokens are outside the searched cluster, leading to incorrect predictions.

HNSW performs well with the high-dimensional embeddings and offers a higher throughput. The long index build times aren’t an issue for inference and the performance can be tuned using the index parameters ef and M.

I tried both, but only got good results for HNSW.

HNSW Layer

The HNSW layer is built on top of hnswlib and outputs the indices and logits of the most probable k tokens, directly from the last hidden state.

The time complexity of search on a HNSW index is O(log(n)).

class HNSWIndexEmbedding():
    def __init__(self, weight, M=32, ef=100, ef_construction=100):
        self.index = hnswlib.Index(space='ip', dim=weight.shape[1])
        self.index.init_index(max_elements=weight.shape[0], M=M, ef_construction=ef_construction)
        self.index.add_items(weight)
        self.index.set_ef(ef)

    def query(self, x, k):
        indices, distances = self.index.knn_query(x, k)
        return 1 - distances, indices

Performance

The performance measurements are done on the CPU for GPT-2's output embeddings, comparing the HNSW layer (with parameters k = 50, M = 32, ef = 100/200 ef_construction = 150) to the matrix multiplication between the final hidden state and the output embedding matrix (torch/cpu):

BSpeedup ef = 100Speedup ef = 200
141.2x23.0x
89.0x4.4x
644.9x2.5x
2563.9x2.0x

The speedup for text generation with a GPT2 model is between 2% and 6% (batch size = 4, CPU, full transformer model).

Prediction Quality

Due to the approximate nature of HNSW, this index-based embedding layer can't guarantee the exact top-k elements and may occasionally miss a token. However, since we randomly sample from these tokens, missing one token might not be a major issue, as long as the most important ones are included.

The accuracy of these top k tokens can be influenced by the HNSW parameters M and ef, where higher values correspond to better accuracy but also higher search times.

I couldn't find a degradation in the quality of the generated text, but more work is needed to quantify that.

Next Steps

There are several ways to move forward:

  • Get model training working (a guess of the sum of the remaining logits might be needed for softmax).
  • Move everything to the GPU. Since HNSW isn’t easily parallelizable, something like CAGRA might be needed.
  • Perform evaluations on quality and performance.
  • Implement top-p sampling.

If you're interested in working on any of these, check out the git repo and consider contributing.

Accelerate.

You can comment on this blog post here.

For updates follow me on 𝕏/twitter: martinloretzzz

I'm available for opportunities focused on solving AGI and all the challenges along the way, especially interested in architectures, frameworks and performance optimizations like this one. Feel free to reach out: work@martinloretz.com.

© 2025 Martin Loretz - Imprint - Privacy