Skip to content

FP8 Blockwise Training Tracker #3290

@danielvegamyhre

Description

@danielvegamyhre

We want to support DeepSeekV3-style FP8 blockwise training in torchao for both dense and MoE models.

Support for dense models (linears)

We can extend the fp8 blockwise training prototype for dense models here which has the core functionality complete, but performance is unoptimized.

The work has been broken down into the following tasks, which anyone is free to work on:

  • Functionality
    • 1x128 quantization for LHS activations, write to row major layout
    • 128x1 quantization for RHS activations, write to col major layout
    • 128x128 quantization for weights, write to col major layout
    • 1x128 @ 128x128 gemm, use for:
      • output = input @ weight.t()
      • dgrad = grad_output @ weight
    • 1x128 @ 128x1 gemm, use for:
      • wgrad = grad_output.t() @ input
    • Autograd function implementing forward and backward
    • DTensor handling for TP support
    • Custom ops around all custom kernels for torch.compile composability
    • Tests for FSDP, TP
    • quantize_ model conversion api peforming module swap of nn.Linear to FP8BlockwiseLinear (wraps autograd func)
    • [P1] fp8 blockwise all-gather for FSDP (would need to ensure weight-shards are divisible by 128x128 blocks, design TBD)
  • Performance
    • all quantization kernels run at 80%+ of peak achievable memory bandwidth on Hopper
      • benchmark scripts for each quantization kernel
    • all gemm kernels run at 60%+ of peak achievable TFLOPs/sec on Hopper
      • benchmark scripts for each gemm
  • Integration into torchtitan
    • Validate loss convergence virtually identical to bf16 for 3k+ steps on full size Llama3 8b/70b
    • Validate e2e throughput (TPS) improvement in same training run as above
  • Documentation
    • README
    • torchao docsite
  • Migrate out of prototype directory, integrate into torchao.float8 module
  • High level goal and completion criteria:
    • Virtually identical convergence training DSV3 16b/671b on H100s (length of training run depends on infra availability, global batch size and other hyper params - loosely speaking let's run to a validation loss of ~2.7). See long term training stability section of this blog for reference.
    • 80%+ of roofline speedup for all linears in the dense FFN (first 3 layers of dsv3). These are big beefy FFNs, so should be achievable.
      • Calculate roofline linear layer speedup using DSV3 16b and 671b shapes, and training configs from the paper (seq_len=4096, 1 microbatch propagating through each layer at a time - so "M" dim of the GEMM will always be 4096)
      • Use these specs for roofline estimates
      • Use these model dim and inter_dim values for K and N (and vice versa).

Support for MoE layers (grouped GEMMs)

We can extend the low precision MoE training code here to support fp8 blockwise by doing the following:

  • Functionality
    • Quantization
      • 128x128 quantization compatible with 3d expert weights, write to per-expert col major layout (e.g. shape (E,N,K) with strides (N*K,1,N))
      • Per-token group 1x128 scale conversion where group boundaries are along M
      • Per-token group 128x1 scale conversion where group boundaries are along K/contracting dim
    • GEMMs
      • 1x128 @ 128x128 scaled grouped gemm
        • output = input @ weight.transpose(-2,-1)
        • dgrad = grad_output @ weight
      • 1x128 @ 128x1 scaled grouped gemm
        • wgrad = grad_output.transpose(-2,-1) @ input
    • Autograd function implementing forward and backward with dynamic quant on inputs (see mxfp8 example)
    • DTensor handling for TP support
    • Custom ops around all custom kernels for torch.compile composability
    • Tests for FSDP, TP
    • quantize_ model conversion api peforming module swap of nn.Linear to FP8BlockwiseLinear (wraps autograd func)
  • Performance
    • all quantization kernels run at 80%+ of peak achievable memory bandwidth on Hopper
      • benchmark scripts for each quantization kernel
    • all gemm kernels run at 60%+ of peak achievable TFLOPs/sec on Hopper
      • benchmark scripts for each gemm
  • Integration into torchtitan
    • Validate loss convergence virtually identical to bf16 for 3k+ steps on full size DeepSeekV3 671b
    • Validate e2e throughput (TPS) improvement in same training run as above
  • Documentation
    • README
    • torchao docsite

Metadata

Metadata

Assignees

Type

No type
No fields configured for issues without a type.

Projects

No projects

Relationships

None yet

Development

No branches or pull requests

Issue actions