评价此页

DebugMode:记录已分派的操作和数值调试#

作者:Pian Pawakapan, Shangdi Yu

您将学到什么
  • 如何捕获 eager 模式和 torch.compile 运行时的已分派算子

  • 如何使用 DebugMode 中的张量哈希和堆栈跟踪来精确定位数值偏差

先决条件
  • PyTorch 2.10 或更高版本

概述#

DebugMode (torch.utils._debug_mode.DebugMode) 是一种 TorchDispatchMode,它会拦截 PyTorch 运行时调用并生成操作的分层日志。当您需要了解在 eager 模式、torch.compile 模式下或需要精确定位两次运行之间的数值偏差时,它特别有用。

主要功能

  • 运行时日志记录 – 记录已分派的操作和 TorchInductor 编译的 Triton 内核。

  • 张量哈希 – 为输入/输出附加确定性哈希值,以便通过对比运行差异来定位数值偏差。

  • 分派钩子 (Dispatch hooks) – 允许注册自定义钩子以对调用进行注释

注意

本教程描述的是一项原型功能。原型功能通常处于反馈和测试的早期阶段,可能会发生变更。

快速入门#

下面的代码片段捕获了一个小型 eager 工作负载并打印调试字符串

from torch._inductor.decomposition import decomps_to_exclude
import torch
from torch.utils._debug_mode import DebugMode

def run_once():
    x = torch.randn(8, 8)
    y = torch.randn(8, 8)
    return torch.mm(torch.relu(x), y)

with DebugMode() as debug_mode:
    out = run_once()

print("DebugMode output:")
print(debug_mode.debug_string())
DebugMode output:
    aten::randn([8, 8], device=cpu, pin_memory=False)  ->  t: f32[8, 8]
    aten::randn([8, 8], device=cpu, pin_memory=False)  ->  t: f32[8, 8]
    aten::relu(t: f32[8, 8])  ->  t: f32[8, 8]
    aten::mm(t: f32[8, 8], t: f32[8, 8])  ->  t: f32[8, 8]

获取更多元数据#

对于大多数调查,您需要启用堆栈跟踪、张量 ID 和张量哈希。这些功能提供的元数据可将操作关联回模型代码。

DebugMode.log_tensor_hashes 会为每次调用在日志中添加哈希值。hash_tensor 哈希函数使用 torch.hash_tensor,它对所有元素相同的张量返回 0。norm 哈希函数使用 p=1norm。通过这两个函数(特别是 norm),数值上的张量接近度与哈希值的接近度相关,因此具有较好的可解释性。默认的 hash_fnnorm

with (
    DebugMode(
        # record_stack_trace is only supported for eager in pytorch 2.10
        record_stack_trace=True,
        record_ids=True,
    ) as debug_mode,
    DebugMode.log_tensor_hashes(
        hash_fn=["norm"], # this is the default
        hash_inputs=True,
    ),
):
    result = run_once()

print("DebugMode output with more metadata:")
print(
    debug_mode.debug_string(show_stack_trace=True)
)
DebugMode output with more metadata:
    # File: /var/lib/workspace/recipes_source/debug_mode_tutorial.py:59 in run_once, code: x = torch.randn(8, 8)
    aten::randn([8, 8], device=cpu, pin_memory=False)  ->  t$0: f32[8, 8]  # {'hash': (50.11742821428925,)}

    # File: /var/lib/workspace/recipes_source/debug_mode_tutorial.py:60 in run_once, code: y = torch.randn(8, 8)
    aten::randn([8, 8], device=cpu, pin_memory=False)  ->  t$1: f32[8, 8]  # {'hash': (71.43443126231432,)}

    # File: /var/lib/workspace/recipes_source/debug_mode_tutorial.py:61 in run_once, code: return torch.mm(torch.relu(x), y)
    aten::relu(t$0: f32[8, 8])  ->  t$2: f32[8, 8]  # {'input_hash': (((50.11742821428925,),), {}), 'hash': (26.50899614393711,)}
    aten::mm(t$2: f32[8, 8], t$1: f32[8, 8])  ->  t$3: f32[8, 8]  # {'input_hash': (((26.50899614393711,), (71.43443126231432,)), {}), 'hash': (136.20053039491177,)}

每一行遵循 op(args) -> outputs 格式。当启用 record_ids 时,张量后缀为 $<id>,DTensor 标记为 dt

记录 Triton 内核#

尽管 Triton 内核不是通过分派调用的,但 DebugMode 具有自定义逻辑来记录它们的输入和输出。

Inductor 生成的 Triton 内核会显示为 [triton] 前缀。预处理/后处理哈希注释会报告每次内核调用前后的缓冲区哈希值,这在隔离错误内核时非常有用。

def f(x):
    return torch.mm(torch.relu(x), x.T)

x = torch.randn(3, 3, device="cuda")

with (
    DebugMode(record_output=True) as debug_mode,
    DebugMode.log_tensor_hashes(
        hash_inputs=True,
    )
):
    a = torch.compile(f)(x)

print("Triton in DebugMode logs:")
print(debug_mode.debug_string())
/usr/local/lib/python3.10/dist-packages/torch/_inductor/compile_fx.py:320: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.
  warnings.warn(
Triton in DebugMode logs:
    aten::_to_copy(t: f32[3, 3], dtype=torch.float64)  ->  t: f64[3, 3]  # {'input_hash': ((7.583102658390999,), {'dtype': None}), 'hash': 7.583102658390999}
    aten::linalg_vector_norm(t: f64[3, 3], 1)  ->  t: f64[]  # {'input_hash': ((7.583102658390999, None), {}), 'hash': 7.583102658390999}
    aten::_local_scalar_dense(t: f64[])  ->  7.583102658390999  # {'input_hash': ((7.583102658390999,), {}), 'hash': None}
    aten::_to_copy(t: f32[3, 3], dtype=torch.float64)  ->  t: f64[3, 3]  # {'input_hash': ((5.458498023450375,), {'dtype': None}), 'hash': 5.458498023450375}
    aten::linalg_vector_norm(t: f64[3, 3], 1)  ->  t: f64[]  # {'input_hash': ((5.458498023450375, None), {}), 'hash': 5.458498023450375}
    aten::_local_scalar_dense(t: f64[])  ->  5.458498023450375  # {'input_hash': ((5.458498023450375,), {}), 'hash': None}
    [triton] triton_poi_fused_relu_0(in_ptr0=t: f32[3, 3], out_ptr0=t: f32[3, 3], xnumel=9)
    # pre-kernel hashes: {in_ptr0: 7.583102658390999, out_ptr0: 5.458498023450375}
    # post-kernel hashes: {in_ptr0: 7.583102658390999, out_ptr0: 4.270898416638374}

    aten::_to_copy(t: f32[3, 3], dtype=torch.float64)  ->  t: f64[3, 3]  # {'input_hash': ((7.583102658390999,), {'dtype': None}), 'hash': 7.583102658390999}
    aten::linalg_vector_norm(t: f64[3, 3], 1)  ->  t: f64[]  # {'input_hash': ((7.583102658390999, None), {}), 'hash': 7.583102658390999}
    aten::_local_scalar_dense(t: f64[])  ->  7.583102658390999  # {'input_hash': ((7.583102658390999,), {}), 'hash': None}
    aten::_to_copy(t: f32[3, 3], dtype=torch.float64)  ->  t: f64[3, 3]  # {'input_hash': ((4.270898416638374,), {'dtype': None}), 'hash': 4.270898416638374}
    aten::linalg_vector_norm(t: f64[3, 3], 1)  ->  t: f64[]  # {'input_hash': ((4.270898416638374, None), {}), 'hash': 4.270898416638374}
    aten::_local_scalar_dense(t: f64[])  ->  4.270898416638374  # {'input_hash': ((4.270898416638374,), {}), 'hash': None}
    aten::mm.out(t: f32[3, 3], t: f32[3, 3], out=t: f32[3, 3])  ->  t: f32[3, 3]  # {'input_hash': ((4.270898416638374, 7.583102658390999), {'out': 9.534462571144104}), 'hash': 10.67264581733616}

使用张量哈希进行数值调试#

如果您在不同模式之间发现数值偏差,可以使用 DebugMode 找到数值偏差的根源。在下面的示例中,您可以看到 eager 模式和编译模式下的所有张量哈希值都是相同的。如果有任何哈希值不同,那便是数值偏差产生的地方。

def run_model(model, data, *, compile_with=None):
    if compile_with is not None:
        model = torch.compile(model, backend=compile_with)
    with DebugMode(record_output=True) as dm, DebugMode.log_tensor_hashes(
        hash_inputs=True,
    ):
        dm_out = model(*data)
    return dm, dm_out

class Toy(torch.nn.Module):
    def forward(self, x):
        return torch.relu(x).mm(x.T)

inputs = (torch.randn(4, 4),)
dm_eager, _ = run_model(Toy(), inputs)
dm_compiled, _ = run_model(Toy(), inputs, compile_with="aot_eager")

print("Eager mode:")
print(dm_eager.debug_string())
print("Compiled aot_eager mode:")
print(dm_compiled.debug_string())
Eager mode:
    aten::relu(t: f32[4, 4])  ->  t: f32[4, 4]  # {'input_hash': ((12.58324533700943,), {}), 'hash': 6.811520010232925}
    aten::permute(t: f32[4, 4], [1, 0])  ->  t: f32[4, 4]  # {'input_hash': ((12.58324533700943, [None, None]), {}), 'hash': 12.58324533700943}
    aten::mm(t: f32[4, 4], t: f32[4, 4])  ->  t: f32[4, 4]  # {'input_hash': ((6.811520010232925, 12.58324533700943), {}), 'hash': 18.071126878261566}
Compiled aot_eager mode:
    aten::relu(t: f32[4, 4])  ->  t: f32[4, 4]  # {'input_hash': ((12.58324533700943,), {}), 'hash': 6.811520010232925}
    aten::permute(t: f32[4, 4], [1, 0])  ->  t: f32[4, 4]  # {'input_hash': ((12.58324533700943, [None, None]), {}), 'hash': 12.58324533700943}
    aten::mm(t: f32[4, 4], t: f32[4, 4])  ->  t: f32[4, 4]  # {'input_hash': ((6.811520010232925, 12.58324533700943), {}), 'hash': 18.071126878261566}

现在让我们看一个张量哈希值不同的示例。我故意写了一个错误的分解,将余弦分解为正弦。这将导致数值偏差。

from torch._dynamo.backends.common import aot_autograd
from torch._dynamo.backends.debugging import get_nop_func

def wrong_decomp(x):
    return torch.sin(x)

decomp_table = {}
decomp_table[torch.ops.aten.cos.default] = wrong_decomp

backend = aot_autograd(
    fw_compiler=get_nop_func(),
    bw_compiler=get_nop_func(),
    decompositions=decomp_table
)

def f(x):
    y = x.relu()
    z = torch.cos(x)
    return y + z

x = torch.randn(3, 3)
with DebugMode(record_output=True) as dm_eager, DebugMode.log_tensor_hashes(
    hash_inputs=True,
):
    f(x)

with DebugMode(record_output=True) as dm_compiled, DebugMode.log_tensor_hashes(
    hash_inputs=True,
):
    torch.compile(f, backend=backend)(x)

print("Eager:")
print(dm_eager.debug_string(show_stack_trace=True))
print()
print("Compiled with wrong decomposition:")
print(dm_compiled.debug_string())
Eager:
    aten::relu(t: f32[3, 3])  ->  t: f32[3, 3]  # {'input_hash': ((5.443918265402317,), {}), 'hash': 2.8230451717972755}
    aten::cos(t: f32[3, 3])  ->  t: f32[3, 3]  # {'input_hash': ((5.443918265402317,), {}), 'hash': 6.760322332382202}
    aten::add.Tensor(t: f32[3, 3], t: f32[3, 3])  ->  t: f32[3, 3]  # {'input_hash': ((2.8230451717972755, 6.760322332382202), {}), 'hash': 9.583367466926575}

Compiled with wrong decomposition:
    aten::relu(t: f32[3, 3])  ->  t: f32[3, 3]  # {'input_hash': ((5.443918265402317,), {}), 'hash': 2.8230451717972755}
    aten::sin(t: f32[3, 3])  ->  t: f32[3, 3]  # {'input_hash': ((5.443918265402317,), {}), 'hash': 4.617258846759796}
    aten::add.Tensor(t: f32[3, 3], t: f32[3, 3])  ->  t: f32[3, 3]  # {'input_hash': ((2.8230451717972755, 4.617258846759796), {}), 'hash': 7.440304018557072}

在 eager 日志中,我们有 aten::cos,但在编译日志中,我们有 aten::sin。此外,输出哈希在 eager 模式和编译模式下也不同。对比两个日志会显示,第一个数值偏差出现在 aten::cos 调用中。

自定义分派钩子#

钩子允许您使用自定义元数据(如 GPU 内存使用情况)对每个调用进行注释。log_hook 返回一个与调试字符串内联呈现的映射。

MB = 1024 * 1024.0

def memory_hook(func, types, args, kwargs, result):
    mem = torch.cuda.memory_allocated() / MB if torch.cuda.is_available() else 0.0
    peak = torch.cuda.max_memory_allocated() / MB if torch.cuda.is_available() else 0.0
    torch.cuda.reset_peak_memory_stats() if torch.cuda.is_available() else None
    return {"mem": f"{mem:.3f} MB", "peak": f"{peak:.3f} MB"}

with (
    DebugMode() as dm,
    DebugMode.dispatch_hooks(log_hook=memory_hook),
):
    run_once()

print("DebugMode output with memory usage:")
print(dm.debug_string())
DebugMode output with memory usage:
    aten::randn([8, 8], device=cpu, pin_memory=False)  ->  t: f32[8, 8]  # {'mem': '8.125 MB', 'peak': '14.128 MB'}
    aten::randn([8, 8], device=cpu, pin_memory=False)  ->  t: f32[8, 8]  # {'mem': '8.125 MB', 'peak': '8.125 MB'}
    aten::relu(t: f32[8, 8])  ->  t: f32[8, 8]  # {'mem': '8.125 MB', 'peak': '8.125 MB'}
    aten::mm(t: f32[8, 8], t: f32[8, 8])  ->  t: f32[8, 8]  # {'mem': '8.125 MB', 'peak': '8.125 MB'}

模块边界#

record_nn_module=True 会插入 [nn.Mod] 标记,显示哪个模块执行了每组操作。截至 PyTorch 2.10,它仅在 eager 模式下有效,但对编译模式的支持正在开发中。

class Foo(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.l1 = torch.nn.Linear(4, 4)
            self.l2 = torch.nn.Linear(4, 4)

        def forward(self, x):
            return self.l2(self.l1(x))

class Bar(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.abc = Foo()
        self.xyz = torch.nn.Linear(4, 4)

    def forward(self, x):
        return self.xyz(self.abc(x))

mod = Bar()
inp = torch.randn(4, 4)
with DebugMode(record_nn_module=True, record_output=False) as debug_mode:
    _ = mod(inp)

print("DebugMode output with stack traces and module boundaries:")
print(debug_mode.debug_string(show_stack_trace=True))
DebugMode output with stack traces and module boundaries:
    [nn.Mod] Bar
      [nn.Mod] Bar.abc
        [nn.Mod] Bar.abc.l1
          aten::t(t: f32[4, 4])
          aten::addmm(t: f32[4], t: f32[4, 4], t: f32[4, 4])
        [nn.Mod] Bar.abc.l2
          aten::t(t: f32[4, 4])
          aten::addmm(t: f32[4], t: f32[4, 4], t: f32[4, 4])
      [nn.Mod] Bar.xyz
        aten::t(t: f32[4, 4])
        aten::addmm(t: f32[4], t: f32[4, 4], t: f32[4, 4])

结论#

在本教程中,我们了解了 DebugMode 如何为您提供轻量级、仅限运行时的 PyTorch 实际执行情况视图,无论您运行的是 eager 代码还是编译后的图。通过组合张量哈希、Triton 日志记录和自定义分派钩子,您可以快速追踪数值差异。这对于调试两次运行之间的位级等效性特别有用。

脚本总运行时间: (0 分 2.695 秒)