Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added docs/static/fp8-rowwise-perf.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
9 changes: 8 additions & 1 deletion torchao/float8/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,9 @@ on using `torchao.float8` in a distributed setting.

# Performance

A common question about float8 training is "when is float8 linear faster vs bfloat16?". Given the M, K, N of the forward pass through your linear, you can reference the table below for a microbenchmark based speedup estimate on NVIDIA H100:
A common question about float8 training is "when is float8 linear faster vs bfloat16?". Given the M, K, N of the forward pass through your linear, you can reference the tables below for a microbenchmark based speedup estimate on NVIDIA H100:

### Tensorwise scaling

<img width="805" alt="float8_speedup" src="https://github.com/user-attachments/assets/5c5f2817-7eb7-4cab-bd03-49fe70cd31a8">

Expand All @@ -135,6 +137,11 @@ To reproduce the raw data for table above, you can run the following script
python benchmarks/float8/float8_roofline.py your_output_filename.csv --shape_gen_name sweep
```

### Rowwise scaling

<img width="805" alt="float8_rowwise_speedup" src="../../docs/static/fp8-rowwise-perf.png" />


## Derivation

In a bf16 linear, assume all of the time is spent in gemms. In a float8 linear, account for max_abs and casting overhead. We want to know when
Expand Down
Loading