TrainingWeightWrapperTensor base class; subclasses for FP8/MXFP8 with grouped_mm and linear overrides#3968
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3968
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New FailureAs of commit 6bcfb53 with merge base 4ae435e ( NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
d0025e8 to
248405c
Compare
248405c to
7fb1c2c
Compare
|
|
||
| class GroupedMMConfig(AOBaseConfig): | ||
| """Base configuration for grouped matrix multiplication. Not intended to be used directly.""" | ||
| class TrainingBaseConfig(AOBaseConfig): |
There was a problem hiding this comment.
the name is very generic, how about TrainingOpBaseConfig to clarify this is for a single op
|
|
||
| @dataclass | ||
| class FP8GroupedMMConfig(GroupedMMConfig): | ||
| class FP8GroupedMMConfig(TrainingBaseConfig): |
There was a problem hiding this comment.
Float8 instead of Fp8, to match PyTorch naming for float8?
| @register_as_pytree_constant | ||
| @dataclass | ||
| class MXFP8GroupedMMConfig(GroupedMMConfig): | ||
| class MXFP8TrainingConfig(TrainingBaseConfig): |
| @classmethod | ||
| def __torch_function__(cls, func, types, args, kwargs={}): | ||
| # grouped_mm op override | ||
| if func.__name__ == cls.grouped_mm_func_name: |
There was a problem hiding this comment.
this is confusing, can this just state the op directly since we are already inside the float8 wrapper?
| @classmethod | ||
| def __torch_function__(cls, func, types, args, kwargs={}): | ||
| # grouped_mm op override | ||
| if func.__name__ == cls.grouped_mm_func_name: |
| ) | ||
|
|
||
| # linear op override | ||
| elif func.__name__ in cls.mm_func_names: |
There was a problem hiding this comment.
just put the ops here? making the code reader jump around to know which ops go here is confusing
|
looks good, I care about cleaning up the |
vkuzo
left a comment
There was a problem hiding this comment.
lg if CI passes and you are sure this does not regress anything
7fb1c2c to
cf26d1d
Compare
e1743e3 to
3ff088b
Compare
|
addressed comments, will land once CI green |
3ff088b to
49ab85c
Compare
5610df4 to
79de41d
Compare
3d258d2 to
89c1a8d
Compare
89c1a8d to
6bcfb53
Compare
…torchao (#2520) ## Summary - Refactor MXFP8 model converter - Previous: 2 separate converters (linear, grouped_mm) - New: 1 unified converter for linear and grouped_mm ops - Details: pytorch/ao#3968 - Add `pad_token_groups_for_grouped_mm` config option to use dynamic per group padding kernels for MXFP8 grouped mm in torchao, so we can delete padding code from torchtitan (context: #2255) - torchao PR stack (must land first): pytorch/ao#4021 ## Tests - TODO: manually test this change prior to landing and update PR
…torchao (pytorch#2520) ## Summary - Refactor MXFP8 model converter - Previous: 2 separate converters (linear, grouped_mm) - New: 1 unified converter for linear and grouped_mm ops - Details: pytorch/ao#3968 - Add `pad_token_groups_for_grouped_mm` config option to use dynamic per group padding kernels for MXFP8 grouped mm in torchao, so we can delete padding code from torchtitan (context: pytorch#2255) - torchao PR stack (must land first): pytorch/ao#4021 ## Tests - TODO: manually test this change prior to landing and update PR
…torchao (pytorch#2520) ## Summary - Refactor MXFP8 model converter - Previous: 2 separate converters (linear, grouped_mm) - New: 1 unified converter for linear and grouped_mm ops - Details: pytorch/ao#3968 - Add `pad_token_groups_for_grouped_mm` config option to use dynamic per group padding kernels for MXFP8 grouped mm in torchao, so we can delete padding code from torchtitan (context: pytorch#2255) - torchao PR stack (must land first): pytorch/ao#4021 ## Tests - TODO: manually test this change prior to landing and update PR
| ScaleCalculationMode.RCEIL, | ||
| ], | ||
| ) | ||
| def test_linear_compile( |
There was a problem hiding this comment.
@danielvegamyhre was this moved over? i can't find it
There was a problem hiding this comment.
Sort of - linear tests were replaced with training test cases here - "shared_experts" FQN is a linear layer, and there is a compile boolean parameterization as well, so it tests mxfp8 linear in eager and with compile.
…torchao (pytorch#2520) ## Summary - Refactor MXFP8 model converter - Previous: 2 separate converters (linear, grouped_mm) - New: 1 unified converter for linear and grouped_mm ops - Details: pytorch/ao#3968 - Add `pad_token_groups_for_grouped_mm` config option to use dynamic per group padding kernels for MXFP8 grouped mm in torchao, so we can delete padding code from torchtitan (context: pytorch#2255) - torchao PR stack (must land first): pytorch/ao#4021 ## Tests - TODO: manually test this change prior to landing and update PR
Tensor subclass changes
TrainingWeightWrapperBaseTensor: base Common logic for FSDP, torch_dispatch, subclass initialization etc is in this base class (not to be used directly)Autograd function changes
_to_mxfp8_then_scaled_mmautograd func to support linear op overrides. Supportswgrad_with_hpas well.Other
MXLinearandMXLinearConfigso we don't have two diverging ways of doing mxfp8 dense training. This also removes MXFP4 training support but nobody is using this as far as we know so not creating tech debt is preferable.Tests
./test/prototype/moe_training/test_everything.shpytest test/prototype/mx_formats/test_mx_linear.pypytest test/prototype/mx_formats/test_mx_tensor.py