Skip to content

[primTorch] Adds random operations#78026

Closed
mruberry wants to merge 15 commits intomasterfrom
primtorch_random
Closed

[primTorch] Adds random operations#78026
mruberry wants to merge 15 commits intomasterfrom
primtorch_random

Conversation

@mruberry
Copy link
Collaborator

@mruberry mruberry commented May 20, 2022

This PR...

Issues Found

Testing

  • disables stride consistency checks in test_ops and test_meta pending resolution of RFC: [primTorch] Stride-agnostic Operator Semantics #78050
  • skips chalf in reference tests (addressing [primTorch] many primTorch test xfails are due to chalf #78054)
  • splits test test_python_reference_consistency in one test for the ctx where torch.foo is torch.foo, and another for when torch.foo is refs.foo
  • updates test names to be more natural and consistent:
    • test_python_reference_errors -> test_python_ref_errors
    • test_python_reference_consistency -> test_python_ref and test_python_ref_torch_fallback
    • test_python_reference_meta_functions -> test_python_ref_meta
    • test_reference_testing -> test_numpy_ref
  • updates test_python_ref and test_python_ref_torch_fallback to check that the reference is more accurate than the torch op if the reference and torch op results are not close, a warning is raised when this occurs (addressing PrimTorch's test_ops.py reference_consistency testing is worse than test_decomps.py testing #77687)
  • adds reference inputs for broadcast_tensors
  • Updates the "fill_" OpInfo to "fill", adding a NumPy reference and making it an elementwise unary operator
  • Adds 1D no element sample inputs to the cat OpInfo and updates the NumPy reference to handle them and type promotion correctly
  • Adds reference inputs for elementwise ternary operations, like clamp
  • Adds a NumPy reference for clamp
  • Adds reference inputs to where's OpInfo
  • Makes softplus an elementwise unary OpInfo
  • Removes the great majority of Python reference OpInfo skips and xfails due to the above test changes
  • Adds Python reference OpInfos for fill, dropout, clamp, broadcast_tensors, and where

Prims

  • adds the fill, empty_strided, and uniform prims
  • removes the empty, empty_like, full, and full_like prims -- these are now references that use empty_strided and fill
  • renames the "concatenate" and "select" prims to "cat" and "where", respectively, to be consistent with PyTorch
  • extends the _elementwise_meta operation to accepts tensors that don't participate in type promotion, like the cond tensor in where
  • fixes a bug in the stride propagation of broadcast_in_dim
  • moves some error checks from prims.cat to prims.where to refs.cat and refs.where, respectively, consistent with our new policy of doing as much error checking in the ref as possible

Utils

  • adds the canoicalize_device, extract_shape, and extract_shape_from_varargs helpers
  • adds the elementwise_unary_scalar_wrapper -- this allows elementwise unary operators to take and return scalar values (ex. refs.sin(1) will return .84...)

Refs

  • adds the fill, broadcast_tensors, clamp, empty_strided, ones, zeros, and uniform references
  • adds the nn.functional.dropout reference
  • fixes refs.cat to handle 1D tensors with no inputs consistent with eager mode

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented May 20, 2022

🔗 Helpful links

❌ 1 New Failures

As of commit 0516660 (more details on the Dr. CI page):

Expand to see more
  • 1/1 failures introduced in this PR

🕵️ 1 new failure recognized by patterns

The following CI failures do not appear to be due to upstream breakages

See GitHub Actions build pull / pytorch-xla-linux-bionic-py3.7-clang8 / test (xla, 1, 1, linux.2xlarge) (1/1)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

2022-05-22T23:52:26.5593369Z RuntimeError: tens...OK() (UNKNOWN: Could not start gRPC server vs. OK)
2022-05-22T23:52:26.5585933Z   File "/opt/conda/lib/python3.7/site-packages/torch_xla-1.12-py3.7-linux-x86_64.egg/torch_xla/distributed/xla_multiprocessing.py", line 315, in _setup_replication
2022-05-22T23:52:26.5586600Z     device = xm.xla_device()
2022-05-22T23:52:26.5587370Z   File "/opt/conda/lib/python3.7/site-packages/torch_xla-1.12-py3.7-linux-x86_64.egg/torch_xla/core/xla_model.py", line 232, in xla_device
2022-05-22T23:52:26.5588004Z     devkind=devkind if devkind is not None else None)
2022-05-22T23:52:26.5588868Z   File "/opt/conda/lib/python3.7/site-packages/torch_xla-1.12-py3.7-linux-x86_64.egg/torch_xla/core/xla_model.py", line 137, in get_xla_supported_devices
2022-05-22T23:52:26.5589461Z     xla_devices = _DEVICES.value
2022-05-22T23:52:26.5590236Z   File "/opt/conda/lib/python3.7/site-packages/torch_xla-1.12-py3.7-linux-x86_64.egg/torch_xla/utils/utils.py", line 32, in value
2022-05-22T23:52:26.5590835Z     self._value = self._gen_fn()
2022-05-22T23:52:26.5591815Z   File "/opt/conda/lib/python3.7/site-packages/torch_xla-1.12-py3.7-linux-x86_64.egg/torch_xla/core/xla_model.py", line 19, in <lambda>
2022-05-22T23:52:26.5592504Z     _DEVICES = xu.LazyProperty(lambda: torch_xla._XLAC._xla_get_devices())
2022-05-22T23:52:26.5593369Z RuntimeError: tensorflow/compiler/xla/xla_client/xrt_local_service.cc:56 : Check failed: tensorflow::NewServer(server_def, &server_) == ::tensorflow::Status::OK() (UNKNOWN: Could not start gRPC server vs. OK)
2022-05-22T23:52:26.8021092Z Traceback (most recent call last):
2022-05-22T23:52:26.8021873Z   File "/var/lib/jenkins/workspace/xla/test/test_mp_save.py", line 63, in <module>
2022-05-22T23:52:26.8022218Z     xmp.spawn(_mp_fn, args=(temp_file,))
2022-05-22T23:52:26.8022914Z   File "/opt/conda/lib/python3.7/site-packages/torch_xla-1.12-py3.7-linux-x86_64.egg/torch_xla/distributed/xla_multiprocessing.py", line 395, in spawn
2022-05-22T23:52:26.8023252Z     start_method=start_method)
2022-05-22T23:52:26.8023642Z   File "/opt/conda/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 198, in start_processes
2022-05-22T23:52:26.8024237Z     while not context.join():
2022-05-22T23:52:26.8024590Z   File "/opt/conda/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 154, in join
2022-05-22T23:52:26.8024864Z     exit_code=exitcode
2022-05-22T23:52:26.8025186Z torch.multiprocessing.spawn.ProcessExitedException: process 1 terminated with exit code 17

This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.


msg = f"Reference result was farther ({ref_distance}) from the precise \
computation than the torch result was ({torch_distance})!"
self.assertTrue(ref_distance <= torch_distance, msg=msg)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think ref_distance is always weakly less than torch_distance, it's ok to be larger with some tolerance?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was also thinking about this but wasn't happy with any immediate ideas I had -- cool if I add a TODO comment?

test/test_ops.py Outdated
# Reports numerical accuracy discrepancies
if ex is not None:
msg = "Test passed because the reference was more accurate than the torch operator."
print(msg)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be a warning? Pytest hides stdout of passing tests, so it will be hard to see for people using pytest

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure -- warning it is

scalar_tensor = None
number = None
for arg in args:
for arg in args_:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this would potentially set tensor to something with wrong dtype (from args_with_different_types)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great catch - fixed

*args, type_promotion: ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND
*args,
type_promotion: ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND,
args_with_different_dtypes: Tuple[TensorLikeType, ...] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

args_with_fixed_types?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that's way better -- fixed



def _select_aten(pred: Tensor, a: Tensor, b: Tensor) -> Tensor:
def _where_aten(pred: Tensor, a: Tensor, b: Tensor) -> Tensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

out of curiosity, why do we need this helper, and not just use torch.where in make_prim?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great point -- I was just on automatic mode -- fixed!


def _empty_like_aten(
a: Tensor, *, dtype: torch.dtype, device: torch.device, requires_grad: bool
def _empty_strided_aten(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same question here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep -- fixed

Copy link
Collaborator

@ngimel ngimel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great

else:
value = 3

return ({'value': value}, {'value': value})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why 2 tuple elements?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is super weird and we could sugar over this, but it's because we sometimes pass different arguments to the NumPy op, so we have "torch kwargs" and "NumPy kwargs" here

Mike Ruberry added 2 commits May 21, 2022 21:24
@mruberry
Copy link
Collaborator Author

@pytorchbot merge this please

@github-actions
Copy link
Contributor

Hey @mruberry.
You've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.
For changes that are 'topic: not user facing' there is no need for a release notes label.

@suo
Copy link
Member

suo commented May 22, 2022

@pytorchbot revert -m "This broke trunk: https://hud.pytorch.org/pytorch/pytorch/commit/043cf1f9c746b4dda2c404ba6c76c6ccad5e2cbe" -c landrace

@suo
Copy link
Member

suo commented May 22, 2022

Actually this looks like the proper classification is "nosignal"--only slow tests broke.

@mruberry
Copy link
Collaborator Author

@pytorchbot merge this please

@github-actions
Copy link
Contributor

Hey @mruberry.
You've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.
For changes that are 'topic: not user facing' there is no need for a release notes label.

self = torch.clamp(self, lo, hi)
return (self / (1 - self)).log()
self = refs.clamp(self, lo, hi)
return refs.log(refs.true_divide(self, refs.sub(1, self)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mruberry Given that the context manager exists now, we should prefer using the torch API calls as this ensures that the decomposition in question is using the limited API supported by torch and not the expanded API from refs. Is this just to work around the local problem that ref consistency tests don't work? I'd much rather we dupe the tests in that case.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's that the meta tests weren't working; I did duplicate the consistency tests

@mruberry mruberry mentioned this pull request May 23, 2022
facebook-github-bot pushed a commit that referenced this pull request May 24, 2022
Summary:
This PR...

**Issues Found**
- #78058
- #78054
- #78053
- #78050
- #77932

**Testing**
- disables stride consistency checks in test_ops and test_meta pending resolution of #78050
- skips chalf in reference tests (addressing #78054)
- splits test test_python_reference_consistency in one test for the ctx where torch.foo is torch.foo, and another for when torch.foo is refs.foo
- updates test names to be more natural and consistent:
  - test_python_reference_errors -> test_python_ref_errors
  - test_python_reference_consistency -> test_python_ref and test_python_ref_torch_fallback
  - test_python_reference_meta_functions -> test_python_ref_meta
  - test_reference_testing -> test_numpy_ref
- updates test_python_ref and test_python_ref_torch_fallback to check that the reference is more accurate than the torch op if the reference and torch op results are not close, a warning is raised when this occurs (addressing #77687)
- adds reference inputs for broadcast_tensors
- Updates the "fill_" OpInfo to "fill", adding a NumPy reference and making it an elementwise unary operator
- Adds 1D no element sample inputs to the cat OpInfo and updates the NumPy reference to handle them and type promotion correctly
- Adds reference inputs for elementwise ternary operations, like clamp
- Adds a NumPy reference for clamp
- Adds reference inputs to where's OpInfo
- Makes softplus an elementwise unary OpInfo
- Removes the great majority of Python reference OpInfo skips and xfails due to the above test changes
- Adds Python reference OpInfos for fill, dropout, clamp, broadcast_tensors, and where

**Prims**
- adds the fill, empty_strided, and uniform prims
- removes the empty, empty_like, full, and full_like prims -- these are now references that use empty_strided and fill
- renames the "concatenate" and "select" prims to "cat" and "where", respectively, to be consistent with PyTorch
- extends the `_elementwise_meta` operation to accepts tensors that don't participate in type promotion, like the `cond` tensor in `where`
- fixes a bug in the stride propagation of broadcast_in_dim
- moves some error checks from prims.cat to prims.where to refs.cat and refs.where, respectively, consistent with our new policy of doing as much error checking in the ref as possible

**Utils**
- adds the canoicalize_device, extract_shape, and extract_shape_from_varargs helpers
- adds the elementwise_unary_scalar_wrapper -- this allows elementwise unary operators to take and return scalar values (ex. refs.sin(1) will return .84...)

**Refs**
- adds the fill, broadcast_tensors, clamp, empty_strided, ones, zeros, and uniform references
- adds the nn.functional.dropout reference
- fixes refs.cat to handle 1D tensors with no inputs consistent with eager mode

Pull Request resolved: #78026
Approved by: https://github.com/ngimel

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/d4345ed0a6c06b1e489e41c219f94d26d3014ce6

Reviewed By: seemethere

Differential Revision: D36610393

Pulled By: mruberry

fbshipit-source-id: 415e532ab647ab8425f9064796704f6c44115f0e
swang392 pushed a commit that referenced this pull request May 25, 2022
This PR...

**Issues Found**
- #78058
- #78054
- #78053
- #78050
- #77932

**Testing**
- disables stride consistency checks in test_ops and test_meta pending resolution of #78050
- skips chalf in reference tests (addressing #78054)
- splits test test_python_reference_consistency in one test for the ctx where torch.foo is torch.foo, and another for when torch.foo is refs.foo
- updates test names to be more natural and consistent:
  - test_python_reference_errors -> test_python_ref_errors
  - test_python_reference_consistency -> test_python_ref and test_python_ref_torch_fallback
  - test_python_reference_meta_functions -> test_python_ref_meta
  - test_reference_testing -> test_numpy_ref
- updates test_python_ref and test_python_ref_torch_fallback to check that the reference is more accurate than the torch op if the reference and torch op results are not close, a warning is raised when this occurs (addressing #77687)
- adds reference inputs for broadcast_tensors
- Updates the "fill_" OpInfo to "fill", adding a NumPy reference and making it an elementwise unary operator
- Adds 1D no element sample inputs to the cat OpInfo and updates the NumPy reference to handle them and type promotion correctly
- Adds reference inputs for elementwise ternary operations, like clamp
- Adds a NumPy reference for clamp
- Adds reference inputs to where's OpInfo
- Makes softplus an elementwise unary OpInfo
- Removes the great majority of Python reference OpInfo skips and xfails due to the above test changes
- Adds Python reference OpInfos for fill, dropout, clamp, broadcast_tensors, and where

**Prims**
- adds the fill, empty_strided, and uniform prims
- removes the empty, empty_like, full, and full_like prims -- these are now references that use empty_strided and fill
- renames the "concatenate" and "select" prims to "cat" and "where", respectively, to be consistent with PyTorch
- extends the `_elementwise_meta` operation to accepts tensors that don't participate in type promotion, like the `cond` tensor in `where`
- fixes a bug in the stride propagation of broadcast_in_dim
- moves some error checks from prims.cat to prims.where to refs.cat and refs.where, respectively, consistent with our new policy of doing as much error checking in the ref as possible

**Utils**
- adds the canoicalize_device, extract_shape, and extract_shape_from_varargs helpers
- adds the elementwise_unary_scalar_wrapper -- this allows elementwise unary operators to take and return scalar values (ex. refs.sin(1) will return .84...)

**Refs**
- adds the fill, broadcast_tensors, clamp, empty_strided, ones, zeros, and uniform references
- adds the nn.functional.dropout reference
- fixes refs.cat to handle 1D tensors with no inputs consistent with eager mode
Pull Request resolved: #78026
Approved by: https://github.com/ngimel
swang392 pushed a commit that referenced this pull request May 25, 2022
This PR...

**Issues Found**
- #78058
- #78054
- #78053
- #78050
- #77932

**Testing**
- disables stride consistency checks in test_ops and test_meta pending resolution of #78050
- skips chalf in reference tests (addressing #78054)
- splits test test_python_reference_consistency in one test for the ctx where torch.foo is torch.foo, and another for when torch.foo is refs.foo
- updates test names to be more natural and consistent:
  - test_python_reference_errors -> test_python_ref_errors
  - test_python_reference_consistency -> test_python_ref and test_python_ref_torch_fallback
  - test_python_reference_meta_functions -> test_python_ref_meta
  - test_reference_testing -> test_numpy_ref
- updates test_python_ref and test_python_ref_torch_fallback to check that the reference is more accurate than the torch op if the reference and torch op results are not close, a warning is raised when this occurs (addressing #77687)
- adds reference inputs for broadcast_tensors
- Updates the "fill_" OpInfo to "fill", adding a NumPy reference and making it an elementwise unary operator
- Adds 1D no element sample inputs to the cat OpInfo and updates the NumPy reference to handle them and type promotion correctly
- Adds reference inputs for elementwise ternary operations, like clamp
- Adds a NumPy reference for clamp
- Adds reference inputs to where's OpInfo
- Makes softplus an elementwise unary OpInfo
- Removes the great majority of Python reference OpInfo skips and xfails due to the above test changes
- Adds Python reference OpInfos for fill, dropout, clamp, broadcast_tensors, and where

**Prims**
- adds the fill, empty_strided, and uniform prims
- removes the empty, empty_like, full, and full_like prims -- these are now references that use empty_strided and fill
- renames the "concatenate" and "select" prims to "cat" and "where", respectively, to be consistent with PyTorch
- extends the `_elementwise_meta` operation to accepts tensors that don't participate in type promotion, like the `cond` tensor in `where`
- fixes a bug in the stride propagation of broadcast_in_dim
- moves some error checks from prims.cat to prims.where to refs.cat and refs.where, respectively, consistent with our new policy of doing as much error checking in the ref as possible

**Utils**
- adds the canoicalize_device, extract_shape, and extract_shape_from_varargs helpers
- adds the elementwise_unary_scalar_wrapper -- this allows elementwise unary operators to take and return scalar values (ex. refs.sin(1) will return .84...)

**Refs**
- adds the fill, broadcast_tensors, clamp, empty_strided, ones, zeros, and uniform references
- adds the nn.functional.dropout reference
- fixes refs.cat to handle 1D tensors with no inputs consistent with eager mode
Pull Request resolved: #78026
Approved by: https://github.com/ngimel
pytorchmergebot pushed a commit that referenced this pull request Jun 27, 2022
Ref: #69991

Probably started working since : #78026
Pull Request resolved: #80277
Approved by: https://github.com/zou3519
facebook-github-bot pushed a commit that referenced this pull request Jun 30, 2022
Summary:
Ref: #69991

Probably started working since : #78026

Pull Request resolved: #80277
Approved by: https://github.com/zou3519

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/1b18c2e93cb5ae96314247e97f3040fda36b6356

Reviewed By: b0noI

Differential Revision: D37495906

fbshipit-source-id: 25dfcb5f8bbe61e5ff2da1c59810a6ebed1850c3
@github-actions github-actions bot deleted the primtorch_random branch February 16, 2024 01:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants