Skip to content

[FX] torch.fx.symbolic_trace patching improvements and math.* support#50793

Closed
jansel wants to merge 10 commits intopytorch:masterfrom
jansel:mathsupport202101
Closed

[FX] torch.fx.symbolic_trace patching improvements and math.* support#50793
jansel wants to merge 10 commits intopytorch:masterfrom
jansel:mathsupport202101

Conversation

@jansel
Copy link
Copy Markdown
Contributor

@jansel jansel commented Jan 20, 2021

This contains some improvements and refactoring to how patching is done in torch.fx.symbolic_trace.

  1. Functions from math.* are now supported without needing to call torch.fx.wrap(). wrap() actually errors on some of these function because they are written in C and don't have __code__ requiring use of the string version. math usage is relatively common, for example BERT uses math.sqrt here. Both math.sqrt() and from math import sqrt (copying to module namespace) are supported. When modules are called FX now searches the module's global scope to find methods to patch.

  2. [Guarded behind env FX_PATCH_GETITEM=1] Fixes a failed trace of PositionalEmbedding from BERT, which failed to trace with the error TypeError: slice indices must be integers or None or have an __index__ method (a Proxy() is getting passed into Tensor.__getitem__). See Quantized ops fail if Tensor.__getitem__ has previously been patched  #50710 for why this is disabled by default.

  3. Support for automatically wrapping methods that may have been copied to a different module scope via an import like from foo import wrapped_function. This also isn't exposed in torch.fx.wrap, but is used to implement math.* support.

Test Plan: Added unittests to check each feature

@codecov
Copy link
Copy Markdown

codecov Bot commented Jan 21, 2021

Codecov Report

Merging #50793 (be60db1) into master (4bbff92) will decrease coverage by 0.31%.
The diff coverage is 89.10%.

@@            Coverage Diff             @@
##           master   #50793      +/-   ##
==========================================
- Coverage   80.99%   80.68%   -0.32%     
==========================================
  Files        1916     1916              
  Lines      209608   209666      +58     
==========================================
- Hits       169780   169168     -612     
- Misses      39828    40498     +670     

Copy link
Copy Markdown
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jansel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Copy Markdown
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jansel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Copy Markdown
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jansel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Copy Markdown
Collaborator

@jamesr66a jamesr66a left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you split this off into separate patches? It's hard to tell from the GitHub diff view which changes are related to which high-level concepts

Comment thread torch/fx/symbolic_trace.py
Copy link
Copy Markdown
Collaborator

@jamesr66a jamesr66a left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice

Comment thread torch/fx/symbolic_trace.py Outdated
Comment thread torch/fx/symbolic_trace.py Outdated
Comment thread torch/fx/symbolic_trace.py Outdated
# this captures both `math.sqrt()` and `from math import sqrt` automatically
_autowrap_search : List[dict] = [math.__dict__]
_autowrap_function_ids : Set[int] = {id(value) for name, value in math.__dict__.items()
if not name.startswith("_") and callable(value)}
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing I tried to convey in our meeting but don't know if I got through was: can we make this configurable from the symbolic_trace/Tracert.trace APIs. E.g. take a parameter of List[module] or List[str] that refers to modules that the user wants to be patched automatically. This could be [math] by default. This way, this behavior is obvious from the API and is also configurable

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You prefer an option to Tracer() as opposed to extending the wrap() API?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved it to an option to Tracer()

Copy link
Copy Markdown
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jansel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Copy Markdown
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jansel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Copy Markdown
Contributor

@jansel merged this pull request in a66851a.

laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 24, 2026
…rt (pytorch#50793)

Summary:
This contains some improvements and refactoring to how patching is done in `torch.fx.symbolic_trace`.

1) Functions from `math.*` are now supported without needing to call `torch.fx.wrap()`.  `wrap()` actually errors on some of these function because they are written in C and don't have `__code__` requiring use of the string version.  `math` usage is relatively common, for example [BERT uses math.sqrt here](https://github.com/pytorch/benchmark/blob/6f79061bd145eeaa9b4a75847939901fd245ddf9/torchbenchmark/models/BERT_pytorch/bert_pytorch/model/attention/single.py#L16).  Both `math.sqrt()` and `from math import sqrt` (copying to module namespace) are supported.  When modules are called FX now searches the module's global scope to find methods to patch.

2) [Guarded behind `env FX_PATCH_GETITEM=1`] Fixes a failed trace of [PositionalEmbedding from BERT](https://github.com/pytorch/benchmark/blob/6f79061bd145eeaa9b4a75847939901fd245ddf9/torchbenchmark/models/BERT_pytorch/bert_pytorch/model/embedding/position.py#L24), which failed to trace with the error `TypeError: slice indices must be integers or None or have an __index__ method` (a Proxy() is getting passed into `Tensor.__getitem__`).  See pytorch#50710 for why this is disabled by default.

3) Support for automatically wrapping methods that may have been copied to a different module scope via an import like `from foo import wrapped_function`.  This also isn't exposed in `torch.fx.wrap`, but is used to implement `math.*` support.

Pull Request resolved: pytorch#50793

Test Plan: Added unittests to check each feature

Reviewed By: jamesr66a

Differential Revision: D25999788

Pulled By: jansel

fbshipit-source-id: f1ce11a69b7d97f26c9e2741c6acf9c513a84467
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants