[Graph Partition][Flex Attention] analyze symints from subgraph inputs and outputs#152878
[Graph Partition][Flex Attention] analyze symints from subgraph inputs and outputs#152878BoyuanFeng wants to merge 8 commits intomainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/152878
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ You can merge normally! (1 Unrelated Failure)As of commit efe6660 with merge base 6f6fac6 ( UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
drisspg
left a comment
There was a problem hiding this comment.
Curious; Does the output buffer ever actually have symints that weren't in the inputs? For flex
| self.subgraph_inps: Optional[list[Optional[Union[IRNode, sympy.Expr]]]] = None | ||
| self.subgraph_outs: Optional[list[Optional[IRNode]]] = None | ||
|
|
||
| def get_free_symbol_uses( |
There was a problem hiding this comment.
probably share the same doc as get_free_symbol_uses for other IRNode.
In general output buffer may have captured symints that were not in the inputs. I don't have a concrete example for flex attention yet, but still included output buffer for completeness. |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Flex Attention may have symints in subgraph inputs and outputs. Existing code implicitly captures these symints but does not explicitly store it in TritonTemplateBuffer. This leads to error when analyzing symints used in Flex Attention as a TritonTemplateBuffer. This PR fixes the issue.
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov