Skip to content

[Graph Partition][Flex Attention] analyze symints from subgraph inputs and outputs#152878

Closed
BoyuanFeng wants to merge 8 commits intomainfrom
bf/partition-fa
Closed

[Graph Partition][Flex Attention] analyze symints from subgraph inputs and outputs#152878
BoyuanFeng wants to merge 8 commits intomainfrom
bf/partition-fa

Conversation

@BoyuanFeng
Copy link
Contributor

@BoyuanFeng BoyuanFeng commented May 5, 2025

Flex Attention may have symints in subgraph inputs and outputs. Existing code implicitly captures these symints but does not explicitly store it in TritonTemplateBuffer. This leads to error when analyzing symints used in Flex Attention as a TritonTemplateBuffer. This PR fixes the issue.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov

@BoyuanFeng BoyuanFeng added ciflow/trunk Trigger trunk jobs on your pull request topic: not user facing topic category module: inductor labels May 5, 2025
@pytorch-bot
Copy link

pytorch-bot bot commented May 5, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/152878

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ You can merge normally! (1 Unrelated Failure)

As of commit efe6660 with merge base 6f6fac6 (image):

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@BoyuanFeng BoyuanFeng requested review from drisspg and eellison May 7, 2025 16:02
Copy link
Contributor

@drisspg drisspg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious; Does the output buffer ever actually have symints that weren't in the inputs? For flex

self.subgraph_inps: Optional[list[Optional[Union[IRNode, sympy.Expr]]]] = None
self.subgraph_outs: Optional[list[Optional[IRNode]]] = None

def get_free_symbol_uses(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe a doc block

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably share the same doc as get_free_symbol_uses for other IRNode.

@BoyuanFeng
Copy link
Contributor Author

Curious; Does the output buffer ever actually have symints that weren't in the inputs? For flex

In general output buffer may have captured symints that were not in the inputs. I don't have a concrete example for flex attention yet, but still included output buffer for completeness.

@BoyuanFeng
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@BoyuanFeng BoyuanFeng mentioned this pull request Apr 24, 2025
20 tasks
@github-actions github-actions bot deleted the bf/partition-fa branch June 17, 2025 02:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants