Skip to content

[DTensor] Fix _to_copy to reduce Partial before non-linear dtype conversions#172696

Open
stmcgovern wants to merge 2 commits intopytorch:mainfrom
stmcgovern:fix/dtensor-partial-to-copy-dtype
Open

[DTensor] Fix _to_copy to reduce Partial before non-linear dtype conversions#172696
stmcgovern wants to merge 2 commits intopytorch:mainfrom
stmcgovern:fix/dtensor-partial-to-copy-dtype

Conversation

@stmcgovern
Copy link
Collaborator

@stmcgovern stmcgovern commented Jan 17, 2026

Fixes #172684
Updated to use single_dim_strategy.
Type conversion to int/bool on Partial(sum) incorrectly preserved the Partial placement, producing wrong results. trunc(a+b) != trunc(a) + trunc(b).

This adds a custom strategy for _to_copy that checks if the dtype conversion is linear for the reduce operation before preserving Partial.

This PR is offered in support of the Partial correctness stabilization efforts.

@pytorch-bot
Copy link

pytorch-bot bot commented Jan 17, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/172696

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit f03d781 with merge base 007b6a4 (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: distributed (dtensor) release notes category label Jan 17, 2026
@mikaylagawarecki mikaylagawarecki added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jan 20, 2026
@skpark-rh
Copy link
Collaborator

Shouldn't this have a test?

@stmcgovern stmcgovern force-pushed the fix/dtensor-partial-to-copy-dtype branch from b083109 to a8083cd Compare January 23, 2026 16:55
target_dtype = cast(torch.dtype | None, op_schema.kwargs_schema.get("dtype", None))

strategies = []
for strategy in first_input_strategy.strategies:
Copy link
Contributor

@wconstab wconstab Jan 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not specific to this PR, but in general we're trying to move away from the style of 'find existing input strategies, and modify/filter/tweak them' and instead use 'single mesh dim' strategies. The way a single-dim strategy would work is you just worry about a single mesh dim and you get TensorMetas for your tensor input arguments and then list out all the viable placements. Infra will take care of expanding this combinatorially to the real mesh.

I think the single_dim strategy equivalent of your PR would look something like

@register_single_dim_strategy(aten._to_copy.default, schema_info=RuntimeSchemaInfo(static_kwargkey=["dtype"])
def _to_copy_single_dim_strategy(
    op: OpOverload, args_schema: ArgsType, kwargs_schema: KwargsType
) -> list[list[Placement | _ShardingPlaceholder]]:
    strategies = []
    input = args_schema[0]
    # don't bother adding the all-replicate strategy, infra adds it for you
    for i in range(input.dim()):
        strategies.append(_ShardingPlaceholder(i), _ShardingPlaceholder(i))
    for partial in (Partial("sum"), Partial("avg"), ...):
        if(dtype allows partial):
             strategies.append(partial, partial)
    return strategies

cc @anshul-si - maybe we should add some helper for nicely iterating the partial types for cases like these?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@stmcgovern were you planning to continue on this PR as single-dim, or do you want me to review this further as-is?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I'll switch it over to single-dim. Got sidetracked with the squeeze issues.


else:
# Unknown reduce_op, be conservative
return True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose we should just work out the cases and be complete, other than NormPartial which we plan to deprecate. cc @anshul-si who is helping with 'all the partial stuff'

Copy link
Collaborator Author

@stmcgovern stmcgovern Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be addressed now since caller iterates Partial.ALL_REDUCE_OPS

@stmcgovern stmcgovern force-pushed the fix/dtensor-partial-to-copy-dtype branch from a8083cd to 39e73a1 Compare March 6, 2026 00:07
@stmcgovern stmcgovern requested a review from wconstab March 6, 2026 00:25
aten._to_copy.default, schema_info=RuntimeSchemaInfo(static_kwargkey=["dtype"])
)(propagate_single_input_strategy)
)
def _to_copy_single_dim_strategy(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@anshul-si does this conflict with your pointwise rewrite?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No it shouldn't

def test_to_copy_partial_reduces_for_nonlinear_cast(self):
# (reduce_op, target_dtype, expect_partial)
cases = [
("sum", torch.int32, False), # truncation breaks additivity
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pianpwk does this address the thing you found with strategy validation claiming incorrect rules on various ops that decompose to to_copies?

wconstab
wconstab previously approved these changes Mar 7, 2026
@stmcgovern
Copy link
Collaborator Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Mar 7, 2026
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@stmcgovern stmcgovern force-pushed the fix/dtensor-partial-to-copy-dtype branch from 39e73a1 to eed4c51 Compare March 7, 2026 18:28
@pytorch-bot pytorch-bot bot removed ciflow/trunk Trigger trunk jobs on your pull request ciflow/inductor labels Mar 7, 2026
@stmcgovern
Copy link
Collaborator Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Mar 7, 2026
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@stmcgovern stmcgovern force-pushed the fix/dtensor-partial-to-copy-dtype branch from eed4c51 to e41f21a Compare March 9, 2026 14:43
@pytorch-bot pytorch-bot bot removed the ciflow/trunk Trigger trunk jobs on your pull request label Mar 9, 2026
@stmcgovern
Copy link
Collaborator Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Mar 9, 2026
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorch-auto-revert
Copy link

@pytorchbot revert -m "Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable" -c autorevert

This PR is attributed to have caused regression in:

Please investigate and fix the issues.
@claude Can you please read this revert comment, follow the links and read the errors, to then give a brief diagnostics on the cause of the error? If you judge the error to be legitimate reason for a revert, please provide brief guidance on how the author could fix it.

@claude
Copy link

claude bot commented Mar 10, 2026

Claude finished @pytorch-auto-revert[bot]'s task in 3m 57s —— View job


Revert Diagnostics for PR #172696

  • Read CI failure logs for test_cached_state_dict
  • Read the PR's changed files to understand the modifications
  • Diagnose root cause and provide guidance

Failure Summary

The test test_cached_state_dict fails consistently with:

RuntimeError: start (0) + length (2) exceeds dimension size (1).

at _fsdp_param.py:806 in reset_sharded_param(), triggered by model.cpu()FSDPModule._apply()reset_sharded_param().

Root Cause Analysis

This is a legitimate revert. The PR changed aten._to_copy.default from using propagate_single_input_strategy (a simple 1:1 pass-through of all input placements) to a new register_single_dim_strategy handler. While the new strategy correctly handles the dtype-conversion correctness issue (the goal of the PR), it introduces a subtle incompatibility with FSDP's padded sharding.

Here's the chain of events:

  1. FSDP pads sharded parameters for uneven sharding. For nn.Linear(3, 3, bias=False) with world_size=4 and Shard(0), the weight [3, 3] is chunked into sizes [1,3], [1,3], [1,3], [0,3], and the last rank's shard is padded to [1, 3]. FSDP records padded_sharded_param_size = [1, 3].

  2. model.cpu() calls Module._apply(lambda t: t.cpu()), which calls .cpu() on the DTensor parameter. This dispatches through aten._to_copy.default.

  3. With the old propagate_single_input_strategy: The strategy is a direct 1:1 copy — input Shard(0) → output Shard(0), preserving the local tensor identity. The local tensor (still padded) gets moved to CPU with no redistribution. Its shape stays [1, 3], matching padded_sharded_param_size.

  4. With the new _to_copy_single_dim_strategy: The single-dim strategy infrastructure computes strategies based on TensorMeta (which reflects the global tensor shape [3, 3]). When the strategy expansion evaluates options and redistribution costs, it may find that the optimal strategy involves a redistribution (all-gather + re-shard). After this redistribution, the new local tensor may have a different size than FSDP's pre-computed padded_sharded_param_size, because the DTensor re-sharding logic doesn't know about FSDP's custom padding convention.

  5. reset_sharded_param() then tries to narrow() the (now differently-sized) local tensor to fit padded_sharded_param_size, and fails with the dimension mismatch.

How to Fix

The author has a few options:

  1. Keep _to_copy as propagate_single_input_strategy for device-only transfers, and only apply the single-dim strategy when a dtype change is actually requested. This could be done by checking kwargs_schema.get("dtype") at registration time, or by registering the single-dim strategy conditionally, or by falling back to propagate_single_input_strategy when there's no dtype change.

  2. Ensure the single-dim strategy produces a zero-cost match for the existing input placement. The issue is that the strategy expansion may assign redistribution costs that cause the optimizer to pick a different strategy than simply preserving the input's Shard(0). The strategy function itself produces [_ShardingPlaceholder(0), _ShardingPlaceholder(0)] which should match, but the cost computation in the single-dim infrastructure may differ from the old approach.

  3. Register _to_copy with both approaches: keep propagate_single_input_strategy for the general case (which FSDP relies on for device transfers), and add a separate override or wrapper that applies the dtype-aware filtering only when dtype is actually being changed.

Option 1 is likely the simplest and safest fix — the PR's correctness concern only applies when there's a non-linear dtype conversion, so device-only _to_copy calls shouldn't need the new strategy at all.

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

pytorchmergebot added a commit that referenced this pull request Mar 10, 2026
…ype conversions (#172696)"

This reverts commit 46cd90c.

Reverted #172696 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](#172696 (comment)))
@pytorchmergebot
Copy link
Collaborator

@stmcgovern your PR has been successfully reverted.

@pytorchmergebot pytorchmergebot added Reverted ci-no-td Do not run TD on this PR labels Mar 10, 2026
@pytorch-bot pytorch-bot bot dismissed wconstab’s stale review March 10, 2026 01:08

This PR was reopened (likely due to being reverted), so your approval was removed. Please request another review.

@wconstab
Copy link
Contributor

@claude can you explain the fsdp issue in more detail? give an example of the sequence of operations in the before (working) and after (broken) case including sizes/metadata at each step? Is FSDP abusing an odd behavior of the old rule or is the new rule overly simple?

cc @weifengpy in case you have any suggestion

@claude
Copy link

claude bot commented Mar 10, 2026

Claude finished @wconstab's task in 8m 39s —— View job


  • Read PR's changed files and FSDP code
  • Read DTensor strategy infrastructure
  • Read the single-dim strategy pipeline
  • Read the failing test
  • Trace through concrete example with sizes
  • Post detailed analysis

Detailed Analysis: Before (Working) vs After (Broken)

Concrete scenario: nn.Linear(3, 3, bias=False) with fully_shard(), world_size=4, Shard(0) on the FSDP mesh.


Setup: How FSDP Creates the Sharded Parameter

Global weight: [3, 3]

FSDP shards on dim 0 with _chunk_with_empty (_fsdp_param.py:342):

Rank Raw chunk Padded to chunk[0].size()
0 [1, 3] [1, 3] (already full)
1 [1, 3] [1, 3]
2 [1, 3] [1, 3]
3 [0, 3] (empty) [1, 3] (padded with zeros)
  • padded_sharded_param_size = [1, 3] for all ranks
  • sharded_size = [1, 3] for ranks 0-2, [0, 3] for rank 3

The local tensor of the DTensor is narrow(dim=0, start=0, length=actual_length) of the padded tensor. So rank 3 has a [0, 3] local tensor but it's backed by [1, 3] storage.

The DTensor's _spec.tensor_meta.shape is the global shape [3, 3] (_fsdp_param.py:312).


Before (Working): propagate_single_input_strategy

When model.cpu() calls param.cpu() on the DTensor, aten._to_copy.default dispatches.

  1. Strategy construction: propagate_single_input_strategy (_tensor_ops.py:54-89) iterates over the input's existing OpStrategy.strategies and creates a 1:1 copy of each. No is_tensor_shardable check. Whatever the input already has is passed through as-is.

  2. Input's OpStrategy: The DTensor parameter has placement (Shard(0),). The strategies list contains an OpSpec with Shard(0) as the placement.

  3. Strategy selection: _select_min_cost_strategy finds the Shard(0) → Shard(0) strategy. The input already has Shard(0), so redistribute_cost = 0 and needs_redistribute = False.

  4. Dispatch: _to_copy runs locally on each rank's local tensor. No redistribution occurs.

  5. Result: The output DTensor has:

    • Same placement: (Shard(0),)
    • Local tensor: the cpu-copied local tensor with the same shape as the input's local tensor
    • Rank 3 still has local shape [0, 3], backed by the narrow of the padded [1, 3] storage
  6. reset_sharded_param(): Checks local_tensor.size() != padded_sharded_param_size. On rank 3: [0, 3] != [1, 3], so it re-pads. narrow(dim=0, start=0, length=0) — length is 0, and padded_sharded_param_size[0] = 1. Padding works because length (0) + start (0) < dimension_size (1). Success.


After (Broken): _to_copy_single_dim_strategy via register_single_dim_strategy

  1. Strategy construction: _to_copy_single_dim_strategy (_tensor_ops.py:126-139) generates:

    • [_ShardingPlaceholder(0), _ShardingPlaceholder(0)] (shard on dim 0)
    • [_ShardingPlaceholder(1), _ShardingPlaceholder(1)] (shard on dim 1)
    • [Partial("sum"), Partial("sum")], [Partial("avg"), Partial("avg")], etc. (no dtype change for .cpu())
    • Plus auto-inserted: [Replicate(), Replicate()]
  2. Expansion via expand_to_full_mesh_op_strategy (utils.py:356-559): Expands these single-dim strategies to the full mesh via Cartesian product. For each combination, it calls is_tensor_shardable:

    is_tensor_shardable(inp.shape, s, allow_unbacked_sharding=True)

    Here inp.shape comes from input_args_strategy[0].tensor_meta.shape — the global shape [3, 3].

  3. The Shard(0) strategy gets filtered out: is_tensor_shardable([3, 3], Shard(0), mesh_size=4) checks shape[0] < num_shards, i.e., 3 < 4returns False (utils.py:221). The strategy is skipped via continue at line 528.

    The allow_uneven_sharding escape hatch at line 523-524 (allow_uneven_sharding and inp.strategies[0].output_spec.placements == s.placements) would save this — but _to_copy_single_dim_strategy is registered with allow_uneven_sharding=False (the default).

  4. Only Replicate() survives: The all-Replicate strategy passes is_tensor_shardable (no sharding to check). Shard(1) passes too (shape[1]=3 >= mesh_size? No, 3 < 4). So the only valid strategy is (Replicate(),).

  5. Strategy selection: _select_min_cost_strategy picks the Replicate → Replicate strategy, because it's the only one. This means the input needs to be redistributed from Shard(0) to Replicate.

  6. Redistribution triggers an all-gather: The sharded tensor is all-gathered across ranks to produce a full [3, 3] tensor on each rank. Then _to_copy runs on the replicated tensor. The output DTensor has placement (Replicate(),) with local shape [3, 3].

  7. swap_tensors replaces the parameter: nn.Module._apply uses swap_tensors to replace the old DTensor parameter with the new one.

  8. reset_sharded_param() (_fsdp_param.py:762-827): Now the new parameter is a Replicate DTensor with local shape [3, 3]. But padded_sharded_param_size is still [1, 3]. The check at line 800: [3, 3] != [1, 3] → enters the padding branch. It tries:

    padded_local_tensor = local_tensor.new_zeros([1, 3])  # new padded buffer
    padded_local_tensor.narrow(dim=0, start=0, length=3).copy_(local_tensor)
    #                                               ^^^ length=3 but dim size=1!

    RuntimeError: start (0) + length (3) exceeds dimension size (1) — the exact error from CI.

    But actually, for rank 3 it's even worse: length comes from local_tensor.size(shard_dim) which is 3 (the replicated local), but padded_sharded_param_size is [1, 3].


Is FSDP Abusing the Old Rule or Is the New Rule Overly Simple?

Both, to some extent:

FSDP has a legitimate expectation: FSDP manages its own padding, and relies on device-transfer ops (.cpu(), .cuda()) preserving the exact same sharding layout. This is a reasonable contract — a device transfer shouldn't change the distributed placement of a tensor. The old propagate_single_input_strategy honored this implicitly by never questioning the input's placements.

The new rule is overly simple in its shardability filtering: The fundamental problem is that expand_to_full_mesh_op_strategy uses is_tensor_shardable(global_shape, spec) to filter strategies, and this rejects Shard(0) because global_dim_size (3) < mesh_size (4). But the tensor is already sharded that way at runtime — FSDP deliberately created an uneven shard. The allow_uneven_sharding=True flag would fix this by adding the escape hatch inp.strategies[0].output_spec.placements == s.placements — which says "if the input already has these placements, allow it even if the shardability check fails."


How to Fix

Option 1 (Simplest): Add allow_uneven_sharding=True to the _to_copy registration:

@register_single_dim_strategy(
    aten._to_copy.default,
    schema_info=RuntimeSchemaInfo(static_kwargkey=["dtype"]),
    allow_unbacked_sharding=True,
    allow_uneven_sharding=True,  # FSDP creates uneven shards
)

This would preserve the Shard(0) strategy when the input already has Shard(0), while still correctly filtering non-linear dtype conversions with Partial.

Option 2: Check whether FSDP should instead use even-only sharding (not viable — FSDP's whole design allows sharding tensors whose dim is smaller than world_size by creating zero-size shards).

Option 3: Make propagate_single_input_strategy a "priority" fast path that register_single_dim_strategy consults first. More invasive, less clear benefit.

Option 1 seems like the correct fix. The _to_copy op has no reason to reject uneven sharding — it's a pointwise operation that doesn't change shapes.


@weifengpy
Copy link
Contributor

is_tensor_shardable

Claude's suggestion on option 1 seems reasonable to me. @anshul-si brough up whether is_tensor_shardable is reasonable for all ops and added allow_uneven_sharding

Rewrite _to_copy strategy using register_single_dim_strategy. The old
propagate_single_input_strategy unconditionally passed through all
placements, which is incorrect when the dtype conversion does not
commute with the reduce op (e.g. trunc(0.6+0.6)=1 ≠ trunc(0.6)+trunc(0.6)=0).
FSDP creates uneven shards (dim < world_size) with padding.
Without allow_uneven_sharding, the single_dim_strategy infra
rejects Shard placements via is_tensor_shardable, forcing a
redistribution to Replicate that breaks reset_sharded_param().
@stmcgovern stmcgovern force-pushed the fix/dtensor-partial-to-copy-dtype branch from e41f21a to f03d781 Compare March 10, 2026 14:04
@pytorch-bot pytorch-bot bot removed the ciflow/trunk Trigger trunk jobs on your pull request label Mar 10, 2026
@stmcgovern
Copy link
Collaborator Author

@pytorchbot merge

@pytorch-bot
Copy link

pytorch-bot bot commented Mar 11, 2026

This PR needs to be approved by an authorized maintainer before merge.

@stmcgovern
Copy link
Collaborator Author

@wconstab Would you please rereview/approve. I added the recommended FSDP fix. The failed CI test is an unrelated timeout.

@stmcgovern stmcgovern requested a review from wconstab March 11, 2026 14:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-no-td Do not run TD on this PR Merged open source release notes: distributed (dtensor) release notes category Reverted triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[DTensor][Partial] Type conversion to int/bool incorrectly preserves Partial(sum)

8 participants