Skip to content

Enable per-layer compile with or without MoE#2741

Merged
weifengpy merged 1 commit into
pytorch:mainfrom
weifengpy:per-layer-compile-moe
Apr 15, 2026
Merged

Enable per-layer compile with or without MoE#2741
weifengpy merged 1 commit into
pytorch:mainfrom
weifengpy:per-layer-compile-moe

Conversation

@weifengpy

@weifengpy weifengpy commented Mar 27, 2026

Copy link
Copy Markdown
Contributor

per-layer compile becomes possible after applying fully_shard at layer-level (no more moe level): #2281

apply_compile: consolidate apply_compile_dense and apply_compile_sparse into one.
The only difference was capture_scalar_outputs which is harmless for dense models.
Also removed the _run_experts_grouped_mm separate compile boundary and EP wrapper.

FSDP2 + EP: NGPU=8 MODULE=deepseek_v3 CONFIG=deepseek_v3_debugmodel ./run_train.sh --compile.enable --compile.components model,loss --parallelism.expert_parallel_degree 4

FSDP2 + EP + TP: NGPU=8 MODULE=deepseek_v3 CONFIG=deepseek_v3_debugmodel ./run_train.sh --compile.enable --compile.components model,loss --parallelism.expert_parallel_degree 4 --parallelism.tensor_parallel_degree 2

FSDP2 + EP + TP&ETP: NGPU=8 MODULE=deepseek_v3 CONFIG=deepseek_v3_debugmodel ./run_train.sh --compile.enable --compile.components model,loss --parallelism.expert_parallel_degree 4 --parallelism.tensor_parallel_degree 2 --parallelism.expert_tensor_parallel_degree 2

qwen3-vl: NGPU=8 MODULE=qwen3_vl CONFIG=qwen3_vl_debugmodel_moe ./run_train.sh --compile.enable --compile.components model,loss --parallelism.expert_parallel_degree 4 --training.steps 20

H100 works
Screenshot 2026-04-02 at 16 28 08

Screenshot 2026-04-03 at 00 32 26

A100 also works
Screenshot 2026-04-03 at 00 26 14

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label Mar 27, 2026
@wwwjn

wwwjn commented Mar 27, 2026

Copy link
Copy Markdown
Contributor

@laithsakka @pianpwk need to follow up here

@pianpwk

pianpwk commented Mar 30, 2026

Copy link
Copy Markdown
Contributor

@laithsakka @pianpwk need to follow up here

for the A100 case,

I'm pretty sure #2399 should solve this (an unbacked symbol is allocated for slice output sizes at dynamo time, but disappears at inductor time due to the size being computable), but we decided against that to remove the padding entirely. It should be doable by extending #2620 to the A100 forloop path.

@laithsakka

Copy link
Copy Markdown

@laithsakka @pianpwk need to follow up here

for the A100 case,

I'm pretty sure #2399 should solve this (an unbacked symbol is allocated for slice output sizes at dynamo time, but disappears at inductor time due to the size being computable), but we decided against that to remove the padding entirely. It should be doable by extending #2620 to the A100 forloop path.

mmm the decision to remove padding was not related to this problem though right? because this is fixable

num_tokens_per_expert: torch.Tensor,
) -> torch.Tensor:
# dynamic number of tokens in expert parallel
torch._dynamo.mark_dynamic(x, 0)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we don't need to mark_dynamic, we can remove everything in this function except for the per-layer compile call.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good call!

per-layer compiler was working util rebasing on top of recent moe refactor

i might need to do some change over that moe refactor. will publish shortly

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rebased. finally ready for review

@weifengpy weifengpy marked this pull request as draft April 2, 2026 23:34
@weifengpy weifengpy force-pushed the per-layer-compile-moe branch from 20fe4dd to 5874b3f Compare April 3, 2026 04:52
@weifengpy weifengpy changed the title gap for per-layer compiler (with moe) Enable per-layer compilation with or without MoE Apr 3, 2026
@weifengpy weifengpy changed the title Enable per-layer compilation with or without MoE Enable per-layer compile with or without MoE Apr 3, 2026


class TestApplyCompile(unittest.TestCase):
def test_patched_once(self):

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no more monkeypatch on _run_experts_grouped_mm
apply_compile now just sets dynamo config and calls in-place .compile(), both inherently idempotent

for layer_id, transformer_block in model.layers.named_children():
transformer_block.compile(backend=compile_config.backend, fullgraph=True)
# pyrefly: ignore [missing-attribute]
model.layers.register_module(layer_id, transformer_block)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed register_module because of using in-place .compile

@weifengpy

Copy link
Copy Markdown
Contributor Author

@weifengpy

Copy link
Copy Markdown
Contributor Author

@weifengpy weifengpy force-pushed the per-layer-compile-moe branch from 5874b3f to b409060 Compare April 3, 2026 07:15
Comment thread torchtitan/models/common/moe.py Outdated
# NOTE: this would incur a synchronization between device and host
num_tokens_per_expert_list = num_tokens_per_expert.tolist()

total_tokens = sum(num_tokens_per_expert_list)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pianpwk : this is applying your fix to unblock A100 #2399

do you want to land your PR first? or ok with me including it in this PR?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought this was a workaround needed when we apply padding, which was removed in #2774. Do we still need this?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the problematic line is num_tokens_per_expert_list = num_tokens_per_expert.tolist(). this exists before/after moe refactoring (or padding removal)

cc @pianpwk if you have a better answer

from claude

def _run_experts_for_loop(w1, w2, w3, x, num_tokens_per_expert):                                            
      # .tolist() crosses device→host, each value becomes an unbacked symint (u0, u1, ...)
      num_tokens_per_expert_list = num_tokens_per_expert.tolist()                                               
                                                                                                                
      # sum of unbacked symints → another unbacked symint (u_total)                                             
      total_tokens = sum(num_tokens_per_expert_list)                                                            
                                                                                                              
      # The compiler sees: x[:u_total]                                                                          
      # It needs to prove: 0 <= u_total <= x.shape[0]
      # But u_total is unbacked — no concrete value or range info                                               
      # → assertion failure without the guards below                                                          
      torch._check(total_tokens >= 0)                                                                           
      torch._check(total_tokens <= x.shape[0])
                                                                                                                
      x_splits = torch.split(x[:total_tokens], ...)                                                             
   
  The torch._check calls inject symbolic constraints so the compiler can prove the slice is valid.  

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@weifengpy what error are you hitting? you likely have to specify capture_scalar_outputs=True or a similar flag as @xmfan mentioned

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And yeah, I thought #2774 should allow you to delete this variable, and not need the x[:total_tokens] slice, since that was meant to remove padding

@weifengpy weifengpy Apr 3, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

finally make it work with torch._check. there is a gap on how inductor handle torch._check( == ), I have to workaround with torch._check(<=) and torch._check(>=)

will file the inductor issue in pytorch

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.split decompose into split_with_sizes, and trigger torch._check_with( ==

if we do the fix at pytorch-side: pytorch/pytorch#179311 we won't need torch._check at titan side

cc @pianpwk @tianyu-l

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that sounds better, as it solves other use cases as well

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

attaching Pian's fix: pytorch/pytorch#179315 at pytorch side. should be better than mine. anyway, I will remove those torch._check from titan once pytorch side fix landed

converting to draft for now if we want to wait for pytorch side fix

@laithsakka laithsakka Apr 3, 2026

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bobrenjc93 i thought bob suppoorted boolean inputs to graphs in inductor?

@weifengpy weifengpy marked this pull request as ready for review April 3, 2026 07:20
@weifengpy weifengpy requested review from pianpwk and tianyu-l April 3, 2026 07:20
@weifengpy weifengpy force-pushed the per-layer-compile-moe branch 6 times, most recently from 7bfb29d to 0da784a Compare April 14, 2026 02:06
@weifengpy weifengpy marked this pull request as ready for review April 14, 2026 02:17
@weifengpy weifengpy marked this pull request as draft April 14, 2026 05:11
@weifengpy weifengpy force-pushed the per-layer-compile-moe branch from 0da784a to 9f38866 Compare April 14, 2026 06:06
# non_blocking=True is safe in eager, but under torch.compile the
# async D2H transfer can race with the subsequent .tolist()/.item()
# calls, producing stale values and failing unbacked-symint guards.
non_blocking = not torch.compiler.is_compiling()

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fine for me to land. Thanks!

@weifengpy weifengpy force-pushed the per-layer-compile-moe branch from 9f38866 to 61e685e Compare April 14, 2026 07:47
@weifengpy weifengpy marked this pull request as ready for review April 14, 2026 07:53
@weifengpy weifengpy force-pushed the per-layer-compile-moe branch 10 times, most recently from 7c7e3ca to 136cda6 Compare April 14, 2026 22:06
Consolidate apply_compile_dense and apply_compile_sparse into a single
apply_compile function. The only difference was capture_scalar_outputs
which is harmless for dense models.

Remove the _run_experts_grouped_mm separate compile boundary and EP
wrapper — no longer needed now that the CI uses cu130+ nightly which
handles the unbacked-symint Eq(u1, u2) constraints in inductor.

Remove the x[:total_tokens] slice in _run_experts_for_loop — padding
was removed in pytorch#2774, so sum(num_tokens_per_expert) == x.shape[0] and
the slice is a no-op.
@weifengpy weifengpy force-pushed the per-layer-compile-moe branch from 136cda6 to cbc0ec8 Compare April 14, 2026 22:18
@weifengpy

Copy link
Copy Markdown
Contributor Author

note: if people use pytorch nightly with cu128 on A100, --compile.enable --compile.components model,loss would fail because pytorch/pytorch#179315 didn't get in cu128 torch nightly

@weifengpy weifengpy merged commit 5242bdf into pytorch:main Apr 15, 2026
24 of 34 checks passed
saforem2 added a commit to saforem2/torchtitan that referenced this pull request Apr 15, 2026
… models

The upstream compile consolidation (PR pytorch#2741) unconditionally sets
torch._dynamo.config.capture_scalar_outputs=True in apply_compile.
This is needed for MoE dynamic shapes but breaks the separately-compiled
loss_fn when loss_parallel + ignore_index in cross_entropy produce
unbacked symbols (zuf0, zuf1), causing PendingUnbackedSymbolNotFound.

Reset the flag after apply_compile in the dense agpt parallelize path.
saforem2 added a commit to saforem2/torchtitan that referenced this pull request Apr 15, 2026
18 configs (11 agpt + 7 MoE), 10 steps each on 24 XPU tiles.
8/18 passed. Documents the 80B compile regression from upstream
PR pytorch#2741 and the blendcorpus cache race fix.
saforem2 added a commit to saforem2/torchtitan that referenced this pull request Apr 15, 2026
…crash

Documents the PendingUnbackedSymbolNotFound error introduced by upstream
PR pytorch#2741's unconditional capture_scalar_outputs=True, and the fix
(resetting the flag for dense models after apply_compile).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot. high priority

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

6 participants