Skip to content

Extends OpInfo architecture with reference inputs, adds them for elementwise binary operators#74280

Closed
mruberry wants to merge 7 commits intomasterfrom
reference_inputs
Closed

Extends OpInfo architecture with reference inputs, adds them for elementwise binary operators#74280
mruberry wants to merge 7 commits intomasterfrom
reference_inputs

Conversation

@mruberry
Copy link
Copy Markdown
Collaborator

@mruberry mruberry commented Mar 16, 2022

This PR extends our OpInfo test architecture with "reference inputs," an optional expansion of typical sample inputs that allows for more thorough testing. Currently only the elementwise binary operations implement an extended set of reference inputs. This PR also cleans up some smaller OpInfo-related issues, including several bugs, and it identified #74279.

A reference inputs function can be specified for an OpInfo by filling in its "reference_inputs_func" metadata. If this is done it's recommended that the reference inputs function first call the sample inputs function, then produce additional sample inputs. See reference_inputs_elementwise_binary for an example of this pattern.

In addition to implementing reference inputs for the elementwise binary operations, this PR improves their consistency and simplifies how their metadata is represented. The great majority now use a generic sample input function, and those that want extensions start by calling the generic sample input function and then adding additional samples. This removes many older sample input functions. The BinaryUfuncInfo subclass also now allows specifying scalar support more precisely, and reference inputs and error inputs are generated based on this metadata to ensure it's correct.

cc @kshitij12345 @pmeier @zou3519 @Chillee

@mruberry mruberry requested a review from ngimel March 16, 2022 03:53
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Mar 16, 2022

CI Flow Status

⚛️ CI Flow

Ruleset - Version: v1
Ruleset - File: https://github.com/pytorch/pytorch/blob/3855feaa2710d9aebd4d975a172f595cdc37699c/.github/generated-ciflow-ruleset.json
PR ciflow labels: ciflow/default
Add ciflow labels to this PR to trigger more builds:

Workflows Labels (bold enabled) Status
Triggered Workflows
linux-binary-conda ciflow/binaries, ciflow/binaries_conda, ciflow/default ✅ triggered
linux-binary-libtorch-cxx11-abi ciflow/all, ciflow/binaries, ciflow/binaries_libtorch, ciflow/default, ciflow/trunk ✅ triggered
linux-binary-libtorch-pre-cxx11 ciflow/all, ciflow/binaries, ciflow/binaries_libtorch, ciflow/default, ciflow/trunk ✅ triggered
linux-binary-manywheel ciflow/all, ciflow/binaries, ciflow/binaries_wheel, ciflow/default, ciflow/trunk ✅ triggered
linux-bionic-py3.7-clang9 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/noarch, ciflow/trunk ✅ triggered
linux-bionic-rocm4.5-py3.7 ciflow/all, ciflow/default, ciflow/linux, ciflow/rocm, ciflow/trunk ✅ triggered
linux-docs ciflow/all, ciflow/cpu, ciflow/default, ciflow/docs, ciflow/linux, ciflow/trunk ✅ triggered
linux-vulkan-bionic-py3.7-clang9 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk, ciflow/vulkan ✅ triggered
linux-xenial-cuda11.3-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-cuda11.3-py3.7-gcc7-bazel-test ciflow/all, ciflow/bazel, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-py3-clang5-mobile-build ciflow/all, ciflow/default, ciflow/linux, ciflow/mobile, ciflow/trunk ✅ triggered
linux-xenial-py3-clang5-mobile-custom-build-static ciflow/all, ciflow/default, ciflow/linux, ciflow/mobile, ciflow/trunk ✅ triggered
linux-xenial-py3.7-clang7-asan ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/sanitizers, ciflow/trunk ✅ triggered
linux-xenial-py3.7-clang7-onnx ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/onnx, ciflow/trunk ✅ triggered
linux-xenial-py3.7-gcc5.4 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-py3.7-gcc5.4-mobile-lightweight-dispatch-build ciflow/all, ciflow/cpu, ciflow/default, ciflow/libtorch, ciflow/linux, ciflow/mobile, ciflow/trunk ✅ triggered
linux-xenial-py3.7-gcc7 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-py3.7-gcc7-no-ops ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
macos-arm64-binary-conda ciflow/binaries, ciflow/binaries_conda, ciflow/default ✅ triggered
macos-arm64-binary-wheel ciflow/binaries, ciflow/binaries_wheel, ciflow/default ✅ triggered
macos-binary-conda ciflow/binaries, ciflow/binaries_conda, ciflow/default ✅ triggered
macos-binary-libtorch-cxx11-abi ciflow/binaries, ciflow/binaries_libtorch, ciflow/default ✅ triggered
macos-binary-libtorch-pre-cxx11 ciflow/binaries, ciflow/binaries_libtorch, ciflow/default ✅ triggered
macos-binary-wheel ciflow/binaries, ciflow/binaries_wheel, ciflow/default ✅ triggered
pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single ciflow/all, ciflow/android, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single-full-jit ciflow/all, ciflow/android, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
win-vs2019-cpu-py3 ciflow/all, ciflow/cpu, ciflow/default, ciflow/trunk, ciflow/win ✅ triggered
win-vs2019-cuda11.3-py3 ciflow/all, ciflow/cuda, ciflow/default, ciflow/trunk, ciflow/win ✅ triggered
windows-binary-conda ciflow/binaries, ciflow/binaries_conda, ciflow/default ✅ triggered
windows-binary-libtorch-debug ciflow/all, ciflow/binaries, ciflow/binaries_libtorch, ciflow/default, ciflow/trunk ✅ triggered
windows-binary-libtorch-release ciflow/all, ciflow/binaries, ciflow/binaries_libtorch, ciflow/default, ciflow/trunk ✅ triggered
windows-binary-wheel ciflow/all, ciflow/binaries, ciflow/binaries_wheel, ciflow/default, ciflow/trunk ✅ triggered
Skipped Workflows
caffe2-linux-xenial-py3.7-gcc5.4 ciflow/all, ciflow/cpu, ciflow/linux, ciflow/trunk 🚫 skipped
docker-builds ciflow/all, ciflow/trunk 🚫 skipped
ios-12-5-1-arm64 ciflow/all, ciflow/ios, ciflow/macos, ciflow/scheduled 🚫 skipped
ios-12-5-1-arm64-coreml ciflow/all, ciflow/ios, ciflow/macos, ciflow/scheduled 🚫 skipped
ios-12-5-1-arm64-custom-ops ciflow/all, ciflow/ios, ciflow/macos, ciflow/scheduled 🚫 skipped
ios-12-5-1-arm64-metal ciflow/all, ciflow/ios, ciflow/macos, ciflow/scheduled 🚫 skipped
ios-12-5-1-x86-64 ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
ios-12-5-1-x86-64-coreml ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
libtorch-linux-xenial-cuda10.2-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/trunk 🚫 skipped
libtorch-linux-xenial-cuda11.3-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/trunk 🚫 skipped
linux-bionic-cuda10.2-py3.9-gcc7 ciflow/all, ciflow/cuda, ciflow/linux, ciflow/slow, ciflow/trunk 🚫 skipped
linux-docs-push ciflow/all, ciflow/cpu, ciflow/linux, ciflow/scheduled 🚫 skipped
linux-xenial-cuda11.3-py3.7-gcc7-no-ops ciflow/all, ciflow/cuda, ciflow/linux, ciflow/trunk 🚫 skipped
macos-10-15-py3-arm64 ciflow/all, ciflow/macos, ciflow/trunk 🚫 skipped
macos-10-15-py3-lite-interpreter-x86-64 ciflow/all, ciflow/macos, ciflow/trunk 🚫 skipped
macos-11-py3-x86-64 ciflow/all, ciflow/macos, ciflow/trunk 🚫 skipped
parallelnative-linux-xenial-py3.7-gcc5.4 ciflow/all, ciflow/cpu, ciflow/linux, ciflow/trunk 🚫 skipped
periodic-libtorch-linux-bionic-cuda11.5-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-linux-bionic-cuda11.5-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-linux-xenial-cuda10.2-py3-gcc7-slow-gradcheck ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled, ciflow/slow, ciflow/slow-gradcheck 🚫 skipped
periodic-linux-xenial-cuda11.3-py3.7-gcc7-debug ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-win-vs2019-cuda11.5-py3 ciflow/all, ciflow/cuda, ciflow/scheduled, ciflow/win 🚫 skipped
pytorch-linux-xenial-py3-clang5-android-ndk-r19c-build ciflow/all, ciflow/android, ciflow/cpu, ciflow/linux, ciflow/trunk 🚫 skipped
pytorch-xla-linux-bionic-py3.7-clang8 ciflow/all, ciflow/cpu, ciflow/linux, ciflow/trunk, ciflow/xla 🚫 skipped

@facebook-github-bot
Copy link
Copy Markdown
Contributor

facebook-github-bot commented Mar 16, 2022

🔗 Helpful links

💊 CI failures summary and remediations

As of commit 9be5c06 (more details on the Dr. CI page):


  • 1/1 failures introduced in this PR

🕵️ 1 new failure recognized by patterns

The following CI failures do not appear to be due to upstream breakages:

See GitHub Actions build linux-xenial-py3.7-clang7-onnx / test (default, 2, 2, linux.2xlarge) (1/1)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

2022-03-20T23:32:29.1100868Z AssertionError:
2022-03-20T23:32:29.1096568Z   File "/var/lib/jenkins/workspace/test/onnx/test_pytorch_onnx_onnxruntime.py", line 194, in run_model_test
2022-03-20T23:32:29.1097493Z     ort_compare_with_pytorch(ort_outs, output, rtol, atol)
2022-03-20T23:32:29.1098057Z   File "/var/lib/jenkins/workspace/test/onnx/test_pytorch_onnx_onnxruntime.py", line 137, in ort_compare_with_pytorch
2022-03-20T23:32:29.1098457Z     [np.testing.assert_allclose(out, ort_out, rtol=rtol, atol=atol) for out, ort_out in zip(outputs, ort_outs)]
2022-03-20T23:32:29.1098819Z   File "/var/lib/jenkins/workspace/test/onnx/test_pytorch_onnx_onnxruntime.py", line 137, in <listcomp>
2022-03-20T23:32:29.1099177Z     [np.testing.assert_allclose(out, ort_out, rtol=rtol, atol=atol) for out, ort_out in zip(outputs, ort_outs)]
2022-03-20T23:32:29.1099661Z   File "/opt/conda/lib/python3.7/site-packages/numpy/testing/_private/utils.py", line 1529, in assert_allclose
2022-03-20T23:32:29.1099980Z     verbose=verbose, header=header, equal_nan=equal_nan)
2022-03-20T23:32:29.1100391Z   File "/opt/conda/lib/python3.7/site-packages/numpy/testing/_private/utils.py", line 842, in assert_array_compare
2022-03-20T23:32:29.1100686Z     raise AssertionError(msg)
2022-03-20T23:32:29.1100868Z AssertionError: 
2022-03-20T23:32:29.1101126Z Not equal to tolerance rtol=0.001, atol=1e-07
2022-03-20T23:32:29.1101264Z 
2022-03-20T23:32:29.1101353Z Mismatched elements: 1 / 32 (3.12%)
2022-03-20T23:32:29.1103034Z Max absolute difference: 192
2022-03-20T23:32:29.1103414Z Max relative difference: 1.
2022-03-20T23:32:29.1103811Z  x: array([[114,   0, 141,   9,   0,   5, 134,   7],
2022-03-20T23:32:29.1104156Z        [  0, 128,  12,   0,  36,   5,   0,   7],
2022-03-20T23:32:29.1104362Z        [132,   0, 164,   0,  20,   5,   0,   7],
2022-03-20T23:32:29.1104568Z        [  0, 128,   0,  88,   4,   5,   6,   7]], dtype=uint8)
2022-03-20T23:32:29.1104777Z  y: array([[114,   0, 141,   9,   0,   5, 134,   7],

This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

Comment thread test/test_binary_ufuncs.py Outdated
l = sample.input
r = sample.args[0]

def to_numpy(x):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why don't you use SampleInput.numpy()? (SampleInput.numpy doesn't handle bfloat16 though)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 on using SampleInput.numpy() here. It probably also should handle bfloat16 so we can use it as general method for reference tests.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure; I updated the SampleInput class's .numpy() method instead

x = torch.tensor((), dtype=dtype, device=device)
else:
if dtype.is_floating_point or dtype.is_complex:
# work around torch.randn not being implemented for bfloat16
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

randn is implemented for bfloat16

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably not worth updating this legacy function?

# TODO: update make_tensor to support extremal additions and remove this in favor of make_tensor
def _generate_input(self, shape, dtype, device, with_extremal):
if shape == ():
x = torch.tensor((), dtype=dtype, device=device)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's extremely conterintuitive to produce 0-element 1d tensor for shape=(), was scalar meant here?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just copy-pasting the same function from earlier in the file and moving it closer to where it's used because it's legacy nonsense

Comment on lines +2433 to +2434
x = torch.zeros(shape, dtype=dtype, device=device)
x[torch.randn(*shape) > 0.5] = True
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
x = torch.zeros(shape, dtype=dtype, device=device)
x[torch.randn(*shape) > 0.5] = True
x = torch.randint(2, shape, dtype=dtype, device=device)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make_tensor also handles this:

Suggested change
x = torch.zeros(shape, dtype=dtype, device=device)
x[torch.randn(*shape) > 0.5] = True
x = make_tensor(shape, dtype=dtype, device=device)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Won't fix for this PR because this "change" is just moving a pre-existing function that's already marked as needing deletion

self.assertEqual(x / s, x / y)
self.assertEqual(s / x, y / x)

# TODO: update make_tensor to support extremal additions and remove this in favor of make_tensor
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please no

Copy link
Copy Markdown
Collaborator

@pmeier pmeier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a few nits and questions inline. Otherwise LGTM. Still, the PR is massive so I'm not confident that I found everything.

Comment thread test/test_binary_ufuncs.py Outdated
l = sample.input
r = sample.args[0]

def to_numpy(x):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 on using SampleInput.numpy() here. It probably also should handle bfloat16 so we can use it as general method for reference tests.

Comment on lines +2433 to +2434
x = torch.zeros(shape, dtype=dtype, device=device)
x[torch.randn(*shape) > 0.5] = True
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make_tensor also handles this:

Suggested change
x = torch.zeros(shape, dtype=dtype, device=device)
x[torch.randn(*shape) > 0.5] = True
x = make_tensor(shape, dtype=dtype, device=device)

Comment thread test/test_binary_ufuncs.py
Comment thread torch/testing/_internal/common_methods_invocations.py Outdated
Comment thread torch/testing/_internal/common_methods_invocations.py Outdated
Comment on lines +10382 to +10383
DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion', device_type='cuda'),
)),
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this skip only be for complex64 rather than for all dtypes?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test_type_promotion isn't instantiated for different dtypes, so no (although your question makes perfect sense)

rhs = make_arg(shape_rhs, **op.rhs_make_tensor_kwargs)
broadcasts_input = (shape_lhs != torch.broadcast_shapes(shape_lhs, shape_rhs))

sample_kwargs = kwargs.get('sample_kwargs', {})
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems kwargs is only used to extract sample_kwargs out of it. Maybe we can use **sample_kwargs directly?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you're suggesting using **kwargs directly instead of extracting sample_kwargs?

Yes I agree. This is a unnecessary refinement that would be useful if an OpInfo wanted separate kwargs for the generically generated elementwise binary samples vs. the bespoke ones, but that case doesn't exist today and may never exist. I simplified it.

DecorateInfo(unittest.skip('Skipped!'), 'TestBinaryUfuncs'),
),
sample_inputs_func=sample_inputs_igamma_igammac),
# TODO: FIXME, ideally by implemented grad for both inputs
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just marking since this is a large PR and we might forget otherwise.

DecorateInfo(unittest.skip('Skipped!'), 'TestBinaryUfuncs'),
),
sample_inputs_func=sample_inputs_igamma_igammac),
# TODO: FIXME, ideally by implementing grad for both inputs
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

Comment thread torch/testing/_internal/common_methods_invocations.py
@@ -692,6 +718,7 @@ def __init__(self,

# We run the sampling functions without tracking the gradiends of the creation of inputs
self.sample_inputs_func = torch.no_grad()(sample_inputs_func)
self.reference_inputs_func = None if reference_inputs_func is None else torch.no_grad()(reference_inputs_func)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why no_grad?

Copy link
Copy Markdown
Collaborator Author

@mruberry mruberry Mar 16, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consistency with sample_inputs_func, which is also run in a no_grad() context to prevent the tensors used in SampleInputs, which are often constructed using a few operations, from having a grad history

if self.reference_inputs_func is None:
return self.sample_inputs_func(self, device, dtype, requires_grad, **kwargs)

if 'include_conjugated_inputs' in kwargs and kwargs.get('include_conjugated_inputs'):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't want to do ref testing on conjugated inputs?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It just seemed kinda extra. There's still a test that conjugated sample inputs work. How likely is it that conjugated sample inputs work but conjugated inputs don't work on an expanded set of tensors with different values?

@mruberry
Copy link
Copy Markdown
Collaborator Author

@pytorchbot merge this please

@github-actions
Copy link
Copy Markdown
Contributor

Hey @mruberry.
You've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.
For changes that are 'topic: not user facing' there is no need for a release notes label.

@mruberry mruberry added the topic: not user facing topic category label Mar 21, 2022
@pmeier pmeier deleted the reference_inputs branch March 21, 2022 06:05
facebook-github-bot pushed a commit that referenced this pull request Mar 21, 2022
…entwise binary operators (#74280)

Summary:
This PR extends our OpInfo test architecture with "reference inputs," an optional expansion of typical sample inputs that allows for more thorough testing. Currently only the elementwise binary operations implement an extended set of reference inputs. This PR also cleans up some smaller OpInfo-related issues, including several bugs, and it identified #74279.

A reference inputs function can be specified for an OpInfo by filling in its "reference_inputs_func" metadata. If this is done it's recommended that the reference inputs function first call the sample inputs function, then produce additional sample inputs. See `reference_inputs_elementwise_binary` for an example of this pattern.

In addition to implementing reference inputs for the elementwise binary operations, this PR improves their consistency and simplifies how their metadata is represented. The great majority now use a generic sample input function, and those that want extensions start by calling the generic sample input function and then adding additional samples. This removes many older sample input functions. The BinaryUfuncInfo subclass also now allows specifying scalar support more precisely, and reference inputs and error inputs are generated based on this metadata to ensure it's correct.

cc kshitij12345 pmeier zou3519 Chillee

Pull Request resolved: #74280
Approved by: https://github.com/ngimel

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/0aa3c39e5f296dd0871d0f849e295d3b7644ff2e

Reviewed By: seemethere, osalpekar

Differential Revision: D35013032

fbshipit-source-id: b47c40bc34467b8043a8a8cb26fe7c9ef2c6e0f4
shahofblah pushed a commit that referenced this pull request Mar 25, 2022
…entwise binary operators

This PR extends our OpInfo test architecture with "reference inputs," an optional expansion of typical sample inputs that allows for more thorough testing. Currently only the elementwise binary operations implement an extended set of reference inputs. This PR also cleans up some smaller OpInfo-related issues, including several bugs, and it identified #74279.

A reference inputs function can be specified for an OpInfo by filling in its "reference_inputs_func" metadata. If this is done it's recommended that the reference inputs function first call the sample inputs function, then produce additional sample inputs. See `reference_inputs_elementwise_binary` for an example of this pattern.

In addition to implementing reference inputs for the elementwise binary operations, this PR improves their consistency and simplifies how their metadata is represented. The great majority now use a generic sample input function, and those that want extensions start by calling the generic sample input function and then adding additional samples. This removes many older sample input functions. The BinaryUfuncInfo subclass also now allows specifying scalar support more precisely, and reference inputs and error inputs are generated based on this metadata to ensure it's correct.

cc @kshitij12345 @pmeier @zou3519 @Chillee

Pull Request resolved: #74280
Approved by: https://github.com/ngimel
laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 25, 2026
…entwise binary operators

This PR extends our OpInfo test architecture with "reference inputs," an optional expansion of typical sample inputs that allows for more thorough testing. Currently only the elementwise binary operations implement an extended set of reference inputs. This PR also cleans up some smaller OpInfo-related issues, including several bugs, and it identified pytorch#74279.

A reference inputs function can be specified for an OpInfo by filling in its "reference_inputs_func" metadata. If this is done it's recommended that the reference inputs function first call the sample inputs function, then produce additional sample inputs. See `reference_inputs_elementwise_binary` for an example of this pattern.

In addition to implementing reference inputs for the elementwise binary operations, this PR improves their consistency and simplifies how their metadata is represented. The great majority now use a generic sample input function, and those that want extensions start by calling the generic sample input function and then adding additional samples. This removes many older sample input functions. The BinaryUfuncInfo subclass also now allows specifying scalar support more precisely, and reference inputs and error inputs are generated based on this metadata to ensure it's correct.

cc @kshitij12345 @pmeier @zou3519 @Chillee

Pull Request resolved: pytorch#74280
Approved by: https://github.com/ngimel
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants