Optimization techniques for LLM Inference
Why your LLM is slow and some techniques that fix it! ⬇️
When you ask a LLM a question, latency of inference is shaped by three layers:
→ Hardware (GPUs, TPUs , LPUs)
→ The model size and architecture
→ Inference engines and strategies
Sampling is memory-bound: during token generation, we spend most of the time loading the entire model weight to generate a single token. Once those weights are loaded, it costs almost nothing to compute the probabilities.
Here are some of the most popular optimization techniques!
0️⃣ 𝐐𝐮𝐚𝐧𝐭𝐢𝐳𝐚𝐭𝐢𝐨𝐧
→ Store model’s weights in lower precision (e.g. 4-bit, 8-bit instead of 32-bit)
⇒ 4-8x less memory, 2-4x higher token generation
1️⃣ 𝐒𝐩𝐞𝐜𝐮𝐥𝐚𝐭𝐢𝐯𝐞 𝐝𝐞𝐜𝐨𝐝𝐢𝐧𝐠
→ Uses a small, fast draft model to predict the next K tokens and probabilities
→ Get probabilities of these predictions according to the large model
→ Use a heuristic to accept or reject the predictions of the draft model
→ Accepted tokens are kept; rejected ones fall back to the large model
⇒ 2-3x faster token generation
2️⃣ 𝐊𝐕 𝐂𝐚𝐜𝐡𝐞
→ LLM computes the key-value (KV) values for each input token
→ Because LLM (decoder) is autoregressive, it computes the same KV values each time and that is not efficient
→ Reuse (cache) previously computed KV values instead of recomputing, resulting 10-100x faster token generation
3️⃣ 𝐏𝐚𝐠𝐞𝐝𝐀𝐭𝐭𝐞𝐧𝐭𝐢𝐨𝐧 (𝐊𝐕 𝐂𝐚𝐜𝐡𝐞 𝐌𝐞𝐦𝐨𝐫𝐲 𝐌𝐚𝐧𝐚𝐠𝐞𝐦𝐞𝐧𝐭)
→ KV Cache does a poor job at memory management (e.g reserves fixed amount of memory for cache ⇒ small cache: large unused memory)
→ PagedAttention breaks the cache into small "pages", allocating new pages only when needed ⇒ dynamically allocates memory, 2-4x higher throughput to all users
4️⃣ 𝐅𝐥𝐚𝐬𝐡 𝐀𝐭𝐭𝐞𝐧𝐭𝐢𝐨𝐧
→ Most of the compute during inference is matrix multiplication
→ Bottleneck is moving data around, not actual compute
→ Flash Attention moves data into smaller “tiles” that fit inside the chip’s SRAM
⇒ 2-4 faster compute, leaves memory open for much longer context windows
5️⃣ 𝐂𝐨𝐧𝐭𝐢𝐧𝐮𝐨𝐮𝐬 𝐁𝐚𝐭𝐜𝐡𝐢𝐧𝐠
→ Group incoming requests as they come, rather than rely on padding and fixed batched sizes.
→ Chunked prefill: if a prompt is too long, process it in smaller pieces
→ Ragged batching + dynamic scheduling — pack sequences back-to-back; slot in new requests the moment one finishes
⇒ 10-20 higher throughput to all users
6️⃣ 𝐌𝐨𝐝𝐞𝐥 𝐩𝐫𝐮𝐧𝐢𝐧𝐠
→ Removes weights and layers (unnecessary connections) to make the model smaller
⇒ 25% reduced size, 1.3-3x faster token generation
7️⃣ 𝐒𝐦𝐚𝐥𝐥𝐞𝐫 𝐦𝐨𝐝𝐞𝐥𝐬
An obvious biggest latency win is using a smaller model that is good enough for the task.