Problem
Training with Torch FSDP2 fails to start
Minimal repro
Here is slightly modified example script
import os
import torch
from typing_extensions import TypedDict, Unpack
from megatron.bridge.training.gpt_step import forward_step
from megatron.bridge.training.pretrain import pretrain
from megatron.core.distributed import DistributedDataParallelConfig
from megatron.bridge import AutoBridge
from megatron.bridge.training.mixed_precision import MixedPrecisionConfig
from megatron.bridge.recipes.utils.optimizer_utils import distributed_fused_adam_with_cosine_annealing
from megatron.bridge.recipes.utils.finetune_utils import default_peft_config, default_squad_config
from megatron.bridge.recipes.utils.optimizer_utils import distributed_fused_adam_with_cosine_annealing
from megatron.bridge.training.config import (
CheckpointConfig,
ConfigContainer,
DistributedInitConfig,
LoggerConfig,
RNGConfig,
TokenizerConfig,
TrainingConfig,
ValidationConfig,
)
MAX_LENGTH = 1024
def _sft_common() -> ConfigContainer:
"""Create a base SFT (Supervised Fine-Tuning) ConfigContainer with common defaults.
This function returns a ConfigContainer template with sensible defaults for full SFT
(not LoRA/DoRA). The caller MUST set `cfg.model` and `cfg.tokenizer.tokenizer_model`
before use.
Key differences from pre-training:
- Uses HFDatasetConfig with SQuAD as default dataset
- Lower learning rate (5e-6) suitable for full fine-tuning
- Fewer training iterations (1000)
- Smaller batch sizes
- Supports pretrained_checkpoint loading
- No PEFT (full parameter training)
Returns:
ConfigContainer: Base configuration template for full SFT.
"""
# Default output directories
base_output_dir = os.path.join(os.getcwd(), "nemo_experiments")
run_output_dir = os.path.join(base_output_dir, "default")
checkpoint_dir = os.path.join(run_output_dir, "checkpoints")
tensorboard_dir = os.path.join(run_output_dir, "tb_logs")
# Default sequence length for SFT
seq_length = 1024
# Packed sequence is enabled by default for training efficiency
# pad_seq_to_mult should be set to context_parallel_size * 2 if CP > 1
packed_sequence = True
pad_seq_to_mult = 1 # Override in model config if context_parallel_size > 1
# Optimizer and scheduler with lower LR for full SFT
# scheduler:
# name: cosine_with_warmup
# t_warmup: 65ba # 300 for non-tools
# alpha_f: 0.1
# optimizer:
# name: decoupled_adamw
# lr: 1.0e-6
# betas:
# - 0.9
# - 0.95
# eps: 1.0e-12
# weight_decay: 0.0
opt_cfg, scheduler_cfg = distributed_fused_adam_with_cosine_annealing(
lr_warmup_iters=65,
lr_decay_iters=None, # Defaults to train_iters during validation
max_lr=1e-6, # Lower LR for full fine-tuning
min_lr=0.0,
adam_beta1=0.9,
adam_beta2=0.98, # Common for fine-tuning
weight_decay=0.1, # alpha in llmfoundry
clip_grad=2.0,
)
opt_cfg.use_precision_aware_optimizer = True
# opt_cfg.main_grads_dtype = torch.bfloat16
# opt_cfg.main_params_dtype = torch.bfloat16
cfg = ConfigContainer(
# Model - MUST be set by each recipe before use
model=None, # type: ignore[arg-type]
# Training config - shorter training for SFT
train=TrainingConfig(
train_iters=10000,
global_batch_size=32,
micro_batch_size=1,
),
validation=ValidationConfig(
eval_interval=100,
eval_iters=32,
),
# Optimizer and scheduler
optimizer=opt_cfg,
scheduler=scheduler_cfg,
# DDP config - minimal settings, model-specific configs can override
ddp=DistributedDataParallelConfig(
check_for_nan_in_grad=True,
# grad_reduce_in_fp32=True,
grad_reduce_in_fp32=False,
),
# Dataset config - uses SQuAD with packed sequences by default
dataset=default_squad_config(seq_length=MAX_LENGTH),
# Logger config
logger=LoggerConfig(
log_interval=1,
tensorboard_dir=tensorboard_dir,
log_timers_to_tensorboard=True,
log_throughput=True,
log_throughput_to_tensorboard=True,
log_l2_norm_grad_to_tensorboard=True,
throughput_window_size=10,
),
# Tokenizer - placeholder, each recipe should set tokenizer_model
tokenizer=TokenizerConfig(
tokenizer_type="HuggingFaceTokenizer",
tokenizer_model=None, # Must be set by each recipe
),
# Checkpoint config with pretrained_checkpoint support
checkpoint=CheckpointConfig(
save_interval=100,
save=checkpoint_dir,
load=checkpoint_dir,
pretrained_checkpoint=None, # Set to load from pretrained weights
ckpt_format="torch_dist",
# ckpt_format="fsdp_dtensor",
fully_parallel_save=True,
),
# RNG config - different seed from pretrain
rng=RNGConfig(seed=5678),
# Distributed init config
dist=DistributedInitConfig(use_torch_fsdp2=True),
# dist=DistributedInitConfig(use_megatron_fsdp=True),
comm_overlap=None,
# Mixed precision - bf16 by default
mixed_precision=MixedPrecisionConfig(
bf16=True,
grad_reduce_in_fp32=False,
),
# No PEFT for full SFT
peft=None,
)
return cfg
def new_qwen3_8b_sft_config() -> ConfigContainer:
"""Return a full SFT config for Qwen3 8B.
Recommended parallelism: TP=4, PP=1 (1 node, 8 GPUs)
"""
cfg = _sft_common()
model_path = "/mnt/models/Qwen3-8B"
# Model config
cfg.model = AutoBridge.from_hf_pretrained(model_path).to_megatron_provider(load_weights=False)
cfg.model.seq_length = MAX_LENGTH
# Tokenizer
cfg.tokenizer.tokenizer_model = model_path
# Parallelism settings
cfg.model.tensor_model_parallel_size = 1
cfg.model.pipeline_model_parallel_size = 1
cfg.model.pipeline_model_parallel_layout = None
cfg.model.pipeline_dtype = None
cfg.model.virtual_pipeline_model_parallel_size = None
cfg.model.context_parallel_size = 1
cfg.model.sequence_parallel = False
# Sequence length (2048 for packed sequences)
# cfg.model.seq_length = 2048
# Global batch size is 8 for packed sequences, 128 otherwise
cfg.train.global_batch_size = 8
# Set pad_seq_to_mult for context parallelism
if cfg.model.context_parallel_size > 1:
cfg.dataset.packed_sequence_specs.pad_seq_to_mult = cfg.model.context_parallel_size * 2
# Training config
cfg.validation.eval_interval = 30
cfg.train.manual_gc = False
cfg.train.manual_gc_interval = 0
# TE (Transformer Engine)
cfg.model.transformer_impl = "transformer_engine"
# CUDA Graph
cfg.model.use_te_rng_tracker = True
cfg.model.cuda_graph_impl = "none"
cfg.model.cuda_graph_scope = "full"
cfg.model.cuda_graph_warmup_steps = 3
# Kernel selections
cfg.model.attention_backend = 'auto'
cfg.model.cross_entropy_loss_fusion = True
cfg.model.cross_entropy_fusion_impl = "native"
# Memory saving (recompute & offloading)
cfg.model.recompute_granularity = "full"
cfg.model.recompute_modules = "uniform"
cfg.model.recompute_method = "uniform"
cfg.model.fine_grained_activation_offloading = False
cfg.model.recompute_num_layers = cfg.model.num_layers
cfg.model.offload_modules = None
# FP8 & MXFP8 (mixed_precision settings)
# Note: mixed_precision="bf16_mixed" is set in _sft_common as default
# These are defaults for FP8, enable them if using FP8 - FP8 is not enabled by default
# cfg.mixed_precision.fp8_recipe = "tensorwise" # default, uncomment to enable
# cfg.mixed_precision.fp8 = None # not enabled by default
# cfg.mixed_precision.fp8_param_gather = False # default
# cfg.mixed_precision.reuse_grad_buf_for_mxfp8_param_ag = False # default
# Optimizer precision settings
# cfg.optimizer.use_precision_aware_optimizer = False
# cfg.optimizer.main_grads_dtype = torch.float32
# cfg.optimizer.main_params_dtype = torch.float32
# cfg.optimizer.exp_avg_dtype = torch.float32
# cfg.optimizer.exp_avg_sq_dtype = torch.float32
# Checkpoint config
cfg.checkpoint.save_interval = 50
# cfg.checkpoint.save and cfg.checkpoint.load are set in _sft_common. To override:
cfg.checkpoint.save = "/mnt/p.geyn/nemo_exps"
# cfg.checkpoint.load = "path/to/load"
# Uncomment below if using a pretrained checkpoint and provide path to the directory containing pretrained model for finetuning
# cfg.checkpoint.pretrained_checkpoint = "/path/to/checkpoint"
# DDP config
cfg.ddp.grad_reduce_in_fp32 = False
cfg.ddp.overlap_grad_reduce = False
cfg.ddp.overlap_param_gather = False
cfg.ddp.check_for_nan_in_grad = True
cfg.ddp.use_distributed_optimizer = True
return cfg
if __name__ == "__main__":
# The recipe uses the Llama 3.2 1B model configuration from HuggingFace
# cfg = qwen3_8b_finetune_config(seq_length=1024, peft=None, hf_path="/mnt/models")
cfg = new_qwen3_8b_sft_config()
# Override training parameters
# cfg.train.train_iters = 10
# cfg.scheduler.lr_decay_iters = 10000
# cfg.model.vocab_size = 8192
cfg.tokenizer.vocab_size = cfg.model.vocab_size
pretrain(cfg, forward_step)
Expected behavior
Training starts
Affected area
area:training
Regression?
Not sure
Environment
No response
Logs
[rank3]: Traceback (most recent call last):
[rank3]: File "/workspaces/tgpt-megatron/example_bug.py", line 271, in <module>
[rank3]: pretrain(cfg, forward_step)
[rank3]: File "/workspaces/tgpt-megatron/Megatron-Bridge/src/megatron/bridge/utils/decorators.py", line 39, in wrapper
[rank3]: return func(*args, **kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/workspaces/tgpt-megatron/Megatron-Bridge/src/megatron/bridge/training/pretrain.py", line 99, in pretrain
[rank3]: _pretrain(state=state, forward_step_func=forward_step_func, callback_manager=callback_manager)
[rank3]: File "/workspaces/tgpt-megatron/Megatron-Bridge/src/megatron/bridge/training/pretrain.py", line 129, in _pretrain
[rank3]: setup_output = setup(state, dataset_provider, restart_store=store)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/workspaces/tgpt-megatron/Megatron-Bridge/src/megatron/bridge/training/setup.py", line 221, in setup
[rank3]: model = cfg.model.provide_distributed_model(
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/workspaces/tgpt-megatron/Megatron-Bridge/src/megatron/bridge/models/model_provider.py", line 195, in provide_distributed_model
[rank3]: model = get_model(
[rank3]: ^^^^^^^^^^
[rank3]: File "/workspaces/tgpt-megatron/Megatron-Bridge/src/megatron/bridge/models/model_provider.py", line 603, in get_model
[rank3]: model = _ddp_wrap(
[rank3]: ^^^^^^^^^^
[rank3]: File "/workspaces/tgpt-megatron/Megatron-Bridge/src/megatron/bridge/models/common/unimodal.py", line 233, in _ddp_wrap
[rank3]: DP(
[rank3]: TypeError: TorchFullyShardedDataParallel.__init__() got an unexpected keyword argument 'pg_collection'
Problem
Training with Torch FSDP2 fails to start
Minimal repro
Here is slightly modified example script import os import torch from typing_extensions import TypedDict, Unpack from megatron.bridge.training.gpt_step import forward_step from megatron.bridge.training.pretrain import pretrain from megatron.core.distributed import DistributedDataParallelConfig from megatron.bridge import AutoBridge from megatron.bridge.training.mixed_precision import MixedPrecisionConfig from megatron.bridge.recipes.utils.optimizer_utils import distributed_fused_adam_with_cosine_annealing from megatron.bridge.recipes.utils.finetune_utils import default_peft_config, default_squad_config from megatron.bridge.recipes.utils.optimizer_utils import distributed_fused_adam_with_cosine_annealing from megatron.bridge.training.config import ( CheckpointConfig, ConfigContainer, DistributedInitConfig, LoggerConfig, RNGConfig, TokenizerConfig, TrainingConfig, ValidationConfig, ) MAX_LENGTH = 1024 def _sft_common() -> ConfigContainer: """Create a base SFT (Supervised Fine-Tuning) ConfigContainer with common defaults. This function returns a ConfigContainer template with sensible defaults for full SFT (not LoRA/DoRA). The caller MUST set `cfg.model` and `cfg.tokenizer.tokenizer_model` before use. Key differences from pre-training: - Uses HFDatasetConfig with SQuAD as default dataset - Lower learning rate (5e-6) suitable for full fine-tuning - Fewer training iterations (1000) - Smaller batch sizes - Supports pretrained_checkpoint loading - No PEFT (full parameter training) Returns: ConfigContainer: Base configuration template for full SFT. """ # Default output directories base_output_dir = os.path.join(os.getcwd(), "nemo_experiments") run_output_dir = os.path.join(base_output_dir, "default") checkpoint_dir = os.path.join(run_output_dir, "checkpoints") tensorboard_dir = os.path.join(run_output_dir, "tb_logs") # Default sequence length for SFT seq_length = 1024 # Packed sequence is enabled by default for training efficiency # pad_seq_to_mult should be set to context_parallel_size * 2 if CP > 1 packed_sequence = True pad_seq_to_mult = 1 # Override in model config if context_parallel_size > 1 # Optimizer and scheduler with lower LR for full SFT # scheduler: # name: cosine_with_warmup # t_warmup: 65ba # 300 for non-tools # alpha_f: 0.1 # optimizer: # name: decoupled_adamw # lr: 1.0e-6 # betas: # - 0.9 # - 0.95 # eps: 1.0e-12 # weight_decay: 0.0 opt_cfg, scheduler_cfg = distributed_fused_adam_with_cosine_annealing( lr_warmup_iters=65, lr_decay_iters=None, # Defaults to train_iters during validation max_lr=1e-6, # Lower LR for full fine-tuning min_lr=0.0, adam_beta1=0.9, adam_beta2=0.98, # Common for fine-tuning weight_decay=0.1, # alpha in llmfoundry clip_grad=2.0, ) opt_cfg.use_precision_aware_optimizer = True # opt_cfg.main_grads_dtype = torch.bfloat16 # opt_cfg.main_params_dtype = torch.bfloat16 cfg = ConfigContainer( # Model - MUST be set by each recipe before use model=None, # type: ignore[arg-type] # Training config - shorter training for SFT train=TrainingConfig( train_iters=10000, global_batch_size=32, micro_batch_size=1, ), validation=ValidationConfig( eval_interval=100, eval_iters=32, ), # Optimizer and scheduler optimizer=opt_cfg, scheduler=scheduler_cfg, # DDP config - minimal settings, model-specific configs can override ddp=DistributedDataParallelConfig( check_for_nan_in_grad=True, # grad_reduce_in_fp32=True, grad_reduce_in_fp32=False, ), # Dataset config - uses SQuAD with packed sequences by default dataset=default_squad_config(seq_length=MAX_LENGTH), # Logger config logger=LoggerConfig( log_interval=1, tensorboard_dir=tensorboard_dir, log_timers_to_tensorboard=True, log_throughput=True, log_throughput_to_tensorboard=True, log_l2_norm_grad_to_tensorboard=True, throughput_window_size=10, ), # Tokenizer - placeholder, each recipe should set tokenizer_model tokenizer=TokenizerConfig( tokenizer_type="HuggingFaceTokenizer", tokenizer_model=None, # Must be set by each recipe ), # Checkpoint config with pretrained_checkpoint support checkpoint=CheckpointConfig( save_interval=100, save=checkpoint_dir, load=checkpoint_dir, pretrained_checkpoint=None, # Set to load from pretrained weights ckpt_format="torch_dist", # ckpt_format="fsdp_dtensor", fully_parallel_save=True, ), # RNG config - different seed from pretrain rng=RNGConfig(seed=5678), # Distributed init config dist=DistributedInitConfig(use_torch_fsdp2=True), # dist=DistributedInitConfig(use_megatron_fsdp=True), comm_overlap=None, # Mixed precision - bf16 by default mixed_precision=MixedPrecisionConfig( bf16=True, grad_reduce_in_fp32=False, ), # No PEFT for full SFT peft=None, ) return cfg def new_qwen3_8b_sft_config() -> ConfigContainer: """Return a full SFT config for Qwen3 8B. Recommended parallelism: TP=4, PP=1 (1 node, 8 GPUs) """ cfg = _sft_common() model_path = "/mnt/models/Qwen3-8B" # Model config cfg.model = AutoBridge.from_hf_pretrained(model_path).to_megatron_provider(load_weights=False) cfg.model.seq_length = MAX_LENGTH # Tokenizer cfg.tokenizer.tokenizer_model = model_path # Parallelism settings cfg.model.tensor_model_parallel_size = 1 cfg.model.pipeline_model_parallel_size = 1 cfg.model.pipeline_model_parallel_layout = None cfg.model.pipeline_dtype = None cfg.model.virtual_pipeline_model_parallel_size = None cfg.model.context_parallel_size = 1 cfg.model.sequence_parallel = False # Sequence length (2048 for packed sequences) # cfg.model.seq_length = 2048 # Global batch size is 8 for packed sequences, 128 otherwise cfg.train.global_batch_size = 8 # Set pad_seq_to_mult for context parallelism if cfg.model.context_parallel_size > 1: cfg.dataset.packed_sequence_specs.pad_seq_to_mult = cfg.model.context_parallel_size * 2 # Training config cfg.validation.eval_interval = 30 cfg.train.manual_gc = False cfg.train.manual_gc_interval = 0 # TE (Transformer Engine) cfg.model.transformer_impl = "transformer_engine" # CUDA Graph cfg.model.use_te_rng_tracker = True cfg.model.cuda_graph_impl = "none" cfg.model.cuda_graph_scope = "full" cfg.model.cuda_graph_warmup_steps = 3 # Kernel selections cfg.model.attention_backend = 'auto' cfg.model.cross_entropy_loss_fusion = True cfg.model.cross_entropy_fusion_impl = "native" # Memory saving (recompute & offloading) cfg.model.recompute_granularity = "full" cfg.model.recompute_modules = "uniform" cfg.model.recompute_method = "uniform" cfg.model.fine_grained_activation_offloading = False cfg.model.recompute_num_layers = cfg.model.num_layers cfg.model.offload_modules = None # FP8 & MXFP8 (mixed_precision settings) # Note: mixed_precision="bf16_mixed" is set in _sft_common as default # These are defaults for FP8, enable them if using FP8 - FP8 is not enabled by default # cfg.mixed_precision.fp8_recipe = "tensorwise" # default, uncomment to enable # cfg.mixed_precision.fp8 = None # not enabled by default # cfg.mixed_precision.fp8_param_gather = False # default # cfg.mixed_precision.reuse_grad_buf_for_mxfp8_param_ag = False # default # Optimizer precision settings # cfg.optimizer.use_precision_aware_optimizer = False # cfg.optimizer.main_grads_dtype = torch.float32 # cfg.optimizer.main_params_dtype = torch.float32 # cfg.optimizer.exp_avg_dtype = torch.float32 # cfg.optimizer.exp_avg_sq_dtype = torch.float32 # Checkpoint config cfg.checkpoint.save_interval = 50 # cfg.checkpoint.save and cfg.checkpoint.load are set in _sft_common. To override: cfg.checkpoint.save = "/mnt/p.geyn/nemo_exps" # cfg.checkpoint.load = "path/to/load" # Uncomment below if using a pretrained checkpoint and provide path to the directory containing pretrained model for finetuning # cfg.checkpoint.pretrained_checkpoint = "/path/to/checkpoint" # DDP config cfg.ddp.grad_reduce_in_fp32 = False cfg.ddp.overlap_grad_reduce = False cfg.ddp.overlap_param_gather = False cfg.ddp.check_for_nan_in_grad = True cfg.ddp.use_distributed_optimizer = True return cfg if __name__ == "__main__": # The recipe uses the Llama 3.2 1B model configuration from HuggingFace # cfg = qwen3_8b_finetune_config(seq_length=1024, peft=None, hf_path="/mnt/models") cfg = new_qwen3_8b_sft_config() # Override training parameters # cfg.train.train_iters = 10 # cfg.scheduler.lr_decay_iters = 10000 # cfg.model.vocab_size = 8192 cfg.tokenizer.vocab_size = cfg.model.vocab_size pretrain(cfg, forward_step)Expected behavior
Training starts
Affected area
area:training
Regression?
Not sure
Environment
No response
Logs