Skip to content

mx cast to mxfp8 across dim0 and dim1 should be performant #1788

@vkuzo

Description

@vkuzo

What this cast is doing

  • reshape the tensor into shape of (-1, block_size), where block_size is usually 32 or 16
  • for each block, calculate a single scale, and then cast that block to torch.float8_e4m3fn
  • do ^ across both dim0 and dim1

What we currently see from inductor is two kernels, one for dim0 and one for dim1:

TORCH_LOGS_FORMAT=short TORCH_LOGS=aot_graphs,output_code python benchmarks/float8/profile_lowp_training.py ~/local/tmp/20250223_test --mx_recipe_name mxfp8_emulated --experiment_filter lowp --mode_filter cast_only_dim0_dim1

Output: https://gist.github.com/vkuzo/7a9f104872790e58b316c7ba477fcbf5

A mx-compliant 32x32 block of a bfloat16 tensor occupies 2kib of memory, so it should easily fit into shared memory of an SM on a modern GPU. We should explore doing this cast across dim0 and dim1 in a tiled fashion, so we can load each tile to shared memory only once.

Metadata

Metadata

Assignees

Labels

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions