Separate SAC Wrapping of MoE and Attention Modules to Enable Flex Attention Compilation#1683
Conversation
| "torch.compile may be invalidated:\n" | ||
| "1. If compile.enable is False, SAC will ignore any torch.compile " | ||
| "inside the SAC region.\n" | ||
| "2. If compile.enable is True but the transformer block contains a MoE module.\n\n" |
There was a problem hiding this comment.
oh what's wrong with this?
Also this doesn't sound general -- is it correct that this function will be shared by both dense and sparse models? If so, for dense models it could cause regression.
There was a problem hiding this comment.
oh what's wrong with this?
MoE is causing a graph break, which invalidates the entire AC block compilation. The AC block will be run under eager. FlexAttention will not be compiled.
Also this doesn't sound general -- is it correct that this function will be shared by both dense and sparse models? If so, for dense models it could cause regression.
That's a good question. I originally kept SAC(TransformerBlock) for dense modules. But it turns out that the memory usage is no better than just SAC(feedforward) + SAC(attention) or even worse. Not sure why. cc., @soulitzer
There was a problem hiding this comment.
If SAC(f(g(x)))'s policy saves the output of g. SAC(f)(SAC(g(x)) is probably strictly better than SAC(f(g(x))) since in eager, it allows us to clear the rematerialized activations of f before recomputing g.
In this case, the last op of the attention is matmul, so there's a chance we fall into this case.
There was a problem hiding this comment.
@soulitzer
I think we should separate forward / backward:
For backward, I roughly get that it may help if we "to clear the rematerialized activations of f before recomputing g". However, IIRC the memory peak is on forward & loss computation, not on backward, so it may not be "strictly better".
For forward, I understand that it's possible that SAC(f(g(x))) and SAC(f)(SAC(g(x)) may result in similar set of activations being saved.
the last op of the attention is matmul, so there's a chance we fall into this case.
Plus I don't think
SAC(TransformerBlock) vs. SAC(feedforward) + SAC(attention) is completely analogous to SAC(f(g(x))) vs. SAC(f)(SAC(g(x)), because in TransformerBlock we also have ffn_norm and attention_norm whose output will be saved in SAC(feedforward) + SAC(attention) but not SAC(TransformerBlock)?
Maybe ignore what I typed if it looks too messy lol -- what I wanted to convey is I expect we save more with SAC(feedforward) + SAC(attention)
IIUC the op you mentioned come from Attention.wo. But our policy is "save every other matmul", so technically we should be occasionally saving more?
Very concretely,
- MLP has 3 matmul, w1, w2, w3
- Attention (DSV3 16B) has 4 matmul, wq, wkv_a, wkv_b, wo
- MoE has 3 grouped_mm, and 1 matmul from router.gate, but for this one it's not in the "save every other matmul" regime due to https://github.com/pytorch/torchtitan/blob/main/torchtitan/distributed/activation_checkpoint.py#L90
Since DSV3 16B has only 1 MLP layer https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/deepseek_v3/__init__.py#L87
so according to the policy here https://github.com/pytorch/torchtitan/blob/main/torchtitan/distributed/activation_checkpoint.py#L94, what happens could be:
SAC(TransformerBlock)
- the input to every
TransformerBlock - for the only MLP layer, recompute
MLP.w1, saveMLP.w2, recomputeMLP.w3 - for every Attention layer, save
wq, recomputewkv_a, savewkv_b, recomputewo
SAC(feedforward) + SAC(attention)
- the input to every attention (results of
attention_norm) - for the only MLP layer, recompute
MLP.w1, saveMLP.w2, recomputeMLP.w3 - for every Attention layer, recompute
wq, savewkv_a, recomputewkv_b, savewo - the input to every feedforward (results of
ffn_norm)
There was a problem hiding this comment.
@tianyu-l I'm confused
Maybe ignore what I typed if it looks too messy lol -- what I wanted to convey is I expect we save more with SAC(feedforward) + SAC(attention)
Do you expect that SAC(feedforward) + SAC(attention) saves more memory or SAC(TransformerBlock) saves more memory?
From the experiment, it is SAC(feedforward) + SAC(attention). But you mentioned you expected a regression in the original comment if I do SAC(feedforward) + SAC(attention).
There was a problem hiding this comment.
Thanks for the data. I do notice that the "wrapped together" cases are significantly more memory.
FlexAttention is compiled because the outer SAC is compiled.
So this is without MoEs, i.e., you set num dense layer to the number of total layers, and you are able to compile with fullgraph=True?
If there are mostly non-Dense layers, then I'd still imagine that the graph break in the MoEs would prevent Flex from being compiled.
There was a problem hiding this comment.
I may miscommunicate. Let me put a summary of how these experiments were done.
-
This is a 16B model, with one dense layer. All the configuration changes are applied to that dense block only.
-
If a TransformerBlock has
MoE, aka a sparse block, its Attention module is always wrapped separately. So the graph break from MoE doesn't prevent Flex from being compiled. -
The different configurations are applied to the only dense block, the TransformerBlock that has
FeedForwardbut notMoE. There should be no graph breaks in this dense block, so even wrapping theFeedForwardwithAttentiontogether with SAC should compile Flex correctly. -
If you check the last configuration in the experiment, I didn't apply SAC nor AC to the dense block. Its memory usage is very similar to other cases where
FeedForwardandAttentionare wrapped together.
There was a problem hiding this comment.
I think my confusion was:
under compile + dense + FlexAttn, I think SAC(feedforward) + SAC(attention) will occupy more memory than SAC(TransformerBlock) because technically the former should save more (aka the norm results). If so this will be a regression (in terms of memory) to dense model under Flex + compile.
But it seems not the case?
There was a problem hiding this comment.
Ah I see, thanks for explaining. I guess I'm pretty confused by these results. Going to think about it more tomorrow.
Some thoughts so far:
- in compile unless the ACs are adjacent the inputs to AC aren't force saved (not sure how I feel about that tbh).
- Attention (4 mms) happens before MLP (3 mms), so it would be (recompute, save, recompute ,save) for the mm in the first SAC on the Attention (recompute, save, recompute) on the second SAC on the MLP whether or not there is one or two SAC region! (As a test we can change the SAC policy to save all matmuls for example.) I was thinking about
- Partitioner itself will do some recompute, so maybe its too surprising that SAC results can be same as no AC at all, e.g. perhaps RMSNorm is fusible and thus gets to be recomputed by default, and it also decides to save mms since they are compute intensive.
- Quite confusingly, wrap TransformerBlock with full AC isn't using the least amount of memory among all these options.
There was a problem hiding this comment.
After rebasing PyTorch, I could not reproduce the issue. Now wrapping the entire dense block cause a similar memory usage as wrapping feedfoward and attention separately.
So I change the code to wrapping the entire dense block.
cc., @soulitzer
| "torch.compile may be invalidated:\n" | ||
| "1. If compile.enable is False, SAC will ignore any torch.compile " | ||
| "inside the SAC region.\n" | ||
| "2. If compile.enable is True but the transformer block contains a MoE module.\n\n" |
There was a problem hiding this comment.
I think my confusion was:
under compile + dense + FlexAttn, I think SAC(feedforward) + SAC(attention) will occupy more memory than SAC(TransformerBlock) because technically the former should save more (aka the norm results). If so this will be a regression (in terms of memory) to dense model under Flex + compile.
But it seems not the case?
| save_list=save_list, | ||
| ), | ||
| ) | ||
| if model_compile_enabled: |
There was a problem hiding this comment.
From a user perspective, I'd hope compile on/off doesn't simultaneously change other settings like this.
Do you think we can always do full AC on attention when Flex + SAC is used, eager or compile?
O/w in extreme cases it's possible that with compile, we are seeing slower throughput (due to more recomputation), but it'll be very unintuitive.
There was a problem hiding this comment.
I think my confusion was:
under compile + dense + FlexAttn, I think SAC(feedforward) + SAC(attention) will occupy more memory than >SAC(TransformerBlock) because technically the former should save more (aka the norm results). If so this will be a >regression (in terms of memory) to dense model under Flex + compile.
But it seems not the case?
Yes, this needs more investigation. It's also not clear to me why the result is counterintuitive.
Do you think we can always do full AC on attention when Flex + SAC is used, eager or compile?
Yes, this is a valid option and I have tested that option. The result doesn't show any downsides and the logic is simpler and UX is much more clear. But I have to re-verify it again as after #1672, the memory usage is different.
There was a problem hiding this comment.
Okay, after second thought, I think we should wrap Attention with SAC.
-
It's more close the semantic: when users try to apply SAC to the whole model. I will view full AC as a limitation due to SAC + torch.compile, which people are trying to lift the limitation. So I think if possible, we still want to use SAC as that the flag users turn on. cc., @soulitzer
-
It's not good for CP to redo the communication.
There was a problem hiding this comment.
It's not good for CP to redo the communication
How are you going to avoid "redo the communication"? Are you going to save all-gather?
If TP is also used, does it mean we have to save TP all-gather as well. Currently it's not in _save_list https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama3/infra/parallelize.py#L33
There was a problem hiding this comment.
CP's communication is going to be a custom op, so we don't need to mess it up with TP's all-gather. But the memory saving can be a concern.
8fed32a to
a2570f2
Compare
38bdc0c to
1456a3f
Compare
|
Logic: cc., @tianyu-l @soulitzer @wwwjn @xmfan |
| wrap_submodule("moe", full_ac=False) | ||
| if model_compile_enabled: |
There was a problem hiding this comment.
Is the reason that under compile, we do compile(SAC(MoE), SAC(Attention)) instead of compile(SAC(TransformerBlock)) the graph break in MoE?
IIUC in this case FlexAttention still runs with compile, but the MoE will be purely eager.
@xmfan @soulitzer
Could you remind me of our plan on graph breaks invalidating AC/SAC?
There was a problem hiding this comment.
Is the reason that under compile, we do compile(SAC(MoE), SAC(Attention)) instead of compile(SAC(TransformerBlock)) the graph break in MoE?
Yes.
There was a problem hiding this comment.
We're working on supporting compile(ac(graph break (https://fburl.com/gdoc/e48ltvvq), which should automatically rewrite the code:
either like:
compile(ac(graph break -> ac(compile(before graph break) + graph break region + compile(after graph break))
or:
compile(ac(graph break -> compile(ac(before graph break region + ac(graph break region + ac(compile(after graph break region
or something similar, the proposal doesn't cover the full details yet.
The rewrite + fixing SAC(compile would remove the need to manually wrap them separately like in this PR
There was a problem hiding this comment.
The high-level plan is that if compile(AC(fn with graph breaks. Flip the ordering so that we have AC(compile(partial_graph. pytorch/pytorch#139989 @xmfan plans to work on this.
For SAC w/ graph breaks, we would need Simon's work and ALSO SAC(compile(fn support, plan here udner "proposal 2 impl details" https://docs.google.com/document/d/1nsO52Q74VXxrwOdGxaNEVEedy-6Jra-FyYjWQxnVVxE/edit?tab=t.0 I will likely be working on this (although someone expressed some interest to work on it as well and waiting to hear back on their decision).
…ention Compilation (pytorch#1683) Flex Attention requires compilation via torch.compile to achieve optimal performance. Therefore, torch.compile is always applied to Flex Attention, regardless of the compile.enable flag. However, when Selective Activation Checkpointing (SAC) is enabled, torch.compile may be bypassed or invalidated under certain conditions: 1. If compile.enable is set to False, SAC will ignore any torch.compile calls within the SAC region. 2. If compile.enable is True but the transformer block includes a Mixture of Experts (MoE) module. To address this limitation, this PR separates the SAC wrapping of Attention from MoE and FeedForward modules. This separation ensures that Flex Attention can be compiled successfully even when SAC is enabled. Attention module is wrapped with full AC if compile.enable is False. FIX (workaround) pytorch#1631
…ention Compilation (pytorch#1683) Flex Attention requires compilation via torch.compile to achieve optimal performance. Therefore, torch.compile is always applied to Flex Attention, regardless of the compile.enable flag. However, when Selective Activation Checkpointing (SAC) is enabled, torch.compile may be bypassed or invalidated under certain conditions: 1. If compile.enable is set to False, SAC will ignore any torch.compile calls within the SAC region. 2. If compile.enable is True but the transformer block includes a Mixture of Experts (MoE) module. To address this limitation, this PR separates the SAC wrapping of Attention from MoE and FeedForward modules. This separation ensures that Flex Attention can be compiled successfully even when SAC is enabled. Attention module is wrapped with full AC if compile.enable is False. FIX (workaround) pytorch#1631
Flex Attention requires compilation via torch.compile to achieve optimal performance. Therefore, torch.compile is always applied to Flex Attention, regardless of the compile.enable flag. However, when Selective Activation Checkpointing (SAC) is enabled, torch.compile may be bypassed or invalidated under certain conditions:
To address this limitation, this PR separates the SAC wrapping of Attention from MoE and FeedForward modules. This separation ensures that Flex Attention can be compiled successfully even when SAC is enabled. Attention module is wrapped with full AC if compile.enable is False.
FIX (workaround) #1631