[mxfp8 moe training][docs] add tutorial for training with MXFP8 expert parallel#3752
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3752
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New FailureAs of commit 49737db with merge base 30fcb15 ( NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
806eb90 to
a54d592
Compare
302d0b4 to
4270ba0
Compare
a54d592 to
8282223
Compare
8282223 to
3f8c54a
Compare
3f8c54a to
761c09d
Compare
761c09d to
cbe10e7
Compare
021de8c to
9d79704
Compare
9d79704 to
67b3b96
Compare
67b3b96 to
0c0e066
Compare
|
will need a rebase on #3769 |
c449c6d to
cc3891f
Compare
0c0e066 to
4c607e4
Compare
| ^^^^^^^^^^^^^ | ||
|
|
||
| 1. (Recommended) Create a new virtual environment with conda or venv. | ||
| 2. `Install torchao <https://github.com/pytorch/ao/tree/main?tab=readme-ov-file#installation>`__ nightly build (required for CUDA 12.8+ support). |
There was a problem hiding this comment.
Do you need torch too? Which versions of all these dependencies? For example, 2.10 or later?
| | 131072 | 2048 | 1408 | 8 | 3.278 | 2.913 | 1.13x | 4.934 | 3.881 | 1.27x | 1.21x | | ||
| +----------+-------+--------------+---------------+---------------+----------------+---------------+---------------+----------------+---------------+-----------------+ | ||
|
|
||
| As shown, using MXFP8 for all-to-all communications achieves **1.14-1.25x total speedup** versus only quantizing directly before the grouped GEMMs. |
There was a problem hiding this comment.
Can you add a conclusion? Like what did the user learn in the tutorial and where else to find more info?
There was a problem hiding this comment.
good point, done!
| # 🔥 happens in BF16 as well. | ||
| # 🔥 In the backward pass, the incoming upstream gradients will be the MXTensor outputs of the | ||
| # 🔥 MXFP8 all-to-all combine backward pass, so this unpermute autograd func accepts MXTensor | ||
| # 🔥 inputs and performs the reordering in MXFP8. |
There was a problem hiding this comment.
ok haha i was trying to draw the reader's eye to the most important parts... will de-emojify
|
How is the code tested? |
@ezyang we have unit tests for this pattern of chaining the autograd functions (with eager and compile), and we also integrated into Torchtitan and did large scale convergence + performance testing on an external cloud partner's B200 cluster (joint blog post on this coming soon!) |
…t parallel stack-info: PR: #3752, branch: danielvegamyhre/stack/127
Stacked PRs:
[mxfp8 moe training][docs] add tutorial for training with MXFP8 expert parallel