[SAC] Centralize selective AC policy and remove per-model op save lists#2357
Merged
Conversation
tianyu-l
added a commit
that referenced
this pull request
Feb 23, 2026
…ry (#2386) **NOTE**: This PR is a large refactor of the codebase. https://github.com/pytorch/torchtitan/releases/tag/v0.2.2 contains a latest release right before this PR is merged. # author's note This refactor is mainly trying to address two issues: - bad encapsulation: previously a monolithic `JobConfig` is leaked everywhere - not easy to iterate and experiment on model architecture and training components The main changes are: - Strict encapsulation, even at the cost of (hopefully temporary) bloated interface when calling subcomponents (e.g. validator). We should try to find the right abstraction on cross-components visibility. - Each `Configurable` component owns its own `Config`, which builds the owner component. It achieves modularization via polymorphism and inheritance, both classic concepts in OOP. - This is partly inspired by repos like [AXLearn](https://github.com/apple/axlearn) (in particular, @ruomingp's [ML API Styles](https://github.com/apple/axlearn/blob/main/docs/ml_api_style.md)), github issues (e.g. #1055), and offline discussions (with @Chillee, @ailzhang etc.). - Similar functionality can be alternatively achieved by other ways, e.g. `_target_` in [Hydra](https://hydra.cc/docs/advanced/instantiate_objects/overview/), but there are opinions not to couple with Hydra's other offerings. See #1415 - Main entry point switches from TOML files to Python functions (a.k.a. `config_registry.py` in each model). - TOML has the constraint that everything needs to be registered explicitly before it can be used, e.g. our quantization components need to be registered with string names. Python's language level implicit registration is what we believe to be more minimal, and should be fairly easy to extended/modified to support TOML/YAML when users builds upon / fork torchtitan. - That said, Python config provides much more power, e.g. one can use arbitrary logic to create (the config of) a component, which is hard to express with TOML/YAML, thus creating extra difficulty when users want to migrate to their own favorite config system. The only thing we can do is to stay conservative on the usage of such power. - We still uses [tyro](https://github.com/brentyi/tyro) to convert config dataclass to CLI, still with the limitation that users need to construct customized config classes, all the way from root level (`Trainer.Config` now, `JobConfig` in the past). - If CLI is not needed, new trainer (or any high-level) config is not required. - To support "polymorphic construction" from CLI without the hassle, check out [chz](https://github.com/openai/chz/blob/main/docs/04_command_line.md#polymorphic-construction). This PR also - updates the docs -- there might be remaining outdated docs, please raise issues or help fix - moves ft to experiments, continuing the effort in #2311 Remaining work - [AutoParallel CI failure](https://github.com/pytorch/torchtitan/actions/runs/22165425254/job/64091572780?pr=2386) seems caused by the way RoPE is authored, and needs change in autoparallel. (cc @xmfan) - being fixed in meta-pytorch/autoparallel#321 - [CompilerToolkit CI failure](https://github.com/pytorch/torchtitan/actions/runs/22168015737/job/64099486707?pr=2386) `TypeError: forward() missing 1 required positional argument: 'fwd_rng_state_2'` cc @yiming0416 please help take a look - [SimpleFSDP CI failure](https://github.com/pytorch/torchtitan/actions/runs/22168015749/job/64099486149?pr=2386) is the same as #2312 around dynamic shape for for-loop MoE experts computation. (cc @pianpwk) - being fixed in #2399 - Fix integration scripts for MAST, Zoomer, etc. - organize docs from `docs/` to subfolders, as we are having more contents to cover in general - generate and store serialized configs (maybe not in the repo) - continue SAC refactor in #2357, but somehow keep the every-other-mm policy (cc @mori360) - refactor RoPE in general, at least resolving the following TODOs in code (cc @shuhuayu) - having to set / no validation on rope dim == decoder dim // attention n_heads - consolidate `apply_rotary_emb_complex` and `apply_rotary_emb_single_complex` - address #2417 Longer-term issues - More careful design about what to put config vs. runtime build kwargs. (thanks @ailzhang) - ModelSpec not serializable. There may be multiple solutions, but we can potentially consolidate `model.py` and `parallelize.py` by - sharing AC, compile, DP application across all Decoder models - putting per-module TP/CP/EP sharding plan inside model itself - Right now `BaseModel.update_from_config` violates encapsulation by passing the Trainer config into Model config. This could be avoided by python logic either in config construction time, or in trainer. - Refactor `init_weights` into `Module.Config` instead of staying in `Module` - The benefit is that param init can be configurable; o/w we are coupling module implementation and its weight init. - This may require refactor of current TransformerBlock and its config. E.g. `weight_init_std` may need to be put in config, with `__post_init__` determining its value. (See related complaints / discussions on `__post_init__` by [chz](https://github.com/openai/chz/blob/main/docs/21_post_init.md)) Note to reviewer: Although I believe the changes in this PR come naturally in a bundle, you may (or may not) find the stack of 16 commits easier to review, as I tried to split the changes in some logic manner. I apologize for the giant PR. # claude-generated summary ## Summary This PR refactors torchtitan's configuration and training infrastructure in 15 incremental, backwards-incompatible commits. The central change replaces TOML config files and a monolithic `JobConfig` parser with **typed Python dataclass configs**, a **`Configurable` base class** pattern, and a **`config_registry`** module per model. **270 files changed, 10,025 insertions, 11,418 deletions.** --- ## Motivation The previous system used TOML files parsed by a custom `ConfigManager` that layered CLI overrides on top. While simple, this had several friction points: 1. **No type safety at config boundaries.** TOML values are strings/ints/floats parsed at runtime. A typo in a key name (e.g., `training.stpes`) silently becomes a default value. 4. **Flat namespace.** All config sections (`[model]`, `[training]`, `[optimizer]`, `[checkpoint]`, ...) lived in a single `JobConfig` class. Every component received the full `JobConfig` even when it only needed a few fields. 5. **Experiment extension was ad-hoc.** Experiments that needed custom config fields (e.g., SimpleFSDP's `compile.graph_passes` or FaultTolerant's `fault_tolerance.*`) required a `custom_config_module` TOML key and a runtime `_merge_configs` call to graft new fields onto `JobConfig`. 6. **Model args were disconnected from model code.** A `ModelArgs` dataclass in `args.py` defined hyperparameters, but the `TrainSpec` that bundled model + parallelization + loss was registered separately, with no type-level link between them. --- ## What Changed ### 1. `Configurable` Base Class A new `Configurable` base class (`torchtitan/config/configurable.py`) establishes a universal pattern: ```python class Configurable: @DataClass(kw_only=True, slots=True) class Config: def build(self, **kwargs): return self._owner(config=self, **kwargs) def __init_subclass__(cls, **kwargs): # Auto-wires Config.build() -> cls(config=..., **kwargs) # Enforces @DataClass(kw_only=True, slots=True) on every Config ``` Every configurable component (Trainer, model, optimizer, tokenizer, dataloader, checkpoint manager, metrics, validators, quantization converters, ...) follows this pattern. Calling `config.build()` constructs the owning class. ### 2. `Trainer.Config` Replaces `JobConfig` The monolithic `JobConfig` is replaced by `Trainer.Config`, a nested dataclass that aggregates typed sub-configs: ```python class Trainer(Stateful, Configurable): @DataClass(kw_only=True, slots=True) class Config(Configurable.Config): model_spec: ModelSpec | None = None # set by config_registry, suppressed from CLI job: JobConfig = ... training: TrainingConfig = ... parallelism: ParallelismConfig = ... optimizer: OptimizersContainer.Config = ... lr_scheduler: LRSchedulersContainer.Config = ... checkpoint: CheckpointManager.Config = ... dataloader: BaseDataLoader.Config = ... metrics: MetricsProcessor.Config = ... # ... etc. ``` Each sub-config is the `Config` class of the component that consumes it (e.g., `CheckpointManager.Config` is defined inside `CheckpointManager`). Components receive only their own config, not the entire training config. ### 3. `config_registry.py` Replaces TOML Files Each model defines a `config_registry.py` with functions that return complete `Trainer.Config` instances: ```python # torchtitan/models/llama3/config_registry.py def llama3_debugmodel() -> Trainer.Config: return Trainer.Config( job=JobConfig(description="Llama 3 debug training", ...), model_spec=model_registry("debugmodel"), optimizer=OptimizersContainer.Config(lr=8e-4), training=TrainingConfig(local_batch_size=8, seq_len=2048, steps=10), dataloader=HuggingFaceTextDataLoader.Config(dataset="c4_test"), # ... ) def llama3_debugmodel_float8() -> Trainer.Config: config = llama3_debugmodel() config.model_converters = ModelConvertersContainer.Config( converters=[Float8LinearConverter.Config(enable_fsdp_float8_all_gather=True)] ) return config ``` ### 4. `TrainSpec` -> `ModelSpec` `TrainSpec` is renamed to `ModelSpec` with a narrower scope: it holds only model-specific concerns (model config, parallelization function, loss function, state dict adapter). All training-level concerns (optimizer, LR scheduler, checkpointing, etc.) live in `Trainer.Config`. ### 5. Model Configs: Flat `ModelArgs` -> Nested Dataclass Hierarchy Model hyperparameters move from a flat `ModelArgs` dataclass into a nested `Config` hierarchy that mirrors the module tree: ```python # Before (main): flat args.py @DataClass class ModelArgs: dim: int = 4096 n_layers: int = 32 n_heads: int = 32 # ... 20+ flat fields # After (this PR): nested Config in model class class Llama3Model(Decoder): @DataClass(kw_only=True, slots=True) class Config(Decoder.Config): layer: Llama3TransformerBlock.Config # contains attention + FFN configs rope: RoPE.Config # contains RoPE-specific params ``` ### 6. `train.py` Split The monolithic `train.py` (~800 lines) is split into: - `train.py` (~60 lines): thin entry point that calls `ConfigManager.parse_args()` and `config.build()` - `trainer.py` (~850 lines): the `Trainer` class with training loop logic ### 7. Experiment Extension via Inheritance Experiments extend the config system through dataclass subclassing instead of runtime config merging: ```python # torchtitan/experiments/simple_fsdp/configs.py @DataClass(kw_only=True, slots=True) class SimpleFSDPConfig(Trainer.Config): compile: SimpleFSDPCompileConfig = field(default_factory=SimpleFSDPCompileConfig) ``` Their `config_registry.py` returns the subclassed config type, and `tyro` auto-generates CLI parsing for the extended fields. --- ## UX Comparison ### Launching Training ```bash # Before (main) CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.py" ./run_train.sh # After (this PR) MODEL=llama3 CONFIG=llama3_8b ./run_train.sh ``` ### CLI Overrides ```bash # Before (main) CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh \ --training.steps 100 --parallelism.tensor_parallel_degree 2 # After (this PR) ./run_train.sh --training.steps 100 --parallelism.tensor_parallel_degree 2 # (defaults to MODEL=llama3, CONFIG=llama3_debugmodel via run_train.sh) ``` CLI override syntax is unchanged (`--section.field value`), but `tyro` now provides typed `--help` output generated from the dataclass tree. ### Defining a New Model Config ```bash # Before: create a new TOML file, copy-paste sections, edit values cp train_configs/debug_model.toml train_configs/my_experiment.toml vim train_configs/my_experiment.toml # After: write a Python function that mutates an existing config def my_experiment() -> Trainer.Config: config = llama3_debugmodel() config.training.steps = 100 config.optimizer.lr = 1e-4 return config ``` ### Adding Experiment-Specific Config Fields ```python # Before (main): custom_config_module in TOML + runtime _merge_configs # Requires: TOML key pointing to a Python module, dynamic dataclass creation # After (this PR): dataclass inheritance @DataClass(kw_only=True, slots=True) class MyExperimentConfig(Trainer.Config): my_custom_field: str = "default" ``` ### Float8 / Quantization Configuration ```python # Before (main): TOML section # [quantize.linear.float8] # enable_fsdp_float8_all_gather = true # precompute_float8_dynamic_scale_for_fsdp = true # After (this PR): typed config object model_converters=ModelConvertersContainer.Config( converters=[ Float8LinearConverter.Config( enable_fsdp_float8_all_gather=True, precompute_float8_dynamic_scale_for_fsdp=True, ), ], ), ``` --- ## Limitations and Trade-offs ### 1. Configs are no longer declarative text files TOML files were readable by anyone without Python knowledge. The new config_registry functions are Python code, which requires understanding imports, function calls, and dataclass construction. For users who only need to tweak hyperparameters, the CLI override syntax (`--training.steps 100`) works the same, but understanding the full config requires reading Python. ### 2. Steeper learning curve for contributors Adding a new model now requires understanding the `Configurable` protocol, nested `Config` dataclass hierarchy, and the `config_registry` pattern. The old approach of copying a TOML file and editing values had a lower barrier to entry. ### 3. Config serialization is more complex TOML files were trivially serializable and diffable. The new system supports `to_dict()` + JSON serialization, but configs containing callables (e.g., `ModelSpec.parallelize_fn`) cannot be fully round-tripped. The `model_spec` field is excluded from serialization and suppressed from CLI parsing. ### 4. tyro dependency The CLI parsing now depends on `tyro`, a third-party library. While `tyro` is well-maintained and provides typed CLI generation from dataclasses, it is an additional dependency that must be kept compatible with the dataclass patterns used here. ### 5. `@dataclass(slots=True)` constraints The `Configurable` base class enforces `@dataclass(kw_only=True, slots=True)` on all Config classes. While this provides memory efficiency and prevents accidental attribute assignment, `slots=True` prevents dynamic attribute addition and makes multiple inheritance with other slotted classes more constrained. Each Config subclass in a deep hierarchy must repeat the `@dataclass(kw_only=True, slots=True)` decorator. ### 6. Two-level indirection for model selection The old system required one identifier: `--job.config_file path/to/file.toml`. The new system requires two: `--module llama3 --config llama3_8b`. While this separates model identity from training recipe, it adds an extra argument. --- ## Numerics Verification All model configs were verified for numerical equivalence against the main branch (commit `10d8a306`): NOTE - only models that can fit on 8 GPUs are tested - only subset of parallelism combination are tested | Model | Status | Notes | |-------|--------|-------| | llama3 (debugmodel, 8B) | Bitwise match | | | llama3 (debugmodel_flex_attn) | Bitwise match | | | qwen3 (0.6B, 1.7B, 32B, MoE debugmodel) | Bitwise match | | | deepseek_v3 (debugmodel, 16B) | Close (max diff 0.00014) | Pre-existing main branch bug: missing `eps` in final RMSNorm | | llama4 debugmodel | Bitwise match | _irope variants don't work on main (FlexAttn `'dict' object has no attribute 'BLOCK_SIZE'`) but now work after this PR | | **gpt_oss** debugmodel | --debug.deterministic causes loss to be NaN; o/w first step loss match, minor difference after (likely caused by flex?) | | | flux | Bitwise match | | --- ## Migration Guide | Old (main) | New (this PR) | |---|---| | `CONFIG_FILE="path/to/config.toml" ./run_train.sh` | `MODEL=llama3 CONFIG=llama3_8b ./run_train.sh` | | `--job.config_file path.toml` | `--module llama3 --config llama3_8b` | | `train_configs/*.toml` | `config_registry.py` functions | | `TrainSpec` | `ModelSpec` | | `ModelArgs` / `args.py` | Nested `Model.Config` dataclass | | `custom_config_module` + `_merge_configs()` | Subclass `Trainer.Config` | | `build_model_converters()` free function | `ModelConvertersContainer.Config.build()` | | `build_metrics_processor()` free function | `MetricsProcessor.Config.build()` |
132bbc6 to
674852a
Compare
tianyu-l
reviewed
Mar 23, 2026
tianyu-l
approved these changes
Mar 24, 2026
yiming0416
added a commit
that referenced
this pull request
Mar 25, 2026
With #2357 landed, we remove the duplicated `_get_default_sac_save_ops()` from the graph trainer's `passes.py` and replaces it with the shared `_get_save_ops()` from `torchtitan.distributed.activation_checkpoint`
pytorch-bot Bot
pushed a commit
that referenced
this pull request
Mar 27, 2026
…ts (#2357) ### Summary - Remove layer-frequency selective activation checkpointing (`selective_ac_option` and `_layer_sac_count`) — per-op SAC is now the only selective mode - Centralize the op save list into `default_activation_checkpoint_policy()` in `activation_checkpoint.py`, removing duplicated `_op_sac_save_list` sets from per-model `parallelize.py` files (llama3, llama4, deepseek_v3, qwen3, gpt_oss, graph_trainer) - Remove the `op_sac_save_list` parameter from `apply_ac` — models no longer need to pass their own op sets - Build the centralized policy from `get_default_op_list()` (upstream PyTorch) plus explicit compute ops (SDPA, FlexAttention, inductor, varlen_attn) and communication ops (reduce_scatter, all_to_all, deepep, hybridep), with conditional resolution for optional dependencies - Use `@lru_cache` with `cache_hash` on the policy factory for dynamo recompilation avoidance and AOTAutograd cache compatibility - Add `--activation_checkpoint.mode full` to PP integration tests (`InterleavedZeroBubble`, `ZBVZeroBubble`, `PipelineScheduleMulti`) since they relied on layer_sac - Clean deepep imports, now we import from `torchtitan.distirbuted.deepep.deepep` or `torchtitan.distirbuted.deepep.hybridep`, to keep them symmetrical. ### Test Added `test_force_recompute_mm_fqns`: verifies that `per_op_sac_force_recompute_mm_shapes_by_fqns` controls exactly which matmuls are recomputed vs stored during backward. Uses a TorchDispatchMode tracker to count aten.mm calls per weight tensor
weifengpy
pushed a commit
to weifengpy/torchtitan
that referenced
this pull request
Mar 27, 2026
) With pytorch#2357 landed, we remove the duplicated `_get_default_sac_save_ops()` from the graph trainer's `passes.py` and replaces it with the shared `_get_save_ops()` from `torchtitan.distributed.activation_checkpoint`
acisseJZhong
pushed a commit
that referenced
this pull request
Mar 31, 2026
…ts (#2357) ### Summary - Remove layer-frequency selective activation checkpointing (`selective_ac_option` and `_layer_sac_count`) — per-op SAC is now the only selective mode - Centralize the op save list into `default_activation_checkpoint_policy()` in `activation_checkpoint.py`, removing duplicated `_op_sac_save_list` sets from per-model `parallelize.py` files (llama3, llama4, deepseek_v3, qwen3, gpt_oss, graph_trainer) - Remove the `op_sac_save_list` parameter from `apply_ac` — models no longer need to pass their own op sets - Build the centralized policy from `get_default_op_list()` (upstream PyTorch) plus explicit compute ops (SDPA, FlexAttention, inductor, varlen_attn) and communication ops (reduce_scatter, all_to_all, deepep, hybridep), with conditional resolution for optional dependencies - Use `@lru_cache` with `cache_hash` on the policy factory for dynamo recompilation avoidance and AOTAutograd cache compatibility - Add `--activation_checkpoint.mode full` to PP integration tests (`InterleavedZeroBubble`, `ZBVZeroBubble`, `PipelineScheduleMulti`) since they relied on layer_sac - Clean deepep imports, now we import from `torchtitan.distirbuted.deepep.deepep` or `torchtitan.distirbuted.deepep.hybridep`, to keep them symmetrical. ### Test Added `test_force_recompute_mm_fqns`: verifies that `per_op_sac_force_recompute_mm_shapes_by_fqns` controls exactly which matmuls are recomputed vs stored during backward. Uses a TorchDispatchMode tracker to count aten.mm calls per weight tensor
acisseJZhong
pushed a commit
that referenced
this pull request
Mar 31, 2026
With #2357 landed, we remove the duplicated `_get_default_sac_save_ops()` from the graph trainer's `passes.py` and replaces it with the shared `_get_save_ops()` from `torchtitan.distributed.activation_checkpoint`
TXacs
pushed a commit
to McmillanTAC/torchtitan
that referenced
this pull request
Apr 13, 2026
…ry (pytorch#2386) **NOTE**: This PR is a large refactor of the codebase. https://github.com/pytorch/torchtitan/releases/tag/v0.2.2 contains a latest release right before this PR is merged. # author's note This refactor is mainly trying to address two issues: - bad encapsulation: previously a monolithic `JobConfig` is leaked everywhere - not easy to iterate and experiment on model architecture and training components The main changes are: - Strict encapsulation, even at the cost of (hopefully temporary) bloated interface when calling subcomponents (e.g. validator). We should try to find the right abstraction on cross-components visibility. - Each `Configurable` component owns its own `Config`, which builds the owner component. It achieves modularization via polymorphism and inheritance, both classic concepts in OOP. - This is partly inspired by repos like [AXLearn](https://github.com/apple/axlearn) (in particular, @ruomingp's [ML API Styles](https://github.com/apple/axlearn/blob/main/docs/ml_api_style.md)), github issues (e.g. pytorch#1055), and offline discussions (with @Chillee, @ailzhang etc.). - Similar functionality can be alternatively achieved by other ways, e.g. `_target_` in [Hydra](https://hydra.cc/docs/advanced/instantiate_objects/overview/), but there are opinions not to couple with Hydra's other offerings. See pytorch#1415 - Main entry point switches from TOML files to Python functions (a.k.a. `config_registry.py` in each model). - TOML has the constraint that everything needs to be registered explicitly before it can be used, e.g. our quantization components need to be registered with string names. Python's language level implicit registration is what we believe to be more minimal, and should be fairly easy to extended/modified to support TOML/YAML when users builds upon / fork torchtitan. - That said, Python config provides much more power, e.g. one can use arbitrary logic to create (the config of) a component, which is hard to express with TOML/YAML, thus creating extra difficulty when users want to migrate to their own favorite config system. The only thing we can do is to stay conservative on the usage of such power. - We still uses [tyro](https://github.com/brentyi/tyro) to convert config dataclass to CLI, still with the limitation that users need to construct customized config classes, all the way from root level (`Trainer.Config` now, `JobConfig` in the past). - If CLI is not needed, new trainer (or any high-level) config is not required. - To support "polymorphic construction" from CLI without the hassle, check out [chz](https://github.com/openai/chz/blob/main/docs/04_command_line.md#polymorphic-construction). This PR also - updates the docs -- there might be remaining outdated docs, please raise issues or help fix - moves ft to experiments, continuing the effort in pytorch#2311 Remaining work - [AutoParallel CI failure](https://github.com/pytorch/torchtitan/actions/runs/22165425254/job/64091572780?pr=2386) seems caused by the way RoPE is authored, and needs change in autoparallel. (cc @xmfan) - being fixed in meta-pytorch/autoparallel#321 - [CompilerToolkit CI failure](https://github.com/pytorch/torchtitan/actions/runs/22168015737/job/64099486707?pr=2386) `TypeError: forward() missing 1 required positional argument: 'fwd_rng_state_2'` cc @yiming0416 please help take a look - [SimpleFSDP CI failure](https://github.com/pytorch/torchtitan/actions/runs/22168015749/job/64099486149?pr=2386) is the same as pytorch#2312 around dynamic shape for for-loop MoE experts computation. (cc @pianpwk) - being fixed in pytorch#2399 - Fix integration scripts for MAST, Zoomer, etc. - organize docs from `docs/` to subfolders, as we are having more contents to cover in general - generate and store serialized configs (maybe not in the repo) - continue SAC refactor in pytorch#2357, but somehow keep the every-other-mm policy (cc @mori360) - refactor RoPE in general, at least resolving the following TODOs in code (cc @shuhuayu) - having to set / no validation on rope dim == decoder dim // attention n_heads - consolidate `apply_rotary_emb_complex` and `apply_rotary_emb_single_complex` - address pytorch#2417 Longer-term issues - More careful design about what to put config vs. runtime build kwargs. (thanks @ailzhang) - ModelSpec not serializable. There may be multiple solutions, but we can potentially consolidate `model.py` and `parallelize.py` by - sharing AC, compile, DP application across all Decoder models - putting per-module TP/CP/EP sharding plan inside model itself - Right now `BaseModel.update_from_config` violates encapsulation by passing the Trainer config into Model config. This could be avoided by python logic either in config construction time, or in trainer. - Refactor `init_weights` into `Module.Config` instead of staying in `Module` - The benefit is that param init can be configurable; o/w we are coupling module implementation and its weight init. - This may require refactor of current TransformerBlock and its config. E.g. `weight_init_std` may need to be put in config, with `__post_init__` determining its value. (See related complaints / discussions on `__post_init__` by [chz](https://github.com/openai/chz/blob/main/docs/21_post_init.md)) Note to reviewer: Although I believe the changes in this PR come naturally in a bundle, you may (or may not) find the stack of 16 commits easier to review, as I tried to split the changes in some logic manner. I apologize for the giant PR. # claude-generated summary ## Summary This PR refactors torchtitan's configuration and training infrastructure in 15 incremental, backwards-incompatible commits. The central change replaces TOML config files and a monolithic `JobConfig` parser with **typed Python dataclass configs**, a **`Configurable` base class** pattern, and a **`config_registry`** module per model. **270 files changed, 10,025 insertions, 11,418 deletions.** --- ## Motivation The previous system used TOML files parsed by a custom `ConfigManager` that layered CLI overrides on top. While simple, this had several friction points: 1. **No type safety at config boundaries.** TOML values are strings/ints/floats parsed at runtime. A typo in a key name (e.g., `training.stpes`) silently becomes a default value. 4. **Flat namespace.** All config sections (`[model]`, `[training]`, `[optimizer]`, `[checkpoint]`, ...) lived in a single `JobConfig` class. Every component received the full `JobConfig` even when it only needed a few fields. 5. **Experiment extension was ad-hoc.** Experiments that needed custom config fields (e.g., SimpleFSDP's `compile.graph_passes` or FaultTolerant's `fault_tolerance.*`) required a `custom_config_module` TOML key and a runtime `_merge_configs` call to graft new fields onto `JobConfig`. 6. **Model args were disconnected from model code.** A `ModelArgs` dataclass in `args.py` defined hyperparameters, but the `TrainSpec` that bundled model + parallelization + loss was registered separately, with no type-level link between them. --- ## What Changed ### 1. `Configurable` Base Class A new `Configurable` base class (`torchtitan/config/configurable.py`) establishes a universal pattern: ```python class Configurable: @DataClass(kw_only=True, slots=True) class Config: def build(self, **kwargs): return self._owner(config=self, **kwargs) def __init_subclass__(cls, **kwargs): # Auto-wires Config.build() -> cls(config=..., **kwargs) # Enforces @DataClass(kw_only=True, slots=True) on every Config ``` Every configurable component (Trainer, model, optimizer, tokenizer, dataloader, checkpoint manager, metrics, validators, quantization converters, ...) follows this pattern. Calling `config.build()` constructs the owning class. ### 2. `Trainer.Config` Replaces `JobConfig` The monolithic `JobConfig` is replaced by `Trainer.Config`, a nested dataclass that aggregates typed sub-configs: ```python class Trainer(Stateful, Configurable): @DataClass(kw_only=True, slots=True) class Config(Configurable.Config): model_spec: ModelSpec | None = None # set by config_registry, suppressed from CLI job: JobConfig = ... training: TrainingConfig = ... parallelism: ParallelismConfig = ... optimizer: OptimizersContainer.Config = ... lr_scheduler: LRSchedulersContainer.Config = ... checkpoint: CheckpointManager.Config = ... dataloader: BaseDataLoader.Config = ... metrics: MetricsProcessor.Config = ... # ... etc. ``` Each sub-config is the `Config` class of the component that consumes it (e.g., `CheckpointManager.Config` is defined inside `CheckpointManager`). Components receive only their own config, not the entire training config. ### 3. `config_registry.py` Replaces TOML Files Each model defines a `config_registry.py` with functions that return complete `Trainer.Config` instances: ```python # torchtitan/models/llama3/config_registry.py def llama3_debugmodel() -> Trainer.Config: return Trainer.Config( job=JobConfig(description="Llama 3 debug training", ...), model_spec=model_registry("debugmodel"), optimizer=OptimizersContainer.Config(lr=8e-4), training=TrainingConfig(local_batch_size=8, seq_len=2048, steps=10), dataloader=HuggingFaceTextDataLoader.Config(dataset="c4_test"), # ... ) def llama3_debugmodel_float8() -> Trainer.Config: config = llama3_debugmodel() config.model_converters = ModelConvertersContainer.Config( converters=[Float8LinearConverter.Config(enable_fsdp_float8_all_gather=True)] ) return config ``` ### 4. `TrainSpec` -> `ModelSpec` `TrainSpec` is renamed to `ModelSpec` with a narrower scope: it holds only model-specific concerns (model config, parallelization function, loss function, state dict adapter). All training-level concerns (optimizer, LR scheduler, checkpointing, etc.) live in `Trainer.Config`. ### 5. Model Configs: Flat `ModelArgs` -> Nested Dataclass Hierarchy Model hyperparameters move from a flat `ModelArgs` dataclass into a nested `Config` hierarchy that mirrors the module tree: ```python # Before (main): flat args.py @DataClass class ModelArgs: dim: int = 4096 n_layers: int = 32 n_heads: int = 32 # ... 20+ flat fields # After (this PR): nested Config in model class class Llama3Model(Decoder): @DataClass(kw_only=True, slots=True) class Config(Decoder.Config): layer: Llama3TransformerBlock.Config # contains attention + FFN configs rope: RoPE.Config # contains RoPE-specific params ``` ### 6. `train.py` Split The monolithic `train.py` (~800 lines) is split into: - `train.py` (~60 lines): thin entry point that calls `ConfigManager.parse_args()` and `config.build()` - `trainer.py` (~850 lines): the `Trainer` class with training loop logic ### 7. Experiment Extension via Inheritance Experiments extend the config system through dataclass subclassing instead of runtime config merging: ```python # torchtitan/experiments/simple_fsdp/configs.py @DataClass(kw_only=True, slots=True) class SimpleFSDPConfig(Trainer.Config): compile: SimpleFSDPCompileConfig = field(default_factory=SimpleFSDPCompileConfig) ``` Their `config_registry.py` returns the subclassed config type, and `tyro` auto-generates CLI parsing for the extended fields. --- ## UX Comparison ### Launching Training ```bash # Before (main) CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.py" ./run_train.sh # After (this PR) MODEL=llama3 CONFIG=llama3_8b ./run_train.sh ``` ### CLI Overrides ```bash # Before (main) CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh \ --training.steps 100 --parallelism.tensor_parallel_degree 2 # After (this PR) ./run_train.sh --training.steps 100 --parallelism.tensor_parallel_degree 2 # (defaults to MODEL=llama3, CONFIG=llama3_debugmodel via run_train.sh) ``` CLI override syntax is unchanged (`--section.field value`), but `tyro` now provides typed `--help` output generated from the dataclass tree. ### Defining a New Model Config ```bash # Before: create a new TOML file, copy-paste sections, edit values cp train_configs/debug_model.toml train_configs/my_experiment.toml vim train_configs/my_experiment.toml # After: write a Python function that mutates an existing config def my_experiment() -> Trainer.Config: config = llama3_debugmodel() config.training.steps = 100 config.optimizer.lr = 1e-4 return config ``` ### Adding Experiment-Specific Config Fields ```python # Before (main): custom_config_module in TOML + runtime _merge_configs # Requires: TOML key pointing to a Python module, dynamic dataclass creation # After (this PR): dataclass inheritance @DataClass(kw_only=True, slots=True) class MyExperimentConfig(Trainer.Config): my_custom_field: str = "default" ``` ### Float8 / Quantization Configuration ```python # Before (main): TOML section # [quantize.linear.float8] # enable_fsdp_float8_all_gather = true # precompute_float8_dynamic_scale_for_fsdp = true # After (this PR): typed config object model_converters=ModelConvertersContainer.Config( converters=[ Float8LinearConverter.Config( enable_fsdp_float8_all_gather=True, precompute_float8_dynamic_scale_for_fsdp=True, ), ], ), ``` --- ## Limitations and Trade-offs ### 1. Configs are no longer declarative text files TOML files were readable by anyone without Python knowledge. The new config_registry functions are Python code, which requires understanding imports, function calls, and dataclass construction. For users who only need to tweak hyperparameters, the CLI override syntax (`--training.steps 100`) works the same, but understanding the full config requires reading Python. ### 2. Steeper learning curve for contributors Adding a new model now requires understanding the `Configurable` protocol, nested `Config` dataclass hierarchy, and the `config_registry` pattern. The old approach of copying a TOML file and editing values had a lower barrier to entry. ### 3. Config serialization is more complex TOML files were trivially serializable and diffable. The new system supports `to_dict()` + JSON serialization, but configs containing callables (e.g., `ModelSpec.parallelize_fn`) cannot be fully round-tripped. The `model_spec` field is excluded from serialization and suppressed from CLI parsing. ### 4. tyro dependency The CLI parsing now depends on `tyro`, a third-party library. While `tyro` is well-maintained and provides typed CLI generation from dataclasses, it is an additional dependency that must be kept compatible with the dataclass patterns used here. ### 5. `@dataclass(slots=True)` constraints The `Configurable` base class enforces `@dataclass(kw_only=True, slots=True)` on all Config classes. While this provides memory efficiency and prevents accidental attribute assignment, `slots=True` prevents dynamic attribute addition and makes multiple inheritance with other slotted classes more constrained. Each Config subclass in a deep hierarchy must repeat the `@dataclass(kw_only=True, slots=True)` decorator. ### 6. Two-level indirection for model selection The old system required one identifier: `--job.config_file path/to/file.toml`. The new system requires two: `--module llama3 --config llama3_8b`. While this separates model identity from training recipe, it adds an extra argument. --- ## Numerics Verification All model configs were verified for numerical equivalence against the main branch (commit `10d8a306`): NOTE - only models that can fit on 8 GPUs are tested - only subset of parallelism combination are tested | Model | Status | Notes | |-------|--------|-------| | llama3 (debugmodel, 8B) | Bitwise match | | | llama3 (debugmodel_flex_attn) | Bitwise match | | | qwen3 (0.6B, 1.7B, 32B, MoE debugmodel) | Bitwise match | | | deepseek_v3 (debugmodel, 16B) | Close (max diff 0.00014) | Pre-existing main branch bug: missing `eps` in final RMSNorm | | llama4 debugmodel | Bitwise match | _irope variants don't work on main (FlexAttn `'dict' object has no attribute 'BLOCK_SIZE'`) but now work after this PR | | **gpt_oss** debugmodel | --debug.deterministic causes loss to be NaN; o/w first step loss match, minor difference after (likely caused by flex?) | | | flux | Bitwise match | | --- ## Migration Guide | Old (main) | New (this PR) | |---|---| | `CONFIG_FILE="path/to/config.toml" ./run_train.sh` | `MODEL=llama3 CONFIG=llama3_8b ./run_train.sh` | | `--job.config_file path.toml` | `--module llama3 --config llama3_8b` | | `train_configs/*.toml` | `config_registry.py` functions | | `TrainSpec` | `ModelSpec` | | `ModelArgs` / `args.py` | Nested `Model.Config` dataclass | | `custom_config_module` + `_merge_configs()` | Subclass `Trainer.Config` | | `build_model_converters()` free function | `ModelConvertersContainer.Config.build()` | | `build_metrics_processor()` free function | `MetricsProcessor.Config.build()` |
TXacs
pushed a commit
to McmillanTAC/torchtitan
that referenced
this pull request
Apr 13, 2026
…ts (pytorch#2357) ### Summary - Remove layer-frequency selective activation checkpointing (`selective_ac_option` and `_layer_sac_count`) — per-op SAC is now the only selective mode - Centralize the op save list into `default_activation_checkpoint_policy()` in `activation_checkpoint.py`, removing duplicated `_op_sac_save_list` sets from per-model `parallelize.py` files (llama3, llama4, deepseek_v3, qwen3, gpt_oss, graph_trainer) - Remove the `op_sac_save_list` parameter from `apply_ac` — models no longer need to pass their own op sets - Build the centralized policy from `get_default_op_list()` (upstream PyTorch) plus explicit compute ops (SDPA, FlexAttention, inductor, varlen_attn) and communication ops (reduce_scatter, all_to_all, deepep, hybridep), with conditional resolution for optional dependencies - Use `@lru_cache` with `cache_hash` on the policy factory for dynamo recompilation avoidance and AOTAutograd cache compatibility - Add `--activation_checkpoint.mode full` to PP integration tests (`InterleavedZeroBubble`, `ZBVZeroBubble`, `PipelineScheduleMulti`) since they relied on layer_sac - Clean deepep imports, now we import from `torchtitan.distirbuted.deepep.deepep` or `torchtitan.distirbuted.deepep.hybridep`, to keep them symmetrical. ### Test Added `test_force_recompute_mm_fqns`: verifies that `per_op_sac_force_recompute_mm_shapes_by_fqns` controls exactly which matmuls are recomputed vs stored during backward. Uses a TorchDispatchMode tracker to count aten.mm calls per weight tensor
TXacs
pushed a commit
to McmillanTAC/torchtitan
that referenced
this pull request
Apr 13, 2026
) With pytorch#2357 landed, we remove the duplicated `_get_default_sac_save_ops()` from the graph trainer's `passes.py` and replaces it with the shared `_get_save_ops()` from `torchtitan.distributed.activation_checkpoint`
ACharacterInASimulation
pushed a commit
to ACharacterInASimulation/torchtitan
that referenced
this pull request
Apr 21, 2026
…ry (pytorch#2386) **NOTE**: This PR is a large refactor of the codebase. https://github.com/pytorch/torchtitan/releases/tag/v0.2.2 contains a latest release right before this PR is merged. # author's note This refactor is mainly trying to address two issues: - bad encapsulation: previously a monolithic `JobConfig` is leaked everywhere - not easy to iterate and experiment on model architecture and training components The main changes are: - Strict encapsulation, even at the cost of (hopefully temporary) bloated interface when calling subcomponents (e.g. validator). We should try to find the right abstraction on cross-components visibility. - Each `Configurable` component owns its own `Config`, which builds the owner component. It achieves modularization via polymorphism and inheritance, both classic concepts in OOP. - This is partly inspired by repos like [AXLearn](https://github.com/apple/axlearn) (in particular, @ruomingp's [ML API Styles](https://github.com/apple/axlearn/blob/main/docs/ml_api_style.md)), github issues (e.g. pytorch#1055), and offline discussions (with @Chillee, @ailzhang etc.). - Similar functionality can be alternatively achieved by other ways, e.g. `_target_` in [Hydra](https://hydra.cc/docs/advanced/instantiate_objects/overview/), but there are opinions not to couple with Hydra's other offerings. See pytorch#1415 - Main entry point switches from TOML files to Python functions (a.k.a. `config_registry.py` in each model). - TOML has the constraint that everything needs to be registered explicitly before it can be used, e.g. our quantization components need to be registered with string names. Python's language level implicit registration is what we believe to be more minimal, and should be fairly easy to extended/modified to support TOML/YAML when users builds upon / fork torchtitan. - That said, Python config provides much more power, e.g. one can use arbitrary logic to create (the config of) a component, which is hard to express with TOML/YAML, thus creating extra difficulty when users want to migrate to their own favorite config system. The only thing we can do is to stay conservative on the usage of such power. - We still uses [tyro](https://github.com/brentyi/tyro) to convert config dataclass to CLI, still with the limitation that users need to construct customized config classes, all the way from root level (`Trainer.Config` now, `JobConfig` in the past). - If CLI is not needed, new trainer (or any high-level) config is not required. - To support "polymorphic construction" from CLI without the hassle, check out [chz](https://github.com/openai/chz/blob/main/docs/04_command_line.md#polymorphic-construction). This PR also - updates the docs -- there might be remaining outdated docs, please raise issues or help fix - moves ft to experiments, continuing the effort in pytorch#2311 Remaining work - [AutoParallel CI failure](https://github.com/pytorch/torchtitan/actions/runs/22165425254/job/64091572780?pr=2386) seems caused by the way RoPE is authored, and needs change in autoparallel. (cc @xmfan) - being fixed in meta-pytorch/autoparallel#321 - [CompilerToolkit CI failure](https://github.com/pytorch/torchtitan/actions/runs/22168015737/job/64099486707?pr=2386) `TypeError: forward() missing 1 required positional argument: 'fwd_rng_state_2'` cc @yiming0416 please help take a look - [SimpleFSDP CI failure](https://github.com/pytorch/torchtitan/actions/runs/22168015749/job/64099486149?pr=2386) is the same as pytorch#2312 around dynamic shape for for-loop MoE experts computation. (cc @pianpwk) - being fixed in pytorch#2399 - Fix integration scripts for MAST, Zoomer, etc. - organize docs from `docs/` to subfolders, as we are having more contents to cover in general - generate and store serialized configs (maybe not in the repo) - continue SAC refactor in pytorch#2357, but somehow keep the every-other-mm policy (cc @mori360) - refactor RoPE in general, at least resolving the following TODOs in code (cc @shuhuayu) - having to set / no validation on rope dim == decoder dim // attention n_heads - consolidate `apply_rotary_emb_complex` and `apply_rotary_emb_single_complex` - address pytorch#2417 Longer-term issues - More careful design about what to put config vs. runtime build kwargs. (thanks @ailzhang) - ModelSpec not serializable. There may be multiple solutions, but we can potentially consolidate `model.py` and `parallelize.py` by - sharing AC, compile, DP application across all Decoder models - putting per-module TP/CP/EP sharding plan inside model itself - Right now `BaseModel.update_from_config` violates encapsulation by passing the Trainer config into Model config. This could be avoided by python logic either in config construction time, or in trainer. - Refactor `init_weights` into `Module.Config` instead of staying in `Module` - The benefit is that param init can be configurable; o/w we are coupling module implementation and its weight init. - This may require refactor of current TransformerBlock and its config. E.g. `weight_init_std` may need to be put in config, with `__post_init__` determining its value. (See related complaints / discussions on `__post_init__` by [chz](https://github.com/openai/chz/blob/main/docs/21_post_init.md)) Note to reviewer: Although I believe the changes in this PR come naturally in a bundle, you may (or may not) find the stack of 16 commits easier to review, as I tried to split the changes in some logic manner. I apologize for the giant PR. # claude-generated summary ## Summary This PR refactors torchtitan's configuration and training infrastructure in 15 incremental, backwards-incompatible commits. The central change replaces TOML config files and a monolithic `JobConfig` parser with **typed Python dataclass configs**, a **`Configurable` base class** pattern, and a **`config_registry`** module per model. **270 files changed, 10,025 insertions, 11,418 deletions.** --- ## Motivation The previous system used TOML files parsed by a custom `ConfigManager` that layered CLI overrides on top. While simple, this had several friction points: 1. **No type safety at config boundaries.** TOML values are strings/ints/floats parsed at runtime. A typo in a key name (e.g., `training.stpes`) silently becomes a default value. 4. **Flat namespace.** All config sections (`[model]`, `[training]`, `[optimizer]`, `[checkpoint]`, ...) lived in a single `JobConfig` class. Every component received the full `JobConfig` even when it only needed a few fields. 5. **Experiment extension was ad-hoc.** Experiments that needed custom config fields (e.g., SimpleFSDP's `compile.graph_passes` or FaultTolerant's `fault_tolerance.*`) required a `custom_config_module` TOML key and a runtime `_merge_configs` call to graft new fields onto `JobConfig`. 6. **Model args were disconnected from model code.** A `ModelArgs` dataclass in `args.py` defined hyperparameters, but the `TrainSpec` that bundled model + parallelization + loss was registered separately, with no type-level link between them. --- ## What Changed ### 1. `Configurable` Base Class A new `Configurable` base class (`torchtitan/config/configurable.py`) establishes a universal pattern: ```python class Configurable: @DataClass(kw_only=True, slots=True) class Config: def build(self, **kwargs): return self._owner(config=self, **kwargs) def __init_subclass__(cls, **kwargs): # Auto-wires Config.build() -> cls(config=..., **kwargs) # Enforces @DataClass(kw_only=True, slots=True) on every Config ``` Every configurable component (Trainer, model, optimizer, tokenizer, dataloader, checkpoint manager, metrics, validators, quantization converters, ...) follows this pattern. Calling `config.build()` constructs the owning class. ### 2. `Trainer.Config` Replaces `JobConfig` The monolithic `JobConfig` is replaced by `Trainer.Config`, a nested dataclass that aggregates typed sub-configs: ```python class Trainer(Stateful, Configurable): @DataClass(kw_only=True, slots=True) class Config(Configurable.Config): model_spec: ModelSpec | None = None # set by config_registry, suppressed from CLI job: JobConfig = ... training: TrainingConfig = ... parallelism: ParallelismConfig = ... optimizer: OptimizersContainer.Config = ... lr_scheduler: LRSchedulersContainer.Config = ... checkpoint: CheckpointManager.Config = ... dataloader: BaseDataLoader.Config = ... metrics: MetricsProcessor.Config = ... # ... etc. ``` Each sub-config is the `Config` class of the component that consumes it (e.g., `CheckpointManager.Config` is defined inside `CheckpointManager`). Components receive only their own config, not the entire training config. ### 3. `config_registry.py` Replaces TOML Files Each model defines a `config_registry.py` with functions that return complete `Trainer.Config` instances: ```python # torchtitan/models/llama3/config_registry.py def llama3_debugmodel() -> Trainer.Config: return Trainer.Config( job=JobConfig(description="Llama 3 debug training", ...), model_spec=model_registry("debugmodel"), optimizer=OptimizersContainer.Config(lr=8e-4), training=TrainingConfig(local_batch_size=8, seq_len=2048, steps=10), dataloader=HuggingFaceTextDataLoader.Config(dataset="c4_test"), # ... ) def llama3_debugmodel_float8() -> Trainer.Config: config = llama3_debugmodel() config.model_converters = ModelConvertersContainer.Config( converters=[Float8LinearConverter.Config(enable_fsdp_float8_all_gather=True)] ) return config ``` ### 4. `TrainSpec` -> `ModelSpec` `TrainSpec` is renamed to `ModelSpec` with a narrower scope: it holds only model-specific concerns (model config, parallelization function, loss function, state dict adapter). All training-level concerns (optimizer, LR scheduler, checkpointing, etc.) live in `Trainer.Config`. ### 5. Model Configs: Flat `ModelArgs` -> Nested Dataclass Hierarchy Model hyperparameters move from a flat `ModelArgs` dataclass into a nested `Config` hierarchy that mirrors the module tree: ```python # Before (main): flat args.py @DataClass class ModelArgs: dim: int = 4096 n_layers: int = 32 n_heads: int = 32 # ... 20+ flat fields # After (this PR): nested Config in model class class Llama3Model(Decoder): @DataClass(kw_only=True, slots=True) class Config(Decoder.Config): layer: Llama3TransformerBlock.Config # contains attention + FFN configs rope: RoPE.Config # contains RoPE-specific params ``` ### 6. `train.py` Split The monolithic `train.py` (~800 lines) is split into: - `train.py` (~60 lines): thin entry point that calls `ConfigManager.parse_args()` and `config.build()` - `trainer.py` (~850 lines): the `Trainer` class with training loop logic ### 7. Experiment Extension via Inheritance Experiments extend the config system through dataclass subclassing instead of runtime config merging: ```python # torchtitan/experiments/simple_fsdp/configs.py @DataClass(kw_only=True, slots=True) class SimpleFSDPConfig(Trainer.Config): compile: SimpleFSDPCompileConfig = field(default_factory=SimpleFSDPCompileConfig) ``` Their `config_registry.py` returns the subclassed config type, and `tyro` auto-generates CLI parsing for the extended fields. --- ## UX Comparison ### Launching Training ```bash # Before (main) CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.py" ./run_train.sh # After (this PR) MODEL=llama3 CONFIG=llama3_8b ./run_train.sh ``` ### CLI Overrides ```bash # Before (main) CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh \ --training.steps 100 --parallelism.tensor_parallel_degree 2 # After (this PR) ./run_train.sh --training.steps 100 --parallelism.tensor_parallel_degree 2 # (defaults to MODEL=llama3, CONFIG=llama3_debugmodel via run_train.sh) ``` CLI override syntax is unchanged (`--section.field value`), but `tyro` now provides typed `--help` output generated from the dataclass tree. ### Defining a New Model Config ```bash # Before: create a new TOML file, copy-paste sections, edit values cp train_configs/debug_model.toml train_configs/my_experiment.toml vim train_configs/my_experiment.toml # After: write a Python function that mutates an existing config def my_experiment() -> Trainer.Config: config = llama3_debugmodel() config.training.steps = 100 config.optimizer.lr = 1e-4 return config ``` ### Adding Experiment-Specific Config Fields ```python # Before (main): custom_config_module in TOML + runtime _merge_configs # Requires: TOML key pointing to a Python module, dynamic dataclass creation # After (this PR): dataclass inheritance @DataClass(kw_only=True, slots=True) class MyExperimentConfig(Trainer.Config): my_custom_field: str = "default" ``` ### Float8 / Quantization Configuration ```python # Before (main): TOML section # [quantize.linear.float8] # enable_fsdp_float8_all_gather = true # precompute_float8_dynamic_scale_for_fsdp = true # After (this PR): typed config object model_converters=ModelConvertersContainer.Config( converters=[ Float8LinearConverter.Config( enable_fsdp_float8_all_gather=True, precompute_float8_dynamic_scale_for_fsdp=True, ), ], ), ``` --- ## Limitations and Trade-offs ### 1. Configs are no longer declarative text files TOML files were readable by anyone without Python knowledge. The new config_registry functions are Python code, which requires understanding imports, function calls, and dataclass construction. For users who only need to tweak hyperparameters, the CLI override syntax (`--training.steps 100`) works the same, but understanding the full config requires reading Python. ### 2. Steeper learning curve for contributors Adding a new model now requires understanding the `Configurable` protocol, nested `Config` dataclass hierarchy, and the `config_registry` pattern. The old approach of copying a TOML file and editing values had a lower barrier to entry. ### 3. Config serialization is more complex TOML files were trivially serializable and diffable. The new system supports `to_dict()` + JSON serialization, but configs containing callables (e.g., `ModelSpec.parallelize_fn`) cannot be fully round-tripped. The `model_spec` field is excluded from serialization and suppressed from CLI parsing. ### 4. tyro dependency The CLI parsing now depends on `tyro`, a third-party library. While `tyro` is well-maintained and provides typed CLI generation from dataclasses, it is an additional dependency that must be kept compatible with the dataclass patterns used here. ### 5. `@dataclass(slots=True)` constraints The `Configurable` base class enforces `@dataclass(kw_only=True, slots=True)` on all Config classes. While this provides memory efficiency and prevents accidental attribute assignment, `slots=True` prevents dynamic attribute addition and makes multiple inheritance with other slotted classes more constrained. Each Config subclass in a deep hierarchy must repeat the `@dataclass(kw_only=True, slots=True)` decorator. ### 6. Two-level indirection for model selection The old system required one identifier: `--job.config_file path/to/file.toml`. The new system requires two: `--module llama3 --config llama3_8b`. While this separates model identity from training recipe, it adds an extra argument. --- ## Numerics Verification All model configs were verified for numerical equivalence against the main branch (commit `8b006830`): NOTE - only models that can fit on 8 GPUs are tested - only subset of parallelism combination are tested | Model | Status | Notes | |-------|--------|-------| | llama3 (debugmodel, 8B) | Bitwise match | | | llama3 (debugmodel_flex_attn) | Bitwise match | | | qwen3 (0.6B, 1.7B, 32B, MoE debugmodel) | Bitwise match | | | deepseek_v3 (debugmodel, 16B) | Close (max diff 0.00014) | Pre-existing main branch bug: missing `eps` in final RMSNorm | | llama4 debugmodel | Bitwise match | _irope variants don't work on main (FlexAttn `'dict' object has no attribute 'BLOCK_SIZE'`) but now work after this PR | | **gpt_oss** debugmodel | --debug.deterministic causes loss to be NaN; o/w first step loss match, minor difference after (likely caused by flex?) | | | flux | Bitwise match | | --- ## Migration Guide | Old (main) | New (this PR) | |---|---| | `CONFIG_FILE="path/to/config.toml" ./run_train.sh` | `MODEL=llama3 CONFIG=llama3_8b ./run_train.sh` | | `--job.config_file path.toml` | `--module llama3 --config llama3_8b` | | `train_configs/*.toml` | `config_registry.py` functions | | `TrainSpec` | `ModelSpec` | | `ModelArgs` / `args.py` | Nested `Model.Config` dataclass | | `custom_config_module` + `_merge_configs()` | Subclass `Trainer.Config` | | `build_model_converters()` free function | `ModelConvertersContainer.Config.build()` | | `build_metrics_processor()` free function | `MetricsProcessor.Config.build()` |
ACharacterInASimulation
pushed a commit
to ACharacterInASimulation/torchtitan
that referenced
this pull request
Apr 21, 2026
…ts (pytorch#2357) ### Summary - Remove layer-frequency selective activation checkpointing (`selective_ac_option` and `_layer_sac_count`) — per-op SAC is now the only selective mode - Centralize the op save list into `default_activation_checkpoint_policy()` in `activation_checkpoint.py`, removing duplicated `_op_sac_save_list` sets from per-model `parallelize.py` files (llama3, llama4, deepseek_v3, qwen3, gpt_oss, graph_trainer) - Remove the `op_sac_save_list` parameter from `apply_ac` — models no longer need to pass their own op sets - Build the centralized policy from `get_default_op_list()` (upstream PyTorch) plus explicit compute ops (SDPA, FlexAttention, inductor, varlen_attn) and communication ops (reduce_scatter, all_to_all, deepep, hybridep), with conditional resolution for optional dependencies - Use `@lru_cache` with `cache_hash` on the policy factory for dynamo recompilation avoidance and AOTAutograd cache compatibility - Add `--activation_checkpoint.mode full` to PP integration tests (`InterleavedZeroBubble`, `ZBVZeroBubble`, `PipelineScheduleMulti`) since they relied on layer_sac - Clean deepep imports, now we import from `torchtitan.distirbuted.deepep.deepep` or `torchtitan.distirbuted.deepep.hybridep`, to keep them symmetrical. ### Test Added `test_force_recompute_mm_fqns`: verifies that `per_op_sac_force_recompute_mm_shapes_by_fqns` controls exactly which matmuls are recomputed vs stored during backward. Uses a TorchDispatchMode tracker to count aten.mm calls per weight tensor
ACharacterInASimulation
pushed a commit
to ACharacterInASimulation/torchtitan
that referenced
this pull request
Apr 21, 2026
) With pytorch#2357 landed, we remove the duplicated `_get_default_sac_save_ops()` from the graph trainer's `passes.py` and replaces it with the shared `_get_save_ops()` from `torchtitan.distributed.activation_checkpoint`
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
selective_ac_optionand_layer_sac_count) — per-op SAC is now the only selective modedefault_activation_checkpoint_policy()inactivation_checkpoint.py, removing duplicated_op_sac_save_listsets from per-modelparallelize.pyfiles (llama3, llama4, deepseek_v3, qwen3, gpt_oss, graph_trainer)op_sac_save_listparameter fromapply_ac— models no longer need to pass their own op setsget_default_op_list()(upstream PyTorch) plus explicit compute ops (SDPA, FlexAttention, inductor, varlen_attn) and communication ops (reduce_scatter, all_to_all, deepep, hybridep), with conditional resolution for optional dependencies@lru_cachewithcache_hashon the policy factory for dynamo recompilation avoidance and AOTAutograd cache compatibility--activation_checkpoint.mode fullto PP integration tests (InterleavedZeroBubble,ZBVZeroBubble,PipelineScheduleMulti) since they relied on layer_sactorchtitan.distirbuted.deepep.deepeportorchtitan.distirbuted.deepep.hybridep, to keep them symmetrical.Test
Added
test_force_recompute_mm_fqns: verifies thatper_op_sac_force_recompute_mm_shapes_by_fqnscontrols exactly which matmuls are recomputed vs stored during backward. Uses a TorchDispatchMode tracker to count aten.mm calls per weight tensor