Conversation
|
Have you tried asking claude to help split this into a stack of PRs that can be reviewed/landed independently? I did see your comment/apology to reviewers, but i still think honestly nobody is going to review this PR in its entirety so are you asking for an uncareful scan and a stamp, or do you want to break out important pieces of the code that you want careful review on? |
|
@wconstab While I understand how intimidating it could be for reviewing a huge PR, I would like to initially deliver the package as whole instead of letting people only see incremental changes (if it's possible at all). Maybe I would like to achieve
|
| model: nn.Module, | ||
| parallel_dims: ParallelDims, | ||
| job_config: JobConfig, | ||
| *, |
There was a problem hiding this comment.
what's the rule for differentiating the positional args and kw args?
There was a problem hiding this comment.
I'm inspired by https://github.com/apple/axlearn/blob/main/docs/ml_api_style.md#avoid-multiple-positional-arguments. Here I'm moving parallel_dims to kwarg as well.
There was a problem hiding this comment.
limit the number of positional arguments to <= 1 and use keyword arguments for the rest
what's the reason not making them all kwargs?
There was a problem hiding this comment.
I don't know for sure. Likely because for some functions, there would always be a "main" arg that is always there and doesn't introduce ambiguity / error-proneness. E.g. if a function only takes one arg, like parallelize(model), maybe it's fine? You can imagine later on when people adds more and more optional kwargs to the function, the model part doesn't need to be changed.
| train_spec: TrainSpec, | ||
| def register_torchtitan_model_from_model_spec( | ||
| model_spec: ModelSpec, | ||
| model_name: str, |
There was a problem hiding this comment.
model_name should be part of model_spec
There was a problem hiding this comment.
Model name refers to another thing: model_name: Name to register in vLLM (e.g., "Qwen3TorchTitanForCausalLM"). Maybe need a more descriptive name here, but that could be done in a separate PR
wwwjn
left a comment
There was a problem hiding this comment.
Mainly took a look on components, config, experiments/rl, models, train.py and trainer.py
| train_spec: TrainSpec, | ||
| def register_torchtitan_model_from_model_spec( | ||
| model_spec: ModelSpec, | ||
| model_name: str, |
There was a problem hiding this comment.
Model name refers to another thing: model_name: Name to register in vLLM (e.g., "Qwen3TorchTitanForCausalLM"). Maybe need a more descriptive name here, but that could be done in a separate PR
| parallel_dims: ParallelDims, | ||
| dump_folder: str = "./outputs", | ||
| pp_schedule: str = "1F1B", | ||
| ft_enable: bool = False, | ||
| ft_replica_id: int = 0, | ||
| config_dict: dict[str, Any] | None = None, | ||
| tag: str | None = None, |
There was a problem hiding this comment.
I also am worried about the kwargs being a loophole where people passing configurations around.
There was a problem hiding this comment.
An alternative approach is to require each component define these shared configurations and resolve the shared configurations when constructing the root configuration (Trainer.Config).
There was a problem hiding this comment.
Yeah, this is the top issue I put in "Longer-term Issues" in PR summary, which I couldn't handle entirely in this initial PR.
First, we need to figure out the boundary between "shared config" and "runtime kwargs". We can use more shared config, but that is "utilizing" (a.k.a. "abusing") the python config power in a way that makes it harder to transform to pure yaml solution, which may be OK.
More importantly, we need to reconsider if the current function calling structure makes sense at all. Current metrics logging is limited and hard to customize -- e.g. in MoE how to log the number of tokens each expert processes? In this sense, such problems are reflecting the design flaws we have in torchtitan -- in the past, these are omitted due to the usage of JobConfig everywhere. I think this is one of the good things about this refactor.
|
|
||
|
|
||
| register_model_converter(Float8LinearConverter, "quantize.linear.float8") | ||
| register_model_converter(Float8GroupedMMConverter, "quantize.grouped_mm.float8") |
There was a problem hiding this comment.
i see this removes the converter names (quantize.grouped_mm.float8 etc) - what does the command line API for this look like now?
There was a problem hiding this comment.
We lose CLI capability for adjusting this, because there is no string attached to each converter anymore.
Torchtitan merged a BC-breaking config system refactor (pytorch/torchtitan#2386) that replaced TOML configs with Python dataclass configs and changed the CLI from CONFIG_FILE + --model.name to --module + --config. Updates the CI commands accordingly. Also fixes a runtime crash where aliased buffers (registered for user-facing API compat by #321) were being passed to the compiled graph, which only expects the canonical (deduplicated) set. The deepseek_v3 test is commented out as it's also disabled in torchtitan's own CI. Authored with Claude.
Torchtitan merged a BC-breaking config system refactor (pytorch/torchtitan#2386) that replaced TOML configs with Python dataclass configs and changed the CLI from CONFIG_FILE + --model.name to --module + --config. Updates the CI commands accordingly. Also fixes a runtime crash where aliased buffers (registered for user-facing API compat by #321) were being passed to the compiled graph, which only expects the canonical (deduplicated) set. The deepseek_v3 test is commented out as it's also disabled in torchtitan's own CI. Authored with Claude. stack-info: PR: #325, branch: xmfan/stack/26
Torchtitan merged a BC-breaking config system refactor (pytorch/torchtitan#2386) that replaced TOML configs with Python dataclass configs and changed the CLI from CONFIG_FILE + --model.name to --module + --config. Updates the CI commands accordingly. Also fixes a runtime crash where aliased buffers (registered for user-facing API compat by #321) were being passed to the compiled graph, which only expects the canonical (deduplicated) set. The deepseek_v3 test is commented out as it's also disabled in torchtitan's own CI. Authored with Claude. stack-info: PR: #325, branch: xmfan/stack/26
Torchtitan merged a BC-breaking config system refactor (pytorch/torchtitan#2386) that replaced TOML configs with Python dataclass configs and changed the CLI from CONFIG_FILE + --model.name to --module + --config. Updates the CI commands accordingly. Also fixes a runtime crash where aliased buffers (registered for user-facing API compat by #321) were being passed to the compiled graph, which only expects the canonical (deduplicated) set. The deepseek_v3 test is commented out as it's also disabled in torchtitan's own CI. Authored with Claude. stack-info: PR: #325, branch: xmfan/stack/26
Torchtitan merged a BC-breaking config system refactor (pytorch/torchtitan#2386) that replaced TOML configs with Python dataclass configs and changed the CLI from CONFIG_FILE + --model.name to --module + --config. Updates the CI commands accordingly. Also fixes a runtime crash where aliased buffers (registered for user-facing API compat by #321) were being passed to the compiled graph, which only expects the canonical (deduplicated) set. The deepseek_v3 test is commented out as it's also disabled in torchtitan's own CI. Authored with Claude. stack-info: PR: #325, branch: xmfan/stack/26
### What does this PR do? Update the trainer API used in Torchtitan Engine after refactor in pytorch/torchtitan#2386. ### Test ``` GRPC_ENABLE_FORK_SUPPORT=0 NCCL_NVLS_ENABLE=0 MODEL_ID=Qwen/Qwen3-0.6B bash tests/special_e2e/run_ppo_trainer_torchtitan.sh ``` loss on par with before ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [ ] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [ ] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) - [ ] If your PR is related to the `recipe` submodule, please also update the reference to the submodule commit via `git submodule update --remote` or `cd recipe && git pull origin main`.
…#5457) ### What does this PR do? Update the trainer API used in Torchtitan Engine after refactor in pytorch/torchtitan#2386. ### Test ``` GRPC_ENABLE_FORK_SUPPORT=0 NCCL_NVLS_ENABLE=0 MODEL_ID=Qwen/Qwen3-0.6B bash tests/special_e2e/run_ppo_trainer_torchtitan.sh ``` loss on par with before ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [ ] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [ ] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) - [ ] If your PR is related to the `recipe` submodule, please also update the reference to the submodule commit via `git submodule update --remote` or `cd recipe && git pull origin main`.
…#5457) ### What does this PR do? Update the trainer API used in Torchtitan Engine after refactor in pytorch/torchtitan#2386. ### Test ``` GRPC_ENABLE_FORK_SUPPORT=0 NCCL_NVLS_ENABLE=0 MODEL_ID=Qwen/Qwen3-0.6B bash tests/special_e2e/run_ppo_trainer_torchtitan.sh ``` loss on par with before ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [ ] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [ ] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) - [ ] If your PR is related to the `recipe` submodule, please also update the reference to the submodule commit via `git submodule update --remote` or `cd recipe && git pull origin main`.
…#5457) ### What does this PR do? Update the trainer API used in Torchtitan Engine after refactor in pytorch/torchtitan#2386. ### Test ``` GRPC_ENABLE_FORK_SUPPORT=0 NCCL_NVLS_ENABLE=0 MODEL_ID=Qwen/Qwen3-0.6B bash tests/special_e2e/run_ppo_trainer_torchtitan.sh ``` loss on par with before ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [ ] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [ ] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) - [ ] If your PR is related to the `recipe` submodule, please also update the reference to the submodule commit via `git submodule update --remote` or `cd recipe && git pull origin main`.
Brings in pytorch/torchtitan's BC-breaking config refactor (PR pytorch#2386): - TOML-based config -> Python dataclass registry - JobConfig -> Trainer.Config with nested component configs - TrainSpec/ModelProtocol -> ModelSpec/BaseModel/Module protocol - New common model components (Linear, Embedding, RMSNorm, etc.) For Qwen3: accepted upstream's new flat structure as base. Our weight sharing modules (weight_sharing.py, factorized_embedding.py, LoRA modules) preserved in model/ subdirectory for re-integration. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…#5457) ### What does this PR do? Update the trainer API used in Torchtitan Engine after refactor in pytorch/torchtitan#2386. ### Test ``` GRPC_ENABLE_FORK_SUPPORT=0 NCCL_NVLS_ENABLE=0 MODEL_ID=Qwen/Qwen3-0.6B bash tests/special_e2e/run_ppo_trainer_torchtitan.sh ``` loss on par with before ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [ ] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [ ] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) - [ ] If your PR is related to the `recipe` submodule, please also update the reference to the submodule commit via `git submodule update --remote` or `cd recipe && git pull origin main`.
…ry (pytorch#2386) **NOTE**: This PR is a large refactor of the codebase. https://github.com/pytorch/torchtitan/releases/tag/v0.2.2 contains a latest release right before this PR is merged. # author's note This refactor is mainly trying to address two issues: - bad encapsulation: previously a monolithic `JobConfig` is leaked everywhere - not easy to iterate and experiment on model architecture and training components The main changes are: - Strict encapsulation, even at the cost of (hopefully temporary) bloated interface when calling subcomponents (e.g. validator). We should try to find the right abstraction on cross-components visibility. - Each `Configurable` component owns its own `Config`, which builds the owner component. It achieves modularization via polymorphism and inheritance, both classic concepts in OOP. - This is partly inspired by repos like [AXLearn](https://github.com/apple/axlearn) (in particular, @ruomingp's [ML API Styles](https://github.com/apple/axlearn/blob/main/docs/ml_api_style.md)), github issues (e.g. pytorch#1055), and offline discussions (with @Chillee, @ailzhang etc.). - Similar functionality can be alternatively achieved by other ways, e.g. `_target_` in [Hydra](https://hydra.cc/docs/advanced/instantiate_objects/overview/), but there are opinions not to couple with Hydra's other offerings. See pytorch#1415 - Main entry point switches from TOML files to Python functions (a.k.a. `config_registry.py` in each model). - TOML has the constraint that everything needs to be registered explicitly before it can be used, e.g. our quantization components need to be registered with string names. Python's language level implicit registration is what we believe to be more minimal, and should be fairly easy to extended/modified to support TOML/YAML when users builds upon / fork torchtitan. - That said, Python config provides much more power, e.g. one can use arbitrary logic to create (the config of) a component, which is hard to express with TOML/YAML, thus creating extra difficulty when users want to migrate to their own favorite config system. The only thing we can do is to stay conservative on the usage of such power. - We still uses [tyro](https://github.com/brentyi/tyro) to convert config dataclass to CLI, still with the limitation that users need to construct customized config classes, all the way from root level (`Trainer.Config` now, `JobConfig` in the past). - If CLI is not needed, new trainer (or any high-level) config is not required. - To support "polymorphic construction" from CLI without the hassle, check out [chz](https://github.com/openai/chz/blob/main/docs/04_command_line.md#polymorphic-construction). This PR also - updates the docs -- there might be remaining outdated docs, please raise issues or help fix - moves ft to experiments, continuing the effort in pytorch#2311 Remaining work - [AutoParallel CI failure](https://github.com/pytorch/torchtitan/actions/runs/22165425254/job/64091572780?pr=2386) seems caused by the way RoPE is authored, and needs change in autoparallel. (cc @xmfan) - being fixed in meta-pytorch/autoparallel#321 - [CompilerToolkit CI failure](https://github.com/pytorch/torchtitan/actions/runs/22168015737/job/64099486707?pr=2386) `TypeError: forward() missing 1 required positional argument: 'fwd_rng_state_2'` cc @yiming0416 please help take a look - [SimpleFSDP CI failure](https://github.com/pytorch/torchtitan/actions/runs/22168015749/job/64099486149?pr=2386) is the same as pytorch#2312 around dynamic shape for for-loop MoE experts computation. (cc @pianpwk) - being fixed in pytorch#2399 - Fix integration scripts for MAST, Zoomer, etc. - organize docs from `docs/` to subfolders, as we are having more contents to cover in general - generate and store serialized configs (maybe not in the repo) - continue SAC refactor in pytorch#2357, but somehow keep the every-other-mm policy (cc @mori360) - refactor RoPE in general, at least resolving the following TODOs in code (cc @shuhuayu) - having to set / no validation on rope dim == decoder dim // attention n_heads - consolidate `apply_rotary_emb_complex` and `apply_rotary_emb_single_complex` - address pytorch#2417 Longer-term issues - More careful design about what to put config vs. runtime build kwargs. (thanks @ailzhang) - ModelSpec not serializable. There may be multiple solutions, but we can potentially consolidate `model.py` and `parallelize.py` by - sharing AC, compile, DP application across all Decoder models - putting per-module TP/CP/EP sharding plan inside model itself - Right now `BaseModel.update_from_config` violates encapsulation by passing the Trainer config into Model config. This could be avoided by python logic either in config construction time, or in trainer. - Refactor `init_weights` into `Module.Config` instead of staying in `Module` - The benefit is that param init can be configurable; o/w we are coupling module implementation and its weight init. - This may require refactor of current TransformerBlock and its config. E.g. `weight_init_std` may need to be put in config, with `__post_init__` determining its value. (See related complaints / discussions on `__post_init__` by [chz](https://github.com/openai/chz/blob/main/docs/21_post_init.md)) Note to reviewer: Although I believe the changes in this PR come naturally in a bundle, you may (or may not) find the stack of 16 commits easier to review, as I tried to split the changes in some logic manner. I apologize for the giant PR. # claude-generated summary ## Summary This PR refactors torchtitan's configuration and training infrastructure in 15 incremental, backwards-incompatible commits. The central change replaces TOML config files and a monolithic `JobConfig` parser with **typed Python dataclass configs**, a **`Configurable` base class** pattern, and a **`config_registry`** module per model. **270 files changed, 10,025 insertions, 11,418 deletions.** --- ## Motivation The previous system used TOML files parsed by a custom `ConfigManager` that layered CLI overrides on top. While simple, this had several friction points: 1. **No type safety at config boundaries.** TOML values are strings/ints/floats parsed at runtime. A typo in a key name (e.g., `training.stpes`) silently becomes a default value. 4. **Flat namespace.** All config sections (`[model]`, `[training]`, `[optimizer]`, `[checkpoint]`, ...) lived in a single `JobConfig` class. Every component received the full `JobConfig` even when it only needed a few fields. 5. **Experiment extension was ad-hoc.** Experiments that needed custom config fields (e.g., SimpleFSDP's `compile.graph_passes` or FaultTolerant's `fault_tolerance.*`) required a `custom_config_module` TOML key and a runtime `_merge_configs` call to graft new fields onto `JobConfig`. 6. **Model args were disconnected from model code.** A `ModelArgs` dataclass in `args.py` defined hyperparameters, but the `TrainSpec` that bundled model + parallelization + loss was registered separately, with no type-level link between them. --- ## What Changed ### 1. `Configurable` Base Class A new `Configurable` base class (`torchtitan/config/configurable.py`) establishes a universal pattern: ```python class Configurable: @DataClass(kw_only=True, slots=True) class Config: def build(self, **kwargs): return self._owner(config=self, **kwargs) def __init_subclass__(cls, **kwargs): # Auto-wires Config.build() -> cls(config=..., **kwargs) # Enforces @DataClass(kw_only=True, slots=True) on every Config ``` Every configurable component (Trainer, model, optimizer, tokenizer, dataloader, checkpoint manager, metrics, validators, quantization converters, ...) follows this pattern. Calling `config.build()` constructs the owning class. ### 2. `Trainer.Config` Replaces `JobConfig` The monolithic `JobConfig` is replaced by `Trainer.Config`, a nested dataclass that aggregates typed sub-configs: ```python class Trainer(Stateful, Configurable): @DataClass(kw_only=True, slots=True) class Config(Configurable.Config): model_spec: ModelSpec | None = None # set by config_registry, suppressed from CLI job: JobConfig = ... training: TrainingConfig = ... parallelism: ParallelismConfig = ... optimizer: OptimizersContainer.Config = ... lr_scheduler: LRSchedulersContainer.Config = ... checkpoint: CheckpointManager.Config = ... dataloader: BaseDataLoader.Config = ... metrics: MetricsProcessor.Config = ... # ... etc. ``` Each sub-config is the `Config` class of the component that consumes it (e.g., `CheckpointManager.Config` is defined inside `CheckpointManager`). Components receive only their own config, not the entire training config. ### 3. `config_registry.py` Replaces TOML Files Each model defines a `config_registry.py` with functions that return complete `Trainer.Config` instances: ```python # torchtitan/models/llama3/config_registry.py def llama3_debugmodel() -> Trainer.Config: return Trainer.Config( job=JobConfig(description="Llama 3 debug training", ...), model_spec=model_registry("debugmodel"), optimizer=OptimizersContainer.Config(lr=8e-4), training=TrainingConfig(local_batch_size=8, seq_len=2048, steps=10), dataloader=HuggingFaceTextDataLoader.Config(dataset="c4_test"), # ... ) def llama3_debugmodel_float8() -> Trainer.Config: config = llama3_debugmodel() config.model_converters = ModelConvertersContainer.Config( converters=[Float8LinearConverter.Config(enable_fsdp_float8_all_gather=True)] ) return config ``` ### 4. `TrainSpec` -> `ModelSpec` `TrainSpec` is renamed to `ModelSpec` with a narrower scope: it holds only model-specific concerns (model config, parallelization function, loss function, state dict adapter). All training-level concerns (optimizer, LR scheduler, checkpointing, etc.) live in `Trainer.Config`. ### 5. Model Configs: Flat `ModelArgs` -> Nested Dataclass Hierarchy Model hyperparameters move from a flat `ModelArgs` dataclass into a nested `Config` hierarchy that mirrors the module tree: ```python # Before (main): flat args.py @DataClass class ModelArgs: dim: int = 4096 n_layers: int = 32 n_heads: int = 32 # ... 20+ flat fields # After (this PR): nested Config in model class class Llama3Model(Decoder): @DataClass(kw_only=True, slots=True) class Config(Decoder.Config): layer: Llama3TransformerBlock.Config # contains attention + FFN configs rope: RoPE.Config # contains RoPE-specific params ``` ### 6. `train.py` Split The monolithic `train.py` (~800 lines) is split into: - `train.py` (~60 lines): thin entry point that calls `ConfigManager.parse_args()` and `config.build()` - `trainer.py` (~850 lines): the `Trainer` class with training loop logic ### 7. Experiment Extension via Inheritance Experiments extend the config system through dataclass subclassing instead of runtime config merging: ```python # torchtitan/experiments/simple_fsdp/configs.py @DataClass(kw_only=True, slots=True) class SimpleFSDPConfig(Trainer.Config): compile: SimpleFSDPCompileConfig = field(default_factory=SimpleFSDPCompileConfig) ``` Their `config_registry.py` returns the subclassed config type, and `tyro` auto-generates CLI parsing for the extended fields. --- ## UX Comparison ### Launching Training ```bash # Before (main) CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.py" ./run_train.sh # After (this PR) MODEL=llama3 CONFIG=llama3_8b ./run_train.sh ``` ### CLI Overrides ```bash # Before (main) CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh \ --training.steps 100 --parallelism.tensor_parallel_degree 2 # After (this PR) ./run_train.sh --training.steps 100 --parallelism.tensor_parallel_degree 2 # (defaults to MODEL=llama3, CONFIG=llama3_debugmodel via run_train.sh) ``` CLI override syntax is unchanged (`--section.field value`), but `tyro` now provides typed `--help` output generated from the dataclass tree. ### Defining a New Model Config ```bash # Before: create a new TOML file, copy-paste sections, edit values cp train_configs/debug_model.toml train_configs/my_experiment.toml vim train_configs/my_experiment.toml # After: write a Python function that mutates an existing config def my_experiment() -> Trainer.Config: config = llama3_debugmodel() config.training.steps = 100 config.optimizer.lr = 1e-4 return config ``` ### Adding Experiment-Specific Config Fields ```python # Before (main): custom_config_module in TOML + runtime _merge_configs # Requires: TOML key pointing to a Python module, dynamic dataclass creation # After (this PR): dataclass inheritance @DataClass(kw_only=True, slots=True) class MyExperimentConfig(Trainer.Config): my_custom_field: str = "default" ``` ### Float8 / Quantization Configuration ```python # Before (main): TOML section # [quantize.linear.float8] # enable_fsdp_float8_all_gather = true # precompute_float8_dynamic_scale_for_fsdp = true # After (this PR): typed config object model_converters=ModelConvertersContainer.Config( converters=[ Float8LinearConverter.Config( enable_fsdp_float8_all_gather=True, precompute_float8_dynamic_scale_for_fsdp=True, ), ], ), ``` --- ## Limitations and Trade-offs ### 1. Configs are no longer declarative text files TOML files were readable by anyone without Python knowledge. The new config_registry functions are Python code, which requires understanding imports, function calls, and dataclass construction. For users who only need to tweak hyperparameters, the CLI override syntax (`--training.steps 100`) works the same, but understanding the full config requires reading Python. ### 2. Steeper learning curve for contributors Adding a new model now requires understanding the `Configurable` protocol, nested `Config` dataclass hierarchy, and the `config_registry` pattern. The old approach of copying a TOML file and editing values had a lower barrier to entry. ### 3. Config serialization is more complex TOML files were trivially serializable and diffable. The new system supports `to_dict()` + JSON serialization, but configs containing callables (e.g., `ModelSpec.parallelize_fn`) cannot be fully round-tripped. The `model_spec` field is excluded from serialization and suppressed from CLI parsing. ### 4. tyro dependency The CLI parsing now depends on `tyro`, a third-party library. While `tyro` is well-maintained and provides typed CLI generation from dataclasses, it is an additional dependency that must be kept compatible with the dataclass patterns used here. ### 5. `@dataclass(slots=True)` constraints The `Configurable` base class enforces `@dataclass(kw_only=True, slots=True)` on all Config classes. While this provides memory efficiency and prevents accidental attribute assignment, `slots=True` prevents dynamic attribute addition and makes multiple inheritance with other slotted classes more constrained. Each Config subclass in a deep hierarchy must repeat the `@dataclass(kw_only=True, slots=True)` decorator. ### 6. Two-level indirection for model selection The old system required one identifier: `--job.config_file path/to/file.toml`. The new system requires two: `--module llama3 --config llama3_8b`. While this separates model identity from training recipe, it adds an extra argument. --- ## Numerics Verification All model configs were verified for numerical equivalence against the main branch (commit `10d8a306`): NOTE - only models that can fit on 8 GPUs are tested - only subset of parallelism combination are tested | Model | Status | Notes | |-------|--------|-------| | llama3 (debugmodel, 8B) | Bitwise match | | | llama3 (debugmodel_flex_attn) | Bitwise match | | | qwen3 (0.6B, 1.7B, 32B, MoE debugmodel) | Bitwise match | | | deepseek_v3 (debugmodel, 16B) | Close (max diff 0.00014) | Pre-existing main branch bug: missing `eps` in final RMSNorm | | llama4 debugmodel | Bitwise match | _irope variants don't work on main (FlexAttn `'dict' object has no attribute 'BLOCK_SIZE'`) but now work after this PR | | **gpt_oss** debugmodel | --debug.deterministic causes loss to be NaN; o/w first step loss match, minor difference after (likely caused by flex?) | | | flux | Bitwise match | | --- ## Migration Guide | Old (main) | New (this PR) | |---|---| | `CONFIG_FILE="path/to/config.toml" ./run_train.sh` | `MODEL=llama3 CONFIG=llama3_8b ./run_train.sh` | | `--job.config_file path.toml` | `--module llama3 --config llama3_8b` | | `train_configs/*.toml` | `config_registry.py` functions | | `TrainSpec` | `ModelSpec` | | `ModelArgs` / `args.py` | Nested `Model.Config` dataclass | | `custom_config_module` + `_merge_configs()` | Subclass `Trainer.Config` | | `build_model_converters()` free function | `ModelConvertersContainer.Config.build()` | | `build_metrics_processor()` free function | `MetricsProcessor.Config.build()` |
* support launching custom trainer; * init trainer components through .build() (pytorch#2386); * move data to GPU by micro-batch; * remove rescale_accumulated_loss (pytorch#2206).
* support launching custom trainer; * init trainer components through .build() (pytorch#2386); * move data to GPU by micro-batch; * remove rescale_accumulated_loss (pytorch#2206).
…ure (#3012) Hey there 👋 Looks like in #2386 a flat folder structure for each model was introduced. Now there are no `model` or `infra` sub-folder in each model's folder. (With one exception - `flux` model. It decided to go rogue.) But the `torchtitan/models/README.md` hasn't been updated to reflect these changes.
…#5457) ### What does this PR do? Update the trainer API used in Torchtitan Engine after refactor in pytorch/torchtitan#2386. ### Test ``` GRPC_ENABLE_FORK_SUPPORT=0 NCCL_NVLS_ENABLE=0 MODEL_ID=Qwen/Qwen3-0.6B bash tests/special_e2e/run_ppo_trainer_torchtitan.sh ``` loss on par with before ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [ ] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [ ] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) - [ ] If your PR is related to the `recipe` submodule, please also update the reference to the submodule commit via `git submodule update --remote` or `cd recipe && git pull origin main`.
…ry (pytorch#2386) **NOTE**: This PR is a large refactor of the codebase. https://github.com/pytorch/torchtitan/releases/tag/v0.2.2 contains a latest release right before this PR is merged. # author's note This refactor is mainly trying to address two issues: - bad encapsulation: previously a monolithic `JobConfig` is leaked everywhere - not easy to iterate and experiment on model architecture and training components The main changes are: - Strict encapsulation, even at the cost of (hopefully temporary) bloated interface when calling subcomponents (e.g. validator). We should try to find the right abstraction on cross-components visibility. - Each `Configurable` component owns its own `Config`, which builds the owner component. It achieves modularization via polymorphism and inheritance, both classic concepts in OOP. - This is partly inspired by repos like [AXLearn](https://github.com/apple/axlearn) (in particular, @ruomingp's [ML API Styles](https://github.com/apple/axlearn/blob/main/docs/ml_api_style.md)), github issues (e.g. pytorch#1055), and offline discussions (with @Chillee, @ailzhang etc.). - Similar functionality can be alternatively achieved by other ways, e.g. `_target_` in [Hydra](https://hydra.cc/docs/advanced/instantiate_objects/overview/), but there are opinions not to couple with Hydra's other offerings. See pytorch#1415 - Main entry point switches from TOML files to Python functions (a.k.a. `config_registry.py` in each model). - TOML has the constraint that everything needs to be registered explicitly before it can be used, e.g. our quantization components need to be registered with string names. Python's language level implicit registration is what we believe to be more minimal, and should be fairly easy to extended/modified to support TOML/YAML when users builds upon / fork torchtitan. - That said, Python config provides much more power, e.g. one can use arbitrary logic to create (the config of) a component, which is hard to express with TOML/YAML, thus creating extra difficulty when users want to migrate to their own favorite config system. The only thing we can do is to stay conservative on the usage of such power. - We still uses [tyro](https://github.com/brentyi/tyro) to convert config dataclass to CLI, still with the limitation that users need to construct customized config classes, all the way from root level (`Trainer.Config` now, `JobConfig` in the past). - If CLI is not needed, new trainer (or any high-level) config is not required. - To support "polymorphic construction" from CLI without the hassle, check out [chz](https://github.com/openai/chz/blob/main/docs/04_command_line.md#polymorphic-construction). This PR also - updates the docs -- there might be remaining outdated docs, please raise issues or help fix - moves ft to experiments, continuing the effort in pytorch#2311 Remaining work - [AutoParallel CI failure](https://github.com/pytorch/torchtitan/actions/runs/22165425254/job/64091572780?pr=2386) seems caused by the way RoPE is authored, and needs change in autoparallel. (cc @xmfan) - being fixed in meta-pytorch/autoparallel#321 - [CompilerToolkit CI failure](https://github.com/pytorch/torchtitan/actions/runs/22168015737/job/64099486707?pr=2386) `TypeError: forward() missing 1 required positional argument: 'fwd_rng_state_2'` cc @yiming0416 please help take a look - [SimpleFSDP CI failure](https://github.com/pytorch/torchtitan/actions/runs/22168015749/job/64099486149?pr=2386) is the same as pytorch#2312 around dynamic shape for for-loop MoE experts computation. (cc @pianpwk) - being fixed in pytorch#2399 - Fix integration scripts for MAST, Zoomer, etc. - organize docs from `docs/` to subfolders, as we are having more contents to cover in general - generate and store serialized configs (maybe not in the repo) - continue SAC refactor in pytorch#2357, but somehow keep the every-other-mm policy (cc @mori360) - refactor RoPE in general, at least resolving the following TODOs in code (cc @shuhuayu) - having to set / no validation on rope dim == decoder dim // attention n_heads - consolidate `apply_rotary_emb_complex` and `apply_rotary_emb_single_complex` - address pytorch#2417 Longer-term issues - More careful design about what to put config vs. runtime build kwargs. (thanks @ailzhang) - ModelSpec not serializable. There may be multiple solutions, but we can potentially consolidate `model.py` and `parallelize.py` by - sharing AC, compile, DP application across all Decoder models - putting per-module TP/CP/EP sharding plan inside model itself - Right now `BaseModel.update_from_config` violates encapsulation by passing the Trainer config into Model config. This could be avoided by python logic either in config construction time, or in trainer. - Refactor `init_weights` into `Module.Config` instead of staying in `Module` - The benefit is that param init can be configurable; o/w we are coupling module implementation and its weight init. - This may require refactor of current TransformerBlock and its config. E.g. `weight_init_std` may need to be put in config, with `__post_init__` determining its value. (See related complaints / discussions on `__post_init__` by [chz](https://github.com/openai/chz/blob/main/docs/21_post_init.md)) Note to reviewer: Although I believe the changes in this PR come naturally in a bundle, you may (or may not) find the stack of 16 commits easier to review, as I tried to split the changes in some logic manner. I apologize for the giant PR. # claude-generated summary ## Summary This PR refactors torchtitan's configuration and training infrastructure in 15 incremental, backwards-incompatible commits. The central change replaces TOML config files and a monolithic `JobConfig` parser with **typed Python dataclass configs**, a **`Configurable` base class** pattern, and a **`config_registry`** module per model. **270 files changed, 10,025 insertions, 11,418 deletions.** --- ## Motivation The previous system used TOML files parsed by a custom `ConfigManager` that layered CLI overrides on top. While simple, this had several friction points: 1. **No type safety at config boundaries.** TOML values are strings/ints/floats parsed at runtime. A typo in a key name (e.g., `training.stpes`) silently becomes a default value. 4. **Flat namespace.** All config sections (`[model]`, `[training]`, `[optimizer]`, `[checkpoint]`, ...) lived in a single `JobConfig` class. Every component received the full `JobConfig` even when it only needed a few fields. 5. **Experiment extension was ad-hoc.** Experiments that needed custom config fields (e.g., SimpleFSDP's `compile.graph_passes` or FaultTolerant's `fault_tolerance.*`) required a `custom_config_module` TOML key and a runtime `_merge_configs` call to graft new fields onto `JobConfig`. 6. **Model args were disconnected from model code.** A `ModelArgs` dataclass in `args.py` defined hyperparameters, but the `TrainSpec` that bundled model + parallelization + loss was registered separately, with no type-level link between them. --- ## What Changed ### 1. `Configurable` Base Class A new `Configurable` base class (`torchtitan/config/configurable.py`) establishes a universal pattern: ```python class Configurable: @DataClass(kw_only=True, slots=True) class Config: def build(self, **kwargs): return self._owner(config=self, **kwargs) def __init_subclass__(cls, **kwargs): # Auto-wires Config.build() -> cls(config=..., **kwargs) # Enforces @DataClass(kw_only=True, slots=True) on every Config ``` Every configurable component (Trainer, model, optimizer, tokenizer, dataloader, checkpoint manager, metrics, validators, quantization converters, ...) follows this pattern. Calling `config.build()` constructs the owning class. ### 2. `Trainer.Config` Replaces `JobConfig` The monolithic `JobConfig` is replaced by `Trainer.Config`, a nested dataclass that aggregates typed sub-configs: ```python class Trainer(Stateful, Configurable): @DataClass(kw_only=True, slots=True) class Config(Configurable.Config): model_spec: ModelSpec | None = None # set by config_registry, suppressed from CLI job: JobConfig = ... training: TrainingConfig = ... parallelism: ParallelismConfig = ... optimizer: OptimizersContainer.Config = ... lr_scheduler: LRSchedulersContainer.Config = ... checkpoint: CheckpointManager.Config = ... dataloader: BaseDataLoader.Config = ... metrics: MetricsProcessor.Config = ... # ... etc. ``` Each sub-config is the `Config` class of the component that consumes it (e.g., `CheckpointManager.Config` is defined inside `CheckpointManager`). Components receive only their own config, not the entire training config. ### 3. `config_registry.py` Replaces TOML Files Each model defines a `config_registry.py` with functions that return complete `Trainer.Config` instances: ```python # torchtitan/models/llama3/config_registry.py def llama3_debugmodel() -> Trainer.Config: return Trainer.Config( job=JobConfig(description="Llama 3 debug training", ...), model_spec=model_registry("debugmodel"), optimizer=OptimizersContainer.Config(lr=8e-4), training=TrainingConfig(local_batch_size=8, seq_len=2048, steps=10), dataloader=HuggingFaceTextDataLoader.Config(dataset="c4_test"), # ... ) def llama3_debugmodel_float8() -> Trainer.Config: config = llama3_debugmodel() config.model_converters = ModelConvertersContainer.Config( converters=[Float8LinearConverter.Config(enable_fsdp_float8_all_gather=True)] ) return config ``` ### 4. `TrainSpec` -> `ModelSpec` `TrainSpec` is renamed to `ModelSpec` with a narrower scope: it holds only model-specific concerns (model config, parallelization function, loss function, state dict adapter). All training-level concerns (optimizer, LR scheduler, checkpointing, etc.) live in `Trainer.Config`. ### 5. Model Configs: Flat `ModelArgs` -> Nested Dataclass Hierarchy Model hyperparameters move from a flat `ModelArgs` dataclass into a nested `Config` hierarchy that mirrors the module tree: ```python # Before (main): flat args.py @DataClass class ModelArgs: dim: int = 4096 n_layers: int = 32 n_heads: int = 32 # ... 20+ flat fields # After (this PR): nested Config in model class class Llama3Model(Decoder): @DataClass(kw_only=True, slots=True) class Config(Decoder.Config): layer: Llama3TransformerBlock.Config # contains attention + FFN configs rope: RoPE.Config # contains RoPE-specific params ``` ### 6. `train.py` Split The monolithic `train.py` (~800 lines) is split into: - `train.py` (~60 lines): thin entry point that calls `ConfigManager.parse_args()` and `config.build()` - `trainer.py` (~850 lines): the `Trainer` class with training loop logic ### 7. Experiment Extension via Inheritance Experiments extend the config system through dataclass subclassing instead of runtime config merging: ```python # torchtitan/experiments/simple_fsdp/configs.py @DataClass(kw_only=True, slots=True) class SimpleFSDPConfig(Trainer.Config): compile: SimpleFSDPCompileConfig = field(default_factory=SimpleFSDPCompileConfig) ``` Their `config_registry.py` returns the subclassed config type, and `tyro` auto-generates CLI parsing for the extended fields. --- ## UX Comparison ### Launching Training ```bash # Before (main) CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.py" ./run_train.sh # After (this PR) MODEL=llama3 CONFIG=llama3_8b ./run_train.sh ``` ### CLI Overrides ```bash # Before (main) CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh \ --training.steps 100 --parallelism.tensor_parallel_degree 2 # After (this PR) ./run_train.sh --training.steps 100 --parallelism.tensor_parallel_degree 2 # (defaults to MODEL=llama3, CONFIG=llama3_debugmodel via run_train.sh) ``` CLI override syntax is unchanged (`--section.field value`), but `tyro` now provides typed `--help` output generated from the dataclass tree. ### Defining a New Model Config ```bash # Before: create a new TOML file, copy-paste sections, edit values cp train_configs/debug_model.toml train_configs/my_experiment.toml vim train_configs/my_experiment.toml # After: write a Python function that mutates an existing config def my_experiment() -> Trainer.Config: config = llama3_debugmodel() config.training.steps = 100 config.optimizer.lr = 1e-4 return config ``` ### Adding Experiment-Specific Config Fields ```python # Before (main): custom_config_module in TOML + runtime _merge_configs # Requires: TOML key pointing to a Python module, dynamic dataclass creation # After (this PR): dataclass inheritance @DataClass(kw_only=True, slots=True) class MyExperimentConfig(Trainer.Config): my_custom_field: str = "default" ``` ### Float8 / Quantization Configuration ```python # Before (main): TOML section # [quantize.linear.float8] # enable_fsdp_float8_all_gather = true # precompute_float8_dynamic_scale_for_fsdp = true # After (this PR): typed config object model_converters=ModelConvertersContainer.Config( converters=[ Float8LinearConverter.Config( enable_fsdp_float8_all_gather=True, precompute_float8_dynamic_scale_for_fsdp=True, ), ], ), ``` --- ## Limitations and Trade-offs ### 1. Configs are no longer declarative text files TOML files were readable by anyone without Python knowledge. The new config_registry functions are Python code, which requires understanding imports, function calls, and dataclass construction. For users who only need to tweak hyperparameters, the CLI override syntax (`--training.steps 100`) works the same, but understanding the full config requires reading Python. ### 2. Steeper learning curve for contributors Adding a new model now requires understanding the `Configurable` protocol, nested `Config` dataclass hierarchy, and the `config_registry` pattern. The old approach of copying a TOML file and editing values had a lower barrier to entry. ### 3. Config serialization is more complex TOML files were trivially serializable and diffable. The new system supports `to_dict()` + JSON serialization, but configs containing callables (e.g., `ModelSpec.parallelize_fn`) cannot be fully round-tripped. The `model_spec` field is excluded from serialization and suppressed from CLI parsing. ### 4. tyro dependency The CLI parsing now depends on `tyro`, a third-party library. While `tyro` is well-maintained and provides typed CLI generation from dataclasses, it is an additional dependency that must be kept compatible with the dataclass patterns used here. ### 5. `@dataclass(slots=True)` constraints The `Configurable` base class enforces `@dataclass(kw_only=True, slots=True)` on all Config classes. While this provides memory efficiency and prevents accidental attribute assignment, `slots=True` prevents dynamic attribute addition and makes multiple inheritance with other slotted classes more constrained. Each Config subclass in a deep hierarchy must repeat the `@dataclass(kw_only=True, slots=True)` decorator. ### 6. Two-level indirection for model selection The old system required one identifier: `--job.config_file path/to/file.toml`. The new system requires two: `--module llama3 --config llama3_8b`. While this separates model identity from training recipe, it adds an extra argument. --- ## Numerics Verification All model configs were verified for numerical equivalence against the main branch (commit `8b006830`): NOTE - only models that can fit on 8 GPUs are tested - only subset of parallelism combination are tested | Model | Status | Notes | |-------|--------|-------| | llama3 (debugmodel, 8B) | Bitwise match | | | llama3 (debugmodel_flex_attn) | Bitwise match | | | qwen3 (0.6B, 1.7B, 32B, MoE debugmodel) | Bitwise match | | | deepseek_v3 (debugmodel, 16B) | Close (max diff 0.00014) | Pre-existing main branch bug: missing `eps` in final RMSNorm | | llama4 debugmodel | Bitwise match | _irope variants don't work on main (FlexAttn `'dict' object has no attribute 'BLOCK_SIZE'`) but now work after this PR | | **gpt_oss** debugmodel | --debug.deterministic causes loss to be NaN; o/w first step loss match, minor difference after (likely caused by flex?) | | | flux | Bitwise match | | --- ## Migration Guide | Old (main) | New (this PR) | |---|---| | `CONFIG_FILE="path/to/config.toml" ./run_train.sh` | `MODEL=llama3 CONFIG=llama3_8b ./run_train.sh` | | `--job.config_file path.toml` | `--module llama3 --config llama3_8b` | | `train_configs/*.toml` | `config_registry.py` functions | | `TrainSpec` | `ModelSpec` | | `ModelArgs` / `args.py` | Nested `Model.Config` dataclass | | `custom_config_module` + `_merge_configs()` | Subclass `Trainer.Config` | | `build_model_converters()` free function | `ModelConvertersContainer.Config.build()` | | `build_metrics_processor()` free function | `MetricsProcessor.Config.build()` |
* support launching custom trainer; * init trainer components through .build() (pytorch#2386); * move data to GPU by micro-batch; * remove rescale_accumulated_loss (pytorch#2206).
…#5457) ### What does this PR do? Update the trainer API used in Torchtitan Engine after refactor in pytorch/torchtitan#2386. ### Test ``` GRPC_ENABLE_FORK_SUPPORT=0 NCCL_NVLS_ENABLE=0 MODEL_ID=Qwen/Qwen3-0.6B bash tests/special_e2e/run_ppo_trainer_torchtitan.sh ``` loss on par with before ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [ ] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [ ] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) - [ ] If your PR is related to the `recipe` submodule, please also update the reference to the submodule commit via `git submodule update --remote` or `cd recipe && git pull origin main`.
…#5457) ### What does this PR do? Update the trainer API used in Torchtitan Engine after refactor in pytorch/torchtitan#2386. ### Test ``` GRPC_ENABLE_FORK_SUPPORT=0 NCCL_NVLS_ENABLE=0 MODEL_ID=Qwen/Qwen3-0.6B bash tests/special_e2e/run_ppo_trainer_torchtitan.sh ``` loss on par with before ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [ ] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [ ] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) - [ ] If your PR is related to the `recipe` submodule, please also update the reference to the submodule commit via `git submodule update --remote` or `cd recipe && git pull origin main`.
Co-authored-by: zhangwei1177<zhangwei1177@huawei.com> # message auto-generated for no-merge-commit merge: !331 merge fix/enable-mtp-cp-on-master into master [fix]: cp+mtp —— enable MTP context-parallel patch on master torchtitan Created-by: zhangwei1177 Commit-by: zhangwei1177 Merged-by: cann-robot Description: ## 描述 在 DeepSeek-V4 上同时开启上下文并行(context_parallel_degree > 1)和多 token 预测(num_mtp_modules > 0)时,训练会在第一步前向的 `apply_rotary_emb` 处崩溃,报 `RuntimeError: shape '[1, 2047, 1, 32]' is invalid for input of size 65536`——`positions`比主干序列多了 `num_mtp_modules` 个 token,导致 RoPE 的 `view` 形状不匹配。仓库里本就有针对 CP+MTP 的补丁 `mtp_context_parallel.py`,但它实际上完全没有生效,训练走的是 torchtitan 默认、不感知 MTP 的 CP 输入准备路径。 根因有两层:一是 `torchtitan_npu/__init__.py` 从未导入该补丁,它对 `prepare_context_parallel_input` 的 monkey-patch 因此从未安装;二是补丁通过 `Trainer.job_config` 读取 `num_mtp_modules`,而 torchtitan在配置系统重构(pytorch/torchtitan#2386,已包含在本仓库 pin 的版本中)里把该属性改名为 `Trainer.config`,导致探测静默失败、`num_mtp_modules` 被当作 0 而回退到默认路径。 本 PR 仅改两个文件:在 `__init__.py` 的 `_apply_patches()` 最前面(早于会拉起 `torchtitan.trainer` 的 init_distributed 补丁)导入 `mtp_context_parallel`,确保 monkey-patch 先于 trainer 绑定符号生效;并把 `mtp_context_parallel.py` 中的探测改为读取 `self.config`,同时用 `hasattr(jc.training, "num_mtp_modules")` 作护栏,避免误匹配调用栈里其他对象的同名 `.config`。 ## 类型 - [x] Bug 修复 - [ ] 新功能 - [ ] 重构(即不是新增功能,也不是修改bug的代码变动) - [ ] 构建过程或辅助工具的变动 - [ ] 文档内容更新 ## Checklist: - [x] 我的代码遵循这个项目的代码风格 - [x] 我已经自己测试过我的代码 - [ ] 我已经更新了相应的文档 - [x] 我已经在标题中正确使用了类型标签(例如:`feat`, `fix`, `refactor`, `docs`, `test`) ## 如何测试 在config_registry.py文件中将num_mtp_modules设置为1,context_parallel_degree设置为2. ## 其他信息 A3上deepseek_v4_285b_debug_4_layers模型2卡/cp=1/mtp=1与4卡/cp=2/mtp=1的对比结果:  See merge request: cann/torchtitan-npu!331
NOTE: This PR is a large refactor of the codebase. https://github.com/pytorch/torchtitan/releases/tag/v0.2.2 contains a latest release right before this PR is merged.
author's note
This refactor is mainly trying to address two issues:
JobConfigis leaked everywhereThe main changes are:
Configurablecomponent owns its ownConfig, which builds the owner component. It achieves modularization via polymorphism and inheritance, both classic concepts in OOP._target_in Hydra, but there are opinions not to couple with Hydra's other offerings. See [Feature request] Use omegaconf or hydra for the config system #1415config_registry.pyin each model).Trainer.Confignow,JobConfigin the past).This PR also
Remaining work
TypeError: forward() missing 1 required positional argument: 'fwd_rng_state_2'cc @yiming0416 please help take a lookdocs/to subfolders, as we are having more contents to cover in generalapply_rotary_emb_complexandapply_rotary_emb_single_complexLonger-term issues
model.pyandparallelize.pybyBaseModel.update_from_configviolates encapsulation by passing the Trainer config into Model config. This could be avoided by python logic either in config construction time, or in trainer.init_weightsintoModule.Configinstead of staying inModuleweight_init_stdmay need to be put in config, with__post_init__determining its value. (See related complaints / discussions on__post_init__by chz)Note to reviewer:
Although I believe the changes in this PR come naturally in a bundle, you may (or may not) find the stack of 16 commits easier to review, as I tried to split the changes in some logic manner. I apologize for the giant PR.
claude-generated summary
Summary
This PR refactors torchtitan's configuration and training infrastructure in 15 incremental, backwards-incompatible commits. The central change replaces TOML config files and a monolithic
JobConfigparser with typed Python dataclass configs, aConfigurablebase class pattern, and aconfig_registrymodule per model.270 files changed, 10,025 insertions, 11,418 deletions.
Motivation
The previous system used TOML files parsed by a custom
ConfigManagerthat layered CLI overrides on top. While simple, this had several friction points:training.stpes) silently becomes a default value.[model],[training],[optimizer],[checkpoint], ...) lived in a singleJobConfigclass. Every component received the fullJobConfigeven when it only needed a few fields.compile.graph_passesor FaultTolerant'sfault_tolerance.*) required acustom_config_moduleTOML key and a runtime_merge_configscall to graft new fields ontoJobConfig.ModelArgsdataclass inargs.pydefined hyperparameters, but theTrainSpecthat bundled model + parallelization + loss was registered separately, with no type-level link between them.What Changed
1.
ConfigurableBase ClassA new
Configurablebase class (torchtitan/config/configurable.py) establishes a universal pattern:Every configurable component (Trainer, model, optimizer, tokenizer, dataloader, checkpoint manager, metrics, validators, quantization converters, ...) follows this pattern. Calling
config.build()constructs the owning class.2.
Trainer.ConfigReplacesJobConfigThe monolithic
JobConfigis replaced byTrainer.Config, a nested dataclass that aggregates typed sub-configs:Each sub-config is the
Configclass of the component that consumes it (e.g.,CheckpointManager.Configis defined insideCheckpointManager). Components receive only their own config, not the entire training config.3.
config_registry.pyReplaces TOML FilesEach model defines a
config_registry.pywith functions that return completeTrainer.Configinstances:4.
TrainSpec->ModelSpecTrainSpecis renamed toModelSpecwith a narrower scope: it holds only model-specific concerns (model config, parallelization function, loss function, state dict adapter). All training-level concerns (optimizer, LR scheduler, checkpointing, etc.) live inTrainer.Config.5. Model Configs: Flat
ModelArgs-> Nested Dataclass HierarchyModel hyperparameters move from a flat
ModelArgsdataclass into a nestedConfighierarchy that mirrors the module tree:6.
train.pySplitThe monolithic
train.py(~800 lines) is split into:train.py(~60 lines): thin entry point that callsConfigManager.parse_args()andconfig.build()trainer.py(~850 lines): theTrainerclass with training loop logic7. Experiment Extension via Inheritance
Experiments extend the config system through dataclass subclassing instead of runtime config merging:
Their
config_registry.pyreturns the subclassed config type, andtyroauto-generates CLI parsing for the extended fields.UX Comparison
Launching Training
CLI Overrides
CLI override syntax is unchanged (
--section.field value), buttyronow provides typed--helpoutput generated from the dataclass tree.Defining a New Model Config
Adding Experiment-Specific Config Fields
Float8 / Quantization Configuration
Limitations and Trade-offs
1. Configs are no longer declarative text files
TOML files were readable by anyone without Python knowledge. The new config_registry functions are Python code, which requires understanding imports, function calls, and dataclass construction. For users who only need to tweak hyperparameters, the CLI override syntax (
--training.steps 100) works the same, but understanding the full config requires reading Python.2. Steeper learning curve for contributors
Adding a new model now requires understanding the
Configurableprotocol, nestedConfigdataclass hierarchy, and theconfig_registrypattern. The old approach of copying a TOML file and editing values had a lower barrier to entry.3. Config serialization is more complex
TOML files were trivially serializable and diffable. The new system supports
to_dict()+ JSON serialization, but configs containing callables (e.g.,ModelSpec.parallelize_fn) cannot be fully round-tripped. Themodel_specfield is excluded from serialization and suppressed from CLI parsing.4. tyro dependency
The CLI parsing now depends on
tyro, a third-party library. Whiletyrois well-maintained and provides typed CLI generation from dataclasses, it is an additional dependency that must be kept compatible with the dataclass patterns used here.5.
@dataclass(slots=True)constraintsThe
Configurablebase class enforces@dataclass(kw_only=True, slots=True)on all Config classes. While this provides memory efficiency and prevents accidental attribute assignment,slots=Trueprevents dynamic attribute addition and makes multiple inheritance with other slotted classes more constrained. Each Config subclass in a deep hierarchy must repeat the@dataclass(kw_only=True, slots=True)decorator.6. Two-level indirection for model selection
The old system required one identifier:
--job.config_file path/to/file.toml. The new system requires two:--module llama3 --config llama3_8b. While this separates model identity from training recipe, it adds an extra argument.Numerics Verification
All model configs were verified for numerical equivalence against the main branch (commit
10d8a306):NOTE
epsin final RMSNorm'dict' object has no attribute 'BLOCK_SIZE') but now work after this PRMigration Guide
CONFIG_FILE="path/to/config.toml" ./run_train.shMODEL=llama3 CONFIG=llama3_8b ./run_train.sh--job.config_file path.toml--module llama3 --config llama3_8btrain_configs/*.tomlconfig_registry.pyfunctionsTrainSpecModelSpecModelArgs/args.pyModel.Configdataclasscustom_config_module+_merge_configs()Trainer.Configbuild_model_converters()free functionModelConvertersContainer.Config.build()build_metrics_processor()free functionMetricsProcessor.Config.build()