Skip to content

[JIT] Fix the clamp special case and gradient problem on None, add None to JIT#9596

Closed
wanchaol wants to merge 7 commits intopytorch:masterfrom
wanchaol:clamp_none
Closed

[JIT] Fix the clamp special case and gradient problem on None, add None to JIT#9596
wanchaol wants to merge 7 commits intopytorch:masterfrom
wanchaol:clamp_none

Conversation

@wanchaol
Copy link
Collaborator

@wanchaol wanchaol commented Jul 19, 2018

Supersedes #8925

This PR fixes #8502, it fixes the gradients problem for clamp when passing None to the function, and add support for the NoneLiteral and NoneType in script to enable clamp tests. Now we could have corner cases like:

@torch.jit.script
def func():
    x = torch.randn(3, 3, requires_grad=True)
    y = torch.clamp(x, None, 0) # max = 0
    y = torch.clamp(x, min=None, max=0)

In both JIT and Aten, we use Scalar(NAN) as a sentinel value when passing None type to function clamp, this is the current way we used to support None type in JIT and to solve the gradient problem when user explicitly passing None into clamp.

In JIT side, we create a tensor(NAN) and undefinedTensor if we encounter None when matching the function schema, and later in the interpreter, it will translate to Scalar(NAN) if needed.

Ideally we don't need clamp_min and clamp_max in ATenNative/Autograd and could only support clamp after this change, but since bunch of other operators (e.g. Activation.cpp, Loss.cpp) is using clamp_min in several places, we will still have the functions available, but all python invocations will only call clamp instead of clamp_min/max (with calling underlying th_max/th_min in clamp).

@zdevito @jamesr66a

@wanchaol wanchaol added the oncall: jit Add this issue/PR to JIT oncall triage queue label Jul 19, 2018
@wanchaol wanchaol requested a review from jamesr66a July 19, 2018 19:01
@wanchaol wanchaol force-pushed the clamp_none branch 3 times, most recently from 18c9009 to 9e9d917 Compare July 20, 2018 16:47
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wanchaol has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@wanchaol wanchaol force-pushed the clamp_none branch 2 times, most recently from e085b8e to da2ac8b Compare July 24, 2018 17:39
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wanchaol has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@zou3519
Copy link
Contributor

zou3519 commented Jul 24, 2018

Does the following run correctly?

import torch
x = torch.arange(0, 10)
y = x.clamp(min=None, max=4)
print(y)

@wanchaol
Copy link
Collaborator Author

@zou3519 Yes here is the output:

In [2]: import torch
   ...: x = torch.arange(0, 10)
   ...: y = x.clamp(min=None, max=4)
   ...: print(y)
   ...:
   ...:
tensor([0, 1, 2, 3, 4, 4, 4, 4, 4, 4])

I will also add more subtle test cases like this one to try to capture every corner case.

Copy link
Contributor

@zdevito zdevito left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good! I have just minor comments. If it is possible to remove the special-casing of clamp in the python frontend then we should.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@wanchaol wanchaol force-pushed the clamp_none branch 2 times, most recently from 7610648 to bd853b1 Compare July 27, 2018 02:46
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wanchaol has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@wanchaol wanchaol force-pushed the clamp_none branch 2 times, most recently from bd9a1cc to fad128b Compare July 27, 2018 07:33
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wanchaol has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

This comment was marked as off-topic.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wanchaol has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@zdevito zdevito left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good now

@wanchaol wanchaol deleted the clamp_none branch July 28, 2018 05:56
zdevito pushed a commit to zdevito/ATen that referenced this pull request Jul 28, 2018
…JIT (#9596)

Summary:
Supersedes #8925

This PR fixes #8502, it fixes the gradients problem for clamp when passing None to the function, and add support for the NoneLiteral and NoneType in script to enable clamp tests. Now we could have corner cases like:

```python
torch.jit.script
def func():
    x = torch.randn(3, 3, requires_grad=True)
    y = torch.clamp(x, None, 0) # max = 0
    y = torch.clamp(x, min=None, max=0)
```

In both JIT and Aten, we use Scalar(NAN) as a sentinel value when passing None type to function clamp, this is the current way we used to support None type in JIT and to solve the gradient problem when user explicitly passing None into clamp.

In JIT side, we create a tensor(NAN) and undefinedTensor if we encounter None when matching the function schema, and later in the interpreter, it will translate to Scalar(NAN) if needed.

Ideally we don't need clamp_min and clamp_max in ATenNative/Autograd and could only support clamp after this change, but since bunch of other operators (e.g. Activation.cpp, Loss.cpp) is using clamp_min in several places, we will still have the functions available, but all python invocations will only call clamp instead of clamp_min/max (with calling underlying th_max/th_min in clamp).

zdevito jamesr66a
Pull Request resolved: pytorch/pytorch#9596

Reviewed By: zdevito

Differential Revision: D8940839

Pulled By: wanchaol

fbshipit-source-id: c543a867b82e0ab8c99384773b173fdde2605d28
jramseyer pushed a commit to jramseyer/pytorch that referenced this pull request Jul 30, 2018
…JIT (pytorch#9596)

Summary:
Supersedes pytorch#8925

This PR fixes pytorch#8502, it fixes the gradients problem for clamp when passing None to the function, and add support for the NoneLiteral and NoneType in script to enable clamp tests. Now we could have corner cases like:

```python
torch.jit.script
def func():
    x = torch.randn(3, 3, requires_grad=True)
    y = torch.clamp(x, None, 0) # max = 0
    y = torch.clamp(x, min=None, max=0)
```

In both JIT and Aten, we use Scalar(NAN) as a sentinel value when passing None type to function clamp, this is the current way we used to support None type in JIT and to solve the gradient problem when user explicitly passing None into clamp.

In JIT side, we create a tensor(NAN) and undefinedTensor if we encounter None when matching the function schema, and later in the interpreter, it will translate to Scalar(NAN) if needed.

Ideally we don't need clamp_min and clamp_max in ATenNative/Autograd and could only support clamp after this change, but since bunch of other operators (e.g. Activation.cpp, Loss.cpp) is using clamp_min in several places, we will still have the functions available, but all python invocations will only call clamp instead of clamp_min/max (with calling underlying th_max/th_min in clamp).

zdevito jamesr66a
Pull Request resolved: pytorch#9596

Reviewed By: zdevito

Differential Revision: D8940839

Pulled By: wanchaol

fbshipit-source-id: c543a867b82e0ab8c99384773b173fdde2605d28
goodlux pushed a commit to goodlux/pytorch that referenced this pull request Aug 15, 2018
…JIT (pytorch#9596)

Summary:
Supersedes pytorch#8925

This PR fixes pytorch#8502, it fixes the gradients problem for clamp when passing None to the function, and add support for the NoneLiteral and NoneType in script to enable clamp tests. Now we could have corner cases like:

```python
torch.jit.script
def func():
    x = torch.randn(3, 3, requires_grad=True)
    y = torch.clamp(x, None, 0) # max = 0
    y = torch.clamp(x, min=None, max=0)
```

In both JIT and Aten, we use Scalar(NAN) as a sentinel value when passing None type to function clamp, this is the current way we used to support None type in JIT and to solve the gradient problem when user explicitly passing None into clamp.

In JIT side, we create a tensor(NAN) and undefinedTensor if we encounter None when matching the function schema, and later in the interpreter, it will translate to Scalar(NAN) if needed.

Ideally we don't need clamp_min and clamp_max in ATenNative/Autograd and could only support clamp after this change, but since bunch of other operators (e.g. Activation.cpp, Loss.cpp) is using clamp_min in several places, we will still have the functions available, but all python invocations will only call clamp instead of clamp_min/max (with calling underlying th_max/th_min in clamp).

zdevito jamesr66a
Pull Request resolved: pytorch#9596

Reviewed By: zdevito

Differential Revision: D8940839

Pulled By: wanchaol

fbshipit-source-id: c543a867b82e0ab8c99384773b173fdde2605d28
@ezyang ezyang added the merged label Jun 26, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

oncall: jit Add this issue/PR to JIT oncall triage queue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[JIT] Don't support None in the script

5 participants