Peter Chng

What is the Transformer KV Cache?

The Key-Value Cache (or simply KV Cache) is a basic optimization technique in decoder-only Transformers which reduces compute at the expense of increased memory utilization. Let’s look at how it works in the context of a decoder-only Transformer.

Background

Recall that:

  1. The input to a transformer is a sequence of tokens (or a batch of these)
  2. Each layer in a decoder-only transformer has an attention layer and a feedforward layer
  3. Almost every layer/operation in a transformer is per-token with the exception of the attention layer. (And only parts of it)
  4. The attention layer has multiple heads, and usually d_model = d_head * n_heads

During generation: Conceptually, one token is generated for each forward pass, appended to the input, and then the input (now with +1 sequence length) is passed back into the model for another forward pass. However doing it this naive way is wasteful because:

  • It recomputes previous key, value, and attention rows which aren’t needed.
  • The other parts of the network which are per-token also get used again, wasting computation.

The basic operations of a single attention head are: (Regularization techniques like dropout have been omitted for brevity)

B, T, C = x.shape
# q, k, v computation is all per-token
q = self.query(x)
k = self.key(x)
v = self.value(x)

att = q @ k.transpose(-2, -1) * k.shape[-1]**-0.5 # Pre-softmax logits
att = att.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # Self-attention mask
att = F.softmax(att, dim=-1)

out = att @ v

When a new token is appended to the input x:

  1. There is one row added to each of q, k, v
  2. att will go from size (T, T) to (T+1, T+1), meaning there is an additional row and column added.
  3. The additional row represents the attention from the new token against all previous tokens.
  4. Since att is masked to be lower triangular, the additional column added is all zero except for the entry in the last row.
  5. The additional row in att will result in an additional output in v, which will be the additional row in att matrix-multiplied with all columns in v. This would be att[-1, :] @ v.

Visualization

Here’s what the Query, Key, Value, Attention Matrix, and Output Values look like after passing an initial two tokens to our single attention head: (d_head = 4 here) Now after appending one more token, they look like this: Notice that the Query, Key and Value matrices just have an additional row, and previous rows were unaffected. Also note how the attention matrix just has one additional row, and consequently, the output also only has one extra row. All previous rows were unaffected due to the self-attention mask.

The same holds if we append yet another token:

KV Cache

Now we can begin to understand how the KV cache will work: For each new token generated, we don’t need to pass in the entire sequence and thus can avoid recomputing the entire attention matrix. We only need to operate on the new token in the following manner:

  1. Compute new q, k, v rows for only the new token.
  2. New q row will be used immediately. (This is why there is no query cache)
  3. Append new key, value entries to existing K, V caches
  4. Compute new att row by doing matrix-vector multiplication between new q row and k_cache.transpose()
  5. Compute new v row by doing matrix-vector multiplication between new att row and v_cache.transpose()
  6. The output (which is just for the latest token) is passed to the next layer.
  7. This can proceed through subsequent layers because we only care about the latest token.

This is a tradeoff to save repeated computation by increasing memory usage, but is worthwhile because without this optimization, we’d be be wasting cycles recomputing the key, value and attention matrices.

Size of the KV Cache

This can be calculated as follows:

  1. There are n_layers blocks in the transformer
  2. There is one multi-headed attention layer in each block
  3. Each multi-headed attention layer has n_heads each with d_head size for k, v
  4. Need a cache for both K and V
  5. The maximum context length is n_context
  6. Precision is n_bytes, e.g. 4 for FP32
  7. Inference batch size is batch_size

Thus the total size would be:

kv_cache_size = n_layers * n_heads * 2 * n_context * d_head * n_bytes * batch_size

But since d_model = n_heads * d_head (usually), this reduces to:

kv_cache_size = n_layers * d_model * 2 * n_context * n_bytes * batch_size

Example OPT-30B (Taken from this X post)

  • n_bytes = 2 (FP16)
  • n_layers = 48
  • d_model = 7168
  • n_context = 1024
  • batch = 128

This comes out to 180,388,626,432 bytes or ~180 GB.

Relation to MHA, MQA and GQA

(Reference)

These are different types of attention which trade-off model performance against the size of the key-value cache by reducing the number of KV heads.

  1. MHA: Multi-headed Attention: The “normal” way of doing things: There is one query, key and value head per attention head.
  2. MQA: Multi-Query Attention: There is only a single key and single value head which are shared across all n_heads query heads.
  3. GQA: Grouped-Query Attention: The middle ground: There are fewer than n_heads key and value heads, so that each key, value head is shared with more than one query head.

By reducing the number of key, value heads we reduce the amount of memory consumed by the key-value cache.