Accelerate CPU based LLM Inference with a Vector Index on the Output Embeddings(up to 30%).
This is an updated and expanded version of the blog post Accelerate GPT Output Embedding computations with a Vector Index.
Motivation
The input and output embedding layers of transformer-based LLMs are massive. For example, GPT-2's embeddings make up 38.5M (50257 x 768) out of 124M total parameters (~30%). To generate text with these models, we use these embeddings to calculate the output probabilities from which we sample the next token.
A popular sampling method is top-k sampling, which samples from the k most probable tokens, preventing the selection of highly unlikely tokens. For GPT2 based inference with top-k sampling we're only interested in the 50 (huggingface default for k) most likely tokens out of 50,257 possible tokens - less than 0.1%. That's a tiny fraction, but is there a way to get these 50 tokens without calculating logits for all 50k tokens?
This is the same problem faced by vector databases, where vector indices are used to speed up the search process.
Vector Indices
Vector indices come in two main groups:
Clustering based: IVF (Inverted File Index)
IVF clusters points into buckets around a center point. It struggles with the high dimensionality of the embeddings and can fail when the most probable tokens are outside the searched cluster, leading to incorrect predictions.
Graph based: HNSW (Hierarchical Navigable Small World)
HNSW builds a layered proximity graph and navigates through and down the graph to find the closest points. It performs well with the high-dimensional embeddings and offers a higher throughput than IVF. The long index build times aren't an issue for inference and performance can be tuned using the parameters ef and M.
I tried both approaches but only got good results with HNSW.
HNSW Layer
The HNSW layer is built on top of hnswlib and outputs the indices and logits of the most probable k tokens directly from the last hidden state.
To integrate better into the current ecosystem that expects the logits rather than the top-k elements, we can allocate a tensor in the shape of the logits filled with -∞ (since exp(-∞)=0) and replace the top-k indices with their corresponding logits.
class HNSWIndexEmbedding():
def __init__(self, weight, M=32, ef=100, ef_construction=100):
self.index = hnswlib.Index(space='ip', dim=weight.shape[1])
self.index.init_index(max_elements=weight.shape[0], M=M, ef_construction=ef_construction)
self.index.add_items(weight)
self.index.set_ef(ef)
def query(self, x, k):
indices, distances = self.index.knn_query(x, k)
return 1 - distances, indices
Performance
The main application for these vector index based output layers is accelerating on-device inference for CPU based systems.
Graph based vector indices aren't easily parallelizable on the GPU, so for now we focus on CPU inference. I tried the GPU index CAGRA, but couldn't get a performance improvement for small to medium batch sizes (needs further investigation).
The theoretical time complexity of searching an HNSW index is O(m * log(n)), which is better than the time complexity of matrix multiplication with O(m * n), but for small numbers of n it's still a lot slower.
Output Layer Performance
Benchmarking the CPU based performance of the output embeddings layer comparing the HNSW layer (of GPT-2, with parameters k = 50, M = 32, ef = 100/200 ef_construction = 150) to the matrix multiplication between the final hidden state and the output embedding matrix (on Intel Xeon @ 2.20GHz, Google Colab):
B | Speedup ef = 100 | Speedup ef = 200 |
---|---|---|
1 | 41.2x | 23.0x |
8 | 9.0x | 4.4x |
64 | 4.9x | 2.5x |
256 | 3.9x | 2.0x |
Full Model Performance
Benchmarking inference time for text generation (average from 8 runs, on Intel i7-10750H CPU @ 2.60GHz, M=32, ef=150, ef_construction=5000, on Llama 3B ef_construction=500, vec = vector index, MM = matrix multiplication):
Model | Tokens | Vec time | MM time | Speedup |
---|---|---|---|---|
gpt2 | 64 | 1.88s | 2.52s | 1.34 |
gpt2 | 128 | 3.78s | 4.96s | 1.31 |
Llama-3.2-1B | 64 | 17.45s | 21.69s | 1.24 |
Llama-3.2-1B | 128 | 35.36s | 43.44s | 1.23 |
Llama-3.2-3B | 64 | 52.7s | 59.6s | 1.13 |
Llama-3.2-3B | 128 | 100.1s | 114.4s | 1.15 |
The speedups correlate with the size of the embedding matrix relative to the total parameter count:
Model | Embedding Params | Speedup |
---|---|---|
gpt2 | 30% | 32% |
Llama-3.2-1B | 21% | 23% |
Llama-3.2-3B | 12% | 14% |
You can run this benchmark with this notebook.
Prediction Quality
Due to the approximate nature of HNSW, vector index-based embedding layers can't guarantee the exact top-k elements and may occasionally miss a token. However, since we randomly sample from these tokens, missing one token might not have a significant effect, as long as the most important tokens are included.
Accuracy is influenced by the parameters M and ef, where higher values correspond to better accuracy but also higher search times.
High values of ef_construction also lead to better accuracy. While index generation takes a lot longer, this isn't really a problem since they only need to be generated once and can be redistributed with the model.
Prediction Quality Benchmark
To quantify the vector index accuracy, we measure the percentage of tokens where the overlap of the exponential sum from the top-k probabilities predicted with the vector index (M=32) compared to the true top-k is under a threshold. So a 1% error for the threshold of 0.5 means that for every 100th token sampled from the vector index, the distribution the token is sampled from is off by over 50% from the true distribution.
ef_const | ef | <0.50 | <0.75 | <0.90 |
---|---|---|---|---|
200 | 100 | 1.30% | 3.00% | 9.31% |
200 | 150 | 0.78% | 1.63% | 4.95% |
200 | 200 | 0.55% | 1.09% | 3.38% |
5000 | 100 | 0.48% | 1.25% | 4.47% |
5000 | 150 | 0.18% | 0.45% | 1.66% |
5000 | 200 | 0.10% | 0.22% | 0.92% |
Subjectively, I didn't notice any degradation in the quality of the generated text.
Training
Training on the top-k tokens doesn't really work, because softmax maximizes the probability of the correct token, by minimizing all other tokens.
Having only a few tokens to minimize results in slow convergence and higher loss. Only with k values of around 10% of the vocab size does the training work comparably well to the full vocab, but then there aren't any performance benefits using a vector index. Additionally, since vector indices aren't differentiable, they must be rebuilt periodically, making them even less suitable for training.
Conclusion
Using a vector index on the output embeddings can significantly speed up inference on CPU based systems, especially for models with large vocabularies. With the ever increasing size of LLMs, the benefits of larger vocabularies and an increased use of inference on edge devices, the use of vector indices might offer promising increases in performance.
The next step is to make a hacky integration into llama.cpp and better benchmarks.
Accelerate.
Source code for this blog post is in this git repo: martinloretzzz/vector-index-layer
You can comment on this blog post here.
For updates follow me on 𝕏/Twitter: martinloretzzz
I'm available for opportunities focused on solving AGI and all the challenges along the way, especially interested in architectures, frameworks and performance optimizations like this one. Feel free to reach out: work@martinloretz.com.