Skip to content

Generate: Fix modern llm generate calls with synced_gpus#34095

Merged
gante merged 5 commits intohuggingface:mainfrom
gante:prepare_sync_gpus
Oct 12, 2024
Merged

Generate: Fix modern llm generate calls with synced_gpus#34095
gante merged 5 commits intohuggingface:mainfrom
gante:prepare_sync_gpus

Conversation

@gante
Copy link
Contributor

@gante gante commented Oct 11, 2024

What does this PR do?

Step 5 in #32685
Fixes #32885
Fixes #32603
Fixes #32641

Modern LLMs, i.e. LLMs that support our cache classes, currently fail when the input has a batch size > 1 and synced_gpus = True.

On main, this is what happens with synced_gpus

  1. cache_position stops being updated when generation finishes in a given device, causing cache indexing errors on that device (the cache continues growing because we keep doing dummy forward passes)
  2. if we continue updating cache_position, then slicing input_ids gets out of bounds for the dummy computations (we stop updating input_ids, so it stops growing)

This PR makes the changes to enable generation with the behavior above.

💛 Please note that, because of the efforts in #32685, updating model input preparation requires an update in a single function, as opposed to an update per model 💛


Test script (call with 2+ GPUs) that fails before this PR (from this comment):

import transformers
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp

def run(rank, size):
    # dist.initialize_dist('gpu')

    name = 'meta-llama/Meta-Llama-3-8B-Instruct'
    tokenizer = transformers.AutoTokenizer.from_pretrained(name)
    pad_token_id = tokenizer.eos_token_id
    model = transformers.AutoModelForCausalLM.from_pretrained(name)

    # rank = dist.get_global_rank()

    model.to(f'cuda:{rank}')

    if rank == 0:
        content = 'Write one short sentence.'
    else:
        content = 'Write one long paragraph.'

    messages = [
        {
            'role': 'user',
            'content': content,
        }
    ]

    tokenized_messages = tokenizer.apply_chat_template(messages, return_tensors='pt')

    padded_messages = torch.cat(
        [
            torch.LongTensor((4096 - 20) * [pad_token_id]),
            tokenized_messages[0],  # [seq]
        ],
        dim=0,
    )
    padded_messages = padded_messages.unsqueeze(0)
    padded_messages = padded_messages.to(f'cuda:{rank}')
    attention_mask = ~(padded_messages == pad_token_id)
    attention_mask = attention_mask.to(f'cuda:{rank}')
    output = model.generate(input_ids=padded_messages, attention_mask=attention_mask, synced_gpus=True, max_new_tokens=200)

    print(tokenizer.decode(output[0]))

def init_process(rank, size, fn, backend='gloo'):
    """ Initialize the distributed environment. """
    os.environ['MASTER_ADDR'] = '127.0.0.1'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(backend, rank=rank, world_size=size)
    fn(rank, size)

if __name__ == "__main__":
    size = 2
    processes = []
    mp.set_start_method("spawn")
    for rank in range(size):
        p = mp.Process(target=init_process, args=(rank, size, run))
        p.start()
        processes.append(p)

    for p in processes:
        p.join()

@gante
Copy link
Contributor Author

gante commented Oct 11, 2024

@SunMarc this should help with FSDP + generate 🤗

Comment on lines -4165 to -4174
# This is needed if return_dict_in_generate is True
start_from_empty_dynamic_cache = False
past_key_values = model_kwargs.get("past_key_values", None)
if isinstance(past_key_values, DynamicCache) or (
isinstance(past_key_values, EncoderDecoderCache)
and isinstance(past_key_values.self_attention_cache, DynamicCache)
):
if past_key_values.get_seq_length() == 0:
start_from_empty_dynamic_cache = True

Copy link
Contributor Author

@gante gante Oct 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simplifies logic in assisted generation: see the new is_first_iteration variable and its uses :)

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The decorelation between prepare input for generation and the modeling is very nice.
I don't know how well we test this, if the slow CIs were crying or not, but if yes, then it's already tested and Good to go!

@ringohoffman
Copy link

This fixes the error I was seeing here:

Thank you so much!

@gante
Copy link
Contributor Author

gante commented Oct 12, 2024

@ArthurZucker I don't think this is being tested!

@SunMarc -- I couldn't find any related test, but multigpu tests have a more elaborated setup, so I could be missing something. Can you confirm?

Meanwhile, I'm merging since this PR unblocks users. If there is no test, I'll open a follow-up PR :)

@gante gante merged commit 37ea040 into huggingface:main Oct 12, 2024
@gante gante deleted the prepare_sync_gpus branch October 12, 2024 15:45
@SunMarc
Copy link
Member

SunMarc commented Oct 14, 2024

@SunMarc -- I couldn't find any related test, but multigpu tests have a more elaborated setup, so I could be missing something. Can you confirm?

I'm not aware of any tests related to multi-gpu and generate with sync_gpus=True. I will have a look at this since we also need to add them for deepspeed and fdsp ! cc @muellerzr

@jiayuanmark
Copy link

jiayuanmark commented Apr 16, 2025

Does this fully address the issue with generate calls with synced_gpus?

When use_cache=False, I encountered mismatched tensor dim due to:

cc: @gante

@gante
Copy link
Contributor Author

gante commented Apr 18, 2025

@jiayuanmark the snippet in the PR header, which serves as a base test case, runs without problems 🤔

Would you be able to create a minimal reproducer for your issue, and open a new issue?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Multi GPU generate with llama shape error Shape mismatch when generating with multiple processes

5 participants