[mxfp8 moe training] use dim1 cast cuda kernel for 3d weights by reshaping to 2d#2998
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2998
Note: Links to docs will display an error until the docs builds have been completed. ⏳ No Failures, 10 PendingAs of commit d29104e with merge base afe5cab ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…aping to 2d stack-info: PR: #2998, branch: danielvegamyhre/stack/67
8299b35 to
2b607cb
Compare
…aping to 2d stack-info: PR: #2998, branch: danielvegamyhre/stack/67
2b607cb to
3685390
Compare
…aping to 2d stack-info: PR: #2998, branch: danielvegamyhre/stack/67
3685390 to
5d874ed
Compare
d9e910b to
8ec20df
Compare
| def bench_fwd_microseconds(fn, *args, use_compile=False, fullgraph=True, **kwargs): | ||
| fn_compiled = torch.compile(fn, fullgraph=fullgraph) if use_compile else fn | ||
|
|
||
| def inference_fn(*args, **kwargs): |
There was a problem hiding this comment.
why do we need this? forward in training is not run with torch.inference_mode()
There was a problem hiding this comment.
I know, I just wanted to measure just the quantizations ops + grouped gemm needed to produce the forward output, without pre-computing and saving things for backward, so that I can more easily see if the quant ops for fwd or bwd are what is slow without generating a trace to look at. IIRC if I didn't use inference mode, the forward graph still precomputed stuff.
8ec20df to
abb98c1
Compare
abb98c1 to
870b9aa
Compare
|
Crawling up the stack but do we need this if we are adding a kernel higher up? |
Yes, for now, because it is still currently the fastest method when E < 8. New CUDA Kernel for 3d tensors is only faster for E>8 (see relevant PR). I want to change this though and just use one kernel everywhere, would appreciate your thoughts on the kernel design. I checked NCU and it flags some minor (~15%) potential speedups from resolving bank conflicts and uncoalesced global accesses but am not confident that is really the issue, since those are present for the 2d kernel as well. |
870b9aa to
30dd4b2
Compare
30dd4b2 to
d229482
Compare
…aping to 2d stack-info: PR: #2998, branch: danielvegamyhre/stack/67
d229482 to
c36adb9
Compare
…aping to 2d stack-info: PR: #2998, branch: danielvegamyhre/stack/67
c36adb9 to
d29104e
Compare
Stacked PRs:
[mxfp8 moe training] use dim1 cast cuda kernel for 3d weights by reshaping to 2d
to_mx(B_t.contiguous())is unspeakably slow (see perf analysis in previous PR in stack)to_mx(B_t.contiguous()).Next steps
to_mxmethod and the 2d dim1 cast CUDA kernel method. This is likely due to the .contiguous() call required for both methods.Test plan
pytest test/prototype/moe_training/test_training.pyBenchmarks
Before:
After: