Skip to content

add position_ids#984

Closed
jiqing-feng wants to merge 7 commits intohuggingface:mainfrom
jiqing-feng:main
Closed

add position_ids#984
jiqing-feng wants to merge 7 commits intohuggingface:mainfrom
jiqing-feng:main

Conversation

@jiqing-feng
Copy link
Contributor

Some models like gpt2 and llama have "position_ids" in their generation inputs. We can add "position_ids" in the model's config to fix it.

@echarlaix @fxmarty Would you please help to review it? Thanks!

@fxmarty
Copy link
Contributor

fxmarty commented Apr 18, 2023

Hi @jiqing-feng , just to be sure, is this concerning the ONNX export or ONNX Runtime integration? Or both? In the ONNX export, we should give the option, yes.

@echarlaix
Copy link
Collaborator

Currently position_ids is ignored during ORT inference for causal LM https://github.com/huggingface/optimum/blob/v1.8.2/optimum/onnxruntime/modeling_decoder.py#L683, so when enabling it as inputs for the ONNX export we should enable it for ORTModelForCausalLM as well

@jiqing-feng
Copy link
Contributor Author

Hi @echarlaix @fxmarty thanks for your comments. I have added position_ids on ORTModelForCausalLM. Would you please help to review it? Thanks.

@fxmarty
Copy link
Contributor

fxmarty commented Apr 21, 2023

Hi @jiqing-feng , could you comment a bit on your use case? I'm not sure what's the benefit of this in the ORT integration.

While attention_mask is updated in _update_model_kwargs_for_generation, position_ids is not. Thus, in my understanding, if we are to pass a position_ids to the generate() method, it would be kept as a constant at each prepare_inputs_for_generation call in the generation, e.g. here (since we'd go into this piece of controlflow). Which is something we don't wish.

Am I misunderstanding something @gante ?

@gante
Copy link

gante commented Apr 21, 2023

@fxmarty What I write next is from a transformers point of view :)

At the moment, position_ids is a required input to support left padding. For legacy reasons, we consider it an optional input the same way attention_mask is an optional input -- if not passed, we assume all tokens are valid, and position_ids is a torch.arange. In .generate(), we delegate its creation to the model's prepare_inputs_for_generation, where it is computed from the attention_mask (e.g. here).

This means that while position_ids is never passed into .generate(), it is passed from .generate() to the model and is updated at each iteration.

@fxmarty
Copy link
Contributor

fxmarty commented Apr 21, 2023

In .generate(), we delegate its creation to the model's prepare_inputs_for_generation, where it is computed from the attention_mask (e.g. here).

So this means that executing this controlflow is absolutely necessary, right?

@KexinFeng
Copy link

Hi @jiqing-feng , could you comment a bit on your use case? I'm not sure what's the benefit of this in the ORT integration.

While attention_mask is updated in _update_model_kwargs_for_generation, position_ids is not. Thus, in my understanding, if we are to pass a position_ids to the generate() method, it would be kept as a constant at each prepare_inputs_for_generation call in the generation, e.g. here (since we'd go into this piece of controlflow). Which is something we don't wish.

Am I misunderstanding something @gante ?

@fxmarty Yeah, I was about to request for exactly the same feature. As @gante mentioned above, the position_ids (along with attention_mask) are necessary inputs to transformer based models like gpt2 to deal with left-padded input. More specifically, my feature request comes from the context described in #972.

In .generate(), we delegate its creation to the model's prepare_inputs_for_generation, where it is computed from the attention_mask (e.g. here).

So this means that executing this controlflow is absolutely necessary, right?

I don't think the controlflow of computing the position_ids from attention_mask is necessary, if this is what you referred to. As long as the traced transformer model (*.onnx) can take position_ids as an effective input, it should be good enough.

@fxmarty
Copy link
Contributor

fxmarty commented Apr 24, 2023

cc @echarlaix @michaelbenayoun

Copy link
Member

@michaelbenayoun michaelbenayoun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On my side, I am ok with adding this as long as we make sure we do not:

  • Add any break changes
  • Do not break .generate, it should work with and without the position_ids (except for the left padding case of course)

Comment on lines +242 to +256
@property
def inputs(self) -> Dict[str, Dict[int, str]]:
if self.use_past_in_inputs:
common_inputs = {"input_ids": {0: "batch_size"}}
self.add_past_key_values(common_inputs, direction="inputs")
common_inputs["attention_mask"] = {0: "batch_size", 1: "past_sequence_length + 1"}
common_inputs["position_ids"] = {0: "batch_size"}
else:
common_inputs = {
"input_ids": {0: "batch_size", 1: "sequence_length"},
"attention_mask": {0: "batch_size", 1: "sequence_length"},
"position_ids": {0: "batch_size", 1: "sequence_length"},
}
return common_inputs

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we can just make LlamaOnnxConfig imherint from GPT2OnnxConfig and override the NORMALIZED_CONFIG_CLASS class attribute?

input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
position_ids=position_ids,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.decoder is an ORTDecoder right?
If so, we also need to update its forward method to handle position_ids.

Comment on lines +661 to +668
position_ids=position_ids,
)
else:
outputs = self.decoder_with_past(
input_ids=input_ids[:, -1:],
past_key_values=past_key_values,
attention_mask=attention_mask,
position_ids=position_ids,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment.

@jiqing-feng
Copy link
Contributor Author

Hi @michaelbenayoun , thanks for your comment. I have updated position_ids in ORTdecoder forward. Could you please review it? Thanks! cc @fxmarty @gante

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@KexinFeng
Copy link

Hi everyone, @jiqing-feng @fxmarty @gante
is there any updates on this feature?

@jiqing-feng
Copy link
Contributor Author

Hi everyone, @jiqing-feng @fxmarty @gante is there any updates on this feature?

cc @michaelbenayoun

@fxmarty
Copy link
Contributor

fxmarty commented Sep 8, 2023

@jiqing-feng @KexinFeng I now realize I misunderstood your argument and motivation for the PR. This appears to me to be a major bug in Optimum which should be fixed ASAP.

@fxmarty
Copy link
Contributor

fxmarty commented Sep 14, 2023

Closing in favor of #1381

@fxmarty fxmarty closed this Sep 14, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants