Is your feature request related to a problem? Please describe.
For various ModelOpt optimizations (quantization, pruning), we need to run forward-only inference on a small calibration dataset. Currently we have to write a bit hacky way of doing this, but would like to have this natively supported in megatron bridge via a API similar to Nemo2's llm.validate API.
Describe the solution you'd like
A simpler version of megatron.bridge.pretraining.eval.evaluate_and_print_results where we dont have to manually construct and pass state, forward_step, data_iterator, etc.:
bridge = AutoBridge.from_hf_pretrained("Qwen/Qwen3-0.6B")
provider = bridge.to_megatron_provider()
provider.finalize()
model = provider.provide_distributed_model(wrap_with_ddp=False)
cfg = ConfigContainer(
model=provider,
train=TrainingConfig(
micro_batch_size=1,
global_batch_size=32,
train_iters=32,
eval_iters=32,
skip_train=True,
),
dataset=HFDatasetConfig(...), # Alternatively GPTDatasetConfig
tokenizer=TokenizerConfig(
tokenizer_type="HuggingFaceTokenizer",
tokenizer_model="Qwen/Qwen3-0.6B",
),
# Unused - Currently still need to set them else errors out
optimizer=OptimizerConfig(optimizer="adam", lr=1e-4, use_distributed_optimizer=False),
scheduler=SchedulerConfig(lr_decay_style="constant"),
logger=LoggerConfig(),
checkpoint=CheckpointConfig(),
)
lm_loss = simpler_eval(model, cfg)
Describe alternatives you've considered
A bit hacky approach (let me know if this can already be simplified)
bridge = AutoBridge.from_hf_pretrained("Qwen/Qwen3-0.6B")
provider = bridge.to_megatron_provider()
provider.finalize()
model = provider.provide_distributed_model(wrap_with_ddp=False)
cfg = ConfigContainer(
model=provider,
train=TrainingConfig(
micro_batch_size=1,
global_batch_size=32,
train_iters=32,
eval_iters=32,
skip_train=True,
),
dataset=HFDatasetConfig(, # Alternatively GPTDatasetConfig
dataset_name="cnn_dailymail",
dataset_dict=DatasetDict({"train": ["sample1", "sample2", ...]}),
process_example_fn=lambda example, tokenizer: {"input": example, "output": ""},
seq_length=1024,
dataloader_type="batch",
num_workers=1,
do_validation=False,
do_test=False,
val_proportion=None,
split_val_from_train=False,
rewrite=False,
)
tokenizer=TokenizerConfig(
tokenizer_type="HuggingFaceTokenizer",
tokenizer_model="Qwen/Qwen3-0.6B",
),
# Unused - Currently still need to set them else errors out
optimizer=OptimizerConfig(optimizer="adam", lr=1e-4, use_distributed_optimizer=False),
scheduler=SchedulerConfig(lr_decay_style="constant"),
logger=LoggerConfig(),
checkpoint=CheckpointConfig(),
)
# Current Hacky Approach
from megatron.bridge.training.eval import evaluate_and_print_results
from megatron.bridge.training.gpt_step import forward_step
runtime_config_update(cfg)
state = GlobalState()
state.cfg = cfg
dataset_provider = get_dataset_provider(cfg.dataset)
def _train_valid_test_datasets_provider(
train_val_test_num_samples: tuple, dataset_cfg: HFDatasetConfig
):
return dataset_provider(train_val_test_num_samples, dataset_cfg, tokenizer=state.tokenizer)
train_data_iterator, _, _ = setup_data_iterators(
cfg=cfg,
train_state=state.train_state,
model_length=len(model),
train_valid_test_datasets_provider=_train_valid_test_datasets_provider,
dp_group=get_data_parallel_group(),
)
evaluate_and_print_results( # Have this function also return validation loss
state,
prefix="iteration 1",
forward_step_func=forward_step,
data_iterator=train_data_iterator,
model=model,
config=cfg,
verbose=True,
write_to_tensorboard=False,
)
Happy to contribute above approach as an API if it looks right or if you have a better approach
Additional context
Add any other context or screenshots about the feature request here.
Is your feature request related to a problem? Please describe.
For various ModelOpt optimizations (quantization, pruning), we need to run forward-only inference on a small calibration dataset. Currently we have to write a bit hacky way of doing this, but would like to have this natively supported in megatron bridge via a API similar to Nemo2's
llm.validateAPI.Describe the solution you'd like
A simpler version of
megatron.bridge.pretraining.eval.evaluate_and_print_resultswhere we dont have to manually construct and pass state, forward_step, data_iterator, etc.:Describe alternatives you've considered
A bit hacky approach (let me know if this can already be simplified)
Happy to contribute above approach as an API if it looks right or if you have a better approach
Additional context
Add any other context or screenshots about the feature request here.