We want to support DeepSeekV3-style FP8 blockwise training in torchao for both dense and MoE models.
Support for dense models (linears)
We can extend the fp8 blockwise training prototype for dense models here which has the core functionality complete, but performance is unoptimized.
The work has been broken down into the following tasks, which anyone is free to work on:
Support for MoE layers (grouped GEMMs)
We can extend the low precision MoE training code here to support fp8 blockwise by doing the following:
We want to support DeepSeekV3-style FP8 blockwise training in torchao for both dense and MoE models.
Support for dense models (linears)
We can extend the fp8 blockwise training prototype for dense models here which has the core functionality complete, but performance is unoptimized.
The work has been broken down into the following tasks, which anyone is free to work on:
output = input @ weight.t()dgrad = grad_output @ weightwgrad = grad_output.t() @ inputtorch.compilecomposabilityquantize_model conversion api peforming module swap of nn.Linear to FP8BlockwiseLinear (wraps autograd func)torchao.float8moduleSupport for MoE layers (grouped GEMMs)
We can extend the low precision MoE training code here to support fp8 blockwise by doing the following:
output = input @ weight.transpose(-2,-1)dgrad = grad_output @ weightwgrad = grad_output.transpose(-2,-1) @ inputtorch.compilecomposabilityquantize_model conversion api peforming module swap of nn.Linear to FP8BlockwiseLinear (wraps autograd func)