Skip to content

Support pure eval loop given a model and ConfigContainer #2033

@kevalmorabia97

Description

@kevalmorabia97

Is your feature request related to a problem? Please describe.

For various ModelOpt optimizations (quantization, pruning), we need to run forward-only inference on a small calibration dataset. Currently we have to write a bit hacky way of doing this, but would like to have this natively supported in megatron bridge via a API similar to Nemo2's llm.validate API.

Describe the solution you'd like

A simpler version of megatron.bridge.pretraining.eval.evaluate_and_print_results where we dont have to manually construct and pass state, forward_step, data_iterator, etc.:

bridge = AutoBridge.from_hf_pretrained("Qwen/Qwen3-0.6B")
provider = bridge.to_megatron_provider()
provider.finalize()
model = provider.provide_distributed_model(wrap_with_ddp=False)

cfg = ConfigContainer(
    model=provider,
    train=TrainingConfig(
        micro_batch_size=1,
        global_batch_size=32,
        train_iters=32,
        eval_iters=32,
        skip_train=True,
    ),
    dataset=HFDatasetConfig(...),  # Alternatively GPTDatasetConfig
    tokenizer=TokenizerConfig(
        tokenizer_type="HuggingFaceTokenizer",
        tokenizer_model="Qwen/Qwen3-0.6B",
    ),
    # Unused - Currently still need to set them else errors out
    optimizer=OptimizerConfig(optimizer="adam", lr=1e-4, use_distributed_optimizer=False),
    scheduler=SchedulerConfig(lr_decay_style="constant"),
    logger=LoggerConfig(),
    checkpoint=CheckpointConfig(),
)

lm_loss = simpler_eval(model, cfg)

Describe alternatives you've considered

A bit hacky approach (let me know if this can already be simplified)

bridge = AutoBridge.from_hf_pretrained("Qwen/Qwen3-0.6B")
provider = bridge.to_megatron_provider()
provider.finalize()
model = provider.provide_distributed_model(wrap_with_ddp=False)

cfg = ConfigContainer(
    model=provider,
    train=TrainingConfig(
        micro_batch_size=1,
        global_batch_size=32,
        train_iters=32,
        eval_iters=32,
        skip_train=True,
    ),
    dataset=HFDatasetConfig(,  # Alternatively GPTDatasetConfig
        dataset_name="cnn_dailymail",
        dataset_dict=DatasetDict({"train": ["sample1", "sample2", ...]}),
        process_example_fn=lambda example, tokenizer: {"input": example, "output": ""},
        seq_length=1024,
        dataloader_type="batch",
        num_workers=1,
        do_validation=False,
        do_test=False,
        val_proportion=None,
        split_val_from_train=False,
        rewrite=False,
    )
    tokenizer=TokenizerConfig(
        tokenizer_type="HuggingFaceTokenizer",
        tokenizer_model="Qwen/Qwen3-0.6B",
    ),
    # Unused - Currently still need to set them else errors out
    optimizer=OptimizerConfig(optimizer="adam", lr=1e-4, use_distributed_optimizer=False),
    scheduler=SchedulerConfig(lr_decay_style="constant"),
    logger=LoggerConfig(),
    checkpoint=CheckpointConfig(),
)


# Current Hacky Approach
from megatron.bridge.training.eval import evaluate_and_print_results
from megatron.bridge.training.gpt_step import forward_step

runtime_config_update(cfg)

state = GlobalState()
state.cfg = cfg

dataset_provider = get_dataset_provider(cfg.dataset)

def _train_valid_test_datasets_provider(
    train_val_test_num_samples: tuple, dataset_cfg: HFDatasetConfig
):
    return dataset_provider(train_val_test_num_samples, dataset_cfg, tokenizer=state.tokenizer)

train_data_iterator, _, _ = setup_data_iterators(
    cfg=cfg,
    train_state=state.train_state,
    model_length=len(model),
    train_valid_test_datasets_provider=_train_valid_test_datasets_provider,
    dp_group=get_data_parallel_group(),
)

evaluate_and_print_results( # Have this function also return validation loss
    state,
    prefix="iteration 1",
    forward_step_func=forward_step,
    data_iterator=train_data_iterator,
    model=model,
    config=cfg,
    verbose=True,
    write_to_tensorboard=False,
)

Happy to contribute above approach as an API if it looks right or if you have a better approach

Additional context
Add any other context or screenshots about the feature request here.

Metadata

Metadata

Assignees

No one assigned

    Labels

    area:trainingTraining loop, callbacks, and runtime integrationfeatureNew capabilities, enhancements, or enablement workwaiting-on-maintainersWaiting on maintainers to respond

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions