Attention 的原地 KV 缓存

Attention 模型中的 KV 缓存是指在自回归生成过程中存储先前计算的 Key 和 Value 张量的机制。在仅解码器 Transformer 中,每个新 token 都必须使用 attention 机制关注所有先前的 token。通常,这需要在每个时间步为每个先前的 token 重新计算 Key 和 Value 投影,效率很低。相反,KV 缓存会在首次计算 Key 和 Value 投影后将其存储起来,从而允许模型在未来 token 中重复使用它们而无需重新计算。这显著加快了生成过程。

原地更新 KV 缓存意味着将新的 Key 和 Value 张量直接写入预分配的内存中,索引对应于序列中的当前位置。这有几个优点:它避免了重复的内存分配或复制,减少了计算开销;它还通过启用融合内核和减少内存带宽使用,在硬件加速器上实现了更好的性能。原地更新对于在推理过程中实现高吞吐量和低延迟至关重要,特别是对于部署在实时应用中的大型语言模型。

ONNX opset-24 引入了新功能,以方便表示原地 KV 缓存更新。此图显示了一个示例用例

InPlace KV Cache

  • Attention 运算符的 KV 输入包含整个 KV 缓存张量,其序列长度维度为 max_sequence_length,因此这些输入的大小在自回归迭代之间不会增长。因此,可选的 nonpad_kv_seqlen 输入可用于指示每个样本中有效(非填充)token 的数量,以跳过不必要的计算。

  • KV 缓存更新的逻辑与 Attention 运算符分离。TensorScatter 运算符可用于更新缓存张量,其中当前迭代传入的 Key 和 Value token 将根据 write_indices 分散到缓存张量中。

  • 作为一种优化,后端可以自由地将过去和当前的 Key/Value 张量别名,以避免重复缓存张量并实现原地更新。为了使此优化有效,后端需要确保 TensorScatter 的输入不会被其他运算符后续重用。只有这样,才能安全地将分配给运算符 past_k/v 输入的内存重用于 present_k/v 输出。

  • 相同的计算图可用于自回归模型的预填充和解码阶段。

需要提醒的是,ONNX 表示仍然是函数式表示,其运算符是纯函数。上面描述的图布局是表达原地 KV 缓存更新的一种有用的常见模式,输入/输出别名完全取决于后端实现。