[float8] add _auto_filter_for_recipe for float8 training#1319
Conversation
tianyu-l
left a comment
There was a problem hiding this comment.
Sounds good to me. Thank you for the studies and efforts!
Let's also modify helper message to reflect this change
https://github.com/pytorch/torchtitan/blob/main/torchtitan/config_manager.py#L504
| self.enabled = False | ||
|
|
||
| float8_config: Float8 = job_config.float8 | ||
| self.float8_config: Float8 = job_config.float8 |
There was a problem hiding this comment.
Having both self.float8_config and self.config sounds confusing.
Can we define self.filter_fn in __init__() so that we don't need self.float8_config or self.filter_fqns?
There was a problem hiding this comment.
Makes sense, updated.
| from torchao.float8 import _auto_filter_for_recipe | ||
|
|
||
| # Mutates the model inplace replacing instances of nn.Linear with Float8Linear | ||
| filter_fn = _auto_filter_for_recipe( |
There was a problem hiding this comment.
how about mx quantization? would it also suffer from the issue / benefit from auto filtering?
There was a problem hiding this comment.
Probably, but we don't have finalized perf numbers to reference to make an autofilter function for it (like the one added here https://github.com/pytorch/ao/pull/2312/files). We should add an auto filter option like this for mxfp8 once we can though.
|
I think it's better to have this off by default and make it easy to enable, to keep the defaults dead simple. Some challenges with this filtering is that it is not aware of |
Makes sense. How about this API to enable the auto filter: torchtitan/train.py ... --float8.filter_fqns="auto_filter"toml: [float8]
filter_fqns = ["auto_filter"]What do you think? This string could theoretically be part of a FQN but I think it's unlikely and we could document it clearly. |
78d91f3 to
89044d4
Compare
| * `--float8.precompute_float8_dynamic_scale_for_fsdp` (optional): communicate AMAX/scales efficiently in a single all-reduce for all parameters instead of doing many small all-reduce for each parameter. | ||
| * `--float8.force_recompute_fp8_weight_in_bwd` (optional): force recomputation of fp8 weights during backward pass, preventing unsharded fp8 weights from being saved for backward. | ||
| * `--float8.filter_fqns="..."` (optional): a comma separated list of fully qualified names of modules not to convert to float8 training. Example: `--float8.filter_fqns="attention.wk,attention.wv"`. You can determine which layers to convert by looking at the microbenchmarks in the [performance section](https://github.com/pytorch/ao/tree/main/torchao/float8#performance) of the torchao documentation for the float8 recipe you're using. | ||
| * **Auto-filter**: use `--float8.filter_fqns="auto_filter"` to enable automatic module filtering, which will automatically not convert linear layers that are not large enough to benefit from float8 training. The thresholds for conversion are based on microbenchmarks measured on NVIDIA H100 GPUs. For best performance, you should still manually filter out layers that are too small to benefit from float8 training. |
There was a problem hiding this comment.
nit 1: would be good to enable the user to filter out module foo and then filter out other modules with the auto filter
nit 2: would be good to make the flag name more specific, for example auto_filter_low_kn instead of auto_filter. I guess this applies to torchao as well, sorry for not catching in initial review.
There was a problem hiding this comment.
nit 2: would be good to make the flag name more specific, for example auto_filter_low_kn instead of auto_filter. I guess this applies to torchao as well, sorry for not catching in initial review.
Made the name more explicit: auto_filter_small_kn
nit 1: would be good to enable the user to filter out module foo and then filter out other modules with the auto filter
I agree, I updated it so the API is to just include "auto_filter_small_kn" flag as one of the FQNs, instead of the only one. This way, the rest of the FQNs specified are processed as usual for filtering.
tianyu-l
left a comment
There was a problem hiding this comment.
LGTM, please address remaining comments and questions
| * `--float8.precompute_float8_dynamic_scale_for_fsdp` (optional): communicate AMAX/scales efficiently in a single all-reduce for all parameters instead of doing many small all-reduce for each parameter. | ||
| * `--float8.force_recompute_fp8_weight_in_bwd` (optional): force recomputation of fp8 weights during backward pass, preventing unsharded fp8 weights from being saved for backward. | ||
| * `--float8.filter_fqns="..."` (optional): a comma separated list of fully qualified names of modules not to convert to float8 training. Example: `--float8.filter_fqns="attention.wk,attention.wv"`. You can determine which layers to convert by looking at the microbenchmarks in the [performance section](https://github.com/pytorch/ao/tree/main/torchao/float8#performance) of the torchao documentation for the float8 recipe you're using. | ||
| * **Auto-filter**: add `"auto_filter_low_kn"` as one of the `--float8.filter_fqns=...` to to enable automatic module filtering, which will automatically not convert linear layers whose K,N dimensions are not large enough to benefit from float8 training. The thresholds for conversion are based on microbenchmarks measured on NVIDIA H100 GPUs. For best performance, you should still manually filter out layers that are too small to benefit from float8 training. |
There was a problem hiding this comment.
whose K,N dimensions are not large enough to benefit from float8 training
Could you educate me more on what are K,N dimensions and why float8 doesn't benefit much if K,N are not large enough?
Users might also have such doubts, so might be good to explain a bit more in the short manual.
There was a problem hiding this comment.
Sure, I'll add some more info to this doc, but basically the K and N dimensions are referring to the GEMM operation between the inputs and weights of a linear layer => (M,K) @ (K,N) = (M,N). So in this context, the linear layer has shape K,N. (technically, the weight is N,K row-major then is transposed for the matmul X @ W^T).
Our microbenchmarking shows there are certain size thresholds for the linear layer K and N, below which the performance of fp8 linear was always worse than bf16. Basically, the GEMMs have to be big enough that the speedup from using FP8 tensorcores is greater than the overhead of creating dynamically quantized inputs.
The threshholds are different for tensorwise scaling vs rowwise scaling - you can check out these performance tables to get an idea of when it makes sense to convert a linear layer to float8 or not: https://github.com/pytorch/ao/tree/main/torchao/float8#performance
For example, for tensorwise scaling, if K <= 4096 and N <= 1024, all of our benchmarks showed worse performance than bf16, for all tested values of M (from 1024 to 16384).
It's possible for very large values of M, beyond what we tested, the perf change could be positive. However, this auto filter is not intended to be universally optiminal in all cases - it's just a simple way users can avoid hitting this common footgun that causes fp8 to seemingly perform worse than bf16, without needing to do manual layer analysis + cross-referencing with our performance tables to manually filter out layers.
For the best results, users should still do layer analysis and not rely on this heuristic based auto filter that doesn't account for M.
There was a problem hiding this comment.
Updated docs
| * `--float8.precompute_float8_dynamic_scale_for_fsdp` (optional): communicate AMAX/scales efficiently in a single all-reduce for all parameters instead of doing many small all-reduce for each parameter. | ||
| * `--float8.force_recompute_fp8_weight_in_bwd` (optional): force recomputation of fp8 weights during backward pass, preventing unsharded fp8 weights from being saved for backward. | ||
| * `--float8.filter_fqns="..."` (optional): a comma separated list of fully qualified names of modules not to convert to float8 training. Example: `--float8.filter_fqns="attention.wk,attention.wv"`. You can determine which layers to convert by looking at the microbenchmarks in the [performance section](https://github.com/pytorch/ao/tree/main/torchao/float8#performance) of the torchao documentation for the float8 recipe you're using. | ||
| * **Auto-filter**: add `"auto_filter_low_kn"` as one of the `--float8.filter_fqns=...` to to enable automatic module filtering, which will automatically not convert linear layers whose K,N dimensions are not large enough to benefit from float8 training. The thresholds for conversion are based on microbenchmarks measured on NVIDIA H100 GPUs. For best performance, you should still manually filter out layers that are too small to benefit from float8 training. |
There was a problem hiding this comment.
let's be consistent with low_kn vs. small_kn
There was a problem hiding this comment.
thanks, fixed
| # remove auto filter flag from filter_fqns before passing to _auto_filter_for_recipe | ||
| fqns = [ | ||
| fqn | ||
| for fqn in float8_config.filter_fqns | ||
| if fqn != AUTO_FILTER_SMALL_KN_FLAG | ||
| ] |
There was a problem hiding this comment.
can we use list.remove(x)
| filter_fn = partial(module_filter_fn, filter_fqns=float8_config.filter_fqns) | ||
| return filter_fn |
There was a problem hiding this comment.
maybe merge the two lines
8a3a3de to
2befae4
Compare
There was a problem hiding this comment.
Thanks for the explanation on K,N values and their impact on performance.
The table in the link seems for tensorwise where K and N have relatively symmetric roles. I would imagine for rowwise, K and N would have different impact on perf -- e.g. the bigger N would incur higher overhead on quantization.
I also realized why we don't mention much on M -- it is a dimension we have no control over when converting the params.
But anyways, providing a filter fn based on empirical results makes sense.
Please make sure CI passes before merge.
d35cf32 to
9c79cfd
Compare
|
Confirmed test failure is unrelated to this change, it is a ft async checkpointing test: |
Fixes #1207 ## Problem - float8 rowwise + vanilla TP in torchtitan had flat perf with respect to bfloat16 (see #1207). - RCA In #1207 found attention.wk and attention.wv layers were so small that float8 rowwise conversion resulted in approx ~40% slowdown for those layers, which nullified the perf benefits from fp8 rowwise conversion on larger linears. - This is because the default `filter_fqns` for float8 model conversion are fine for the fp8 tensorwise recipe, but bad for the float8 rowwise recipe. ### Solution This has been a footgun for various users as well (including Poolside), so I created an "auto filter" (pytorch/ao#2410) which automatically filters Linears for a given float8 recipe, by checking for the following criteria: 1. dims not divisible by 16 (hardware requirement for float8) 2. dim sizes below thresholds that may result in worse perf **for that given recipe**, using simple heuristics based on the linked recipe perf tables above. 3. fqn matches one of the user defined `filter_fqns` It prevents users from hitting this common footgun, while also preserving the flexibility to define their model-specific fqns. ## Results Benchmarks show a ~10% TPS improvement for TP and ~15% TPS improvement for async TP (over bf16 TP baseline). Llama3 70b on 256 H100s with FSDP=32, TP=8, torch.compile, full AC, local batch size 16: - [bfloat16 baseline](https://fburl.com/mlhub/ji9smr5u) = ~597TPS - [fp8 rowwise WITH attention.wk, attention.wv converted](https://fburl.com/mlhub/cu4o6w5m) = ~600 TPS - [fp8 rowwise WITHOUT attention.wk, attention.wv converted](https://fburl.com/mlhub/mgzz309o) = ~660 TPS - [fp8 rowwise + async TP WITH attention.wk, attention.wv converted](https://fburl.com/mlhub/76q4mel9 ) = ~625 TPS - [fp8 rowwise + async TP WITHOUT attention.wk, attention.wv converted](https://fburl.com/mlhub/6b07aa4d) = ~695 TPS
Fixes pytorch#1207 ## Problem - float8 rowwise + vanilla TP in torchtitan had flat perf with respect to bfloat16 (see pytorch#1207). - RCA In pytorch#1207 found attention.wk and attention.wv layers were so small that float8 rowwise conversion resulted in approx ~40% slowdown for those layers, which nullified the perf benefits from fp8 rowwise conversion on larger linears. - This is because the default `filter_fqns` for float8 model conversion are fine for the fp8 tensorwise recipe, but bad for the float8 rowwise recipe. ### Solution This has been a footgun for various users as well (including Poolside), so I created an "auto filter" (pytorch/ao#2410) which automatically filters Linears for a given float8 recipe, by checking for the following criteria: 1. dims not divisible by 16 (hardware requirement for float8) 2. dim sizes below thresholds that may result in worse perf **for that given recipe**, using simple heuristics based on the linked recipe perf tables above. 3. fqn matches one of the user defined `filter_fqns` It prevents users from hitting this common footgun, while also preserving the flexibility to define their model-specific fqns. ## Results Benchmarks show a ~10% TPS improvement for TP and ~15% TPS improvement for async TP (over bf16 TP baseline). Llama3 70b on 256 H100s with FSDP=32, TP=8, torch.compile, full AC, local batch size 16: - [bfloat16 baseline](https://fburl.com/mlhub/ji9smr5u) = ~597TPS - [fp8 rowwise WITH attention.wk, attention.wv converted](https://fburl.com/mlhub/cu4o6w5m) = ~600 TPS - [fp8 rowwise WITHOUT attention.wk, attention.wv converted](https://fburl.com/mlhub/mgzz309o) = ~660 TPS - [fp8 rowwise + async TP WITH attention.wk, attention.wv converted](https://fburl.com/mlhub/76q4mel9 ) = ~625 TPS - [fp8 rowwise + async TP WITHOUT attention.wk, attention.wv converted](https://fburl.com/mlhub/6b07aa4d) = ~695 TPS
Fixes pytorch#1207 ## Problem - float8 rowwise + vanilla TP in torchtitan had flat perf with respect to bfloat16 (see pytorch#1207). - RCA In pytorch#1207 found attention.wk and attention.wv layers were so small that float8 rowwise conversion resulted in approx ~40% slowdown for those layers, which nullified the perf benefits from fp8 rowwise conversion on larger linears. - This is because the default `filter_fqns` for float8 model conversion are fine for the fp8 tensorwise recipe, but bad for the float8 rowwise recipe. ### Solution This has been a footgun for various users as well (including Poolside), so I created an "auto filter" (pytorch/ao#2410) which automatically filters Linears for a given float8 recipe, by checking for the following criteria: 1. dims not divisible by 16 (hardware requirement for float8) 2. dim sizes below thresholds that may result in worse perf **for that given recipe**, using simple heuristics based on the linked recipe perf tables above. 3. fqn matches one of the user defined `filter_fqns` It prevents users from hitting this common footgun, while also preserving the flexibility to define their model-specific fqns. ## Results Benchmarks show a ~10% TPS improvement for TP and ~15% TPS improvement for async TP (over bf16 TP baseline). Llama3 70b on 256 H100s with FSDP=32, TP=8, torch.compile, full AC, local batch size 16: - [bfloat16 baseline](https://fburl.com/mlhub/ji9smr5u) = ~597TPS - [fp8 rowwise WITH attention.wk, attention.wv converted](https://fburl.com/mlhub/cu4o6w5m) = ~600 TPS - [fp8 rowwise WITHOUT attention.wk, attention.wv converted](https://fburl.com/mlhub/mgzz309o) = ~660 TPS - [fp8 rowwise + async TP WITH attention.wk, attention.wv converted](https://fburl.com/mlhub/76q4mel9 ) = ~625 TPS - [fp8 rowwise + async TP WITHOUT attention.wk, attention.wv converted](https://fburl.com/mlhub/6b07aa4d) = ~695 TPS
Fixes pytorch#1207 ## Problem - float8 rowwise + vanilla TP in torchtitan had flat perf with respect to bfloat16 (see pytorch#1207). - RCA In pytorch#1207 found attention.wk and attention.wv layers were so small that float8 rowwise conversion resulted in approx ~40% slowdown for those layers, which nullified the perf benefits from fp8 rowwise conversion on larger linears. - This is because the default `filter_fqns` for float8 model conversion are fine for the fp8 tensorwise recipe, but bad for the float8 rowwise recipe. ### Solution This has been a footgun for various users as well (including Poolside), so I created an "auto filter" (pytorch/ao#2410) which automatically filters Linears for a given float8 recipe, by checking for the following criteria: 1. dims not divisible by 16 (hardware requirement for float8) 2. dim sizes below thresholds that may result in worse perf **for that given recipe**, using simple heuristics based on the linked recipe perf tables above. 3. fqn matches one of the user defined `filter_fqns` It prevents users from hitting this common footgun, while also preserving the flexibility to define their model-specific fqns. ## Results Benchmarks show a ~10% TPS improvement for TP and ~15% TPS improvement for async TP (over bf16 TP baseline). Llama3 70b on 256 H100s with FSDP=32, TP=8, torch.compile, full AC, local batch size 16: - [bfloat16 baseline](https://fburl.com/mlhub/ji9smr5u) = ~597TPS - [fp8 rowwise WITH attention.wk, attention.wv converted](https://fburl.com/mlhub/cu4o6w5m) = ~600 TPS - [fp8 rowwise WITHOUT attention.wk, attention.wv converted](https://fburl.com/mlhub/mgzz309o) = ~660 TPS - [fp8 rowwise + async TP WITH attention.wk, attention.wv converted](https://fburl.com/mlhub/76q4mel9 ) = ~625 TPS - [fp8 rowwise + async TP WITHOUT attention.wk, attention.wv converted](https://fburl.com/mlhub/6b07aa4d) = ~695 TPS
Fixes #1207
Problem
filter_fqnsfor float8 model conversion are fine for the fp8 tensorwise recipe, but bad for the float8 rowwise recipe.Solution
This has been a footgun for various users as well (including Poolside), so I created an "auto filter" (pytorch/ao#2410) which automatically filters Linears for a given float8 recipe, by checking for the following criteria:
filter_fqnsIt prevents users from hitting this common footgun, while also preserving the flexibility to define their model-specific fqns.
Results
Benchmarks show a ~10% TPS improvement for TP and ~15% TPS improvement for async TP (over bf16 TP baseline).
Llama3 70b on 256 H100s with FSDP=32, TP=8, torch.compile, full AC, local batch size 16: