Skip to content

Tiling bug fix#167771

Closed
eellison wants to merge 5 commits intogh/eellison/865/basefrom
gh/eellison/865/head
Closed

Tiling bug fix#167771
eellison wants to merge 5 commits intogh/eellison/865/basefrom
gh/eellison/865/head

Conversation

@eellison
Copy link
Contributor

@eellison eellison commented Nov 13, 2025

Stack from ghstack (oldest at bottom):

Fix for #166653.

Two fixes:

  • We were inducing a split for broadcasted loads. e.g. (x // 16). While a split of 16 here will make the load coalesced in one of the tile vars, since the load is already in cache it's not worth splitting. And it would make the other tile var load from memory that isnt in cache.
  • Add a slight term for uncoalesced memory. This prevents doing tiling for loads which are a small % of the overall kernel.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben

[ghstack-poisoned]
eellison added a commit that referenced this pull request Nov 13, 2025
ghstack-source-id: 198bd40
Pull Request resolved: #167771
@pytorch-bot
Copy link

pytorch-bot bot commented Nov 13, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/167771

Note: Links to docs will display an error until the docs builds have been completed.

❌ 4 New Failures

As of commit 5664cd8 with merge base 5a3930a (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot
Copy link

pytorch-bot bot commented Nov 13, 2025

The label module: inductor is only applicable to issues and has been removed. Please only use this label on issues.

@pytorch-bot
Copy link

pytorch-bot bot commented Nov 13, 2025

The label module: inductor is only applicable to issues and has been removed. Please only use this label on issues.

@pytorch-bot
Copy link

pytorch-bot bot commented Nov 13, 2025

The label module: inductor is only applicable to issues and has been removed. Please only use this label on issues.

@pytorch-bot
Copy link

pytorch-bot bot commented Nov 13, 2025

The label module: inductor is only applicable to issues and has been removed. Please only use this label on issues.

@eellison eellison requested a review from v0i0 November 13, 2025 22:10
@eellison eellison added the topic: not user facing topic category label Nov 13, 2025
[ghstack-poisoned]
eellison added a commit that referenced this pull request Nov 13, 2025
ghstack-source-id: ef8b139
Pull Request resolved: #167771
@pytorch-bot
Copy link

pytorch-bot bot commented Nov 13, 2025

The label module: inductor is only applicable to issues and has been removed. Please only use this label on issues.

@pytorch-bot
Copy link

pytorch-bot bot commented Nov 13, 2025

The label module: inductor is only applicable to issues and has been removed. Please only use this label on issues.


byte_multipler = 0
total_score = 0
for buf_name in buf_names:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

more for me to understand, but why would there be more than one buffer per access?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is just format of normalized read writes:

class FusedNormalizedReadsWrites:
"""
Normalized reads and writes for nodes in the same FusedSchedulerNode.
"""
index_vars: OrderedSet[sympy.Symbol]
reduce_vars: OrderedSet[sympy.Symbol]
reads: dict[sympy.Expr, OrderedSet[str]]
writes: dict[sympy.Expr, OrderedSet[str]]
var_ranges: dict[sympy.Symbol, int]

it contains a mapping of sympy memory expr -> all buffers with that expression

[ghstack-poisoned]
eellison added a commit that referenced this pull request Nov 14, 2025
ghstack-source-id: c428a99
Pull Request resolved: #167771
pytorchmergebot pushed a commit that referenced this pull request Nov 14, 2025
ghstack-source-id: 0102654
Pull Request resolved: #167771
"""
Try to find the variable that this index is broadcast over.
A broadcast pattern is one where consecutive values of a variable
access the same memory location (e.g., x // 10).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general x % 10 can also be a broadcast?

x %10 v.s. x // 10 just picks different dimension

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case, x % 10 will be read as coalesced so should still work the same

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, but for stride * (x % 10), it's not coalesced

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That won't be considered coalesced. see

def find_coalesced_var(

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the main thing confuses me is the code treats:

stride * (x % 10) and stride * (x // 10) differently, while they are both broadcasting.

Copy link
Contributor Author

@eellison eellison Nov 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

x % 2048 is not broadcasting. it's only with a very small modulo that it is broadcasting. in this case we're treating both coalesced and broadcasting the same, so it shouldn't matter though.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import torch

@torch.compile
def f(x, y):
    return x[::2, None] + y[None, ::4]

x = torch.randn(1024, device="cuda")
y = torch.randn(2048, device="cuda")
f(x, y)

generates:

@triton.jit
def triton_poi_fused_add_slice_unsqueeze_0(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 262144
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x1 = xindex // 512
    x0 = (xindex % 512)
    x2 = xindex
    tmp0 = tl.load(in_ptr0 + (2*x1), None, eviction_policy='evict_last')
    tmp1 = tl.load(in_ptr1 + (4*x0), None, eviction_policy='evict_last')
    tmp2 = tmp0 + tmp1
    tl.store(out_ptr0 + (x2), tmp2, None)

Althoughly the generated code replace xindex // 512 with x1 and xindex % 512 with x0

Copy link
Contributor

@shunting314 shunting314 Nov 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe a more complex example can trigger the case that xindex // 512 and xindex % 512 shows up in the memory address expression directly. But the tiny example above already shows the idea

jsuarez5341 pushed a commit to PufferAI/pytorch that referenced this pull request Nov 15, 2025
Fix for pytorch#166653.

Two fixes:
- We were inducing a split for broadcasted loads. e.g. (x // 16). While a split of 16 here will make the load coalesced in one of the tile vars, since the load is already in cache it's not worth splitting. And it would make the other tile var load from memory that isnt in cache.
- Add a slight term for uncoalesced memory. This prevents doing tiling for loads which are a small % of the overall kernel.

Pull Request resolved: pytorch#167771
Approved by: https://github.com/v0i0

Empty draft PR

Initial muon port

Change branch name

lint

refresh cla

lint

lint
Khanaksahu pushed a commit to Khanaksahu/pytorch that referenced this pull request Nov 17, 2025
ghstack-source-id: 0102654
Pull Request resolved: pytorch/pytorch#167771
@eellison
Copy link
Contributor Author

@pytorchbot revert -c "weird"

@pytorch-bot
Copy link

pytorch-bot bot commented Nov 17, 2025

❌ 🤖 pytorchbot command failed:

@pytorchbot revert: error: the following arguments are required: -m/--message

usage: @pytorchbot revert -m MESSAGE -c
                          {nosignal,ignoredsignal,landrace,weird,ghfirst,autorevert}

Try @pytorchbot --help for more info.

@eellison
Copy link
Contributor Author

@pytorchbot revert -m "needs one fix" -c weird

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

@pytorchmergebot
Copy link
Collaborator

@eellison your PR has been successfully reverted.

pytorchmergebot added a commit that referenced this pull request Nov 17, 2025
This reverts commit 7ede33b.

Reverted #167771 on behalf of https://github.com/eellison due to needs one fix ([comment](#167771 (comment)))
@pytorchmergebot pytorchmergebot added Reverted ci-no-td Do not run TD on this PR labels Nov 17, 2025
[ghstack-poisoned]
@eellison
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@eellison
Copy link
Contributor Author

@pytorchbot merge -i

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged while ignoring the following 4 checks: Lint / Test collect_env (with_torch, linux.24_04.4x), inductor / inductor-cpu-test / test (dynamic_cpu_inductor_torchbench, 2, 2, linux.8xlarge.amx), inductor / inductor-cpu-test / test (cpu_inductor_torchbench, 2, 2, linux.8xlarge.amx), inductor / inductor-test / test (inductor_torchbench, 2, 2, linux.g5.4xlarge.nvidia.gpu)

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Silv3S pushed a commit to Silv3S/pytorch that referenced this pull request Nov 18, 2025
Fix for pytorch#166653.

Two fixes:
- We were inducing a split for broadcasted loads. e.g. (x // 16). While a split of 16 here will make the load coalesced in one of the tile vars, since the load is already in cache it's not worth splitting. And it would make the other tile var load from memory that isnt in cache.
- Add a slight term for uncoalesced memory. This prevents doing tiling for loads which are a small % of the overall kernel.

Pull Request resolved: pytorch#167771
Approved by: https://github.com/v0i0
tiendatngcs pushed a commit to tiendatngcs/pytorch-Dec25 that referenced this pull request Dec 10, 2025
ghstack-source-id: dae7771
Pull Request resolved: pytorch/pytorch#167771
@github-actions github-actions bot deleted the gh/eellison/865/head branch December 18, 2025 02:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-no-td Do not run TD on this PR ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged module: inductor Reverted topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants