Skip to content

[dynamo] change error_on_graph_break/fullgraph semantics#161747

Closed
williamwen42 wants to merge 8 commits intogh/williamwen42/284/basefrom
gh/williamwen42/284/head
Closed

[dynamo] change error_on_graph_break/fullgraph semantics#161747
williamwen42 wants to merge 8 commits intogh/williamwen42/284/basefrom
gh/williamwen42/284/head

Conversation

@williamwen42
Copy link
Copy Markdown
Member

@williamwen42 williamwen42 commented Aug 28, 2025

Stack from ghstack (oldest at bottom):

This PR implements the semantics change to torch._dynamo.error_on_graph_break:

  • torch.compile now has a new error_on_graph_break kwarg that serves as a lower-priority toggle for erroring/continuing on graph breaks
  • error_on_graph_break is a new internal torch.compile setting that is lower-priority than fullgraph. It allows the user to toggle erroring/continuing on graph breaks.
  • error_on_graph_break does nothing when fullgraph=True
  • error_on_graph_break does NOT guarantee a single graph

Followup [DONE]: need to change the programming model docs to reflect the 3 graph break modes for compilation:

  • fullgraph=True: enforce one graph, no graph breaks, cannot be toggled
  • fullgraph=False, error_on_graph_break=True: errors on graph breaks, latter can be toggled during compile time
  • fullgraph=False, error_on_graph_break=False: resumes tracing on graph breaks, latter can be toggled during compile time

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @Lucaskabela @mlazos

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Aug 28, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/161747

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 7a755fe with merge base 734ce8e (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@mlazos
Copy link
Copy Markdown
Contributor

mlazos commented Aug 29, 2025

Why does this need to be a kwarg to compile? couldn't this just be a custom decorator? I think this will be really confusing for users.

@guilhermeleobas
Copy link
Copy Markdown
Collaborator

guilhermeleobas commented Aug 29, 2025

I don't get the difference of errors_on_graph_break=True/False. When fullgraph=False, wouldn't Dynamo already handle the graph and split the graph? Got it

@williamwen42
Copy link
Copy Markdown
Member Author

@mlazos the kwarg is not strictly necessary, but it does allow for a user to specify the initial setting without an additional decorator (e.g. torch.compile(fn, error_on_graph_break=True).

Here's a table summary of fullgraph vs. error_on_graph_break

Screenshot 2025-08-29 at 11 10 43 AM

Comment thread test/dynamo/test_decorators.py
Comment thread torch/__init__.py Outdated
in the function that it will optimize. If True, then we require that the entire function be
capturable into a single graph. If this is not possible (that is, if there are graph breaks),
then this will raise an error.
error_on_graph_break (bool): If `fullgraph` is set, then this arg does nothing
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

once we have some comprehensive documentation on this, might be worth linking to an example that shows when users would want to use this setting

@mlazos mlazos self-requested a review August 29, 2025 21:33
@mlazos
Copy link
Copy Markdown
Contributor

mlazos commented Aug 29, 2025

My thinking on this is that if this isn't a super common use case, it shouldn't be a kwarg due to the confusion with fullgraph (which I think comparatively is pretty common). If this is a power-user API I think a decorator makes more sense. If you're okay with this could we hold off on the kwarg until we see how often it would be needed (e.g. seeing people add the decorator to the top-level function)

@williamwen42
Copy link
Copy Markdown
Member Author

@mlazos The idea going forward is that fullgraph=True is for power users/frameworks that really want to enforce no graph breaks/one graph guarantee. error_on_graph_break is a setting for the average user to disallow most graph breaks, but allow ones that are difficult to deal with.

This PR implements the semantics change to `torch._dynamo.error_on_graph_break`:
- `torch.compile` now has a new `error_on_graph_break` kwarg that serves as a lower-priority toggle for erroring/continuing on graph breaks
- `error_on_graph_break` does nothing when `fullgraph=True`
- `error_on_graph_break` does NOT guarantee a single graph

Followup: need to change the programming model docs to reflect the 3 graph break modes for compilation:
- `fullgraph=True`: enforce one graph, no graph breaks, cannot be toggled
- `fullgraph=False, error_on_graph_break=True`: errors on graph breaks, latter can be toggled during compile time
- `fullgraph=False, error_on_graph_break=False`: resumes tracing on graph breaks, latter can be toggled during compile time

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
williamwen42 added a commit that referenced this pull request Aug 29, 2025
@williamwen42 williamwen42 added the module: compile ux UX issues related to torch.compile label Aug 29, 2025
@williamwen42 williamwen42 added the keep-going Don't stop on first failure, keep running tests until the end label Aug 29, 2025
This PR implements the semantics change to `torch._dynamo.error_on_graph_break`:
- `torch.compile` now has a new `error_on_graph_break` kwarg that serves as a lower-priority toggle for erroring/continuing on graph breaks
- `error_on_graph_break` does nothing when `fullgraph=True`
- `error_on_graph_break` does NOT guarantee a single graph

Followup [DONE]: need to change the programming model docs to reflect the 3 graph break modes for compilation:
- `fullgraph=True`: enforce one graph, no graph breaks, cannot be toggled
- `fullgraph=False, error_on_graph_break=True`: errors on graph breaks, latter can be toggled during compile time
- `fullgraph=False, error_on_graph_break=False`: resumes tracing on graph breaks, latter can be toggled during compile time

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben Lucaskabela mlazos

[ghstack-poisoned]
williamwen42 added a commit that referenced this pull request Sep 2, 2025
@williamwen42
Copy link
Copy Markdown
Member Author

@mlazos I removed error_on_graph_break as a kwarg to torch.compile - we can add it later again if deemed helpful.

This PR implements the semantics change to `torch._dynamo.error_on_graph_break`:
- `torch.compile` now has a new `error_on_graph_break` kwarg that serves as a lower-priority toggle for erroring/continuing on graph breaks
- `error_on_graph_break` does nothing when `fullgraph=True`
- `error_on_graph_break` does NOT guarantee a single graph

Followup [DONE]: need to change the programming model docs to reflect the 3 graph break modes for compilation:
- `fullgraph=True`: enforce one graph, no graph breaks, cannot be toggled
- `fullgraph=False, error_on_graph_break=True`: errors on graph breaks, latter can be toggled during compile time
- `fullgraph=False, error_on_graph_break=False`: resumes tracing on graph breaks, latter can be toggled during compile time

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben Lucaskabela mlazos

[ghstack-poisoned]
williamwen42 added a commit that referenced this pull request Sep 2, 2025
This PR implements the semantics change to `torch._dynamo.error_on_graph_break`:
- `torch.compile` now has a new `error_on_graph_break` kwarg that serves as a lower-priority toggle for erroring/continuing on graph breaks
- `error_on_graph_break` does nothing when `fullgraph=True`
- `error_on_graph_break` does NOT guarantee a single graph

Followup [DONE]: need to change the programming model docs to reflect the 3 graph break modes for compilation:
- `fullgraph=True`: enforce one graph, no graph breaks, cannot be toggled
- `fullgraph=False, error_on_graph_break=True`: errors on graph breaks, latter can be toggled during compile time
- `fullgraph=False, error_on_graph_break=False`: resumes tracing on graph breaks, latter can be toggled during compile time

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben Lucaskabela mlazos

[ghstack-poisoned]
This PR implements the semantics change to `torch._dynamo.error_on_graph_break`:
- ~`torch.compile` now has a new `error_on_graph_break` kwarg that serves as a lower-priority toggle for erroring/continuing on graph breaks~
- `error_on_graph_break` is a new internal `torch.compile `setting that is lower-priority than `fullgraph`. It allows the user to toggle erroring/continuing on graph breaks.
- `error_on_graph_break` does nothing when `fullgraph=True`
- `error_on_graph_break` does NOT guarantee a single graph

Followup [DONE]: need to change the programming model docs to reflect the 3 graph break modes for compilation:
- `fullgraph=True`: enforce one graph, no graph breaks, cannot be toggled
- `fullgraph=False, error_on_graph_break=True`: errors on graph breaks, latter can be toggled during compile time
- `fullgraph=False, error_on_graph_break=False`: resumes tracing on graph breaks, latter can be toggled during compile time

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben Lucaskabela mlazos

[ghstack-poisoned]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on the source code, this one should fail due to support.gc_collect()

def test_highly_nested_subclass(self):
    # Issues 25395 and 35983: test that the trashcan mechanism works
    # correctly for OrderedDict: deleting a highly nested OrderDict
    # should not crash Python.
    OrderedDict = self.OrderedDict
    deleted = []
    with torch._dynamo.set_fullgraph(fullgraph=False):
        class MyOD(OrderedDict):
            def __del__(self):
                deleted.append(self.i)
    obj = None
    for i in range(100):
        obj = MyOD([(None, obj)])
        obj.i = i
    del obj
    support.gc_collect()
    self.assertEqual(deleted, list(reversed(range(100))))

Comment thread test/dynamo/test_modes.py
return func(*args, **kwargs)

# test e2e, with Inductor, as smoketest.
@torch.compile(fullgraph=True, backend="inductor")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this test? I'm curious haha

Copy link
Copy Markdown
Member Author

@williamwen42 williamwen42 Sep 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was a test that was added after set_fullgraph was merged - in that change, fullgraph=True went through the same convert_frame.py path as fullgraph=False, which increments counters["frames"]["total"]. This PR reverted that change so that the counter is no longer incremented.

This PR implements the semantics change to `torch._dynamo.error_on_graph_break`:
- ~`torch.compile` now has a new `error_on_graph_break` kwarg that serves as a lower-priority toggle for erroring/continuing on graph breaks~
- `error_on_graph_break` is a new internal `torch.compile `setting that is lower-priority than `fullgraph`. It allows the user to toggle erroring/continuing on graph breaks.
- `error_on_graph_break` does nothing when `fullgraph=True`
- `error_on_graph_break` does NOT guarantee a single graph

Followup [DONE]: need to change the programming model docs to reflect the 3 graph break modes for compilation:
- `fullgraph=True`: enforce one graph, no graph breaks, cannot be toggled
- `fullgraph=False, error_on_graph_break=True`: errors on graph breaks, latter can be toggled during compile time
- `fullgraph=False, error_on_graph_break=False`: resumes tracing on graph breaks, latter can be toggled during compile time

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben Lucaskabela mlazos

[ghstack-poisoned]
This PR implements the semantics change to `torch._dynamo.error_on_graph_break`:
- ~`torch.compile` now has a new `error_on_graph_break` kwarg that serves as a lower-priority toggle for erroring/continuing on graph breaks~
- `error_on_graph_break` is a new internal `torch.compile `setting that is lower-priority than `fullgraph`. It allows the user to toggle erroring/continuing on graph breaks.
- `error_on_graph_break` does nothing when `fullgraph=True`
- `error_on_graph_break` does NOT guarantee a single graph

Followup [DONE]: need to change the programming model docs to reflect the 3 graph break modes for compilation:
- `fullgraph=True`: enforce one graph, no graph breaks, cannot be toggled
- `fullgraph=False, error_on_graph_break=True`: errors on graph breaks, latter can be toggled during compile time
- `fullgraph=False, error_on_graph_break=False`: resumes tracing on graph breaks, latter can be toggled during compile time

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben Lucaskabela mlazos

[ghstack-poisoned]
@williamwen42
Copy link
Copy Markdown
Member Author

@pytorchbot merge

@pytorch-bot pytorch-bot Bot added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 4, 2025
@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
)

This PR implements the semantics change to `torch._dynamo.error_on_graph_break`:
- ~`torch.compile` now has a new `error_on_graph_break` kwarg that serves as a lower-priority toggle for erroring/continuing on graph breaks~
- `error_on_graph_break` is a new internal `torch.compile `setting that is lower-priority than `fullgraph`. It allows the user to toggle erroring/continuing on graph breaks.
- `error_on_graph_break` does nothing when `fullgraph=True`
- `error_on_graph_break` does NOT guarantee a single graph

Followup [DONE]: need to change the programming model docs to reflect the 3 graph break modes for compilation:
- `fullgraph=True`: enforce one graph, no graph breaks, cannot be toggled
- `fullgraph=False, error_on_graph_break=True`: errors on graph breaks, latter can be toggled during compile time
- `fullgraph=False, error_on_graph_break=False`: resumes tracing on graph breaks, latter can be toggled during compile time

Pull Request resolved: pytorch#161747
Approved by: https://github.com/mlazos
ghstack dependencies: pytorch#161739
mansiag05 pushed a commit to mansiag05/pytorch that referenced this pull request Sep 22, 2025
)

This PR implements the semantics change to `torch._dynamo.error_on_graph_break`:
- ~`torch.compile` now has a new `error_on_graph_break` kwarg that serves as a lower-priority toggle for erroring/continuing on graph breaks~
- `error_on_graph_break` is a new internal `torch.compile `setting that is lower-priority than `fullgraph`. It allows the user to toggle erroring/continuing on graph breaks.
- `error_on_graph_break` does nothing when `fullgraph=True`
- `error_on_graph_break` does NOT guarantee a single graph

Followup [DONE]: need to change the programming model docs to reflect the 3 graph break modes for compilation:
- `fullgraph=True`: enforce one graph, no graph breaks, cannot be toggled
- `fullgraph=False, error_on_graph_break=True`: errors on graph breaks, latter can be toggled during compile time
- `fullgraph=False, error_on_graph_break=False`: resumes tracing on graph breaks, latter can be toggled during compile time

Pull Request resolved: pytorch#161747
Approved by: https://github.com/mlazos
ghstack dependencies: pytorch#161739
cleonard530 pushed a commit to cleonard530/pytorch that referenced this pull request Sep 22, 2025
)

This PR implements the semantics change to `torch._dynamo.error_on_graph_break`:
- ~`torch.compile` now has a new `error_on_graph_break` kwarg that serves as a lower-priority toggle for erroring/continuing on graph breaks~
- `error_on_graph_break` is a new internal `torch.compile `setting that is lower-priority than `fullgraph`. It allows the user to toggle erroring/continuing on graph breaks.
- `error_on_graph_break` does nothing when `fullgraph=True`
- `error_on_graph_break` does NOT guarantee a single graph

Followup [DONE]: need to change the programming model docs to reflect the 3 graph break modes for compilation:
- `fullgraph=True`: enforce one graph, no graph breaks, cannot be toggled
- `fullgraph=False, error_on_graph_break=True`: errors on graph breaks, latter can be toggled during compile time
- `fullgraph=False, error_on_graph_break=False`: resumes tracing on graph breaks, latter can be toggled during compile time

Pull Request resolved: pytorch#161747
Approved by: https://github.com/mlazos
ghstack dependencies: pytorch#161739
dsashidh pushed a commit to dsashidh/pytorch that referenced this pull request Sep 26, 2025
)

This PR implements the semantics change to `torch._dynamo.error_on_graph_break`:
- ~`torch.compile` now has a new `error_on_graph_break` kwarg that serves as a lower-priority toggle for erroring/continuing on graph breaks~
- `error_on_graph_break` is a new internal `torch.compile `setting that is lower-priority than `fullgraph`. It allows the user to toggle erroring/continuing on graph breaks.
- `error_on_graph_break` does nothing when `fullgraph=True`
- `error_on_graph_break` does NOT guarantee a single graph

Followup [DONE]: need to change the programming model docs to reflect the 3 graph break modes for compilation:
- `fullgraph=True`: enforce one graph, no graph breaks, cannot be toggled
- `fullgraph=False, error_on_graph_break=True`: errors on graph breaks, latter can be toggled during compile time
- `fullgraph=False, error_on_graph_break=False`: resumes tracing on graph breaks, latter can be toggled during compile time

Pull Request resolved: pytorch#161747
Approved by: https://github.com/mlazos
ghstack dependencies: pytorch#161739
@github-actions github-actions Bot deleted the gh/williamwen42/284/head branch October 5, 2025 02:17
Khanaksahu pushed a commit to Khanaksahu/pytorch-fork that referenced this pull request Nov 17, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request keep-going Don't stop on first failure, keep running tests until the end Merged module: compile ux UX issues related to torch.compile module: dynamo module: inductor release notes: dynamo

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants