Skip to content

[DTensor] Make Replicate->Partial cost > 0#172282

Closed
wconstab wants to merge 17 commits intogh/wconstab/495/basefrom
gh/wconstab/495/head
Closed

[DTensor] Make Replicate->Partial cost > 0#172282
wconstab wants to merge 17 commits intogh/wconstab/495/basefrom
gh/wconstab/495/head

Conversation

@wconstab
Copy link
Copy Markdown
Contributor

@wconstab wconstab commented Jan 12, 2026

Stack from ghstack (oldest at bottom):

The cost of doing this conversion is actually nonzero as it involves
dispatching some operators - currently this differs depending on which
type of Partial, as each defines its own 'partition' function, but in
general could be a scaling operation.

It's helpful to express this as non-free in the cost model becuase
otherwise it is likely that a suboptimal op sharding strategy will be
selected on the basis that it's just as cheap to convert one partial
through replica to another partial as it is to stay in replicate.

Before this PR, when multiplying Partial("max") * Replicate, the strategy:

  • [Partial(sum), Replicate, Partial(sum)] has cost 22.82 (Pmax ->
    Replicate -> Psum)
  • [Replicate, Replicate, Replicate] has cost 22.82 (Pmax ->
    Replicate)
    And we would select which ever appears first in the strategy list.

The cost of doing this conversion is actually nonzero as it involves
dispatching some operators - currently this differs depending on which
type of Partial, as each defines its own 'partition' function, but in
general could be a scaling operation.

It's helpful to express this as non-free in the cost model becuase
otherwise it is likely that a suboptimal op sharding strategy will be
selected on the basis that it's just as cheap to convert one partial
through replica to another partial as it is to stay in replicate.

Before this PR, when multiplying Partial("max") * Replicate, the strategy:
  - [Partial(sum), Replicate, Partial(sum)] has cost 22.82 (Pmax ->
    Replicate -> Psum)
  - [Replicate, Replicate, Replicate] has cost 22.82 (Pmax ->
    Replicate)
And we would select which ever appears first in the strategy list.

[ghstack-poisoned]
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Jan 12, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/172282

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 New Failures, 2 Unrelated Failures

As of commit 739240d with merge base 7754b55 (image):

NEW FAILURES - The following jobs have failed:

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

The cost of doing this conversion is actually nonzero as it involves
dispatching some operators - currently this differs depending on which
type of Partial, as each defines its own 'partition' function, but in
general could be a scaling operation.

It's helpful to express this as non-free in the cost model becuase
otherwise it is likely that a suboptimal op sharding strategy will be
selected on the basis that it's just as cheap to convert one partial
through replica to another partial as it is to stay in replicate.

Before this PR, when multiplying Partial("max") * Replicate, the strategy:
  - [Partial(sum), Replicate, Partial(sum)] has cost 22.82 (Pmax ->
    Replicate -> Psum)
  - [Replicate, Replicate, Replicate] has cost 22.82 (Pmax ->
    Replicate)
And we would select which ever appears first in the strategy list.

[ghstack-poisoned]
@wconstab wconstab requested review from fmassa and zpcore January 13, 2026 00:01
The cost of doing this conversion is actually nonzero as it involves
dispatching some operators - currently this differs depending on which
type of Partial, as each defines its own 'partition' function, but in
general could be a scaling operation.

It's helpful to express this as non-free in the cost model becuase
otherwise it is likely that a suboptimal op sharding strategy will be
selected on the basis that it's just as cheap to convert one partial
through replica to another partial as it is to stay in replicate.

Before this PR, when multiplying Partial("max") * Replicate, the strategy:
  - [Partial(sum), Replicate, Partial(sum)] has cost 22.82 (Pmax ->
    Replicate -> Psum)
  - [Replicate, Replicate, Replicate] has cost 22.82 (Pmax ->
    Replicate)
And we would select which ever appears first in the strategy list.

[ghstack-poisoned]
The cost of doing this conversion is actually nonzero as it involves
dispatching some operators - currently this differs depending on which
type of Partial, as each defines its own 'partition' function, but in
general could be a scaling operation.

It's helpful to express this as non-free in the cost model becuase
otherwise it is likely that a suboptimal op sharding strategy will be
selected on the basis that it's just as cheap to convert one partial
through replica to another partial as it is to stay in replicate.

Before this PR, when multiplying Partial("max") * Replicate, the strategy:
  - [Partial(sum), Replicate, Partial(sum)] has cost 22.82 (Pmax ->
    Replicate -> Psum)
  - [Replicate, Replicate, Replicate] has cost 22.82 (Pmax ->
    Replicate)
And we would select which ever appears first in the strategy list.

[ghstack-poisoned]
@fmassa
Copy link
Copy Markdown
Member

fmassa commented Jan 13, 2026

I think modelling the compute cost associated with a given redistribution is a good thing to have.
We do it in AutoParallel through the "compute cost" part, and it was important for cases like additional copies / etc that were involved in a given redistribution.

I'm just not sure if we should bundle comms and compute cost in the same place.
Additionally, if we do add a compute cost maybe we should be more thorough on it and model the compute cost more accurately?

Wdyt?

@wconstab
Copy link
Copy Markdown
Contributor Author

I'm just not sure if we should bundle comms and compute cost in the same place.
Additionally, if we do add a compute cost maybe we should be more thorough on it and model the compute cost more accurately?

@weifengpy had the same reaction, to not bundle them.

Do you think it makes sense to model them separately, but then offer a bundled 'total cost' api? I think for DTensor purposes, I just want the total cost. Not sure if we would want to separate the costs for some reason in DTensor too?

@fmassa
Copy link
Copy Markdown
Member

fmassa commented Jan 13, 2026

I think that if we start adding additional one-off costs (like the extra division that happens in this case), then for consistency we should also model additional copies that might happen, see what we do in autoparallel for an example.

And the reason to model it more carefully is that the break-even value you are adding will probably won't be enough for all use-cases. But then, we will have to model the compute cost taking different GPU architectures into account (as they have different bandwidth)

And if we start discussing modelling those costs more accurately, we should also model the communication costs for different GPUs / interconnects (as the current cost model hard-codes A100 GPUs)

IMO, I think we should improve our cost models across the board, but it might require a bit of discussion about what we will want to model, as it can quickly grow a bit in scope.

The cost of doing this conversion is actually nonzero as it involves
dispatching some operators - currently this differs depending on which
type of Partial, as each defines its own 'partition' function, but in
general could be a scaling operation.

It's helpful to express this as non-free in the cost model becuase
otherwise it is likely that a suboptimal op sharding strategy will be
selected on the basis that it's just as cheap to convert one partial
through replica to another partial as it is to stay in replicate.

Before this PR, when multiplying Partial("max") * Replicate, the strategy:
  - [Partial(sum), Replicate, Partial(sum)] has cost 22.82 (Pmax ->
    Replicate -> Psum)
  - [Replicate, Replicate, Replicate] has cost 22.82 (Pmax ->
    Replicate)
And we would select which ever appears first in the strategy list.

[ghstack-poisoned]
@wconstab
Copy link
Copy Markdown
Contributor Author

@fmassa i would be happy to model num copies and memory bandwidth instead of hardcoding a constant.

The first question though, do you want separate comms and compute costs through separate APIs? in autop as well as for my PR, the costs are just summed.

@weifengpy
Copy link
Copy Markdown
Contributor

comm cost >> local compute cost is what I thought. stragglers are the major bottlenecks, not msg size or cpu overhead.

For immediate landable version, I was just proposing using local compute cost to break the tie when comm cost is on par.

modeling local compute &data movement cost sounds totally reasonable, but even with that, I was still thinking about comm cost >> local cost, and prefer using local cost for break the tie

@wconstab
Copy link
Copy Markdown
Contributor Author

@weifengpy I agree that generally comm cost >> compute cost. However for extremely small message sizes this may not be true. I think accurately modeling the compute part and summing them together gives redistribute planner the most accurate signal.

Are you suggesting that we should not include the compute cost in the cost value, but explicitly only consider compute cost as a separate 'tie break' step during min cost strategy selection? I feel like this way is actually more confusing / complex and I am not sure it adds value.

The cost of doing this conversion is actually nonzero as it involves
dispatching some operators - currently this differs depending on which
type of Partial, as each defines its own 'partition' function, but in
general could be a scaling operation.

It's helpful to express this as non-free in the cost model becuase
otherwise it is likely that a suboptimal op sharding strategy will be
selected on the basis that it's just as cheap to convert one partial
through replica to another partial as it is to stay in replicate.

Before this PR, when multiplying Partial("max") * Replicate, the strategy:
  - [Partial(sum), Replicate, Partial(sum)] has cost 22.82 (Pmax ->
    Replicate -> Psum)
  - [Replicate, Replicate, Replicate] has cost 22.82 (Pmax ->
    Replicate)
And we would select which ever appears first in the strategy list.

[ghstack-poisoned]
The cost of doing this conversion is actually nonzero as it involves
dispatching some operators - currently this differs depending on which
type of Partial, as each defines its own 'partition' function, but in
general could be a scaling operation.

It's helpful to express this as non-free in the cost model becuase
otherwise it is likely that a suboptimal op sharding strategy will be
selected on the basis that it's just as cheap to convert one partial
through replica to another partial as it is to stay in replicate.

Before this PR, when multiplying Partial("max") * Replicate, the strategy:
  - [Partial(sum), Replicate, Partial(sum)] has cost 22.82 (Pmax ->
    Replicate -> Psum)
  - [Replicate, Replicate, Replicate] has cost 22.82 (Pmax ->
    Replicate)
And we would select which ever appears first in the strategy list.

[ghstack-poisoned]
@weifengpy
Copy link
Copy Markdown
Contributor

weifengpy commented Jan 14, 2026

However for extremely small message sizes this may not be true

sorry I should mentioned my justification. We regressed the perf a lot by adding grad clipping (communicating a scalar value) to recommendation workload. I realized it's the sync point (or stragglers) that matters the most than msg size. Another example is we never achived expected perf gain when reducing the msg size by switching from bf16 to fp8 (50% msg size but only achieve <15% perf gain). It's still because of stragglers. That makes me reach the extreme thought that comm cost 0 (no comm) -> comm cost 0.01 (comm a scalar value) are fundmentally different, no matter what local compute cost is

The cost of doing this conversion is actually nonzero as it involves
dispatching some operators - currently this differs depending on which
type of Partial, as each defines its own 'partition' function, but in
general could be a scaling operation.

It's helpful to express this as non-free in the cost model becuase
otherwise it is likely that a suboptimal op sharding strategy will be
selected on the basis that it's just as cheap to convert one partial
through replica to another partial as it is to stay in replicate.

Before this PR, when multiplying Partial("max") * Replicate, the strategy:
  - [Partial(sum), Replicate, Partial(sum)] has cost 22.82 (Pmax ->
    Replicate -> Psum)
  - [Replicate, Replicate, Replicate] has cost 22.82 (Pmax ->
    Replicate)
And we would select which ever appears first in the strategy list.

[ghstack-poisoned]
The cost of doing this conversion is actually nonzero as it involves
dispatching some operators - currently this differs depending on which
type of Partial, as each defines its own 'partition' function, but in
general could be a scaling operation.

It's helpful to express this as non-free in the cost model becuase
otherwise it is likely that a suboptimal op sharding strategy will be
selected on the basis that it's just as cheap to convert one partial
through replica to another partial as it is to stay in replicate.

Before this PR, when multiplying Partial("max") * Replicate, the strategy:
  - [Partial(sum), Replicate, Partial(sum)] has cost 22.82 (Pmax ->
    Replicate -> Psum)
  - [Replicate, Replicate, Replicate] has cost 22.82 (Pmax ->
    Replicate)
And we would select which ever appears first in the strategy list.

[ghstack-poisoned]
The cost of doing this conversion is actually nonzero as it involves
dispatching some operators - currently this differs depending on which
type of Partial, as each defines its own 'partition' function, but in
general could be a scaling operation.

It's helpful to express this as non-free in the cost model becuase
otherwise it is likely that a suboptimal op sharding strategy will be
selected on the basis that it's just as cheap to convert one partial
through replica to another partial as it is to stay in replicate.

Before this PR, when multiplying Partial("max") * Replicate, the strategy:
  - [Partial(sum), Replicate, Partial(sum)] has cost 22.82 (Pmax ->
    Replicate -> Psum)
  - [Replicate, Replicate, Replicate] has cost 22.82 (Pmax ->
    Replicate)
And we would select which ever appears first in the strategy list.

[ghstack-poisoned]
SergeyTyshkevich pushed a commit to SergeyTyshkevich/chart2 that referenced this pull request Jan 19, 2026
The cost of doing this conversion is actually nonzero as it involves
dispatching some operators - currently this differs depending on which
type of Partial, as each defines its own 'partition' function, but in
general could be a scaling operation.

It's helpful to express this as non-free in the cost model becuase
otherwise it is likely that a suboptimal op sharding strategy will be
selected on the basis that it's just as cheap to convert one partial
through replica to another partial as it is to stay in replicate.

Before this PR, when multiplying Partial("max") * Replicate, the strategy:
  - [Partial(sum), Replicate, Partial(sum)] has cost 22.82 (Pmax ->
    Replicate -> Psum)
  - [Replicate, Replicate, Replicate] has cost 22.82 (Pmax ->
    Replicate)
And we would select which ever appears first in the strategy list.

ghstack-source-id: 8ce9f4c
Pull Request resolved: pytorch/pytorch#172282
The cost of doing this conversion is actually nonzero as it involves
dispatching some operators - currently this differs depending on which
type of Partial, as each defines its own 'partition' function, but in
general could be a scaling operation.

It's helpful to express this as non-free in the cost model becuase
otherwise it is likely that a suboptimal op sharding strategy will be
selected on the basis that it's just as cheap to convert one partial
through replica to another partial as it is to stay in replicate.

Before this PR, when multiplying Partial("max") * Replicate, the strategy:
  - [Partial(sum), Replicate, Partial(sum)] has cost 22.82 (Pmax ->
    Replicate -> Psum)
  - [Replicate, Replicate, Replicate] has cost 22.82 (Pmax ->
    Replicate)
And we would select which ever appears first in the strategy list.

[ghstack-poisoned]
The cost of doing this conversion is actually nonzero as it involves
dispatching some operators - currently this differs depending on which
type of Partial, as each defines its own 'partition' function, but in
general could be a scaling operation.

It's helpful to express this as non-free in the cost model becuase
otherwise it is likely that a suboptimal op sharding strategy will be
selected on the basis that it's just as cheap to convert one partial
through replica to another partial as it is to stay in replicate.

Before this PR, when multiplying Partial("max") * Replicate, the strategy:
  - [Partial(sum), Replicate, Partial(sum)] has cost 22.82 (Pmax ->
    Replicate -> Psum)
  - [Replicate, Replicate, Replicate] has cost 22.82 (Pmax ->
    Replicate)
And we would select which ever appears first in the strategy list.

[ghstack-poisoned]
@wconstab
Copy link
Copy Markdown
Contributor Author

sorry I should mentioned my justification. We regressed the perf a lot by adding grad clipping (communicating a scalar value) to recommendation workload. I realized it's the sync point (or stragglers) that matters the most than msg size. Another example is we never achived expected perf gain when reducing the msg size by switching from bf16 to fp8 (50% msg size but only achieve <15% perf gain). It's still because of stragglers. That makes me reach the extreme thought that comm cost 0 (no comm) -> comm cost 0.01 (comm a scalar value) are fundmentally different, no matter what local compute cost is

I think this is compatible with my framing.

IIUC your point is that any comm, even a tiny one, can lead to a cost larger than its bandwidth-time computation would suggest. I agree with this.

Overall, the cost model for an operator can include contributions from these sources, depending on which operator it is and how many kernels it calls:

  • fixed latency (model the overhead of doing a kernel launch)
  • bandwidth / compute time (based on num elements and either tflops or memory bw or comms bw)
  • straggler overhead (could model as some % increase on top of the above, although, I am not too confident on how to model this effectively.

for DTensor, I still like having all of this summed up into one number - think of it as modeling 'redistribute time' - and minimizing over that for strategy selection.

Copy link
Copy Markdown
Contributor

@sanketpurandare sanketpurandare left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense to me. IIUC,
The goal is to avoid selecting strategies that introduce unnecessary placement conversions, especially ones involving Partial, when an equally-good (or better) Replicate-only strategy exists.
The reason being, even if two strategies are equivalent for the current op’s output, introducing a Partial can impose real downstream costs because a later consumer may have to “reduce (finish)” that Partial to satisfy its own valid strategies. This is exactly the kind of “hidden future tax” a local cost model can miss.

suncapitalllc007-star pushed a commit to suncapitalllc007-star/pytorch that referenced this pull request Jan 25, 2026
The cost of doing this conversion is actually nonzero as it involves
dispatching some operators - currently this differs depending on which
type of Partial, as each defines its own 'partition' function, but in
general could be a scaling operation.

It's helpful to express this as non-free in the cost model becuase
otherwise it is likely that a suboptimal op sharding strategy will be
selected on the basis that it's just as cheap to convert one partial
through replica to another partial as it is to stay in replicate.

Before this PR, when multiplying Partial("max") * Replicate, the strategy:
  - [Partial(sum), Replicate, Partial(sum)] has cost 22.82 (Pmax ->
    Replicate -> Psum)
  - [Replicate, Replicate, Replicate] has cost 22.82 (Pmax ->
    Replicate)
And we would select which ever appears first in the strategy list.

ghstack-source-id: 9ebb4e5
Pull Request resolved: pytorch/pytorch#172282
The cost of doing this conversion is actually nonzero as it involves
dispatching some operators - currently this differs depending on which
type of Partial, as each defines its own 'partition' function, but in
general could be a scaling operation.

It's helpful to express this as non-free in the cost model becuase
otherwise it is likely that a suboptimal op sharding strategy will be
selected on the basis that it's just as cheap to convert one partial
through replica to another partial as it is to stay in replicate.

Before this PR, when multiplying Partial("max") * Replicate, the strategy:
  - [Partial(sum), Replicate, Partial(sum)] has cost 22.82 (Pmax ->
    Replicate -> Psum)
  - [Replicate, Replicate, Replicate] has cost 22.82 (Pmax ->
    Replicate)
And we would select which ever appears first in the strategy list.

[ghstack-poisoned]
The cost of doing this conversion is actually nonzero as it involves
dispatching some operators - currently this differs depending on which
type of Partial, as each defines its own 'partition' function, but in
general could be a scaling operation.

It's helpful to express this as non-free in the cost model becuase
otherwise it is likely that a suboptimal op sharding strategy will be
selected on the basis that it's just as cheap to convert one partial
through replica to another partial as it is to stay in replicate.

Before this PR, when multiplying Partial("max") * Replicate, the strategy:
  - [Partial(sum), Replicate, Partial(sum)] has cost 22.82 (Pmax ->
    Replicate -> Psum)
  - [Replicate, Replicate, Replicate] has cost 22.82 (Pmax ->
    Replicate)
And we would select which ever appears first in the strategy list.

[ghstack-poisoned]
The cost of doing this conversion is actually nonzero as it involves
dispatching some operators - currently this differs depending on which
type of Partial, as each defines its own 'partition' function, but in
general could be a scaling operation.

It's helpful to express this as non-free in the cost model becuase
otherwise it is likely that a suboptimal op sharding strategy will be
selected on the basis that it's just as cheap to convert one partial
through replica to another partial as it is to stay in replicate.

Before this PR, when multiplying Partial("max") * Replicate, the strategy:
  - [Partial(sum), Replicate, Partial(sum)] has cost 22.82 (Pmax ->
    Replicate -> Psum)
  - [Replicate, Replicate, Replicate] has cost 22.82 (Pmax ->
    Replicate)
And we would select which ever appears first in the strategy list.

[ghstack-poisoned]
The cost of doing this conversion is actually nonzero as it involves
dispatching some operators - currently this differs depending on which
type of Partial, as each defines its own 'partition' function, but in
general could be a scaling operation.

It's helpful to express this as non-free in the cost model becuase
otherwise it is likely that a suboptimal op sharding strategy will be
selected on the basis that it's just as cheap to convert one partial
through replica to another partial as it is to stay in replicate.

Before this PR, when multiplying Partial("max") * Replicate, the strategy:
  - [Partial(sum), Replicate, Partial(sum)] has cost 22.82 (Pmax ->
    Replicate -> Psum)
  - [Replicate, Replicate, Replicate] has cost 22.82 (Pmax ->
    Replicate)
And we would select which ever appears first in the strategy list.

[ghstack-poisoned]
The cost of doing this conversion is actually nonzero as it involves
dispatching some operators - currently this differs depending on which
type of Partial, as each defines its own 'partition' function, but in
general could be a scaling operation.

It's helpful to express this as non-free in the cost model becuase
otherwise it is likely that a suboptimal op sharding strategy will be
selected on the basis that it's just as cheap to convert one partial
through replica to another partial as it is to stay in replicate.

Before this PR, when multiplying Partial("max") * Replicate, the strategy:
  - [Partial(sum), Replicate, Partial(sum)] has cost 22.82 (Pmax ->
    Replicate -> Psum)
  - [Replicate, Replicate, Replicate] has cost 22.82 (Pmax ->
    Replicate)
And we would select which ever appears first in the strategy list.

[ghstack-poisoned]
@wconstab
Copy link
Copy Markdown
Contributor Author

wconstab commented Mar 3, 2026

abandoning for now. don't need the short-term fix because I banned mixed partials, i think. Still want the long term improvement of more accurate cost models, but that wasn't done in this PR anyway.

@wconstab wconstab closed this Mar 3, 2026
sandy-gags pushed a commit to sandy-gags/pytorch that referenced this pull request Mar 12, 2026
The cost of doing this conversion is actually nonzero as it involves
dispatching some operators - currently this differs depending on which
type of Partial, as each defines its own 'partition' function, but in
general could be a scaling operation.

It's helpful to express this as non-free in the cost model becuase
otherwise it is likely that a suboptimal op sharding strategy will be
selected on the basis that it's just as cheap to convert one partial
through replica to another partial as it is to stay in replicate.

Before this PR, when multiplying Partial("max") * Replicate, the strategy:
  - [Partial(sum), Replicate, Partial(sum)] has cost 22.82 (Pmax ->
    Replicate -> Psum)
  - [Replicate, Replicate, Replicate] has cost 22.82 (Pmax ->
    Replicate)
And we would select which ever appears first in the strategy list.

ghstack-source-id: 9cca092
Pull Request resolved: pytorch/pytorch#172282
@github-actions github-actions Bot deleted the gh/wconstab/495/head branch April 3, 2026 02:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants