KV Cache
Table of Contents:
a simple dynamic cache, a faster preallocated cache, cached causal self-attention, a tiny usage example for autoregressive decoding.
Below is a clean Transformer KV-cache from scratch in Python/PyTorch, including:
- a simple dynamic cache,
- a faster preallocated cache,
- cached causal self-attention,
- a tiny usage example for autoregressive decoding.
import torch
import torch.nn as nn
import torch.nn.functional as F
class DynamicKVCache:
"""
Simple KV cache that grows by concatenating new keys/values.
Shapes:
key/value: [batch, n_heads, seq_len, head_dim]
"""
def __init__(self):
self.keys = None
self.values = None
def append(self, key: torch.Tensor, value: torch.Tensor):
"""
Add new key/value tensors to the cache.
key/value shape:
[batch, n_heads, new_seq_len, head_dim]
"""
if self.keys is None:
self.keys = key
self.values = value
else:
self.keys = torch.cat([self.keys, key], dim=2)
self.values = torch.cat([self.values, value], dim=2)
return self.keys, self.values
def get(self):
return self.keys, self.values
def reset(self):
self.keys = None
self.values = None
@property
def seq_len(self):
if self.keys is None:
return 0
return self.keys.shape[2]
This version is easy to understand but inefficient for long decoding because torch.cat reallocates memory every token.
A better implementation preallocates memory.
class StaticKVCache:
"""
Preallocated KV cache for efficient autoregressive decoding.
Stores tensors of shape:
keys: [batch, n_heads, max_seq_len, head_dim]
values: [batch, n_heads, max_seq_len, head_dim]
"""
def __init__(
self,
batch_size: int,
n_heads: int,
max_seq_len: int,
head_dim: int,
dtype=torch.float16,
device="cuda",
):
self.batch_size = batch_size
self.n_heads = n_heads
self.max_seq_len = max_seq_len
self.head_dim = head_dim
self.dtype = dtype
self.device = device
self.keys = torch.empty(
batch_size,
n_heads,
max_seq_len,
head_dim,
dtype=dtype,
device=device,
)
self.values = torch.empty(
batch_size,
n_heads,
max_seq_len,
head_dim,
dtype=dtype,
device=device,
)
self.cur_pos = 0
def append(self, key: torch.Tensor, value: torch.Tensor):
"""
Append new key/value states.
key/value shape:
[batch, n_heads, new_seq_len, head_dim]
"""
bsz, n_heads, new_seq_len, head_dim = key.shape
assert bsz == self.batch_size
assert n_heads == self.n_heads
assert head_dim == self.head_dim
assert self.cur_pos + new_seq_len <= self.max_seq_len
start = self.cur_pos
end = self.cur_pos + new_seq_len
self.keys[:, :, start:end, :] = key
self.values[:, :, start:end, :] = value
self.cur_pos = end
return self.get()
def get(self):
"""
Return the filled portion of the cache.
"""
return (
self.keys[:, :, : self.cur_pos, :],
self.values[:, :, : self.cur_pos, :],
)
def reset(self):
self.cur_pos = 0
@property
def seq_len(self):
return self.cur_pos
Now here is a self-attention module that uses the KV cache.
class CachedSelfAttention(nn.Module):
def __init__(self, embed_dim: int, n_heads: int):
super().__init__()
assert embed_dim % n_heads == 0
self.embed_dim = embed_dim
self.n_heads = n_heads
self.head_dim = embed_dim // n_heads
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False)
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=False)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)
def _split_heads(self, x: torch.Tensor):
"""
[batch, seq_len, embed_dim]
-> [batch, n_heads, seq_len, head_dim]
"""
bsz, seq_len, _ = x.shape
x = x.view(bsz, seq_len, self.n_heads, self.head_dim)
x = x.transpose(1, 2)
return x
def _merge_heads(self, x: torch.Tensor):
"""
[batch, n_heads, seq_len, head_dim]
-> [batch, seq_len, embed_dim]
"""
bsz, n_heads, seq_len, head_dim = x.shape
x = x.transpose(1, 2)
x = x.contiguous().view(bsz, seq_len, n_heads * head_dim)
return x
def forward(
self,
x: torch.Tensor,
cache=None,
use_cache: bool = False,
):
"""
x shape:
[batch, seq_len, embed_dim]
During training/prefill:
seq_len may be large.
During decoding:
seq_len is usually 1.
"""
bsz, seq_len, _ = x.shape
q = self._split_heads(self.q_proj(x))
k = self._split_heads(self.k_proj(x))
v = self._split_heads(self.v_proj(x))
if use_cache:
assert cache is not None
k, v = cache.append(k, v)
total_seq_len = k.shape[2]
attn_scores = torch.matmul(q, k.transpose(-2, -1))
attn_scores = attn_scores / (self.head_dim ** 0.5)
if not use_cache:
# Standard causal mask for full-sequence training/prefill.
causal_mask = torch.triu(
torch.ones(seq_len, total_seq_len, device=x.device, dtype=torch.bool),
diagonal=1,
)
attn_scores = attn_scores.masked_fill(causal_mask, float("-inf"))
else:
# During token-by-token decoding, q is usually length 1.
# The cache already only contains previous + current tokens,
# so no causal mask is needed in the common single-token case.
if seq_len > 1:
# If appending multiple tokens at once, apply an offset causal mask.
past_len = total_seq_len - seq_len
q_positions = torch.arange(
past_len, past_len + seq_len, device=x.device
)[:, None]
k_positions = torch.arange(total_seq_len, device=x.device)[None, :]
causal_mask = k_positions > q_positions
attn_scores = attn_scores.masked_fill(causal_mask, float("-inf"))
attn_probs = F.softmax(attn_scores, dim=-1)
out = torch.matmul(attn_probs, v)
out = self._merge_heads(out)
out = self.out_proj(out)
return out
Example usage:
device = "cuda" if torch.cuda.is_available() else "cpu"
batch_size = 2
embed_dim = 64
n_heads = 4
head_dim = embed_dim // n_heads
max_seq_len = 128
attn = CachedSelfAttention(embed_dim, n_heads).to(device)
cache = StaticKVCache(
batch_size=batch_size,
n_heads=n_heads,
max_seq_len=max_seq_len,
head_dim=head_dim,
dtype=torch.float32,
device=device,
)
# Pretend we are decoding one token at a time.
for t in range(10):
x_t = torch.randn(batch_size, 1, embed_dim, device=device)
y_t = attn(
x_t,
cache=cache,
use_cache=True,
)
print(
f"step={t}, input={x_t.shape}, output={y_t.shape}, cached_seq_len={cache.seq_len}"
)
Output shape each step:
step=0, input=torch.Size([2, 1, 64]), output=torch.Size([2, 1, 64]), cached_seq_len=1
step=1, input=torch.Size([2, 1, 64]), output=torch.Size([2, 1, 64]), cached_seq_len=2
...
The key idea is this:
# Without cache, every generation step recomputes K and V for all previous tokens.
K = k_proj(all_tokens)
V = v_proj(all_tokens)
# With cache, only compute K and V for the new token.
new_K = k_proj(new_token)
new_V = v_proj(new_token)
# Then append them.
cached_K, cached_V = cache.append(new_K, new_V)
In autoregressive decoding, the query is usually only for the newest token, but it attends to all cached keys and values:
attn_scores = Q_new @ K_cached.transpose(-2, -1)
attn_probs = softmax(attn_scores)
output = attn_probs @ V_cached
That turns generation from repeatedly recomputing attention states for the whole prefix into only computing the new token’s projections and attending over the stored prefix.