评价此页

(测试版) 使用缩放点积注意力 (SDPA) 实现高性能 Transformer#

创建日期: 2023年3月15日 | 最后更新: 2024年10月9日 | 最后验证: 2024年11月5日

作者: Driss Guessous

摘要#

在本教程中,我们旨在介绍一个有助于实现 Transformer 架构的全新 torch.nn.functional 函数。该函数名为 torch.nn.functional.scaled_dot_product_attention。关于该函数的详细说明,请参阅 PyTorch 文档。此函数已被整合进 torch.nn.MultiheadAttentiontorch.nn.TransformerEncoderLayer 中。

概述#

宏观上,该 PyTorch 函数根据论文 《Attention is all you need》 中的定义,计算查询 (query)、键 (key) 和值 (value) 之间的缩放点积注意力 (SDPA)。虽然可以使用现有的 PyTorch 函数编写此功能,但使用融合实现 (fused implementation) 比原生实现能带来显著的性能提升。

融合实现#

对于 CUDA 张量输入,该函数将分发 (dispatch) 至以下实现之一:

注意

本教程需要 PyTorch 2.0.0 或更高版本。

import torch
import torch.nn as nn
import torch.nn.functional as F
device = "cuda" if torch.cuda.is_available() else "cpu"

# Example Usage:
query, key, value = torch.randn(2, 3, 8, device=device), torch.randn(2, 3, 8, device=device), torch.randn(2, 3, 8, device=device)
F.scaled_dot_product_attention(query, key, value)
tensor([[[ 3.8814e-01,  7.3777e-01,  5.3059e-01,  4.4033e-01,  4.2172e-03,
          -7.2611e-01, -2.1513e-03,  1.9442e-02],
         [-2.9045e-02,  1.3793e-01,  3.8818e-01,  2.3693e-01, -5.3844e-01,
          -4.3346e-01,  1.8052e-01,  2.3583e-02],
         [ 6.2640e-02,  4.4714e-01,  5.7041e-01, -1.2344e-01, -9.4141e-01,
          -9.1985e-01, -1.5386e-02,  3.7205e-01]],

        [[ 8.6307e-01, -3.0465e-01, -2.0338e-01,  6.0832e-01,  9.1712e-01,
          -1.0877e+00,  4.9930e-01, -1.7761e-01],
         [ 8.4066e-01, -3.4873e-02,  8.2936e-01, -1.0034e-01,  1.3865e-03,
          -1.4349e+00, -8.5389e-02, -5.4365e-02],
         [ 8.0448e-01, -3.1879e-01, -4.3490e-01,  6.9458e-01,  1.1321e+00,
          -9.6985e-01,  6.0291e-01, -1.3440e-01]]], device='cuda:0')

显式分发器控制#

虽然该函数会隐式地分发到上述三种实现之一,但用户也可以通过上下文管理器显式控制分发过程。此上下文管理器允许用户明确禁用某些实现。如果用户想要确保函数确实为其特定输入使用了最快的实现,可以使用该上下文管理器来测试并衡量性能。

# Lets define a helpful benchmarking function:
import torch.utils.benchmark as benchmark
def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
    t0 = benchmark.Timer(
        stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
    )
    return t0.blocked_autorange().mean * 1e6

# Lets define the hyper-parameters of our input
batch_size = 32
max_sequence_len = 1024
num_heads = 32
embed_dimension = 32

dtype = torch.float16

query = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)
key = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)
value = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)

print(f"The default implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")

# Lets explore the speed of each of the 3 implementations
from torch.nn.attention import SDPBackend, sdpa_kernel


with sdpa_kernel(SDPBackend.MATH):
    math_time=benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value)
    print(f"The math implementation runs in {math_time:.3f} microseconds")

with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
    try:
        flash_time=benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value)
        print(f"The flash attention implementation runs in {flash_time:.3f} microseconds")
    except RuntimeError:
        print("FlashAttention is not supported. See warnings for reasons.")

with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):
    try:
        efficient_time=benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value)
        print(f"The memory efficient implementation runs in {efficient_time:.3f} microseconds")
    except RuntimeError:
        print("EfficientAttention is not supported. See warnings for reasons.")
The default implementation runs in 2273.691 microseconds
The math implementation runs in 87493.669 microseconds
The flash attention implementation runs in 2274.499 microseconds
The memory efficient implementation runs in 4362.119 microseconds

硬件依赖性#

根据您运行上述单元格的机器及其硬件配置,结果可能会有所不同。- 如果您没有 GPU 并在 CPU 上运行,那么对于 FP32 数据,上下文管理器将不起作用,且三次运行的结果耗时应相似。- 根据您的显卡支持的计算能力,Flash Attention 或内存高效实现可能会失败。

因果自注意力 (Causal Self Attention)#

以下是一个受 Andrej Karpathy 的 NanoGPT 仓库启发的多头因果自注意力块的实现示例。

class CausalSelfAttention(nn.Module):

    def __init__(self, num_heads: int, embed_dimension: int, bias: bool=False, is_causal: bool=False, dropout:float=0.0):
        super().__init__()
        assert embed_dimension % num_heads == 0
        # key, query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(embed_dimension, 3 * embed_dimension, bias=bias)
        # output projection
        self.c_proj = nn.Linear(embed_dimension, embed_dimension, bias=bias)
        # regularization
        self.dropout = dropout
        self.resid_dropout = nn.Dropout(dropout)
        self.num_heads = num_heads
        self.embed_dimension = embed_dimension
        # Perform causal masking
        self.is_causal = is_causal

    def forward(self, x):
        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        query_projected = self.c_attn(x)

        batch_size = query_projected.size(0)
        embed_dim = query_projected.size(2)
        head_dim = embed_dim // (self.num_heads * 3)

        query, key, value = query_projected.chunk(3, -1)
        query = query.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)
        key = key.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)
        value = value.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)

        if self.training:
            dropout = self.dropout
            is_causal = self.is_causal
        else:
            dropout = 0.0
            is_causal = False

        y = F.scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=dropout, is_causal=is_causal)
        y = y.transpose(1, 2).view(batch_size, -1, self.num_heads * head_dim)

        y = self.resid_dropout(self.c_proj(y))
        return y


num_heads = 8
heads_per_dim = 64
embed_dimension = num_heads * heads_per_dim
dtype = torch.float16
model = CausalSelfAttention(num_heads=num_heads, embed_dimension=embed_dimension, bias=False, is_causal=True, dropout=0.1).to("cuda").to(dtype).eval()
print(model)
CausalSelfAttention(
  (c_attn): Linear(in_features=512, out_features=1536, bias=False)
  (c_proj): Linear(in_features=512, out_features=512, bias=False)
  (resid_dropout): Dropout(p=0.1, inplace=False)
)

NestedTensor 和稠密张量支持#

SDPA 同时支持 NestedTensor 和稠密张量输入。NestedTensors 可以处理输入为变长序列批次的情况,而无需将每个序列填充 (pad) 到批次中的最大长度。有关 NestedTensors 的更多信息,请参阅 torch.nestedNestedTensors 教程

import random
def generate_rand_batch(
    batch_size,
    max_sequence_len,
    embed_dimension,
    pad_percentage=None,
    dtype=torch.float16,
    device="cuda",
):
    if not pad_percentage:
        return (
            torch.randn(
                batch_size,
                max_sequence_len,
                embed_dimension,
                dtype=dtype,
                device=device,
            ),
            None,
        )
    # Random sequence lengths
    seq_len_list = [
        int(max_sequence_len * (1 - random.gauss(pad_percentage, 0.01)))
        for _ in range(batch_size)
    ]
    # Make random entry in the batch have max sequence length
    seq_len_list[random.randint(0, batch_size - 1)] = max_sequence_len
    return (
        torch.nested.nested_tensor(
            [
                torch.randn(seq_len, embed_dimension,
                            dtype=dtype, device=device)
                for seq_len in seq_len_list
            ]
        ),
        seq_len_list,
    )

random_nt, _ = generate_rand_batch(32, 512, embed_dimension, pad_percentage=0.5, dtype=dtype, device=device)
random_dense, _ = generate_rand_batch(32, 512, embed_dimension, pad_percentage=None, dtype=dtype, device=device)

# Currently the fused implementations don't support ``NestedTensor`` for training
model.eval()

with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
    try:
        print(f"Random NT runs in {benchmark_torch_function_in_microseconds(model, random_nt):.3f} microseconds")
        print(f"Random Dense runs in {benchmark_torch_function_in_microseconds(model, random_dense):.3f} microseconds")
    except RuntimeError:
        print("FlashAttention is not supported. See warnings for reasons.")
/usr/local/lib/python3.10/dist-packages/torch/nested/__init__.py:256: UserWarning: The PyTorch API of nested tensors is in prototype stage and will change in the near future. We recommend specifying layout=torch.jagged when constructing a nested tensor, as this layout receives active development, has better operator coverage, and works with torch.compile. (Triggered internally at /pytorch/aten/src/ATen/NestedTensorImpl.cpp:178.)
  return _nested.nested_tensor(
Random NT runs in 632.878 microseconds
Random Dense runs in 950.426 microseconds

torch.compile 中使用 SDPA#

随着 PyTorch 2.0 的发布,引入了一项名为 torch.compile() 的新功能,它相比 Eager 模式能提供显著的性能提升。缩放点积注意力与 torch.compile() 完全兼容。为了证明这一点,让我们使用 torch.compile() 编译 CausalSelfAttention 模块,并观察随之而来的性能提升。

batch_size = 32
max_sequence_len = 256
x = torch.rand(batch_size, max_sequence_len,
               embed_dimension, device=device, dtype=dtype)
print(
    f"The non compiled module runs in  {benchmark_torch_function_in_microseconds(model, x):.3f} microseconds")


compiled_model = torch.compile(model)
# Let's compile it
compiled_model(x)
print(
    f"The compiled module runs in  {benchmark_torch_function_in_microseconds(compiled_model, x):.3f} microseconds")
The non compiled module runs in  424.494 microseconds
The compiled module runs in  545.335 microseconds

确切的执行时间取决于机器,但我的测试结果如下:未编译模块运行耗时 166.616 微秒,编译后的模块运行耗时 166.726 微秒。这与我们的预期不符。让我们深入探讨一下。PyTorch 提供了一个出色的内置性能分析器 (profiler),您可以利用它来检查代码的性能特性。

from torch.profiler import profile, record_function, ProfilerActivity
activities = [ProfilerActivity.CPU]
if device == 'cuda':
    activities.append(ProfilerActivity.CUDA)

with profile(activities=activities, record_shapes=False) as prof:
    with record_function(" Non-Compilied Causal Attention"):
        for _ in range(25):
            model(x)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))


with profile(activities=activities, record_shapes=False) as prof:
    with record_function("Compiled Causal Attention"):
        for _ in range(25):
            compiled_model(x)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

# For even more insights, you can export the trace and use ``chrome://tracing`` to view the results
#
# .. code-block:: python
#
#    prof.export_chrome_trace("compiled_causal_attention_trace.json").
/usr/local/lib/python3.10/dist-packages/torch/profiler/profiler.py:272: UserWarning: Warning: Profiler clears events at the end of each cycle.Only events from the current cycle will be reported.To keep events across cycles, set acc_events=True.
  _warn_once(
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                         Non-Compilied Causal Attention        17.55%       2.225ms        77.60%       9.838ms       9.838ms       0.000us         0.00%      10.831ms      10.831ms             1
                         Non-Compilied Causal Attention         0.00%       0.000us         0.00%       0.000us       0.000us      10.702ms       100.90%      10.702ms      10.702ms             1
                                           aten::linear         1.05%     132.643us        35.76%       4.533ms      90.667us       0.000us         0.00%       8.020ms     160.403us            50
                                           aten::matmul         2.07%     262.726us        32.39%       4.107ms      82.132us       0.000us         0.00%       8.020ms     160.403us            50
                                               aten::mm        10.28%       1.304ms        28.14%       3.567ms      71.344us       7.797ms        73.51%       8.020ms     160.403us            50
         ampere_fp16_s1688gemm_fp16_128x128_ldg8_f2f_tn         0.00%       0.000us         0.00%       0.000us       0.000us       5.581ms        52.61%       5.581ms     223.227us            25
                     aten::scaled_dot_product_attention         1.77%     224.244us        15.93%       2.020ms      80.809us       0.000us         0.00%       2.810ms     112.414us            25
              aten::_scaled_dot_product_flash_attention         2.45%     310.646us        14.17%       1.796ms      71.840us       0.000us         0.00%       2.810ms     112.414us            25
                         aten::_flash_attention_forward         2.84%     360.127us        10.43%       1.323ms      52.900us       2.810ms        26.49%       2.810ms     112.414us            25
void pytorch_flash::flash_fwd_kernel<Flash_fwd_kerne...         0.00%       0.000us         0.00%       0.000us       0.000us       2.810ms        26.49%       2.810ms     112.414us            25
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
Self CPU time total: 12.678ms
Self CUDA time total: 10.607ms

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                              Compiled Causal Attention         0.00%       0.000us         0.00%       0.000us       0.000us      10.742ms       101.21%      10.742ms      10.742ms             1
                              Compiled Causal Attention         8.56%       1.108ms        84.72%      10.964ms      10.964ms       0.000us         0.00%      10.613ms      10.613ms             1
                             Torch-Compiled Region: 0/0        10.21%       1.321ms        73.49%       9.511ms     380.423us       0.000us         0.00%      10.613ms     424.540us            25
                                       CompiledFunction         8.06%       1.043ms        60.42%       7.819ms     312.756us       0.000us         0.00%      10.613ms     424.540us            25
## Call CompiledFxGraph fdsccik44p4muv2w4335ry2gjcei...        13.85%       1.792ms        52.36%       6.776ms     271.038us       0.000us         0.00%      10.613ms     424.540us            25
                                               aten::mm         8.03%       1.039ms        12.12%       1.569ms      31.379us       7.790ms        73.40%       7.790ms     155.806us            50
         ampere_fp16_s1688gemm_fp16_128x128_ldg8_f2f_tn         0.00%       0.000us         0.00%       0.000us       0.000us       5.574ms        52.51%       5.574ms     222.946us            25
              aten::_scaled_dot_product_flash_attention         1.95%     252.356us        12.96%       1.677ms      67.065us       0.000us         0.00%       2.823ms     112.927us            25
                         aten::_flash_attention_forward         2.73%     352.869us         9.49%       1.228ms      49.108us       2.823ms        26.60%       2.823ms     112.927us            25
void pytorch_flash::flash_fwd_kernel<Flash_fwd_kerne...         0.00%       0.000us         0.00%       0.000us       0.000us       2.823ms        26.60%       2.823ms     112.927us            25
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
Self CPU time total: 12.941ms
Self CUDA time total: 10.613ms

之前的代码片段生成了一份报告,列出了在编译和未编译模块中消耗 GPU 执行时间最多的前 10 个 PyTorch 函数。分析表明,两个模块在 GPU 上花费的大部分时间都集中在同一组函数上。这里的原因是 torch.compile 非常善于消除与 PyTorch 相关的框架开销。如果您的模型正在启动大型、高效的 CUDA 内核(本例中 CausalSelfAttention 就是这种情况),那么 PyTorch 的开销可能会被掩盖。

实际上,您的模块通常不会只包含一个 CausalSelfAttention 块。在试验 Andrej Karpathy 的 NanoGPT 仓库时,编译该模块将每个训练步骤的时间从 6090.49ms 缩短到了 3273.17ms!这是在 NanoGPT 训练莎士比亚数据集的 commit ae3a8d5 上完成的。

将 SDPA 与 attn_bias 子类配合使用#

# As of PyTorch 2.3, we have added a new submodule that contains tensor subclasses.
# Designed to be used with ``torch.nn.functional.scaled_dot_product_attention``.
# The module is named ``torch.nn.attention.bias`` and contains the following two
# utilities for generating causal attention variants:
#
# - ``torch.nn.attention.bias.causal_upper_left``
# - ``torch.nn.attention.bias.causal_lower_right``
#
# .. note::
#    The current argument ``is_causal`` in ``torch.nn.functional.scaled_dot_product_attention``
#    is the same as using ``torch.nn.attention.bias.causal_upper_left``.
#

from torch.nn.attention.bias import causal_lower_right, causal_upper_left

batch_size = 32
sequence_length_q = 2
sequence_length_kv = 10
num_heads = 16
embed_dimension = 32

dtype = torch.float16

query = torch.rand(batch_size, num_heads, sequence_length_q, embed_dimension, device=device, dtype=dtype)
key = torch.rand(batch_size, num_heads, sequence_length_kv, embed_dimension, device=device, dtype=dtype)
value = torch.rand(batch_size, num_heads, sequence_length_kv, embed_dimension, device=device, dtype=dtype)

upper_left_bias = causal_upper_left(sequence_length_q, sequence_length_kv)
lower_right_bias = causal_lower_right(sequence_length_q, sequence_length_kv)

print(type(upper_left_bias))
print(type(lower_right_bias))

assert type(upper_left_bias) == type(lower_right_bias)
assert issubclass(type(upper_left_bias), torch.Tensor)

# As you can see from the previous output, are the same type ``torch.nn.attention.bias.CausalBias``
# and subclass ``torch.Tensor``

# Lets see what these tensors look like
print(upper_left_bias)
print(lower_right_bias)

# Upper Left Bias aligns the causal attention mask to the upper left corner of the attention scores matrix.
# This only has an impact when the attention scores matrix is not square, which is common for decoding use cases.
# Another way of thinking about this concept is that when you use upper left bias,
# the 0th token in the query is aligned to the 0th token in the key, while for lower right bias,
# Assuming the attention score matrix is two dimensional, ``attn_score[0][0]`` is the attention score
# between the 0th token in the query and the 0th token in the key.
# For lower right bias, the sequence of q is aligned so that the last token in q is aligned to the last token in k
# (for example, ``attn_score[-1][-1])`` is all True since the last token in q is at the same position as the last token in k
# even if the sequence length of q and k are different.

# These objects are intended to be used with sdpa
out_upper_left = F.scaled_dot_product_attention(query, key, value, upper_left_bias)
out_lower_right = F.scaled_dot_product_attention(query, key, value, lower_right_bias)
out_is_causal = F.scaled_dot_product_attention(query, key, value, is_causal=True)

assert torch.allclose(out_upper_left, out_is_causal)
assert not torch.allclose(out_upper_left, out_lower_right)

# These attention biases should also be compatible with torch.compile
compiled_sdpa = torch.compile(F.scaled_dot_product_attention, fullgraph=True)
out_upper_left = compiled_sdpa(query, key, value, upper_left_bias)
<class 'torch.nn.attention.bias.CausalBias'>
<class 'torch.nn.attention.bias.CausalBias'>
tensor([[ True, False, False, False, False, False, False, False, False, False],
        [ True,  True, False, False, False, False, False, False, False, False]])
tensor([[ True,  True,  True,  True,  True,  True,  True,  True,  True, False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True]])

结论#

在本教程中,我们演示了 torch.nn.functional.scaled_dot_product_attention 的基本用法。我们展示了如何使用 sdpa_kernel 上下文管理器来确保在 GPU 上使用特定的实现。同时,我们构建了一个简单的、支持 NestedTensor 且可使用 torch 编译的 CausalSelfAttention 模块。在此过程中,我们展示了如何使用性能分析工具来探索用户自定义模块的性能特征。

脚本总运行时间: (0 分钟 6.186 秒)