Skip to content

[mxfp8] refactor model converter; use token group padding kernels in torchao#2520

Merged
tianyu-l merged 1 commit into
mainfrom
mar6
Mar 10, 2026
Merged

[mxfp8] refactor model converter; use token group padding kernels in torchao#2520
tianyu-l merged 1 commit into
mainfrom
mar6

Conversation

@danielvegamyhre

@danielvegamyhre danielvegamyhre commented Mar 7, 2026

Copy link
Copy Markdown
Contributor

Summary

Tests

  • TODO: manually test this change prior to landing and update PR

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label Mar 7, 2026
@danielvegamyhre

Copy link
Copy Markdown
Contributor Author

fyi @tianyu-l @rakkit this will unblock deleting the token group padding logic from torchtitan (for everything including mxfp8)

To clarify, the torchao _to_mxfp8_then_scaled_grouped_mm API still expects the tokens to be grouped by expert, rather than grouped by remote/source rank. it just no longer has alignment requirements, I've added kernels to pad inputs and unpad outputs accordingly.

So in Torchtitan, for _permute, the tokens go from:

  • [from rank0 for e0, from rank0 for e1, from rank1 for e0, from rank1 for e1]

To:

  • [from rank0 for e0, from rank1 for e0, from rank0 for e1, from rank1 for e1]

Then when torch._grouped_mm executes and dispatches to torchao, we pad the groups, and unpad the outputs.

@tianyu-l tianyu-l left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sg, one nit comment


filter_fqns: list[str]
mx_config: Any # MXLinearConfig type when imported
class MXFP8Converter(Configurable):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inherit QuantizationConverter

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated

@tianyu-l tianyu-l merged commit 2b976ee into main Mar 10, 2026
27 of 32 checks passed
@tianyu-l tianyu-l deleted the mar6 branch March 10, 2026 19:59
weifengpy pushed a commit to weifengpy/torchtitan that referenced this pull request Mar 27, 2026
…torchao (pytorch#2520)

## Summary
- Refactor MXFP8 model converter
    - Previous: 2 separate converters (linear, grouped_mm)
    - New: 1 unified converter for linear and grouped_mm ops
    - Details: pytorch/ao#3968
- Add `pad_token_groups_for_grouped_mm` config option to use dynamic per
group padding kernels for MXFP8 grouped mm in torchao, so we can delete
padding code from torchtitan (context:
pytorch#2255)
- torchao PR stack (must land first):
pytorch/ao#4021

## Tests
- TODO: manually test this change prior to landing and update PR
TXacs pushed a commit to McmillanTAC/torchtitan that referenced this pull request Apr 13, 2026
…torchao (pytorch#2520)

## Summary
- Refactor MXFP8 model converter
    - Previous: 2 separate converters (linear, grouped_mm)
    - New: 1 unified converter for linear and grouped_mm ops
    - Details: pytorch/ao#3968
- Add `pad_token_groups_for_grouped_mm` config option to use dynamic per
group padding kernels for MXFP8 grouped mm in torchao, so we can delete
padding code from torchtitan (context:
pytorch#2255)
- torchao PR stack (must land first):
pytorch/ao#4021

## Tests
- TODO: manually test this change prior to landing and update PR
ACharacterInASimulation pushed a commit to ACharacterInASimulation/torchtitan that referenced this pull request Apr 21, 2026
…torchao (pytorch#2520)

## Summary
- Refactor MXFP8 model converter
    - Previous: 2 separate converters (linear, grouped_mm)
    - New: 1 unified converter for linear and grouped_mm ops
    - Details: pytorch/ao#3968
- Add `pad_token_groups_for_grouped_mm` config option to use dynamic per
group padding kernels for MXFP8 grouped mm in torchao, so we can delete
padding code from torchtitan (context:
pytorch#2255)
- torchao PR stack (must land first):
pytorch/ao#4021

## Tests
- TODO: manually test this change prior to landing and update PR
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants