Skip to content

Separate SAC Wrapping of MoE and Attention Modules to Enable Flex Attention Compilation#1683

Merged
fegin merged 9 commits into
mainfrom
chienchin/flex_sac_hack2
Sep 25, 2025
Merged

Separate SAC Wrapping of MoE and Attention Modules to Enable Flex Attention Compilation#1683
fegin merged 9 commits into
mainfrom
chienchin/flex_sac_hack2

Conversation

@fegin

@fegin fegin commented Sep 5, 2025

Copy link
Copy Markdown
Contributor

Flex Attention requires compilation via torch.compile to achieve optimal performance. Therefore, torch.compile is always applied to Flex Attention, regardless of the compile.enable flag. However, when Selective Activation Checkpointing (SAC) is enabled, torch.compile may be bypassed or invalidated under certain conditions:

  1. If compile.enable is set to False, SAC will ignore any torch.compile calls within the SAC region.
  2. If compile.enable is True but the transformer block includes a Mixture of Experts (MoE) module.

To address this limitation, this PR separates the SAC wrapping of Attention from MoE and FeedForward modules. This separation ensures that Flex Attention can be compiled successfully even when SAC is enabled. Attention module is wrapped with full AC if compile.enable is False.

FIX (workaround) #1631

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label Sep 5, 2025
@fegin fegin changed the title [Don't review yet] SAC + Flex refactoring Separate SAC Wrapping of MoE and Attention Modules to Enable Flex Attention Compilation Sep 5, 2025
@fegin fegin requested a review from soulitzer September 5, 2025 22:39
"torch.compile may be invalidated:\n"
"1. If compile.enable is False, SAC will ignore any torch.compile "
"inside the SAC region.\n"
"2. If compile.enable is True but the transformer block contains a MoE module.\n\n"

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh what's wrong with this?

Also this doesn't sound general -- is it correct that this function will be shared by both dense and sparse models? If so, for dense models it could cause regression.

@fegin fegin Sep 8, 2025

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh what's wrong with this?

MoE is causing a graph break, which invalidates the entire AC block compilation. The AC block will be run under eager. FlexAttention will not be compiled.

Also this doesn't sound general -- is it correct that this function will be shared by both dense and sparse models? If so, for dense models it could cause regression.

That's a good question. I originally kept SAC(TransformerBlock) for dense modules. But it turns out that the memory usage is no better than just SAC(feedforward) + SAC(attention) or even worse. Not sure why. cc., @soulitzer

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If SAC(f(g(x)))'s policy saves the output of g. SAC(f)(SAC(g(x)) is probably strictly better than SAC(f(g(x))) since in eager, it allows us to clear the rematerialized activations of f before recomputing g.
In this case, the last op of the attention is matmul, so there's a chance we fall into this case.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@soulitzer
I think we should separate forward / backward:

For backward, I roughly get that it may help if we "to clear the rematerialized activations of f before recomputing g". However, IIRC the memory peak is on forward & loss computation, not on backward, so it may not be "strictly better".

For forward, I understand that it's possible that SAC(f(g(x))) and SAC(f)(SAC(g(x)) may result in similar set of activations being saved.

the last op of the attention is matmul, so there's a chance we fall into this case.

Plus I don't think
SAC(TransformerBlock) vs. SAC(feedforward) + SAC(attention) is completely analogous to SAC(f(g(x))) vs. SAC(f)(SAC(g(x)), because in TransformerBlock we also have ffn_norm and attention_norm whose output will be saved in SAC(feedforward) + SAC(attention) but not SAC(TransformerBlock)?

Maybe ignore what I typed if it looks too messy lol -- what I wanted to convey is I expect we save more with SAC(feedforward) + SAC(attention)
IIUC the op you mentioned come from Attention.wo. But our policy is "save every other matmul", so technically we should be occasionally saving more?
Very concretely,

Since DSV3 16B has only 1 MLP layer https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/deepseek_v3/__init__.py#L87
so according to the policy here https://github.com/pytorch/torchtitan/blob/main/torchtitan/distributed/activation_checkpoint.py#L94, what happens could be:

SAC(TransformerBlock)

  • the input to every TransformerBlock
  • for the only MLP layer, recompute MLP.w1, save MLP.w2, recompute MLP.w3
  • for every Attention layer, save wq, recompute wkv_a, save wkv_b, recompute wo

SAC(feedforward) + SAC(attention)

  • the input to every attention (results of attention_norm)
  • for the only MLP layer, recompute MLP.w1, save MLP.w2, recompute MLP.w3
  • for every Attention layer, recompute wq, save wkv_a, recompute wkv_b, save wo
  • the input to every feedforward (results of ffn_norm)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tianyu-l I'm confused

Maybe ignore what I typed if it looks too messy lol -- what I wanted to convey is I expect we save more with SAC(feedforward) + SAC(attention)

Do you expect that SAC(feedforward) + SAC(attention) saves more memory or SAC(TransformerBlock) saves more memory?

From the experiment, it is SAC(feedforward) + SAC(attention). But you mentioned you expected a regression in the original comment if I do SAC(feedforward) + SAC(attention).

@soulitzer soulitzer Sep 9, 2025

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the data. I do notice that the "wrapped together" cases are significantly more memory.

FlexAttention is compiled because the outer SAC is compiled.

So this is without MoEs, i.e., you set num dense layer to the number of total layers, and you are able to compile with fullgraph=True?

If there are mostly non-Dense layers, then I'd still imagine that the graph break in the MoEs would prevent Flex from being compiled.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I may miscommunicate. Let me put a summary of how these experiments were done.

  1. This is a 16B model, with one dense layer. All the configuration changes are applied to that dense block only.

  2. If a TransformerBlock has MoE, aka a sparse block, its Attention module is always wrapped separately. So the graph break from MoE doesn't prevent Flex from being compiled.

  3. The different configurations are applied to the only dense block, the TransformerBlock that has FeedForward but not MoE. There should be no graph breaks in this dense block, so even wrapping the FeedForward with Attention together with SAC should compile Flex correctly.

  4. If you check the last configuration in the experiment, I didn't apply SAC nor AC to the dense block. Its memory usage is very similar to other cases where FeedForward and Attention are wrapped together.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think my confusion was:
under compile + dense + FlexAttn, I think SAC(feedforward) + SAC(attention) will occupy more memory than SAC(TransformerBlock) because technically the former should save more (aka the norm results). If so this will be a regression (in terms of memory) to dense model under Flex + compile.

But it seems not the case?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see, thanks for explaining. I guess I'm pretty confused by these results. Going to think about it more tomorrow.

Some thoughts so far:

  • in compile unless the ACs are adjacent the inputs to AC aren't force saved (not sure how I feel about that tbh).
  • Attention (4 mms) happens before MLP (3 mms), so it would be (recompute, save, recompute ,save) for the mm in the first SAC on the Attention (recompute, save, recompute) on the second SAC on the MLP whether or not there is one or two SAC region! (As a test we can change the SAC policy to save all matmuls for example.) I was thinking about
  • Partitioner itself will do some recompute, so maybe its too surprising that SAC results can be same as no AC at all, e.g. perhaps RMSNorm is fusible and thus gets to be recomputed by default, and it also decides to save mms since they are compute intensive.
  • Quite confusingly, wrap TransformerBlock with full AC isn't using the least amount of memory among all these options.

@fegin fegin Sep 15, 2025

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After rebasing PyTorch, I could not reproduce the issue. Now wrapping the entire dense block cause a similar memory usage as wrapping feedfoward and attention separately.

So I change the code to wrapping the entire dense block.

cc., @soulitzer

Comment thread torchtitan/distributed/activation_checkpoint.py Outdated
Comment thread torchtitan/distributed/activation_checkpoint.py
Comment thread torchtitan/distributed/activation_checkpoint.py Outdated
"torch.compile may be invalidated:\n"
"1. If compile.enable is False, SAC will ignore any torch.compile "
"inside the SAC region.\n"
"2. If compile.enable is True but the transformer block contains a MoE module.\n\n"

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think my confusion was:
under compile + dense + FlexAttn, I think SAC(feedforward) + SAC(attention) will occupy more memory than SAC(TransformerBlock) because technically the former should save more (aka the norm results). If so this will be a regression (in terms of memory) to dense model under Flex + compile.

But it seems not the case?

save_list=save_list,
),
)
if model_compile_enabled:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From a user perspective, I'd hope compile on/off doesn't simultaneously change other settings like this.
Do you think we can always do full AC on attention when Flex + SAC is used, eager or compile?
O/w in extreme cases it's possible that with compile, we are seeing slower throughput (due to more recomputation), but it'll be very unintuitive.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think my confusion was:
under compile + dense + FlexAttn, I think SAC(feedforward) + SAC(attention) will occupy more memory than >SAC(TransformerBlock) because technically the former should save more (aka the norm results). If so this will be a >regression (in terms of memory) to dense model under Flex + compile.
But it seems not the case?

Yes, this needs more investigation. It's also not clear to me why the result is counterintuitive.

Do you think we can always do full AC on attention when Flex + SAC is used, eager or compile?

Yes, this is a valid option and I have tested that option. The result doesn't show any downsides and the logic is simpler and UX is much more clear. But I have to re-verify it again as after #1672, the memory usage is different.

@fegin fegin Sep 15, 2025

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, after second thought, I think we should wrap Attention with SAC.

  1. It's more close the semantic: when users try to apply SAC to the whole model. I will view full AC as a limitation due to SAC + torch.compile, which people are trying to lift the limitation. So I think if possible, we still want to use SAC as that the flag users turn on. cc., @soulitzer

  2. It's not good for CP to redo the communication.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not good for CP to redo the communication

How are you going to avoid "redo the communication"? Are you going to save all-gather?
If TP is also used, does it mean we have to save TP all-gather as well. Currently it's not in _save_list https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama3/infra/parallelize.py#L33

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CP's communication is going to be a custom op, so we don't need to mess it up with TP's all-gather. But the memory saving can be a concern.

@fegin fegin force-pushed the chienchin/flex_sac_hack2 branch from 38bdc0c to 1456a3f Compare September 24, 2025 00:53
@fegin

fegin commented Sep 24, 2025

Copy link
Copy Markdown
Contributor Author

Logic:

if TransformerBlock has MoE:
     if compile model:
          SAC(MoE)
          SAC(attention)
     else:
          SAC(MoE)
          AC(attention)
else:
     if compile model:
          SAC(TransformerBlock)
     else:
          SAC(FeedForward)
          AC(Attention)

cc., @tianyu-l @soulitzer @wwwjn @xmfan

Comment thread torchtitan/distributed/activation_checkpoint.py Outdated
Comment thread torchtitan/distributed/activation_checkpoint.py
Comment thread torchtitan/experiments/vlm/infra/parallelize.py Outdated
Comment on lines +203 to +204
wrap_submodule("moe", full_ac=False)
if model_compile_enabled:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the reason that under compile, we do compile(SAC(MoE), SAC(Attention)) instead of compile(SAC(TransformerBlock)) the graph break in MoE?

IIUC in this case FlexAttention still runs with compile, but the MoE will be purely eager.

@xmfan @soulitzer
Could you remind me of our plan on graph breaks invalidating AC/SAC?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the reason that under compile, we do compile(SAC(MoE), SAC(Attention)) instead of compile(SAC(TransformerBlock)) the graph break in MoE?

Yes.

@xmfan xmfan Sep 24, 2025

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're working on supporting compile(ac(graph break (https://fburl.com/gdoc/e48ltvvq), which should automatically rewrite the code:

either like:
compile(ac(graph break -> ac(compile(before graph break) + graph break region + compile(after graph break))
or:
compile(ac(graph break -> compile(ac(before graph break region + ac(graph break region + ac(compile(after graph break region
or something similar, the proposal doesn't cover the full details yet.

The rewrite + fixing SAC(compile would remove the need to manually wrap them separately like in this PR

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The high-level plan is that if compile(AC(fn with graph breaks. Flip the ordering so that we have AC(compile(partial_graph. pytorch/pytorch#139989 @xmfan plans to work on this.

For SAC w/ graph breaks, we would need Simon's work and ALSO SAC(compile(fn support, plan here udner "proposal 2 impl details" https://docs.google.com/document/d/1nsO52Q74VXxrwOdGxaNEVEedy-6Jra-FyYjWQxnVVxE/edit?tab=t.0 I will likely be working on this (although someone expressed some interest to work on it as well and waiting to hear back on their decision).

@tianyu-l tianyu-l left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@fegin fegin merged commit 0943771 into main Sep 25, 2025
8 of 13 checks passed
@fegin fegin deleted the chienchin/flex_sac_hack2 branch December 8, 2025 20:08
@fegin fegin restored the chienchin/flex_sac_hack2 branch December 8, 2025 20:08
@fegin fegin deleted the chienchin/flex_sac_hack2 branch December 8, 2025 20:08
xrsrke pushed a commit to NousResearch/torchtitan that referenced this pull request Feb 13, 2026
…ention Compilation (pytorch#1683)

Flex Attention requires compilation via torch.compile to achieve optimal
performance. Therefore, torch.compile is always applied to Flex
Attention, regardless of the compile.enable flag. However, when
Selective Activation Checkpointing (SAC) is enabled, torch.compile may
be bypassed or invalidated under certain conditions:

1. If compile.enable is set to False, SAC will ignore any torch.compile
calls within the SAC region.
2. If compile.enable is True but the transformer block includes a
Mixture of Experts (MoE) module.

To address this limitation, this PR separates the SAC wrapping of
Attention from MoE and FeedForward modules. This separation ensures that
Flex Attention can be compiled successfully even when SAC is enabled.
Attention module is wrapped with full AC if compile.enable is False.

FIX (workaround) pytorch#1631
xrsrke pushed a commit to NousResearch/torchtitan that referenced this pull request Feb 25, 2026
…ention Compilation (pytorch#1683)

Flex Attention requires compilation via torch.compile to achieve optimal
performance. Therefore, torch.compile is always applied to Flex
Attention, regardless of the compile.enable flag. However, when
Selective Activation Checkpointing (SAC) is enabled, torch.compile may
be bypassed or invalidated under certain conditions:

1. If compile.enable is set to False, SAC will ignore any torch.compile
calls within the SAC region.
2. If compile.enable is True but the transformer block includes a
Mixture of Experts (MoE) module.

To address this limitation, this PR separates the SAC wrapping of
Attention from MoE and FeedForward modules. This separation ensures that
Flex Attention can be compiled successfully even when SAC is enabled.
Attention module is wrapped with full AC if compile.enable is False.

FIX (workaround) pytorch#1631
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot. high priority

Projects

None yet

Development

Successfully merging this pull request may close these issues.

torch.compile Not Applied to FlexAttention with SAC selective_ac_option=op

4 participants