评价此页

在 PyTorch 中使用变长注意力 (Variable Length Attention)#

作者: Angel Li

在本教程中,我们将介绍一个变长注意力 API。该 API 名为 varlen_attn,是 PyTorch 中的一个自定义算子,这意味着它也可以使用 torch.compile 进行编译。

注意: | 本教程目前需要您使用 PyTorch 的 nightly 构建版本。该算子目前仅适用于 A100 及更新型号机器上的 NVIDIA CUDA 环境。支持的数据类型包括 BF16 和 FP16。

您将学到什么#

  • 变长注意力及其与缩放点积注意力 (SDPA) 的区别

  • 探索在简单的 Transformer 注意力层中如何使用 varlen_attn 的示例

先决条件#

  • PyTorch v2.10.0.dev 或更高版本

  • NVIDIA A100 GPU 或更新型号

  • 对注意力和我们目前提供的功能有基础了解。请参考以下教程以获取关于 FlexAttentionSDPA 的更多详细信息。

变长注意力概述#

在普通的 SDPA 中,序列通常被假定为固定长度。在实践中,这意味着输入张量通常需要被填充 (padded) 到批次中的相同长度。然而,这会通过存储这些填充内容以及执行不必要的计算而浪费内存和算力。变长注意力通过将批次中的张量打包 (packing) 在一起并本质上合并批次维度,从而处理不同长度的序列。

然而,我们仍然需要维护文档之间的边界。为此,我们计算查询 (query) 和键/值 (key/value) 的累积序列位置,以标记文档的结束。在下图中,文档 1 长度为 7 个 token,文档 2 长度为 10 个 token,依此类推,因此 cu_seq_lens = [0, 7, 17, ...]

Diagram comparing two approaches for handling variable-length sequences in attention. Left side labeled 'PADDING (SDPA)' shows a 2D batch of 4 samples stacked vertically, each padded to match the longest sequence (length 10). Sample 1 has length 7 with 3 padding tokens, sample 2 has length 10 with no padding, sample 3 has length 8 with 2 padding tokens, and sample 4 has length 5 with 5 padding tokens. The vertical axis represents batch size and horizontal axis represents sequence length. Right side labeled 'PACKING (VARLEN)' shows the same 4 samples concatenated into a single 1D sequence with no padding. Arrows indicate boundaries at positions 7, 17, 25, and 30. Below shows cu_seq_lens: [0, 7, 17, 25, 30] representing cumulative sequence lengths, and max_seqlen: 10.

填充与打包示意图#

请注意,NestedTensor 是另一种通过打包张量来实现变长的方法(参见教程此处)。

定义#

以下是 varlen_attn 的定义,它返回注意力计算后的输出张量。

def varlen_attn(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    cu_seq_q: torch.Tensor,
    cu_seq_k: torch.Tensor,
    max_q: int,
    max_k: int,
    *,
    return_aux: AuxRequest | None = None,
    scale: float | None = None,
    window_size: tuple[int, int] = (-1, -1),
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:

querykeyvalue 对应打包输入中的 qkvcu_seq_qcu_seq_k 分别是查询和键/值的累积索引。它们标记了输入中分隔文档的逻辑边界。max_qmax_k 分别是查询和键的最大序列长度。return_aux 指定返回哪些辅助输出(例如 lse)。scale 是一个可选的缩放因子,应用于 softmax 之前的注意力分数。window_size 是一个 (left, right) 元组,用于控制滑动窗口注意力:使用 (-1, -1) 进行全注意力(默认),(-1, 0) 进行因果注意力,或使用 (W, 0) 进行大小为 W 的滑动窗口因果注意力。

关于因果掩码的说明:window_size 设置为 (-1, 0) 时,将应用因果掩码,这意味着 token 只能关注之前的 token。对于双向(全)注意力,请使用默认的 (-1, -1)

在 torchtitan(PyTorch 的预训练框架)中,我们统一将 window_size = (-1, 0),以防止模型“作弊”并过快地人为降低损失。

示例#

让我们通过一个简单的示例来讲解如何在训练 Transformer 模型时使用 varlen_attn

从输入批次中创建 varlen_attn 所需的元数据#

给定一个输入批次,我们如何构建 varlen_attn 所需的元数据?更具体地说,我们如何计算累积序列索引?

辅助函数 create_varlen_metadata 根据 input_batch 和标记文档结束的序列结束 token ID,返回所需的 cu_seqlensmax_seqlen

import torch


def create_varlen_metadata(input_batch: torch.Tensor, eos_id: int):
    batch_size, seq_len = input_batch.shape
    device = input_batch.device
    cu_seqlens_list, all_seq_lengths = [], []
    offset = 0

    for b in range(batch_size):
        tokens = input_batch[b]
        eos_positions = (tokens == eos_id).nonzero(as_tuple=True)[0].to(torch.int32)

        # we use the position of the eos tokens to mark the end of documents
        sample_cu_seqlens = torch.cat(
            [
                torch.tensor([0], dtype=torch.int32, device=device),
                eos_positions + 1,
                torch.tensor([seq_len], dtype=torch.int32, device=device),
            ]
        )
        sample_cu_seqlens = torch.unique_consecutive(sample_cu_seqlens)

        seq_lengths = torch.diff(sample_cu_seqlens)
        all_seq_lengths.append(seq_lengths)

        cu_seqlens_adjusted = sample_cu_seqlens[:-1] + offset
        cu_seqlens_list.append(cu_seqlens_adjusted)

        offset += seq_len

    packed_cu_seqlens = torch.cat(
        cu_seqlens_list + [torch.tensor([offset], dtype=torch.int32, device=device)]
    )

    max_seqlen = 0
    if len(all_seq_lengths) > 0:
        all_seq_lengths = torch.cat(all_seq_lengths)
        max_seqlen = all_seq_lengths.max().item()

    return packed_cu_seqlens, max_seqlen

使用 varlen_attn 实现注意力模块#

让我们探索如何在注意力模块中使用 varlen_attn。我们像往常一样定义一个注意力模块,但在 forward 方法中,调用新的 varlen_attn 自定义算子。

此函数需要我们在之前使用 create_varlen_metadata 计算的 cu_seq 索引和 max_len,以标记不同文档的边界。

在调用 varlen_attn 之前,我们还需要打包输入,使其形状为 (total tokens, dim)。回想一下,变长注意力允许我们合并 batch_size 维度,以便能够连续排列输入样本。

import torch
import torch.nn as nn
from torch.nn.attention.varlen import varlen_attn


class SimpleVarlenAttention(nn.Module):
    def __init__(self, embed_dim: int, num_heads: int):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        self.qkv_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=False)
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)

    def forward(
        self, x: torch.Tensor, cu_seq: torch.Tensor, max_len: int
    ) -> torch.Tensor:
        batch_size, seq_len, _ = x.shape
        x_packed = x.view(batch_size * seq_len, -1)  # pack x into (total_tokens, dim)

        qkv = self.qkv_proj(x_packed)
        q, k, v = qkv.chunk(3, dim=-1)

        q = q.view(-1, self.num_heads, self.head_dim)
        k = k.view(-1, self.num_heads, self.head_dim)
        v = v.view(-1, self.num_heads, self.head_dim)

        attn_out = varlen_attn(
            query=q,
            key=k,
            value=v,
            cu_seq_q=cu_seq,
            cu_seq_k=cu_seq,
            max_q=max_len,
            max_k=max_len,
            window_size=(-1, 0),
        )
        attn_out = attn_out.view(-1, self.embed_dim)
        attn_out = self.out_proj(attn_out)
        return attn_out.view(batch_size, seq_len, self.embed_dim)

我们还可以将 torch.compilevarlen_attn 一起使用,并定义

compiled_varlen_attn: ClassVar[Callable] = torch.compile(
    varlen_attn, mode="max-autotune-no-cudagraphs"
)

我们可以在 Attention 前向传播中调用 compiled_varlen_attn 代替 varlen_attn,其余部分保持不变。

创建一个 Transformer#

现在,我们可以在一个简单的 Transformer 中使用此 SimpleVarlenAttention 模块。

class SimpleVarlenTransformer(nn.Module):
    """
    simple 1 layer transformer with varlen attention
    """

    def __init__(self, vocab_size: int, embed_dim: int, num_heads: int):
        super().__init__()
        self.tok_embeddings = nn.Embedding(vocab_size, embed_dim)
        self.attention = SimpleVarlenAttention(embed_dim, num_heads)
        self.norm = nn.LayerNorm(embed_dim)

    def forward(
        self, tokens: torch.Tensor, cu_seq: torch.Tensor, max_len: int
    ) -> torch.Tensor:
        x = self.tok_embeddings(tokens)
        x = x + self.attention(x, cu_seq, max_len)
        x = self.norm(x)
        return x

运行训练步骤#

现在我们准备将所有部分整合在一起!让我们使用 SimpleVarlenTransformer 运行一个训练步骤。我们定义模型,使用 create_varlen_metadata 计算 cu_seqmax_len,并执行前向和反向传播。

def main():
    torch.manual_seed(42)

    batch_size = 3
    seq_len = 64
    vocab_size = 1000
    embed_dim = 128
    num_heads = 4
    eos_id = 2
    num_docs = 3
    device = "cuda"
    dtype = torch.bfloat16

    model = SimpleVarlenTransformer(vocab_size, embed_dim, num_heads).to(
        device=device, dtype=dtype
    )

    # create input_batch tokens
    input_batch = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)

    for b in range(batch_size):
        # getting random positions to cut the input into multiple documents
        doc_positions = torch.randint(10, seq_len - 1, (num_docs - 1,))
        for pos in doc_positions:
            input_batch[b, pos] = eos_id  # insert eos token to simulate end of sample
        input_batch[b, -1] = eos_id

    cu_seq, max_len = create_varlen_metadata(input_batch, eos_id)
    print(
        f"cu_seq: {cu_seq}, max_len: {max_len}"
    )  # cu_seq: tensor([0, 32, 47, 64, 92, 103, 128, 168, 177, 192]), max_len: 40

    # fwd pass
    output = model(input_batch, cu_seq, max_len)
    print(f"output shape: {output.shape}")  # (3, 64, 128)

    # bwd pass
    loss = output.mean()
    loss.backward()


if __name__ == "__main__":
    main()
cu_seq: tensor([  0,  32,  47,  64,  92, 103, 128, 168, 177, 192], device='cuda:0',
       dtype=torch.int32), max_len: 40
output shape: torch.Size([3, 64, 128])

结论#

在本教程中,我们介绍了如何在 PyTorch 中使用 varlen_attn API 来高效处理不同长度的序列而无需填充。我们探索了如何创建必要的元数据(包括累积序列索引),实现了一个带有变长注意力的简单 Transformer 注意力层,并运行了一个完整的训练步骤。

这种方法消除了在填充 token 上的计算浪费,并为处理不同长度文档的模型提供了更高效的训练和推理能力。

脚本运行总时长: (0 分 0.464 秒)