Skip to content

[bug] Torch FSDP2 does not work #2785

@pavelgein

Description

@pavelgein

Problem

Training with Torch FSDP2 fails to start

Minimal repro

Here is slightly modified example script


import os
import torch
from typing_extensions import TypedDict, Unpack

from megatron.bridge.training.gpt_step import forward_step
from megatron.bridge.training.pretrain import pretrain
from megatron.core.distributed import DistributedDataParallelConfig

from megatron.bridge import AutoBridge
from megatron.bridge.training.mixed_precision import MixedPrecisionConfig
from megatron.bridge.recipes.utils.optimizer_utils import distributed_fused_adam_with_cosine_annealing
from megatron.bridge.recipes.utils.finetune_utils import default_peft_config, default_squad_config

from megatron.bridge.recipes.utils.optimizer_utils import distributed_fused_adam_with_cosine_annealing
from megatron.bridge.training.config import (
    CheckpointConfig,
    ConfigContainer,
    DistributedInitConfig,
    LoggerConfig,
    RNGConfig,
    TokenizerConfig,
    TrainingConfig,
    ValidationConfig,
)


MAX_LENGTH = 1024


def _sft_common() -> ConfigContainer:
    """Create a base SFT (Supervised Fine-Tuning) ConfigContainer with common defaults.

    This function returns a ConfigContainer template with sensible defaults for full SFT
    (not LoRA/DoRA). The caller MUST set `cfg.model` and `cfg.tokenizer.tokenizer_model`
    before use.

    Key differences from pre-training:
    - Uses HFDatasetConfig with SQuAD as default dataset
    - Lower learning rate (5e-6) suitable for full fine-tuning
    - Fewer training iterations (1000)
    - Smaller batch sizes
    - Supports pretrained_checkpoint loading
    - No PEFT (full parameter training)

    Returns:
        ConfigContainer: Base configuration template for full SFT.
    """
    # Default output directories
    base_output_dir = os.path.join(os.getcwd(), "nemo_experiments")
    run_output_dir = os.path.join(base_output_dir, "default")
    checkpoint_dir = os.path.join(run_output_dir, "checkpoints")
    tensorboard_dir = os.path.join(run_output_dir, "tb_logs")

    # Default sequence length for SFT
    seq_length = 1024

    # Packed sequence is enabled by default for training efficiency
    # pad_seq_to_mult should be set to context_parallel_size * 2 if CP > 1
    packed_sequence = True
    pad_seq_to_mult = 1  # Override in model config if context_parallel_size > 1

    # Optimizer and scheduler with lower LR for full SFT

#     scheduler:
#   name: cosine_with_warmup
#   t_warmup: 65ba  # 300 for non-tools
#   alpha_f: 0.1

# optimizer:
#   name: decoupled_adamw
#   lr: 1.0e-6
#   betas:
#   - 0.9
#   - 0.95
#   eps: 1.0e-12
#   weight_decay: 0.0


    opt_cfg, scheduler_cfg = distributed_fused_adam_with_cosine_annealing(
        lr_warmup_iters=65,
        lr_decay_iters=None,  # Defaults to train_iters during validation
        max_lr=1e-6,  # Lower LR for full fine-tuning
        min_lr=0.0,
        adam_beta1=0.9,
        adam_beta2=0.98,  # Common for fine-tuning
        weight_decay=0.1, # alpha in llmfoundry
        clip_grad=2.0,
    )

    opt_cfg.use_precision_aware_optimizer = True
    # opt_cfg.main_grads_dtype = torch.bfloat16
    # opt_cfg.main_params_dtype = torch.bfloat16

    cfg = ConfigContainer(
        # Model - MUST be set by each recipe before use
        model=None,  # type: ignore[arg-type]
        # Training config - shorter training for SFT
        train=TrainingConfig(
            train_iters=10000,
            global_batch_size=32,
            micro_batch_size=1,
        ),
        validation=ValidationConfig(
            eval_interval=100,
            eval_iters=32,
        ),
        # Optimizer and scheduler
        optimizer=opt_cfg,
        scheduler=scheduler_cfg,
        # DDP config - minimal settings, model-specific configs can override
        ddp=DistributedDataParallelConfig(
            check_for_nan_in_grad=True,
            # grad_reduce_in_fp32=True,
            grad_reduce_in_fp32=False,

        ),
        # Dataset config - uses SQuAD with packed sequences by default
        dataset=default_squad_config(seq_length=MAX_LENGTH),
        # Logger config
        logger=LoggerConfig(
            log_interval=1,
            tensorboard_dir=tensorboard_dir,
            log_timers_to_tensorboard=True,
            log_throughput=True,
            log_throughput_to_tensorboard=True,
            log_l2_norm_grad_to_tensorboard=True,
            throughput_window_size=10,
        ),
        # Tokenizer - placeholder, each recipe should set tokenizer_model
        tokenizer=TokenizerConfig(
            tokenizer_type="HuggingFaceTokenizer",
            tokenizer_model=None,  # Must be set by each recipe
        ),
        # Checkpoint config with pretrained_checkpoint support
        checkpoint=CheckpointConfig(
            save_interval=100,
            save=checkpoint_dir,
            load=checkpoint_dir,
            pretrained_checkpoint=None,  # Set to load from pretrained weights
            ckpt_format="torch_dist",
            # ckpt_format="fsdp_dtensor",
            fully_parallel_save=True,
        ),
        # RNG config - different seed from pretrain
        rng=RNGConfig(seed=5678),
        # Distributed init config
        dist=DistributedInitConfig(use_torch_fsdp2=True),
        # dist=DistributedInitConfig(use_megatron_fsdp=True),
        comm_overlap=None,
        # Mixed precision - bf16 by default
        mixed_precision=MixedPrecisionConfig(
            bf16=True,
            grad_reduce_in_fp32=False,
        ),
        # No PEFT for full SFT
        peft=None,
    )

    return cfg


def new_qwen3_8b_sft_config() -> ConfigContainer:
    """Return a full SFT config for Qwen3 8B.

    Recommended parallelism: TP=4, PP=1 (1 node, 8 GPUs)
    """
    cfg = _sft_common()

    model_path = "/mnt/models/Qwen3-8B"

    # Model config
    cfg.model = AutoBridge.from_hf_pretrained(model_path).to_megatron_provider(load_weights=False)

    cfg.model.seq_length = MAX_LENGTH

    # Tokenizer
    cfg.tokenizer.tokenizer_model = model_path

    # Parallelism settings

    cfg.model.tensor_model_parallel_size = 1
    cfg.model.pipeline_model_parallel_size = 1
    cfg.model.pipeline_model_parallel_layout = None
    cfg.model.pipeline_dtype = None
    cfg.model.virtual_pipeline_model_parallel_size = None
    cfg.model.context_parallel_size = 1
    cfg.model.sequence_parallel = False

    # Sequence length (2048 for packed sequences)
    # cfg.model.seq_length = 2048

    # Global batch size is 8 for packed sequences, 128 otherwise
    cfg.train.global_batch_size = 8
    # Set pad_seq_to_mult for context parallelism
    if cfg.model.context_parallel_size > 1:
        cfg.dataset.packed_sequence_specs.pad_seq_to_mult = cfg.model.context_parallel_size * 2

    # Training config
    cfg.validation.eval_interval = 30
    cfg.train.manual_gc = False
    cfg.train.manual_gc_interval = 0

    # TE (Transformer Engine)
    cfg.model.transformer_impl = "transformer_engine"

    # CUDA Graph
    cfg.model.use_te_rng_tracker = True
    cfg.model.cuda_graph_impl = "none"
    cfg.model.cuda_graph_scope = "full"
    cfg.model.cuda_graph_warmup_steps = 3

    # Kernel selections
    cfg.model.attention_backend = 'auto'
    cfg.model.cross_entropy_loss_fusion = True
    cfg.model.cross_entropy_fusion_impl = "native"

    # Memory saving (recompute & offloading)
    cfg.model.recompute_granularity = "full"
    cfg.model.recompute_modules = "uniform"
    cfg.model.recompute_method = "uniform"
    cfg.model.fine_grained_activation_offloading = False
    cfg.model.recompute_num_layers = cfg.model.num_layers
    cfg.model.offload_modules = None

    # FP8 & MXFP8 (mixed_precision settings)
    # Note: mixed_precision="bf16_mixed" is set in _sft_common as default
    # These are defaults for FP8, enable them if using FP8 - FP8 is not enabled by default
    # cfg.mixed_precision.fp8_recipe = "tensorwise"  # default, uncomment to enable
    # cfg.mixed_precision.fp8 = None  # not enabled by default
    # cfg.mixed_precision.fp8_param_gather = False  # default
    # cfg.mixed_precision.reuse_grad_buf_for_mxfp8_param_ag = False  # default

    # Optimizer precision settings
    # cfg.optimizer.use_precision_aware_optimizer = False
    # cfg.optimizer.main_grads_dtype = torch.float32
    # cfg.optimizer.main_params_dtype = torch.float32
    # cfg.optimizer.exp_avg_dtype = torch.float32
    # cfg.optimizer.exp_avg_sq_dtype = torch.float32

    # Checkpoint config
    cfg.checkpoint.save_interval = 50
    # cfg.checkpoint.save and cfg.checkpoint.load are set in _sft_common. To override:
    cfg.checkpoint.save = "/mnt/p.geyn/nemo_exps"
    # cfg.checkpoint.load = "path/to/load"
    # Uncomment below if using a pretrained checkpoint and provide path to the directory containing pretrained model for finetuning
    # cfg.checkpoint.pretrained_checkpoint = "/path/to/checkpoint"

    # DDP config
    cfg.ddp.grad_reduce_in_fp32 = False
    cfg.ddp.overlap_grad_reduce = False
    cfg.ddp.overlap_param_gather = False
    cfg.ddp.check_for_nan_in_grad = True
    cfg.ddp.use_distributed_optimizer = True

    return cfg



if __name__ == "__main__":
    # The recipe uses the Llama 3.2 1B model configuration from HuggingFace
    # cfg = qwen3_8b_finetune_config(seq_length=1024, peft=None, hf_path="/mnt/models")
    cfg = new_qwen3_8b_sft_config()

    # Override training parameters
    # cfg.train.train_iters = 10
    # cfg.scheduler.lr_decay_iters = 10000

    # cfg.model.vocab_size = 8192
    cfg.tokenizer.vocab_size = cfg.model.vocab_size

    pretrain(cfg, forward_step)

Expected behavior

Training starts

Affected area

area:training

Regression?

Not sure

Environment

No response

Logs

[rank3]: Traceback (most recent call last):
[rank3]:   File "/workspaces/tgpt-megatron/example_bug.py", line 271, in <module>
[rank3]:     pretrain(cfg, forward_step)
[rank3]:   File "/workspaces/tgpt-megatron/Megatron-Bridge/src/megatron/bridge/utils/decorators.py", line 39, in wrapper
[rank3]:     return func(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/workspaces/tgpt-megatron/Megatron-Bridge/src/megatron/bridge/training/pretrain.py", line 99, in pretrain
[rank3]:     _pretrain(state=state, forward_step_func=forward_step_func, callback_manager=callback_manager)
[rank3]:   File "/workspaces/tgpt-megatron/Megatron-Bridge/src/megatron/bridge/training/pretrain.py", line 129, in _pretrain
[rank3]:     setup_output = setup(state, dataset_provider, restart_store=store)
[rank3]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/workspaces/tgpt-megatron/Megatron-Bridge/src/megatron/bridge/training/setup.py", line 221, in setup
[rank3]:     model = cfg.model.provide_distributed_model(
[rank3]:             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/workspaces/tgpt-megatron/Megatron-Bridge/src/megatron/bridge/models/model_provider.py", line 195, in provide_distributed_model
[rank3]:     model = get_model(
[rank3]:             ^^^^^^^^^^
[rank3]:   File "/workspaces/tgpt-megatron/Megatron-Bridge/src/megatron/bridge/models/model_provider.py", line 603, in get_model
[rank3]:     model = _ddp_wrap(
[rank3]:             ^^^^^^^^^^
[rank3]:   File "/workspaces/tgpt-megatron/Megatron-Bridge/src/megatron/bridge/models/common/unimodal.py", line 233, in _ddp_wrap
[rank3]:     DP(
[rank3]: TypeError: TorchFullyShardedDataParallel.__init__() got an unexpected keyword argument 'pg_collection'

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions