Add a way to do power of 2 scaling#2256
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2256
Note: Links to docs will display an error until the docs builds have been completed. ❌ 4 New FailuresAs of commit 5c64718 with merge base a776b1f ( NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
stack-info: PR: #2256, branch: drisspg/stack/57
| assert scale_dtype is torch.float8_e8m0fnu, "Only float8_e8m0fnuz is supported" | ||
| scale = torch.exp2(torch.round(torch.log2(scale))) | ||
|
|
||
| return scale.to(dtype=torch.float32) |
There was a problem hiding this comment.
IMO this is a really great way to express this, this switch back is the only spooky part
There was a problem hiding this comment.
Hmm so the api to use power of 2 scales for inference would be to use float8_e8m0 as the scale dtype, which is all exponent bits so only powers of 2, is that right? This is clever but does require a step of indirection that may be confusing to users, IMO it would be better to have the API be consistent with training, where it just a config option round_scales_to_powers_of_2.
stack-info: PR: #2256, branch: drisspg/stack/57
|
Referencing #2182 to link to issue |
stack-info: PR: #2256, branch: drisspg/stack/57
stack-info: PR: #2256, branch: drisspg/stack/57
stack-info: PR: #2256, branch: drisspg/stack/57
stack-info: PR: #2256, branch: drisspg/stack/57
stack-info: PR: #2256, branch: drisspg/stack/57
Stacked PRs:
Fixes: #2182
Add a way to do power of 2 scaling