integration of new mxfp8 casting cuda kernel#2564
Conversation
stack-info: PR: #2564, branch: danielvegamyhre/stack/13
651b912 to
3f88897
Compare
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2564
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit cc00ef6 with merge base 95d13d5 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
stack-info: PR: #2564, branch: danielvegamyhre/stack/13
3f88897 to
20847a7
Compare
stack-info: PR: #2564, branch: danielvegamyhre/stack/13
20847a7 to
b1ed196
Compare
stack-info: PR: #2564, branch: danielvegamyhre/stack/13
b1ed196 to
bb930b6
Compare
|
Didn't read any code yet: I think updating to just accept generically strided inputs and writing out row major (required for mm kernels) is good, is that what you did? |
stack-info: PR: #2564, branch: danielvegamyhre/stack/13
bb930b6 to
b1e237f
Compare
Yep, that's correct |
stack-info: PR: #2564, branch: danielvegamyhre/stack/13
b1e237f to
c9c9da0
Compare
| from tqdm import tqdm | ||
|
|
||
| from torchao.prototype.mx_formats import MXLinearConfig | ||
| from torchao.prototype.mx_formats.config import MXFP8CastKernelChoice |
There was a problem hiding this comment.
sorry, MXFP8Dim1CastKernelChoice? since for dim0 we are always using torch.compile
There was a problem hiding this comment.
It was originally MXFP8Dim1CastKernelChoice but here we discussed naming it MXFP8CastKernelChoice for potentially adding support for dim0 and dim0+dim1 casts as well. I don't have strong opinions either way, I went ahead and changed it back to MXFP8Dim1CastKernelChoice
| CUBLAS = "cublas" | ||
|
|
||
|
|
||
| class MXFP8CastKernelChoice(Enum): |
There was a problem hiding this comment.
add some comments on what the options are?
| a, | ||
| rowwise=False, | ||
| colwise=True, | ||
| scaling_mode="floor", |
There was a problem hiding this comment.
TODO for later to allow choice of scaling modes
There was a problem hiding this comment.
Added todo on the custom op itself, with an explanation why we currently are using a string param
c9c9da0 to
7b1f899
Compare
stack-info: PR: #2564, branch: danielvegamyhre/stack/13
stack-info: PR: #2564, branch: danielvegamyhre/stack/13
7b1f899 to
cc00ef6
Compare
Stacked PRs: * __->__#1427 --- --- --- make mxfp8 dim1 cast kernel configurable ## Summary - We recently added a new CUDA kernel for the mxfp8 dim1 cast which is ~1.4x faster than the existing Triton kernel or torch.compile, and using it results in an e2e training speedup of +1.5-2.5% TPS with Llama3 8b using FSDP=4/8 (pytorch/ao#2513). The integration work for composability with torch.compile + FSDP is complete as well: pytorch/ao#2564 - This PR updates the mxfp8 user facing API to replace the boolean flag `"--mx.use_triton_for_dim1_cast=[true|false]` to `mxfp8_dim1_cast_kernel_choice=[triton|cuda|torch]` ## Test plan - Triton: `NGPU=8 CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --training.steps=100 --model.converters="mx" --mx.recipe_name="mxfp8" --training.compile --mx.mxfp8_dim1_cast_kernel_choice="triton"` - Cuda: `NGPU=8 CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --training.steps=100 --model.converters="mx" --mx.recipe_name="mxfp8" --training.compile --mx.mxfp8_dim1_cast_kernel_choice="cuda"` - Torch: `NGPU=8 CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --training.steps=100 --model.converters="mx" --mx.recipe_name="mxfp8" --training.compile --mx.mxfp8_dim1_cast_kernel_choice="torch"` ## Limitations - TP is currently not supported yet, as both the Triton kernel and CUDA kernel are affected by an issue: `RuntimeError: Attempting to use FunctionalTensor on its own. Instead, please use it with a corresponding FunctionalTensorMode()`. This is a known issue we were talking to Brian about, will continue following up on it.
Stacked PRs: * __->__#1427 --- --- --- make mxfp8 dim1 cast kernel configurable ## Summary - We recently added a new CUDA kernel for the mxfp8 dim1 cast which is ~1.4x faster than the existing Triton kernel or torch.compile, and using it results in an e2e training speedup of +1.5-2.5% TPS with Llama3 8b using FSDP=4/8 (pytorch/ao#2513). The integration work for composability with torch.compile + FSDP is complete as well: pytorch/ao#2564 - This PR updates the mxfp8 user facing API to replace the boolean flag `"--mx.use_triton_for_dim1_cast=[true|false]` to `mxfp8_dim1_cast_kernel_choice=[triton|cuda|torch]` ## Test plan - Triton: `NGPU=8 CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --training.steps=100 --model.converters="mx" --mx.recipe_name="mxfp8" --training.compile --mx.mxfp8_dim1_cast_kernel_choice="triton"` - Cuda: `NGPU=8 CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --training.steps=100 --model.converters="mx" --mx.recipe_name="mxfp8" --training.compile --mx.mxfp8_dim1_cast_kernel_choice="cuda"` - Torch: `NGPU=8 CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --training.steps=100 --model.converters="mx" --mx.recipe_name="mxfp8" --training.compile --mx.mxfp8_dim1_cast_kernel_choice="torch"` ## Limitations - TP is currently not supported yet, as both the Triton kernel and CUDA kernel are affected by an issue: `RuntimeError: Attempting to use FunctionalTensor on its own. Instead, please use it with a corresponding FunctionalTensorMode()`. This is a known issue we were talking to Brian about, will continue following up on it.
Stacked PRs: * __->__#1427 --- --- --- make mxfp8 dim1 cast kernel configurable ## Summary - We recently added a new CUDA kernel for the mxfp8 dim1 cast which is ~1.4x faster than the existing Triton kernel or torch.compile, and using it results in an e2e training speedup of +1.5-2.5% TPS with Llama3 8b using FSDP=4/8 (pytorch/ao#2513). The integration work for composability with torch.compile + FSDP is complete as well: pytorch/ao#2564 - This PR updates the mxfp8 user facing API to replace the boolean flag `"--mx.use_triton_for_dim1_cast=[true|false]` to `mxfp8_dim1_cast_kernel_choice=[triton|cuda|torch]` ## Test plan - Triton: `NGPU=8 CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --training.steps=100 --model.converters="mx" --mx.recipe_name="mxfp8" --training.compile --mx.mxfp8_dim1_cast_kernel_choice="triton"` - Cuda: `NGPU=8 CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --training.steps=100 --model.converters="mx" --mx.recipe_name="mxfp8" --training.compile --mx.mxfp8_dim1_cast_kernel_choice="cuda"` - Torch: `NGPU=8 CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --training.steps=100 --model.converters="mx" --mx.recipe_name="mxfp8" --training.compile --mx.mxfp8_dim1_cast_kernel_choice="torch"` ## Limitations - TP is currently not supported yet, as both the Triton kernel and CUDA kernel are affected by an issue: `RuntimeError: Attempting to use FunctionalTensor on its own. Instead, please use it with a corresponding FunctionalTensorMode()`. This is a known issue we were talking to Brian about, will continue following up on it.
Stacked PRs: * __->__#1427 --- --- --- make mxfp8 dim1 cast kernel configurable ## Summary - We recently added a new CUDA kernel for the mxfp8 dim1 cast which is ~1.4x faster than the existing Triton kernel or torch.compile, and using it results in an e2e training speedup of +1.5-2.5% TPS with Llama3 8b using FSDP=4/8 (pytorch/ao#2513). The integration work for composability with torch.compile + FSDP is complete as well: pytorch/ao#2564 - This PR updates the mxfp8 user facing API to replace the boolean flag `"--mx.use_triton_for_dim1_cast=[true|false]` to `mxfp8_dim1_cast_kernel_choice=[triton|cuda|torch]` ## Test plan - Triton: `NGPU=8 CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --training.steps=100 --model.converters="mx" --mx.recipe_name="mxfp8" --training.compile --mx.mxfp8_dim1_cast_kernel_choice="triton"` - Cuda: `NGPU=8 CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --training.steps=100 --model.converters="mx" --mx.recipe_name="mxfp8" --training.compile --mx.mxfp8_dim1_cast_kernel_choice="cuda"` - Torch: `NGPU=8 CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --training.steps=100 --model.converters="mx" --mx.recipe_name="mxfp8" --training.compile --mx.mxfp8_dim1_cast_kernel_choice="torch"` ## Limitations - TP is currently not supported yet, as both the Triton kernel and CUDA kernel are affected by an issue: `RuntimeError: Attempting to use FunctionalTensor on its own. Instead, please use it with a corresponding FunctionalTensorMode()`. This is a known issue we were talking to Brian about, will continue following up on it.
Stacked PRs: * __->__#1427 --- --- --- make mxfp8 dim1 cast kernel configurable ## Summary - We recently added a new CUDA kernel for the mxfp8 dim1 cast which is ~1.4x faster than the existing Triton kernel or torch.compile, and using it results in an e2e training speedup of +1.5-2.5% TPS with Llama3 8b using FSDP=4/8 (pytorch/ao#2513). The integration work for composability with torch.compile + FSDP is complete as well: pytorch/ao#2564 - This PR updates the mxfp8 user facing API to replace the boolean flag `"--mx.use_triton_for_dim1_cast=[true|false]` to `mxfp8_dim1_cast_kernel_choice=[triton|cuda|torch]` ## Test plan - Triton: `NGPU=8 CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --training.steps=100 --model.converters="mx" --mx.recipe_name="mxfp8" --training.compile --mx.mxfp8_dim1_cast_kernel_choice="triton"` - Cuda: `NGPU=8 CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --training.steps=100 --model.converters="mx" --mx.recipe_name="mxfp8" --training.compile --mx.mxfp8_dim1_cast_kernel_choice="cuda"` - Torch: `NGPU=8 CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --training.steps=100 --model.converters="mx" --mx.recipe_name="mxfp8" --training.compile --mx.mxfp8_dim1_cast_kernel_choice="torch"` ## Limitations - TP is currently not supported yet, as both the Triton kernel and CUDA kernel are affected by an issue: `RuntimeError: Attempting to use FunctionalTensor on its own. Instead, please use it with a corresponding FunctionalTensorMode()`. This is a known issue we were talking to Brian about, will continue following up on it.
Stacked PRs: * __->__#1427 --- --- --- make mxfp8 dim1 cast kernel configurable ## Summary - We recently added a new CUDA kernel for the mxfp8 dim1 cast which is ~1.4x faster than the existing Triton kernel or torch.compile, and using it results in an e2e training speedup of +1.5-2.5% TPS with Llama3 8b using FSDP=4/8 (pytorch/ao#2513). The integration work for composability with torch.compile + FSDP is complete as well: pytorch/ao#2564 - This PR updates the mxfp8 user facing API to replace the boolean flag `"--mx.use_triton_for_dim1_cast=[true|false]` to `mxfp8_dim1_cast_kernel_choice=[triton|cuda|torch]` ## Test plan - Triton: `NGPU=8 CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --training.steps=100 --model.converters="mx" --mx.recipe_name="mxfp8" --training.compile --mx.mxfp8_dim1_cast_kernel_choice="triton"` - Cuda: `NGPU=8 CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --training.steps=100 --model.converters="mx" --mx.recipe_name="mxfp8" --training.compile --mx.mxfp8_dim1_cast_kernel_choice="cuda"` - Torch: `NGPU=8 CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --training.steps=100 --model.converters="mx" --mx.recipe_name="mxfp8" --training.compile --mx.mxfp8_dim1_cast_kernel_choice="torch"` ## Limitations - TP is currently not supported yet, as both the Triton kernel and CUDA kernel are affected by an issue: `RuntimeError: Attempting to use FunctionalTensor on its own. Instead, please use it with a corresponding FunctionalTensorMode()`. This is a known issue we were talking to Brian about, will continue following up on it.
Stacked PRs: * __->__#1427 --- --- --- make mxfp8 dim1 cast kernel configurable ## Summary - We recently added a new CUDA kernel for the mxfp8 dim1 cast which is ~1.4x faster than the existing Triton kernel or torch.compile, and using it results in an e2e training speedup of +1.5-2.5% TPS with Llama3 8b using FSDP=4/8 (pytorch/ao#2513). The integration work for composability with torch.compile + FSDP is complete as well: pytorch/ao#2564 - This PR updates the mxfp8 user facing API to replace the boolean flag `"--mx.use_triton_for_dim1_cast=[true|false]` to `mxfp8_dim1_cast_kernel_choice=[triton|cuda|torch]` ## Test plan - Triton: `NGPU=8 CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --training.steps=100 --model.converters="mx" --mx.recipe_name="mxfp8" --training.compile --mx.mxfp8_dim1_cast_kernel_choice="triton"` - Cuda: `NGPU=8 CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --training.steps=100 --model.converters="mx" --mx.recipe_name="mxfp8" --training.compile --mx.mxfp8_dim1_cast_kernel_choice="cuda"` - Torch: `NGPU=8 CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --training.steps=100 --model.converters="mx" --mx.recipe_name="mxfp8" --training.compile --mx.mxfp8_dim1_cast_kernel_choice="torch"` ## Limitations - TP is currently not supported yet, as both the Triton kernel and CUDA kernel are affected by an issue: `RuntimeError: Attempting to use FunctionalTensor on its own. Instead, please use it with a corresponding FunctionalTensorMode()`. This is a known issue we were talking to Brian about, will continue following up on it.
Stacked PRs:
integration of new mxfp8 casting cuda kernel
Summary
Integrating kernel added in #2513. Custom op wrapper was recently added in #2543. Remaining code to migrate from my private repo is in this PR:
mxfp8_quantize_cudacustom op.mxfp8_quantize_cudacustom optriton_scale_swizzlekernel to accept both row major and col major inputs (since cuda kernel writes scale in col major, to avoid uncoalesced global accesses)MXFP8Dim1CastKernelChoiceenum and replace all uses of boolean flaguse_fp8_dim1_cast_triton_kernelwith it. (Default to Triton for now)Test plan
pytest test/prototype/mx_formats/test_mx_linear.py -k eager_vs_hppytest test/prototype/mx_formats/test_mx_linear.py -k compileNext steps
./test/prototype/mx_formats/test_mx_dtensor.sh:RuntimeError: Attempting to use FunctionalTensor on its own. Instead, please use it with a corresponding FunctionalTensorMode(). This is a known issue, will follow up on it.