Skip to content

TorchFunctionMode shouldn't be active when checking guards #172088

@ezyang

Description

@ezyang

🐛 Describe the bug

This program:

import torch
from torch.overrides import TorchFunctionMode

class CounterMode(TorchFunctionMode):
    def __init__(self):
        self.count = 0

    def __torch_function__(self, func, types, args=(), kwargs=None):
        self.count += 1
        return func(*args, **(kwargs or {}))

@torch.compile
def f(x):
    return x.sin()

with CounterMode():
    f(torch.randn(3))

Fails with:

  AssertionError: Guard failed on the same frame it was created. This is a bug - please create an issue.
  Guard fail reason: 0/0: ___get_torch_function_mode_stack_at(0).count == 5

I care about this because we've stubbed our toe on this with TorchFunctionMetadataMode for make_fx over torch.compile region use cases.

I feel there is something deeply wrong with how guards are handled for torch function mode. In particular, the system invariant that I feel is being violated is we must NEVER execute arbitrary user code during compilation. Running guards with the real user torch function mode active means we do this, which therefore is bad. So concretely, I would say that torch function modes should be DISABLED when we do guard evaluation; it is Dynamo's responsibility to have already inlined the guards in this case so you don't need to trigger the torch function there.

cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @kadeng @amjames @Lucaskabela @jataylo @mlazos

Versions

main

Metadata

Metadata

Assignees

Labels

PT2-Bug-BashActionable issues for PT2-Bug-Bashmodule: dynamooncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions