Skip to content

Conversational Pipeline returns <|im_end|> in the assistant's output. #28801

@OfficialDelta

Description

@OfficialDelta

System Info

  • transformers version: 4.37.2
  • Platform: Linux-5.15.0-91-generic-x86_64-with-glibc2.35
  • Python version: 3.10.13
  • Huggingface_hub version: 0.20.2
  • Safetensors version: 0.4.2
  • Accelerate version: 0.26.1
  • Accelerate config: - compute_environment: LOCAL_MACHINE
    - distributed_type: DEEPSPEED
    - use_cpu: False
    - debug: True
    - num_processes: 8
    - machine_rank: 0
    - num_machines: 1
    - rdzv_backend: static
    - same_network: True
    - main_training_function: main
    - deepspeed_config: {'deepspeed_config_file': '/workspace/zero3.json', 'zero3_init_flag': True}
    - downcast_bf16: no
    - tpu_use_cluster: False
    - tpu_use_sudo: False
    - tpu_env: []
  • PyTorch version (GPU?): 2.2.0 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: Yes
  • Using distributed or parallel set-up in script?: No

Who can help?

@Narsil
@Rocketknight1

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

I'm trying to inference on a custom fine-tuned Mixtral-8x7B-Instruct-v0.1 model. The fine-tuning dataset I generated used the chatml format for tokenizing the data, and when I try inferencing, the conversational pipeline returns the <|im_end|> text at the end.

Here is a minimal working example:

from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig
)
from peft import PeftModelForCausalLM

# load mixtral quantized because inferencing on a single GPU
bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16
)

model = AutoModelForCausalLM.from_pretrained(
        "mistralai/Mixtral-8x7B-Instruct-v0.1", torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2", 
        trust_remote_code=True, quantization_config=bnb_config,
)

# load the custom LoRA adapter for the fine-tuned chatml model
lora_model = PeftModelForCausalLM.from_pretrained(model, '/workspace/chatml-lora-checkpoint')

# load the tokenizer with the custom chatml format
tokenizer = AutoTokenizer.from_pretrained('mistralai/Mixtral-8x7B-Instruct-v0.1')
tokenizer.chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
tokenizer.pad_token = tokenizer.eos_token

# finally, load the pipeline and try inferencing
generator = pipeline("conversational", model=lora_model, tokenizer=tokenizer)

output = generator([
    {
        'role': 'user',
        'content': 'Hello, how are you today?'
    }
])

print(output)

Output:

Conversation id: 7dc0e9fd-9d79-49c8-b4e1-a01b6ed63c98
user: Hello, how are you today?
assistant: I'm an artificial intelligence. How can I assist you today?<|im_end|>

After troubleshooting, I noticed in postprocess function of the conversational pipeline

def postprocess(self, model_outputs, clean_up_tokenization_spaces=True):
        output_ids = model_outputs["output_ids"]
        answer = self.tokenizer.decode(
            output_ids[0],
            skip_special_tokens=True,
            clean_up_tokenization_spaces=clean_up_tokenization_spaces,
        )
        conversation = model_outputs["conversation"]
        conversation.add_message({"role": "assistant", "content": answer})
        return conversation

The decoded answer has skip_special_tokens as True. So, to solve this issue, I considered adding <|im_end|> as a special token. However, the model itself wasn't trained on this token, and <|im_end|> was originally encoded as multiple tokens.

Before coming across this issue, I wanted to have the model consider <|im_end|> as a custom stopping token. In the process of implementing this, i realized that my model, which sometimes outputted <|im_end|> as \n<|im_end|> or \n\n<|im_end|> (variable number of \n's), which were each tokenized differently than <|im_end|> by itself.

print({
    'no new line': tokenizer('<|im_end|>', add_special_tokens=False)['input_ids'],
    'one new line': tokenizer('\n<|im_end|>', add_special_tokens=False)['input_ids'],
    'two new lines': tokenizer('\n\n<|im_end|>', add_special_tokens=False)['input_ids']
})
{
 'no new line': [523, 28766, 321, 28730, 416, 28766, 28767],
 'one new line': [28705, 13, 28789, 28766, 321, 28730, 416, 28766, 28767],
 'two new lines': [28705, 13, 13, 28789, 28766, 321, 28730, 416, 28766, 28767]
}

Notice how with new lines, the 523 token becomes 28789, which is preceeded by 28705 and a number of 13's. This means that having this as a special token is nearly impossible to do with the intended functionality of it ignoring the end token when post processing despite new lines. The main way to make it work, at least to me, would be to add custom logic for processing the token which is capable of handling the new line tokens.

In order to combat this for my early stopping, I decided to take the easy way out and decode the tokenized input_ids to see if the end contained my custom stop token:

from transformers import StoppingCriteria, StoppingCriteriaList

class StoppingCriteriaSub(StoppingCriteria):
    def __init__(self, stops = [], encounters=1, tokenizer=None):
        super().__init__()
        self.stops = stops
        self.ENCOUNTERS = encounters
        self.tokenizer = tokenizer

        assert tokenizer is not None, "Tokenizer is required"

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
        stop_count = 0
        for input_ids_list in input_ids:
            for stop in self.stops:
                length = len(stop) + 5 # buffer for special tokens preceeding stop
                
                if len(input_ids_list) < length:
                    continue

                last_elements = input_ids_list[-length:]
                decoded_elements = self.tokenizer.decode(last_elements)

                if stop in decoded_elements:
                    stop_count += 1

        if stop_count >= self.ENCOUNTERS:
            return True

        return False

stop_words = ["<|im_end|>"]
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words, tokenizer=tokenizer)])

The code above works but it doesn't feel like the best method of solving this.

Expected behavior

I would like for there to be the potential of custom removing the <|im_end|> text at the end, despite the tokenization differences with new lines.

Metadata

Metadata

Assignees

No one assigned

    Labels

    WIPLabel your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions