[JIT] Fix the clamp special case and gradient problem on None, add None to JIT#9596
[JIT] Fix the clamp special case and gradient problem on None, add None to JIT#9596wanchaol wants to merge 7 commits intopytorch:masterfrom
Conversation
18c9009 to
9e9d917
Compare
facebook-github-bot
left a comment
There was a problem hiding this comment.
@wanchaol has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
e085b8e to
da2ac8b
Compare
facebook-github-bot
left a comment
There was a problem hiding this comment.
@wanchaol has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
Does the following run correctly? |
|
@zou3519 Yes here is the output: I will also add more subtle test cases like this one to try to capture every corner case. |
zdevito
left a comment
There was a problem hiding this comment.
This looks good! I have just minor comments. If it is possible to remove the special-casing of clamp in the python frontend then we should.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/jit/operator.cpp
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/jit/operator.cpp
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/jit/script/compiler.cpp
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
7610648 to
bd853b1
Compare
facebook-github-bot
left a comment
There was a problem hiding this comment.
wanchaol has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
bd9a1cc to
fad128b
Compare
facebook-github-bot
left a comment
There was a problem hiding this comment.
wanchaol has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
facebook-github-bot
left a comment
There was a problem hiding this comment.
wanchaol has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
…JIT (#9596)
Summary:
Supersedes #8925
This PR fixes #8502, it fixes the gradients problem for clamp when passing None to the function, and add support for the NoneLiteral and NoneType in script to enable clamp tests. Now we could have corner cases like:
```python
torch.jit.script
def func():
x = torch.randn(3, 3, requires_grad=True)
y = torch.clamp(x, None, 0) # max = 0
y = torch.clamp(x, min=None, max=0)
```
In both JIT and Aten, we use Scalar(NAN) as a sentinel value when passing None type to function clamp, this is the current way we used to support None type in JIT and to solve the gradient problem when user explicitly passing None into clamp.
In JIT side, we create a tensor(NAN) and undefinedTensor if we encounter None when matching the function schema, and later in the interpreter, it will translate to Scalar(NAN) if needed.
Ideally we don't need clamp_min and clamp_max in ATenNative/Autograd and could only support clamp after this change, but since bunch of other operators (e.g. Activation.cpp, Loss.cpp) is using clamp_min in several places, we will still have the functions available, but all python invocations will only call clamp instead of clamp_min/max (with calling underlying th_max/th_min in clamp).
zdevito jamesr66a
Pull Request resolved: pytorch/pytorch#9596
Reviewed By: zdevito
Differential Revision: D8940839
Pulled By: wanchaol
fbshipit-source-id: c543a867b82e0ab8c99384773b173fdde2605d28
…JIT (pytorch#9596) Summary: Supersedes pytorch#8925 This PR fixes pytorch#8502, it fixes the gradients problem for clamp when passing None to the function, and add support for the NoneLiteral and NoneType in script to enable clamp tests. Now we could have corner cases like: ```python torch.jit.script def func(): x = torch.randn(3, 3, requires_grad=True) y = torch.clamp(x, None, 0) # max = 0 y = torch.clamp(x, min=None, max=0) ``` In both JIT and Aten, we use Scalar(NAN) as a sentinel value when passing None type to function clamp, this is the current way we used to support None type in JIT and to solve the gradient problem when user explicitly passing None into clamp. In JIT side, we create a tensor(NAN) and undefinedTensor if we encounter None when matching the function schema, and later in the interpreter, it will translate to Scalar(NAN) if needed. Ideally we don't need clamp_min and clamp_max in ATenNative/Autograd and could only support clamp after this change, but since bunch of other operators (e.g. Activation.cpp, Loss.cpp) is using clamp_min in several places, we will still have the functions available, but all python invocations will only call clamp instead of clamp_min/max (with calling underlying th_max/th_min in clamp). zdevito jamesr66a Pull Request resolved: pytorch#9596 Reviewed By: zdevito Differential Revision: D8940839 Pulled By: wanchaol fbshipit-source-id: c543a867b82e0ab8c99384773b173fdde2605d28
…JIT (pytorch#9596) Summary: Supersedes pytorch#8925 This PR fixes pytorch#8502, it fixes the gradients problem for clamp when passing None to the function, and add support for the NoneLiteral and NoneType in script to enable clamp tests. Now we could have corner cases like: ```python torch.jit.script def func(): x = torch.randn(3, 3, requires_grad=True) y = torch.clamp(x, None, 0) # max = 0 y = torch.clamp(x, min=None, max=0) ``` In both JIT and Aten, we use Scalar(NAN) as a sentinel value when passing None type to function clamp, this is the current way we used to support None type in JIT and to solve the gradient problem when user explicitly passing None into clamp. In JIT side, we create a tensor(NAN) and undefinedTensor if we encounter None when matching the function schema, and later in the interpreter, it will translate to Scalar(NAN) if needed. Ideally we don't need clamp_min and clamp_max in ATenNative/Autograd and could only support clamp after this change, but since bunch of other operators (e.g. Activation.cpp, Loss.cpp) is using clamp_min in several places, we will still have the functions available, but all python invocations will only call clamp instead of clamp_min/max (with calling underlying th_max/th_min in clamp). zdevito jamesr66a Pull Request resolved: pytorch#9596 Reviewed By: zdevito Differential Revision: D8940839 Pulled By: wanchaol fbshipit-source-id: c543a867b82e0ab8c99384773b173fdde2605d28
Supersedes #8925
This PR fixes #8502, it fixes the gradients problem for clamp when passing None to the function, and add support for the NoneLiteral and NoneType in script to enable clamp tests. Now we could have corner cases like:
In both JIT and Aten, we use Scalar(NAN) as a sentinel value when passing None type to function clamp, this is the current way we used to support None type in JIT and to solve the gradient problem when user explicitly passing None into clamp.
In JIT side, we create a tensor(NAN) and undefinedTensor if we encounter None when matching the function schema, and later in the interpreter, it will translate to Scalar(NAN) if needed.
Ideally we don't need clamp_min and clamp_max in ATenNative/Autograd and could only support clamp after this change, but since bunch of other operators (e.g. Activation.cpp, Loss.cpp) is using clamp_min in several places, we will still have the functions available, but all python invocations will only call clamp instead of clamp_min/max (with calling underlying th_max/th_min in clamp).
@zdevito @jamesr66a