What this cast is doing
- reshape the tensor into shape of (-1, block_size), where block_size is usually 32 or 16
- for each block, calculate a single scale, and then cast that block to torch.float8_e4m3fn
- do ^ across both dim0 and dim1
What we currently see from inductor is two kernels, one for dim0 and one for dim1:
TORCH_LOGS_FORMAT=short TORCH_LOGS=aot_graphs,output_code python benchmarks/float8/profile_lowp_training.py ~/local/tmp/20250223_test --mx_recipe_name mxfp8_emulated --experiment_filter lowp --mode_filter cast_only_dim0_dim1
Output: https://gist.github.com/vkuzo/7a9f104872790e58b316c7ba477fcbf5
A mx-compliant 32x32 block of a bfloat16 tensor occupies 2kib of memory, so it should easily fit into shared memory of an SM on a modern GPU. We should explore doing this cast across dim0 and dim1 in a tiled fashion, so we can load each tile to shared memory only once.
What this cast is doing
What we currently see from inductor is two kernels, one for dim0 and one for dim1:
Output: https://gist.github.com/vkuzo/7a9f104872790e58b316c7ba477fcbf5
A mx-compliant 32x32 block of a bfloat16 tensor occupies 2kib of memory, so it should easily fit into shared memory of an SM on a modern GPU. We should explore doing this cast across dim0 and dim1 in a tiled fashion, so we can load each tile to shared memory only once.