[LocalTensor] Cache DeviceMesh.get_coordinate results in LocalTensorMode#173836
[LocalTensor] Cache DeviceMesh.get_coordinate results in LocalTensorMode#173836wconstab wants to merge 7 commits intogh/wconstab/511/basefrom
Conversation
The get_coordinate method was being called repeatedly with the same DeviceMesh during operations like DTensor.from_local, recomputing the same coordinate mapping each time. This adds a per-mode cache keyed by mesh id to avoid redundant computation. In profiling of sharding rule validation, get_coordinate accounted for ~86% of from_local call time. With caching, from_local latency dropped from 4.55ms to 0.76ms (83% reduction). Authored with Claude. [ghstack-poisoned]
This PR needs a
|
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/173836
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 1ad1fb9 with merge base 4b0f7fb ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
The get_coordinate method was being called repeatedly with the same DeviceMesh during operations like DTensor.from_local, recomputing the same coordinate mapping each time. This adds a per-mode cache keyed by mesh id to avoid redundant computation. In profiling of sharding rule validation, get_coordinate accounted for ~86% of from_local call time. With caching, from_local latency dropped from 4.55ms to 0.76ms (83% reduction). Authored with Claude. ghstack-source-id: 02756a0 Pull Request resolved: #173836
…ocalTensorMode" The get_coordinate method was being called repeatedly with the same DeviceMesh during operations like DTensor.from_local, recomputing the same coordinate mapping each time. This adds a per-mode cache keyed by mesh id to avoid redundant computation. In profiling of sharding rule validation, get_coordinate accounted for ~86% of from_local call time. With caching, from_local latency dropped from 4.55ms to 0.76ms (83% reduction). Authored with Claude. [ghstack-poisoned]
…ocalTensorMode" The get_coordinate method was being called repeatedly with the same DeviceMesh during operations like DTensor.from_local, recomputing the same coordinate mapping each time. This adds a per-mode cache keyed by mesh id to avoid redundant computation. In profiling of sharding rule validation, get_coordinate accounted for ~86% of from_local call time. With caching, from_local latency dropped from 4.55ms to 0.76ms (83% reduction). Authored with Claude. [ghstack-poisoned]
…ocalTensorMode" The get_coordinate method was being called repeatedly with the same DeviceMesh during operations like DTensor.from_local, recomputing the same coordinate mapping each time. This adds a per-mode cache keyed by mesh id to avoid redundant computation. In profiling of sharding rule validation, get_coordinate accounted for ~86% of from_local call time. With caching, from_local latency dropped from 4.55ms to 0.76ms (83% reduction). Authored with Claude. [ghstack-poisoned]
…ocalTensorMode" The get_coordinate method was being called repeatedly with the same DeviceMesh during operations like DTensor.from_local, recomputing the same coordinate mapping each time. This adds a per-mode cache keyed by mesh id to avoid redundant computation. In profiling of sharding rule validation, get_coordinate accounted for ~86% of from_local call time. With caching, from_local latency dropped from 4.55ms to 0.76ms (83% reduction). Authored with Claude. [ghstack-poisoned]
…ocalTensorMode" The get_coordinate method was being called repeatedly with the same DeviceMesh during operations like DTensor.from_local, recomputing the same coordinate mapping each time. This adds a per-mode cache keyed by mesh id to avoid redundant computation. In profiling of sharding rule validation, get_coordinate accounted for ~86% of from_local call time. With caching, from_local latency dropped from 4.55ms to 0.76ms (83% reduction). Authored with Claude. [ghstack-poisoned]
The get_coordinate method was being called repeatedly with the same DeviceMesh during operations like DTensor.from_local, recomputing the same coordinate mapping each time. This adds a per-mode cache keyed by mesh id to avoid redundant computation. In profiling of sharding rule validation, get_coordinate accounted for ~86% of from_local call time. With caching, from_local latency dropped from 4.55ms to 0.76ms (83% reduction). Authored with Claude. ghstack-source-id: df2f1e5 Pull Request resolved: #173836
…ocalTensorMode" The get_coordinate method was being called repeatedly with the same DeviceMesh during operations like DTensor.from_local, recomputing the same coordinate mapping each time. This adds a per-mode cache keyed by mesh id to avoid redundant computation. In profiling of sharding rule validation, get_coordinate accounted for ~86% of from_local call time. With caching, from_local latency dropped from 4.55ms to 0.76ms (83% reduction). Authored with Claude. [ghstack-poisoned]
The get_coordinate method was being called repeatedly with the same DeviceMesh during operations like DTensor.from_local, recomputing the same coordinate mapping each time. This adds a per-mode cache keyed by mesh id to avoid redundant computation. In profiling of sharding rule validation, get_coordinate accounted for ~86% of from_local call time. With caching, from_local latency dropped from 4.55ms to 0.76ms (83% reduction). Authored with Claude. ghstack-source-id: 97e54a8 Pull Request resolved: #173836
|
@pytorchbot merge |
Merge failedReason: This PR needs a If not, please add the To add a label, you can comment to pytorchbot, for example For more information, see Details for Dev Infra teamRaised by workflow job |
This PR needs a
|
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…ode (pytorch#173836) The get_coordinate method was being called repeatedly with the same DeviceMesh during operations like DTensor.from_local, recomputing the same coordinate mapping each time. This adds a per-mode cache keyed by mesh id to avoid redundant computation. In profiling of sharding rule validation, get_coordinate accounted for ~86% of from_local call time. With caching, from_local latency dropped from 4.55ms to 0.76ms (83% reduction). Authored with Claude. Pull Request resolved: pytorch#173836 Approved by: https://github.com/dzmitry-huba
Stack from ghstack (oldest at bottom):
The get_coordinate method was being called repeatedly with the same
DeviceMesh during operations like DTensor.from_local, recomputing the
same coordinate mapping each time. This adds a per-mode cache keyed
by mesh id to avoid redundant computation.
In profiling of sharding rule validation, get_coordinate accounted for
~86% of from_local call time. With caching, from_local latency dropped
from 4.55ms to 0.76ms (83% reduction).
Authored with Claude.