support delayed scaling of weight in float8 all-gather#312
support delayed scaling of weight in float8 all-gather#312vkuzo wants to merge 3 commits intogh/vkuzo/27/basefrom
Conversation
Summary: Adds support for delayed scaling in FSDP2 float8 all-gather. In detail: 1. add `WeightWithDelayedFloat8CastTensor`, note that we don't reuse code with the dynamic version because I'd rather not deal with plumbing optional tensors through dynamo. We can try that in a separate PR later. 2. wire `Float8Linear` to use (1) 3. add weight amax syncing back, since we need it for float8 all-gather 4. add test coverage for eager mode numerics Next up (in separate PRs) will be training run validation for numerics, and taking a look at performance. Test Plan: ``` ./test/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
Summary: Adds support for delayed scaling in FSDP2 float8 all-gather. In detail: 1. add `WeightWithDelayedFloat8CastTensor`, note that we don't reuse code with the dynamic version because I'd rather not deal with plumbing optional tensors through dynamo. We can try that in a separate PR later. 2. wire `Float8Linear` to use (1) 3. add weight amax syncing back, since we need it for float8 all-gather 4. add test coverage for eager mode numerics Next up (in separate PRs) will be training run validation for numerics, and taking a look at performance. Test Plan: ``` ./test/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: f1707c1 Pull Request resolved: #312
what are the optional tensors ? |
| all_amax_tensors = torch.cat( | ||
| fp8_amax_x_tensor_list + fp8_amax_dL_dY_tensor_list | ||
| fp8_amax_x_tensor_list | ||
| + fp8_amax_w_tensor_list |
There was a problem hiding this comment.
should we only do this if we are using fp8 all gather ?
There was a problem hiding this comment.
that could make sense, I'd love to see the data to see if this is going to matter for performance. Focusing on numerics for now, was hoping for performance be tackled in future PRs.
| return Float8Tensor(data, scale, param_dtype, self._mm_config), (data,) | ||
|
|
||
|
|
||
| class WeightWithDelayedFloat8CastTensor(torch.Tensor): |
There was a problem hiding this comment.
[no change needed] I wish there was a way to share some more code with the dynamic version
There was a problem hiding this comment.
yeah, me too. Looking at the code below, really the only code which would be shared is fsdp_post_all_gather, everything else would have to have if/else branches for delayed vs dynamic
| def __repr__(self): | ||
| return f"WeightWithDelayedFloat8CastTensor(tensor={self._tensor}, amax_buffer={self._amax_buffer}, scale_buffer={self._scale_buffer}, mm_config={self._mm_config})" | ||
|
|
||
| def fsdp_pre_all_gather(self, mesh): |
There was a problem hiding this comment.
confirming that fsdp part looks good
Summary: Adds support for delayed scaling in FSDP2 float8 all-gather. In detail: 1. add `WeightWithDelayedFloat8CastTensor`, note that we don't reuse code with the dynamic version because I'd rather not deal with plumbing optional tensors through dynamo. We can try that in a separate PR later. 2. wire `Float8Linear` to use (1) 3. add weight amax syncing back, since we need it for float8 all-gather 4. add test coverage for eager mode numerics Next up (in separate PRs) will be training run validation for numerics, and taking a look at performance. Test Plan: ``` ./test/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
Summary: Adds support for delayed scaling in FSDP2 float8 all-gather. In detail: 1. add `WeightWithDelayedFloat8CastTensor`, note that we don't reuse code with the dynamic version because I'd rather not deal with plumbing optional tensors through dynamo. We can try that in a separate PR later. 2. wire `Float8Linear` to use (1) 3. add weight amax syncing back, since we need it for float8 all-gather 4. add test coverage for eager mode numerics Next up (in separate PRs) will be training run validation for numerics, and taking a look at performance. Test Plan: ``` ./test/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: c83e4df Pull Request resolved: #312
|
|
||
| @classmethod | ||
| def __torch_dispatch__(cls, func, types, args, kwargs=None): | ||
| if func == torch.ops.aten.detach.default: |
There was a problem hiding this comment.
mostly just a nit, but any reason to special-case detach here? Alternatively, you could set it up so that every view ops automatiomatically propagates subclass-ness in the same way
There was a problem hiding this comment.
If this is something I wrote, I think it was just something I saw in some other subclasses. Having every view up propagate subclass-ness in the same way sounds good to me.
weifengpy
left a comment
There was a problem hiding this comment.
stamping for the fsdp part
document 2 open questions (not blocker for this PR)
- should we merge
WeightWithDelayedFloat8CastTensorandWeightWithDynamicFloat8CastTensorinto one class and add if-else to unify logic around__torch_dispatch__,fsdp_pre_all_gather/fsdp_post_all_gather. we unifedFloat8Linearalready - compare perfs between
sync_float8_amax_and_scale_historyandprecompute_float8_dynamic_scale_for_fsdp. If they are similar, people would not need to worry about numeric problem from delayed scaling
I'm open if someone is interested in doing that in a follow-up PR. I'm not sure it will be better than what we have now though. Note that
yes, that would be great! I think we can do this in follow-up PRs. Note that delayed scaling is theoretically faster than dynamic scaling (less memory reads), but performance is not optimized across the stack yet. I think it's good to have options and allow people to optimize different settings in parallel. Eventually if there is clear data that only one of these is needed, we can delete the not-needed ones. |
Summary: Adds support for delayed scaling in FSDP2 float8 all-gather. In detail: 1. add `WeightWithDelayedFloat8CastTensor`, note that we don't reuse code with the dynamic version because I'd rather not deal with plumbing optional tensors through dynamo. We can try that in a separate PR later. 2. wire `Float8Linear` to use (1) 3. add weight amax syncing back, since we need it for float8 all-gather 4. add test coverage for eager mode numerics Next up (in separate PRs) will be training run validation for numerics, and taking a look at performance. Test Plan: ``` ./test/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
Summary: Adds support for delayed scaling in FSDP2 float8 all-gather. In detail: 1. add `WeightWithDelayedFloat8CastTensor`, note that we don't reuse code with the dynamic version because I'd rather not deal with plumbing optional tensors through dynamo. We can try that in a separate PR later. 2. wire `Float8Linear` to use (1) 3. add weight amax syncing back, since we need it for float8 all-gather 4. add test coverage for eager mode numerics Next up (in separate PRs) will be training run validation for numerics, and taking a look at performance. Test Plan: ``` ./test/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: cdc9d96 Pull Request resolved: #312
|
@vkuzo has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
|
This pull request has been merged in de93990. |
|
|
||
| def fsdp_pre_all_gather(self, mesh): | ||
| # initialize if needed | ||
| # TODO(before land): ensure settings are consistent between Float8Linear and here |
| self._amax_buffer, | ||
| self._amax_history_buffer, | ||
| self._scale_buffer, | ||
| "max", # TODO(before land): read this from parent |
Stack from ghstack (oldest at bottom):
swap_linear_with_dynamicfrom fsdp2 eager test case #311Summary:
Adds support for delayed scaling in FSDP2 float8 all-gather. In detail:
WeightWithDelayedFloat8CastTensor, note that we don't reusecode with the dynamic version because I'd rather not deal with
plumbing optional tensors through dynamo. We can try that in a
separate PR later.
Float8Linearto use (1)Next up (in separate PRs) will be training run validation for numerics, and
taking a look at performance.
Test Plan:
Reviewers:
Subscribers:
Tasks:
Tags:
Differential Revision: D59685258