Skip to content

patching clamp for one sided clamp#75558

Closed
jjsjann123 wants to merge 6 commits intopytorch:masterfrom
jjsjann123:clamp_patch
Closed

patching clamp for one sided clamp#75558
jjsjann123 wants to merge 6 commits intopytorch:masterfrom
jjsjann123:clamp_patch

Conversation

@jjsjann123
Copy link
Collaborator

Fixes #75088

The solution is just to avoid putting random value for non-specified clamp as pointed out in #75088 (comment)

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Apr 9, 2022

🔗 Helpful links

💊 CI failures summary and remediations

As of commit 17c2c80 (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@facebook-github-bot facebook-github-bot added the oncall: jit Add this issue/PR to JIT oncall triage queue label Apr 9, 2022
@samdow samdow added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Apr 11, 2022
@ngimel
Copy link
Collaborator

ngimel commented Apr 11, 2022

Cool, can you do a more systematic testing for extreme values? E.g. min/max look like they should be codegening fmin/fmax and thus handle nans correctly, but without systematic testing it's hard to be sure, there might be other ops like relu/hardtanh/hardsigmoid and the like that might or might not do correct things for nan/inf propagation, or produces 0-s in the non-differentiable points?

@jjsjann123
Copy link
Collaborator Author

Cool, can you do a more systematic testing for extreme values? E.g. min/max look like they should be codegening fmin/fmax and thus handle nans correctly, but without systematic testing it's hard to be sure, there might be other ops like relu/hardtanh/hardsigmoid and the like that might or might not do correct things for nan/inf propagation, or produces 0-s in the non-differentiable points?

I think our unary/binary test do cover all these special numbers.

self.special_values = torch.tensor(
[float("-inf"), -10, -math.pi,
-1, -0.5, 0, 1, 0.5,
math.pi, 10, float("inf"),
float("nan")], dtype=torch.float, device=dev)

if random_data:
x = torch.rand(shape, dtype=torch.float32, device="cuda", requires_grad=gradient_check)
if dtype in self.int_types:
# prefer a larger variance for integer types
x = x * 5
x = x.to(dtype=dtype)
else:
x = self.special_values.to(dtype=dtype)
try:
ref = t(x, y)
except Exception:
# same way as TE checker, if eager mode throws, ignore this test
return
t_jit = torch.jit.script(t)
jit_o = t_jit(x, y)
jit_o = t_jit(x, y)
jit_o = t_jit(x, y)
if gradient_check:
gradcheck(t_jit, [x, y], nondet_tol=1e-5)
elif dtype in self.support_tensor_dtypes:
self.assertGraphContains(t_jit.graph_for(x, y), FUSION_GUARD)
o = t(x, y)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertTrue(self._compare("failing case {}\n{}\n{}\n{}".format(dtype, operation, x, y), o, jit_o, 1e-2))

if random_data:
x = (torch.randn(shapex, dtype=torch.float, device="cuda") * 5).to(dtype_arg1)
y = (torch.randn(shapey, dtype=torch.float, device="cuda") * 5).to(dtype_arg2)
else:
x = self.special_values.to(dtype=dtype_arg1)
y = (torch.rand_like(self.special_values) * 5).to(dtype_arg2)

hardtanh is not in our parser yet. But I do agree that we should probably having a generalized helper function for ternary ops as well. I'll update the tests then.

@ngimel
Copy link
Collaborator

ngimel commented Apr 11, 2022

Apparently, they weren't covering these special numbers, or these failure wouldn't have happened?
Another case where I see discrepancy between nvfuser and eager is when min is greater than max:

In [25]: def fn(x):
    ...:     x=x.clamp(min=1., max=0.5)*.1
    ...:     return x
    ...: 
    ...: a=torch.tensor([1.,float('inf'), 2., float('inf')], device="cuda")
    ...: scripted = torch.jit.script(fn)
    ...: fn(a)
    ...: with torch.jit.fuser("fuser2"):
    ...:     for _ in range(10):
    ...:         scripted(a)
    ...: print(fn(a))
    ...: print(scripted(a))
tensor([0.0500, 0.0500, 0.0500, 0.0500], device='cuda:0')
tensor([0.0500, 0.0500, 0.0500, 0.0500], device='cuda:0')

In [26]: print(fn(x.cuda()))
tensor([0.0500, 0.0500, 0.0500, 0.0500], device='cuda:0')

In [27]: print(scripted(x.cuda()))
tensor([0.1000, 0.1000, 0.1000, 0.0500], device='cuda:0')

In [28]: x
Out[28]: tensor([-0.3980, -0.2727,  0.1300,  2.0310])

Note, although this is arguably an error case, numpy, jax and torch eager all have the same behavior in this case, with nvfuser being an outlier.

@jjsjann123
Copy link
Collaborator Author

clamp is a ternary op which is not covered by unary/binary tests and that's what I was promising to update in our tests.

x=x.clamp(min=1., max=0.5) That's a confusing case... I think we just need to add two lines to regulate the min/max range. I'll update that to aten logic... But if these cases are defined behavior, maybe we should update opinfo tests to have it backed up by CI?

@ngimel
Copy link
Collaborator

ngimel commented Apr 11, 2022

As far as NVFuser is concerned, clamp is a unary op with kwargs, NVFuser accepts only (Tensor, Scalar, Scalar) overload.

@ngimel
Copy link
Collaborator

ngimel commented Apr 11, 2022

And yes, adding a test for this behavior (since it's documented) is necessary, I wasn't able to find existing one. cc @mruberry

bool has_high = value_map.count(node->inputs()[2]->unique()) != 0;
Val* high = has_high
? *value_map[node->inputs()[2]->unique()]
: IrBuilder::create<Double>(std::numeric_limits<float>::max());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You still have these very misleading high and low assignments, you should be able to remove them?

@jjsjann123 jjsjann123 requested a review from ngimel April 11, 2022 23:24
@jjsjann123
Copy link
Collaborator Author

@pytorchbot merge this

@github-actions
Copy link
Contributor

Hey @jjsjann123.
You've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.
For changes that are 'topic: not user facing' there is no need for a release notes label.

facebook-github-bot pushed a commit that referenced this pull request Apr 13, 2022
Summary:
Fixes #75088

The solution is just to avoid putting random value for non-specified clamp as pointed out in #75088 (comment)

Pull Request resolved: #75558
Approved by: https://github.com/ngimel

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/0203341bbde7cebdb9a04d9f021797b8bea7de2f

Reviewed By: mehtanirav

Differential Revision: D35582770

fbshipit-source-id: d7781249f568c8cf28ecbf4bbce3c4d3b0f947ce
jjsjann123 added a commit to csarofeen/pytorch that referenced this pull request Apr 14, 2022
Fixes pytorch#75558 (comment)

Updated clamp logic to be consistent with aten. This avoids us producing different result when clamp was given a min/max argument where min > max.
We don't have an issue opened for this. It is also tricky to argue the right behavior, but getting better consistency with eager is always(?!) a good thing.
jjsjann123 added a commit to csarofeen/pytorch that referenced this pull request Apr 17, 2022
Fixes pytorch#75088

The solution is just to avoid putting random value for non-specified clamp as pointed out in pytorch#75088 (comment)

Pull Request resolved: pytorch#75558
Approved by: https://github.com/ngimel
jjsjann123 added a commit to csarofeen/pytorch that referenced this pull request Apr 18, 2022
Fixes pytorch#75088

The solution is just to avoid putting random value for non-specified clamp as pointed out in pytorch#75088 (comment)

Pull Request resolved: pytorch#75558
Approved by: https://github.com/ngimel
jjsjann123 added a commit to csarofeen/pytorch that referenced this pull request Apr 18, 2022
Fixes pytorch#75088

The solution is just to avoid putting random value for non-specified clamp as pointed out in pytorch#75088 (comment)

Pull Request resolved: pytorch#75558
Approved by: https://github.com/ngimel
jjsjann123 added a commit to jjsjann123/nvfuser that referenced this pull request Oct 29, 2022
Fixes #75088

The solution is just to avoid putting random value for non-specified clamp as pointed out in pytorch/pytorch#75088 (comment)

Pull Request resolved: pytorch/pytorch#75558
Approved by: https://github.com/ngimel
jjsjann123 added a commit to jjsjann123/nvfuser that referenced this pull request Nov 10, 2022
Fixes #75088

The solution is just to avoid putting random value for non-specified clamp as pointed out in pytorch/pytorch#75088 (comment)

Pull Request resolved: pytorch/pytorch#75558
Approved by: https://github.com/ngimel
jjsjann123 added a commit to jjsjann123/nvfuser that referenced this pull request Nov 10, 2022
Fixes #75088

The solution is just to avoid putting random value for non-specified clamp as pointed out in pytorch/pytorch#75088 (comment)

Pull Request resolved: pytorch/pytorch#75558
Approved by: https://github.com/ngimel
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed oncall: jit Add this issue/PR to JIT oncall triage queue open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

NVFuser produces wrong outputs for extreme values in clamp

5 participants