Skip to content

separate out training for fault tolerance#2311

Merged
tushar00jain merged 1 commit into
pytorch:mainfrom
tushar00jain:pr2311
Feb 11, 2026
Merged

separate out training for fault tolerance#2311
tushar00jain merged 1 commit into
pytorch:mainfrom
tushar00jain:pr2311

Conversation

@tushar00jain

@tushar00jain tushar00jain commented Feb 2, 2026

Copy link
Copy Markdown
Contributor

Summary:

  • extract logic in train.py that was different for ft in separate functions
  • override these functions in ft's train.py

Test Plan:
Run lighthouse and two replicas

RUST_BACKTRACE=1 torchft_lighthouse --min_replicas 2 --quorum_tick_ms 100 --join_timeout_ms 10000

TRAIN_FILE=torchtitan.experiments.ft.train CONFIG_FILE=./torchtitan/models/llama3_ft/train_configs/debug_model.toml CUDA_VISIBLE_DEVICES=0 NGPU=1 ./run_train.sh --parallelism.data_parallel_shard_degree=1 --fault_tolerance.enable --fault_tolerance.group_size=2 --fault_tolerance.replica_id=0


TRAIN_FILE=torchtitan.experiments.ft.train CONFIG_FILE=./torchtitan/models/llama3_ft/train_configs/debug_model.toml CUDA_VISIBLE_DEVICES=1 NGPU=1 ./run_train.sh --parallelism.data_parallel_shard_degree=1 --fault_tolerance.enable --fault_tolerance.group_size=2 --fault_tolerance.replica_id=1

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label Feb 2, 2026
@tushar00jain tushar00jain force-pushed the pr2311 branch 5 times, most recently from d6a30aa to be416e3 Compare February 2, 2026 16:52
@tushar00jain tushar00jain marked this pull request as ready for review February 2, 2026 16:52
@tushar00jain tushar00jain force-pushed the pr2311 branch 3 times, most recently from 751c4bc to f012f36 Compare February 2, 2026 18:55
Comment thread torchtitan/experiments/ft/torchft.md
Comment thread torchtitan/experiments/ft/train.py Outdated
Comment thread torchtitan/experiments/ft/train.py Outdated
Comment thread torchtitan/experiments/ft/train.py Outdated
Comment thread torchtitan/experiments/ft/train.py Outdated
Comment thread torchtitan/train.py Outdated
Comment thread torchtitan/train.py Outdated
Comment thread torchtitan/train.py
Comment thread torchtitan/train.py Outdated
Comment thread torchtitan/train.py Outdated
Comment thread torchtitan/train.py
@tushar00jain tushar00jain force-pushed the pr2311 branch 3 times, most recently from f4dcc0d to 2f05a79 Compare February 4, 2026 15:42
Comment thread torchtitan/train.py Outdated
Comment thread torchtitan/train.py Outdated
Comment thread torchtitan/train.py Outdated
Comment thread torchtitan/train.py Outdated
Comment thread torchtitan/train.py Outdated
Comment thread torchtitan/train.py Outdated
Comment thread torchtitan/train.py Outdated
Comment thread torchtitan/models/flux/train.py
@tushar00jain tushar00jain force-pushed the pr2311 branch 2 times, most recently from 512d506 to 5f93807 Compare February 6, 2026 11:56
@tushar00jain tushar00jain force-pushed the pr2311 branch 3 times, most recently from 89a5d90 to 7022f5e Compare February 6, 2026 16:31
@tushar00jain tushar00jain force-pushed the pr2311 branch 2 times, most recently from 2302d58 to c2f2098 Compare February 7, 2026 10:55

@fegin fegin left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall, this PR introduces two extra methods, compute_global_losses and get_dp_info while removing all the FT from the main trainer. I think this is a balanced compromise. @tianyu-l any thought?

Comment thread torchtitan/train.py Outdated
Comment thread torchtitan/train.py Outdated
Comment thread torchtitan/train.py Outdated
Summary:
- extract logic in train.py that was different for ft in separate functions
- override these functions in ft's train.py

Test Plan:
Run lighthouse and two replicas

```bash
RUST_BACKTRACE=1 torchft_lighthouse --min_replicas 2 --quorum_tick_ms 100 --join_timeout_ms 10000

TRAIN_FILE=torchtitan.experiments.ft.train CONFIG_FILE=./torchtitan/models/llama3_ft/train_configs/debug_model.toml CUDA_VISIBLE_DEVICES=0 NGPU=1 ./run_train.sh --parallelism.data_parallel_shard_degree=1 --fault_tolerance.enable --fault_tolerance.group_size=2 --fault_tolerance.replica_id=0


TRAIN_FILE=torchtitan.experiments.ft.train CONFIG_FILE=./torchtitan/models/llama3_ft/train_configs/debug_model.toml CUDA_VISIBLE_DEVICES=1 NGPU=1 ./run_train.sh --parallelism.data_parallel_shard_degree=1 --fault_tolerance.enable --fault_tolerance.group_size=2 --fault_tolerance.replica_id=1
```
@tushar00jain tushar00jain merged commit 846653f into pytorch:main Feb 11, 2026
38 of 43 checks passed
@tushar00jain tushar00jain deleted the pr2311 branch February 11, 2026 21:30
tianyu-l added a commit that referenced this pull request Feb 23, 2026
…ry (#2386)

**NOTE**: This PR is a large refactor of the codebase.
https://github.com/pytorch/torchtitan/releases/tag/v0.2.2 contains a
latest release right before this PR is merged.

# author's note

This refactor is mainly trying to address two issues:
- bad encapsulation: previously a monolithic `JobConfig` is leaked
everywhere
- not easy to iterate and experiment on model architecture and training
components

The main changes are:
- Strict encapsulation, even at the cost of (hopefully temporary)
bloated interface when calling subcomponents (e.g. validator). We should
try to find the right abstraction on cross-components visibility.
- Each `Configurable` component owns its own `Config`, which builds the
owner component. It achieves modularization via polymorphism and
inheritance, both classic concepts in OOP.
- This is partly inspired by repos like
[AXLearn](https://github.com/apple/axlearn) (in particular, @ruomingp's
[ML API
Styles](https://github.com/apple/axlearn/blob/main/docs/ml_api_style.md)),
github issues (e.g. #1055),
and offline discussions (with @Chillee, @ailzhang etc.).
- Similar functionality can be alternatively achieved by other ways,
e.g. `_target_` in
[Hydra](https://hydra.cc/docs/advanced/instantiate_objects/overview/),
but there are opinions not to couple with Hydra's other offerings. See
#1415
- Main entry point switches from TOML files to Python functions (a.k.a.
`config_registry.py` in each model).
- TOML has the constraint that everything needs to be registered
explicitly before it can be used, e.g. our quantization components need
to be registered with string names. Python's language level implicit
registration is what we believe to be more minimal, and should be fairly
easy to extended/modified to support TOML/YAML when users builds upon /
fork torchtitan.
- That said, Python config provides much more power, e.g. one can use
arbitrary logic to create (the config of) a component, which is hard to
express with TOML/YAML, thus creating extra difficulty when users want
to migrate to their own favorite config system. The only thing we can do
is to stay conservative on the usage of such power.
- We still uses [tyro](https://github.com/brentyi/tyro) to convert
config dataclass to CLI, still with the limitation that users need to
construct customized config classes, all the way from root level
(`Trainer.Config` now, `JobConfig` in the past).
- If CLI is not needed, new trainer (or any high-level) config is not
required.
- To support "polymorphic construction" from CLI without the hassle,
check out
[chz](https://github.com/openai/chz/blob/main/docs/04_command_line.md#polymorphic-construction).

This PR also
- updates the docs -- there might be remaining outdated docs, please
raise issues or help fix
- moves ft to experiments, continuing the effort in
#2311

Remaining work
- [AutoParallel CI
failure](https://github.com/pytorch/torchtitan/actions/runs/22165425254/job/64091572780?pr=2386)
seems caused by the way RoPE is authored, and needs change in
autoparallel. (cc @xmfan)
  - being fixed in meta-pytorch/autoparallel#321
- [CompilerToolkit CI
failure](https://github.com/pytorch/torchtitan/actions/runs/22168015737/job/64099486707?pr=2386)
`TypeError: forward() missing 1 required positional argument:
'fwd_rng_state_2'` cc @yiming0416 please help take a look
- [SimpleFSDP CI
failure](https://github.com/pytorch/torchtitan/actions/runs/22168015749/job/64099486149?pr=2386)
is the same as #2312 around
dynamic shape for for-loop MoE experts computation. (cc @pianpwk)
  - being fixed in #2399
- Fix integration scripts for MAST, Zoomer, etc.
- organize docs from `docs/` to subfolders, as we are having more
contents to cover in general
- generate and store serialized configs (maybe not in the repo)
- continue SAC refactor in
#2357, but somehow keep the
every-other-mm policy (cc @mori360)
- refactor RoPE in general, at least resolving the following TODOs in
code (cc @shuhuayu)
- having to set / no validation on rope dim == decoder dim // attention
n_heads
- consolidate `apply_rotary_emb_complex` and
`apply_rotary_emb_single_complex`
  - address #2417

Longer-term issues
- More careful design about what to put config vs. runtime build kwargs.
(thanks @ailzhang)
- ModelSpec not serializable. There may be multiple solutions, but we
can potentially consolidate `model.py` and `parallelize.py` by
  - sharing AC, compile, DP application across all Decoder models
  - putting per-module TP/CP/EP sharding plan inside model itself
- Right now `BaseModel.update_from_config` violates encapsulation by
passing the Trainer config into Model config. This could be avoided by
python logic either in config construction time, or in trainer.
- Refactor `init_weights` into `Module.Config` instead of staying in
`Module`
- The benefit is that param init can be configurable; o/w we are
coupling module implementation and its weight init.
- This may require refactor of current TransformerBlock and its config.
E.g. `weight_init_std` may need to be put in config, with
`__post_init__` determining its value. (See related complaints /
discussions on `__post_init__` by
[chz](https://github.com/openai/chz/blob/main/docs/21_post_init.md))

Note to reviewer:
Although I believe the changes in this PR come naturally in a bundle,
you may (or may not) find the stack of 16 commits easier to review, as I
tried to split the changes in some logic manner. I apologize for the
giant PR.

# claude-generated summary

## Summary

This PR refactors torchtitan's configuration and training infrastructure
in 15 incremental, backwards-incompatible commits. The central change
replaces TOML config files and a monolithic `JobConfig` parser with
**typed Python dataclass configs**, a **`Configurable` base class**
pattern, and a **`config_registry`** module per model.

**270 files changed, 10,025 insertions, 11,418 deletions.**

---

## Motivation

The previous system used TOML files parsed by a custom `ConfigManager`
that layered CLI overrides on top. While simple, this had several
friction points:

1. **No type safety at config boundaries.** TOML values are
strings/ints/floats parsed at runtime. A typo in a key name (e.g.,
`training.stpes`) silently becomes a default value.
4. **Flat namespace.** All config sections (`[model]`, `[training]`,
`[optimizer]`, `[checkpoint]`, ...) lived in a single `JobConfig` class.
Every component received the full `JobConfig` even when it only needed a
few fields.
5. **Experiment extension was ad-hoc.** Experiments that needed custom
config fields (e.g., SimpleFSDP's `compile.graph_passes` or
FaultTolerant's `fault_tolerance.*`) required a `custom_config_module`
TOML key and a runtime `_merge_configs` call to graft new fields onto
`JobConfig`.
6. **Model args were disconnected from model code.** A `ModelArgs`
dataclass in `args.py` defined hyperparameters, but the `TrainSpec` that
bundled model + parallelization + loss was registered separately, with
no type-level link between them.

---

## What Changed

### 1. `Configurable` Base Class

A new `Configurable` base class (`torchtitan/config/configurable.py`)
establishes a universal pattern:

```python
class Configurable:
    @DataClass(kw_only=True, slots=True)
    class Config:
        def build(self, **kwargs):
            return self._owner(config=self, **kwargs)

    def __init_subclass__(cls, **kwargs):
        # Auto-wires Config.build() -> cls(config=..., **kwargs)
        # Enforces @DataClass(kw_only=True, slots=True) on every Config
```

Every configurable component (Trainer, model, optimizer, tokenizer,
dataloader, checkpoint manager, metrics, validators, quantization
converters, ...) follows this pattern. Calling `config.build()`
constructs the owning class.

### 2. `Trainer.Config` Replaces `JobConfig`

The monolithic `JobConfig` is replaced by `Trainer.Config`, a nested
dataclass that aggregates typed sub-configs:

```python
class Trainer(Stateful, Configurable):
    @DataClass(kw_only=True, slots=True)
    class Config(Configurable.Config):
        model_spec: ModelSpec | None = None    # set by config_registry, suppressed from CLI
        job: JobConfig = ...
        training: TrainingConfig = ...
        parallelism: ParallelismConfig = ...
        optimizer: OptimizersContainer.Config = ...
        lr_scheduler: LRSchedulersContainer.Config = ...
        checkpoint: CheckpointManager.Config = ...
        dataloader: BaseDataLoader.Config = ...
        metrics: MetricsProcessor.Config = ...
        # ... etc.
```

Each sub-config is the `Config` class of the component that consumes it
(e.g., `CheckpointManager.Config` is defined inside
`CheckpointManager`). Components receive only their own config, not the
entire training config.

### 3. `config_registry.py` Replaces TOML Files

Each model defines a `config_registry.py` with functions that return
complete `Trainer.Config` instances:

```python
# torchtitan/models/llama3/config_registry.py

def llama3_debugmodel() -> Trainer.Config:
    return Trainer.Config(
        job=JobConfig(description="Llama 3 debug training", ...),
        model_spec=model_registry("debugmodel"),
        optimizer=OptimizersContainer.Config(lr=8e-4),
        training=TrainingConfig(local_batch_size=8, seq_len=2048, steps=10),
        dataloader=HuggingFaceTextDataLoader.Config(dataset="c4_test"),
        # ...
    )

def llama3_debugmodel_float8() -> Trainer.Config:
    config = llama3_debugmodel()
    config.model_converters = ModelConvertersContainer.Config(
        converters=[Float8LinearConverter.Config(enable_fsdp_float8_all_gather=True)]
    )
    return config
```

### 4. `TrainSpec` -> `ModelSpec`

`TrainSpec` is renamed to `ModelSpec` with a narrower scope: it holds
only model-specific concerns (model config, parallelization function,
loss function, state dict adapter). All training-level concerns
(optimizer, LR scheduler, checkpointing, etc.) live in `Trainer.Config`.

### 5. Model Configs: Flat `ModelArgs` -> Nested Dataclass Hierarchy

Model hyperparameters move from a flat `ModelArgs` dataclass into a
nested `Config` hierarchy that mirrors the module tree:

```python
# Before (main): flat args.py
@DataClass
class ModelArgs:
    dim: int = 4096
    n_layers: int = 32
    n_heads: int = 32
    # ... 20+ flat fields

# After (this PR): nested Config in model class
class Llama3Model(Decoder):
    @DataClass(kw_only=True, slots=True)
    class Config(Decoder.Config):
        layer: Llama3TransformerBlock.Config  # contains attention + FFN configs
        rope: RoPE.Config                    # contains RoPE-specific params
```

### 6. `train.py` Split

The monolithic `train.py` (~800 lines) is split into:
- `train.py` (~60 lines): thin entry point that calls
`ConfigManager.parse_args()` and `config.build()`
- `trainer.py` (~850 lines): the `Trainer` class with training loop
logic

### 7. Experiment Extension via Inheritance

Experiments extend the config system through dataclass subclassing
instead of runtime config merging:

```python
# torchtitan/experiments/simple_fsdp/configs.py
@DataClass(kw_only=True, slots=True)
class SimpleFSDPConfig(Trainer.Config):
    compile: SimpleFSDPCompileConfig = field(default_factory=SimpleFSDPCompileConfig)
```

Their `config_registry.py` returns the subclassed config type, and
`tyro` auto-generates CLI parsing for the extended fields.

---

## UX Comparison

### Launching Training

```bash
# Before (main)
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.py" ./run_train.sh

# After (this PR)
MODEL=llama3 CONFIG=llama3_8b ./run_train.sh
```

### CLI Overrides

```bash
# Before (main)
CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh \
  --training.steps 100 --parallelism.tensor_parallel_degree 2

# After (this PR)
./run_train.sh --training.steps 100 --parallelism.tensor_parallel_degree 2
# (defaults to MODEL=llama3, CONFIG=llama3_debugmodel via run_train.sh)
```

CLI override syntax is unchanged (`--section.field value`), but `tyro`
now provides typed `--help` output generated from the dataclass tree.

### Defining a New Model Config

```bash
# Before: create a new TOML file, copy-paste sections, edit values
cp train_configs/debug_model.toml train_configs/my_experiment.toml
vim train_configs/my_experiment.toml

# After: write a Python function that mutates an existing config
def my_experiment() -> Trainer.Config:
    config = llama3_debugmodel()
    config.training.steps = 100
    config.optimizer.lr = 1e-4
    return config
```

### Adding Experiment-Specific Config Fields

```python
# Before (main): custom_config_module in TOML + runtime _merge_configs
# Requires: TOML key pointing to a Python module, dynamic dataclass creation

# After (this PR): dataclass inheritance
@DataClass(kw_only=True, slots=True)
class MyExperimentConfig(Trainer.Config):
    my_custom_field: str = "default"
```

### Float8 / Quantization Configuration

```python
# Before (main): TOML section
# [quantize.linear.float8]
# enable_fsdp_float8_all_gather = true
# precompute_float8_dynamic_scale_for_fsdp = true

# After (this PR): typed config object
model_converters=ModelConvertersContainer.Config(
    converters=[
        Float8LinearConverter.Config(
            enable_fsdp_float8_all_gather=True,
            precompute_float8_dynamic_scale_for_fsdp=True,
        ),
    ],
),
```

---

## Limitations and Trade-offs

### 1. Configs are no longer declarative text files

TOML files were readable by anyone without Python knowledge. The new
config_registry functions are Python code, which requires understanding
imports, function calls, and dataclass construction. For users who only
need to tweak hyperparameters, the CLI override syntax
(`--training.steps 100`) works the same, but understanding the full
config requires reading Python.

### 2. Steeper learning curve for contributors

Adding a new model now requires understanding the `Configurable`
protocol, nested `Config` dataclass hierarchy, and the `config_registry`
pattern. The old approach of copying a TOML file and editing values had
a lower barrier to entry.

### 3. Config serialization is more complex

TOML files were trivially serializable and diffable. The new system
supports `to_dict()` + JSON serialization, but configs containing
callables (e.g., `ModelSpec.parallelize_fn`) cannot be fully
round-tripped. The `model_spec` field is excluded from serialization and
suppressed from CLI parsing.

### 4. tyro dependency

The CLI parsing now depends on `tyro`, a third-party library. While
`tyro` is well-maintained and provides typed CLI generation from
dataclasses, it is an additional dependency that must be kept compatible
with the dataclass patterns used here.

### 5. `@dataclass(slots=True)` constraints

The `Configurable` base class enforces `@dataclass(kw_only=True,
slots=True)` on all Config classes. While this provides memory
efficiency and prevents accidental attribute assignment, `slots=True`
prevents dynamic attribute addition and makes multiple inheritance with
other slotted classes more constrained. Each Config subclass in a deep
hierarchy must repeat the `@dataclass(kw_only=True, slots=True)`
decorator.

### 6. Two-level indirection for model selection

The old system required one identifier: `--job.config_file
path/to/file.toml`. The new system requires two: `--module llama3
--config llama3_8b`. While this separates model identity from training
recipe, it adds an extra argument.

---

## Numerics Verification

All model configs were verified for numerical equivalence against the
main branch (commit `10d8a306`):

NOTE
- only models that can fit on 8 GPUs are tested
- only subset of parallelism combination are tested

| Model | Status | Notes |
|-------|--------|-------|
| llama3 (debugmodel, 8B) | Bitwise match | |
| llama3 (debugmodel_flex_attn) | Bitwise match | |
| qwen3 (0.6B, 1.7B, 32B, MoE debugmodel) | Bitwise match | |
| deepseek_v3 (debugmodel, 16B) | Close (max diff 0.00014) |
Pre-existing main branch bug: missing `eps` in final RMSNorm |
| llama4 debugmodel | Bitwise match | _irope variants don't work on main
(FlexAttn `'dict' object has no attribute 'BLOCK_SIZE'`) but now work
after this PR |
| **gpt_oss** debugmodel | --debug.deterministic causes loss to be NaN;
o/w first step loss match, minor difference after (likely caused by
flex?) | |
| flux | Bitwise match | |

---

## Migration Guide

| Old (main) | New (this PR) |
|---|---|
| `CONFIG_FILE="path/to/config.toml" ./run_train.sh` | `MODEL=llama3
CONFIG=llama3_8b ./run_train.sh` |
| `--job.config_file path.toml` | `--module llama3 --config llama3_8b` |
| `train_configs/*.toml` | `config_registry.py` functions |
| `TrainSpec` | `ModelSpec` |
| `ModelArgs` / `args.py` | Nested `Model.Config` dataclass |
| `custom_config_module` + `_merge_configs()` | Subclass
`Trainer.Config` |
| `build_model_converters()` free function |
`ModelConvertersContainer.Config.build()` |
| `build_metrics_processor()` free function |
`MetricsProcessor.Config.build()` |
TXacs pushed a commit to McmillanTAC/torchtitan that referenced this pull request Apr 13, 2026
Summary:
- extract logic in train.py that was different for ft in separate
functions
- override these functions in ft's train.py

Test Plan:
Run lighthouse and two replicas

```bash
RUST_BACKTRACE=1 torchft_lighthouse --min_replicas 2 --quorum_tick_ms 100 --join_timeout_ms 10000

TRAIN_FILE=torchtitan.experiments.ft.train CONFIG_FILE=./torchtitan/models/llama3_ft/train_configs/debug_model.toml CUDA_VISIBLE_DEVICES=0 NGPU=1 ./run_train.sh --parallelism.data_parallel_shard_degree=1 --fault_tolerance.enable --fault_tolerance.group_size=2 --fault_tolerance.replica_id=0


TRAIN_FILE=torchtitan.experiments.ft.train CONFIG_FILE=./torchtitan/models/llama3_ft/train_configs/debug_model.toml CUDA_VISIBLE_DEVICES=1 NGPU=1 ./run_train.sh --parallelism.data_parallel_shard_degree=1 --fault_tolerance.enable --fault_tolerance.group_size=2 --fault_tolerance.replica_id=1
```

Co-authored-by: Tushar Jain <tushar00jain@users.noreply.github.com>
TXacs pushed a commit to McmillanTAC/torchtitan that referenced this pull request Apr 13, 2026
…ry (pytorch#2386)

**NOTE**: This PR is a large refactor of the codebase.
https://github.com/pytorch/torchtitan/releases/tag/v0.2.2 contains a
latest release right before this PR is merged.

# author's note

This refactor is mainly trying to address two issues:
- bad encapsulation: previously a monolithic `JobConfig` is leaked
everywhere
- not easy to iterate and experiment on model architecture and training
components

The main changes are:
- Strict encapsulation, even at the cost of (hopefully temporary)
bloated interface when calling subcomponents (e.g. validator). We should
try to find the right abstraction on cross-components visibility.
- Each `Configurable` component owns its own `Config`, which builds the
owner component. It achieves modularization via polymorphism and
inheritance, both classic concepts in OOP.
- This is partly inspired by repos like
[AXLearn](https://github.com/apple/axlearn) (in particular, @ruomingp's
[ML API
Styles](https://github.com/apple/axlearn/blob/main/docs/ml_api_style.md)),
github issues (e.g. pytorch#1055),
and offline discussions (with @Chillee, @ailzhang etc.).
- Similar functionality can be alternatively achieved by other ways,
e.g. `_target_` in
[Hydra](https://hydra.cc/docs/advanced/instantiate_objects/overview/),
but there are opinions not to couple with Hydra's other offerings. See
pytorch#1415
- Main entry point switches from TOML files to Python functions (a.k.a.
`config_registry.py` in each model).
- TOML has the constraint that everything needs to be registered
explicitly before it can be used, e.g. our quantization components need
to be registered with string names. Python's language level implicit
registration is what we believe to be more minimal, and should be fairly
easy to extended/modified to support TOML/YAML when users builds upon /
fork torchtitan.
- That said, Python config provides much more power, e.g. one can use
arbitrary logic to create (the config of) a component, which is hard to
express with TOML/YAML, thus creating extra difficulty when users want
to migrate to their own favorite config system. The only thing we can do
is to stay conservative on the usage of such power.
- We still uses [tyro](https://github.com/brentyi/tyro) to convert
config dataclass to CLI, still with the limitation that users need to
construct customized config classes, all the way from root level
(`Trainer.Config` now, `JobConfig` in the past).
- If CLI is not needed, new trainer (or any high-level) config is not
required.
- To support "polymorphic construction" from CLI without the hassle,
check out
[chz](https://github.com/openai/chz/blob/main/docs/04_command_line.md#polymorphic-construction).

This PR also
- updates the docs -- there might be remaining outdated docs, please
raise issues or help fix
- moves ft to experiments, continuing the effort in
pytorch#2311

Remaining work
- [AutoParallel CI
failure](https://github.com/pytorch/torchtitan/actions/runs/22165425254/job/64091572780?pr=2386)
seems caused by the way RoPE is authored, and needs change in
autoparallel. (cc @xmfan)
  - being fixed in meta-pytorch/autoparallel#321
- [CompilerToolkit CI
failure](https://github.com/pytorch/torchtitan/actions/runs/22168015737/job/64099486707?pr=2386)
`TypeError: forward() missing 1 required positional argument:
'fwd_rng_state_2'` cc @yiming0416 please help take a look
- [SimpleFSDP CI
failure](https://github.com/pytorch/torchtitan/actions/runs/22168015749/job/64099486149?pr=2386)
is the same as pytorch#2312 around
dynamic shape for for-loop MoE experts computation. (cc @pianpwk)
  - being fixed in pytorch#2399
- Fix integration scripts for MAST, Zoomer, etc.
- organize docs from `docs/` to subfolders, as we are having more
contents to cover in general
- generate and store serialized configs (maybe not in the repo)
- continue SAC refactor in
pytorch#2357, but somehow keep the
every-other-mm policy (cc @mori360)
- refactor RoPE in general, at least resolving the following TODOs in
code (cc @shuhuayu)
- having to set / no validation on rope dim == decoder dim // attention
n_heads
- consolidate `apply_rotary_emb_complex` and
`apply_rotary_emb_single_complex`
  - address pytorch#2417

Longer-term issues
- More careful design about what to put config vs. runtime build kwargs.
(thanks @ailzhang)
- ModelSpec not serializable. There may be multiple solutions, but we
can potentially consolidate `model.py` and `parallelize.py` by
  - sharing AC, compile, DP application across all Decoder models
  - putting per-module TP/CP/EP sharding plan inside model itself
- Right now `BaseModel.update_from_config` violates encapsulation by
passing the Trainer config into Model config. This could be avoided by
python logic either in config construction time, or in trainer.
- Refactor `init_weights` into `Module.Config` instead of staying in
`Module`
- The benefit is that param init can be configurable; o/w we are
coupling module implementation and its weight init.
- This may require refactor of current TransformerBlock and its config.
E.g. `weight_init_std` may need to be put in config, with
`__post_init__` determining its value. (See related complaints /
discussions on `__post_init__` by
[chz](https://github.com/openai/chz/blob/main/docs/21_post_init.md))

Note to reviewer:
Although I believe the changes in this PR come naturally in a bundle,
you may (or may not) find the stack of 16 commits easier to review, as I
tried to split the changes in some logic manner. I apologize for the
giant PR.

# claude-generated summary

## Summary

This PR refactors torchtitan's configuration and training infrastructure
in 15 incremental, backwards-incompatible commits. The central change
replaces TOML config files and a monolithic `JobConfig` parser with
**typed Python dataclass configs**, a **`Configurable` base class**
pattern, and a **`config_registry`** module per model.

**270 files changed, 10,025 insertions, 11,418 deletions.**

---

## Motivation

The previous system used TOML files parsed by a custom `ConfigManager`
that layered CLI overrides on top. While simple, this had several
friction points:

1. **No type safety at config boundaries.** TOML values are
strings/ints/floats parsed at runtime. A typo in a key name (e.g.,
`training.stpes`) silently becomes a default value.
4. **Flat namespace.** All config sections (`[model]`, `[training]`,
`[optimizer]`, `[checkpoint]`, ...) lived in a single `JobConfig` class.
Every component received the full `JobConfig` even when it only needed a
few fields.
5. **Experiment extension was ad-hoc.** Experiments that needed custom
config fields (e.g., SimpleFSDP's `compile.graph_passes` or
FaultTolerant's `fault_tolerance.*`) required a `custom_config_module`
TOML key and a runtime `_merge_configs` call to graft new fields onto
`JobConfig`.
6. **Model args were disconnected from model code.** A `ModelArgs`
dataclass in `args.py` defined hyperparameters, but the `TrainSpec` that
bundled model + parallelization + loss was registered separately, with
no type-level link between them.

---

## What Changed

### 1. `Configurable` Base Class

A new `Configurable` base class (`torchtitan/config/configurable.py`)
establishes a universal pattern:

```python
class Configurable:
    @DataClass(kw_only=True, slots=True)
    class Config:
        def build(self, **kwargs):
            return self._owner(config=self, **kwargs)

    def __init_subclass__(cls, **kwargs):
        # Auto-wires Config.build() -> cls(config=..., **kwargs)
        # Enforces @DataClass(kw_only=True, slots=True) on every Config
```

Every configurable component (Trainer, model, optimizer, tokenizer,
dataloader, checkpoint manager, metrics, validators, quantization
converters, ...) follows this pattern. Calling `config.build()`
constructs the owning class.

### 2. `Trainer.Config` Replaces `JobConfig`

The monolithic `JobConfig` is replaced by `Trainer.Config`, a nested
dataclass that aggregates typed sub-configs:

```python
class Trainer(Stateful, Configurable):
    @DataClass(kw_only=True, slots=True)
    class Config(Configurable.Config):
        model_spec: ModelSpec | None = None    # set by config_registry, suppressed from CLI
        job: JobConfig = ...
        training: TrainingConfig = ...
        parallelism: ParallelismConfig = ...
        optimizer: OptimizersContainer.Config = ...
        lr_scheduler: LRSchedulersContainer.Config = ...
        checkpoint: CheckpointManager.Config = ...
        dataloader: BaseDataLoader.Config = ...
        metrics: MetricsProcessor.Config = ...
        # ... etc.
```

Each sub-config is the `Config` class of the component that consumes it
(e.g., `CheckpointManager.Config` is defined inside
`CheckpointManager`). Components receive only their own config, not the
entire training config.

### 3. `config_registry.py` Replaces TOML Files

Each model defines a `config_registry.py` with functions that return
complete `Trainer.Config` instances:

```python
# torchtitan/models/llama3/config_registry.py

def llama3_debugmodel() -> Trainer.Config:
    return Trainer.Config(
        job=JobConfig(description="Llama 3 debug training", ...),
        model_spec=model_registry("debugmodel"),
        optimizer=OptimizersContainer.Config(lr=8e-4),
        training=TrainingConfig(local_batch_size=8, seq_len=2048, steps=10),
        dataloader=HuggingFaceTextDataLoader.Config(dataset="c4_test"),
        # ...
    )

def llama3_debugmodel_float8() -> Trainer.Config:
    config = llama3_debugmodel()
    config.model_converters = ModelConvertersContainer.Config(
        converters=[Float8LinearConverter.Config(enable_fsdp_float8_all_gather=True)]
    )
    return config
```

### 4. `TrainSpec` -> `ModelSpec`

`TrainSpec` is renamed to `ModelSpec` with a narrower scope: it holds
only model-specific concerns (model config, parallelization function,
loss function, state dict adapter). All training-level concerns
(optimizer, LR scheduler, checkpointing, etc.) live in `Trainer.Config`.

### 5. Model Configs: Flat `ModelArgs` -> Nested Dataclass Hierarchy

Model hyperparameters move from a flat `ModelArgs` dataclass into a
nested `Config` hierarchy that mirrors the module tree:

```python
# Before (main): flat args.py
@DataClass
class ModelArgs:
    dim: int = 4096
    n_layers: int = 32
    n_heads: int = 32
    # ... 20+ flat fields

# After (this PR): nested Config in model class
class Llama3Model(Decoder):
    @DataClass(kw_only=True, slots=True)
    class Config(Decoder.Config):
        layer: Llama3TransformerBlock.Config  # contains attention + FFN configs
        rope: RoPE.Config                    # contains RoPE-specific params
```

### 6. `train.py` Split

The monolithic `train.py` (~800 lines) is split into:
- `train.py` (~60 lines): thin entry point that calls
`ConfigManager.parse_args()` and `config.build()`
- `trainer.py` (~850 lines): the `Trainer` class with training loop
logic

### 7. Experiment Extension via Inheritance

Experiments extend the config system through dataclass subclassing
instead of runtime config merging:

```python
# torchtitan/experiments/simple_fsdp/configs.py
@DataClass(kw_only=True, slots=True)
class SimpleFSDPConfig(Trainer.Config):
    compile: SimpleFSDPCompileConfig = field(default_factory=SimpleFSDPCompileConfig)
```

Their `config_registry.py` returns the subclassed config type, and
`tyro` auto-generates CLI parsing for the extended fields.

---

## UX Comparison

### Launching Training

```bash
# Before (main)
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.py" ./run_train.sh

# After (this PR)
MODEL=llama3 CONFIG=llama3_8b ./run_train.sh
```

### CLI Overrides

```bash
# Before (main)
CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh \
  --training.steps 100 --parallelism.tensor_parallel_degree 2

# After (this PR)
./run_train.sh --training.steps 100 --parallelism.tensor_parallel_degree 2
# (defaults to MODEL=llama3, CONFIG=llama3_debugmodel via run_train.sh)
```

CLI override syntax is unchanged (`--section.field value`), but `tyro`
now provides typed `--help` output generated from the dataclass tree.

### Defining a New Model Config

```bash
# Before: create a new TOML file, copy-paste sections, edit values
cp train_configs/debug_model.toml train_configs/my_experiment.toml
vim train_configs/my_experiment.toml

# After: write a Python function that mutates an existing config
def my_experiment() -> Trainer.Config:
    config = llama3_debugmodel()
    config.training.steps = 100
    config.optimizer.lr = 1e-4
    return config
```

### Adding Experiment-Specific Config Fields

```python
# Before (main): custom_config_module in TOML + runtime _merge_configs
# Requires: TOML key pointing to a Python module, dynamic dataclass creation

# After (this PR): dataclass inheritance
@DataClass(kw_only=True, slots=True)
class MyExperimentConfig(Trainer.Config):
    my_custom_field: str = "default"
```

### Float8 / Quantization Configuration

```python
# Before (main): TOML section
# [quantize.linear.float8]
# enable_fsdp_float8_all_gather = true
# precompute_float8_dynamic_scale_for_fsdp = true

# After (this PR): typed config object
model_converters=ModelConvertersContainer.Config(
    converters=[
        Float8LinearConverter.Config(
            enable_fsdp_float8_all_gather=True,
            precompute_float8_dynamic_scale_for_fsdp=True,
        ),
    ],
),
```

---

## Limitations and Trade-offs

### 1. Configs are no longer declarative text files

TOML files were readable by anyone without Python knowledge. The new
config_registry functions are Python code, which requires understanding
imports, function calls, and dataclass construction. For users who only
need to tweak hyperparameters, the CLI override syntax
(`--training.steps 100`) works the same, but understanding the full
config requires reading Python.

### 2. Steeper learning curve for contributors

Adding a new model now requires understanding the `Configurable`
protocol, nested `Config` dataclass hierarchy, and the `config_registry`
pattern. The old approach of copying a TOML file and editing values had
a lower barrier to entry.

### 3. Config serialization is more complex

TOML files were trivially serializable and diffable. The new system
supports `to_dict()` + JSON serialization, but configs containing
callables (e.g., `ModelSpec.parallelize_fn`) cannot be fully
round-tripped. The `model_spec` field is excluded from serialization and
suppressed from CLI parsing.

### 4. tyro dependency

The CLI parsing now depends on `tyro`, a third-party library. While
`tyro` is well-maintained and provides typed CLI generation from
dataclasses, it is an additional dependency that must be kept compatible
with the dataclass patterns used here.

### 5. `@dataclass(slots=True)` constraints

The `Configurable` base class enforces `@dataclass(kw_only=True,
slots=True)` on all Config classes. While this provides memory
efficiency and prevents accidental attribute assignment, `slots=True`
prevents dynamic attribute addition and makes multiple inheritance with
other slotted classes more constrained. Each Config subclass in a deep
hierarchy must repeat the `@dataclass(kw_only=True, slots=True)`
decorator.

### 6. Two-level indirection for model selection

The old system required one identifier: `--job.config_file
path/to/file.toml`. The new system requires two: `--module llama3
--config llama3_8b`. While this separates model identity from training
recipe, it adds an extra argument.

---

## Numerics Verification

All model configs were verified for numerical equivalence against the
main branch (commit `10d8a306`):

NOTE
- only models that can fit on 8 GPUs are tested
- only subset of parallelism combination are tested

| Model | Status | Notes |
|-------|--------|-------|
| llama3 (debugmodel, 8B) | Bitwise match | |
| llama3 (debugmodel_flex_attn) | Bitwise match | |
| qwen3 (0.6B, 1.7B, 32B, MoE debugmodel) | Bitwise match | |
| deepseek_v3 (debugmodel, 16B) | Close (max diff 0.00014) |
Pre-existing main branch bug: missing `eps` in final RMSNorm |
| llama4 debugmodel | Bitwise match | _irope variants don't work on main
(FlexAttn `'dict' object has no attribute 'BLOCK_SIZE'`) but now work
after this PR |
| **gpt_oss** debugmodel | --debug.deterministic causes loss to be NaN;
o/w first step loss match, minor difference after (likely caused by
flex?) | |
| flux | Bitwise match | |

---

## Migration Guide

| Old (main) | New (this PR) |
|---|---|
| `CONFIG_FILE="path/to/config.toml" ./run_train.sh` | `MODEL=llama3
CONFIG=llama3_8b ./run_train.sh` |
| `--job.config_file path.toml` | `--module llama3 --config llama3_8b` |
| `train_configs/*.toml` | `config_registry.py` functions |
| `TrainSpec` | `ModelSpec` |
| `ModelArgs` / `args.py` | Nested `Model.Config` dataclass |
| `custom_config_module` + `_merge_configs()` | Subclass
`Trainer.Config` |
| `build_model_converters()` free function |
`ModelConvertersContainer.Config.build()` |
| `build_metrics_processor()` free function |
`MetricsProcessor.Config.build()` |
ACharacterInASimulation pushed a commit to ACharacterInASimulation/torchtitan that referenced this pull request Apr 21, 2026
Summary:
- extract logic in train.py that was different for ft in separate
functions
- override these functions in ft's train.py

Test Plan:
Run lighthouse and two replicas

```bash
RUST_BACKTRACE=1 torchft_lighthouse --min_replicas 2 --quorum_tick_ms 100 --join_timeout_ms 10000

TRAIN_FILE=torchtitan.experiments.ft.train CONFIG_FILE=./torchtitan/models/llama3_ft/train_configs/debug_model.toml CUDA_VISIBLE_DEVICES=0 NGPU=1 ./run_train.sh --parallelism.data_parallel_shard_degree=1 --fault_tolerance.enable --fault_tolerance.group_size=2 --fault_tolerance.replica_id=0


TRAIN_FILE=torchtitan.experiments.ft.train CONFIG_FILE=./torchtitan/models/llama3_ft/train_configs/debug_model.toml CUDA_VISIBLE_DEVICES=1 NGPU=1 ./run_train.sh --parallelism.data_parallel_shard_degree=1 --fault_tolerance.enable --fault_tolerance.group_size=2 --fault_tolerance.replica_id=1
```

Co-authored-by: Tushar Jain <tushar00jain@users.noreply.github.com>
ACharacterInASimulation pushed a commit to ACharacterInASimulation/torchtitan that referenced this pull request Apr 21, 2026
…ry (pytorch#2386)

**NOTE**: This PR is a large refactor of the codebase.
https://github.com/pytorch/torchtitan/releases/tag/v0.2.2 contains a
latest release right before this PR is merged.

# author's note

This refactor is mainly trying to address two issues:
- bad encapsulation: previously a monolithic `JobConfig` is leaked
everywhere
- not easy to iterate and experiment on model architecture and training
components

The main changes are:
- Strict encapsulation, even at the cost of (hopefully temporary)
bloated interface when calling subcomponents (e.g. validator). We should
try to find the right abstraction on cross-components visibility.
- Each `Configurable` component owns its own `Config`, which builds the
owner component. It achieves modularization via polymorphism and
inheritance, both classic concepts in OOP.
- This is partly inspired by repos like
[AXLearn](https://github.com/apple/axlearn) (in particular, @ruomingp's
[ML API
Styles](https://github.com/apple/axlearn/blob/main/docs/ml_api_style.md)),
github issues (e.g. pytorch#1055),
and offline discussions (with @Chillee, @ailzhang etc.).
- Similar functionality can be alternatively achieved by other ways,
e.g. `_target_` in
[Hydra](https://hydra.cc/docs/advanced/instantiate_objects/overview/),
but there are opinions not to couple with Hydra's other offerings. See
pytorch#1415
- Main entry point switches from TOML files to Python functions (a.k.a.
`config_registry.py` in each model).
- TOML has the constraint that everything needs to be registered
explicitly before it can be used, e.g. our quantization components need
to be registered with string names. Python's language level implicit
registration is what we believe to be more minimal, and should be fairly
easy to extended/modified to support TOML/YAML when users builds upon /
fork torchtitan.
- That said, Python config provides much more power, e.g. one can use
arbitrary logic to create (the config of) a component, which is hard to
express with TOML/YAML, thus creating extra difficulty when users want
to migrate to their own favorite config system. The only thing we can do
is to stay conservative on the usage of such power.
- We still uses [tyro](https://github.com/brentyi/tyro) to convert
config dataclass to CLI, still with the limitation that users need to
construct customized config classes, all the way from root level
(`Trainer.Config` now, `JobConfig` in the past).
- If CLI is not needed, new trainer (or any high-level) config is not
required.
- To support "polymorphic construction" from CLI without the hassle,
check out
[chz](https://github.com/openai/chz/blob/main/docs/04_command_line.md#polymorphic-construction).

This PR also
- updates the docs -- there might be remaining outdated docs, please
raise issues or help fix
- moves ft to experiments, continuing the effort in
pytorch#2311

Remaining work
- [AutoParallel CI
failure](https://github.com/pytorch/torchtitan/actions/runs/22165425254/job/64091572780?pr=2386)
seems caused by the way RoPE is authored, and needs change in
autoparallel. (cc @xmfan)
  - being fixed in meta-pytorch/autoparallel#321
- [CompilerToolkit CI
failure](https://github.com/pytorch/torchtitan/actions/runs/22168015737/job/64099486707?pr=2386)
`TypeError: forward() missing 1 required positional argument:
'fwd_rng_state_2'` cc @yiming0416 please help take a look
- [SimpleFSDP CI
failure](https://github.com/pytorch/torchtitan/actions/runs/22168015749/job/64099486149?pr=2386)
is the same as pytorch#2312 around
dynamic shape for for-loop MoE experts computation. (cc @pianpwk)
  - being fixed in pytorch#2399
- Fix integration scripts for MAST, Zoomer, etc.
- organize docs from `docs/` to subfolders, as we are having more
contents to cover in general
- generate and store serialized configs (maybe not in the repo)
- continue SAC refactor in
pytorch#2357, but somehow keep the
every-other-mm policy (cc @mori360)
- refactor RoPE in general, at least resolving the following TODOs in
code (cc @shuhuayu)
- having to set / no validation on rope dim == decoder dim // attention
n_heads
- consolidate `apply_rotary_emb_complex` and
`apply_rotary_emb_single_complex`
  - address pytorch#2417

Longer-term issues
- More careful design about what to put config vs. runtime build kwargs.
(thanks @ailzhang)
- ModelSpec not serializable. There may be multiple solutions, but we
can potentially consolidate `model.py` and `parallelize.py` by
  - sharing AC, compile, DP application across all Decoder models
  - putting per-module TP/CP/EP sharding plan inside model itself
- Right now `BaseModel.update_from_config` violates encapsulation by
passing the Trainer config into Model config. This could be avoided by
python logic either in config construction time, or in trainer.
- Refactor `init_weights` into `Module.Config` instead of staying in
`Module`
- The benefit is that param init can be configurable; o/w we are
coupling module implementation and its weight init.
- This may require refactor of current TransformerBlock and its config.
E.g. `weight_init_std` may need to be put in config, with
`__post_init__` determining its value. (See related complaints /
discussions on `__post_init__` by
[chz](https://github.com/openai/chz/blob/main/docs/21_post_init.md))

Note to reviewer:
Although I believe the changes in this PR come naturally in a bundle,
you may (or may not) find the stack of 16 commits easier to review, as I
tried to split the changes in some logic manner. I apologize for the
giant PR.

# claude-generated summary

## Summary

This PR refactors torchtitan's configuration and training infrastructure
in 15 incremental, backwards-incompatible commits. The central change
replaces TOML config files and a monolithic `JobConfig` parser with
**typed Python dataclass configs**, a **`Configurable` base class**
pattern, and a **`config_registry`** module per model.

**270 files changed, 10,025 insertions, 11,418 deletions.**

---

## Motivation

The previous system used TOML files parsed by a custom `ConfigManager`
that layered CLI overrides on top. While simple, this had several
friction points:

1. **No type safety at config boundaries.** TOML values are
strings/ints/floats parsed at runtime. A typo in a key name (e.g.,
`training.stpes`) silently becomes a default value.
4. **Flat namespace.** All config sections (`[model]`, `[training]`,
`[optimizer]`, `[checkpoint]`, ...) lived in a single `JobConfig` class.
Every component received the full `JobConfig` even when it only needed a
few fields.
5. **Experiment extension was ad-hoc.** Experiments that needed custom
config fields (e.g., SimpleFSDP's `compile.graph_passes` or
FaultTolerant's `fault_tolerance.*`) required a `custom_config_module`
TOML key and a runtime `_merge_configs` call to graft new fields onto
`JobConfig`.
6. **Model args were disconnected from model code.** A `ModelArgs`
dataclass in `args.py` defined hyperparameters, but the `TrainSpec` that
bundled model + parallelization + loss was registered separately, with
no type-level link between them.

---

## What Changed

### 1. `Configurable` Base Class

A new `Configurable` base class (`torchtitan/config/configurable.py`)
establishes a universal pattern:

```python
class Configurable:
    @DataClass(kw_only=True, slots=True)
    class Config:
        def build(self, **kwargs):
            return self._owner(config=self, **kwargs)

    def __init_subclass__(cls, **kwargs):
        # Auto-wires Config.build() -> cls(config=..., **kwargs)
        # Enforces @DataClass(kw_only=True, slots=True) on every Config
```

Every configurable component (Trainer, model, optimizer, tokenizer,
dataloader, checkpoint manager, metrics, validators, quantization
converters, ...) follows this pattern. Calling `config.build()`
constructs the owning class.

### 2. `Trainer.Config` Replaces `JobConfig`

The monolithic `JobConfig` is replaced by `Trainer.Config`, a nested
dataclass that aggregates typed sub-configs:

```python
class Trainer(Stateful, Configurable):
    @DataClass(kw_only=True, slots=True)
    class Config(Configurable.Config):
        model_spec: ModelSpec | None = None    # set by config_registry, suppressed from CLI
        job: JobConfig = ...
        training: TrainingConfig = ...
        parallelism: ParallelismConfig = ...
        optimizer: OptimizersContainer.Config = ...
        lr_scheduler: LRSchedulersContainer.Config = ...
        checkpoint: CheckpointManager.Config = ...
        dataloader: BaseDataLoader.Config = ...
        metrics: MetricsProcessor.Config = ...
        # ... etc.
```

Each sub-config is the `Config` class of the component that consumes it
(e.g., `CheckpointManager.Config` is defined inside
`CheckpointManager`). Components receive only their own config, not the
entire training config.

### 3. `config_registry.py` Replaces TOML Files

Each model defines a `config_registry.py` with functions that return
complete `Trainer.Config` instances:

```python
# torchtitan/models/llama3/config_registry.py

def llama3_debugmodel() -> Trainer.Config:
    return Trainer.Config(
        job=JobConfig(description="Llama 3 debug training", ...),
        model_spec=model_registry("debugmodel"),
        optimizer=OptimizersContainer.Config(lr=8e-4),
        training=TrainingConfig(local_batch_size=8, seq_len=2048, steps=10),
        dataloader=HuggingFaceTextDataLoader.Config(dataset="c4_test"),
        # ...
    )

def llama3_debugmodel_float8() -> Trainer.Config:
    config = llama3_debugmodel()
    config.model_converters = ModelConvertersContainer.Config(
        converters=[Float8LinearConverter.Config(enable_fsdp_float8_all_gather=True)]
    )
    return config
```

### 4. `TrainSpec` -> `ModelSpec`

`TrainSpec` is renamed to `ModelSpec` with a narrower scope: it holds
only model-specific concerns (model config, parallelization function,
loss function, state dict adapter). All training-level concerns
(optimizer, LR scheduler, checkpointing, etc.) live in `Trainer.Config`.

### 5. Model Configs: Flat `ModelArgs` -> Nested Dataclass Hierarchy

Model hyperparameters move from a flat `ModelArgs` dataclass into a
nested `Config` hierarchy that mirrors the module tree:

```python
# Before (main): flat args.py
@DataClass
class ModelArgs:
    dim: int = 4096
    n_layers: int = 32
    n_heads: int = 32
    # ... 20+ flat fields

# After (this PR): nested Config in model class
class Llama3Model(Decoder):
    @DataClass(kw_only=True, slots=True)
    class Config(Decoder.Config):
        layer: Llama3TransformerBlock.Config  # contains attention + FFN configs
        rope: RoPE.Config                    # contains RoPE-specific params
```

### 6. `train.py` Split

The monolithic `train.py` (~800 lines) is split into:
- `train.py` (~60 lines): thin entry point that calls
`ConfigManager.parse_args()` and `config.build()`
- `trainer.py` (~850 lines): the `Trainer` class with training loop
logic

### 7. Experiment Extension via Inheritance

Experiments extend the config system through dataclass subclassing
instead of runtime config merging:

```python
# torchtitan/experiments/simple_fsdp/configs.py
@DataClass(kw_only=True, slots=True)
class SimpleFSDPConfig(Trainer.Config):
    compile: SimpleFSDPCompileConfig = field(default_factory=SimpleFSDPCompileConfig)
```

Their `config_registry.py` returns the subclassed config type, and
`tyro` auto-generates CLI parsing for the extended fields.

---

## UX Comparison

### Launching Training

```bash
# Before (main)
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.py" ./run_train.sh

# After (this PR)
MODEL=llama3 CONFIG=llama3_8b ./run_train.sh
```

### CLI Overrides

```bash
# Before (main)
CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh \
  --training.steps 100 --parallelism.tensor_parallel_degree 2

# After (this PR)
./run_train.sh --training.steps 100 --parallelism.tensor_parallel_degree 2
# (defaults to MODEL=llama3, CONFIG=llama3_debugmodel via run_train.sh)
```

CLI override syntax is unchanged (`--section.field value`), but `tyro`
now provides typed `--help` output generated from the dataclass tree.

### Defining a New Model Config

```bash
# Before: create a new TOML file, copy-paste sections, edit values
cp train_configs/debug_model.toml train_configs/my_experiment.toml
vim train_configs/my_experiment.toml

# After: write a Python function that mutates an existing config
def my_experiment() -> Trainer.Config:
    config = llama3_debugmodel()
    config.training.steps = 100
    config.optimizer.lr = 1e-4
    return config
```

### Adding Experiment-Specific Config Fields

```python
# Before (main): custom_config_module in TOML + runtime _merge_configs
# Requires: TOML key pointing to a Python module, dynamic dataclass creation

# After (this PR): dataclass inheritance
@DataClass(kw_only=True, slots=True)
class MyExperimentConfig(Trainer.Config):
    my_custom_field: str = "default"
```

### Float8 / Quantization Configuration

```python
# Before (main): TOML section
# [quantize.linear.float8]
# enable_fsdp_float8_all_gather = true
# precompute_float8_dynamic_scale_for_fsdp = true

# After (this PR): typed config object
model_converters=ModelConvertersContainer.Config(
    converters=[
        Float8LinearConverter.Config(
            enable_fsdp_float8_all_gather=True,
            precompute_float8_dynamic_scale_for_fsdp=True,
        ),
    ],
),
```

---

## Limitations and Trade-offs

### 1. Configs are no longer declarative text files

TOML files were readable by anyone without Python knowledge. The new
config_registry functions are Python code, which requires understanding
imports, function calls, and dataclass construction. For users who only
need to tweak hyperparameters, the CLI override syntax
(`--training.steps 100`) works the same, but understanding the full
config requires reading Python.

### 2. Steeper learning curve for contributors

Adding a new model now requires understanding the `Configurable`
protocol, nested `Config` dataclass hierarchy, and the `config_registry`
pattern. The old approach of copying a TOML file and editing values had
a lower barrier to entry.

### 3. Config serialization is more complex

TOML files were trivially serializable and diffable. The new system
supports `to_dict()` + JSON serialization, but configs containing
callables (e.g., `ModelSpec.parallelize_fn`) cannot be fully
round-tripped. The `model_spec` field is excluded from serialization and
suppressed from CLI parsing.

### 4. tyro dependency

The CLI parsing now depends on `tyro`, a third-party library. While
`tyro` is well-maintained and provides typed CLI generation from
dataclasses, it is an additional dependency that must be kept compatible
with the dataclass patterns used here.

### 5. `@dataclass(slots=True)` constraints

The `Configurable` base class enforces `@dataclass(kw_only=True,
slots=True)` on all Config classes. While this provides memory
efficiency and prevents accidental attribute assignment, `slots=True`
prevents dynamic attribute addition and makes multiple inheritance with
other slotted classes more constrained. Each Config subclass in a deep
hierarchy must repeat the `@dataclass(kw_only=True, slots=True)`
decorator.

### 6. Two-level indirection for model selection

The old system required one identifier: `--job.config_file
path/to/file.toml`. The new system requires two: `--module llama3
--config llama3_8b`. While this separates model identity from training
recipe, it adds an extra argument.

---

## Numerics Verification

All model configs were verified for numerical equivalence against the
main branch (commit `8b006830`):

NOTE
- only models that can fit on 8 GPUs are tested
- only subset of parallelism combination are tested

| Model | Status | Notes |
|-------|--------|-------|
| llama3 (debugmodel, 8B) | Bitwise match | |
| llama3 (debugmodel_flex_attn) | Bitwise match | |
| qwen3 (0.6B, 1.7B, 32B, MoE debugmodel) | Bitwise match | |
| deepseek_v3 (debugmodel, 16B) | Close (max diff 0.00014) |
Pre-existing main branch bug: missing `eps` in final RMSNorm |
| llama4 debugmodel | Bitwise match | _irope variants don't work on main
(FlexAttn `'dict' object has no attribute 'BLOCK_SIZE'`) but now work
after this PR |
| **gpt_oss** debugmodel | --debug.deterministic causes loss to be NaN;
o/w first step loss match, minor difference after (likely caused by
flex?) | |
| flux | Bitwise match | |

---

## Migration Guide

| Old (main) | New (this PR) |
|---|---|
| `CONFIG_FILE="path/to/config.toml" ./run_train.sh` | `MODEL=llama3
CONFIG=llama3_8b ./run_train.sh` |
| `--job.config_file path.toml` | `--module llama3 --config llama3_8b` |
| `train_configs/*.toml` | `config_registry.py` functions |
| `TrainSpec` | `ModelSpec` |
| `ModelArgs` / `args.py` | Nested `Model.Config` dataclass |
| `custom_config_module` + `_merge_configs()` | Subclass
`Trainer.Config` |
| `build_model_converters()` free function |
`ModelConvertersContainer.Config.build()` |
| `build_metrics_processor()` free function |
`MetricsProcessor.Config.build()` |
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants