fix jit trace error for model forward sequence is not aligned with jit.trace tuple input sequence, update related doc#19891
Conversation
|
@liangan1 @jianan-gu @yao-matrix please help review |
|
The documentation is not available anymore as the PR was closed or merged. |
da2172f to
a603c98
Compare
sgugger
left a comment
There was a problem hiding this comment.
This PR breaks the JIT eval in pretty much all cases unless on PyTorch nighly, unless label smoothing is activated.
Label smoothing is a training technique, so it's very unlikely a user will combine a jit eval with label smoothing (when jsut running evaluation). This PR therefore can't be accepted as it is.
src/transformers/trainer.py
Outdated
There was a problem hiding this comment.
This mean the whole eval will fail in 99.99% of the case when users do not use label smoothing (it's a training technique)
There was a problem hiding this comment.
the fix of jit trace failure is pytorch/pytorch#81623, I add it to the doc to let user know it, which is expected to be merged in pytorch1.14.0. However, we will catch the trace failure issue and fallback to the original path
There was a problem hiding this comment.
I still do not understand the main problem: it looks like JIT does not support dictionary inputs which are used in every model in Transformers. Classification models are not the only ones using the labels key, all task-specific models do... and a model that a user wants to evaluate will very likely have a dataset with labels. The proposed workaround to use label smoothing makes no sense for an evaluation.
It looks like this integration has maybe be merged too quickly and doesn't actually work or are there models that can be evaluated with it?
src/transformers/trainer.py
Outdated
There was a problem hiding this comment.
This is the logic that is in self.autocast_smart_context_manager() if I'm not mistaken so let's no change anything here.
There was a problem hiding this comment.
Hi, sgugger, the key point is we need set "cache_enabled=False" here, self.autocast_smart_context_manager() do not have such argument setting. or else, torch+bf16/fp16+jit.trace will fail here.
There was a problem hiding this comment.
Can we add that argument to the method then, please? It would make this more readable.
yes. all the cases containing "labels" will fail in jit.trace, while other case like QnA could pass. it's pytorch limitation for jit.trace which only support tuple input now, Intel has commited a PR(pytorch/pytorch#81623) for this and expected to be released in pytorch 1.14 (I also added it in doc). If we would like to jit.trace successfully for such case, the other option is to modify the model like below, making forward input sequence like tuple input sequence..., --- a/src/transformers/models/distilbert/modeling_distilbert.py
+++ b/src/transformers/models/distilbert/modeling_distilbert.py
@@ -731,11 +731,11 @@ class DistilBertForSequenceClassification(DistilBertPreTrainedModel):
)
def forward(
self,
+ labels: Optional[torch.LongTensor] = None,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
- labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,"label smoothing" is just a smart way to walk around the jit.trace failure, since it happens to pop the labels from the input |
|
We are not going to make a breaking change in the parameter order of every model. So basically the jit eval functionality added in #17753 does not work and has never worked for any model which contain labels, can you confirm? Since it is an evaluation function, I fail to see the point of having it in Transformers until PyTorch supports it. |
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
ee330ed to
8217c5e
Compare
|
The key point of the jit error cases met here is that jit cannot well handle the case that the dictionary forward parameter order does not match the dataset input order, not specific to whether there are "labels" or not. And to improve PyTorch jit ability to solve this issue, we landed pytorch/pytorch#81623 in PyTorch; For the usage of model inference with jit, for now, there could be many cases that natively get the benefits, like models running question-answering example mentioned above; For these failed model inferences with jit cases, we are capturing this with the exception here to make it fallback and use logging to notify users; Meanwhile, these failed cases shall work when PyTorch release contains this feature, (expect in next release); Besides, bringing "label smoothing" here with jit is not that reasonable since it would be confusing for users. |
sgugger
left a comment
There was a problem hiding this comment.
You are a bit beating around the bush here: are there any models with a head where this feature can be used right now without hacks? I understand support in PyTorch is coming in the next version for dictionaries, but I think this feature was just added to early. Can the doc explicitly mention that the feature requires a nightly install?
src/transformers/trainer.py
Outdated
There was a problem hiding this comment.
Can we add that argument to the method then, please? It would make this more readable.
Hi, sgugger |
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
sgugger
left a comment
There was a problem hiding this comment.
Thanks for the precision. Could you add all of this to the documentation?
Also have one last question on the actual code.
| else: | ||
| jit_model = torch.jit.trace( | ||
| jit_model, | ||
| example_kwarg_inputs={key: example_batch[key] for key in example_batch}, | ||
| strict=False, | ||
| ) |
There was a problem hiding this comment.
I don't fully understand this path here, as example_batch is not a dict here, but you are using it as one.
There was a problem hiding this comment.
I just meet the example_batch is a class like transformers.feature_extraction_utils.BatchFeature in audio-classification example. because the BatchFeature implements the API like getitem(), I could operate it like a dict
which document would you recommend to add this, since it's not cpu specific. |
Every time the jit eval is mentioned. |
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
sgugger
left a comment
There was a problem hiding this comment.
Thanks for adding the documentation. Just left a few last nits.
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
sgugger
left a comment
There was a problem hiding this comment.
Thanks for iterating on this.
…t.trace tuple input sequence, update related doc (huggingface#19891) * fix jit trace error for classification usecase, update related doc Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * add implementation in torch 1.14.0 Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * update_doc Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * update_doc Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
Signed-off-by: Wang, Yi A yi.a.wang@intel.com
What does this PR do?
Fixes # (issue)
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.