Table of Contents:

a simple dynamic cache, a faster preallocated cache, cached causal self-attention, a tiny usage example for autoregressive decoding.

Below is a clean Transformer KV-cache from scratch in Python/PyTorch, including:

  1. a simple dynamic cache,
  2. a faster preallocated cache,
  3. cached causal self-attention,
  4. a tiny usage example for autoregressive decoding.
import torch
import torch.nn as nn
import torch.nn.functional as F


class DynamicKVCache:
    """
    Simple KV cache that grows by concatenating new keys/values.

    Shapes:
        key/value: [batch, n_heads, seq_len, head_dim]
    """

    def __init__(self):
        self.keys = None
        self.values = None

    def append(self, key: torch.Tensor, value: torch.Tensor):
        """
        Add new key/value tensors to the cache.

        key/value shape:
            [batch, n_heads, new_seq_len, head_dim]
        """
        if self.keys is None:
            self.keys = key
            self.values = value
        else:
            self.keys = torch.cat([self.keys, key], dim=2)
            self.values = torch.cat([self.values, value], dim=2)

        return self.keys, self.values

    def get(self):
        return self.keys, self.values

    def reset(self):
        self.keys = None
        self.values = None

    @property
    def seq_len(self):
        if self.keys is None:
            return 0
        return self.keys.shape[2]

This version is easy to understand but inefficient for long decoding because torch.cat reallocates memory every token.

A better implementation preallocates memory.

class StaticKVCache:
    """
    Preallocated KV cache for efficient autoregressive decoding.

    Stores tensors of shape:
        keys:   [batch, n_heads, max_seq_len, head_dim]
        values: [batch, n_heads, max_seq_len, head_dim]
    """

    def __init__(
        self,
        batch_size: int,
        n_heads: int,
        max_seq_len: int,
        head_dim: int,
        dtype=torch.float16,
        device="cuda",
    ):
        self.batch_size = batch_size
        self.n_heads = n_heads
        self.max_seq_len = max_seq_len
        self.head_dim = head_dim
        self.dtype = dtype
        self.device = device

        self.keys = torch.empty(
            batch_size,
            n_heads,
            max_seq_len,
            head_dim,
            dtype=dtype,
            device=device,
        )

        self.values = torch.empty(
            batch_size,
            n_heads,
            max_seq_len,
            head_dim,
            dtype=dtype,
            device=device,
        )

        self.cur_pos = 0

    def append(self, key: torch.Tensor, value: torch.Tensor):
        """
        Append new key/value states.

        key/value shape:
            [batch, n_heads, new_seq_len, head_dim]
        """
        bsz, n_heads, new_seq_len, head_dim = key.shape

        assert bsz == self.batch_size
        assert n_heads == self.n_heads
        assert head_dim == self.head_dim
        assert self.cur_pos + new_seq_len <= self.max_seq_len

        start = self.cur_pos
        end = self.cur_pos + new_seq_len

        self.keys[:, :, start:end, :] = key
        self.values[:, :, start:end, :] = value

        self.cur_pos = end

        return self.get()

    def get(self):
        """
        Return the filled portion of the cache.
        """
        return (
            self.keys[:, :, : self.cur_pos, :],
            self.values[:, :, : self.cur_pos, :],
        )

    def reset(self):
        self.cur_pos = 0

    @property
    def seq_len(self):
        return self.cur_pos

Now here is a self-attention module that uses the KV cache.

class CachedSelfAttention(nn.Module):
    def __init__(self, embed_dim: int, n_heads: int):
        super().__init__()

        assert embed_dim % n_heads == 0

        self.embed_dim = embed_dim
        self.n_heads = n_heads
        self.head_dim = embed_dim // n_heads

        self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.v_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)

    def _split_heads(self, x: torch.Tensor):
        """
        [batch, seq_len, embed_dim]
        -> [batch, n_heads, seq_len, head_dim]
        """
        bsz, seq_len, _ = x.shape

        x = x.view(bsz, seq_len, self.n_heads, self.head_dim)
        x = x.transpose(1, 2)

        return x

    def _merge_heads(self, x: torch.Tensor):
        """
        [batch, n_heads, seq_len, head_dim]
        -> [batch, seq_len, embed_dim]
        """
        bsz, n_heads, seq_len, head_dim = x.shape

        x = x.transpose(1, 2)
        x = x.contiguous().view(bsz, seq_len, n_heads * head_dim)

        return x

    def forward(
        self,
        x: torch.Tensor,
        cache=None,
        use_cache: bool = False,
    ):
        """
        x shape:
            [batch, seq_len, embed_dim]

        During training/prefill:
            seq_len may be large.

        During decoding:
            seq_len is usually 1.
        """
        bsz, seq_len, _ = x.shape

        q = self._split_heads(self.q_proj(x))
        k = self._split_heads(self.k_proj(x))
        v = self._split_heads(self.v_proj(x))

        if use_cache:
            assert cache is not None
            k, v = cache.append(k, v)

        total_seq_len = k.shape[2]

        attn_scores = torch.matmul(q, k.transpose(-2, -1))
        attn_scores = attn_scores / (self.head_dim ** 0.5)

        if not use_cache:
            # Standard causal mask for full-sequence training/prefill.
            causal_mask = torch.triu(
                torch.ones(seq_len, total_seq_len, device=x.device, dtype=torch.bool),
                diagonal=1,
            )

            attn_scores = attn_scores.masked_fill(causal_mask, float("-inf"))

        else:
            # During token-by-token decoding, q is usually length 1.
            # The cache already only contains previous + current tokens,
            # so no causal mask is needed in the common single-token case.
            if seq_len > 1:
                # If appending multiple tokens at once, apply an offset causal mask.
                past_len = total_seq_len - seq_len

                q_positions = torch.arange(
                    past_len, past_len + seq_len, device=x.device
                )[:, None]

                k_positions = torch.arange(total_seq_len, device=x.device)[None, :]

                causal_mask = k_positions > q_positions

                attn_scores = attn_scores.masked_fill(causal_mask, float("-inf"))

        attn_probs = F.softmax(attn_scores, dim=-1)

        out = torch.matmul(attn_probs, v)
        out = self._merge_heads(out)
        out = self.out_proj(out)

        return out

Example usage:

device = "cuda" if torch.cuda.is_available() else "cpu"

batch_size = 2
embed_dim = 64
n_heads = 4
head_dim = embed_dim // n_heads
max_seq_len = 128

attn = CachedSelfAttention(embed_dim, n_heads).to(device)

cache = StaticKVCache(
    batch_size=batch_size,
    n_heads=n_heads,
    max_seq_len=max_seq_len,
    head_dim=head_dim,
    dtype=torch.float32,
    device=device,
)

# Pretend we are decoding one token at a time.
for t in range(10):
    x_t = torch.randn(batch_size, 1, embed_dim, device=device)

    y_t = attn(
        x_t,
        cache=cache,
        use_cache=True,
    )

    print(
        f"step={t}, input={x_t.shape}, output={y_t.shape}, cached_seq_len={cache.seq_len}"
    )

Output shape each step:

step=0, input=torch.Size([2, 1, 64]), output=torch.Size([2, 1, 64]), cached_seq_len=1
step=1, input=torch.Size([2, 1, 64]), output=torch.Size([2, 1, 64]), cached_seq_len=2
...

The key idea is this:

# Without cache, every generation step recomputes K and V for all previous tokens.
K = k_proj(all_tokens)
V = v_proj(all_tokens)

# With cache, only compute K and V for the new token.
new_K = k_proj(new_token)
new_V = v_proj(new_token)

# Then append them.
cached_K, cached_V = cache.append(new_K, new_V)

In autoregressive decoding, the query is usually only for the newest token, but it attends to all cached keys and values:

attn_scores = Q_new @ K_cached.transpose(-2, -1)
attn_probs = softmax(attn_scores)
output = attn_probs @ V_cached

That turns generation from repeatedly recomputing attention states for the whole prefix into only computing the new token’s projections and attending over the stored prefix.