Experimental symbolic tracing feature with torch.fx for BERT, ELECTRA and T5#11475
Conversation
src/transformers/file_utils.py
Outdated
There was a problem hiding this comment.
this one is a bit tricky since it's a recent addition in pytorch, so it can only be loaded if the pytorch version is right.
Do you know if we want pt-1.8.0 or higher? I think it should be 1.8.0 - we can adjust later if need be.
and of course it'd impact if isinstance below - so probably split isinstance in 2 parts and condition if isinstance(x, torch.fx.Proxy) check on pytorch version
typically we do it by implementing is_torch_fx_available - see a whole bunch of those in file_utils.py
So
if is_torch_fx_available():
import torch.fx
....
if is_torch_fx_available() and isinstance(x, torch.fx.Proxy):
return True
and version you get from:
if version.parse(torch.__version__) >= version.parse("1.8"):
so now you can implement is_torch_fx_available in file_utiles.py and import it here.
There was a problem hiding this comment.
Do we know if the fx-friendly version is slower and thus we need both?
And we need here and the import on top to add if is_torch_fx_available ...
stas00
left a comment
There was a problem hiding this comment.
So torch.fx can't work with modular interface and has to be replaced with functional? Asking since you replaced many CrossEntropyLoss() with F.cross_entropy?
|
Let's also:
|
d711cf4 to
e98218c
Compare
There was a problem hiding this comment.
If we are going to use this combo a lot, then down the road we could consider combining these 2 with a helper, so we won't need to repeat the code:
if is_torch_fx_proxy(input_ids): ...
but it's probably perfect as it is for now.
dc99b65 to
fae6f03
Compare
…pdated the models that were causing utils/check_copies.py to complain
33145ec to
aaadd24
Compare
stas00
left a comment
There was a problem hiding this comment.
Looks ready! Great work, @michaelbenayoun
sgugger
left a comment
There was a problem hiding this comment.
Thanks for adding this experimental feature. My main problem with the PR are the tests: we should have all those new tests refactored in one common test, which will also make it easier to add support for tracing to new architectures in the future.
| assert frame is not None | ||
| calling_frame = frame.f_back | ||
| assert calling_frame is not None |
There was a problem hiding this comment.
We don't merge empty asserts (except in tests) so please add a helpful error message :-)
There was a problem hiding this comment.
These are verbatim copy from the original pytorch implementation, but yes, it'd definitely be helpful to improve those.
@michaelbenayoun, if you think this code is a keeper let's then do a better error handling then.
| assert frame is not None | ||
| calling_frame = frame.f_back | ||
| assert calling_frame is not None |
tests/test_modeling_bert.py
Outdated
| ) | ||
| self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices)) | ||
|
|
||
| def create_and_check_tracing_for_causal_lm( |
There was a problem hiding this comment.
This looks like it could be a common test instead of repeating the same thing for all classes. I suggest adding an attributed test_tracing in the common tester and a common test for tracing that will loop through self.all_model_classes in the test_modeling_common file.
|
My understanding was that this is experimental and as we start using this side of library we will generalize and improve things. Hence the more slack approach. Same for tests, I thought it was good to start with unique tests because the workarounds are unique and then over time as more models are ported to come up with common tests. @michaelbenayoun, one way to approach this puzzle is to create common tests for what's the same in all of them, and if something is unique to a given model then have just that tested in that model's test file. If you need help with that, please don't hesitate to ask. |
|
Even for experimental features like model parallelism, we are using common tests. This should not be different IMO. |
…ing of modules defined in the forward pass, making the whole nn.functional renaming not needed anymore
…les instanciated in the forward pass
| import os | ||
|
|
||
| import torch | ||
| import torch.nn.functional as F |
There was a problem hiding this comment.
So this one is no longer needed, is it?
There was a problem hiding this comment.
No, it shouldn't be needed anymore, I was wondering if we should keep it or not.
There was a problem hiding this comment.
probably remove it then if tests pass.
We are going to get rid of F after we merge this - will use nn.functional everywhere consistently.
|
@sgugger, Michael merged the custom tests into common_tests and significantly simplified the mods to the models - yay! So it looks ready for your review whenever you have a chance. Thank you! |
sgugger
left a comment
There was a problem hiding this comment.
This is great! I just have a few last comments on documentation/naming and this should be good to be merged!
|
|
||
|
|
||
| class CustomProxy(Proxy): | ||
| def __init__(self, node: Node, tracer: Optional[Tracer] = None): |
There was a problem hiding this comment.
A small docstring on what that object does would go a long way for the code maintainability in the future :-)
What does this custom proxy does that the torch.fx does not?
Also should if be HFProxy instead of CustomProxy?
|
|
||
|
|
||
| class CustomTracer(Tracer): | ||
| def __init__(self, batch_size=1, seqlen=[128, 128], num_choices=-1): |
There was a problem hiding this comment.
Same here for the docstring + HFTracer ?
| model: PreTrainedModel, | ||
| input_names: Optional[List[str]] = None, | ||
| batch_size: int = 1, | ||
| seqlen: Union[int, List[int]] = [128, 128], |
There was a problem hiding this comment.
Use sequence_length to go with batch_size please.
| model (:obj:`PretrainedModel`): The model to trace. | ||
| input_names (:obj:`Optional[List[str]]`): The names of the inputs of the traced model. | ||
| If input_names is None, the model dummy_inputs keys are used instead. | ||
| batch_size (:obj:`int`): The batch size of the traced model inputs. | ||
| seqlen (:obj:`Union[int, List[int]]`): The sequence length of the traced model inputs. | ||
| For Seq2Seq models with differents sequence length between the encoder and the decoder inputs, seqlen must | ||
| be [encoder_sequence_length, decoder_sequence_length]. | ||
| num_choices (:obj:`int`): The number of possible choices for MultipleChoice task. |
There was a problem hiding this comment.
If we detail the args, let's use the same style as in other places then:
| model (:obj:`PretrainedModel`): The model to trace. | |
| input_names (:obj:`Optional[List[str]]`): The names of the inputs of the traced model. | |
| If input_names is None, the model dummy_inputs keys are used instead. | |
| batch_size (:obj:`int`): The batch size of the traced model inputs. | |
| seqlen (:obj:`Union[int, List[int]]`): The sequence length of the traced model inputs. | |
| For Seq2Seq models with differents sequence length between the encoder and the decoder inputs, seqlen must | |
| be [encoder_sequence_length, decoder_sequence_length]. | |
| num_choices (:obj:`int`): The number of possible choices for MultipleChoice task. | |
| model (:obj:`PretrainedModel`): | |
| The model to trace. | |
| input_names (:obj:`List[str]`, `optional`): | |
| The names of the inputs of the traced model. If unset, the model dummy_inputs keys are used instead. | |
| batch_size (:obj:`int`, `optional`, defaults to 1): | |
| The batch size of the traced model inputs. | |
| sequence_length (:obj:`int` or :obj:`List[int]]`): | |
| The sequence length of the traced model inputs. For sequence-to-sequence models with different sequence lengths | |
| between the encoder and the decoder inputs, this must | |
| be :obj:`[encoder_sequence_length, decoder_sequence_length]`. | |
| num_choices (:obj:`int`, `optional`, defaults to -1): | |
| The number of possible choices for a multiple choice task. |
There was a problem hiding this comment.
Also, what is dummy_inputs in that docstring?
| else () | ||
| ) | ||
| all_generative_model_classes = (BertLMHeadModel,) if is_torch_available() else () | ||
| fx_ready_model_classes = all_model_classes |
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
There was a problem hiding this comment.
Very nice implementation! Love the common tests, it's clean.
Before adding other models, I would prioritize writing a small documentation (which can probably live under "Advanced guides" in the transformers doc) explaining what this is and how it can be used; this will help understandability and maintainability in the future. (This can be done in a future PR)
Of course it looks as if it is still early in the developments so no need for a very thorough doc - just enough to get a grasp of what's happening without necessarily playing with the code first.
| def test_torch_fx(self): | ||
| config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() | ||
| self._create_and_check_torch_fx_tracing(config, inputs_dict) | ||
|
|
||
| def test_torch_fx_output_loss(self): | ||
| config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() | ||
| self._create_and_check_torch_fx_tracing(config, inputs_dict, output_loss=True) |
There was a problem hiding this comment.
How heavy in terms of memory/processing power are those tests? It makes me think of the Keras tests that we had to eventually disable as it took more than a couple of minutes per model class, which wasn't feasible for CI.
There was a problem hiding this comment.
3.71s call tests/test_modeling_bert.py::BertModelTest::test_torch_fx
0.71s call tests/test_modeling_electra.py::ElectraModelTest::test_torch_fx
0.62s call tests/test_modeling_t5.py::T5ModelTest::test_torch_fx
3.56s call tests/test_modeling_bert.py::BertModelTest::test_torch_fx_output_loss
0.86s call tests/test_modeling_electra.py::ElectraModelTest::test_torch_fx_output_loss
0.58s call tests/test_modeling_t5.py::T5ModelTest::test_torch_fx_output_loss
It's interesting Bert is 6x slower than t5. Any idea why?
There was a problem hiding this comment.
Yes I would say that's because there are more model classes for BERT than for T5.
There was a problem hiding this comment.
Ah! the answer was simple then ;) Thank you!
|
Sorry for jumping in. |
|
Well, I initially wanted this in order to be able to try https://github.com/flexflow/FlexFlow, which requires symbolic tracing - but I haven't had a chance to do so yet. |
|
Got it, thanks for the explanation. |
… and T5 (huggingface#11475) Symbolic tracing feature for BERT, ELECTRA and T5 Co-authored-by: Michael Benayoun <michael@huggingface.co> Co-authored-by: Stas Bekman <stas@stason.org> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This would be also be helpful to quantize models using FX Graph Mode Quantization which automate the quantization process in Pytorch. |
|
Are these updates still functional currently? As no modeling_fx_utils.py can be seen in the source code directory. |
What does this PR do?
This PR provides a function called "symbolic_trace" which enables symbolic tracing for models of the library using the new and still experimental torch.fx feature. Our models can't be symbolically traces directly using
torch.fx, so this is wrapper function that overcomes various issues.This new feature allows to perform many kinds of transformations to the graph.
It's also needed for projects like https://github.com/flexflow/FlexFlow/
As an experiment currently only three models are supported: BERT, ELECTRA and T5 (support for other models will follow soon).