Senthilkumar Gopal

Musings of a machine learning researcher, engineer and leader

Paged Attention and Chunked Prefill for LLM Inference

Large Language Models (LLMs) impose significant memory and compute demands during inference, particularly when handling long sequences or large batch sizes. To address these challenges, the vLLM system introduces two key optimization techniques - Paged Attention1 and Chunked Prefill2 that target memory efficiency at different stages of inference. This post examines these techniques, clarifying their roles, mechanisms, and interplay. We demonstrate that they can be jointly employed to optimize both the prefill and decoding phases without conflict. Through tensor shape analysis and execution flow, we illustrate how these strategies improve throughput and memory utilization.

Memory Bottlenecks in LLM Inference

Inference in autoregressive LLMs typically involves two phases:

  1. Prefill Phase: The model processes a prompt sequence, computes hidden states, and builds key-value caches for subsequent decoding.
  2. Decoding Phase: The model generates tokens autoregressively, using previously computed key-value caches.

Both phases can be memory-intensive. In the prefill phase, peak memory usage can spike when processing long sequences. In the decoding phase, the cumulative storage of key-value tensors can saturate memory, particularly in multi-user, high-throughput settings.

Paged Attention

Paged Attention is a memory management strategy that organizes key-value tensors into fixed-size blocks. Instead of pre-allocating memory proportional to the maximum possible sequence length, Paged Attention dynamically allocates and pages these blocks into memory as needed during attention score computation.

Mechanism

Let the input sequence have length L, batch size B, hidden size H, and number of heads nhead. Without Paged Attention, key-value tensors are stored as:

Keys, Values ∈ ℝB × L × nhead × dhead

Paged Attention partitions these tensors into blocks of size block_size. Each block stores:

Block Tensor ∈ ℝB × block_size × nhead × dhead

During attention computation at position t, only the relevant blocks containing key-value pairs from positions 0 to t − 1 are loaded into memory.

Example Illustration

Consider the following configuration:

  • Sequence length L = 12
  • Batch size B = 2
  • Hidden size H = 8
  • Number of heads n_head = 2
  • Head dimension d_head = 4 (since H = n_head × d_head)
  • Block size = 4

The key-value tensor shapes without Paged Attention would be:

Keys, Values ∈ ℝ2 × 12 × 2 × 4

With Paged Attention, the sequence is divided into:

  • Block 1: Positions 0–3 → Shape (2, 4, 2, 4)
  • Block 2: Positions 4–7 → Shape (2, 4, 2, 4)
  • Block 3: Positions 8–11 → Shape (2, 4, 2, 4)

When computing attention for position t = 9, the model dynamically loads:

  • Block 1: Positions 0–3
  • Block 2: Positions 4–7
  • First two positions of Block 3: Positions 8–9

Chunked Prefill

Chunked Prefill is an optimization technique applied during the prefill phase to control peak memory usage. It divides the input sequence into fixed-size chunks and processes them sequentially, limiting memory allocation for hidden state and key-value computation.

Mechanism

Given a sequence of length ( L ), the sequence is partitioned into chunks of size chunk_size. For each chunk:

  1. The model computes the hidden states for tokens within the chunk.
  2. It computes and caches the key-value tensors corresponding to these tokens.

Chunked Prefill performs full attention computation within each chunk, meaning each token’s query attends to all preceding tokens, including those in prior chunks.

However, during prefill, the model does not compute attention outputs for future decoding tokens; it only prepares the key-value cache required for the decoding phase.

Example

Using the same configuration:

  • Sequence length L = 12
  • Batch size B = 2
  • Hidden size H = 8
  • Number of heads n_head = 2
  • Head dimension d_head = 4
  • Chunk size = 6

The sequence is divided into:

  • Chunk 1: Positions 0–5
  • Chunk 2: Positions 6–11

For each chunk:

  • Hidden state tensor shape: (2, 6, 8)
  • Key-value tensor shape: (2, 6, 2, 4)

Attention Computation during Prefill:

For token at position t = 4 in Chunk 1, attention is computed over positions 0–3 within Chunk 1.
For token at position t = 7 in Chunk 2, attention is computed over:

  • Cached key-values from Chunk 1 (positions 0–5)
  • Tokens 6–6 (preceding tokens in Chunk 2)

Combined Example and Execution Flow

Let:

  • Sequence length L = 12
  • Batch size B = 2
  • Hidden size H = 8
  • Number of heads n_head = 2
  • Head dimension d_head = 4
  • Chunk size = 6
  • Block size = 4

Prefill Phase

The model processes:

  • Chunk 1 (positions 0–5):
    • Hidden states: (2, 6, 8)
    • Key-value tensors cached: (2, 6, 2, 4)
  • Chunk 2 (positions 6–11):
    • Hidden states: (2, 6, 8)
    • Key-value tensors cached: (2, 6, 2, 4)

Each token in Chunk 2 attends to all cached key-value pairs from Chunk 1 and preceding tokens in Chunk 2.

Decoding Phase

During decoding, suppose the model is generating token at position t = 12:

  • The model computes attention for the query at t = 12 over all cached key-value tensors from positions 0–11.
  • Key-value tensors are organized in blocks:

Paged Attention dynamically loads the necessary blocks during attention computation.

Block Positions Shape
1 0–3 (2, 4, 2, 4)
2 4–7 (2, 4, 2, 4)
3 8–11 (2, 4, 2, 4)


Tensor Shapes

Component Tensor Shape
Key/Value tensor (no paging) (B, L, n_head, d_head)
Key/Value tensor in Paged Attention (B, block_size, n_head, d_head) per block
Key/Value tensor in Chunked Prefill (B, chunk_size, n_head, d_head) per chunk

A simple structure of how block_size and chunk_size are utilized in attention mechanisms is provided below.

Paged Attention and block_size:

Paged Attention manages the Key-Value (KV) cache by dividing it into fixed-size blocks, allowing dynamic memory allocation and efficient handling of varying sequence lengths. The block_size parameter determines the number of tokens each block can store. In vLLM, this is evident in the attention backend implementations.

Example from flash_attn.py:

# Each block can contain up to block_size tokens.
block_tables: Optional[torch.Tensor]

In this snippet, block_tables is a tensor that maps sequences to their respective blocks in the KV cache, with each block capable of holding up to block_size tokens. This structure enables efficient memory usage by allocating only the necessary amount of memory based on the sequence length.

Chunked Prefill and chunk_size:

Chunked Prefill processes long input sequences by dividing them into smaller chunks, reducing peak memory usage during the prefill phase. The chunk_size parameter defines the size of these chunks. In vLLM, the attention backends handle chunking to ensure that memory constraints are respected during computation.

Example from pallas.py:

# Make sure the chunk size is a multiple of 2.
chunk_size = chunk_size // 2 * 2
num_chunks = (batch_size + chunk_size - 1) // chunk_size

Here, the code ensures that chunk_size is an even number and calculates num_chunks, the total number of chunks needed to process the batch. This approach allows the model to handle large sequences by processing manageable chunks sequentially, thereby optimizing memory usage.

Integration in Attention Computation:

The interplay between chunk size and block size exposes a tunable trade-off between latency and memory efficiency. A smaller chunk size reduces peak memory during prefill but may increase latency due to more sequential chunk processing. A smaller block size reduces persistent memory overhead but may increase the number of memory page operations, potentially affecting throughput.

Adaptive strategies that adjust chunk and block sizes based on sequence length, batch size, and memory pressure may further improve resource utilization. During attention computation, especially in scenarios involving long sequences or large batches, vLLM’s attention backends utilize both block_size and chunk_size to manage memory effectively. The KV cache is organized into blocks of size block_size, and input sequences are processed in chunks of size chunk_size. This dual approach ensures that the model can handle extensive inputs without exceeding memory limitations.

Example from flashinfer.py:

# Get the number of valid blocks based on sequence length.
block_table_bound = seq_len // self.block_size + 1 if seq_len % self.block_size != 0 else seq_len // self.block_size

This snippet calculates the number of valid blocks required for a given sequence length, ensuring that the KV cache is allocated appropriately. By managing the KV cache in blocks and processing input in chunks, vLLM effectively balances memory usage and computational efficiency during attention operations.

Sample code structure with Paged Attention and Chunked Prefill

To provide further clarity on the operational dynamics of Paged Attention and Chunked Prefill, the below includes a code walkthrough that models the mechanisms used in the vLLM system. The following Python code captures the essential logic behind how the key-value (KV) cache is organized into blocks and how input sequences are processed in chunks.

This example is based on the configuration discussed earlier:

  • Batch size B = 2
  • Sequence length L = 12
  • Hidden size H = 8
  • Number of attention heads n_head = 2
  • Head dimension d_head = 4
  • Block size 4
  • Chunk size 6
import torch

# Configuration
B = 2                     # Batch size
L = 12                    # Sequence length
H = 8                     # Hidden size
n_head = 2                # Number of heads
d_head = H // n_head      # Head dimension
block_size = 4            # Paged Attention block size
chunk_size = 6            # Chunked Prefill size

# Initialize key-value cache using block table
num_blocks = (L + block_size - 1) // block_size

# Each sequence in batch has its own block table mapping token idx to physical block
block_table = torch.arange(num_blocks).repeat(B, 1)  # [B, num_blocks]
kv_cache = torch.randn(B, num_blocks, block_size, n_head, d_head)  # KV store

# Dummy queries for current chunk
Q = torch.randn(B, chunk_size, n_head, d_head)

# Prefill processing loop (Chunked Prefill)
for b in range(B):
    for chunk_start in range(0, L, chunk_size):
        chunk_end = min(chunk_start + chunk_size, L)
        curr_chunk_len = chunk_end - chunk_start
        q = Q[b, :curr_chunk_len]  # [chunk_len, n_head, d_head]

        # Accumulate attention outputs
        attn_out = []

        # Compute attention over previous tokens using Paged Attention blocks
        for t in range(chunk_start, chunk_end):
            # For query at position t, get all keys from positions [0, t]
            t_block_idx = t // block_size

            context_keys = []
            context_values = []

            # Page in required blocks
            for blk_id in range(t_block_idx + 1):  # include current block
                physical_block = block_table[b, blk_id]
                kv_block = kv_cache[b, physical_block]   # [block_size, n_head, d_head]
                context_keys.append(kv_block)
                context_values.append(kv_block)  # simplified: keys = values

            # Flatten context into a single tensor [T', n_head, d_head]
            K = torch.cat(context_keys, dim=0)[:t+1]  # causal
            V = torch.cat(context_values, dim=0)[:t+1]

            q_t = q[t - chunk_start]  # query vector at position t

            # Scaled dot-product attention
            scores = torch.einsum("hd,thd->t", q_t, K) / (d_head ** 0.5)
            probs = torch.softmax(scores, dim=0)
            ctx = torch.einsum("t,thd->hd", probs, V)

            attn_out.append(ctx)

        # Stack outputs: [chunk_len, n_head, d_head]
        attn_out = torch.stack(attn_out, dim=0)
        print(f"Chunk [{chunk_start}:{chunk_end}] attention output shape:", attn_out.shape)

Operational Flow

  • The block_table holds logical-to-physical mappings of KV blocks for each sequence in the batch. Each block can store block_size tokens.

  • The KV cache (kv_cache) is allocated as [B, num_blocks, block_size, n_head, d_head].

  • For each chunk (driven by chunk_size), queries are processed in a causal way: each query t attends to all previous tokens (positions 0 to t), using the paged blocks.

  • Only the relevant blocks (from block 0 up to block t/block_size⌋) are paged in.

  • Attention is computed with the standard scaled dot-product mechanism using concatenated keys and values.

  • Paged Attention ensures that the key-value cache is organized into fixed-size blocks. Each token’s query accesses the minimum required set of blocks based on its position t. This reduces persistent memory footprint and supports dynamic allocation.

  • Chunked Prefill ensures that the input sequences are processed in chunks of size 6. For each chunk, attention scores are computed incrementally over all previously processed tokens, including those in earlier chunks. This limits peak memory usage during the prefill phase.

  • Dynamic Paging is setup such that for each token t in a chunk, only the blocks containing keys/values from positions 0 to t are loaded, ensuring efficient memory utilization.

Conclusion

Paged Attention and Chunked Prefill are orthogonal strategies for optimizing LLM inference in memory-constrained environments. Paged Attention reduces persistent memory consumption during attention score computation by dynamically paging key-value blocks. Chunked Prefill mitigates peak memory usage during the prefill phase by processing input sequences in chunks. There are also exploration around kernel-level optimizations, fused attention score computation across blocks, and dynamic eviction policies for the block table to further improve inference throughput under multi-user, long-sequence workloads.

References


  1. vLLM Paged Attention — vLLM. https://docs.vllm.ai/en/latest/design/kernel/paged_attention.html#vllm-paged-attention. Accessed 27 Mar. 2025.↩︎

  2. Optimization and Tuning — vLLM. https://docs.vllm.ai/en/latest/performance/optimization.html#chunked-prefill. Accessed 27 Mar. 2025.↩︎


If you found this useful, please cite this post using

Senthilkumar Gopal. (Dec 2024). Paged Attention and Chunked Prefill for LLM Inference. sengopal.me. https://sengopal.me/posts/paged-attention-and-chunked-prefill-for-llm-inference

or

@article{gopal2024pagedattentionandchunkedprefillforllminference,
  title   = {Paged Attention and Chunked Prefill for LLM Inference},
  author  = {Senthilkumar Gopal},
  journal = {sengopal.me},
  year    = {2024},
  month   = {Dec},
  url     = {https://sengopal.me/posts/paged-attention-and-chunked-prefill-for-llm-inference}
}