Skip to content

[Dynamo] Imporve-graph-break-skip-logs#167067

Closed
parsshar-RH wants to merge 5 commits intopytorch:mainfrom
parsshar-RH:improve-graph-break-skip-frame
Closed

[Dynamo] Imporve-graph-break-skip-logs#167067
parsshar-RH wants to merge 5 commits intopytorch:mainfrom
parsshar-RH:improve-graph-break-skip-frame

Conversation

@parsshar-RH
Copy link
Collaborator

@parsshar-RH parsshar-RH commented Nov 5, 2025

Fixes #150477

Summary:

  • Added frame information (function name, file, line number) to all graph break/skip messages
  • Standardized message format: "torch.compile will skip tracing the frame ( line ) and fall back to eager. Reason: "

Impacts:

module: dynamo

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames @Lucaskabela

@pytorch-bot
Copy link

pytorch-bot bot commented Nov 5, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/167067

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit dd201aa with merge base c940b1f (image):

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@Lucaskabela
Copy link
Contributor

cc @williamwen42 to help take a look at this

Copy link
Member

@williamwen42 williamwen42 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for submitting! I made a few comments.

)

@make_logging_test(graph_breaks=True)
def test_skipped_frame_no_verbose_traceback(self, records):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does this test differ from test_assert_failure_in_generic_ctx_mgr?

def test_skipped_frame_no_verbose_traceback(self, records):
def fn(x):
with GenericCtxMgr():
assert x is None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we just do a regular graph break here? Since there's some special handling of asserts that causes such a test to possibly fail internally (see comment above test_assert_failure_in_generic_ctx_mgr)


self.assertEqual(len(records), 1)

message = records[0].getMessage()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prefer to use assertExpectedInline[Munged] here in order to more clearly see what the error message looks like (otherwise, can you provide a reason why this would not be a good idea?)

Comment on lines +1892 to +1903
if config.verbose:
graph_break_log.debug(
user_stack_trace,
exc_info=True,
stack_info=True,
)
else:
graph_break_log.debug(
user_stack_trace,
exc_info=True,
stack_info=False,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if config.verbose:
graph_break_log.debug(
user_stack_trace,
exc_info=True,
stack_info=True,
)
else:
graph_break_log.debug(
user_stack_trace,
exc_info=True,
stack_info=False,
)
graph_break_log.debug(
user_stack_trace,
exc_info=True,
stack_info=config.verbose,
)

Comment on lines +892 to +899
frame_info = (
f"{getattr(self.f_code, 'co_name', '<unknown>')} "
f"({getattr(self.f_code, 'co_filename', '<unknown>')} "
f"line {getattr(self.f_code, 'co_firstlineno', 0)})"
)
msg = (
"Skipping frame because there is a graph break in a for/while loop\n"
f"torch.compile intentionally decided to skip the frame {frame_info} and fall back to eager.\n"
f"Reason: Skipping frame because there is a graph break in a for/while loop.\n"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we refactor this code (and the above) into one function?

Comment on lines +4628 to +4630
f"{getattr(self.f_code, 'co_name', '<unknown>')} "
f"({getattr(self.f_code, 'co_filename', '<unknown>')} "
f"line {getattr(self.f_code, 'co_firstlineno', 0)})"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm seeing this code appear a lot - can we move it to a single function in exc.py?


result = torch.compile(f2, backend="eager")(torch.randn(3))
expected = f2(torch.randn(3))
self.assertEqual(result.shape, expected.shape)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you check for UX here (i.e. using assertExpectedInline[Munged])? And same for the added cases below.

@parsshar-RH
Copy link
Collaborator Author

Hi @williamwen42

Thank you for the detailed review and valuable suggestions.
I will incorporate the changes and push an updated commit shortly.

Thanks again for your time and feedback.

@parsshar-RH
Copy link
Collaborator Author

@williamwen42

I have implemented the suggested changes.
Please feel free to let me know if any additional updates are needed.

Thanks!

@janeyx99 janeyx99 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Nov 7, 2025
functorch_subclass_name = re.sub(r"\(.*", "", repr(val))
raise SkipFrame(
f"torch.compile cannot be run in context: {functorch_subclass_name}"
f"torch.compile intentionally decided to skip the frame and fall back to eager.\n"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use format_skip_frame_message?

raise SkipFrame(
f"Skip frame due to `torch._dynamo.skip_frame()`. Message: {skip_frame_msg}"
frame_info = format_frame_info(tx.f_code)
msg = (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use format_skip_frame_message?

def format_skip_frame_message(code: types.CodeType, reason: str) -> str:
frame_info = format_frame_info(code)
return (
f"torch.compile intentionally decided to skip the frame {frame_info} and fall back to eager.\n"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see a test that checks for this string - I think this might not be logged to the graph_breaks logging artifact (I believe it is logged to the torch._dynamo logger?). Can you make sure this is covered by the tests you added?

@parsshar-RH
Copy link
Collaborator Author

@williamwen42

I have implemented the suggested changes.
Please aware me if you have more suggestions.

Thanks!

@parsshar-RH
Copy link
Collaborator Author

@williamwen42
Looks like some unrelated CI checks are failing.
Please review.

Thanks in advance!

Copy link
Member

@williamwen42 williamwen42 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The lint error looks real - can you move the nested graph break tests you added to test_error_messages?

@parsshar-RH
Copy link
Collaborator Author

@williamwen42

I have moved the nested graph break tests to test_error_messages.py

@parsshar-RH
Copy link
Collaborator Author

@williamwen42

Lint error is resolved.
But this failing CI check looks unrelated.
Please review and suggest.

Thanks

@parsshar-RH
Copy link
Collaborator Author

@williamwen42
Please review and merge the PR.

@williamwen42
Copy link
Member

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Nov 13, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: trunk / linux-jammy-cuda12.8-py3.10-gcc11 / test (default, 4, 5, lf.linux.g6.4xlarge.experimental.nvidia.gpu)

Details for Dev Infra team Raised by workflow job

@williamwen42
Copy link
Member

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased improve-graph-break-skip-frame onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout improve-graph-break-skip-frame && git pull --rebase)

@pytorchmergebot pytorchmergebot force-pushed the improve-graph-break-skip-frame branch from 683d9ec to dd201aa Compare November 13, 2025 21:51
@pytorch-bot pytorch-bot bot removed the ciflow/trunk Trigger trunk jobs on your pull request label Nov 13, 2025
@parsshar-RH
Copy link
Collaborator Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Nov 13, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@williamwen42
Copy link
Member

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

meta-codesync bot pushed a commit to pytorch/benchmark that referenced this pull request Nov 14, 2025
Summary:
Fixes #150477

### Summary:

- Added frame information (function name, file, line number) to all graph break/skip messages
- Standardized message format: "torch.compile will skip tracing the frame <name> (<file> line <N>) and fall back to eager. Reason: <reason>"

### Impacts:
module: dynamo

X-link: pytorch/pytorch#167067
Approved by: https://github.com/williamwen42

Reviewed By: jeanschmidt

Differential Revision: D87036500

fbshipit-source-id: 62281bad4609b8ea3557f7139695678bed0679cb
Silv3S pushed a commit to Silv3S/pytorch that referenced this pull request Nov 18, 2025
Fixes pytorch#150477

### Summary:

- Added frame information (function name, file, line number) to all graph break/skip messages
- Standardized message format: "torch.compile will skip tracing the frame <name> (<file> line <N>) and fall back to eager. Reason: <reason>"

### Impacts:
module: dynamo

Pull Request resolved: pytorch#167067
Approved by: https://github.com/williamwen42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged module: dynamo open source topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[dynamo] improve graph break message causing skipped frame

6 participants