Skip to content

torch.compile cast to mxfp8 should only require one kernel #1769

@vkuzo

Description

@vkuzo

What this cast is doing

  • reshape the tensor into shape of (-1, block_size), where block_size is usually 32 or 16
  • for each block, calculate a single scale, and then cast that block to torch.float8_e4m3fn
  • return the casted elements and the scale

We really should do this all in one kernel, but today we see two kernels

How to reproduce (requires latest main branch)

TORCH_LOGS_FORMAT=short TORCH_LOGS=aot_graphs,output_code python benchmarks/float8/profile_lowp_training.py ~/local/tmp/20250223_test --mx_recipe_name mxfp8_emulated --experiment_filter lowp --mode_filter cast_only

Output logs: https://gist.github.com/vkuzo/ce205fde5ae6b0fc223892c8a46560d4 - we currently see two kernels

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions