注意
跳转至页面底部 下载完整示例代码。
在 PyTorch 中使用变长注意力 (Variable Length Attention)#
作者: Angel Li
在本教程中,我们将介绍一个变长注意力 API。该 API 名为 varlen_attn,是 PyTorch 中的一个自定义算子,这意味着它也可以使用 torch.compile 进行编译。
您将学到什么#
变长注意力及其与缩放点积注意力 (SDPA) 的区别
探索在简单的 Transformer 注意力层中如何使用
varlen_attn的示例
先决条件#
PyTorch v2.10.0.dev 或更高版本
NVIDIA A100 GPU 或更新型号
对注意力和我们目前提供的功能有基础了解。请参考以下教程以获取关于 FlexAttention 和 SDPA 的更多详细信息。
变长注意力概述#
在普通的 SDPA 中,序列通常被假定为固定长度。在实践中,这意味着输入张量通常需要被填充 (padded) 到批次中的相同长度。然而,这会通过存储这些填充内容以及执行不必要的计算而浪费内存和算力。变长注意力通过将批次中的张量打包 (packing) 在一起并本质上合并批次维度,从而处理不同长度的序列。
然而,我们仍然需要维护文档之间的边界。为此,我们计算查询 (query) 和键/值 (key/value) 的累积序列位置,以标记文档的结束。在下图中,文档 1 长度为 7 个 token,文档 2 长度为 10 个 token,依此类推,因此 cu_seq_lens = [0, 7, 17, ...]。
填充与打包示意图#
请注意,NestedTensor 是另一种通过打包张量来实现变长的方法(参见教程此处)。
定义#
以下是 varlen_attn 的定义,它返回注意力计算后的输出张量。
def varlen_attn(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seq_q: torch.Tensor,
cu_seq_k: torch.Tensor,
max_q: int,
max_k: int,
*,
return_aux: AuxRequest | None = None,
scale: float | None = None,
window_size: tuple[int, int] = (-1, -1),
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
query、key 和 value 对应打包输入中的 q、k 和 v。cu_seq_q 和 cu_seq_k 分别是查询和键/值的累积索引。它们标记了输入中分隔文档的逻辑边界。max_q 和 max_k 分别是查询和键的最大序列长度。return_aux 指定返回哪些辅助输出(例如 lse)。scale 是一个可选的缩放因子,应用于 softmax 之前的注意力分数。window_size 是一个 (left, right) 元组,用于控制滑动窗口注意力:使用 (-1, -1) 进行全注意力(默认),(-1, 0) 进行因果注意力,或使用 (W, 0) 进行大小为 W 的滑动窗口因果注意力。
关于因果掩码的说明:当 window_size 设置为 (-1, 0) 时,将应用因果掩码,这意味着 token 只能关注之前的 token。对于双向(全)注意力,请使用默认的 (-1, -1)。
在 torchtitan(PyTorch 的预训练框架)中,我们统一将 window_size = (-1, 0),以防止模型“作弊”并过快地人为降低损失。
示例#
让我们通过一个简单的示例来讲解如何在训练 Transformer 模型时使用 varlen_attn。
从输入批次中创建 varlen_attn 所需的元数据#
给定一个输入批次,我们如何构建 varlen_attn 所需的元数据?更具体地说,我们如何计算累积序列索引?
辅助函数 create_varlen_metadata 根据 input_batch 和标记文档结束的序列结束 token ID,返回所需的 cu_seqlens 和 max_seqlen。
import torch
def create_varlen_metadata(input_batch: torch.Tensor, eos_id: int):
batch_size, seq_len = input_batch.shape
device = input_batch.device
cu_seqlens_list, all_seq_lengths = [], []
offset = 0
for b in range(batch_size):
tokens = input_batch[b]
eos_positions = (tokens == eos_id).nonzero(as_tuple=True)[0].to(torch.int32)
# we use the position of the eos tokens to mark the end of documents
sample_cu_seqlens = torch.cat(
[
torch.tensor([0], dtype=torch.int32, device=device),
eos_positions + 1,
torch.tensor([seq_len], dtype=torch.int32, device=device),
]
)
sample_cu_seqlens = torch.unique_consecutive(sample_cu_seqlens)
seq_lengths = torch.diff(sample_cu_seqlens)
all_seq_lengths.append(seq_lengths)
cu_seqlens_adjusted = sample_cu_seqlens[:-1] + offset
cu_seqlens_list.append(cu_seqlens_adjusted)
offset += seq_len
packed_cu_seqlens = torch.cat(
cu_seqlens_list + [torch.tensor([offset], dtype=torch.int32, device=device)]
)
max_seqlen = 0
if len(all_seq_lengths) > 0:
all_seq_lengths = torch.cat(all_seq_lengths)
max_seqlen = all_seq_lengths.max().item()
return packed_cu_seqlens, max_seqlen
使用 varlen_attn 实现注意力模块#
让我们探索如何在注意力模块中使用 varlen_attn。我们像往常一样定义一个注意力模块,但在 forward 方法中,调用新的 varlen_attn 自定义算子。
此函数需要我们在之前使用 create_varlen_metadata 计算的 cu_seq 索引和 max_len,以标记不同文档的边界。
在调用 varlen_attn 之前,我们还需要打包输入,使其形状为 (total tokens, dim)。回想一下,变长注意力允许我们合并 batch_size 维度,以便能够连续排列输入样本。
import torch
import torch.nn as nn
from torch.nn.attention.varlen import varlen_attn
class SimpleVarlenAttention(nn.Module):
def __init__(self, embed_dim: int, num_heads: int):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.qkv_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=False)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)
def forward(
self, x: torch.Tensor, cu_seq: torch.Tensor, max_len: int
) -> torch.Tensor:
batch_size, seq_len, _ = x.shape
x_packed = x.view(batch_size * seq_len, -1) # pack x into (total_tokens, dim)
qkv = self.qkv_proj(x_packed)
q, k, v = qkv.chunk(3, dim=-1)
q = q.view(-1, self.num_heads, self.head_dim)
k = k.view(-1, self.num_heads, self.head_dim)
v = v.view(-1, self.num_heads, self.head_dim)
attn_out = varlen_attn(
query=q,
key=k,
value=v,
cu_seq_q=cu_seq,
cu_seq_k=cu_seq,
max_q=max_len,
max_k=max_len,
window_size=(-1, 0),
)
attn_out = attn_out.view(-1, self.embed_dim)
attn_out = self.out_proj(attn_out)
return attn_out.view(batch_size, seq_len, self.embed_dim)
我们还可以将 torch.compile 与 varlen_attn 一起使用,并定义
compiled_varlen_attn: ClassVar[Callable] = torch.compile(
varlen_attn, mode="max-autotune-no-cudagraphs"
)
我们可以在 Attention 前向传播中调用 compiled_varlen_attn 代替 varlen_attn,其余部分保持不变。
创建一个 Transformer#
现在,我们可以在一个简单的 Transformer 中使用此 SimpleVarlenAttention 模块。
class SimpleVarlenTransformer(nn.Module):
"""
simple 1 layer transformer with varlen attention
"""
def __init__(self, vocab_size: int, embed_dim: int, num_heads: int):
super().__init__()
self.tok_embeddings = nn.Embedding(vocab_size, embed_dim)
self.attention = SimpleVarlenAttention(embed_dim, num_heads)
self.norm = nn.LayerNorm(embed_dim)
def forward(
self, tokens: torch.Tensor, cu_seq: torch.Tensor, max_len: int
) -> torch.Tensor:
x = self.tok_embeddings(tokens)
x = x + self.attention(x, cu_seq, max_len)
x = self.norm(x)
return x
运行训练步骤#
现在我们准备将所有部分整合在一起!让我们使用 SimpleVarlenTransformer 运行一个训练步骤。我们定义模型,使用 create_varlen_metadata 计算 cu_seq 和 max_len,并执行前向和反向传播。
def main():
torch.manual_seed(42)
batch_size = 3
seq_len = 64
vocab_size = 1000
embed_dim = 128
num_heads = 4
eos_id = 2
num_docs = 3
device = "cuda"
dtype = torch.bfloat16
model = SimpleVarlenTransformer(vocab_size, embed_dim, num_heads).to(
device=device, dtype=dtype
)
# create input_batch tokens
input_batch = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
for b in range(batch_size):
# getting random positions to cut the input into multiple documents
doc_positions = torch.randint(10, seq_len - 1, (num_docs - 1,))
for pos in doc_positions:
input_batch[b, pos] = eos_id # insert eos token to simulate end of sample
input_batch[b, -1] = eos_id
cu_seq, max_len = create_varlen_metadata(input_batch, eos_id)
print(
f"cu_seq: {cu_seq}, max_len: {max_len}"
) # cu_seq: tensor([0, 32, 47, 64, 92, 103, 128, 168, 177, 192]), max_len: 40
# fwd pass
output = model(input_batch, cu_seq, max_len)
print(f"output shape: {output.shape}") # (3, 64, 128)
# bwd pass
loss = output.mean()
loss.backward()
if __name__ == "__main__":
main()
cu_seq: tensor([ 0, 32, 47, 64, 92, 103, 128, 168, 177, 192], device='cuda:0',
dtype=torch.int32), max_len: 40
output shape: torch.Size([3, 64, 128])
结论#
在本教程中,我们介绍了如何在 PyTorch 中使用 varlen_attn API 来高效处理不同长度的序列而无需填充。我们探索了如何创建必要的元数据(包括累积序列索引),实现了一个带有变长注意力的简单 Transformer 注意力层,并运行了一个完整的训练步骤。
这种方法消除了在填充 token 上的计算浪费,并为处理不同长度文档的模型提供了更高效的训练和推理能力。
脚本运行总时长: (0 分 0.464 秒)