Skip to content

[inductor][bucketing] Fx collectives bucketing of multiple dtypes#162470

Closed
IvanKobzarev wants to merge 12 commits intogh/IvanKobzarev/151/basefrom
gh/IvanKobzarev/151/head
Closed

[inductor][bucketing] Fx collectives bucketing of multiple dtypes#162470
IvanKobzarev wants to merge 12 commits intogh/IvanKobzarev/151/basefrom
gh/IvanKobzarev/151/head

Conversation

@IvanKobzarev
Copy link
Contributor

@IvanKobzarev IvanKobzarev commented Sep 9, 2025

Stack from ghstack (oldest at bottom):

Bucketing of multiple dtypes to be processed in one bucketed collective.

First target is to bucket bf16 and f32, but already can be used with other dtypes.

For now multidtype bucketing is only supported with "custom_ops" mode.
Non custom_ops needs additional work on inductor side.

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @msaroufim @dcci @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @ezyang

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 9, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/162470

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 00ecea8 with merge base 9272437 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
IvanKobzarev added a commit that referenced this pull request Sep 9, 2025
ghstack-source-id: 9ad19c7
Pull Request resolved: #162470
Lowering for aten._to_copy fails on fallback and 
```
  File "/data/users/ivankobzarev/h/pytorch/torch/_dynamo/eval_frame.py", line 845, in compile_wrapper
    raise e.remove_dynamo_frames() from None  # see TORCHDYNAMO_VERBOSE=1
  File "/data/users/ivankobzarev/h/pytorch/torch/_inductor/compile_fx.py", line 990, in _compile_fx_inner
    raise InductorError(e, currentframe()).with_traceback(
  File "/data/users/ivankobzarev/h/pytorch/torch/_inductor/compile_fx.py", line 974, in _compile_fx_inner
    mb_compiled_graph = fx_codegen_and_compile(
  File "/data/users/ivankobzarev/h/pytorch/torch/_inductor/compile_fx.py", line 1695, in fx_codegen_and_compile
    return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs)
  File "/data/users/ivankobzarev/h/pytorch/torch/_inductor/compile_fx.py", line 1420, in codegen_and_compile
    graph.run(*example_inputs)
  File "/data/users/ivankobzarev/h/pytorch/torch/_inductor/graph.py", line 937, in run
    return super().run(*args)
  File "/data/users/ivankobzarev/h/pytorch/torch/fx/interpreter.py", line 174, in run
    self.env[node] = self.run_node(node)
  File "/data/users/ivankobzarev/h/pytorch/torch/_inductor/graph.py", line 1624, in run_node
    result = super().run_node(n)
  File "/data/users/ivankobzarev/h/pytorch/torch/fx/interpreter.py", line 256, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
  File "/data/users/ivankobzarev/h/pytorch/torch/_inductor/graph.py", line 1233, in call_function
    make_fallback(target, layout_constraint=decided_constraint)
  File "/data/users/ivankobzarev/h/pytorch/torch/_inductor/lowering.py", line 2080, in make_fallback
    assert op not in decompositions or override_decomp, (
torch._inductor.exc.InductorError: AssertionError: both a fallback and a decomp for same op: aten._to_copy.default
```


cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
IvanKobzarev added a commit that referenced this pull request Sep 9, 2025
ghstack-source-id: e47b03e
Pull Request resolved: #162470
Lowering for aten._to_copy fails on fallback and 
```
  File "/data/users/ivankobzarev/h/pytorch/torch/_dynamo/eval_frame.py", line 845, in compile_wrapper
    raise e.remove_dynamo_frames() from None  # see TORCHDYNAMO_VERBOSE=1
  File "/data/users/ivankobzarev/h/pytorch/torch/_inductor/compile_fx.py", line 990, in _compile_fx_inner
    raise InductorError(e, currentframe()).with_traceback(
  File "/data/users/ivankobzarev/h/pytorch/torch/_inductor/compile_fx.py", line 974, in _compile_fx_inner
    mb_compiled_graph = fx_codegen_and_compile(
  File "/data/users/ivankobzarev/h/pytorch/torch/_inductor/compile_fx.py", line 1695, in fx_codegen_and_compile
    return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs)
  File "/data/users/ivankobzarev/h/pytorch/torch/_inductor/compile_fx.py", line 1420, in codegen_and_compile
    graph.run(*example_inputs)
  File "/data/users/ivankobzarev/h/pytorch/torch/_inductor/graph.py", line 937, in run
    return super().run(*args)
  File "/data/users/ivankobzarev/h/pytorch/torch/fx/interpreter.py", line 174, in run
    self.env[node] = self.run_node(node)
  File "/data/users/ivankobzarev/h/pytorch/torch/_inductor/graph.py", line 1624, in run_node
    result = super().run_node(n)
  File "/data/users/ivankobzarev/h/pytorch/torch/fx/interpreter.py", line 256, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
  File "/data/users/ivankobzarev/h/pytorch/torch/_inductor/graph.py", line 1233, in call_function
    make_fallback(target, layout_constraint=decided_constraint)
  File "/data/users/ivankobzarev/h/pytorch/torch/_inductor/lowering.py", line 2080, in make_fallback
    assert op not in decompositions or override_decomp, (
torch._inductor.exc.InductorError: AssertionError: both a fallback and a decomp for same op: aten._to_copy.default
```


cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
IvanKobzarev added a commit that referenced this pull request Sep 9, 2025
@IvanKobzarev IvanKobzarev changed the title [inductor][bucketing][WIP] Bucket multidtype [inductor][bucketing] Fx collectives bucketing of multiple dtypes Sep 9, 2025
Copy link
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

now that the other prs have landed - mind rebasing ?

… dtypes"





cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
IvanKobzarev added a commit that referenced this pull request Oct 2, 2025
@IvanKobzarev IvanKobzarev requested a review from eellison October 2, 2025 10:20
… dtypes"


Bucketing of multiple dtypes to be processed in one bucketed collective.

First target is to bucket bf16 and f32, but already can be used with other dtypes.

For now multidtype bucketing is only supported with "custom_ops" mode.
Non custom_ops needs additional work on inductor side.



cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
IvanKobzarev added a commit that referenced this pull request Oct 2, 2025
… dtypes"


Bucketing of multiple dtypes to be processed in one bucketed collective.

First target is to bucket bf16 and f32, but already can be used with other dtypes.

For now multidtype bucketing is only supported with "custom_ops" mode.
Non custom_ops needs additional work on inductor side.



cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
IvanKobzarev added a commit that referenced this pull request Oct 2, 2025
Copy link
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good !

Only blocking question is about reduce scatter multi dtype


@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not SM80OrLater, "bfloat16")
@parametrize("bucket_mode", ["all_custom_ops_multidtype"])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can we update the config of bucket_mode : Union[literal[....]]]` so we know what the options are ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can. It will be a bit long sequence, as we have 2**num_options (custom_ops +/ multidtype +/- all/fsdp). We should reduce it to 4 for now.

Comment on lines +55 to +56
if s == OrderedSet([torch.bfloat16, torch.float]): # type: ignore[attr-defined]
return torch.bfloat16 # type: ignore[attr-defined]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reason for this special case ? we could just always choose the lowest itemsize dtype.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we can pick lowert dtype.

gm: torch.fx.GraphModule,
bucket_cap_mb_by_bucket_idx: Callable[[int], float],
filter_wait_node: Optional[Callable[[torch.fx.Node], bool]] = None,
mode: Optional[str] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same question - add literal options ?

gm: torch.fx.GraphModule,
bucket_cap_mb_by_bucket_idx: Callable[[int], float],
filter_wait_node: Optional[Callable[[torch.fx.Node], bool]] = None,
mode: Optional[str] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • here too ?

"""

group_key_fn = (
_rs_group_key_multidtype if mode and "multidtype" in mode else _rs_group_key
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we can do multidtype for reduce scatter, since nccl is actually doing a reduction.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, agree, we can do it only with casts and numerics changes. I will remove this option.
We can only do joint uppermost dtype, but that is not what we want :)

rs_ins: list[torch.Tensor],
group_size: int,
dtype: torch.dtype, # type: ignore[name-defined]
numel_mults: list[int],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: not sure we need this if we always have the assumption we will view as lowest bitwidth dtype

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Custom ops do not support List[dtype] to calculate it in the op. So I changed to passing multipliers as List[int] instead. We need to know how to split the result according to different dtypes.

… dtypes"


Bucketing of multiple dtypes to be processed in one bucketed collective.

First target is to bucket bf16 and f32, but already can be used with other dtypes.

For now multidtype bucketing is only supported with "custom_ops" mode.
Non custom_ops needs additional work on inductor side.



cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
IvanKobzarev added a commit that referenced this pull request Oct 7, 2025
… dtypes"


Bucketing of multiple dtypes to be processed in one bucketed collective.

First target is to bucket bf16 and f32, but already can be used with other dtypes.

For now multidtype bucketing is only supported with "custom_ops" mode.
Non custom_ops needs additional work on inductor side.



cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
IvanKobzarev added a commit that referenced this pull request Oct 7, 2025
@IvanKobzarev IvanKobzarev requested a review from eellison October 7, 2025 16:59
Copy link
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good, just one comment


def pick_bucket_dtype(dtypes: list[torch.dtype]) -> torch.dtype: # type: ignore[name-defined]
assert len(dtypes) > 0
lowest_dtype = dtypes[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: return min(dtypes, key=operator.attrgetter("itemsize"))

_pre_bucket_all_gather.register_fake(_pre_bucket_all_gather_fake)


def _dtype_size_bytes(dtype: torch.dtype) -> int: # type: ignore[name-defined]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is dtype.itemsize

rank: int,
) -> list[torch.Tensor]:
ag_ins = [
torch._prims.convert_element_type(_ag_in, out_dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should only view dtype here... Could we make this only do collective merging of different dtypes if we can do it without increase the total bytes transmitted ? Potentially in the future we would want to upcast if the latency is longer than the cost of additional bytes.. leave for future change?

Copy link
Contributor Author

@IvanKobzarev IvanKobzarev Oct 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we will have convert to out_dtypes only for fused-convert-dtype for all-gathers.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I follow. Why do we do this ? We shouldn't need extra casts

… dtypes"


Bucketing of multiple dtypes to be processed in one bucketed collective.

First target is to bucket bf16 and f32, but already can be used with other dtypes.

For now multidtype bucketing is only supported with "custom_ops" mode.
Non custom_ops needs additional work on inductor side.



cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
IvanKobzarev added a commit that referenced this pull request Oct 16, 2025
… dtypes"


Bucketing of multiple dtypes to be processed in one bucketed collective.

First target is to bucket bf16 and f32, but already can be used with other dtypes.

For now multidtype bucketing is only supported with "custom_ops" mode.
Non custom_ops needs additional work on inductor side.



cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
IvanKobzarev added a commit that referenced this pull request Oct 16, 2025
… dtypes"


Bucketing of multiple dtypes to be processed in one bucketed collective.

First target is to bucket bf16 and f32, but already can be used with other dtypes.

For now multidtype bucketing is only supported with "custom_ops" mode.
Non custom_ops needs additional work on inductor side.



cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
IvanKobzarev added a commit that referenced this pull request Oct 16, 2025
@IvanKobzarev
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 16, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Oct 21, 2025
…torch#162470)

Bucketing of multiple dtypes to be processed in one bucketed collective.

First target is to bucket bf16 and f32, but already can be used with other dtypes.

For now multidtype bucketing is only supported with "custom_ops" mode.
Non custom_ops needs additional work on inductor side.

Pull Request resolved: pytorch#162470
Approved by: https://github.com/eellison
zhudada0120 pushed a commit to zhudada0120/pytorch that referenced this pull request Oct 22, 2025
…torch#162470)

Bucketing of multiple dtypes to be processed in one bucketed collective.

First target is to bucket bf16 and f32, but already can be used with other dtypes.

For now multidtype bucketing is only supported with "custom_ops" mode.
Non custom_ops needs additional work on inductor side.

Pull Request resolved: pytorch#162470
Approved by: https://github.com/eellison
@github-actions github-actions bot deleted the gh/IvanKobzarev/151/head branch November 17, 2025 02:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged module: inductor oncall: distributed Add this issue/PR to distributed oncall triage queue topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants