[inductor] Runtime estimations: use nccl estimator; mm only benchmark mode#161405
[inductor] Runtime estimations: use nccl estimator; mm only benchmark mode#161405IvanKobzarev wants to merge 16 commits intogh/IvanKobzarev/140/basefrom
Conversation
… mode [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/161405
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 3e76e04 with merge base 5b90e85 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
… mode ghstack-source-id: 8ba5e3a Pull Request resolved: pytorch#161405
…y benchmark mode" During comms reordering , sink wait iterative observed previous runtime estimations pretty off for collectives and mms. Adding optional usage of: - c10d.time_estimator for collectives, which is based on NCCL estimator Benchmark mode only for matmuls, as they are highly dependent on mm backend - The logic mostly copied from Ruisi's PRs for inductor simple_fsdp #157572 This estimations corrections are in default `BaseSchedulerNode.estimate_runtime()` cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben [ghstack-poisoned]
… mode ghstack-source-id: de87a24 Pull Request resolved: pytorch#161405
…y benchmark mode" During comms reordering , sink wait iterative observed previous runtime estimations pretty off for collectives and mms. Adding optional usage of: - c10d.time_estimator for collectives, which is based on NCCL estimator Benchmark mode only for matmuls, as they are highly dependent on mm backend - The logic mostly copied from Ruisi's PRs for inductor simple_fsdp #157572 This estimations corrections are in default `BaseSchedulerNode.estimate_runtime()` cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben [ghstack-poisoned]
…y benchmark mode" During comms reordering , sink wait iterative observed previous runtime estimations pretty off for collectives and mms. Adding optional usage of: - c10d.time_estimator for collectives, which is based on NCCL estimator Benchmark mode only for matmuls, as they are highly dependent on mm backend - The logic mostly copied from Ruisi's PRs for inductor simple_fsdp #157572 This estimations corrections are in default `BaseSchedulerNode.estimate_runtime()` cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben [ghstack-poisoned]
|
@IvanKobzarev has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
|
@IvanKobzarev has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
…y benchmark mode" During comms reordering , sink wait iterative observed previous runtime estimations pretty off for collectives and mms. Adding optional usage of: - c10d.time_estimator for collectives, which is based on NCCL estimator Benchmark mode only for matmuls, as they are highly dependent on mm backend - The logic mostly copied from Ruisi's PRs for inductor simple_fsdp #157572 This estimations corrections are in default `BaseSchedulerNode.estimate_runtime()` cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben Differential Revision: [D81152294](https://our.internmc.facebook.com/intern/diff/D81152294) [ghstack-poisoned]
|
@IvanKobzarev has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
| } | ||
| elif name == "torch.ops._c10d_functional.all_gather_into_tensor_out.default": | ||
| # TODO: use real all_gather_into_tensor_out | ||
| fn = torch.ops._c10d_functional.all_gather_into_tensor |
There was a problem hiding this comment.
Curious why we use all_gather_into_tensor here for all_gather_into_tensor_out?
There was a problem hiding this comment.
I reused arguments parsing for all_gather_into_tensor,
as collective work should be the same.
But in _out variant we do not have memory allocation
| fn = torch.ops._c10d_functional.all_to_all_single | ||
| # Artificial uniform split assumption, | ||
| # which can be not true in case of uneven sharding. | ||
| split_sizes = [in_t.size(0) // pg_size] * pg_size |
There was a problem hiding this comment.
if in_t.size(0) is 5 and pg_size is 2, then split_sizes is [2,2]? If so, it only covers 4 not 5?
|
|
||
| tensor_size_mult = 1.0 | ||
| if coll == NCCL_COLL.ALL_TO_ALL: | ||
| tensor_size_mult = 2.0 / group_size |
There was a problem hiding this comment.
Why are 2.0 divided by group_size?
There was a problem hiding this comment.
Yeah, just added hacky multiplier to be closer to the values during benchmarking.
There was a problem hiding this comment.
The rest of this code is a direct port of the nccl logic
The following heuristics are copied from https://github.com/NVIDIA/nccl/blob/master/src/graph/tuning.cc.
We aim to estimate the runtime as accurately as possible.
does it make sense to keep that property ?
There was a problem hiding this comment.
also, i dont see AlltoAll here for whatevr reason.. https://github.com/NVIDIA/nccl/blob/master/src/graph/tuning.cc
| if name == "torch.ops._c10d_functional.all_gather_into_tensor.default": | ||
| fn = torch.ops._c10d_functional.all_gather_into_tensor | ||
| return fn, { | ||
| "input": in_t, | ||
| "group_size": pg_size, | ||
| "group_name": pg_name, | ||
| } | ||
| elif name == "torch.ops._c10d_functional.all_gather_into_tensor_out.default": | ||
| # TODO: use real all_gather_into_tensor_out | ||
| fn = torch.ops._c10d_functional.all_gather_into_tensor | ||
| return fn, { | ||
| "input": in_t, | ||
| "group_size": pg_size, | ||
| "group_name": pg_name, | ||
| } | ||
| elif name == "torch.ops._c10d_functional.reduce_scatter_tensor.default": |
There was a problem hiding this comment.
I think the way you generating input nodes below generically would be better here as well
| args = snode.node.inputs | ||
| args = snode.node.fill_non_provided_args( | ||
| [*args, *snode.node.constant_args], snode.node.kwargs | ||
| ) | ||
| kwargs = snode.node.kwargs | ||
| flat_args, flat_args_pytree_spec = pytree.tree_flatten((args, kwargs)) |
There was a problem hiding this comment.
These are the same utilities we could have used above for getting the nccl inputs.
| num_iters = 3 | ||
| start_event = torch.cuda.Event(enable_timing=True) | ||
| end_event = torch.cuda.Event(enable_timing=True) | ||
| cpu_start = time.time() | ||
| start_event.record(torch.cuda.current_stream()) | ||
| for _ in range(num_iters): | ||
| fn(*args, **kwargs) | ||
| end_event.record(torch.cuda.current_stream()) | ||
| cpu_end = time.time() | ||
| torch.cuda.synchronize() | ||
| cpu_time = cpu_end - cpu_start | ||
| total_op_time = start_event.elapsed_time(end_event) - cpu_time | ||
| mean_op_time_ms = total_op_time / num_iters | ||
| del flat_args | ||
| mean_op_time_ns = mean_op_time_ms * 1e6 | ||
| cache.put(cache_key, mean_op_time_ns) | ||
| return mean_op_time_ns |
There was a problem hiding this comment.
Could we reuse benchmark_gpu here ?
| return value | ||
|
|
||
|
|
||
| def estimate_runtime_benchmark(snode: BaseSchedulerNode) -> Optional[float]: |
There was a problem hiding this comment.
Have you tried benchmark_fused_nodes ? this should already accomplish what this function does
| super().__init__() | ||
| V.graph.scheduler = self | ||
| self.backends: dict[torch.device, BaseScheduling] = {} | ||
| self.estimate_runtime_cache = EstimateRuntimeCache() |
There was a problem hiding this comment.
Can we make this a machine local cache ? See:
pytorch/torch/_inductor/fx_passes/pad_mm.py
Lines 250 to 268 in 1f820de
ruisizhang123
left a comment
There was a problem hiding this comment.
I found there is a config.estimate_op_runtime, which allows users to parse customized op estimation function to inductor:
pytorch/torch/_inductor/comms.py
Lines 1227 to 1236 in b7e207c
It might make more sense to have runtime_estimations_use_nccl_lib_estimations and runtime_estimations_mms_benchmark as a callable function to estimate_op_runtime?
…y benchmark mode" During comms reordering , sink wait iterative observed previous runtime estimations pretty off for collectives and mms. Adding optional usage of: - c10d.time_estimator for collectives, which is based on NCCL estimator Benchmark mode only for matmuls, as they are highly dependent on mm backend - The logic mostly copied from Ruisi's PRs for inductor simple_fsdp #157572 This estimations corrections are in default `BaseSchedulerNode.estimate_runtime()` cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben Differential Revision: [D81152294](https://our.internmc.facebook.com/intern/diff/D81152294) [ghstack-poisoned]
…y benchmark mode" During comms reordering , sink wait iterative observed previous runtime estimations pretty off for collectives and mms. Adding optional usage of: - c10d.time_estimator for collectives, which is based on NCCL estimator Benchmark mode only for matmuls, as they are highly dependent on mm backend - The logic mostly copied from Ruisi's PRs for inductor simple_fsdp #157572 This estimations corrections are in default `BaseSchedulerNode.estimate_runtime()` cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben Differential Revision: [D81152294](https://our.internmc.facebook.com/intern/diff/D81152294) [ghstack-poisoned]
…y benchmark mode" During comms reordering , sink wait iterative observed previous runtime estimations pretty off for collectives and mms. Adding optional usage of: - c10d.time_estimator for collectives, which is based on NCCL estimator Benchmark mode only for matmuls, as they are highly dependent on mm backend - The logic mostly copied from Ruisi's PRs for inductor simple_fsdp #157572 This estimations corrections are in default `BaseSchedulerNode.estimate_runtime()` cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben Differential Revision: [D81152294](https://our.internmc.facebook.com/intern/diff/D81152294) [ghstack-poisoned]
…y benchmark mode" During comms reordering , sink wait iterative observed previous runtime estimations pretty off for collectives and mms. Adding optional usage of: - c10d.time_estimator for collectives, which is based on NCCL estimator Benchmark mode only for matmuls, as they are highly dependent on mm backend - The logic mostly copied from Ruisi's PRs for inductor simple_fsdp #157572 This estimations corrections are in default `BaseSchedulerNode.estimate_runtime()` cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben Differential Revision: [D81152294](https://our.internmc.facebook.com/intern/diff/D81152294) [ghstack-poisoned]
|
@IvanKobzarev has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
…y benchmark mode" During comms reordering , sink wait iterative observed previous runtime estimations pretty off for collectives and mms. Adding optional usage of: - c10d.time_estimator for collectives, which is based on NCCL estimator Benchmark mode only for matmuls, as they are highly dependent on mm backend - The logic mostly copied from Ruisi's PRs for inductor simple_fsdp #157572 This estimations corrections are in default `BaseSchedulerNode.estimate_runtime()` cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben Differential Revision: [D81152294](https://our.internmc.facebook.com/intern/diff/D81152294) [ghstack-poisoned]
|
@IvanKobzarev has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
| # for built-in estimation function, pass in "default"; for user-defined estimation function, pass in the function handle | ||
| estimate_op_runtime = "default" | ||
|
|
||
| runtime_estimations_mms_benchmark: bool = False |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: Command Details for Dev Infra teamRaised by workflow job |
…y benchmark mode" During comms reordering , sink wait iterative observed previous runtime estimations pretty off for collectives and mms. Adding optional usage of: - c10d.time_estimator for collectives, which is based on NCCL estimator Benchmark mode only for matmuls, as they are highly dependent on mm backend - The logic mostly copied from Ruisi's PRs for inductor simple_fsdp #157572 This estimations corrections are in default `BaseSchedulerNode.estimate_runtime()` cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben Differential Revision: [D81152294](https://our.internmc.facebook.com/intern/diff/D81152294) [ghstack-poisoned]
|
@IvanKobzarev has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
…y benchmark mode" During comms reordering , sink wait iterative observed previous runtime estimations pretty off for collectives and mms. Adding optional usage of: - c10d.time_estimator for collectives, which is based on NCCL estimator Benchmark mode only for matmuls, as they are highly dependent on mm backend - The logic mostly copied from Ruisi's PRs for inductor simple_fsdp #157572 This estimations corrections are in default `BaseSchedulerNode.estimate_runtime()` cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben Differential Revision: [D81152294](https://our.internmc.facebook.com/intern/diff/D81152294) [ghstack-poisoned]
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
… mode (pytorch#161405) During comms reordering , sink wait iterative observed previous runtime estimations pretty off for collectives and mms. Adding optional usage of: - c10d.time_estimator for collectives, which is based on NCCL estimator Benchmark mode only for matmuls, as they are highly dependent on mm backend - The logic mostly copied from Ruisi's PRs for inductor simple_fsdp pytorch#157572 This estimations corrections are in default `BaseSchedulerNode.estimate_runtime()` Differential Revision: [D81152294](https://our.internmc.facebook.com/intern/diff/D81152294) Pull Request resolved: pytorch#161405 Approved by: https://github.com/eellison
… mode (pytorch#161405) During comms reordering , sink wait iterative observed previous runtime estimations pretty off for collectives and mms. Adding optional usage of: - c10d.time_estimator for collectives, which is based on NCCL estimator Benchmark mode only for matmuls, as they are highly dependent on mm backend - The logic mostly copied from Ruisi's PRs for inductor simple_fsdp pytorch#157572 This estimations corrections are in default `BaseSchedulerNode.estimate_runtime()` Differential Revision: [D81152294](https://our.internmc.facebook.com/intern/diff/D81152294) Pull Request resolved: pytorch#161405 Approved by: https://github.com/eellison
… mode (pytorch#161405) During comms reordering , sink wait iterative observed previous runtime estimations pretty off for collectives and mms. Adding optional usage of: - c10d.time_estimator for collectives, which is based on NCCL estimator Benchmark mode only for matmuls, as they are highly dependent on mm backend - The logic mostly copied from Ruisi's PRs for inductor simple_fsdp pytorch#157572 This estimations corrections are in default `BaseSchedulerNode.estimate_runtime()` Differential Revision: [D81152294](https://our.internmc.facebook.com/intern/diff/D81152294) Pull Request resolved: pytorch#161405 Approved by: https://github.com/eellison
… mode (pytorch#161405) During comms reordering , sink wait iterative observed previous runtime estimations pretty off for collectives and mms. Adding optional usage of: - c10d.time_estimator for collectives, which is based on NCCL estimator Benchmark mode only for matmuls, as they are highly dependent on mm backend - The logic mostly copied from Ruisi's PRs for inductor simple_fsdp pytorch#157572 This estimations corrections are in default `BaseSchedulerNode.estimate_runtime()` Differential Revision: [D81152294](https://our.internmc.facebook.com/intern/diff/D81152294) Pull Request resolved: pytorch#161405 Approved by: https://github.com/eellison
… mode ghstack-source-id: d4b213b Pull Request resolved: pytorch/pytorch#161405
Stack from ghstack (oldest at bottom):
During comms reordering , sink wait iterative observed previous runtime estimations pretty off for collectives and mms.
Adding optional usage of:
Benchmark mode only for matmuls, as they are highly dependent on mm backend
This estimations corrections are in default
BaseSchedulerNode.estimate_runtime()cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @ezyang @msaroufim @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben
Differential Revision: D81152294