Skip to content

[Tiling rewrite pt1] Normalize reads and writes to common iter space#153723

Closed
eellison wants to merge 9 commits intogh/eellison/790/basefrom
gh/eellison/790/head
Closed

[Tiling rewrite pt1] Normalize reads and writes to common iter space#153723
eellison wants to merge 9 commits intogh/eellison/790/basefrom
gh/eellison/790/head

Conversation

@eellison
Copy link
Contributor

@eellison eellison commented May 16, 2025

Stack from ghstack (oldest at bottom):

In order to take the globally best tiling, we need to normalize all the node read and writes to a common iteration space. This first pr finds a common split among nodes in a fused scheduler node, and then normalizes reads and writes to the common split.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented May 16, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/153723

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit e4f4618 with merge base 9258cfc (image):

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

[ghstack-poisoned]
eellison added a commit that referenced this pull request May 16, 2025
@eellison eellison added the topic: not user facing topic category label May 16, 2025
@eellison eellison requested a review from jansel May 16, 2025 14:05
[ghstack-poisoned]
def foo(x, y):
return x + y

foo(torch.rand([4, 4], device="cuda"), torch.rand([4, 4], device="cuda").T)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, May I suggest we mark these cases as requires_cuda or replace the hardcode cuda with GPU_TYPE here? These new test case will also run on XPU and fail with cuda, thanks.

Copy link
Contributor

@jansel jansel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test failures?

…iter space"


In order to take the globally best tiling, we need to normalize all the node read and writes to a common iteration space. This first pr finds a common split among nodes in a fused scheduler node, and then normalizes reads and writes to the common split. 


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
@eellison eellison requested a review from jansel May 21, 2025 01:14
@eellison eellison mentioned this pull request May 21, 2025
…iter space"


In order to take the globally best tiling, we need to normalize all the node read and writes to a common iteration space. This first pr finds a common split among nodes in a fused scheduler node, and then normalizes reads and writes to the common split. 


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
@eellison eellison mentioned this pull request May 22, 2025
[ghstack-poisoned]
This was referenced May 27, 2025
[ghstack-poisoned]
[ghstack-poisoned]
@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jun 2, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Starting merge as part of PR stack under #153730

[ghstack-poisoned]
@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: New commits were pushed while merging. Please rerun the merge command.

Details for Dev Infra team Raised by workflow job

@eellison
Copy link
Contributor Author

eellison commented Jun 3, 2025

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Jun 3, 2025
Analyze memory expressions to see if they contain a coalescing symbol.

Pull Request resolved: #153730
Approved by: https://github.com/jansel
ghstack dependencies: #153723
pytorchmergebot pushed a commit that referenced this pull request Jun 3, 2025
Find variables that coalesce the reads and writes and score the total size. If uncoalesced memory expressions are found, look for additional tiling of variables which will coalesce memory accesses.

For instance - for the following expression: `(32*p0) // 2048`, tiling p0 by 64 will make this expression coalesced.

Pull Request resolved: #153748
Approved by: https://github.com/jansel
ghstack dependencies: #153723, #153730
pytorchmergebot pushed a commit that referenced this pull request Jun 4, 2025
This pr uses the coalescing information in generating a tiling. The previous tiling heuristic would have each dependency generate a tiling. Then, we sum up the score for each generated tiling, preferring any 2d tiling over the default. The new tiling heuristics scores each tiling by its global coalesced memory. This gives both a potentially better tiling (especially for more complicated, 3d patterns) as well as information we can use in generating block sizes.

In triton heuristics, for generating 3d tiled reductions, we take the same total block size that the 2d reduction would use, then distribute the block according to whichever block coalesces the most memory.

The motivating kernel is in #149982 which is a 32 element reduction. A smaller version of it is [here](https://gist.github.com/eellison/0fa9396f5479eb4dba09756e3bf6ff2a). We need to run this kernel once in the forward per linear layer on a contiguous tensor, and once in the backward on a transposed tensor.

While the contiguous kernel has coalesced accesses, and is performant on master, the transposed version accesses uncoalesced memory on main and is ~2.8x slower. See, this [full log](https://gist.github.com/eellison/fa644bfd9d0ae11dadb62e17a5d48a83) from the above repro. Now, with this PR, it is only ~1.15x slower. See the [updated log](https://gist.github.com/eellison/0b2b653309494d28cf7b48929a022075).

Pull Request resolved: #153751
Approved by: https://github.com/jansel
ghstack dependencies: #153723, #153730, #153748
iupaikov-amd pushed a commit to ROCm/pytorch that referenced this pull request Jun 4, 2025
…ytorch#153723)

In order to take the globally best tiling, we need to normalize all the node read and writes to a common iteration space. This first pr finds a common split among nodes in a fused scheduler node, and then normalizes reads and writes to the common split.

Pull Request resolved: pytorch#153723
Approved by: https://github.com/jansel
iupaikov-amd pushed a commit to ROCm/pytorch that referenced this pull request Jun 4, 2025
Analyze memory expressions to see if they contain a coalescing symbol.

Pull Request resolved: pytorch#153730
Approved by: https://github.com/jansel
ghstack dependencies: pytorch#153723
iupaikov-amd pushed a commit to ROCm/pytorch that referenced this pull request Jun 4, 2025
Find variables that coalesce the reads and writes and score the total size. If uncoalesced memory expressions are found, look for additional tiling of variables which will coalesce memory accesses.

For instance - for the following expression: `(32*p0) // 2048`, tiling p0 by 64 will make this expression coalesced.

Pull Request resolved: pytorch#153748
Approved by: https://github.com/jansel
ghstack dependencies: pytorch#153723, pytorch#153730
iupaikov-amd pushed a commit to ROCm/pytorch that referenced this pull request Jun 4, 2025
This pr uses the coalescing information in generating a tiling. The previous tiling heuristic would have each dependency generate a tiling. Then, we sum up the score for each generated tiling, preferring any 2d tiling over the default. The new tiling heuristics scores each tiling by its global coalesced memory. This gives both a potentially better tiling (especially for more complicated, 3d patterns) as well as information we can use in generating block sizes.

In triton heuristics, for generating 3d tiled reductions, we take the same total block size that the 2d reduction would use, then distribute the block according to whichever block coalesces the most memory.

The motivating kernel is in pytorch#149982 which is a 32 element reduction. A smaller version of it is [here](https://gist.github.com/eellison/0fa9396f5479eb4dba09756e3bf6ff2a). We need to run this kernel once in the forward per linear layer on a contiguous tensor, and once in the backward on a transposed tensor.

While the contiguous kernel has coalesced accesses, and is performant on master, the transposed version accesses uncoalesced memory on main and is ~2.8x slower. See, this [full log](https://gist.github.com/eellison/fa644bfd9d0ae11dadb62e17a5d48a83) from the above repro. Now, with this PR, it is only ~1.15x slower. See the [updated log](https://gist.github.com/eellison/0b2b653309494d28cf7b48929a022075).

Pull Request resolved: pytorch#153751
Approved by: https://github.com/jansel
ghstack dependencies: pytorch#153723, pytorch#153730, pytorch#153748
angelayi pushed a commit to angelayi/pytorch that referenced this pull request Jun 5, 2025
Analyze memory expressions to see if they contain a coalescing symbol.

Pull Request resolved: pytorch#153730
Approved by: https://github.com/jansel
ghstack dependencies: pytorch#153723
angelayi pushed a commit to angelayi/pytorch that referenced this pull request Jun 5, 2025
Find variables that coalesce the reads and writes and score the total size. If uncoalesced memory expressions are found, look for additional tiling of variables which will coalesce memory accesses.

For instance - for the following expression: `(32*p0) // 2048`, tiling p0 by 64 will make this expression coalesced.

Pull Request resolved: pytorch#153748
Approved by: https://github.com/jansel
ghstack dependencies: pytorch#153723, pytorch#153730
angelayi pushed a commit to angelayi/pytorch that referenced this pull request Jun 5, 2025
This pr uses the coalescing information in generating a tiling. The previous tiling heuristic would have each dependency generate a tiling. Then, we sum up the score for each generated tiling, preferring any 2d tiling over the default. The new tiling heuristics scores each tiling by its global coalesced memory. This gives both a potentially better tiling (especially for more complicated, 3d patterns) as well as information we can use in generating block sizes.

In triton heuristics, for generating 3d tiled reductions, we take the same total block size that the 2d reduction would use, then distribute the block according to whichever block coalesces the most memory.

The motivating kernel is in pytorch#149982 which is a 32 element reduction. A smaller version of it is [here](https://gist.github.com/eellison/0fa9396f5479eb4dba09756e3bf6ff2a). We need to run this kernel once in the forward per linear layer on a contiguous tensor, and once in the backward on a transposed tensor.

While the contiguous kernel has coalesced accesses, and is performant on master, the transposed version accesses uncoalesced memory on main and is ~2.8x slower. See, this [full log](https://gist.github.com/eellison/fa644bfd9d0ae11dadb62e17a5d48a83) from the above repro. Now, with this PR, it is only ~1.15x slower. See the [updated log](https://gist.github.com/eellison/0b2b653309494d28cf7b48929a022075).

Pull Request resolved: pytorch#153751
Approved by: https://github.com/jansel
ghstack dependencies: pytorch#153723, pytorch#153730, pytorch#153748
@github-actions github-actions bot deleted the gh/eellison/790/head branch July 4, 2025 02:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants