-
Notifications
You must be signed in to change notification settings - Fork 27.4k
Large numeric divergence for torch compile vs eager in bf16 #168126
Description
🐛 Describe the bug
Context
I'm observing large numeric divergence (max rtol=117924) between torch.compile and eager execution when running the code below with bfloat16. For context this layer came from the Boltz repo (code) and is one of the building blocks in an Alphafold-esque model.
I'm hoping to get some help with understanding where the numeric divergence comes from, and ideally figure out if there's a way to selectively apply torch compile so we can get speedups without a big difference in results from eager.
Repro
import torch
from torch import Tensor, nn
class TriangleMultiplicationOutgoing(nn.Module):
def __init__(self, dim: int = 128) -> None:
super().__init__()
self.norm_in = nn.LayerNorm(dim, eps=1e-5)
self.p_in = nn.Linear(dim, 2 * dim, bias=False)
self.g_in = nn.Linear(dim, 2 * dim, bias=False)
self.norm_out = nn.LayerNorm(dim)
self.p_out = nn.Linear(dim, dim, bias=False)
self.g_out = nn.Linear(dim, dim, bias=False)
def forward(self, x: Tensor, mask: Tensor) -> Tensor:
x = self.norm_in(x)
x_in = x
x = self.p_in(x) * self.g_in(x).sigmoid()
x = x * mask.unsqueeze(-1)
a, b = torch.chunk(x.float(), 2, dim=-1)
x = torch.einsum("bikd,bjkd->bijd", a, b)
x = self.p_out(self.norm_out(x)) * self.g_out(x_in).sigmoid()
x = x + x_in
return x
if __name__ == "__main__":
for dtype in [torch.float32, torch.bfloat16]:
print(f"Testing with dtype: {dtype}")
with torch.autocast(device_type="cuda", dtype=dtype):
torch.manual_seed(42)
x = torch.randn(16, 128, 128, 128, device="cuda")
mask = torch.randint(0, 2, (16, 128, 128), device="cuda")
eager_layer = TriangleMultiplicationOutgoing().cuda()
compiled_layer = torch.compile(TriangleMultiplicationOutgoing().cuda(), fullgraph=True)
# Copy weights from reference to optimized to ensure identical parameters
with torch.no_grad():
for param, ref_param in zip(compiled_layer.parameters(), eager_layer.parameters()):
param.data.copy_(ref_param.data)
out_eager = eager_layer(x, mask)
out_compiled = compiled_layer(x, mask)
torch.testing.assert_close(out_eager, out_compiled)
print(f"Passed with dtype: {dtype}")Findings
- I think it's related to TorchInductor because the repro passes when I set
torch.compile(..., backend='aot_eager'). - tlparse log: dedicated_log_torch_trace_kl7ekd1z.log
Error logs
This is the output I get when I execute my repro:
(repro) jamin@jamin-dev:~/deep-affinity$ python repro.py
Testing with dtype: torch.float32
/home/jamin/miniconda3/envs/repro/lib/python3.10/site-packages/torch/_inductor/compile_fx.py:312: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.
warnings.warn(
Passed with dtype: torch.float32
Testing with dtype: torch.bfloat16
Traceback (most recent call last):
File "/home/jamin/deep-affinity/repro.py", line 47, in <module>
torch.testing.assert_close(out_eager, out_compiled)
File "/home/jamin/miniconda3/envs/repro/lib/python3.10/site-packages/torch/testing/_comparison.py", line 1589, in assert_close
raise error_metas[0].to_error(msg)
AssertionError: Tensor-likes are not close!
Mismatched elements: 33347983 / 33554432 (99.4%)
Greatest absolute difference: 0.018128156661987305 at index (7, 15, 43, 81) (up to 1e-05 allowed)
Greatest relative difference: 117924.3515625 at index (15, 23, 89, 84) (up to 1.3e-06 allowed)
Versions
PyTorch version: 2.9.1+cu130
Is debug build: False
CUDA used to build PyTorch: 13.0
ROCM used to build PyTorch: N/A
OS: Ubuntu 22.04.5 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04.2) 11.4.0
Clang version: 14.0.0-1ubuntu1.1
CMake version: version 3.22.1
Libc version: glibc-2.35
Python version: 3.10.13 (main, Sep 11 2023, 13:44:35) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-6.8.0-1043-gcp-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 13.0.88
CUDA_MODULE_LOADING set to:
GPU models and configuration: GPU 0: NVIDIA H100 80GB HBM3
Nvidia driver version: 580.95.05
cuDNN version: Could not collect
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 52 bits physical, 57 bits virtual
Byte Order: Little Endian
CPU(s): 26
On-line CPU(s) list: 0-25
Vendor ID: GenuineIntel
Model name: Intel(R) Xeon(R) Platinum 8481C CPU @ 2.70GHz
CPU family: 6
Model: 143
Thread(s) per core: 2
Core(s) per socket: 13
Socket(s): 1
Stepping: 8
BogoMIPS: 5399.99
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch ssbd ibrs ibpb stibp ibrs_enhanced fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid rtm avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves avx_vnni avx512_bf16 arat avx512vbmi umip avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid cldemote movdiri movdir64b fsrm md_clear serialize tsxldtrk amx_bf16 avx512_fp16 amx_tile amx_int8 arch_capabilities
Hypervisor vendor: KVM
Virtualization type: full
L1d cache: 624 KiB (13 instances)
L1i cache: 416 KiB (13 instances)
L2 cache: 26 MiB (13 instances)
L3 cache: 105 MiB (1 instance)
NUMA node(s): 1
NUMA node0 CPU(s): 0-25
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI SW loop, KVM SW loop
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Vulnerability Vmscape: Not affected
Versions of relevant libraries:
[pip3] numpy==2.1.2
[pip3] nvidia-cublas==13.0.0.19
[pip3] nvidia-cuda-cupti==13.0.48
[pip3] nvidia-cuda-nvrtc==13.0.48
[pip3] nvidia-cuda-runtime==13.0.48
[pip3] nvidia-cudnn-cu13==9.13.0.50
[pip3] nvidia-cufft==12.0.0.15
[pip3] nvidia-curand==10.4.0.35
[pip3] nvidia-cusolver==12.0.3.29
[pip3] nvidia-cusparse==12.6.2.49
[pip3] nvidia-cusparselt-cu13==0.8.0
[pip3] nvidia-nccl-cu13==2.27.7
[pip3] nvidia-nvjitlink==13.0.39
[pip3] nvidia-nvtx==13.0.39
[pip3] torch==2.9.1+cu130
[pip3] torchvision==0.24.1+cu130
[pip3] triton==3.5.1
[conda] numpy 2.1.2 pypi_0 pypi
[conda] nvidia-cublas 13.0.0.19 pypi_0 pypi
[conda] nvidia-cuda-cupti 13.0.48 pypi_0 pypi
[conda] nvidia-cuda-nvrtc 13.0.48 pypi_0 pypi
[conda] nvidia-cuda-runtime 13.0.48 pypi_0 pypi
[conda] nvidia-cudnn-cu13 9.13.0.50 pypi_0 pypi
[conda] nvidia-cufft 12.0.0.15 pypi_0 pypi
[conda] nvidia-curand 10.4.0.35 pypi_0 pypi
[conda] nvidia-cusolver 12.0.3.29 pypi_0 pypi
[conda] nvidia-cusparse 12.6.2.49 pypi_0 pypi
[conda] nvidia-cusparselt-cu13 0.8.0 pypi_0 pypi
[conda] nvidia-nccl-cu13 2.27.7 pypi_0 pypi
[conda] nvidia-nvjitlink 13.0.39 pypi_0 pypi
[conda] nvidia-nvtx 13.0.39 pypi_0 pypi
[conda] torch 2.9.1+cu130 pypi_0 pypi
[conda] torchvision 0.24.1+cu130 pypi_0 pypi
[conda] triton 3.5.1 pypi_0 pypi
cc @ezyang @gchanan @kadeng @msaroufim @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @muchulee8 @amjames @aakhundov @coconutruben @jataylo @bdhirsh @bobrenjc93 @aorenste @chenyang78