What is the Transformer KV Cache?
The Key-Value Cache (or simply KV Cache) is a basic optimization technique in decoder-only Transformers which reduces compute at the expense of increased memory utilization. Let’s look at how it works in the context of a decoder-only Transformer.
Background
Recall that:
- The input to a transformer is a sequence of tokens (or a batch of these)
- Each layer in a decoder-only transformer has an attention layer and a feedforward layer
- Almost every layer/operation in a transformer is per-token with the exception of the attention layer. (And only parts of it)
- The attention layer has multiple heads, and usually
d_model = d_head * n_heads
During generation: Conceptually, one token is generated for each forward pass, appended to the input, and then the input (now with +1 sequence length) is passed back into the model for another forward pass. However doing it this naive way is wasteful because:
- It recomputes previous key, value, and attention rows which aren’t needed.
- The other parts of the network which are per-token also get used again, wasting computation.
The basic operations of a single attention head are: (Regularization techniques like dropout have been omitted for brevity)
B, T, C = x.shape
# q, k, v computation is all per-token
q = self.query(x)
k = self.key(x)
v = self.value(x)
att = q @ k.transpose(-2, -1) * k.shape[-1]**-0.5 # Pre-softmax logits
att = att.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # Self-attention mask
att = F.softmax(att, dim=-1)
out = att @ v
When a new token is appended to the input x
:
- There is one row added to each of
q, k, v
att
will go from size(T, T)
to(T+1, T+1)
, meaning there is an additional row and column added.- The additional row represents the attention from the new token against all previous tokens.
- Since
att
is masked to be lower triangular, the additional column added is all zero except for the entry in the last row. - The additional row in
att
will result in an additional output inv
, which will be the additional row inatt
matrix-multiplied with all columns inv
. This would beatt[-1, :] @ v
.
Visualization
Here’s what the Query, Key, Value, Attention Matrix, and Output Values look like after passing an initial two tokens to our single attention head: (d_head = 4
here)
Now after appending one more token, they look like this:
Notice that the Query, Key and Value matrices just have an additional row, and previous rows were unaffected. Also note how the attention matrix just has one additional row, and consequently, the output also only has one extra row. All previous rows were unaffected due to the self-attention mask.
The same holds if we append yet another token:
KV Cache
Now we can begin to understand how the KV cache will work: For each new token generated, we don’t need to pass in the entire sequence and thus can avoid recomputing the entire attention matrix. We only need to operate on the new token in the following manner:
- Compute new
q, k, v
rows for only the new token. - New
q
row will be used immediately. (This is why there is no query cache) - Append new key, value entries to existing K, V caches
- Compute new
att
row by doing matrix-vector multiplication between newq
row andk_cache.transpose()
- Compute new
v
row by doing matrix-vector multiplication between newatt
row andv_cache.transpose()
- The output (which is just for the latest token) is passed to the next layer.
- This can proceed through subsequent layers because we only care about the latest token.
This is a tradeoff to save repeated computation by increasing memory usage, but is worthwhile because without this optimization, we’d be be wasting cycles recomputing the key, value and attention matrices.
Size of the KV Cache
This can be calculated as follows:
- There are
n_layers
blocks in the transformer - There is one multi-headed attention layer in each block
- Each multi-headed attention layer has
n_heads
each withd_head
size fork, v
- Need a cache for both K and V
- The maximum context length is
n_context
- Precision is
n_bytes
, e.g.4
forFP32
- Inference batch size is
batch_size
Thus the total size would be:
kv_cache_size = n_layers * n_heads * 2 * n_context * d_head * n_bytes * batch_size
But since d_model = n_heads * d_head
(usually), this reduces to:
kv_cache_size = n_layers * d_model * 2 * n_context * n_bytes * batch_size
Example OPT-30B (Taken from this X post)
n_bytes
= 2 (FP16)n_layers
= 48d_model
= 7168n_context
= 1024batch
= 128
This comes out to 180,388,626,432 bytes or ~180 GB.
Relation to MHA, MQA and GQA
These are different types of attention which trade-off model performance against the size of the key-value cache by reducing the number of KV heads.
- MHA: Multi-headed Attention: The “normal” way of doing things: There is one query, key and value head per attention head.
- MQA: Multi-Query Attention: There is only a single key and single value head which are shared across all
n_heads
query heads. - GQA: Grouped-Query Attention: The middle ground: There are fewer than
n_heads
key and value heads, so that each key, value head is shared with more than one query head.
By reducing the number of key, value heads we reduce the amount of memory consumed by the key-value cache.