Skip to content

torch.compile fails for torch.func.grad with nested function:f(f(x)) #169783

@bysdxt

Description

@bysdxt

🐛 Describe the bug

torch.compile will fail for torch.func.grad with f(f(x)), but not with f(x)

import os
# os.environ["TORCHDYNAMO_VERBOSE"] = "1"
# os.environ["TORCH_LOGS"] = "+dynamo"
import platform
import sys
import typing
import math
import torch

param_pi2 = math.pi * 2
param_reciprocal_2pi = 1 / param_pi2
class impl(torch.autograd.Function):
    @staticmethod
    def forward(x):
        pi_2_x = torch.mul(torch.frac(x), param_pi2) # 2 * pi * x
        # x - sin(2 * pi * x) / (2 * pi)
        y = torch.sub(x, torch.sin(pi_2_x), alpha=param_reciprocal_2pi)
        return y, pi_2_x
    @staticmethod
    def setup_context(ctx, inputs, output):
        _, pi_2_x = output
        ctx.save_for_backward(pi_2_x)
    @staticmethod
    def backward(ctx, *grad_outputs):
        grad, _ = grad_outputs
        pi_2_x, = ctx.saved_tensors
        return torch.mul(grad, torch.sub(1, torch.cos(pi_2_x)))

def f(x: torch.Tensor) -> torch.Tensor:
    y, _ = impl.apply(x) # type: ignore
    return y
def g(x: torch.Tensor) -> torch.Tensor:
    # return f(x) # ok
    return f(f(x)) # fail if torch.compile for torch.func.grad
def h(x: torch.Tensor) -> torch.Tensor:
    return g(x).sum()

if __name__ == '__main__':
    print(platform.platform())
    print(platform.machine(), platform.architecture())
    print(sys.version)
    print(torch.__version__)
    x = torch.tensor([-1.9, -1.6, -1.4, -1.1, -0.9, -0.6, -0.4, -0.1, 0.1, 0.4, 0.6, 0.9, 1.1, 1.4, 1.6, 1.9], dtype=torch.float32)

    # t = h # ok
    # t = torch.func.grad(h) # ok
    # t = torch.compile(h,fullgraph=True) # ok
    t = torch.compile(torch.func.grad(h),fullgraph=True) # fail for torch.func.grad with f(f(x))
    # t = g # ok
    # t = torch.compile(g,fullgraph=True) # ok

    print(t(x))

code file: test.py
normal log: log1.txt
log with

os.environ["TORCHDYNAMO_VERBOSE"] = "1"
os.environ["TORCH_LOGS"] = "+dynamo"

log2.txt

Versions

Collecting environment information...
PyTorch version: 2.9.1+cu130
Is debug build: False
CUDA used to build PyTorch: 13.0
ROCM used to build PyTorch: N/A

OS: Ubuntu 24.04.3 LTS (x86_64)
GCC version: (Ubuntu 13.3.0-6ubuntu2~24.04) 13.3.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.39

Python version: 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] (64-bit runtime)
Python platform: Linux-6.6.87.2-microsoft-standard-WSL2-x86_64-with-glibc2.39
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to:
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3090
Nvidia driver version: 581.29
cuDNN version: Could not collect
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 48 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 8
On-line CPU(s) list: 0-7
Vendor ID: AuthenticAMD
Model name: AMD Ryzen 9 3950X 16-Core Processor
CPU family: 23
Model: 113
Thread(s) per core: 2
Core(s) per socket: 4
Socket(s): 1
Stepping: 0
BogoMIPS: 7000.02
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm rep_good nopl cpuid extd_apicid tsc_known_freq pni pclmulqdq ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw topoext ssbd ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 clzero xsaveerptr virt_ssbd arat umip rdpid
Hypervisor vendor: Microsoft
Virtualization type: full
L1d cache: 128 KiB (4 instances)
L1i cache: 128 KiB (4 instances)
L2 cache: 2 MiB (4 instances)
L3 cache: 16 MiB (1 instance)
NUMA node(s): 1
NUMA node0 CPU(s): 0-7
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Mitigation; untrained return thunk; SMT enabled with STIBP protection
Vulnerability Spec rstack overflow: Vulnerable: Safe RET, no microcode
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines; IBPB conditional; STIBP always-on; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected

Versions of relevant libraries:
[pip3] numpy==2.3.3
[pip3] nvidia-cublas==13.0.0.19
[pip3] nvidia-cuda-cupti==13.0.48
[pip3] nvidia-cuda-nvrtc==13.0.48
[pip3] nvidia-cuda-runtime==13.0.48
[pip3] nvidia-cudnn-cu13==9.13.0.50
[pip3] nvidia-cufft==12.0.0.15
[pip3] nvidia-curand==10.4.0.35
[pip3] nvidia-cusolver==12.0.3.29
[pip3] nvidia-cusparse==12.6.2.49
[pip3] nvidia-cusparselt-cu13==0.8.0
[pip3] nvidia-nccl-cu13==2.27.7
[pip3] nvidia-nvjitlink==13.0.39
[pip3] nvidia-nvtx==13.0.39
[pip3] torch==2.9.1+cu130
[pip3] torchvision==0.24.1+cu130
[pip3] triton==3.5.1
[conda] Could not collect

cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @kadeng @amjames @Lucaskabela @jataylo

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: dynamooncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions