Conversation
Summary: We are standardizing on `Float8Linear` as the only float8 linear object: 1. the stack ending with #300 moved all of the functionality of `Float8DynamicLinear` to `Float8Linear`. The default settings of `Float8Linear` are to use dynamic scaling. 2. this PR deletes `Float8DynamicLinear` from the codebase and patches the relevant callsites in fbsource. Test Plan: ``` // all tests pass ./test_everything.sh // also run all benchmarks and verify correctness ``` Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
Summary: We are standardizing on `Float8Linear` as the only float8 linear object: 1. the stack ending with #300 moved all of the functionality of `Float8DynamicLinear` to `Float8Linear`. The default settings of `Float8Linear` are to use dynamic scaling. 2. this PR deletes `Float8DynamicLinear` from the codebase and patches the relevant callsites in fbsource. Test Plan: ``` // all tests pass ./test_everything.sh // also run all benchmarks and verify correctness ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 8ab4833 Pull Request resolved: #304
|
@vkuzo has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
| import torch.nn as nn | ||
| import torch.utils.benchmark as benchmark | ||
| from float8_experimental.float8_linear import Float8Linear | ||
| from float8_experimental.float8_linear import Float8Linear, TensorScalingType |
There was a problem hiding this comment.
How useful is this benchmark in general?
There was a problem hiding this comment.
I haven't used it recently
| # example: "x:del,w:del,dldy:dyn" | ||
| return f"x:{self.scaling_type_x.short_str()},w:{self.scaling_type_w.short_str()},dldy:{self.scaling_type_dL_dY.short_str()}" | ||
| # example: "x_del_w_del_dldy_dyn" | ||
| return f"x_{self.scaling_type_x.short_str()}_w_{self.scaling_type_w.short_str()}_dldy_{self.scaling_type_dL_dY.short_str()}" |
There was a problem hiding this comment.
Why the change out of curiosity? I think the prior version might be a little more readable
There was a problem hiding this comment.
I should have reverted this. Will follow-up in a future PR if that's ok, to make landing this PR easier.
|
|
||
| m_fp8 = get_float8_linear( | ||
| linear_type, m_ref, emulate, scaling_type_x, scaling_type_w, scaling_type_dL_dY | ||
| m_fp8 = Float8Linear.from_float( |
There was a problem hiding this comment.
calling 'swap_..' on nn.Linear module returns a model out of place. I think its fine either way
There was a problem hiding this comment.
I agree, we can make the tests use that if we want in a future PR.
| "scaling_type_dL_dY": TensorScalingType.DYNAMIC, | ||
| } | ||
| # For now, just use Float8Linear with dynamic scaling, which is the | ||
| # same behavior as Float8Linear. |
There was a problem hiding this comment.
Float8Dynamic ? But also its probably to to just say, only supports dynamic scaling for all 3 tensors, x, w, dl_dY
There was a problem hiding this comment.
agreed, let me fix in a future PR to speed up landing this, since this is a minor point.
| param.grad.div_(dist.get_world_size()) | ||
| if module_cls is Float8Linear: | ||
| sync_float8_amax_and_scale_history(model) | ||
| # TODO(future): add amax syncing once delayed scaling is supported |
There was a problem hiding this comment.
was this just an unused code path?
| return swap_linear_with_float8_linear(module, Float8Linear, **kwargs) | ||
| else: | ||
| return swap_linear_with_float8_linear(module, Float8DynamicLinear, **kwargs) | ||
| def swap_linear_with_dynamic(self, module: nn.Module, **kwargs: Any) -> nn.Module: |
There was a problem hiding this comment.
can we just remove this since this is the default?
There was a problem hiding this comment.
agreed in principle, but ideally that would be a separate PR since it's only tangentially related
|
This pull request has been merged in 8e9623a. |
Summary: Addressing a couple of nits that slipped in #304 * more defaults to dynamic * undo repr change * fix comment Test Plan: ``` ./test/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
Stack from ghstack (oldest at bottom):
Summary:
We are standardizing on
Float8Linearas the only float8 linear object:[9/x]: make dynamic scaling default in Float8Linear #300 moved
all of the functionality of
Float8DynamicLineartoFloat8Linear.The default settings of
Float8Linearare to use dynamic scaling.Float8DynamicLinearfrom the codebase and patchesthe relevant callsites in fbsource.
Test Plan:
Reviewers:
Subscribers:
Tasks:
Tags:
Differential Revision: D59342767