Skip to content

Support saving custom megatron model (e.g. pruned variant) back to Hugging Face format #2036

@kevalmorabia97

Description

@kevalmorabia97

Is your feature request related to a problem? Please describe.
Currently HF can be converted to megatron but when saving the megatron model back to HF, it expects the model architecture to be exactly same and only allows weights to be changed.

When we use ModelOpt to prune a GPT/Mamba model, it will result in a model with less number of layers, lower hidden/ffn/attentions/etc. While we can easily save this as a megatron ckpt, we want to be able to save this as a hugging face checkpoint as well

Describe the solution you'd like

Natively supported in bridge.save_hf_pretrained() API

Describe alternatives you've considered

bridge = AutoBridge.from_hf_pretrained("Qwen/Qwen3-0.6B")
provider = bridge.to_megatron_provider()
provider.finalize()
model = provider.provide_distributed_model(wrap_with_ddp=False)
unwrapped_model = model[0].module

kept_layer_nums = range(1,25)  # 1-indexed
modelopt.prune(model, num_layers=24, hidden_size=1024, kept_layer_nums=kept_layer_nums)

# Save artifacts, overwrite config, create dummy weights, use new bridge to save pruned model weights
bridge.hf_pretrained.save_artifacts(output_hf_path)
hf_cfg = AutoConfig.from_pretrained(output_hf_path)
mcore_cfg = unwrapped_model.config

hf_cfg.hidden_size = mcore_cfg.hidden_size
hf_cfg.intermediate_size = mcore_cfg.ffn_hidden_size
hf_cfg.num_attention_heads = mcore_cfg.num_attention_heads
hf_cfg.head_dim = mcore_cfg.kv_channels
hf_cfg.num_key_value_heads = mcore_cfg.num_query_groups
if hasattr(hf_cfg, "mamba_num_heads"):
    hf_cfg.mamba_num_heads = mcore_cfg.mamba_num_heads
if hasattr(hf_cfg, "mamba_head_dim"):
    hf_cfg.mamba_head_dim = mcore_cfg.mamba_head_dim
if hasattr(hf_cfg, "moe_intermediate_size"):
    hf_cfg.moe_intermediate_size = mcore_cfg.moe_ffn_hidden_size
if hasattr(hf_cfg, "moe_shared_expert_intermediate_size"):
    hf_cfg.moe_shared_expert_intermediate_size = (
        mcore_cfg.moe_shared_expert_intermediate_size
    )
if hasattr(hf_cfg, "num_experts"):
    hf_cfg.num_experts = mcore_cfg.num_moe_experts
if hasattr(hf_cfg, "n_routed_experts"):
    hf_cfg.n_routed_experts = mcore_cfg.num_moe_experts
if hasattr(hf_cfg, "n_shared_experts"):
    hf_cfg.n_shared_experts = (
        mcore_cfg.moe_shared_expert_intermediate_size // mcore_cfg.moe_ffn_hidden_size
    )
if hasattr(hf_cfg, "layer_types"):
    hf_cfg.layer_types = [
        lt for i, lt in enumerate(hf_cfg.layer_types) if i + 1 in kept_layer_nums
    ]
hf_cfg.num_hidden_layers = mcore_cfg.num_layers

# Save dummy pruned HF model to get the correct bridge for saving pruned weights
AutoModelForCausalLM.from_config(hf_cfg).save_pretrained(output_hf_path)
pruned_bridge = AutoBridge.from_hf_pretrained(output_hf_path)
pruned_bridge.save_hf_weights(model, output_hf_path)

Metadata

Metadata

Assignees

Labels

area:ckptCheckpoint conversion, loading, export, and save pathsfeatureNew capabilities, enhancements, or enablement workwaiting-on-maintainersWaiting on maintainers to respond

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions