Skip to content

[mxpf8] Make mxfp8 dim1 cast kernel configurable#1401

Closed
danielvegamyhre wants to merge 2 commits into
mainfrom
mxcuda
Closed

[mxpf8] Make mxfp8 dim1 cast kernel configurable#1401
danielvegamyhre wants to merge 2 commits into
mainfrom
mxcuda

Conversation

@danielvegamyhre

@danielvegamyhre danielvegamyhre commented Jul 15, 2025

Copy link
Copy Markdown
Contributor

Summary

  • We recently developed a CUDA kernel in torchao to perform mxfp8 casting with scaling along dim1, which is ~1.4x faster than the previous Triton implementation, this results in e2e training speedup of 1.5% - 2.5% with torchtitan Llama3 8b with FSDP=4/8: Add CUDA kernel for MXFP8 dim1 casting ao#2513
  • The integration into torchao is finished (integration of new mxfp8 casting cuda kernel ao#2564), so we need to update torchtitan to make the kernel choice for mxfp8 dim1 cast configurable to "triton", "cuda", or "torch".

Test plan

  • Triton: NGPU=8 CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --training.steps=100 --model.converters="mx" --mx.recipe_name="mxfp8" --training.compile --mx.mxfp8_dim1_cast_kernel_choice="triton"
  • Cuda: NGPU=8 CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --training.steps=100 --model.converters="mx" --mx.recipe_name="mxfp8" --training.compile --mx.mxfp8_dim1_cast_kernel_choice="cuda"

@danielvegamyhre

Copy link
Copy Markdown
Contributor Author

cc @tianyu-l @vkuzo

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants