Skip to content

fix jit trace error for model forward sequence is not aligned with jit.trace tuple input sequence, update related doc#19891

Merged
sgugger merged 4 commits intohuggingface:mainfrom
sywangyi:jit_error
Nov 3, 2022
Merged

fix jit trace error for model forward sequence is not aligned with jit.trace tuple input sequence, update related doc#19891
sgugger merged 4 commits intohuggingface:mainfrom
sywangyi:jit_error

Conversation

@sywangyi
Copy link
Contributor

Signed-off-by: Wang, Yi A yi.a.wang@intel.com

What does this PR do?

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@sywangyi
Copy link
Contributor Author

@liangan1 @jianan-gu @yao-matrix please help review

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Oct 26, 2022

The documentation is not available anymore as the PR was closed or merged.

@sywangyi sywangyi force-pushed the jit_error branch 2 times, most recently from da2172f to a603c98 Compare October 26, 2022 12:40
Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR breaks the JIT eval in pretty much all cases unless on PyTorch nighly, unless label smoothing is activated.
Label smoothing is a training technique, so it's very unlikely a user will combine a jit eval with label smoothing (when jsut running evaluation). This PR therefore can't be accepted as it is.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This mean the whole eval will fail in 99.99% of the case when users do not use label smoothing (it's a training technique)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the fix of jit trace failure is pytorch/pytorch#81623, I add it to the doc to let user know it, which is expected to be merged in pytorch1.14.0. However, we will catch the trace failure issue and fallback to the original path

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still do not understand the main problem: it looks like JIT does not support dictionary inputs which are used in every model in Transformers. Classification models are not the only ones using the labels key, all task-specific models do... and a model that a user wants to evaluate will very likely have a dataset with labels. The proposed workaround to use label smoothing makes no sense for an evaluation.

It looks like this integration has maybe be merged too quickly and doesn't actually work or are there models that can be evaluated with it?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the logic that is in self.autocast_smart_context_manager() if I'm not mistaken so let's no change anything here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, sgugger, the key point is we need set "cache_enabled=False" here, self.autocast_smart_context_manager() do not have such argument setting. or else, torch+bf16/fp16+jit.trace will fail here.

Copy link
Collaborator

@sgugger sgugger Nov 1, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add that argument to the method then, please? It would make this more readable.

@sywangyi
Copy link
Contributor Author

sywangyi commented Oct 28, 2022

I still do not understand the main problem: it looks like JIT does not support dictionary inputs which are used in every model in Transformers. Classification models are not the only ones using the labels key, all task-specific models do... and a model that a user wants to evaluate will very likely have a dataset with labels. The proposed workaround to use label smoothing makes no sense for an evaluation.

It looks like this integration has maybe be merged too quickly and doesn't actually work or are there models that can be evaluated with it?

yes. all the cases containing "labels" will fail in jit.trace, while other case like QnA could pass. it's pytorch limitation for jit.trace which only support tuple input now, Intel has commited a PR(pytorch/pytorch#81623) for this and expected to be released in pytorch 1.14 (I also added it in doc).

If we would like to jit.trace successfully for such case, the other option is to modify the model like below, making forward input sequence like tuple input sequence...,

--- a/src/transformers/models/distilbert/modeling_distilbert.py
+++ b/src/transformers/models/distilbert/modeling_distilbert.py
@@ -731,11 +731,11 @@ class DistilBertForSequenceClassification(DistilBertPreTrainedModel):
     )
     def forward(
         self,
+        labels: Optional[torch.LongTensor] = None,
         input_ids: Optional[torch.Tensor] = None,
         attention_mask: Optional[torch.Tensor] = None,
         head_mask: Optional[torch.Tensor] = None,
         inputs_embeds: Optional[torch.Tensor] = None,
-        labels: Optional[torch.LongTensor] = None,
         output_attentions: Optional[bool] = None,
         output_hidden_states: Optional[bool] = None,
         return_dict: Optional[bool] = None,

"label smoothing" is just a smart way to walk around the jit.trace failure, since it happens to pop the labels from the input

@sgugger
Copy link
Collaborator

sgugger commented Oct 28, 2022

We are not going to make a breaking change in the parameter order of every model. So basically the jit eval functionality added in #17753 does not work and has never worked for any model which contain labels, can you confirm?

Since it is an evaluation function, I fail to see the point of having it in Transformers until PyTorch supports it.

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
@sywangyi sywangyi force-pushed the jit_error branch 2 times, most recently from ee330ed to 8217c5e Compare October 31, 2022 04:01
@jianan-gu
Copy link
Contributor

jianan-gu commented Oct 31, 2022

The key point of the jit error cases met here is that jit cannot well handle the case that the dictionary forward parameter order does not match the dataset input order, not specific to whether there are "labels" or not. And to improve PyTorch jit ability to solve this issue, we landed pytorch/pytorch#81623 in PyTorch;

For the usage of model inference with jit, for now, there could be many cases that natively get the benefits, like models running question-answering example mentioned above;

For these failed model inferences with jit cases, we are capturing this with the exception here to make it fallback and use logging to notify users; Meanwhile, these failed cases shall work when PyTorch release contains this feature, (expect in next release);

Besides, bringing "label smoothing" here with jit is not that reasonable since it would be confusing for users.

@sywangyi
Copy link
Contributor Author

sywangyi commented Oct 31, 2022

Hi, @sgugger to make it a clear, I file a issue to record the issue I meet #19973. also I agree that "label smoothing" is a training skill and I have removed it in inference part. This PR could fix the error listed in #19973

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are a bit beating around the bush here: are there any models with a head where this feature can be used right now without hacks? I understand support in PyTorch is coming in the next version for dictionaries, but I think this feature was just added to early. Can the doc explicitly mention that the feature requires a nightly install?

Copy link
Collaborator

@sgugger sgugger Nov 1, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add that argument to the method then, please? It would make this more readable.

@sywangyi
Copy link
Contributor Author

sywangyi commented Nov 2, 2022

You are a bit beating around the bush here: are there any models with a head where this feature can be used right now without hacks? I understand support in PyTorch is coming in the next version for dictionaries, but I think this feature was just added to early. Can the doc explicitly mention that the feature requires a nightly install?

Hi, sgugger
for pytorch >= 1.14.0 (nightly version is 1.14.0). jit could benefit any models for predict and eval.
for pytorch < 1.14.0. jit could benefit models like "Question and Answer", whose forward parameter order matches the tuple input order in jit.trace. If we meet case like "text classification",whose forward parameter order does not matches the tuple input order in jit.trace in evaluation, jit trace will fail and we are capturing this with the exception here to make it fallback and use logging to notify users

@sywangyi sywangyi changed the title fix jit trace error for classification usecase, update related doc fix jit trace error for forward sequence is not aligned with jit.trace tuple input sequence, update related doc Nov 2, 2022
@sywangyi sywangyi changed the title fix jit trace error for forward sequence is not aligned with jit.trace tuple input sequence, update related doc fix jit trace error for model forward sequence is not aligned with jit.trace tuple input sequence, update related doc Nov 2, 2022
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the precision. Could you add all of this to the documentation?
Also have one last question on the actual code.

Comment on lines +1262 to +1267
else:
jit_model = torch.jit.trace(
jit_model,
example_kwarg_inputs={key: example_batch[key] for key in example_batch},
strict=False,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't fully understand this path here, as example_batch is not a dict here, but you are using it as one.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just meet the example_batch is a class like transformers.feature_extraction_utils.BatchFeature in audio-classification example. because the BatchFeature implements the API like getitem(), I could operate it like a dict

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Understood!

@sywangyi
Copy link
Contributor Author

sywangyi commented Nov 2, 2022

Thanks for the precision. Could you add all of this to the documentation? Also have one last question on the actual code.

which document would you recommend to add this, since it's not cpu specific.

@sgugger
Copy link
Collaborator

sgugger commented Nov 2, 2022

which document would you recommend to add this

Every time the jit eval is mentioned.

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding the documentation. Just left a few last nits.

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for iterating on this.

@sgugger sgugger merged commit 2564f0c into huggingface:main Nov 3, 2022
mpierrau pushed a commit to mpierrau/transformers that referenced this pull request Dec 15, 2022
…t.trace tuple input sequence, update related doc (huggingface#19891)

* fix jit trace error for classification usecase, update related doc

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* add implementation in torch 1.14.0

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* update_doc

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* update_doc

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
@sywangyi sywangyi deleted the jit_error branch November 19, 2025 04:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants