Skip to content

Only apply grouped GEMM padding for MXFP8 and FP8 non-HybridEP cases#2620

Closed
danielvegamyhre wants to merge 4 commits into
mainfrom
paddingupdate
Closed

Only apply grouped GEMM padding for MXFP8 and FP8 non-HybridEP cases#2620
danielvegamyhre wants to merge 4 commits into
mainfrom
paddingupdate

Conversation

@danielvegamyhre

@danielvegamyhre danielvegamyhre commented Mar 18, 2026

Copy link
Copy Markdown
Contributor

Context

  • BF16 grouped GEMM no longer requires padding, we can remove it from the BF16 path and only use it for FP8 and MXFP8 grouped GEMMs
  • TorchTitan will now only only contain a torch native "rank major to expert major" permutation impl for BF16 grouped GEMM, and not any extra per group padding kernels/logic for FP8/MXFP8 (these will live in torchao, as the quantization library it is a better home for them).

Summary

There are 7 cases to handle:

  • Case 1: BF16 + NoEP
    • (do nothing)
  • Case 2: BF16 + EP
    • Torch native impl handles permute from rank major to expert major (no padding)
  • Case 3: MXFP8 + No EP
    • Handled with pad/unpad kernels in torchao
  • Case 4: MXFP8 + Standard EP
    • torchao permute_and_pad() if token_group_alignment_size > 0, in ExpertParallel implementation
  • Case 5: MXFP8 + HybridEP
    • HybridEP handles token group padding for MXFP8 grouped GEMM as part of the all2all dispatch
  • Case 6: FP8 + No EP
    • Same as case 3
  • Case 7: FP8 + EP
    • Same as case 4

Misc changes

  • Delete kernels.py
  • Delete tests for those kernels
  • Remove pad_token_groups_for_grouped_mm option from MXFP8ConverterConfig, since we can set it correctly automatically
  • Added debug models for float8 and mxfp8 to config registry to speed up future development

Tests

FP8 tests were done with fp8 grouped mm only, not fp8 linear. Using both I get this weird tyro error?

[rank0]:│ model-converters.converters.0:config was not a match because:                │
[rank0]:│ • Default value Config(enable_fsdp_float8_all_gather=False,                  │
[rank0]:│   precompute_float8_dynamic_scale_for_fsdp=False, recipe_name=None,          │
[rank0]:│   filter_fqns=['output', 'router.gate'], emulate=False) with type Config     │
[rank0]:│   does not match type <class 'torchtitan.components.quantization.float8.Floa │
[rank0]:│   t8GroupedMMConverter.Config'>         

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label Mar 18, 2026
@danielvegamyhre

Copy link
Copy Markdown
Contributor Author

Have not tested yet because of devgpu issues but if you want to take a look feel free @tianyu-l

@danielvegamyhre danielvegamyhre force-pushed the paddingupdate branch 3 times, most recently from f906b3d to f79fff4 Compare March 20, 2026 04:14
@danielvegamyhre

Copy link
Copy Markdown
Contributor Author

I finished testing @tianyu-l this is ready for review

Comment thread torchtitan/models/common/moe/utils.py Outdated
from torchtitan.tools.utils import _round_up

from .kernels import generate_permute_indices
TOKEN_GROUP_ALIGN_SIZE_M = 0

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should remove this -- setting global variables is error-prone

we should move logic to parallelize functions for various combinations

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense, i am working on this, the changes are straightforward for standard EP, but for Hybrid EP it seems like it will require (1) refactoring the custom ops, DispatchState etc to pass around the quantization type used, or (2) just a module level variable storing the quantization type, similar to _buffer. I think (2) is a less invasive change, wdyt?

Comment thread torchtitan/models/common/moe/utils.py Outdated
Comment on lines +61 to +65
def maybe_align_num_tokens_for_mxfp8(num_tokens: int) -> int:
"""Round up token count only when MXFP8 group alignment is active."""
if TOKEN_GROUP_ALIGN_SIZE_M != MXFP8_GROUP_ALIGNMENT_SIZE:
return num_tokens
return _round_up(num_tokens, MXFP8_GROUP_ALIGNMENT_SIZE)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move this logic to hybridep.py, including _round_up (as an inline function) which is currently only used once in this repo

# FP8/MXFP8 require groups to be permuted to expert major order AND padded to
# `alignment_size`.
# Otherwise, we only need to permute to expert major order.
if self.token_group_alignment > 0:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO the proper way is to create e.g. FP8ExpertParallel and dispatch to it in parallelize function, instead of making if-else in existing ExpertParallel.

Also the condition should be whether quantization is used, not the token_group_alignment size set from somewhere.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That works, did a refactor



# Source: https://github.com/pytorch/torchtitan/pull/2255
def _generate_permute_indices(

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you verify that before vs. after, we get bitwise identical results under same seed and determinism?

TOKEN_GROUP_ALIGN_SIZE_M = 8
ValidTokenGroupAlignmentSize = Literal[8, 16, 32]

def indices_padding_wrapper(func: Callable) -> Callable:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need this function any more. Please remove this and simplify

# NOTE: If EP is not used, we need to pad the indices
# to prepare for grouped_mm;
# otherwise, EP will handle the padding.
if (
not isinstance(self.w1, DTensor)
# pyrefly: ignore[not-iterable]
or "ep" not in self.w1.device_mesh.mesh_dim_names
):
run_experts_fn = indices_padding_wrapper(_run_experts_grouped_mm)
else:
run_experts_fn = _run_experts_grouped_mm
return run_experts_fn(w1, w2, w3, x, num_tokens_per_expert)

@@ -45,10 +45,9 @@ def backward(ctx, grad_output):

def indices_padding_wrapper(func: Callable) -> Callable:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

num_tokens_per_expert_group,
ep_degree,
num_local_experts,
FLOAT8_GROUP_ALIGNMENT_SIZE,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do you need two different classes? You could just init with different quantization type, which can be used to determine the alignment size, e.g. based on a static dict.

Comment on lines +35 to 36
FLOAT8_GROUP_ALIGNMENT_SIZE = 16
MXFP8_GROUP_ALIGNMENT_SIZE = 32

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make this a dict from quantization type to alignment size

Comment on lines +31 to +34
if find_float8_grouped_mm_config(model_converters):
return QuantizationType.FLOAT8
elif config := find_mxfp8_config(model_converters):
if routed_experts_in_fqns(config.fqns):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to modularize into multiple small functions which are not used elsewhere -- we can make everything in a single util function for now

from torchtitan.protocols import ModelConverter


class QuantizationType(Enum):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

# as part of the EP implementation.
# Otherwise, if EP is not enabled, we need TorchAO to pad the token groups.
self.pad_token_groups_for_grouped_mm = not parallel_dims.ep_enabled
logger.warning(

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why it's a warning? sounds like a comment to me, especially when both hybridEP is used this warning would still be there

group: ProcessGroup,
score_before_experts: bool = True,
non_blocking_expert_capacity_factor: float | None = None,
quantization_type: QuantizationType | None = None,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hybridep module doesn't need to know the quantization_type. All it needs to know is pad multiple size.

)


class Float8ExpertParallel(BaseExpertParallel):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you inherit ExpertParallel instead of BaseExpertParallel, which can save a lot of code?

@pianpwk

pianpwk commented Mar 31, 2026

Copy link
Copy Markdown
Contributor

Would it also be possible to remove the for-loop padding/unpadding path (for < SM90) in

def _run_experts_for_loop(
and
def _run_experts_for_loop(
, as part of this work?

I think that should resolve issues #2312 and #2741

@tianyu-l

Copy link
Copy Markdown
Contributor

I can take this over if you are busy @danielvegamyhre

@tianyu-l tianyu-l left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I gave a try in #2774

# If EP is enabled, TorchTitan handles the token group padding for MXFP8 grouped GEMM
# as part of the EP implementation.
# Otherwise, if EP is not enabled, we need TorchAO to pad the token groups.
self.pad_token_groups_for_grouped_mm = not parallel_dims.ep_enabled

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what happens for float8, where you don't have this flag?

@danielvegamyhre

Copy link
Copy Markdown
Contributor Author

thank you @tianyu-l yes i have been busy with an urgent NaN loss situation reported by a customer... will take a look!

@danielvegamyhre

Copy link
Copy Markdown
Contributor Author

Closing in favor of #2774, thanks @tianyu-l !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

torch.compile fails with DeepSeekV3 + SimpleFSDP

4 participants