Skip to content

[float8] add _auto_filter_for_recipe for float8 training#1319

Merged
danielvegamyhre merged 3 commits into
pytorch:mainfrom
danielvegamyhre:auto_filter
Jul 1, 2025
Merged

[float8] add _auto_filter_for_recipe for float8 training#1319
danielvegamyhre merged 3 commits into
pytorch:mainfrom
danielvegamyhre:auto_filter

Conversation

@danielvegamyhre

@danielvegamyhre danielvegamyhre commented Jun 18, 2025

Copy link
Copy Markdown
Contributor

Fixes #1207

Problem

  • float8 rowwise + vanilla TP in torchtitan had flat perf with respect to bfloat16 (see float8 rowwise vanilla TP low throughput #1207).
  • RCA In float8 rowwise vanilla TP low throughput #1207 found attention.wk and attention.wv layers were so small that float8 rowwise conversion resulted in approx ~40% slowdown for those layers, which nullified the perf benefits from fp8 rowwise conversion on larger linears.
  • This is because the default filter_fqns for float8 model conversion are fine for the fp8 tensorwise recipe, but bad for the float8 rowwise recipe.

Solution

This has been a footgun for various users as well (including Poolside), so I created an "auto filter" (pytorch/ao#2410) which automatically filters Linears for a given float8 recipe, by checking for the following criteria:

  1. dims not divisible by 16 (hardware requirement for float8)
  2. dim sizes below thresholds that may result in worse perf for that given recipe, using simple heuristics based on the linked recipe perf tables above.
  3. fqn matches one of the user defined filter_fqns

It prevents users from hitting this common footgun, while also preserving the flexibility to define their model-specific fqns.

Results

Benchmarks show a ~10% TPS improvement for TP and ~15% TPS improvement for async TP (over bf16 TP baseline).

Llama3 70b on 256 H100s with FSDP=32, TP=8, torch.compile, full AC, local batch size 16:

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jun 18, 2025
@danielvegamyhre danielvegamyhre changed the title [WIP] [float8] add float auto_filter_for_recipe [float8] add float auto_filter_for_recipe Jun 18, 2025
@danielvegamyhre danielvegamyhre marked this pull request as draft June 18, 2025 22:13
@danielvegamyhre danielvegamyhre changed the title [float8] add float auto_filter_for_recipe [WIP] [float8] add float auto_filter_for_recipe Jun 18, 2025
@danielvegamyhre danielvegamyhre marked this pull request as ready for review June 23, 2025 21:14
@danielvegamyhre danielvegamyhre changed the title [WIP] [float8] add float auto_filter_for_recipe [float8] add float auto_filter_for_recipe Jun 24, 2025
@danielvegamyhre danielvegamyhre changed the title [float8] add float auto_filter_for_recipe [float8] add float8 _auto_filter_for_recipe Jun 24, 2025
@danielvegamyhre danielvegamyhre changed the title [float8] add float8 _auto_filter_for_recipe [float8] add _auto_filter_for_recipe for float8 training Jun 24, 2025
@danielvegamyhre

Copy link
Copy Markdown
Contributor Author

cc @tianyu @vkuzo for review + thoughts on if this would be useful to add as the default module filter for float8 in torchtitan

@tianyu-l tianyu-l left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good to me. Thank you for the studies and efforts!

Let's also modify helper message to reflect this change
https://github.com/pytorch/torchtitan/blob/main/torchtitan/config_manager.py#L504

self.enabled = False

float8_config: Float8 = job_config.float8
self.float8_config: Float8 = job_config.float8

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having both self.float8_config and self.config sounds confusing.
Can we define self.filter_fn in __init__() so that we don't need self.float8_config or self.filter_fqns?

@danielvegamyhre danielvegamyhre Jun 27, 2025

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, updated.

from torchao.float8 import _auto_filter_for_recipe

# Mutates the model inplace replacing instances of nn.Linear with Float8Linear
filter_fn = _auto_filter_for_recipe(

@tianyu-l tianyu-l Jun 25, 2025

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about mx quantization? would it also suffer from the issue / benefit from auto filtering?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably, but we don't have finalized perf numbers to reference to make an autofilter function for it (like the one added here https://github.com/pytorch/ao/pull/2312/files). We should add an auto filter option like this for mxfp8 once we can though.

@vkuzo

vkuzo commented Jun 25, 2025

Copy link
Copy Markdown
Contributor

I think it's better to have this off by default and make it easy to enable, to keep the defaults dead simple. Some challenges with this filtering is that it is not aware of M, it is not aware of the underlying hardware, and it will behave unexpectedly on the debug model. How about we just make this easy to enable and add documentation recommending to enable it?

@danielvegamyhre

danielvegamyhre commented Jun 27, 2025

Copy link
Copy Markdown
Contributor Author

I think it's better to have this off by default and make it easy to enable, to keep the defaults dead simple. Some challenges with this filtering is that it is not aware of M, it is not aware of the underlying hardware, and it will behave unexpectedly on the debug model. How about we just make this easy to enable and add documentation recommending to enable it?

Makes sense. How about this API to enable the auto filter:

torchtitan/train.py ... --float8.filter_fqns="auto_filter"

toml:

[float8]
filter_fqns = ["auto_filter"]

What do you think? This string could theoretically be part of a FQN but I think it's unlikely and we could document it clearly.

@danielvegamyhre danielvegamyhre force-pushed the auto_filter branch 5 times, most recently from 78d91f3 to 89044d4 Compare June 27, 2025 06:44
Comment thread docs/float8.md Outdated
* `--float8.precompute_float8_dynamic_scale_for_fsdp` (optional): communicate AMAX/scales efficiently in a single all-reduce for all parameters instead of doing many small all-reduce for each parameter.
* `--float8.force_recompute_fp8_weight_in_bwd` (optional): force recomputation of fp8 weights during backward pass, preventing unsharded fp8 weights from being saved for backward.
* `--float8.filter_fqns="..."` (optional): a comma separated list of fully qualified names of modules not to convert to float8 training. Example: `--float8.filter_fqns="attention.wk,attention.wv"`. You can determine which layers to convert by looking at the microbenchmarks in the [performance section](https://github.com/pytorch/ao/tree/main/torchao/float8#performance) of the torchao documentation for the float8 recipe you're using.
* **Auto-filter**: use `--float8.filter_fqns="auto_filter"` to enable automatic module filtering, which will automatically not convert linear layers that are not large enough to benefit from float8 training. The thresholds for conversion are based on microbenchmarks measured on NVIDIA H100 GPUs. For best performance, you should still manually filter out layers that are too small to benefit from float8 training.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit 1: would be good to enable the user to filter out module foo and then filter out other modules with the auto filter
nit 2: would be good to make the flag name more specific, for example auto_filter_low_kn instead of auto_filter. I guess this applies to torchao as well, sorry for not catching in initial review.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit 2: would be good to make the flag name more specific, for example auto_filter_low_kn instead of auto_filter. I guess this applies to torchao as well, sorry for not catching in initial review.

Made the name more explicit: auto_filter_small_kn

nit 1: would be good to enable the user to filter out module foo and then filter out other modules with the auto filter

I agree, I updated it so the API is to just include "auto_filter_small_kn" flag as one of the FQNs, instead of the only one. This way, the rest of the FQNs specified are processed as usual for filtering.

@tianyu-l tianyu-l left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, please address remaining comments and questions

Comment thread docs/float8.md Outdated
* `--float8.precompute_float8_dynamic_scale_for_fsdp` (optional): communicate AMAX/scales efficiently in a single all-reduce for all parameters instead of doing many small all-reduce for each parameter.
* `--float8.force_recompute_fp8_weight_in_bwd` (optional): force recomputation of fp8 weights during backward pass, preventing unsharded fp8 weights from being saved for backward.
* `--float8.filter_fqns="..."` (optional): a comma separated list of fully qualified names of modules not to convert to float8 training. Example: `--float8.filter_fqns="attention.wk,attention.wv"`. You can determine which layers to convert by looking at the microbenchmarks in the [performance section](https://github.com/pytorch/ao/tree/main/torchao/float8#performance) of the torchao documentation for the float8 recipe you're using.
* **Auto-filter**: add `"auto_filter_low_kn"` as one of the `--float8.filter_fqns=...` to to enable automatic module filtering, which will automatically not convert linear layers whose K,N dimensions are not large enough to benefit from float8 training. The thresholds for conversion are based on microbenchmarks measured on NVIDIA H100 GPUs. For best performance, you should still manually filter out layers that are too small to benefit from float8 training.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

whose K,N dimensions are not large enough to benefit from float8 training

Could you educate me more on what are K,N dimensions and why float8 doesn't benefit much if K,N are not large enough?

Users might also have such doubts, so might be good to explain a bit more in the short manual.

@danielvegamyhre danielvegamyhre Jun 29, 2025

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I'll add some more info to this doc, but basically the K and N dimensions are referring to the GEMM operation between the inputs and weights of a linear layer => (M,K) @ (K,N) = (M,N). So in this context, the linear layer has shape K,N. (technically, the weight is N,K row-major then is transposed for the matmul X @ W^T).

Our microbenchmarking shows there are certain size thresholds for the linear layer K and N, below which the performance of fp8 linear was always worse than bf16. Basically, the GEMMs have to be big enough that the speedup from using FP8 tensorcores is greater than the overhead of creating dynamically quantized inputs.

The threshholds are different for tensorwise scaling vs rowwise scaling - you can check out these performance tables to get an idea of when it makes sense to convert a linear layer to float8 or not: https://github.com/pytorch/ao/tree/main/torchao/float8#performance

For example, for tensorwise scaling, if K <= 4096 and N <= 1024, all of our benchmarks showed worse performance than bf16, for all tested values of M (from 1024 to 16384).

It's possible for very large values of M, beyond what we tested, the perf change could be positive. However, this auto filter is not intended to be universally optiminal in all cases - it's just a simple way users can avoid hitting this common footgun that causes fp8 to seemingly perform worse than bf16, without needing to do manual layer analysis + cross-referencing with our performance tables to manually filter out layers.

For the best results, users should still do layer analysis and not rely on this heuristic based auto filter that doesn't account for M.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated docs

Comment thread docs/float8.md Outdated
* `--float8.precompute_float8_dynamic_scale_for_fsdp` (optional): communicate AMAX/scales efficiently in a single all-reduce for all parameters instead of doing many small all-reduce for each parameter.
* `--float8.force_recompute_fp8_weight_in_bwd` (optional): force recomputation of fp8 weights during backward pass, preventing unsharded fp8 weights from being saved for backward.
* `--float8.filter_fqns="..."` (optional): a comma separated list of fully qualified names of modules not to convert to float8 training. Example: `--float8.filter_fqns="attention.wk,attention.wv"`. You can determine which layers to convert by looking at the microbenchmarks in the [performance section](https://github.com/pytorch/ao/tree/main/torchao/float8#performance) of the torchao documentation for the float8 recipe you're using.
* **Auto-filter**: add `"auto_filter_low_kn"` as one of the `--float8.filter_fqns=...` to to enable automatic module filtering, which will automatically not convert linear layers whose K,N dimensions are not large enough to benefit from float8 training. The thresholds for conversion are based on microbenchmarks measured on NVIDIA H100 GPUs. For best performance, you should still manually filter out layers that are too small to benefit from float8 training.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's be consistent with low_kn vs. small_kn

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, fixed

Comment on lines +118 to +123
# remove auto filter flag from filter_fqns before passing to _auto_filter_for_recipe
fqns = [
fqn
for fqn in float8_config.filter_fqns
if fqn != AUTO_FILTER_SMALL_KN_FLAG
]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we use list.remove(x)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment on lines +139 to +140
filter_fn = partial(module_filter_fn, filter_fqns=float8_config.filter_fqns)
return filter_fn

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe merge the two lines

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@tianyu-l tianyu-l left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation on K,N values and their impact on performance.

The table in the link seems for tensorwise where K and N have relatively symmetric roles. I would imagine for rowwise, K and N would have different impact on perf -- e.g. the bigger N would incur higher overhead on quantization.

I also realized why we don't mention much on M -- it is a dimension we have no control over when converting the params.

But anyways, providing a filter fn based on empirical results makes sense.

Please make sure CI passes before merge.

@danielvegamyhre

Copy link
Copy Markdown
Contributor Author

Confirmed test failure is unrelated to this change, it is a ft async checkpointing test:

tests/unit_tests/test_checkpoint.py::TestCheckpointManager::test_ft_async_save_calls_async_wait FAILED [ 12%]

@danielvegamyhre danielvegamyhre merged commit b0902b2 into pytorch:main Jul 1, 2025
6 of 7 checks passed
wwwjn pushed a commit that referenced this pull request Jul 1, 2025
Fixes #1207 

## Problem
- float8 rowwise + vanilla TP in torchtitan had flat perf with respect
to bfloat16 (see #1207).
- RCA In #1207 found attention.wk and attention.wv layers were so small
that float8 rowwise conversion resulted in approx ~40% slowdown for
those layers, which nullified the perf benefits from fp8 rowwise
conversion on larger linears.
- This is because the default `filter_fqns` for float8 model conversion
are fine for the fp8 tensorwise recipe, but bad for the float8 rowwise
recipe.

### Solution 

This has been a footgun for various users as well (including Poolside),
so I created an "auto filter" (pytorch/ao#2410)
which automatically filters Linears for a given float8 recipe, by
checking for the following criteria:

1. dims not divisible by 16 (hardware requirement for float8)
2. dim sizes below thresholds that may result in worse perf **for that
given recipe**, using simple heuristics based on the linked recipe perf
tables above.
3. fqn matches one of the user defined `filter_fqns`

It prevents users from hitting this common footgun, while also
preserving the flexibility to define their model-specific fqns.

## Results 

Benchmarks show a ~10% TPS improvement for TP and ~15% TPS improvement
for async TP (over bf16 TP baseline).

Llama3 70b on 256 H100s with FSDP=32, TP=8, torch.compile, full AC,
local batch size 16:

- [bfloat16 baseline](https://fburl.com/mlhub/ji9smr5u) = ~597TPS
- [fp8 rowwise WITH attention.wk, attention.wv
converted](https://fburl.com/mlhub/cu4o6w5m) = ~600 TPS
- [fp8 rowwise WITHOUT attention.wk, attention.wv
converted](https://fburl.com/mlhub/mgzz309o) = ~660 TPS
- [fp8 rowwise + async TP WITH attention.wk, attention.wv
converted](https://fburl.com/mlhub/76q4mel9 ) = ~625 TPS
- [fp8 rowwise + async TP WITHOUT attention.wk, attention.wv
converted](https://fburl.com/mlhub/6b07aa4d) = ~695 TPS
mori360 pushed a commit to mori360/torchtitan that referenced this pull request Jul 8, 2025
Fixes pytorch#1207 

## Problem
- float8 rowwise + vanilla TP in torchtitan had flat perf with respect
to bfloat16 (see pytorch#1207).
- RCA In pytorch#1207 found attention.wk and attention.wv layers were so small
that float8 rowwise conversion resulted in approx ~40% slowdown for
those layers, which nullified the perf benefits from fp8 rowwise
conversion on larger linears.
- This is because the default `filter_fqns` for float8 model conversion
are fine for the fp8 tensorwise recipe, but bad for the float8 rowwise
recipe.

### Solution 

This has been a footgun for various users as well (including Poolside),
so I created an "auto filter" (pytorch/ao#2410)
which automatically filters Linears for a given float8 recipe, by
checking for the following criteria:

1. dims not divisible by 16 (hardware requirement for float8)
2. dim sizes below thresholds that may result in worse perf **for that
given recipe**, using simple heuristics based on the linked recipe perf
tables above.
3. fqn matches one of the user defined `filter_fqns`

It prevents users from hitting this common footgun, while also
preserving the flexibility to define their model-specific fqns.

## Results 

Benchmarks show a ~10% TPS improvement for TP and ~15% TPS improvement
for async TP (over bf16 TP baseline).

Llama3 70b on 256 H100s with FSDP=32, TP=8, torch.compile, full AC,
local batch size 16:

- [bfloat16 baseline](https://fburl.com/mlhub/ji9smr5u) = ~597TPS
- [fp8 rowwise WITH attention.wk, attention.wv
converted](https://fburl.com/mlhub/cu4o6w5m) = ~600 TPS
- [fp8 rowwise WITHOUT attention.wk, attention.wv
converted](https://fburl.com/mlhub/mgzz309o) = ~660 TPS
- [fp8 rowwise + async TP WITH attention.wk, attention.wv
converted](https://fburl.com/mlhub/76q4mel9 ) = ~625 TPS
- [fp8 rowwise + async TP WITHOUT attention.wk, attention.wv
converted](https://fburl.com/mlhub/6b07aa4d) = ~695 TPS
xrsrke pushed a commit to NousResearch/torchtitan that referenced this pull request Feb 13, 2026
Fixes pytorch#1207 

## Problem
- float8 rowwise + vanilla TP in torchtitan had flat perf with respect
to bfloat16 (see pytorch#1207).
- RCA In pytorch#1207 found attention.wk and attention.wv layers were so small
that float8 rowwise conversion resulted in approx ~40% slowdown for
those layers, which nullified the perf benefits from fp8 rowwise
conversion on larger linears.
- This is because the default `filter_fqns` for float8 model conversion
are fine for the fp8 tensorwise recipe, but bad for the float8 rowwise
recipe.

### Solution 

This has been a footgun for various users as well (including Poolside),
so I created an "auto filter" (pytorch/ao#2410)
which automatically filters Linears for a given float8 recipe, by
checking for the following criteria:

1. dims not divisible by 16 (hardware requirement for float8)
2. dim sizes below thresholds that may result in worse perf **for that
given recipe**, using simple heuristics based on the linked recipe perf
tables above.
3. fqn matches one of the user defined `filter_fqns`

It prevents users from hitting this common footgun, while also
preserving the flexibility to define their model-specific fqns.

## Results 

Benchmarks show a ~10% TPS improvement for TP and ~15% TPS improvement
for async TP (over bf16 TP baseline).

Llama3 70b on 256 H100s with FSDP=32, TP=8, torch.compile, full AC,
local batch size 16:

- [bfloat16 baseline](https://fburl.com/mlhub/ji9smr5u) = ~597TPS
- [fp8 rowwise WITH attention.wk, attention.wv
converted](https://fburl.com/mlhub/cu4o6w5m) = ~600 TPS
- [fp8 rowwise WITHOUT attention.wk, attention.wv
converted](https://fburl.com/mlhub/mgzz309o) = ~660 TPS
- [fp8 rowwise + async TP WITH attention.wk, attention.wv
converted](https://fburl.com/mlhub/76q4mel9 ) = ~625 TPS
- [fp8 rowwise + async TP WITHOUT attention.wk, attention.wv
converted](https://fburl.com/mlhub/6b07aa4d) = ~695 TPS
xrsrke pushed a commit to NousResearch/torchtitan that referenced this pull request Feb 25, 2026
Fixes pytorch#1207 

## Problem
- float8 rowwise + vanilla TP in torchtitan had flat perf with respect
to bfloat16 (see pytorch#1207).
- RCA In pytorch#1207 found attention.wk and attention.wv layers were so small
that float8 rowwise conversion resulted in approx ~40% slowdown for
those layers, which nullified the perf benefits from fp8 rowwise
conversion on larger linears.
- This is because the default `filter_fqns` for float8 model conversion
are fine for the fp8 tensorwise recipe, but bad for the float8 rowwise
recipe.

### Solution 

This has been a footgun for various users as well (including Poolside),
so I created an "auto filter" (pytorch/ao#2410)
which automatically filters Linears for a given float8 recipe, by
checking for the following criteria:

1. dims not divisible by 16 (hardware requirement for float8)
2. dim sizes below thresholds that may result in worse perf **for that
given recipe**, using simple heuristics based on the linked recipe perf
tables above.
3. fqn matches one of the user defined `filter_fqns`

It prevents users from hitting this common footgun, while also
preserving the flexibility to define their model-specific fqns.

## Results 

Benchmarks show a ~10% TPS improvement for TP and ~15% TPS improvement
for async TP (over bf16 TP baseline).

Llama3 70b on 256 H100s with FSDP=32, TP=8, torch.compile, full AC,
local batch size 16:

- [bfloat16 baseline](https://fburl.com/mlhub/ji9smr5u) = ~597TPS
- [fp8 rowwise WITH attention.wk, attention.wv
converted](https://fburl.com/mlhub/cu4o6w5m) = ~600 TPS
- [fp8 rowwise WITHOUT attention.wk, attention.wv
converted](https://fburl.com/mlhub/mgzz309o) = ~660 TPS
- [fp8 rowwise + async TP WITH attention.wk, attention.wv
converted](https://fburl.com/mlhub/76q4mel9 ) = ~625 TPS
- [fp8 rowwise + async TP WITHOUT attention.wk, attention.wv
converted](https://fburl.com/mlhub/6b07aa4d) = ~695 TPS
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

float8 rowwise vanilla TP low throughput

4 participants