FX tracing improvement#14321
Conversation
1f08935 to
474aa54
Compare
|
Hey, thanks for your PR @michaelbenayoun ! It seems there are a few failing tests (1096 😄), could you take a look at it? |
|
Currently looking into it! |
|
Fixed! |
sgugger
left a comment
There was a problem hiding this comment.
I'm not too comfortable with some of the changes in the models, especially XLNet, apart from that, the PR looks good.
In the tests, the fx_ready_model_classes seems to always be set to all_model_classes, so maybe it's time to use a boolean flag instead of a list of classes, if we always test all classes?
src/transformers/modeling_utils.py
Outdated
There was a problem hiding this comment.
| seq_ids = torch.arange(seq_length, device=device) | |
| causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None] | |
| causal_mask = torch.tril(torch.ones(batch_size, seq_length, seq_length, dtype=torch.bool, device=device)) |
Unrelated to this PR, but constructing a triangular matrix should be a bit more simple IMO (unless I'm missing something) ...
There was a problem hiding this comment.
Would be nice if we keep the code as is for now to make sure to not break anything here accidentally. Could you also run T5's and Bart's SLOW tests to be sure nothing is broken with the attention mask?
src/transformers/utils/fx.py
Outdated
There was a problem hiding this comment.
This shouldn't be true no?
src/transformers/utils/fx.py
Outdated
There was a problem hiding this comment.
| return super().__len__(self) | |
| return super().__len__(self.cache) |
Shouldn't that be something along these lines?
src/transformers/utils/fx.py
Outdated
There was a problem hiding this comment.
I'm not sure why this is here?
src/transformers/utils/fx.py
Outdated
There was a problem hiding this comment.
I'm not sure to understand how that does what it says it does?
|
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
|
Unstale comment |
… for the methods to record than what will be seen at tracing time (which would otherwise desynchronize the recorded values and the values that need to be given to the proxies during tracing, causing errors).
5d694b0 to
83aedfc
Compare
|
I am planning to try another approach to make both the code easier, and the tracing process cleaner, this will allow to add other models as well as to limit the number of bugs. |
| if "tokenization" not in str(f) and "processor" not in str(f) and "feature_extraction" not in str(f) | ||
| ] | ||
|
|
||
| def disable_fx_test(filename: Path): |
There was a problem hiding this comment.
What do you think of this @sgugger ?
The reason I added that is because symbolic_trace checks the model class before trying to trace the model to make sure it is supported.
Because the tests are copied, if a new model is created from a supported model for symbolic tracing, the test file will contain something like fx_ready = True which will trigger the torch.fx tests, all of them failing because the model class is not in the list of the supported models.
I do not think this is a good approach to automatically add the new model class to the supported models because the model implementation can be changed, so I thought that disabling the test and printing some message was a better option.
| with open(filename) as fp: | ||
| content = fp.read() | ||
| with open(filename, "w") as fp: | ||
| new_content = re.sub(r"fx_ready\s*=\s*True", "fx_ready = False", content) |
There was a problem hiding this comment.
Nit, this line should go before the second with.
LysandreJik
left a comment
There was a problem hiding this comment.
This looks good to me as long as it's 100% backwards compatible.
Pinging @patrickvonplaten and @patil-suraj for a quick look as it touches to a lot of different models.
|
|
||
| # This is the version of torch required to run torch.fx features and torch.onnx with dictionary inputs. | ||
| TORCH_FX_REQUIRED_VERSION = version.parse("1.9") | ||
| TORCH_FX_REQUIRED_VERSION = version.parse("1.10") |
There was a problem hiding this comment.
Out of curiosity, is it possible to support many different versions, or are there breaking changes in torch.fx that we have to support one version at a time?
There was a problem hiding this comment.
I can check for torch 1.9, the plan from now on is to support torch 1.10 + as fx became stable starting at this version (still need to validate that with pytorch team).
There was a problem hiding this comment.
And you probably need to change this line from == to >=.
| print( | ||
| "The tests for symbolic tracing with torch.fx were disabled, you can add those once symbolic tracing works " | ||
| "for your new model." | ||
| ) |
There was a problem hiding this comment.
Ideally this would use the logger
There was a problem hiding this comment.
I followed what was done in the script, but can definitely change that to logger if needed.
|
|
||
| if self.scale_attn_weights: | ||
| attn_weights = attn_weights / (float(value.size(-1)) ** 0.5) | ||
| attn_weights = attn_weights / (value.size(-1) ** 0.5) |
There was a problem hiding this comment.
Is this backwards compatible?
There was a problem hiding this comment.
In my opinion, this doesn't cause any problems.
When we do tracing, python values cause several problems.
I don't think there is any reason to change this value to a Python value.
There was a problem hiding this comment.
This change seems to cause the fail on mixed-precision training gpt-2 with ONNX Runtime backend. Link to the reported issue #11279.
patil-suraj
left a comment
There was a problem hiding this comment.
Went through all the modeling changes and it looks good to me!
| ) | ||
|
|
||
| pooled_logits = logits[range(batch_size), sequence_lengths] | ||
| pooled_logits = logits[torch.arange(batch_size), sequence_lengths] |
There was a problem hiding this comment.
| pooled_logits = logits[torch.arange(batch_size), sequence_lengths] | |
| pooled_logits = logits[torch.arange(batch_size, device=self.device), sequence_lengths] |
We need to make sure the tensor is on the same device no?
| ) | ||
|
|
||
| pooled_logits = logits[range(batch_size), sequence_lengths] | ||
| pooled_logits = logits[torch.arange(batch_size), sequence_lengths] |
tests/test_modeling_bert.py
Outdated
| all_generative_model_classes = (BertLMHeadModel,) if is_torch_available() else () | ||
| fx_ready_model_classes = all_model_classes | ||
| fx_dynamic_ready_model_classes = all_model_classes | ||
| fx_ready = True |
There was a problem hiding this comment.
(nit) not a huge fan of the name fx_ready - does that mean fx_compatible?
patrickvonplaten
left a comment
There was a problem hiding this comment.
Left some comments, but in general this looks good to me as well
* Change the way tracing happens, enabling dynamic axes out of the box * Update the tests and modeling xlnet * Add the non recoding of leaf modules to avoid recording more values for the methods to record than what will be seen at tracing time (which would otherwise desynchronize the recorded values and the values that need to be given to the proxies during tracing, causing errors). * Comments and making tracing work for gpt-j and xlnet * Refactore things related to num_choices (and batch_size, sequence_length) * Update fx to work on PyTorch 1.10 * Postpone autowrap_function feature usage for later * Add copyrights * Remove unnecessary file * Fix issue with add_new_model_like * Apply suggestions
What does this PR do?
This PR improves significantly the way transformers models are traced by the HFTracer (
torch.fx).This has 2 major consequences:
Because of these changes the
symbolic_tracesignature becomes easier:symbolic_trace(model: PreTrainedModel, input_names: Optional[List[str]] = None) -> GraphModuleThere is no need to specify the batch size, the sequence length or the number of choices (for multiple-choice) anymore.
The same thing can be said about the
HFTracer, which can be instantiated exactly the same way as the regulartorch.fx.Tracercan.