[AOTAutograd] Fix static_input_indices not offset when effect tokens are prepended#175904
[AOTAutograd] Fix static_input_indices not offset when effect tokens are prepended#175904wmhst7 wants to merge 3 commits intopytorch:mainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/175904
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (4 Unrelated Failures)As of commit 36ac958 with merge base ca7ffb7 ( FLAKY - The following job failed but was likely due to flakiness present on trunk:
BROKEN TRUNK - The following jobs failed but was present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
Fix line-too-long lint error in graph_capture_wrappers.py and add a test verifying that static_input_indices are correctly offset when effect tokens are prepended to inputs.
|
@pytorchbot merge |
Merge failedReason: This PR needs a If not, please add the To add a label, you can comment to pytorchbot, for example For more information, see Details for Dev Infra teamRaised by workflow job |
This PR needs a
|
|
@pytorchbot label "topic: not user facing" |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Summary
When effectful ops (e.g.,
with_effects) are present,handle_effect_tokens_fn()prepends effect token placeholders to the input args. However,static_input_indicesinViewAndMutationMetais computed before this prepending and is not adjusted afterwards. This causes indices to point to wrong inputs, leading to issues like unnecessary CUDA graph re-recording.Problem
In
handle_effect_tokens_fn(), effect tokens are prepended to args:But meta.static_input_indices is not offset by num_tokens. When these indices are later used (e.g., by CUDAGraph's check_invariants), they point to the wrong inputs:
Before tokens:
args=[activation, weight], static_input_indices=[1] → weight ✓After tokens:
args=[token, activation, weight], static_input_indices=[1] → activation ✗Expected:
static_input_indices=[2] (offset by num_tokens=1) → weight ✓Impact
Fix
Offset static_input_indices by num_tokens after prepending effect tokens in the forward-only (trace_joint=False) path:
Unit Test
Added
test_static_input_indices_with_effect_tokensintest/functorch/test_aotdispatch.pywhich:static_input_indices are >= num_tokensafter effect tokens are prepended (i.e., no indexincorrectly points to a token input)
cc @yanboliang