Skip to content

[inductor] Runtime estimations: use nccl estimator; mm only benchmark mode#161405

Closed
IvanKobzarev wants to merge 16 commits intogh/IvanKobzarev/140/basefrom
gh/IvanKobzarev/140/head
Closed

[inductor] Runtime estimations: use nccl estimator; mm only benchmark mode#161405
IvanKobzarev wants to merge 16 commits intogh/IvanKobzarev/140/basefrom
gh/IvanKobzarev/140/head

Conversation

@IvanKobzarev
Copy link
Copy Markdown
Contributor

@IvanKobzarev IvanKobzarev commented Aug 25, 2025

Stack from ghstack (oldest at bottom):

During comms reordering , sink wait iterative observed previous runtime estimations pretty off for collectives and mms.

Adding optional usage of:

  • c10d.time_estimator for collectives, which is based on NCCL estimator

Benchmark mode only for matmuls, as they are highly dependent on mm backend

This estimations corrections are in default BaseSchedulerNode.estimate_runtime()

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @ezyang @msaroufim @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben

Differential Revision: D81152294

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Aug 25, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/161405

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 3e76e04 with merge base 5b90e85 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

…y benchmark mode"



During comms reordering , sink wait iterative observed previous runtime estimations pretty off for collectives and mms.

Adding optional usage of:
- c10d.time_estimator for collectives, which is based on NCCL estimator

Benchmark mode only for matmuls, as they are highly dependent on mm backend

- The logic mostly copied from Ruisi's PRs for inductor simple_fsdp #157572

This estimations corrections are in default `BaseSchedulerNode.estimate_runtime()`




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
IvanKobzarev added a commit to IvanKobzarev/pytorch that referenced this pull request Aug 26, 2025
@IvanKobzarev IvanKobzarev mentioned this pull request Aug 27, 2025
…y benchmark mode"



During comms reordering , sink wait iterative observed previous runtime estimations pretty off for collectives and mms.

Adding optional usage of:
- c10d.time_estimator for collectives, which is based on NCCL estimator

Benchmark mode only for matmuls, as they are highly dependent on mm backend

- The logic mostly copied from Ruisi's PRs for inductor simple_fsdp #157572

This estimations corrections are in default `BaseSchedulerNode.estimate_runtime()`




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
…y benchmark mode"



During comms reordering , sink wait iterative observed previous runtime estimations pretty off for collectives and mms.

Adding optional usage of:
- c10d.time_estimator for collectives, which is based on NCCL estimator

Benchmark mode only for matmuls, as they are highly dependent on mm backend

- The logic mostly copied from Ruisi's PRs for inductor simple_fsdp #157572

This estimations corrections are in default `BaseSchedulerNode.estimate_runtime()`




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
@IvanKobzarev
Copy link
Copy Markdown
Contributor Author

@IvanKobzarev has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@pytorch-bot pytorch-bot Bot added the ciflow/trunk Trigger trunk jobs on your pull request label Aug 27, 2025
@IvanKobzarev
Copy link
Copy Markdown
Contributor Author

@IvanKobzarev has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

…y benchmark mode"



During comms reordering , sink wait iterative observed previous runtime estimations pretty off for collectives and mms.

Adding optional usage of:
- c10d.time_estimator for collectives, which is based on NCCL estimator

Benchmark mode only for matmuls, as they are highly dependent on mm backend

- The logic mostly copied from Ruisi's PRs for inductor simple_fsdp #157572

This estimations corrections are in default `BaseSchedulerNode.estimate_runtime()`




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

Differential Revision: [D81152294](https://our.internmc.facebook.com/intern/diff/D81152294)

[ghstack-poisoned]
@pytorch-bot pytorch-bot Bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Aug 29, 2025
@IvanKobzarev
Copy link
Copy Markdown
Contributor Author

@IvanKobzarev has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

Comment thread torch/_inductor/comm_analysis.py
Comment thread torch/_inductor/comm_analysis.py Outdated
}
elif name == "torch.ops._c10d_functional.all_gather_into_tensor_out.default":
# TODO: use real all_gather_into_tensor_out
fn = torch.ops._c10d_functional.all_gather_into_tensor
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious why we use all_gather_into_tensor here for all_gather_into_tensor_out?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I reused arguments parsing for all_gather_into_tensor,
as collective work should be the same.

But in _out variant we do not have memory allocation

Comment thread torch/_inductor/comm_analysis.py Outdated
fn = torch.ops._c10d_functional.all_to_all_single
# Artificial uniform split assumption,
# which can be not true in case of uneven sharding.
split_sizes = [in_t.size(0) // pg_size] * pg_size
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if in_t.size(0) is 5 and pg_size is 2, then split_sizes is [2,2]? If so, it only covers 4 not 5?

Comment thread torch/_inductor/comm_analysis.py Outdated

tensor_size_mult = 1.0
if coll == NCCL_COLL.ALL_TO_ALL:
tensor_size_mult = 2.0 / group_size
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are 2.0 divided by group_size?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, just added hacky multiplier to be closer to the values during benchmarking.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The rest of this code is a direct port of the nccl logic

    The following heuristics are copied from https://github.com/NVIDIA/nccl/blob/master/src/graph/tuning.cc.
    We aim to estimate the runtime as accurately as possible.

does it make sense to keep that property ?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also, i dont see AlltoAll here for whatevr reason.. https://github.com/NVIDIA/nccl/blob/master/src/graph/tuning.cc

Copy link
Copy Markdown
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

few comments

Comment thread torch/_inductor/comm_analysis.py Outdated
Comment on lines +191 to +206
if name == "torch.ops._c10d_functional.all_gather_into_tensor.default":
fn = torch.ops._c10d_functional.all_gather_into_tensor
return fn, {
"input": in_t,
"group_size": pg_size,
"group_name": pg_name,
}
elif name == "torch.ops._c10d_functional.all_gather_into_tensor_out.default":
# TODO: use real all_gather_into_tensor_out
fn = torch.ops._c10d_functional.all_gather_into_tensor
return fn, {
"input": in_t,
"group_size": pg_size,
"group_name": pg_name,
}
elif name == "torch.ops._c10d_functional.reduce_scatter_tensor.default":
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the way you generating input nodes below generically would be better here as well

Comment thread torch/_inductor/scheduler.py Outdated
Comment on lines +943 to +948
args = snode.node.inputs
args = snode.node.fill_non_provided_args(
[*args, *snode.node.constant_args], snode.node.kwargs
)
kwargs = snode.node.kwargs
flat_args, flat_args_pytree_spec = pytree.tree_flatten((args, kwargs))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are the same utilities we could have used above for getting the nccl inputs.

Comment thread torch/_inductor/scheduler.py Outdated
Comment on lines +982 to +998
num_iters = 3
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
cpu_start = time.time()
start_event.record(torch.cuda.current_stream())
for _ in range(num_iters):
fn(*args, **kwargs)
end_event.record(torch.cuda.current_stream())
cpu_end = time.time()
torch.cuda.synchronize()
cpu_time = cpu_end - cpu_start
total_op_time = start_event.elapsed_time(end_event) - cpu_time
mean_op_time_ms = total_op_time / num_iters
del flat_args
mean_op_time_ns = mean_op_time_ms * 1e6
cache.put(cache_key, mean_op_time_ns)
return mean_op_time_ns
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we reuse benchmark_gpu here ?

Comment thread torch/_inductor/scheduler.py Outdated
return value


def estimate_runtime_benchmark(snode: BaseSchedulerNode) -> Optional[float]:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you tried benchmark_fused_nodes ? this should already accomplish what this function does

Comment thread torch/_inductor/scheduler.py Outdated
super().__init__()
V.graph.scheduler = self
self.backends: dict[torch.device, BaseScheduling] = {}
self.estimate_runtime_cache = EstimateRuntimeCache()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make this a machine local cache ? See:

@functools.cache
def get_pad_cache() -> torch._inductor.codecache.LocalCache:
return torch._inductor.codecache.LocalCache()
def get_cached_should_pad(key: str) -> bool:
return get_pad_cache().lookup(key) # type: ignore[return-value]
def set_cached_should_pad(key: str, value: bool) -> None:
return get_pad_cache().set_value(key, value=value)
def get_cached_base_mm_benchmark_time(key: str) -> float:
return get_pad_cache().lookup(key) # type: ignore[return-value]
def set_cached_base_mm_benchmark_time(key: str, value: float) -> None:
return get_pad_cache().set_value(key, value=value)

Copy link
Copy Markdown
Contributor

@ruisizhang123 ruisizhang123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found there is a config.estimate_op_runtime, which allows users to parse customized op estimation function to inductor:

def estimate_op_runtime(snode: BaseSchedulerNode) -> float:
"""
Returns estimated op runtime in nanoseconds (ns)
"""
if config.estimate_op_runtime == "default":
runtime = snode.get_estimated_runtime()
else:
assert callable(config.estimate_op_runtime)
runtime = config.estimate_op_runtime(snode)
return runtime
.

It might make more sense to have runtime_estimations_use_nccl_lib_estimations and runtime_estimations_mms_benchmark as a callable function to estimate_op_runtime?

…y benchmark mode"



During comms reordering , sink wait iterative observed previous runtime estimations pretty off for collectives and mms.

Adding optional usage of:
- c10d.time_estimator for collectives, which is based on NCCL estimator

Benchmark mode only for matmuls, as they are highly dependent on mm backend

- The logic mostly copied from Ruisi's PRs for inductor simple_fsdp #157572

This estimations corrections are in default `BaseSchedulerNode.estimate_runtime()`




cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

Differential Revision: [D81152294](https://our.internmc.facebook.com/intern/diff/D81152294)

[ghstack-poisoned]
…y benchmark mode"



During comms reordering , sink wait iterative observed previous runtime estimations pretty off for collectives and mms.

Adding optional usage of:
- c10d.time_estimator for collectives, which is based on NCCL estimator

Benchmark mode only for matmuls, as they are highly dependent on mm backend

- The logic mostly copied from Ruisi's PRs for inductor simple_fsdp #157572

This estimations corrections are in default `BaseSchedulerNode.estimate_runtime()`




cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

Differential Revision: [D81152294](https://our.internmc.facebook.com/intern/diff/D81152294)

[ghstack-poisoned]
…y benchmark mode"



During comms reordering , sink wait iterative observed previous runtime estimations pretty off for collectives and mms.

Adding optional usage of:
- c10d.time_estimator for collectives, which is based on NCCL estimator

Benchmark mode only for matmuls, as they are highly dependent on mm backend

- The logic mostly copied from Ruisi's PRs for inductor simple_fsdp #157572

This estimations corrections are in default `BaseSchedulerNode.estimate_runtime()`




cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

Differential Revision: [D81152294](https://our.internmc.facebook.com/intern/diff/D81152294)

[ghstack-poisoned]
…y benchmark mode"



During comms reordering , sink wait iterative observed previous runtime estimations pretty off for collectives and mms.

Adding optional usage of:
- c10d.time_estimator for collectives, which is based on NCCL estimator

Benchmark mode only for matmuls, as they are highly dependent on mm backend

- The logic mostly copied from Ruisi's PRs for inductor simple_fsdp #157572

This estimations corrections are in default `BaseSchedulerNode.estimate_runtime()`




cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

Differential Revision: [D81152294](https://our.internmc.facebook.com/intern/diff/D81152294)

[ghstack-poisoned]
IvanKobzarev added a commit that referenced this pull request Sep 4, 2025
@IvanKobzarev
Copy link
Copy Markdown
Contributor Author

@IvanKobzarev has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

…y benchmark mode"



During comms reordering , sink wait iterative observed previous runtime estimations pretty off for collectives and mms.

Adding optional usage of:
- c10d.time_estimator for collectives, which is based on NCCL estimator

Benchmark mode only for matmuls, as they are highly dependent on mm backend

- The logic mostly copied from Ruisi's PRs for inductor simple_fsdp #157572

This estimations corrections are in default `BaseSchedulerNode.estimate_runtime()`




cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

Differential Revision: [D81152294](https://our.internmc.facebook.com/intern/diff/D81152294)

[ghstack-poisoned]
IvanKobzarev added a commit that referenced this pull request Sep 4, 2025
@IvanKobzarev
Copy link
Copy Markdown
Contributor Author

@IvanKobzarev has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

Comment thread torch/_inductor/config.py
# for built-in estimation function, pass in "default"; for user-defined estimation function, pass in the function handle
estimate_op_runtime = "default"

runtime_estimations_mms_benchmark: bool = False
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: config menu

@IvanKobzarev
Copy link
Copy Markdown
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge failed

Reason: Command git -C /home/runner/work/pytorch/pytorch cherry-pick -x fae7122e3f48111925a0cd6905383bb1e4923264 returned non-zero exit code 1

Auto-merging test/distributed/test_inductor_collectives.py
Auto-merging torch/_inductor/config.py
Auto-merging torch/_inductor/scheduler.py
Auto-merging torch/_inductor/utils.py
CONFLICT (content): Merge conflict in torch/_inductor/utils.py
error: could not apply fae7122e3f4... [inductor] Runtime estimations: use nccl estimator; mm only benchmark mode
hint: After resolving the conflicts, mark them with
hint: "git add/rm <pathspec>", then run
hint: "git cherry-pick --continue".
hint: You can instead skip this commit with "git cherry-pick --skip".
hint: To abort and get back to the state before "git cherry-pick",
hint: run "git cherry-pick --abort".
hint: Disable this message with "git config set advice.mergeConflict false"
Details for Dev Infra team Raised by workflow job

…y benchmark mode"



During comms reordering , sink wait iterative observed previous runtime estimations pretty off for collectives and mms.

Adding optional usage of:
- c10d.time_estimator for collectives, which is based on NCCL estimator

Benchmark mode only for matmuls, as they are highly dependent on mm backend

- The logic mostly copied from Ruisi's PRs for inductor simple_fsdp #157572

This estimations corrections are in default `BaseSchedulerNode.estimate_runtime()`




cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

Differential Revision: [D81152294](https://our.internmc.facebook.com/intern/diff/D81152294)

[ghstack-poisoned]
IvanKobzarev added a commit that referenced this pull request Sep 8, 2025
@IvanKobzarev
Copy link
Copy Markdown
Contributor Author

@IvanKobzarev has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

…y benchmark mode"



During comms reordering , sink wait iterative observed previous runtime estimations pretty off for collectives and mms.

Adding optional usage of:
- c10d.time_estimator for collectives, which is based on NCCL estimator

Benchmark mode only for matmuls, as they are highly dependent on mm backend

- The logic mostly copied from Ruisi's PRs for inductor simple_fsdp #157572

This estimations corrections are in default `BaseSchedulerNode.estimate_runtime()`




cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

Differential Revision: [D81152294](https://our.internmc.facebook.com/intern/diff/D81152294)

[ghstack-poisoned]
IvanKobzarev added a commit that referenced this pull request Sep 8, 2025
@IvanKobzarev
Copy link
Copy Markdown
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
… mode (pytorch#161405)

During comms reordering , sink wait iterative observed previous runtime estimations pretty off for collectives and mms.

Adding optional usage of:
- c10d.time_estimator for collectives, which is based on NCCL estimator

Benchmark mode only for matmuls, as they are highly dependent on mm backend

- The logic mostly copied from Ruisi's PRs for inductor simple_fsdp pytorch#157572

This estimations corrections are in default `BaseSchedulerNode.estimate_runtime()`

Differential Revision: [D81152294](https://our.internmc.facebook.com/intern/diff/D81152294)
Pull Request resolved: pytorch#161405
Approved by: https://github.com/eellison
mansiag05 pushed a commit to mansiag05/pytorch that referenced this pull request Sep 22, 2025
… mode (pytorch#161405)

During comms reordering , sink wait iterative observed previous runtime estimations pretty off for collectives and mms.

Adding optional usage of:
- c10d.time_estimator for collectives, which is based on NCCL estimator

Benchmark mode only for matmuls, as they are highly dependent on mm backend

- The logic mostly copied from Ruisi's PRs for inductor simple_fsdp pytorch#157572

This estimations corrections are in default `BaseSchedulerNode.estimate_runtime()`

Differential Revision: [D81152294](https://our.internmc.facebook.com/intern/diff/D81152294)
Pull Request resolved: pytorch#161405
Approved by: https://github.com/eellison
cleonard530 pushed a commit to cleonard530/pytorch that referenced this pull request Sep 22, 2025
… mode (pytorch#161405)

During comms reordering , sink wait iterative observed previous runtime estimations pretty off for collectives and mms.

Adding optional usage of:
- c10d.time_estimator for collectives, which is based on NCCL estimator

Benchmark mode only for matmuls, as they are highly dependent on mm backend

- The logic mostly copied from Ruisi's PRs for inductor simple_fsdp pytorch#157572

This estimations corrections are in default `BaseSchedulerNode.estimate_runtime()`

Differential Revision: [D81152294](https://our.internmc.facebook.com/intern/diff/D81152294)
Pull Request resolved: pytorch#161405
Approved by: https://github.com/eellison
dsashidh pushed a commit to dsashidh/pytorch that referenced this pull request Sep 26, 2025
… mode (pytorch#161405)

During comms reordering , sink wait iterative observed previous runtime estimations pretty off for collectives and mms.

Adding optional usage of:
- c10d.time_estimator for collectives, which is based on NCCL estimator

Benchmark mode only for matmuls, as they are highly dependent on mm backend

- The logic mostly copied from Ruisi's PRs for inductor simple_fsdp pytorch#157572

This estimations corrections are in default `BaseSchedulerNode.estimate_runtime()`

Differential Revision: [D81152294](https://our.internmc.facebook.com/intern/diff/D81152294)
Pull Request resolved: pytorch#161405
Approved by: https://github.com/eellison
@github-actions github-actions Bot deleted the gh/IvanKobzarev/140/head branch October 9, 2025 02:10
Khanaksahu pushed a commit to Khanaksahu/pytorch-fork that referenced this pull request Nov 17, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged module: inductor oncall: distributed Add this issue/PR to distributed oncall triage queue topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants