mxfp8 training: add TP sharding strategy for dim1 kernel#2436
Conversation
|
Stack from ghstack (oldest at bottom): |
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2436
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 1 Cancelled JobAs of commit 68b683d with merge base c57226b ( NEW FAILURE - The following job has failed:
CANCELLED JOB - The following job was cancelled. Please retry:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| ) | ||
| # TODO(future PR): enable compile here, currently seeing | ||
| # https://www.internalfb.com/phabricator/paste/view/P1851219639 | ||
| # _test_lowp_mlp_tensor_parallelism_base( |
There was a problem hiding this comment.
need to uncomment this and run ./test/prototype/mx_formats/test_mx_dtensor.sh to reproduce
|
@vkuzo This PR seems to have introduced timeouts in torch ao CI on trunk for both CUDA and ROCm nightly runs: https://hud.pytorch.org/hud/pytorch/ao/e675ffd9745e745056cd27a5f64cacad0aebd051/1?per_page=50&name_filter=regression&mergeEphemeralLF=true |
Just fyi we are increasing timeout thresholds in #2549 and #2548 |
|
* Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned]
Summary:
Enables mxfp8 training with the dim1 triton kernel and TP, in eager mode. In detail:
Note that compile does not work yet, seeing https://www.internalfb.com/phabricator/paste/view/P1851219639
Test Plan:
Reviewers:
Subscribers:
Tasks:
Tags: