Skip to content

[BC Breaking] Config System Refactor: TOML to Python Dataclass Registry#2386

Merged
tianyu-l merged 17 commits into
mainfrom
config
Feb 23, 2026
Merged

[BC Breaking] Config System Refactor: TOML to Python Dataclass Registry#2386
tianyu-l merged 17 commits into
mainfrom
config

Conversation

@tianyu-l

@tianyu-l tianyu-l commented Feb 17, 2026

Copy link
Copy Markdown
Contributor

NOTE: This PR is a large refactor of the codebase. https://github.com/pytorch/torchtitan/releases/tag/v0.2.2 contains a latest release right before this PR is merged.

author's note

This refactor is mainly trying to address two issues:

  • bad encapsulation: previously a monolithic JobConfig is leaked everywhere
  • not easy to iterate and experiment on model architecture and training components

The main changes are:

  • Strict encapsulation, even at the cost of (hopefully temporary) bloated interface when calling subcomponents (e.g. validator). We should try to find the right abstraction on cross-components visibility.
  • Each Configurable component owns its own Config, which builds the owner component. It achieves modularization via polymorphism and inheritance, both classic concepts in OOP.
  • Main entry point switches from TOML files to Python functions (a.k.a. config_registry.py in each model).
    • TOML has the constraint that everything needs to be registered explicitly before it can be used, e.g. our quantization components need to be registered with string names. Python's language level implicit registration is what we believe to be more minimal, and should be fairly easy to extended/modified to support TOML/YAML when users builds upon / fork torchtitan.
    • That said, Python config provides much more power, e.g. one can use arbitrary logic to create (the config of) a component, which is hard to express with TOML/YAML, thus creating extra difficulty when users want to migrate to their own favorite config system. The only thing we can do is to stay conservative on the usage of such power.
  • We still uses tyro to convert config dataclass to CLI, still with the limitation that users need to construct customized config classes, all the way from root level (Trainer.Config now, JobConfig in the past).
    • If CLI is not needed, new trainer (or any high-level) config is not required.
    • To support "polymorphic construction" from CLI without the hassle, check out chz.

This PR also

Remaining work

Longer-term issues

  • More careful design about what to put config vs. runtime build kwargs. (thanks @ailzhang)
  • ModelSpec not serializable. There may be multiple solutions, but we can potentially consolidate model.py and parallelize.py by
    • sharing AC, compile, DP application across all Decoder models
    • putting per-module TP/CP/EP sharding plan inside model itself
  • Right now BaseModel.update_from_config violates encapsulation by passing the Trainer config into Model config. This could be avoided by python logic either in config construction time, or in trainer.
  • Refactor init_weights into Module.Config instead of staying in Module
    • The benefit is that param init can be configurable; o/w we are coupling module implementation and its weight init.
    • This may require refactor of current TransformerBlock and its config. E.g. weight_init_std may need to be put in config, with __post_init__ determining its value. (See related complaints / discussions on __post_init__ by chz)

Note to reviewer:
Although I believe the changes in this PR come naturally in a bundle, you may (or may not) find the stack of 16 commits easier to review, as I tried to split the changes in some logic manner. I apologize for the giant PR.

claude-generated summary

Summary

This PR refactors torchtitan's configuration and training infrastructure in 15 incremental, backwards-incompatible commits. The central change replaces TOML config files and a monolithic JobConfig parser with typed Python dataclass configs, a Configurable base class pattern, and a config_registry module per model.

270 files changed, 10,025 insertions, 11,418 deletions.


Motivation

The previous system used TOML files parsed by a custom ConfigManager that layered CLI overrides on top. While simple, this had several friction points:

  1. No type safety at config boundaries. TOML values are strings/ints/floats parsed at runtime. A typo in a key name (e.g., training.stpes) silently becomes a default value.
  2. Flat namespace. All config sections ([model], [training], [optimizer], [checkpoint], ...) lived in a single JobConfig class. Every component received the full JobConfig even when it only needed a few fields.
  3. Experiment extension was ad-hoc. Experiments that needed custom config fields (e.g., SimpleFSDP's compile.graph_passes or FaultTolerant's fault_tolerance.*) required a custom_config_module TOML key and a runtime _merge_configs call to graft new fields onto JobConfig.
  4. Model args were disconnected from model code. A ModelArgs dataclass in args.py defined hyperparameters, but the TrainSpec that bundled model + parallelization + loss was registered separately, with no type-level link between them.

What Changed

1. Configurable Base Class

A new Configurable base class (torchtitan/config/configurable.py) establishes a universal pattern:

class Configurable:
    @dataclass(kw_only=True, slots=True)
    class Config:
        def build(self, **kwargs):
            return self._owner(config=self, **kwargs)

    def __init_subclass__(cls, **kwargs):
        # Auto-wires Config.build() -> cls(config=..., **kwargs)
        # Enforces @dataclass(kw_only=True, slots=True) on every Config

Every configurable component (Trainer, model, optimizer, tokenizer, dataloader, checkpoint manager, metrics, validators, quantization converters, ...) follows this pattern. Calling config.build() constructs the owning class.

2. Trainer.Config Replaces JobConfig

The monolithic JobConfig is replaced by Trainer.Config, a nested dataclass that aggregates typed sub-configs:

class Trainer(Stateful, Configurable):
    @dataclass(kw_only=True, slots=True)
    class Config(Configurable.Config):
        model_spec: ModelSpec | None = None    # set by config_registry, suppressed from CLI
        job: JobConfig = ...
        training: TrainingConfig = ...
        parallelism: ParallelismConfig = ...
        optimizer: OptimizersContainer.Config = ...
        lr_scheduler: LRSchedulersContainer.Config = ...
        checkpoint: CheckpointManager.Config = ...
        dataloader: BaseDataLoader.Config = ...
        metrics: MetricsProcessor.Config = ...
        # ... etc.

Each sub-config is the Config class of the component that consumes it (e.g., CheckpointManager.Config is defined inside CheckpointManager). Components receive only their own config, not the entire training config.

3. config_registry.py Replaces TOML Files

Each model defines a config_registry.py with functions that return complete Trainer.Config instances:

# torchtitan/models/llama3/config_registry.py

def llama3_debugmodel() -> Trainer.Config:
    return Trainer.Config(
        job=JobConfig(description="Llama 3 debug training", ...),
        model_spec=model_registry("debugmodel"),
        optimizer=OptimizersContainer.Config(lr=8e-4),
        training=TrainingConfig(local_batch_size=8, seq_len=2048, steps=10),
        dataloader=HuggingFaceTextDataLoader.Config(dataset="c4_test"),
        # ...
    )

def llama3_debugmodel_float8() -> Trainer.Config:
    config = llama3_debugmodel()
    config.model_converters = ModelConvertersContainer.Config(
        converters=[Float8LinearConverter.Config(enable_fsdp_float8_all_gather=True)]
    )
    return config

4. TrainSpec -> ModelSpec

TrainSpec is renamed to ModelSpec with a narrower scope: it holds only model-specific concerns (model config, parallelization function, loss function, state dict adapter). All training-level concerns (optimizer, LR scheduler, checkpointing, etc.) live in Trainer.Config.

5. Model Configs: Flat ModelArgs -> Nested Dataclass Hierarchy

Model hyperparameters move from a flat ModelArgs dataclass into a nested Config hierarchy that mirrors the module tree:

# Before (main): flat args.py
@dataclass
class ModelArgs:
    dim: int = 4096
    n_layers: int = 32
    n_heads: int = 32
    # ... 20+ flat fields

# After (this PR): nested Config in model class
class Llama3Model(Decoder):
    @dataclass(kw_only=True, slots=True)
    class Config(Decoder.Config):
        layer: Llama3TransformerBlock.Config  # contains attention + FFN configs
        rope: RoPE.Config                    # contains RoPE-specific params

6. train.py Split

The monolithic train.py (~800 lines) is split into:

  • train.py (~60 lines): thin entry point that calls ConfigManager.parse_args() and config.build()
  • trainer.py (~850 lines): the Trainer class with training loop logic

7. Experiment Extension via Inheritance

Experiments extend the config system through dataclass subclassing instead of runtime config merging:

# torchtitan/experiments/simple_fsdp/configs.py
@dataclass(kw_only=True, slots=True)
class SimpleFSDPConfig(Trainer.Config):
    compile: SimpleFSDPCompileConfig = field(default_factory=SimpleFSDPCompileConfig)

Their config_registry.py returns the subclassed config type, and tyro auto-generates CLI parsing for the extended fields.


UX Comparison

Launching Training

# Before (main)
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.py" ./run_train.sh

# After (this PR)
MODEL=llama3 CONFIG=llama3_8b ./run_train.sh

CLI Overrides

# Before (main)
CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh \
  --training.steps 100 --parallelism.tensor_parallel_degree 2

# After (this PR)
./run_train.sh --training.steps 100 --parallelism.tensor_parallel_degree 2
# (defaults to MODEL=llama3, CONFIG=llama3_debugmodel via run_train.sh)

CLI override syntax is unchanged (--section.field value), but tyro now provides typed --help output generated from the dataclass tree.

Defining a New Model Config

# Before: create a new TOML file, copy-paste sections, edit values
cp train_configs/debug_model.toml train_configs/my_experiment.toml
vim train_configs/my_experiment.toml

# After: write a Python function that mutates an existing config
def my_experiment() -> Trainer.Config:
    config = llama3_debugmodel()
    config.training.steps = 100
    config.optimizer.lr = 1e-4
    return config

Adding Experiment-Specific Config Fields

# Before (main): custom_config_module in TOML + runtime _merge_configs
# Requires: TOML key pointing to a Python module, dynamic dataclass creation

# After (this PR): dataclass inheritance
@dataclass(kw_only=True, slots=True)
class MyExperimentConfig(Trainer.Config):
    my_custom_field: str = "default"

Float8 / Quantization Configuration

# Before (main): TOML section
# [quantize.linear.float8]
# enable_fsdp_float8_all_gather = true
# precompute_float8_dynamic_scale_for_fsdp = true

# After (this PR): typed config object
model_converters=ModelConvertersContainer.Config(
    converters=[
        Float8LinearConverter.Config(
            enable_fsdp_float8_all_gather=True,
            precompute_float8_dynamic_scale_for_fsdp=True,
        ),
    ],
),

Limitations and Trade-offs

1. Configs are no longer declarative text files

TOML files were readable by anyone without Python knowledge. The new config_registry functions are Python code, which requires understanding imports, function calls, and dataclass construction. For users who only need to tweak hyperparameters, the CLI override syntax (--training.steps 100) works the same, but understanding the full config requires reading Python.

2. Steeper learning curve for contributors

Adding a new model now requires understanding the Configurable protocol, nested Config dataclass hierarchy, and the config_registry pattern. The old approach of copying a TOML file and editing values had a lower barrier to entry.

3. Config serialization is more complex

TOML files were trivially serializable and diffable. The new system supports to_dict() + JSON serialization, but configs containing callables (e.g., ModelSpec.parallelize_fn) cannot be fully round-tripped. The model_spec field is excluded from serialization and suppressed from CLI parsing.

4. tyro dependency

The CLI parsing now depends on tyro, a third-party library. While tyro is well-maintained and provides typed CLI generation from dataclasses, it is an additional dependency that must be kept compatible with the dataclass patterns used here.

5. @dataclass(slots=True) constraints

The Configurable base class enforces @dataclass(kw_only=True, slots=True) on all Config classes. While this provides memory efficiency and prevents accidental attribute assignment, slots=True prevents dynamic attribute addition and makes multiple inheritance with other slotted classes more constrained. Each Config subclass in a deep hierarchy must repeat the @dataclass(kw_only=True, slots=True) decorator.

6. Two-level indirection for model selection

The old system required one identifier: --job.config_file path/to/file.toml. The new system requires two: --module llama3 --config llama3_8b. While this separates model identity from training recipe, it adds an extra argument.


Numerics Verification

All model configs were verified for numerical equivalence against the main branch (commit 10d8a306):

NOTE

  • only models that can fit on 8 GPUs are tested
  • only subset of parallelism combination are tested
Model Status Notes
llama3 (debugmodel, 8B) Bitwise match
llama3 (debugmodel_flex_attn) Bitwise match
qwen3 (0.6B, 1.7B, 32B, MoE debugmodel) Bitwise match
deepseek_v3 (debugmodel, 16B) Close (max diff 0.00014) Pre-existing main branch bug: missing eps in final RMSNorm
llama4 debugmodel Bitwise match _irope variants don't work on main (FlexAttn 'dict' object has no attribute 'BLOCK_SIZE') but now work after this PR
gpt_oss debugmodel --debug.deterministic causes loss to be NaN; o/w first step loss match, minor difference after (likely caused by flex?)
flux Bitwise match

Migration Guide

Old (main) New (this PR)
CONFIG_FILE="path/to/config.toml" ./run_train.sh MODEL=llama3 CONFIG=llama3_8b ./run_train.sh
--job.config_file path.toml --module llama3 --config llama3_8b
train_configs/*.toml config_registry.py functions
TrainSpec ModelSpec
ModelArgs / args.py Nested Model.Config dataclass
custom_config_module + _merge_configs() Subclass Trainer.Config
build_model_converters() free function ModelConvertersContainer.Config.build()
build_metrics_processor() free function MetricsProcessor.Config.build()

@wconstab

Copy link
Copy Markdown
Contributor

Have you tried asking claude to help split this into a stack of PRs that can be reviewed/landed independently? I did see your comment/apology to reviewers, but i still think honestly nobody is going to review this PR in its entirety so are you asking for an uncareful scan and a stamp, or do you want to break out important pieces of the code that you want careful review on?

@tianyu-l

Copy link
Copy Markdown
Contributor Author

@wconstab While I understand how intimidating it could be for reviewing a huge PR, I would like to initially deliver the package as whole instead of letting people only see incremental changes (if it's possible at all). Maybe I would like to achieve

  • [alignment] most reviewers would pick a single model / config registry and play with it and convince themselves the change looks OK in general. The correctness would be guaranteed by my numerics test and CI (will fix at least the core ones)
  • [more careful check] If reviewers get aligned and would like to help review line-by-line, I'm more than happy to split into a stack of PRs.

Comment thread scripts/generate/test_generate.py Outdated
Comment thread tests/integration_tests/features.py
Comment thread tests/unit_tests/test_dataset_checkpointing.py Outdated
Comment thread tests/unit_tests/test_dataset_checkpointing.py Outdated
Comment thread torchtitan/components/optimizer.py Outdated
model: nn.Module,
parallel_dims: ParallelDims,
job_config: JobConfig,
*,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the rule for differentiating the positional args and kw args?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm inspired by https://github.com/apple/axlearn/blob/main/docs/ml_api_style.md#avoid-multiple-positional-arguments. Here I'm moving parallel_dims to kwarg as well.

@acisseJZhong acisseJZhong Feb 19, 2026

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

limit the number of positional arguments to <= 1 and use keyword arguments for the rest

what's the reason not making them all kwargs?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know for sure. Likely because for some functions, there would always be a "main" arg that is always there and doesn't introduce ambiguity / error-proneness. E.g. if a function only takes one arg, like parallelize(model), maybe it's fine? You can imagine later on when people adds more and more optional kwargs to the function, the model part doesn't need to be changed.

Comment thread torchtitan/distributed/dual_pipe_v.py Outdated
train_spec: TrainSpec,
def register_torchtitan_model_from_model_spec(
model_spec: ModelSpec,
model_name: str,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

model_name should be part of model_spec

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Model name refers to another thing: model_name: Name to register in vLLM (e.g., "Qwen3TorchTitanForCausalLM"). Maybe need a more descriptive name here, but that could be done in a separate PR

Comment thread torchtitan/models/common/attention.py Outdated
Comment thread torchtitan/trainer.py

@wwwjn wwwjn left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mainly took a look on components, config, experiments/rl, models, train.py and trainer.py

Comment thread run_train.sh Outdated
Comment thread torchtitan/models/qwen3/config_registry.py
train_spec: TrainSpec,
def register_torchtitan_model_from_model_spec(
model_spec: ModelSpec,
model_name: str,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Model name refers to another thing: model_name: Name to register in vLLM (e.g., "Qwen3TorchTitanForCausalLM"). Maybe need a more descriptive name here, but that could be done in a separate PR

Comment thread torchtitan/models/flux/flux_datasets.py
Comment thread torchtitan/models/flux/README.md Outdated
Comment thread torchtitan/config/configs.py Outdated
Comment thread torchtitan/config/configs.py Outdated
Comment thread torchtitan/experiments/transformers_modeling_backend/README.md

@fegin fegin left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The overall direction looks good to me, will go over again in more detail. The only uncertainty is what's the best way to handle the shared configurations as mentioning in the review below.

Comment thread torchtitan/trainer.py
Comment on lines 407 to 413
parallel_dims: ParallelDims,
dump_folder: str = "./outputs",
pp_schedule: str = "1F1B",
ft_enable: bool = False,
ft_replica_id: int = 0,
config_dict: dict[str, Any] | None = None,
tag: str | None = None,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also am worried about the kwargs being a loophole where people passing configurations around.

@fegin fegin Feb 18, 2026

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An alternative approach is to require each component define these shared configurations and resolve the shared configurations when constructing the root configuration (Trainer.Config).

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this is the top issue I put in "Longer-term Issues" in PR summary, which I couldn't handle entirely in this initial PR.

First, we need to figure out the boundary between "shared config" and "runtime kwargs". We can use more shared config, but that is "utilizing" (a.k.a. "abusing") the python config power in a way that makes it harder to transform to pure yaml solution, which may be OK.

More importantly, we need to reconsider if the current function calling structure makes sense at all. Current metrics logging is limited and hard to customize -- e.g. in MoE how to log the number of tokens each expert processes? In this sense, such problems are reflecting the design flaws we have in torchtitan -- in the past, these are omitted due to the usage of JobConfig everywhere. I think this is one of the good things about this refactor.



register_model_converter(Float8LinearConverter, "quantize.linear.float8")
register_model_converter(Float8GroupedMMConverter, "quantize.grouped_mm.float8")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i see this removes the converter names (quantize.grouped_mm.float8 etc) - what does the command line API for this look like now?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We lose CLI capability for adjusting this, because there is no string attached to each converter anymore.

Comment thread torchtitan/components/quantization/float8.py
xmfan added a commit to meta-pytorch/autoparallel that referenced this pull request Feb 24, 2026
Torchtitan merged a BC-breaking config system refactor (pytorch/torchtitan#2386)
that replaced TOML configs with Python dataclass configs and changed the
CLI from CONFIG_FILE + --model.name to --module + --config.

Updates the CI commands accordingly. Also fixes a runtime crash where
aliased buffers (registered for user-facing API compat by #321) were
being passed to the compiled graph, which only expects the canonical
(deduplicated) set. The deepseek_v3 test is commented out as it's also
disabled in torchtitan's own CI.

Authored with Claude.
xmfan added a commit to meta-pytorch/autoparallel that referenced this pull request Feb 24, 2026
Torchtitan merged a BC-breaking config system refactor (pytorch/torchtitan#2386)
that replaced TOML configs with Python dataclass configs and changed the
CLI from CONFIG_FILE + --model.name to --module + --config.

Updates the CI commands accordingly. Also fixes a runtime crash where
aliased buffers (registered for user-facing API compat by #321) were
being passed to the compiled graph, which only expects the canonical
(deduplicated) set. The deepseek_v3 test is commented out as it's also
disabled in torchtitan's own CI.

Authored with Claude.

stack-info: PR: #325, branch: xmfan/stack/26
xmfan added a commit to meta-pytorch/autoparallel that referenced this pull request Feb 25, 2026
Torchtitan merged a BC-breaking config system refactor (pytorch/torchtitan#2386)
that replaced TOML configs with Python dataclass configs and changed the
CLI from CONFIG_FILE + --model.name to --module + --config.

Updates the CI commands accordingly. Also fixes a runtime crash where
aliased buffers (registered for user-facing API compat by #321) were
being passed to the compiled graph, which only expects the canonical
(deduplicated) set. The deepseek_v3 test is commented out as it's also
disabled in torchtitan's own CI.

Authored with Claude.

stack-info: PR: #325, branch: xmfan/stack/26
xmfan added a commit to meta-pytorch/autoparallel that referenced this pull request Feb 25, 2026
Torchtitan merged a BC-breaking config system refactor (pytorch/torchtitan#2386)
that replaced TOML configs with Python dataclass configs and changed the
CLI from CONFIG_FILE + --model.name to --module + --config.

Updates the CI commands accordingly. Also fixes a runtime crash where
aliased buffers (registered for user-facing API compat by #321) were
being passed to the compiled graph, which only expects the canonical
(deduplicated) set. The deepseek_v3 test is commented out as it's also
disabled in torchtitan's own CI.

Authored with Claude.

stack-info: PR: #325, branch: xmfan/stack/26
tianyu-l pushed a commit that referenced this pull request Feb 25, 2026
* support launching custom trainer;
* init trainer components through .build() (#2386);
* move data to GPU by micro-batch;
* remove rescale_accumulated_loss (#2206).
xmfan added a commit to meta-pytorch/autoparallel that referenced this pull request Feb 25, 2026
Torchtitan merged a BC-breaking config system refactor (pytorch/torchtitan#2386)
that replaced TOML configs with Python dataclass configs and changed the
CLI from CONFIG_FILE + --model.name to --module + --config.

Updates the CI commands accordingly. Also fixes a runtime crash where
aliased buffers (registered for user-facing API compat by #321) were
being passed to the compiled graph, which only expects the canonical
(deduplicated) set. The deepseek_v3 test is commented out as it's also
disabled in torchtitan's own CI.

Authored with Claude.

stack-info: PR: #325, branch: xmfan/stack/26
wuxibin89 pushed a commit to verl-project/verl that referenced this pull request Mar 3, 2026
### What does this PR do?

Update the trainer API used in Torchtitan Engine after refactor in
pytorch/torchtitan#2386.


### Test

```
GRPC_ENABLE_FORK_SUPPORT=0 NCCL_NVLS_ENABLE=0 MODEL_ID=Qwen/Qwen3-0.6B bash tests/special_e2e/run_ppo_trainer_torchtitan.sh
```
loss on par with before


### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [ ] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [ ] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [ ] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [ ] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [ ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
- [ ] If your PR is related to the `recipe` submodule, please also
update the reference to the submodule commit via `git submodule update
--remote` or `cd recipe && git pull origin main`.
guillemgt pushed a commit to guillemgt/verl that referenced this pull request Mar 9, 2026
…#5457)

### What does this PR do?

Update the trainer API used in Torchtitan Engine after refactor in
pytorch/torchtitan#2386.


### Test

```
GRPC_ENABLE_FORK_SUPPORT=0 NCCL_NVLS_ENABLE=0 MODEL_ID=Qwen/Qwen3-0.6B bash tests/special_e2e/run_ppo_trainer_torchtitan.sh
```
loss on par with before


### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [ ] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [ ] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [ ] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [ ] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [ ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
- [ ] If your PR is related to the `recipe` submodule, please also
update the reference to the submodule commit via `git submodule update
--remote` or `cd recipe && git pull origin main`.
guillemgt added a commit to guillemgt/verl that referenced this pull request Mar 9, 2026
…#5457)

### What does this PR do?

Update the trainer API used in Torchtitan Engine after refactor in
pytorch/torchtitan#2386.

### Test

```
GRPC_ENABLE_FORK_SUPPORT=0 NCCL_NVLS_ENABLE=0 MODEL_ID=Qwen/Qwen3-0.6B bash tests/special_e2e/run_ppo_trainer_torchtitan.sh
```
loss on par with before

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [ ] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [ ] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [ ] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [ ] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [ ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
- [ ] If your PR is related to the `recipe` submodule, please also
update the reference to the submodule commit via `git submodule update
--remote` or `cd recipe && git pull origin main`.
DearFishi pushed a commit to KunlunxinAD/verl that referenced this pull request Mar 20, 2026
…#5457)

### What does this PR do?

Update the trainer API used in Torchtitan Engine after refactor in
pytorch/torchtitan#2386.


### Test

```
GRPC_ENABLE_FORK_SUPPORT=0 NCCL_NVLS_ENABLE=0 MODEL_ID=Qwen/Qwen3-0.6B bash tests/special_e2e/run_ppo_trainer_torchtitan.sh
```
loss on par with before


### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [ ] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [ ] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [ ] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [ ] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [ ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
- [ ] If your PR is related to the `recipe` submodule, please also
update the reference to the submodule commit via `git submodule update
--remote` or `cd recipe && git pull origin main`.
muchanem added a commit to muchanem/weightshared_torchtitan that referenced this pull request Mar 28, 2026
Brings in pytorch/torchtitan's BC-breaking config refactor (PR pytorch#2386):
- TOML-based config -> Python dataclass registry
- JobConfig -> Trainer.Config with nested component configs
- TrainSpec/ModelProtocol -> ModelSpec/BaseModel/Module protocol
- New common model components (Linear, Embedding, RMSNorm, etc.)

For Qwen3: accepted upstream's new flat structure as base.
Our weight sharing modules (weight_sharing.py, factorized_embedding.py,
LoRA modules) preserved in model/ subdirectory for re-integration.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
sijyang pushed a commit to sijyang/verl that referenced this pull request Apr 1, 2026
…#5457)

### What does this PR do?

Update the trainer API used in Torchtitan Engine after refactor in
pytorch/torchtitan#2386.


### Test

```
GRPC_ENABLE_FORK_SUPPORT=0 NCCL_NVLS_ENABLE=0 MODEL_ID=Qwen/Qwen3-0.6B bash tests/special_e2e/run_ppo_trainer_torchtitan.sh
```
loss on par with before


### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [ ] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [ ] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [ ] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [ ] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [ ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
- [ ] If your PR is related to the `recipe` submodule, please also
update the reference to the submodule commit via `git submodule update
--remote` or `cd recipe && git pull origin main`.
TXacs pushed a commit to McmillanTAC/torchtitan that referenced this pull request Apr 13, 2026
…ry (pytorch#2386)

**NOTE**: This PR is a large refactor of the codebase.
https://github.com/pytorch/torchtitan/releases/tag/v0.2.2 contains a
latest release right before this PR is merged.

# author's note

This refactor is mainly trying to address two issues:
- bad encapsulation: previously a monolithic `JobConfig` is leaked
everywhere
- not easy to iterate and experiment on model architecture and training
components

The main changes are:
- Strict encapsulation, even at the cost of (hopefully temporary)
bloated interface when calling subcomponents (e.g. validator). We should
try to find the right abstraction on cross-components visibility.
- Each `Configurable` component owns its own `Config`, which builds the
owner component. It achieves modularization via polymorphism and
inheritance, both classic concepts in OOP.
- This is partly inspired by repos like
[AXLearn](https://github.com/apple/axlearn) (in particular, @ruomingp's
[ML API
Styles](https://github.com/apple/axlearn/blob/main/docs/ml_api_style.md)),
github issues (e.g. pytorch#1055),
and offline discussions (with @Chillee, @ailzhang etc.).
- Similar functionality can be alternatively achieved by other ways,
e.g. `_target_` in
[Hydra](https://hydra.cc/docs/advanced/instantiate_objects/overview/),
but there are opinions not to couple with Hydra's other offerings. See
pytorch#1415
- Main entry point switches from TOML files to Python functions (a.k.a.
`config_registry.py` in each model).
- TOML has the constraint that everything needs to be registered
explicitly before it can be used, e.g. our quantization components need
to be registered with string names. Python's language level implicit
registration is what we believe to be more minimal, and should be fairly
easy to extended/modified to support TOML/YAML when users builds upon /
fork torchtitan.
- That said, Python config provides much more power, e.g. one can use
arbitrary logic to create (the config of) a component, which is hard to
express with TOML/YAML, thus creating extra difficulty when users want
to migrate to their own favorite config system. The only thing we can do
is to stay conservative on the usage of such power.
- We still uses [tyro](https://github.com/brentyi/tyro) to convert
config dataclass to CLI, still with the limitation that users need to
construct customized config classes, all the way from root level
(`Trainer.Config` now, `JobConfig` in the past).
- If CLI is not needed, new trainer (or any high-level) config is not
required.
- To support "polymorphic construction" from CLI without the hassle,
check out
[chz](https://github.com/openai/chz/blob/main/docs/04_command_line.md#polymorphic-construction).

This PR also
- updates the docs -- there might be remaining outdated docs, please
raise issues or help fix
- moves ft to experiments, continuing the effort in
pytorch#2311

Remaining work
- [AutoParallel CI
failure](https://github.com/pytorch/torchtitan/actions/runs/22165425254/job/64091572780?pr=2386)
seems caused by the way RoPE is authored, and needs change in
autoparallel. (cc @xmfan)
  - being fixed in meta-pytorch/autoparallel#321
- [CompilerToolkit CI
failure](https://github.com/pytorch/torchtitan/actions/runs/22168015737/job/64099486707?pr=2386)
`TypeError: forward() missing 1 required positional argument:
'fwd_rng_state_2'` cc @yiming0416 please help take a look
- [SimpleFSDP CI
failure](https://github.com/pytorch/torchtitan/actions/runs/22168015749/job/64099486149?pr=2386)
is the same as pytorch#2312 around
dynamic shape for for-loop MoE experts computation. (cc @pianpwk)
  - being fixed in pytorch#2399
- Fix integration scripts for MAST, Zoomer, etc.
- organize docs from `docs/` to subfolders, as we are having more
contents to cover in general
- generate and store serialized configs (maybe not in the repo)
- continue SAC refactor in
pytorch#2357, but somehow keep the
every-other-mm policy (cc @mori360)
- refactor RoPE in general, at least resolving the following TODOs in
code (cc @shuhuayu)
- having to set / no validation on rope dim == decoder dim // attention
n_heads
- consolidate `apply_rotary_emb_complex` and
`apply_rotary_emb_single_complex`
  - address pytorch#2417

Longer-term issues
- More careful design about what to put config vs. runtime build kwargs.
(thanks @ailzhang)
- ModelSpec not serializable. There may be multiple solutions, but we
can potentially consolidate `model.py` and `parallelize.py` by
  - sharing AC, compile, DP application across all Decoder models
  - putting per-module TP/CP/EP sharding plan inside model itself
- Right now `BaseModel.update_from_config` violates encapsulation by
passing the Trainer config into Model config. This could be avoided by
python logic either in config construction time, or in trainer.
- Refactor `init_weights` into `Module.Config` instead of staying in
`Module`
- The benefit is that param init can be configurable; o/w we are
coupling module implementation and its weight init.
- This may require refactor of current TransformerBlock and its config.
E.g. `weight_init_std` may need to be put in config, with
`__post_init__` determining its value. (See related complaints /
discussions on `__post_init__` by
[chz](https://github.com/openai/chz/blob/main/docs/21_post_init.md))

Note to reviewer:
Although I believe the changes in this PR come naturally in a bundle,
you may (or may not) find the stack of 16 commits easier to review, as I
tried to split the changes in some logic manner. I apologize for the
giant PR.

# claude-generated summary

## Summary

This PR refactors torchtitan's configuration and training infrastructure
in 15 incremental, backwards-incompatible commits. The central change
replaces TOML config files and a monolithic `JobConfig` parser with
**typed Python dataclass configs**, a **`Configurable` base class**
pattern, and a **`config_registry`** module per model.

**270 files changed, 10,025 insertions, 11,418 deletions.**

---

## Motivation

The previous system used TOML files parsed by a custom `ConfigManager`
that layered CLI overrides on top. While simple, this had several
friction points:

1. **No type safety at config boundaries.** TOML values are
strings/ints/floats parsed at runtime. A typo in a key name (e.g.,
`training.stpes`) silently becomes a default value.
4. **Flat namespace.** All config sections (`[model]`, `[training]`,
`[optimizer]`, `[checkpoint]`, ...) lived in a single `JobConfig` class.
Every component received the full `JobConfig` even when it only needed a
few fields.
5. **Experiment extension was ad-hoc.** Experiments that needed custom
config fields (e.g., SimpleFSDP's `compile.graph_passes` or
FaultTolerant's `fault_tolerance.*`) required a `custom_config_module`
TOML key and a runtime `_merge_configs` call to graft new fields onto
`JobConfig`.
6. **Model args were disconnected from model code.** A `ModelArgs`
dataclass in `args.py` defined hyperparameters, but the `TrainSpec` that
bundled model + parallelization + loss was registered separately, with
no type-level link between them.

---

## What Changed

### 1. `Configurable` Base Class

A new `Configurable` base class (`torchtitan/config/configurable.py`)
establishes a universal pattern:

```python
class Configurable:
    @DataClass(kw_only=True, slots=True)
    class Config:
        def build(self, **kwargs):
            return self._owner(config=self, **kwargs)

    def __init_subclass__(cls, **kwargs):
        # Auto-wires Config.build() -> cls(config=..., **kwargs)
        # Enforces @DataClass(kw_only=True, slots=True) on every Config
```

Every configurable component (Trainer, model, optimizer, tokenizer,
dataloader, checkpoint manager, metrics, validators, quantization
converters, ...) follows this pattern. Calling `config.build()`
constructs the owning class.

### 2. `Trainer.Config` Replaces `JobConfig`

The monolithic `JobConfig` is replaced by `Trainer.Config`, a nested
dataclass that aggregates typed sub-configs:

```python
class Trainer(Stateful, Configurable):
    @DataClass(kw_only=True, slots=True)
    class Config(Configurable.Config):
        model_spec: ModelSpec | None = None    # set by config_registry, suppressed from CLI
        job: JobConfig = ...
        training: TrainingConfig = ...
        parallelism: ParallelismConfig = ...
        optimizer: OptimizersContainer.Config = ...
        lr_scheduler: LRSchedulersContainer.Config = ...
        checkpoint: CheckpointManager.Config = ...
        dataloader: BaseDataLoader.Config = ...
        metrics: MetricsProcessor.Config = ...
        # ... etc.
```

Each sub-config is the `Config` class of the component that consumes it
(e.g., `CheckpointManager.Config` is defined inside
`CheckpointManager`). Components receive only their own config, not the
entire training config.

### 3. `config_registry.py` Replaces TOML Files

Each model defines a `config_registry.py` with functions that return
complete `Trainer.Config` instances:

```python
# torchtitan/models/llama3/config_registry.py

def llama3_debugmodel() -> Trainer.Config:
    return Trainer.Config(
        job=JobConfig(description="Llama 3 debug training", ...),
        model_spec=model_registry("debugmodel"),
        optimizer=OptimizersContainer.Config(lr=8e-4),
        training=TrainingConfig(local_batch_size=8, seq_len=2048, steps=10),
        dataloader=HuggingFaceTextDataLoader.Config(dataset="c4_test"),
        # ...
    )

def llama3_debugmodel_float8() -> Trainer.Config:
    config = llama3_debugmodel()
    config.model_converters = ModelConvertersContainer.Config(
        converters=[Float8LinearConverter.Config(enable_fsdp_float8_all_gather=True)]
    )
    return config
```

### 4. `TrainSpec` -> `ModelSpec`

`TrainSpec` is renamed to `ModelSpec` with a narrower scope: it holds
only model-specific concerns (model config, parallelization function,
loss function, state dict adapter). All training-level concerns
(optimizer, LR scheduler, checkpointing, etc.) live in `Trainer.Config`.

### 5. Model Configs: Flat `ModelArgs` -> Nested Dataclass Hierarchy

Model hyperparameters move from a flat `ModelArgs` dataclass into a
nested `Config` hierarchy that mirrors the module tree:

```python
# Before (main): flat args.py
@DataClass
class ModelArgs:
    dim: int = 4096
    n_layers: int = 32
    n_heads: int = 32
    # ... 20+ flat fields

# After (this PR): nested Config in model class
class Llama3Model(Decoder):
    @DataClass(kw_only=True, slots=True)
    class Config(Decoder.Config):
        layer: Llama3TransformerBlock.Config  # contains attention + FFN configs
        rope: RoPE.Config                    # contains RoPE-specific params
```

### 6. `train.py` Split

The monolithic `train.py` (~800 lines) is split into:
- `train.py` (~60 lines): thin entry point that calls
`ConfigManager.parse_args()` and `config.build()`
- `trainer.py` (~850 lines): the `Trainer` class with training loop
logic

### 7. Experiment Extension via Inheritance

Experiments extend the config system through dataclass subclassing
instead of runtime config merging:

```python
# torchtitan/experiments/simple_fsdp/configs.py
@DataClass(kw_only=True, slots=True)
class SimpleFSDPConfig(Trainer.Config):
    compile: SimpleFSDPCompileConfig = field(default_factory=SimpleFSDPCompileConfig)
```

Their `config_registry.py` returns the subclassed config type, and
`tyro` auto-generates CLI parsing for the extended fields.

---

## UX Comparison

### Launching Training

```bash
# Before (main)
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.py" ./run_train.sh

# After (this PR)
MODEL=llama3 CONFIG=llama3_8b ./run_train.sh
```

### CLI Overrides

```bash
# Before (main)
CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh \
  --training.steps 100 --parallelism.tensor_parallel_degree 2

# After (this PR)
./run_train.sh --training.steps 100 --parallelism.tensor_parallel_degree 2
# (defaults to MODEL=llama3, CONFIG=llama3_debugmodel via run_train.sh)
```

CLI override syntax is unchanged (`--section.field value`), but `tyro`
now provides typed `--help` output generated from the dataclass tree.

### Defining a New Model Config

```bash
# Before: create a new TOML file, copy-paste sections, edit values
cp train_configs/debug_model.toml train_configs/my_experiment.toml
vim train_configs/my_experiment.toml

# After: write a Python function that mutates an existing config
def my_experiment() -> Trainer.Config:
    config = llama3_debugmodel()
    config.training.steps = 100
    config.optimizer.lr = 1e-4
    return config
```

### Adding Experiment-Specific Config Fields

```python
# Before (main): custom_config_module in TOML + runtime _merge_configs
# Requires: TOML key pointing to a Python module, dynamic dataclass creation

# After (this PR): dataclass inheritance
@DataClass(kw_only=True, slots=True)
class MyExperimentConfig(Trainer.Config):
    my_custom_field: str = "default"
```

### Float8 / Quantization Configuration

```python
# Before (main): TOML section
# [quantize.linear.float8]
# enable_fsdp_float8_all_gather = true
# precompute_float8_dynamic_scale_for_fsdp = true

# After (this PR): typed config object
model_converters=ModelConvertersContainer.Config(
    converters=[
        Float8LinearConverter.Config(
            enable_fsdp_float8_all_gather=True,
            precompute_float8_dynamic_scale_for_fsdp=True,
        ),
    ],
),
```

---

## Limitations and Trade-offs

### 1. Configs are no longer declarative text files

TOML files were readable by anyone without Python knowledge. The new
config_registry functions are Python code, which requires understanding
imports, function calls, and dataclass construction. For users who only
need to tweak hyperparameters, the CLI override syntax
(`--training.steps 100`) works the same, but understanding the full
config requires reading Python.

### 2. Steeper learning curve for contributors

Adding a new model now requires understanding the `Configurable`
protocol, nested `Config` dataclass hierarchy, and the `config_registry`
pattern. The old approach of copying a TOML file and editing values had
a lower barrier to entry.

### 3. Config serialization is more complex

TOML files were trivially serializable and diffable. The new system
supports `to_dict()` + JSON serialization, but configs containing
callables (e.g., `ModelSpec.parallelize_fn`) cannot be fully
round-tripped. The `model_spec` field is excluded from serialization and
suppressed from CLI parsing.

### 4. tyro dependency

The CLI parsing now depends on `tyro`, a third-party library. While
`tyro` is well-maintained and provides typed CLI generation from
dataclasses, it is an additional dependency that must be kept compatible
with the dataclass patterns used here.

### 5. `@dataclass(slots=True)` constraints

The `Configurable` base class enforces `@dataclass(kw_only=True,
slots=True)` on all Config classes. While this provides memory
efficiency and prevents accidental attribute assignment, `slots=True`
prevents dynamic attribute addition and makes multiple inheritance with
other slotted classes more constrained. Each Config subclass in a deep
hierarchy must repeat the `@dataclass(kw_only=True, slots=True)`
decorator.

### 6. Two-level indirection for model selection

The old system required one identifier: `--job.config_file
path/to/file.toml`. The new system requires two: `--module llama3
--config llama3_8b`. While this separates model identity from training
recipe, it adds an extra argument.

---

## Numerics Verification

All model configs were verified for numerical equivalence against the
main branch (commit `10d8a306`):

NOTE
- only models that can fit on 8 GPUs are tested
- only subset of parallelism combination are tested

| Model | Status | Notes |
|-------|--------|-------|
| llama3 (debugmodel, 8B) | Bitwise match | |
| llama3 (debugmodel_flex_attn) | Bitwise match | |
| qwen3 (0.6B, 1.7B, 32B, MoE debugmodel) | Bitwise match | |
| deepseek_v3 (debugmodel, 16B) | Close (max diff 0.00014) |
Pre-existing main branch bug: missing `eps` in final RMSNorm |
| llama4 debugmodel | Bitwise match | _irope variants don't work on main
(FlexAttn `'dict' object has no attribute 'BLOCK_SIZE'`) but now work
after this PR |
| **gpt_oss** debugmodel | --debug.deterministic causes loss to be NaN;
o/w first step loss match, minor difference after (likely caused by
flex?) | |
| flux | Bitwise match | |

---

## Migration Guide

| Old (main) | New (this PR) |
|---|---|
| `CONFIG_FILE="path/to/config.toml" ./run_train.sh` | `MODEL=llama3
CONFIG=llama3_8b ./run_train.sh` |
| `--job.config_file path.toml` | `--module llama3 --config llama3_8b` |
| `train_configs/*.toml` | `config_registry.py` functions |
| `TrainSpec` | `ModelSpec` |
| `ModelArgs` / `args.py` | Nested `Model.Config` dataclass |
| `custom_config_module` + `_merge_configs()` | Subclass
`Trainer.Config` |
| `build_model_converters()` free function |
`ModelConvertersContainer.Config.build()` |
| `build_metrics_processor()` free function |
`MetricsProcessor.Config.build()` |
TXacs pushed a commit to McmillanTAC/torchtitan that referenced this pull request Apr 13, 2026
* support launching custom trainer;
* init trainer components through .build() (pytorch#2386);
* move data to GPU by micro-batch;
* remove rescale_accumulated_loss (pytorch#2206).
hann-wang added a commit to AMD-AGI/torchtitan-amd that referenced this pull request Apr 15, 2026
* support launching custom trainer;
* init trainer components through .build() (pytorch#2386);
* move data to GPU by micro-batch;
* remove rescale_accumulated_loss (pytorch#2206).
tianyu-l pushed a commit that referenced this pull request Apr 18, 2026
…ure (#3012)

Hey there 👋 

Looks like in #2386 a flat
folder structure for each model was introduced.
Now there are no `model` or `infra` sub-folder in each model's folder.
(With one exception - `flux` model. It decided to go rogue.)

But the `torchtitan/models/README.md` hasn't been updated to reflect
these changes.
DaizeDong pushed a commit to DaizeDong/verl that referenced this pull request Apr 19, 2026
…#5457)

### What does this PR do?

Update the trainer API used in Torchtitan Engine after refactor in
pytorch/torchtitan#2386.


### Test

```
GRPC_ENABLE_FORK_SUPPORT=0 NCCL_NVLS_ENABLE=0 MODEL_ID=Qwen/Qwen3-0.6B bash tests/special_e2e/run_ppo_trainer_torchtitan.sh
```
loss on par with before


### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [ ] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [ ] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [ ] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [ ] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [ ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
- [ ] If your PR is related to the `recipe` submodule, please also
update the reference to the submodule commit via `git submodule update
--remote` or `cd recipe && git pull origin main`.
ACharacterInASimulation pushed a commit to ACharacterInASimulation/torchtitan that referenced this pull request Apr 21, 2026
…ry (pytorch#2386)

**NOTE**: This PR is a large refactor of the codebase.
https://github.com/pytorch/torchtitan/releases/tag/v0.2.2 contains a
latest release right before this PR is merged.

# author's note

This refactor is mainly trying to address two issues:
- bad encapsulation: previously a monolithic `JobConfig` is leaked
everywhere
- not easy to iterate and experiment on model architecture and training
components

The main changes are:
- Strict encapsulation, even at the cost of (hopefully temporary)
bloated interface when calling subcomponents (e.g. validator). We should
try to find the right abstraction on cross-components visibility.
- Each `Configurable` component owns its own `Config`, which builds the
owner component. It achieves modularization via polymorphism and
inheritance, both classic concepts in OOP.
- This is partly inspired by repos like
[AXLearn](https://github.com/apple/axlearn) (in particular, @ruomingp's
[ML API
Styles](https://github.com/apple/axlearn/blob/main/docs/ml_api_style.md)),
github issues (e.g. pytorch#1055),
and offline discussions (with @Chillee, @ailzhang etc.).
- Similar functionality can be alternatively achieved by other ways,
e.g. `_target_` in
[Hydra](https://hydra.cc/docs/advanced/instantiate_objects/overview/),
but there are opinions not to couple with Hydra's other offerings. See
pytorch#1415
- Main entry point switches from TOML files to Python functions (a.k.a.
`config_registry.py` in each model).
- TOML has the constraint that everything needs to be registered
explicitly before it can be used, e.g. our quantization components need
to be registered with string names. Python's language level implicit
registration is what we believe to be more minimal, and should be fairly
easy to extended/modified to support TOML/YAML when users builds upon /
fork torchtitan.
- That said, Python config provides much more power, e.g. one can use
arbitrary logic to create (the config of) a component, which is hard to
express with TOML/YAML, thus creating extra difficulty when users want
to migrate to their own favorite config system. The only thing we can do
is to stay conservative on the usage of such power.
- We still uses [tyro](https://github.com/brentyi/tyro) to convert
config dataclass to CLI, still with the limitation that users need to
construct customized config classes, all the way from root level
(`Trainer.Config` now, `JobConfig` in the past).
- If CLI is not needed, new trainer (or any high-level) config is not
required.
- To support "polymorphic construction" from CLI without the hassle,
check out
[chz](https://github.com/openai/chz/blob/main/docs/04_command_line.md#polymorphic-construction).

This PR also
- updates the docs -- there might be remaining outdated docs, please
raise issues or help fix
- moves ft to experiments, continuing the effort in
pytorch#2311

Remaining work
- [AutoParallel CI
failure](https://github.com/pytorch/torchtitan/actions/runs/22165425254/job/64091572780?pr=2386)
seems caused by the way RoPE is authored, and needs change in
autoparallel. (cc @xmfan)
  - being fixed in meta-pytorch/autoparallel#321
- [CompilerToolkit CI
failure](https://github.com/pytorch/torchtitan/actions/runs/22168015737/job/64099486707?pr=2386)
`TypeError: forward() missing 1 required positional argument:
'fwd_rng_state_2'` cc @yiming0416 please help take a look
- [SimpleFSDP CI
failure](https://github.com/pytorch/torchtitan/actions/runs/22168015749/job/64099486149?pr=2386)
is the same as pytorch#2312 around
dynamic shape for for-loop MoE experts computation. (cc @pianpwk)
  - being fixed in pytorch#2399
- Fix integration scripts for MAST, Zoomer, etc.
- organize docs from `docs/` to subfolders, as we are having more
contents to cover in general
- generate and store serialized configs (maybe not in the repo)
- continue SAC refactor in
pytorch#2357, but somehow keep the
every-other-mm policy (cc @mori360)
- refactor RoPE in general, at least resolving the following TODOs in
code (cc @shuhuayu)
- having to set / no validation on rope dim == decoder dim // attention
n_heads
- consolidate `apply_rotary_emb_complex` and
`apply_rotary_emb_single_complex`
  - address pytorch#2417

Longer-term issues
- More careful design about what to put config vs. runtime build kwargs.
(thanks @ailzhang)
- ModelSpec not serializable. There may be multiple solutions, but we
can potentially consolidate `model.py` and `parallelize.py` by
  - sharing AC, compile, DP application across all Decoder models
  - putting per-module TP/CP/EP sharding plan inside model itself
- Right now `BaseModel.update_from_config` violates encapsulation by
passing the Trainer config into Model config. This could be avoided by
python logic either in config construction time, or in trainer.
- Refactor `init_weights` into `Module.Config` instead of staying in
`Module`
- The benefit is that param init can be configurable; o/w we are
coupling module implementation and its weight init.
- This may require refactor of current TransformerBlock and its config.
E.g. `weight_init_std` may need to be put in config, with
`__post_init__` determining its value. (See related complaints /
discussions on `__post_init__` by
[chz](https://github.com/openai/chz/blob/main/docs/21_post_init.md))

Note to reviewer:
Although I believe the changes in this PR come naturally in a bundle,
you may (or may not) find the stack of 16 commits easier to review, as I
tried to split the changes in some logic manner. I apologize for the
giant PR.

# claude-generated summary

## Summary

This PR refactors torchtitan's configuration and training infrastructure
in 15 incremental, backwards-incompatible commits. The central change
replaces TOML config files and a monolithic `JobConfig` parser with
**typed Python dataclass configs**, a **`Configurable` base class**
pattern, and a **`config_registry`** module per model.

**270 files changed, 10,025 insertions, 11,418 deletions.**

---

## Motivation

The previous system used TOML files parsed by a custom `ConfigManager`
that layered CLI overrides on top. While simple, this had several
friction points:

1. **No type safety at config boundaries.** TOML values are
strings/ints/floats parsed at runtime. A typo in a key name (e.g.,
`training.stpes`) silently becomes a default value.
4. **Flat namespace.** All config sections (`[model]`, `[training]`,
`[optimizer]`, `[checkpoint]`, ...) lived in a single `JobConfig` class.
Every component received the full `JobConfig` even when it only needed a
few fields.
5. **Experiment extension was ad-hoc.** Experiments that needed custom
config fields (e.g., SimpleFSDP's `compile.graph_passes` or
FaultTolerant's `fault_tolerance.*`) required a `custom_config_module`
TOML key and a runtime `_merge_configs` call to graft new fields onto
`JobConfig`.
6. **Model args were disconnected from model code.** A `ModelArgs`
dataclass in `args.py` defined hyperparameters, but the `TrainSpec` that
bundled model + parallelization + loss was registered separately, with
no type-level link between them.

---

## What Changed

### 1. `Configurable` Base Class

A new `Configurable` base class (`torchtitan/config/configurable.py`)
establishes a universal pattern:

```python
class Configurable:
    @DataClass(kw_only=True, slots=True)
    class Config:
        def build(self, **kwargs):
            return self._owner(config=self, **kwargs)

    def __init_subclass__(cls, **kwargs):
        # Auto-wires Config.build() -> cls(config=..., **kwargs)
        # Enforces @DataClass(kw_only=True, slots=True) on every Config
```

Every configurable component (Trainer, model, optimizer, tokenizer,
dataloader, checkpoint manager, metrics, validators, quantization
converters, ...) follows this pattern. Calling `config.build()`
constructs the owning class.

### 2. `Trainer.Config` Replaces `JobConfig`

The monolithic `JobConfig` is replaced by `Trainer.Config`, a nested
dataclass that aggregates typed sub-configs:

```python
class Trainer(Stateful, Configurable):
    @DataClass(kw_only=True, slots=True)
    class Config(Configurable.Config):
        model_spec: ModelSpec | None = None    # set by config_registry, suppressed from CLI
        job: JobConfig = ...
        training: TrainingConfig = ...
        parallelism: ParallelismConfig = ...
        optimizer: OptimizersContainer.Config = ...
        lr_scheduler: LRSchedulersContainer.Config = ...
        checkpoint: CheckpointManager.Config = ...
        dataloader: BaseDataLoader.Config = ...
        metrics: MetricsProcessor.Config = ...
        # ... etc.
```

Each sub-config is the `Config` class of the component that consumes it
(e.g., `CheckpointManager.Config` is defined inside
`CheckpointManager`). Components receive only their own config, not the
entire training config.

### 3. `config_registry.py` Replaces TOML Files

Each model defines a `config_registry.py` with functions that return
complete `Trainer.Config` instances:

```python
# torchtitan/models/llama3/config_registry.py

def llama3_debugmodel() -> Trainer.Config:
    return Trainer.Config(
        job=JobConfig(description="Llama 3 debug training", ...),
        model_spec=model_registry("debugmodel"),
        optimizer=OptimizersContainer.Config(lr=8e-4),
        training=TrainingConfig(local_batch_size=8, seq_len=2048, steps=10),
        dataloader=HuggingFaceTextDataLoader.Config(dataset="c4_test"),
        # ...
    )

def llama3_debugmodel_float8() -> Trainer.Config:
    config = llama3_debugmodel()
    config.model_converters = ModelConvertersContainer.Config(
        converters=[Float8LinearConverter.Config(enable_fsdp_float8_all_gather=True)]
    )
    return config
```

### 4. `TrainSpec` -> `ModelSpec`

`TrainSpec` is renamed to `ModelSpec` with a narrower scope: it holds
only model-specific concerns (model config, parallelization function,
loss function, state dict adapter). All training-level concerns
(optimizer, LR scheduler, checkpointing, etc.) live in `Trainer.Config`.

### 5. Model Configs: Flat `ModelArgs` -> Nested Dataclass Hierarchy

Model hyperparameters move from a flat `ModelArgs` dataclass into a
nested `Config` hierarchy that mirrors the module tree:

```python
# Before (main): flat args.py
@DataClass
class ModelArgs:
    dim: int = 4096
    n_layers: int = 32
    n_heads: int = 32
    # ... 20+ flat fields

# After (this PR): nested Config in model class
class Llama3Model(Decoder):
    @DataClass(kw_only=True, slots=True)
    class Config(Decoder.Config):
        layer: Llama3TransformerBlock.Config  # contains attention + FFN configs
        rope: RoPE.Config                    # contains RoPE-specific params
```

### 6. `train.py` Split

The monolithic `train.py` (~800 lines) is split into:
- `train.py` (~60 lines): thin entry point that calls
`ConfigManager.parse_args()` and `config.build()`
- `trainer.py` (~850 lines): the `Trainer` class with training loop
logic

### 7. Experiment Extension via Inheritance

Experiments extend the config system through dataclass subclassing
instead of runtime config merging:

```python
# torchtitan/experiments/simple_fsdp/configs.py
@DataClass(kw_only=True, slots=True)
class SimpleFSDPConfig(Trainer.Config):
    compile: SimpleFSDPCompileConfig = field(default_factory=SimpleFSDPCompileConfig)
```

Their `config_registry.py` returns the subclassed config type, and
`tyro` auto-generates CLI parsing for the extended fields.

---

## UX Comparison

### Launching Training

```bash
# Before (main)
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.py" ./run_train.sh

# After (this PR)
MODEL=llama3 CONFIG=llama3_8b ./run_train.sh
```

### CLI Overrides

```bash
# Before (main)
CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh \
  --training.steps 100 --parallelism.tensor_parallel_degree 2

# After (this PR)
./run_train.sh --training.steps 100 --parallelism.tensor_parallel_degree 2
# (defaults to MODEL=llama3, CONFIG=llama3_debugmodel via run_train.sh)
```

CLI override syntax is unchanged (`--section.field value`), but `tyro`
now provides typed `--help` output generated from the dataclass tree.

### Defining a New Model Config

```bash
# Before: create a new TOML file, copy-paste sections, edit values
cp train_configs/debug_model.toml train_configs/my_experiment.toml
vim train_configs/my_experiment.toml

# After: write a Python function that mutates an existing config
def my_experiment() -> Trainer.Config:
    config = llama3_debugmodel()
    config.training.steps = 100
    config.optimizer.lr = 1e-4
    return config
```

### Adding Experiment-Specific Config Fields

```python
# Before (main): custom_config_module in TOML + runtime _merge_configs
# Requires: TOML key pointing to a Python module, dynamic dataclass creation

# After (this PR): dataclass inheritance
@DataClass(kw_only=True, slots=True)
class MyExperimentConfig(Trainer.Config):
    my_custom_field: str = "default"
```

### Float8 / Quantization Configuration

```python
# Before (main): TOML section
# [quantize.linear.float8]
# enable_fsdp_float8_all_gather = true
# precompute_float8_dynamic_scale_for_fsdp = true

# After (this PR): typed config object
model_converters=ModelConvertersContainer.Config(
    converters=[
        Float8LinearConverter.Config(
            enable_fsdp_float8_all_gather=True,
            precompute_float8_dynamic_scale_for_fsdp=True,
        ),
    ],
),
```

---

## Limitations and Trade-offs

### 1. Configs are no longer declarative text files

TOML files were readable by anyone without Python knowledge. The new
config_registry functions are Python code, which requires understanding
imports, function calls, and dataclass construction. For users who only
need to tweak hyperparameters, the CLI override syntax
(`--training.steps 100`) works the same, but understanding the full
config requires reading Python.

### 2. Steeper learning curve for contributors

Adding a new model now requires understanding the `Configurable`
protocol, nested `Config` dataclass hierarchy, and the `config_registry`
pattern. The old approach of copying a TOML file and editing values had
a lower barrier to entry.

### 3. Config serialization is more complex

TOML files were trivially serializable and diffable. The new system
supports `to_dict()` + JSON serialization, but configs containing
callables (e.g., `ModelSpec.parallelize_fn`) cannot be fully
round-tripped. The `model_spec` field is excluded from serialization and
suppressed from CLI parsing.

### 4. tyro dependency

The CLI parsing now depends on `tyro`, a third-party library. While
`tyro` is well-maintained and provides typed CLI generation from
dataclasses, it is an additional dependency that must be kept compatible
with the dataclass patterns used here.

### 5. `@dataclass(slots=True)` constraints

The `Configurable` base class enforces `@dataclass(kw_only=True,
slots=True)` on all Config classes. While this provides memory
efficiency and prevents accidental attribute assignment, `slots=True`
prevents dynamic attribute addition and makes multiple inheritance with
other slotted classes more constrained. Each Config subclass in a deep
hierarchy must repeat the `@dataclass(kw_only=True, slots=True)`
decorator.

### 6. Two-level indirection for model selection

The old system required one identifier: `--job.config_file
path/to/file.toml`. The new system requires two: `--module llama3
--config llama3_8b`. While this separates model identity from training
recipe, it adds an extra argument.

---

## Numerics Verification

All model configs were verified for numerical equivalence against the
main branch (commit `8b006830`):

NOTE
- only models that can fit on 8 GPUs are tested
- only subset of parallelism combination are tested

| Model | Status | Notes |
|-------|--------|-------|
| llama3 (debugmodel, 8B) | Bitwise match | |
| llama3 (debugmodel_flex_attn) | Bitwise match | |
| qwen3 (0.6B, 1.7B, 32B, MoE debugmodel) | Bitwise match | |
| deepseek_v3 (debugmodel, 16B) | Close (max diff 0.00014) |
Pre-existing main branch bug: missing `eps` in final RMSNorm |
| llama4 debugmodel | Bitwise match | _irope variants don't work on main
(FlexAttn `'dict' object has no attribute 'BLOCK_SIZE'`) but now work
after this PR |
| **gpt_oss** debugmodel | --debug.deterministic causes loss to be NaN;
o/w first step loss match, minor difference after (likely caused by
flex?) | |
| flux | Bitwise match | |

---

## Migration Guide

| Old (main) | New (this PR) |
|---|---|
| `CONFIG_FILE="path/to/config.toml" ./run_train.sh` | `MODEL=llama3
CONFIG=llama3_8b ./run_train.sh` |
| `--job.config_file path.toml` | `--module llama3 --config llama3_8b` |
| `train_configs/*.toml` | `config_registry.py` functions |
| `TrainSpec` | `ModelSpec` |
| `ModelArgs` / `args.py` | Nested `Model.Config` dataclass |
| `custom_config_module` + `_merge_configs()` | Subclass
`Trainer.Config` |
| `build_model_converters()` free function |
`ModelConvertersContainer.Config.build()` |
| `build_metrics_processor()` free function |
`MetricsProcessor.Config.build()` |
ACharacterInASimulation pushed a commit to ACharacterInASimulation/torchtitan that referenced this pull request Apr 21, 2026
* support launching custom trainer;
* init trainer components through .build() (pytorch#2386);
* move data to GPU by micro-batch;
* remove rescale_accumulated_loss (pytorch#2206).
zwluestc pushed a commit to zwluestc/verl that referenced this pull request May 12, 2026
…#5457)

### What does this PR do?

Update the trainer API used in Torchtitan Engine after refactor in
pytorch/torchtitan#2386.


### Test

```
GRPC_ENABLE_FORK_SUPPORT=0 NCCL_NVLS_ENABLE=0 MODEL_ID=Qwen/Qwen3-0.6B bash tests/special_e2e/run_ppo_trainer_torchtitan.sh
```
loss on par with before


### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [ ] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [ ] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [ ] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [ ] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [ ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
- [ ] If your PR is related to the `recipe` submodule, please also
update the reference to the submodule commit via `git submodule update
--remote` or `cd recipe && git pull origin main`.
xvlincaigou pushed a commit to xvlincaigou/verl that referenced this pull request May 19, 2026
…#5457)

### What does this PR do?

Update the trainer API used in Torchtitan Engine after refactor in
pytorch/torchtitan#2386.


### Test

```
GRPC_ENABLE_FORK_SUPPORT=0 NCCL_NVLS_ENABLE=0 MODEL_ID=Qwen/Qwen3-0.6B bash tests/special_e2e/run_ppo_trainer_torchtitan.sh
```
loss on par with before


### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [ ] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [ ] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [ ] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [ ] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [ ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
- [ ] If your PR is related to the `recipe` submodule, please also
update the reference to the submodule commit via `git submodule update
--remote` or `cd recipe && git pull origin main`.
drizzlezyk pushed a commit to hicann/torchtitan-npu that referenced this pull request Jun 12, 2026
Co-authored-by: zhangwei1177<zhangwei1177@huawei.com>



# message auto-generated for no-merge-commit merge:
!331 merge fix/enable-mtp-cp-on-master into master

[fix]: cp+mtp —— enable MTP context-parallel patch on master torchtitan

Created-by: zhangwei1177
Commit-by: zhangwei1177
Merged-by: cann-robot
Description: ## 描述
  在 DeepSeek-V4 上同时开启上下文并行(context_parallel_degree > 1)和多 token 预测(num_mtp_modules > 0)时,训练会在第一步前向的 `apply_rotary_emb` 处崩溃,报 `RuntimeError: shape '[1, 2047, 1, 32]' is invalid for input of size 65536`——`positions`比主干序列多了 `num_mtp_modules` 个 token,导致 RoPE 的 `view` 形状不匹配。仓库里本就有针对 CP+MTP 的补丁 `mtp_context_parallel.py`,但它实际上完全没有生效,训练走的是 torchtitan 默认、不感知 MTP 的 CP 输入准备路径。
  
  根因有两层:一是 `torchtitan_npu/__init__.py` 从未导入该补丁,它对 `prepare_context_parallel_input` 的 monkey-patch 因此从未安装;二是补丁通过 `Trainer.job_config` 读取 `num_mtp_modules`,而 torchtitan在配置系统重构(pytorch/torchtitan#2386,已包含在本仓库 pin 的版本中)里把该属性改名为 `Trainer.config`,导致探测静默失败、`num_mtp_modules` 被当作 0 而回退到默认路径。

  本 PR 仅改两个文件:在 `__init__.py` 的 `_apply_patches()` 最前面(早于会拉起 `torchtitan.trainer` 的 init_distributed 补丁)导入 `mtp_context_parallel`,确保 monkey-patch 先于 trainer 绑定符号生效;并把 `mtp_context_parallel.py` 中的探测改为读取 `self.config`,同时用 `hasattr(jc.training, "num_mtp_modules")` 作护栏,避免误匹配调用栈里其他对象的同名 `.config`。


## 类型
- [x] Bug 修复
- [ ] 新功能
- [ ] 重构(即不是新增功能,也不是修改bug的代码变动)
- [ ] 构建过程或辅助工具的变动
- [ ] 文档内容更新

## Checklist:
- [x] 我的代码遵循这个项目的代码风格
- [x] 我已经自己测试过我的代码
- [ ] 我已经更新了相应的文档
- [x] 我已经在标题中正确使用了类型标签(例如:`feat`, `fix`, `refactor`, `docs`, `test`)

## 如何测试
在config_registry.py文件中将num_mtp_modules设置为1,context_parallel_degree设置为2.

## 其他信息
A3上deepseek_v4_285b_debug_4_layers模型2卡/cp=1/mtp=1与4卡/cp=2/mtp=1的对比结果:
![cp1_vs_cp2_mtp1.png](https://raw.gitcode.com/user-images/assets/9028822/81c2f425-0106-4a37-9c73-19f70c9df996/cp1_vs_cp2_mtp1.png 'cp1_vs_cp2_mtp1.png')


See merge request: cann/torchtitan-npu!331
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

9 participants