[inductor][bucketing] Fx collectives bucketing of multiple dtypes#162470
[inductor][bucketing] Fx collectives bucketing of multiple dtypes#162470IvanKobzarev wants to merge 12 commits intogh/IvanKobzarev/151/basefrom
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/162470
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 00ecea8 with merge base 9272437 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben [ghstack-poisoned]
Lowering for aten._to_copy fails on fallback and
```
File "/data/users/ivankobzarev/h/pytorch/torch/_dynamo/eval_frame.py", line 845, in compile_wrapper
raise e.remove_dynamo_frames() from None # see TORCHDYNAMO_VERBOSE=1
File "/data/users/ivankobzarev/h/pytorch/torch/_inductor/compile_fx.py", line 990, in _compile_fx_inner
raise InductorError(e, currentframe()).with_traceback(
File "/data/users/ivankobzarev/h/pytorch/torch/_inductor/compile_fx.py", line 974, in _compile_fx_inner
mb_compiled_graph = fx_codegen_and_compile(
File "/data/users/ivankobzarev/h/pytorch/torch/_inductor/compile_fx.py", line 1695, in fx_codegen_and_compile
return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs)
File "/data/users/ivankobzarev/h/pytorch/torch/_inductor/compile_fx.py", line 1420, in codegen_and_compile
graph.run(*example_inputs)
File "/data/users/ivankobzarev/h/pytorch/torch/_inductor/graph.py", line 937, in run
return super().run(*args)
File "/data/users/ivankobzarev/h/pytorch/torch/fx/interpreter.py", line 174, in run
self.env[node] = self.run_node(node)
File "/data/users/ivankobzarev/h/pytorch/torch/_inductor/graph.py", line 1624, in run_node
result = super().run_node(n)
File "/data/users/ivankobzarev/h/pytorch/torch/fx/interpreter.py", line 256, in run_node
return getattr(self, n.op)(n.target, args, kwargs)
File "/data/users/ivankobzarev/h/pytorch/torch/_inductor/graph.py", line 1233, in call_function
make_fallback(target, layout_constraint=decided_constraint)
File "/data/users/ivankobzarev/h/pytorch/torch/_inductor/lowering.py", line 2080, in make_fallback
assert op not in decompositions or override_decomp, (
torch._inductor.exc.InductorError: AssertionError: both a fallback and a decomp for same op: aten._to_copy.default
```
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben
[ghstack-poisoned]
Lowering for aten._to_copy fails on fallback and
```
File "/data/users/ivankobzarev/h/pytorch/torch/_dynamo/eval_frame.py", line 845, in compile_wrapper
raise e.remove_dynamo_frames() from None # see TORCHDYNAMO_VERBOSE=1
File "/data/users/ivankobzarev/h/pytorch/torch/_inductor/compile_fx.py", line 990, in _compile_fx_inner
raise InductorError(e, currentframe()).with_traceback(
File "/data/users/ivankobzarev/h/pytorch/torch/_inductor/compile_fx.py", line 974, in _compile_fx_inner
mb_compiled_graph = fx_codegen_and_compile(
File "/data/users/ivankobzarev/h/pytorch/torch/_inductor/compile_fx.py", line 1695, in fx_codegen_and_compile
return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs)
File "/data/users/ivankobzarev/h/pytorch/torch/_inductor/compile_fx.py", line 1420, in codegen_and_compile
graph.run(*example_inputs)
File "/data/users/ivankobzarev/h/pytorch/torch/_inductor/graph.py", line 937, in run
return super().run(*args)
File "/data/users/ivankobzarev/h/pytorch/torch/fx/interpreter.py", line 174, in run
self.env[node] = self.run_node(node)
File "/data/users/ivankobzarev/h/pytorch/torch/_inductor/graph.py", line 1624, in run_node
result = super().run_node(n)
File "/data/users/ivankobzarev/h/pytorch/torch/fx/interpreter.py", line 256, in run_node
return getattr(self, n.op)(n.target, args, kwargs)
File "/data/users/ivankobzarev/h/pytorch/torch/_inductor/graph.py", line 1233, in call_function
make_fallback(target, layout_constraint=decided_constraint)
File "/data/users/ivankobzarev/h/pytorch/torch/_inductor/lowering.py", line 2080, in make_fallback
assert op not in decompositions or override_decomp, (
torch._inductor.exc.InductorError: AssertionError: both a fallback and a decomp for same op: aten._to_copy.default
```
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben
[ghstack-poisoned]
eellison
left a comment
There was a problem hiding this comment.
now that the other prs have landed - mind rebasing ?
… dtypes" cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben [ghstack-poisoned]
… dtypes" Bucketing of multiple dtypes to be processed in one bucketed collective. First target is to bucket bf16 and f32, but already can be used with other dtypes. For now multidtype bucketing is only supported with "custom_ops" mode. Non custom_ops needs additional work on inductor side. cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben [ghstack-poisoned]
… dtypes" Bucketing of multiple dtypes to be processed in one bucketed collective. First target is to bucket bf16 and f32, but already can be used with other dtypes. For now multidtype bucketing is only supported with "custom_ops" mode. Non custom_ops needs additional work on inductor side. cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben [ghstack-poisoned]
eellison
left a comment
There was a problem hiding this comment.
Looks good !
Only blocking question is about reduce scatter multi dtype
|
|
||
| @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") | ||
| @unittest.skipIf(not SM80OrLater, "bfloat16") | ||
| @parametrize("bucket_mode", ["all_custom_ops_multidtype"]) |
There was a problem hiding this comment.
nit: can we update the config of bucket_mode : Union[literal[....]]]` so we know what the options are ?
There was a problem hiding this comment.
We can. It will be a bit long sequence, as we have 2**num_options (custom_ops +/ multidtype +/- all/fsdp). We should reduce it to 4 for now.
| if s == OrderedSet([torch.bfloat16, torch.float]): # type: ignore[attr-defined] | ||
| return torch.bfloat16 # type: ignore[attr-defined] |
There was a problem hiding this comment.
reason for this special case ? we could just always choose the lowest itemsize dtype.
There was a problem hiding this comment.
Yes, we can pick lowert dtype.
| gm: torch.fx.GraphModule, | ||
| bucket_cap_mb_by_bucket_idx: Callable[[int], float], | ||
| filter_wait_node: Optional[Callable[[torch.fx.Node], bool]] = None, | ||
| mode: Optional[str] = None, |
There was a problem hiding this comment.
same question - add literal options ?
| gm: torch.fx.GraphModule, | ||
| bucket_cap_mb_by_bucket_idx: Callable[[int], float], | ||
| filter_wait_node: Optional[Callable[[torch.fx.Node], bool]] = None, | ||
| mode: Optional[str] = None, |
| """ | ||
|
|
||
| group_key_fn = ( | ||
| _rs_group_key_multidtype if mode and "multidtype" in mode else _rs_group_key |
There was a problem hiding this comment.
I don't think we can do multidtype for reduce scatter, since nccl is actually doing a reduction.
There was a problem hiding this comment.
Yes, agree, we can do it only with casts and numerics changes. I will remove this option.
We can only do joint uppermost dtype, but that is not what we want :)
| rs_ins: list[torch.Tensor], | ||
| group_size: int, | ||
| dtype: torch.dtype, # type: ignore[name-defined] | ||
| numel_mults: list[int], |
There was a problem hiding this comment.
nit: not sure we need this if we always have the assumption we will view as lowest bitwidth dtype
There was a problem hiding this comment.
Custom ops do not support List[dtype] to calculate it in the op. So I changed to passing multipliers as List[int] instead. We need to know how to split the result according to different dtypes.
… dtypes" Bucketing of multiple dtypes to be processed in one bucketed collective. First target is to bucket bf16 and f32, but already can be used with other dtypes. For now multidtype bucketing is only supported with "custom_ops" mode. Non custom_ops needs additional work on inductor side. cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben [ghstack-poisoned]
… dtypes" Bucketing of multiple dtypes to be processed in one bucketed collective. First target is to bucket bf16 and f32, but already can be used with other dtypes. For now multidtype bucketing is only supported with "custom_ops" mode. Non custom_ops needs additional work on inductor side. cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben [ghstack-poisoned]
eellison
left a comment
There was a problem hiding this comment.
looks good, just one comment
|
|
||
| def pick_bucket_dtype(dtypes: list[torch.dtype]) -> torch.dtype: # type: ignore[name-defined] | ||
| assert len(dtypes) > 0 | ||
| lowest_dtype = dtypes[0] |
There was a problem hiding this comment.
nit: return min(dtypes, key=operator.attrgetter("itemsize"))
| _pre_bucket_all_gather.register_fake(_pre_bucket_all_gather_fake) | ||
|
|
||
|
|
||
| def _dtype_size_bytes(dtype: torch.dtype) -> int: # type: ignore[name-defined] |
| rank: int, | ||
| ) -> list[torch.Tensor]: | ||
| ag_ins = [ | ||
| torch._prims.convert_element_type(_ag_in, out_dtype) |
There was a problem hiding this comment.
We should only view dtype here... Could we make this only do collective merging of different dtypes if we can do it without increase the total bytes transmitted ? Potentially in the future we would want to upcast if the latency is longer than the cost of additional bytes.. leave for future change?
There was a problem hiding this comment.
Here we will have convert to out_dtypes only for fused-convert-dtype for all-gathers.
There was a problem hiding this comment.
I'm not sure I follow. Why do we do this ? We shouldn't need extra casts
… dtypes" Bucketing of multiple dtypes to be processed in one bucketed collective. First target is to bucket bf16 and f32, but already can be used with other dtypes. For now multidtype bucketing is only supported with "custom_ops" mode. Non custom_ops needs additional work on inductor side. cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben [ghstack-poisoned]
… dtypes" Bucketing of multiple dtypes to be processed in one bucketed collective. First target is to bucket bf16 and f32, but already can be used with other dtypes. For now multidtype bucketing is only supported with "custom_ops" mode. Non custom_ops needs additional work on inductor side. cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben [ghstack-poisoned]
… dtypes" Bucketing of multiple dtypes to be processed in one bucketed collective. First target is to bucket bf16 and f32, but already can be used with other dtypes. For now multidtype bucketing is only supported with "custom_ops" mode. Non custom_ops needs additional work on inductor side. cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben [ghstack-poisoned]
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…torch#162470) Bucketing of multiple dtypes to be processed in one bucketed collective. First target is to bucket bf16 and f32, but already can be used with other dtypes. For now multidtype bucketing is only supported with "custom_ops" mode. Non custom_ops needs additional work on inductor side. Pull Request resolved: pytorch#162470 Approved by: https://github.com/eellison
…torch#162470) Bucketing of multiple dtypes to be processed in one bucketed collective. First target is to bucket bf16 and f32, but already can be used with other dtypes. For now multidtype bucketing is only supported with "custom_ops" mode. Non custom_ops needs additional work on inductor side. Pull Request resolved: pytorch#162470 Approved by: https://github.com/eellison
Stack from ghstack (oldest at bottom):
Bucketing of multiple dtypes to be processed in one bucketed collective.
First target is to bucket bf16 and f32, but already can be used with other dtypes.
For now multidtype bucketing is only supported with "custom_ops" mode.
Non custom_ops needs additional work on inductor side.
cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @msaroufim @dcci @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @ezyang