[FX] torch.fx.symbolic_trace patching improvements and math.* support#50793
[FX] torch.fx.symbolic_trace patching improvements and math.* support#50793jansel wants to merge 10 commits intopytorch:masterfrom
math.* support#50793Conversation
Codecov Report
@@ Coverage Diff @@
## master #50793 +/- ##
==========================================
- Coverage 80.99% 80.68% -0.32%
==========================================
Files 1916 1916
Lines 209608 209666 +58
==========================================
- Hits 169780 169168 -612
- Misses 39828 40498 +670 |
facebook-github-bot
left a comment
There was a problem hiding this comment.
@jansel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
@jansel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
…02101 Conflicts: test/test_fx.py
facebook-github-bot
left a comment
There was a problem hiding this comment.
@jansel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
| # this captures both `math.sqrt()` and `from math import sqrt` automatically | ||
| _autowrap_search : List[dict] = [math.__dict__] | ||
| _autowrap_function_ids : Set[int] = {id(value) for name, value in math.__dict__.items() | ||
| if not name.startswith("_") and callable(value)} |
There was a problem hiding this comment.
One thing I tried to convey in our meeting but don't know if I got through was: can we make this configurable from the symbolic_trace/Tracert.trace APIs. E.g. take a parameter of List[module] or List[str] that refers to modules that the user wants to be patched automatically. This could be [math] by default. This way, this behavior is obvious from the API and is also configurable
There was a problem hiding this comment.
You prefer an option to Tracer() as opposed to extending the wrap() API?
There was a problem hiding this comment.
I moved it to an option to Tracer()
facebook-github-bot
left a comment
There was a problem hiding this comment.
@jansel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
…02101 Conflicts: test/test_fx.py
facebook-github-bot
left a comment
There was a problem hiding this comment.
@jansel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
…rt (pytorch#50793) Summary: This contains some improvements and refactoring to how patching is done in `torch.fx.symbolic_trace`. 1) Functions from `math.*` are now supported without needing to call `torch.fx.wrap()`. `wrap()` actually errors on some of these function because they are written in C and don't have `__code__` requiring use of the string version. `math` usage is relatively common, for example [BERT uses math.sqrt here](https://github.com/pytorch/benchmark/blob/6f79061bd145eeaa9b4a75847939901fd245ddf9/torchbenchmark/models/BERT_pytorch/bert_pytorch/model/attention/single.py#L16). Both `math.sqrt()` and `from math import sqrt` (copying to module namespace) are supported. When modules are called FX now searches the module's global scope to find methods to patch. 2) [Guarded behind `env FX_PATCH_GETITEM=1`] Fixes a failed trace of [PositionalEmbedding from BERT](https://github.com/pytorch/benchmark/blob/6f79061bd145eeaa9b4a75847939901fd245ddf9/torchbenchmark/models/BERT_pytorch/bert_pytorch/model/embedding/position.py#L24), which failed to trace with the error `TypeError: slice indices must be integers or None or have an __index__ method` (a Proxy() is getting passed into `Tensor.__getitem__`). See pytorch#50710 for why this is disabled by default. 3) Support for automatically wrapping methods that may have been copied to a different module scope via an import like `from foo import wrapped_function`. This also isn't exposed in `torch.fx.wrap`, but is used to implement `math.*` support. Pull Request resolved: pytorch#50793 Test Plan: Added unittests to check each feature Reviewed By: jamesr66a Differential Revision: D25999788 Pulled By: jansel fbshipit-source-id: f1ce11a69b7d97f26c9e2741c6acf9c513a84467
This contains some improvements and refactoring to how patching is done in
torch.fx.symbolic_trace.Functions from
math.*are now supported without needing to calltorch.fx.wrap().wrap()actually errors on some of these function because they are written in C and don't have__code__requiring use of the string version.mathusage is relatively common, for example BERT uses math.sqrt here. Bothmath.sqrt()andfrom math import sqrt(copying to module namespace) are supported. When modules are called FX now searches the module's global scope to find methods to patch.[Guarded behind
env FX_PATCH_GETITEM=1] Fixes a failed trace of PositionalEmbedding from BERT, which failed to trace with the errorTypeError: slice indices must be integers or None or have an __index__ method(a Proxy() is getting passed intoTensor.__getitem__). See Quantized ops fail ifTensor.__getitem__has previously been patched #50710 for why this is disabled by default.Support for automatically wrapping methods that may have been copied to a different module scope via an import like
from foo import wrapped_function. This also isn't exposed intorch.fx.wrap, but is used to implementmath.*support.Test Plan: Added unittests to check each feature