Skip to content

TrainingWeightWrapperTensor base class; subclasses for FP8/MXFP8 with grouped_mm and linear overrides#3968

Merged
danielvegamyhre merged 4 commits into
mainfrom
traintensor
Mar 4, 2026
Merged

TrainingWeightWrapperTensor base class; subclasses for FP8/MXFP8 with grouped_mm and linear overrides#3968
danielvegamyhre merged 4 commits into
mainfrom
traintensor

Conversation

@danielvegamyhre

@danielvegamyhre danielvegamyhre commented Feb 28, 2026

Copy link
Copy Markdown
Contributor

Tensor subclass changes

  • TrainingWeightWrapperBaseTensor: base Common logic for FSDP, torch_dispatch, subclass initialization etc is in this base class (not to be used directly)
    • Common base class also enables common model conversion / param wrapping code
  • MXFP8 and FP8 tensor subclasses inherit from this and implement the override torch_function with the specific grouped_mm and linear overrides, dispatching to the appropriate autograd functions wrapping our kernels

Autograd function changes

  • Add new _to_mxfp8_then_scaled_mm autograd func to support linear op overrides. Supports wgrad_with_hp as well.

Other

  • Delete MXLinear and MXLinearConfig so we don't have two diverging ways of doing mxfp8 dense training. This also removes MXFP4 training support but nobody is using this as far as we know so not creating tech debt is preferable.

Tests

  • ./test/prototype/moe_training/test_everything.sh
  • pytest test/prototype/mx_formats/test_mx_linear.py
  • pytest test/prototype/mx_formats/test_mx_tensor.py

@pytorch-bot

pytorch-bot Bot commented Feb 28, 2026

Copy link
Copy Markdown

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3968

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit 6bcfb53 with merge base 4ae435e (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Feb 28, 2026
@danielvegamyhre danielvegamyhre force-pushed the traintensor branch 3 times, most recently from d0025e8 to 248405c Compare March 2, 2026 17:48
@danielvegamyhre danielvegamyhre added module: training quantize_ api training flow moe labels Mar 2, 2026
@danielvegamyhre danielvegamyhre changed the title [WIP] unified tensor subclass for training TorchAOTrainingTensor base class; MXFP8TrainingTensor and FP8TrainingTensor subclasses with grouped_mm and linear overrides Mar 2, 2026
@danielvegamyhre danielvegamyhre requested a review from vkuzo March 2, 2026 18:11

class GroupedMMConfig(AOBaseConfig):
"""Base configuration for grouped matrix multiplication. Not intended to be used directly."""
class TrainingBaseConfig(AOBaseConfig):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the name is very generic, how about TrainingOpBaseConfig to clarify this is for a single op


@dataclass
class FP8GroupedMMConfig(GroupedMMConfig):
class FP8GroupedMMConfig(TrainingBaseConfig):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Float8 instead of Fp8, to match PyTorch naming for float8?

@register_as_pytree_constant
@dataclass
class MXFP8GroupedMMConfig(GroupedMMConfig):
class MXFP8TrainingConfig(TrainingBaseConfig):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MXFP8OpTrainingConfig?

@classmethod
def __torch_function__(cls, func, types, args, kwargs={}):
# grouped_mm op override
if func.__name__ == cls.grouped_mm_func_name:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is confusing, can this just state the op directly since we are already inside the float8 wrapper?

@classmethod
def __torch_function__(cls, func, types, args, kwargs={}):
# grouped_mm op override
if func.__name__ == cls.grouped_mm_func_name:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just say the op directly?

)

# linear op override
elif func.__name__ in cls.mm_func_names:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just put the ops here? making the code reader jump around to know which ops go here is confusing

@vkuzo

vkuzo commented Mar 2, 2026

Copy link
Copy Markdown
Contributor

looks good, I care about cleaning up the func.__name__ == cls.grouped_mm_func_name and elif func.__name__ in cls.mm_func_names the most from my nit comments, thank you!

@vkuzo vkuzo left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lg if CI passes and you are sure this does not regress anything

@danielvegamyhre danielvegamyhre changed the title TorchAOTrainingTensor base class; MXFP8TrainingTensor and FP8TrainingTensor subclasses with grouped_mm and linear overrides TrainingWeightWrapperTensor base class; subclasses for FP8/MXFP8 with grouped_mm and linear overrides Mar 2, 2026
@danielvegamyhre danielvegamyhre force-pushed the traintensor branch 5 times, most recently from e1743e3 to 3ff088b Compare March 2, 2026 21:52
@danielvegamyhre

Copy link
Copy Markdown
Contributor Author

addressed comments, will land once CI green

@danielvegamyhre danielvegamyhre force-pushed the traintensor branch 3 times, most recently from 5610df4 to 79de41d Compare March 3, 2026 19:00
@danielvegamyhre danielvegamyhre force-pushed the traintensor branch 2 times, most recently from 3d258d2 to 89c1a8d Compare March 3, 2026 22:30
@danielvegamyhre danielvegamyhre merged commit b8708a2 into main Mar 4, 2026
25 of 26 checks passed
tianyu-l pushed a commit to pytorch/torchtitan that referenced this pull request Mar 10, 2026
…torchao (#2520)

## Summary
- Refactor MXFP8 model converter
    - Previous: 2 separate converters (linear, grouped_mm)
    - New: 1 unified converter for linear and grouped_mm ops
    - Details: pytorch/ao#3968
- Add `pad_token_groups_for_grouped_mm` config option to use dynamic per
group padding kernels for MXFP8 grouped mm in torchao, so we can delete
padding code from torchtitan (context:
#2255)
- torchao PR stack (must land first):
pytorch/ao#4021

## Tests
- TODO: manually test this change prior to landing and update PR
weifengpy pushed a commit to weifengpy/torchtitan that referenced this pull request Mar 27, 2026
…torchao (pytorch#2520)

## Summary
- Refactor MXFP8 model converter
    - Previous: 2 separate converters (linear, grouped_mm)
    - New: 1 unified converter for linear and grouped_mm ops
    - Details: pytorch/ao#3968
- Add `pad_token_groups_for_grouped_mm` config option to use dynamic per
group padding kernels for MXFP8 grouped mm in torchao, so we can delete
padding code from torchtitan (context:
pytorch#2255)
- torchao PR stack (must land first):
pytorch/ao#4021

## Tests
- TODO: manually test this change prior to landing and update PR
TXacs pushed a commit to McmillanTAC/torchtitan that referenced this pull request Apr 13, 2026
…torchao (pytorch#2520)

## Summary
- Refactor MXFP8 model converter
    - Previous: 2 separate converters (linear, grouped_mm)
    - New: 1 unified converter for linear and grouped_mm ops
    - Details: pytorch/ao#3968
- Add `pad_token_groups_for_grouped_mm` config option to use dynamic per
group padding kernels for MXFP8 grouped mm in torchao, so we can delete
padding code from torchtitan (context:
pytorch#2255)
- torchao PR stack (must land first):
pytorch/ao#4021

## Tests
- TODO: manually test this change prior to landing and update PR
ScaleCalculationMode.RCEIL,
],
)
def test_linear_compile(

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@danielvegamyhre was this moved over? i can't find it

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sort of - linear tests were replaced with training test cases here - "shared_experts" FQN is a linear layer, and there is a compile boolean parameterization as well, so it tests mxfp8 linear in eager and with compile.

ACharacterInASimulation pushed a commit to ACharacterInASimulation/torchtitan that referenced this pull request Apr 21, 2026
…torchao (pytorch#2520)

## Summary
- Refactor MXFP8 model converter
    - Previous: 2 separate converters (linear, grouped_mm)
    - New: 1 unified converter for linear and grouped_mm ops
    - Details: pytorch/ao#3968
- Add `pad_token_groups_for_grouped_mm` config option to use dynamic per
group padding kernels for MXFP8 grouped mm in torchao, so we can delete
padding code from torchtitan (context:
pytorch#2255)
- torchao PR stack (must land first):
pytorch/ao#4021

## Tests
- TODO: manually test this change prior to landing and update PR
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: training quantize_ api training flow moe

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants