- Sat 28 December 2024
- Large Language Models
- #ml-code, #llm, #vllm, #serving
Paged Attention and Chunked Prefill for LLM Inference
Large Language Models (LLMs) impose significant memory and compute demands during inference, particularly when handling long sequences or large batch sizes. To address these challenges, the vLLM system introduces two key optimization techniques - Paged Attention1 and Chunked Prefill2 that target memory efficiency at different stages of inference. This post examines these techniques, clarifying their roles, mechanisms, and interplay. We demonstrate that they can be jointly employed to optimize both the prefill and decoding phases without conflict. Through tensor shape analysis and execution flow, we illustrate how these strategies improve throughput and memory utilization.
Memory Bottlenecks in LLM Inference
Inference in autoregressive LLMs typically involves two phases:
- Prefill Phase: The model processes a prompt sequence, computes hidden states, and builds key-value caches for subsequent decoding.
- Decoding Phase: The model generates tokens autoregressively, using previously computed key-value caches.
Both phases can be memory-intensive. In the prefill phase, peak memory usage can spike when processing long sequences. In the decoding phase, the cumulative storage of key-value tensors can saturate memory, particularly in multi-user, high-throughput settings.
Paged Attention
Paged Attention is a memory management strategy that organizes key-value tensors into fixed-size blocks. Instead of pre-allocating memory proportional to the maximum possible sequence length, Paged Attention dynamically allocates and pages these blocks into memory as needed during attention score computation.
Mechanism
Let the input sequence have length L, batch size B, hidden size H, and number of heads nhead. Without Paged Attention, key-value tensors are stored as:
Keys, Values ∈ ℝB × L × nhead × dhead
Paged Attention partitions these tensors into blocks
of size block_size
. Each block stores:
Block Tensor ∈ ℝB × block_size × nhead × dhead
During attention computation at position t, only the relevant blocks containing key-value pairs from positions 0 to t − 1 are loaded into memory.
Example Illustration
Consider the following configuration:
- Sequence length L = 12
- Batch size B = 2
- Hidden size H = 8
- Number of heads n_head = 2
- Head dimension d_head = 4 (since H = n_head × d_head)
- Block size = 4
The key-value tensor shapes without Paged Attention would be:
Keys, Values ∈ ℝ2 × 12 × 2 × 4
With Paged Attention, the sequence is divided into:
- Block 1: Positions 0–3 → Shape (2, 4, 2, 4)
- Block 2: Positions 4–7 → Shape (2, 4, 2, 4)
- Block 3: Positions 8–11 → Shape (2, 4, 2, 4)
When computing attention for position t = 9, the model dynamically loads:
- Block 1: Positions 0–3
- Block 2: Positions 4–7
- First two positions of Block 3: Positions 8–9
Chunked Prefill
Chunked Prefill is an optimization technique applied during the prefill phase to control peak memory usage. It divides the input sequence into fixed-size chunks and processes them sequentially, limiting memory allocation for hidden state and key-value computation.
Mechanism
Given a sequence of length ( L ), the sequence is partitioned into
chunks of size chunk_size
. For each chunk:
- The model computes the hidden states for tokens within the chunk.
- It computes and caches the key-value tensors corresponding to these tokens.
Chunked Prefill performs full attention computation within each chunk, meaning each token’s query attends to all preceding tokens, including those in prior chunks.
However, during prefill, the model does not compute attention outputs for future decoding tokens; it only prepares the key-value cache required for the decoding phase.
Example
Using the same configuration:
- Sequence length L = 12
- Batch size B = 2
- Hidden size H = 8
- Number of heads n_head = 2
- Head dimension d_head = 4
- Chunk size = 6
The sequence is divided into:
- Chunk 1: Positions 0–5
- Chunk 2: Positions 6–11
For each chunk:
- Hidden state tensor shape: (2, 6, 8)
- Key-value tensor shape: (2, 6, 2, 4)
Attention Computation during Prefill:
For token at position t = 4
in Chunk 1, attention is computed over positions 0–3 within Chunk
1.
For token at position t = 7 in
Chunk 2, attention is computed over:
- Cached key-values from Chunk 1 (positions 0–5)
- Tokens 6–6 (preceding tokens in Chunk 2)
Combined Example and Execution Flow
Let:
- Sequence length L = 12
- Batch size B = 2
- Hidden size H = 8
- Number of heads n_head = 2
- Head dimension d_head = 4
- Chunk size = 6
- Block size = 4
Prefill Phase
The model processes:
- Chunk 1 (positions 0–5):
- Hidden states: (2, 6, 8)
- Key-value tensors cached: (2, 6, 2, 4)
- Hidden states: (2, 6, 8)
- Chunk 2 (positions 6–11):
- Hidden states: (2, 6, 8)
- Key-value tensors cached: (2, 6, 2, 4)
- Hidden states: (2, 6, 8)
Each token in Chunk 2 attends to all cached key-value pairs from Chunk 1 and preceding tokens in Chunk 2.
Decoding Phase
During decoding, suppose the model is generating token at position t = 12:
- The model computes attention for the query at t = 12 over all cached key-value tensors from positions 0–11.
- Key-value tensors are organized in blocks:
Paged Attention dynamically loads the necessary blocks during attention computation.
Block | Positions | Shape |
---|---|---|
1 | 0–3 | (2, 4, 2, 4) |
2 | 4–7 | (2, 4, 2, 4) |
3 | 8–11 | (2, 4, 2, 4) |
Tensor Shapes
Component | Tensor Shape |
---|---|
Key/Value tensor (no paging) | (B, L, n_head, d_head) |
Key/Value tensor in Paged Attention | (B, block_size, n_head, d_head) per block |
Key/Value tensor in Chunked Prefill | (B, chunk_size, n_head, d_head) per chunk |
A simple structure of how block_size
and
chunk_size
are utilized in attention mechanisms is provided
below.
Paged Attention and block_size
:
Paged Attention manages the Key-Value (KV) cache by dividing it into
fixed-size blocks, allowing dynamic memory allocation and efficient
handling of varying sequence lengths. The block_size
parameter determines the number of tokens each block can store. In vLLM,
this is evident in the attention backend implementations.
Example from flash_attn.py
:
# Each block can contain up to block_size tokens.
block_tables: Optional[torch.Tensor]
In this snippet, block_tables
is a tensor that maps
sequences to their respective blocks in the KV cache, with each block
capable of holding up to block_size
tokens. This structure
enables efficient memory usage by allocating only the necessary amount
of memory based on the sequence length.
Chunked Prefill and chunk_size
:
Chunked Prefill processes long input sequences by dividing them into
smaller chunks, reducing peak memory usage during the prefill phase. The
chunk_size
parameter defines the size of these chunks. In
vLLM, the attention backends handle chunking to ensure that memory
constraints are respected during computation.
Example from pallas.py
:
# Make sure the chunk size is a multiple of 2.
= chunk_size // 2 * 2
chunk_size = (batch_size + chunk_size - 1) // chunk_size num_chunks
Here, the code ensures that chunk_size
is an even number
and calculates num_chunks
, the total number of chunks
needed to process the batch. This approach allows the model to handle
large sequences by processing manageable chunks sequentially, thereby
optimizing memory usage.
Integration in Attention Computation:
The interplay between chunk size and block size exposes a tunable trade-off between latency and memory efficiency. A smaller chunk size reduces peak memory during prefill but may increase latency due to more sequential chunk processing. A smaller block size reduces persistent memory overhead but may increase the number of memory page operations, potentially affecting throughput.
Adaptive strategies that adjust chunk and block sizes based on
sequence length, batch size, and memory pressure may further improve
resource utilization. During attention computation, especially in
scenarios involving long sequences or large batches, vLLM’s attention
backends utilize both block_size
and
chunk_size
to manage memory effectively. The KV cache is
organized into blocks of size block_size
, and input
sequences are processed in chunks of size chunk_size
. This
dual approach ensures that the model can handle extensive inputs without
exceeding memory limitations.
Example from flashinfer.py
:
# Get the number of valid blocks based on sequence length.
= seq_len // self.block_size + 1 if seq_len % self.block_size != 0 else seq_len // self.block_size block_table_bound
This snippet calculates the number of valid blocks required for a given sequence length, ensuring that the KV cache is allocated appropriately. By managing the KV cache in blocks and processing input in chunks, vLLM effectively balances memory usage and computational efficiency during attention operations.
Sample code structure with Paged Attention and Chunked Prefill
To provide further clarity on the operational dynamics of Paged Attention and Chunked Prefill, the below includes a code walkthrough that models the mechanisms used in the vLLM system. The following Python code captures the essential logic behind how the key-value (KV) cache is organized into blocks and how input sequences are processed in chunks.
This example is based on the configuration discussed earlier:
- Batch size B = 2
- Sequence length L = 12
- Hidden size H = 8
- Number of attention heads n_head = 2
- Head dimension d_head = 4
- Block size 4
- Chunk size 6
import torch
# Configuration
= 2 # Batch size
B = 12 # Sequence length
L = 8 # Hidden size
H = 2 # Number of heads
n_head = H // n_head # Head dimension
d_head = 4 # Paged Attention block size
block_size = 6 # Chunked Prefill size
chunk_size
# Initialize key-value cache using block table
= (L + block_size - 1) // block_size
num_blocks
# Each sequence in batch has its own block table mapping token idx to physical block
= torch.arange(num_blocks).repeat(B, 1) # [B, num_blocks]
block_table = torch.randn(B, num_blocks, block_size, n_head, d_head) # KV store
kv_cache
# Dummy queries for current chunk
= torch.randn(B, chunk_size, n_head, d_head)
Q
# Prefill processing loop (Chunked Prefill)
for b in range(B):
for chunk_start in range(0, L, chunk_size):
= min(chunk_start + chunk_size, L)
chunk_end = chunk_end - chunk_start
curr_chunk_len = Q[b, :curr_chunk_len] # [chunk_len, n_head, d_head]
q
# Accumulate attention outputs
= []
attn_out
# Compute attention over previous tokens using Paged Attention blocks
for t in range(chunk_start, chunk_end):
# For query at position t, get all keys from positions [0, t]
= t // block_size
t_block_idx
= []
context_keys = []
context_values
# Page in required blocks
for blk_id in range(t_block_idx + 1): # include current block
= block_table[b, blk_id]
physical_block = kv_cache[b, physical_block] # [block_size, n_head, d_head]
kv_block
context_keys.append(kv_block)# simplified: keys = values
context_values.append(kv_block)
# Flatten context into a single tensor [T', n_head, d_head]
= torch.cat(context_keys, dim=0)[:t+1] # causal
K = torch.cat(context_values, dim=0)[:t+1]
V
= q[t - chunk_start] # query vector at position t
q_t
# Scaled dot-product attention
= torch.einsum("hd,thd->t", q_t, K) / (d_head ** 0.5)
scores = torch.softmax(scores, dim=0)
probs = torch.einsum("t,thd->hd", probs, V)
ctx
attn_out.append(ctx)
# Stack outputs: [chunk_len, n_head, d_head]
= torch.stack(attn_out, dim=0)
attn_out print(f"Chunk [{chunk_start}:{chunk_end}] attention output shape:", attn_out.shape)
Operational Flow
The block_table holds logical-to-physical mappings of KV blocks for each sequence in the batch. Each block can store
block_size
tokens.The KV cache (kv_cache) is allocated as [B, num_blocks, block_size, n_head, d_head].
For each chunk (driven by
chunk_size
), queries are processed in a causal way: each query t attends to all previous tokens (positions 0 to t), using the paged blocks.Only the relevant blocks (from block 0 up to block ⌊t/block_size⌋) are paged in.
Attention is computed with the standard scaled dot-product mechanism using concatenated keys and values.
Paged Attention ensures that the key-value cache is organized into fixed-size blocks. Each token’s query accesses the minimum required set of blocks based on its position t. This reduces persistent memory footprint and supports dynamic allocation.
Chunked Prefill ensures that the input sequences are processed in chunks of size 6. For each chunk, attention scores are computed incrementally over all previously processed tokens, including those in earlier chunks. This limits peak memory usage during the prefill phase.
Dynamic Paging is setup such that for each token t in a chunk, only the blocks containing keys/values from positions 0 to t are loaded, ensuring efficient memory utilization.
Conclusion
Paged Attention and Chunked Prefill are orthogonal strategies for optimizing LLM inference in memory-constrained environments. Paged Attention reduces persistent memory consumption during attention score computation by dynamically paging key-value blocks. Chunked Prefill mitigates peak memory usage during the prefill phase by processing input sequences in chunks. There are also exploration around kernel-level optimizations, fused attention score computation across blocks, and dynamic eviction policies for the block table to further improve inference throughput under multi-user, long-sequence workloads.
References
If you found this useful, please cite this post using
Senthilkumar Gopal. (Dec 2024). Paged Attention and Chunked Prefill for LLM Inference. sengopal.me. https://sengopal.me/posts/paged-attention-and-chunked-prefill-for-llm-inference
or
@article{gopal2024pagedattentionandchunkedprefillforllminference, title = {Paged Attention and Chunked Prefill for LLM Inference}, author = {Senthilkumar Gopal}, journal = {sengopal.me}, year = {2024}, month = {Dec}, url = {https://sengopal.me/posts/paged-attention-and-chunked-prefill-for-llm-inference} }