Skip to content

Fix head_size in NeMo to HF checkpoint converters for width pruned model support#11230

Merged
kevalmorabia97 merged 2 commits intoNVIDIA-NeMo:mainfrom
eagle705:fix-nemo-to-hf-llama-converter
Nov 15, 2024
Merged

Fix head_size in NeMo to HF checkpoint converters for width pruned model support#11230
kevalmorabia97 merged 2 commits intoNVIDIA-NeMo:mainfrom
eagle705:fix-nemo-to-hf-llama-converter

Conversation

@eagle705
Copy link
Collaborator

@eagle705 eagle705 commented Nov 8, 2024

What does this PR do ?

  • update llama ckpt converter for supporting width pruning.
    • change head_size as kv_channels
  • Because of width pruning of hidden_size, attention head size should be updated by kv_channels.

Changelog

  • update llama ckpt converter's attention head dim size by kv_channels
  • head_size = model.cfg.get("kv_channels") or (hidden_size // head_num)

Usage

    hidden_size = model.cfg.hidden_size
    head_num = model.cfg.num_attention_heads
    num_layers = model.cfg.num_layers
    ffn_hidden_size = model.cfg.ffn_hidden_size
    num_query_groups = model.cfg.get("num_query_groups", head_num)  # different num_query_groups for 70B

+  head_size = model.cfg.get("kv_channels") or (hidden_size // head_num) # equivalent to hf's head_dim
-   head_size = hidden_size // head_num
    heads_per_group = head_num // num_query_groups
    qkv_total_dim = head_num + 2 * num_query_groups

GitHub Actions CI

The Jenkins CI system has been replaced by GitHub Actions self-hosted runners.

The GitHub Actions CI will run automatically when the "Run CICD" label is added to the PR.
To re-run CI remove and add the label again.
To run CI on an untrusted fork, a NeMo user with write access must first click "Approve and run".

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?
  • Does the PR affect components that are optional to install? (Ex: Numba, Pynini, Apex etc)
    • Reviewer: Does the PR have correct import guards for all optional libraries?

PR Type:

  • New Feature
  • Bugfix
  • Documentation

Who can review?

@kevalmorabia97

Additional Information

Pruned values corresponding to huggingface's config.json should be updated.
I added head_dim (hf) to hf's config.json with the same value of kv_channels (nemo)

@eagle705 eagle705 force-pushed the fix-nemo-to-hf-llama-converter branch from 318009a to c6975a1 Compare November 8, 2024 13:27
@kevalmorabia97
Copy link
Collaborator

You need to signoff commit: git commit -s -m "MESSAGE" else NeMo wont allow: https://github.com/NVIDIA/NeMo/pull/11230/checks?check_run_id=32714934933

@eagle705
Copy link
Collaborator Author

eagle705 commented Nov 8, 2024

FYI) @kevalmorabia97
When I add args related to tokenizer, I encounter errors as below.

I put manually modified config file to hf_input_path and got converted ckpt into the same path (--hf_output_path).

  • Script
python /opt/NeMo/scripts/checkpoint_converters/convert_llama_nemo_to_hf.py \
    --input_name_or_path ${base_dir}/KD/distill_trainings/megatron_llama_distill/checkpoints/distilled_4b_model_extracted \
    --output_path ${base_dir}/ckpt/distilled_4b_model_llama3_8b_inst/pytorch_model.bin \
    --hf_input_path ${base_dir}/ckpt/distilled_4b_model_llama3_8b_inst \
    --hf_output_path ${base_dir}/ckpt/distilled_4b_model_llama3_8b_inst
  • Error msg
converting layer 31
done layer 31
[NeMo I 2024-11-08 13:34:37 convert_llama_nemo_to_hf:228] Weights saved to /<base_dir>/ckpt/distilled_4b_model_llama3_8b_inst/pytorch_model.bin
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.55it/s]
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'PreTrainedTokenizerFast'. 
The class this function is called from is 'LlamaTokenizer'.
[rank0]: Traceback (most recent call last):
[rank0]:   File "/opt/NeMo/scripts/checkpoint_converters/convert_llama_nemo_to_hf.py", line 275, in <module>
[rank0]:     replace_hf_weights_and_tokenizer(
[rank0]:   File "/opt/NeMo/scripts/checkpoint_converters/convert_llama_nemo_to_hf.py", line 249, in replace_hf_weights_and_tokenizer
[rank0]:     tokenizer = LlamaTokenizer.from_pretrained(
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils_base.py", line 2213, in from_pretrained
[rank0]:     return cls._from_pretrained(
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils_base.py", line 2447, in _from_pretrained
[rank0]:     tokenizer = cls(*init_inputs, **init_kwargs)
[rank0]:   File "/opt/NeMo/nemo/lightning/io/mixin.py", line 577, in wrapped_init
[rank0]:     original_init(self, *args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/transformers/models/llama/tokenization_llama.py", line 169, in __init__
[rank0]:     self.sp_model = self.get_spm_processor(kwargs.pop("from_slow", False))
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/transformers/models/llama/tokenization_llama.py", line 199, in get_spm_processor
[rank0]:     with open(self.vocab_file, "rb") as f:
[rank0]: TypeError: expected str, bytes or os.PathLike object, not NoneType

@eagle705 eagle705 force-pushed the fix-nemo-to-hf-llama-converter branch from 1e9c039 to fac2270 Compare November 8, 2024 13:56
@eagle705
Copy link
Collaborator Author

eagle705 commented Nov 8, 2024

Thanks @kevalmorabia97
I applied sign-off to the commit

@kevalmorabia97 kevalmorabia97 changed the title update head_size to kv_channels for width pruning support Fix head_size in NeMo to HF checkpoint converters for width pruned model support Nov 8, 2024
cuichenx
cuichenx previously approved these changes Nov 8, 2024
Copy link
Collaborator

@cuichenx cuichenx left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

suiyoubi
suiyoubi previously approved these changes Nov 8, 2024
@suiyoubi suiyoubi added Run CICD and removed Run CICD labels Nov 8, 2024
Signed-off-by: Joosung <joosungy@nvidia.com>
@eagle705 eagle705 dismissed stale reviews from suiyoubi and cuichenx via 5695859 November 8, 2024 14:37
@eagle705 eagle705 force-pushed the fix-nemo-to-hf-llama-converter branch from fac2270 to 5695859 Compare November 8, 2024 14:37
@kevalmorabia97
Copy link
Collaborator

@cuichenx can you please comment on the tokenizer issue posted above? Does that require some other parameters to be updated to tokenizer loading logic?

kevalmorabia97
kevalmorabia97 previously approved these changes Nov 8, 2024
@cuichenx
Copy link
Collaborator

cuichenx commented Nov 8, 2024

can you try setting hf_output_path to a different dir?

@eagle705
Copy link
Collaborator Author

eagle705 commented Nov 9, 2024

can you try setting hf_output_path to a different dir?

@cuichenx
I gave it a try with a different dir but it didn't work🤔. I got the same error msg

  • modified script
python /opt/NeMo/scripts/checkpoint_converters/convert_llama_nemo_to_hf.py \
    --input_name_or_path ${base_dir}/KD/distill_trainings/megatron_llama_distill/checkpoints/distilled_4b_model_extracted \
    --output_path ${base_dir}/ckpt/distilled_4b_model_llama3_8b_inst/pytorch_model.bin \
    --hf_input_path ${base_dir}/ckpt/distilled_4b_model_llama3_8b_inst \
    --hf_output_path ${base_dir}/ckpt/distilled_4b_model_llama3_8b_inst_2 \
    --input_tokenizer ${base_dir}/ckpt/Meta-Llama-3-8B-Instruct \
    --hf_output_tokenizer ${base_dir}/ckpt/distilled_4b_model_llama3_8b_inst_2

@kevalmorabia97
Copy link
Collaborator

kevalmorabia97 commented Nov 12, 2024

@cuichenx seems like with different output path as well there is the same issue with tokenizer. Can you suggest how to fix that?
Weight conversion works fine and pruning never touches the tokenizer so doesn't seem to be related to ModelOpt Pruning

@cuichenx
Copy link
Collaborator

I'm not sure. @tdene added support for the tokenizer. Do you have any idea?

@tdene
Copy link
Contributor

tdene commented Nov 13, 2024

Discussed offline. My comments on how to fix the tokenizer-related issue are as follows:

Change the script explanation to

"""
Script to convert a llama checkpoint in nemo (mcore path) into a HuggingFace checkpoint.
This script can be used to 1) generate only the HF weights, or 2) generate an entire HF model folder.

1) Generate only HF weights from a nemo file:

    python convert_llama_nemo_to_hf.py \
    --input_name_or_path /path/to/file.nemo or /path/to/extracted_folder \
    --output_path /path/to/pytorch_model.bin
    
2) Generate the full HF model folder

    python convert_llama_nemo_to_hf.py \
    --input_name_or_path /path/to/file.nemo or /path/to/extracted_folder \
    --output_path /path/to/pytorch_model.bin \
    --hf_input_path /path/to/input_hf_folder \
    --hf_output_path /path/to/output_hf_folder

3) Generate the full HF model folder with a custom tokenizer

    python convert_llama_nemo_to_hf.py \
    --input_name_or_path /path/to/file.nemo or /path/to/extracted_folder \
    --output_path /path/to/pytorch_model.bin \
    --hf_input_path /path/to/input_hf_folder \
    --hf_output_path /path/to/output_hf_folder \
    --input_tokenizer /path/to/custom_nemo_tokenizer.model \
    --hf_output_tokenizer /path/to/output_tokenizer

    Use the --cpu-only flag if the model cannot fit in the GPU (e.g. Llama2 70b). 
    However this option makes the conversion script significantly slower.
"""

Change L248 through 257 to

    if tokenizer_path:
        try:
            tokenizer = LlamaTokenizer.from_pretrained(
                tokenizer_path,
                local_files_only=True,
                legacy=False,
            )
            tmp_tokenizer = convert_slow_tokenizer.convert_slow_tokenizer(tokenizer)
            fast_tokenizer = LlamaTokenizerFast(tokenizer_object=tmp_tokenizer)
            tokenizer_length = len(fast_tokenizer)
            model.resize_token_embeddings(tokenizer_length)
        except:
            tokenizer = None
            logger.warning("Could not load custom tokenizer, proceeding with default tokenizer")

Change L263 through 266 to

    if tokenizer_path and (tokenizer is not None):
        fast_tokenizer.save_pretrained(output_hf_tokenizer)
        tokenizer.save_pretrained(output_hf_tokenizer)
        logging.info(f"Tokenizer saved to {output_hf_tokenizer}")

Signed-off-by: Joosung <joosungy@nvidia.com>
fast_tokenizer = LlamaTokenizerFast(tokenizer_object=tmp_tokenizer)
tokenizer_length = len(fast_tokenizer)
model.resize_token_embeddings(tokenizer_length)
except:

Check notice

Code scanning / CodeQL

Except block handles 'BaseException'

Except block directly handles BaseException.
@kevalmorabia97 kevalmorabia97 enabled auto-merge (squash) November 15, 2024 11:07
@kevalmorabia97 kevalmorabia97 merged commit ed244d9 into NVIDIA-NeMo:main Nov 15, 2024
HuiyingLi pushed a commit to HuiyingLi/NeMo that referenced this pull request Nov 15, 2024
…del support (NVIDIA-NeMo#11230)

* update attn head_size to kv_channels for width pruning support

Signed-off-by: Joosung <joosungy@nvidia.com>

* Update llama ckpt converter usage about tokenizer args

Signed-off-by: Joosung <joosungy@nvidia.com>

---------

Signed-off-by: Joosung <joosungy@nvidia.com>
Co-authored-by: Joosung <joosungy@nvidia.com>
XuesongYang pushed a commit to paarthneekhara/NeMo that referenced this pull request Jan 18, 2025
…del support (NVIDIA-NeMo#11230)

* update attn head_size to kv_channels for width pruning support

Signed-off-by: Joosung <joosungy@nvidia.com>

* Update llama ckpt converter usage about tokenizer args

Signed-off-by: Joosung <joosungy@nvidia.com>

---------

Signed-off-by: Joosung <joosungy@nvidia.com>
Co-authored-by: Joosung <joosungy@nvidia.com>
youngeunkwon0405 pushed a commit to youngeunkwon0405/NeMo that referenced this pull request Feb 10, 2025
…del support (NVIDIA-NeMo#11230)

* update attn head_size to kv_channels for width pruning support

Signed-off-by: Joosung <joosungy@nvidia.com>

* Update llama ckpt converter usage about tokenizer args

Signed-off-by: Joosung <joosungy@nvidia.com>

---------

Signed-off-by: Joosung <joosungy@nvidia.com>
Co-authored-by: Joosung <joosungy@nvidia.com>
Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants