Skip to content
Closed
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
1c8f4c1
WPI: changing as_strided_scatter to deterministic inputs
srossross Sep 24, 2022
839ffa2
Update on "WIP: changing as_strided_scatter to deterministic inputs"
srossross Sep 26, 2022
904fd95
Update on "WIP: changing as_strided_scatter to deterministic inputs"
srossross Sep 26, 2022
63781d3
Update on "WIP: changing as_strided_scatter to deterministic inputs"
srossross Sep 27, 2022
b735e96
Update on "WIP: changing as_strided_scatter to deterministic inputs"
srossross Sep 27, 2022
853b9fe
Update on "WIP: changing as_strided_scatter to deterministic inputs"
srossross Oct 4, 2022
73497a2
Update on "WIP: changing as_strided_scatter to deterministic inputs"
srossross Oct 5, 2022
46036da
Update on "WIP: changing as_strided_scatter to deterministic inputs"
srossross Oct 5, 2022
ea3ea66
Update on "WIP: changing as_strided_scatter to deterministic inputs"
srossross Oct 5, 2022
7df7c65
Update on "WIP: changing as_strided_scatter to deterministic inputs"
srossross Oct 5, 2022
d23afcf
Update on "WIP: changing as_strided_scatter to deterministic inputs"
srossross Oct 7, 2022
9da3db3
Update on "WIP: changing as_strided_scatter to deterministic inputs"
srossross Oct 7, 2022
2320d5e
Update on "WIP: changing as_strided_scatter to deterministic inputs"
srossross Oct 10, 2022
9c10a5f
Update on "WIP: changing as_strided_scatter to deterministic inputs"
srossross Oct 11, 2022
8cc7a77
Update on "WIP: changing as_strided_scatter to deterministic inputs"
srossross Oct 12, 2022
a12555d
Update on "WIP: changing as_strided_scatter to deterministic inputs"
srossross Oct 12, 2022
3add5fd
Update on "WIP: changing as_strided_scatter to deterministic inputs"
srossross Oct 12, 2022
2c7f50e
Update on "WIP: changing as_strided_scatter to deterministic inputs"
srossross Oct 12, 2022
28aba99
Update on "WIP: changing as_strided_scatter to deterministic inputs"
srossross Oct 13, 2022
068108f
Update on "WIP: changing as_strided_scatter to deterministic inputs"
srossross Oct 26, 2022
5cbff47
Update on "WIP: changing as_strided_scatter to deterministic inputs"
srossross Oct 27, 2022
dfc65be
Update on "WIP: changing as_strided_scatter to deterministic inputs"
srossross Oct 28, 2022
6787c1d
Update on "WIP: changing as_strided_scatter to deterministic inputs"
srossross Oct 28, 2022
2906c8a
Update on "WIP: changing as_strided_scatter to deterministic inputs"
srossross Oct 31, 2022
854d431
Update on "WIP: changing as_strided_scatter to deterministic inputs"
srossross Oct 31, 2022
f5f039e
Update on "Changing as_strided_scatter to deterministic inputs"
srossross Nov 1, 2022
4a9ef27
Update on "Changing as_strided_scatter to deterministic inputs"
srossross Nov 2, 2022
4b64fe2
Update on "Changing as_strided_scatter to deterministic inputs"
srossross Nov 7, 2022
0b103b0
Update on "Changing as_strided_scatter to deterministic inputs"
srossross Nov 8, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 19 additions & 7 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,15 +274,30 @@ def sample_inputs_as_strided_scatter(op_info, device, dtype, requires_grad, **kw
((1,), (1,), (1,), 0),
((3, 3), (2, 2), (1, 2), 0),
((3, 3), (2, 2), (1, 2), 1),
((16,), (2, 2, 2, 2), (1, 1, 1, 1), 0),
((16,), (2, 1, 1, 2), (1, 7, 7, 1), 0),
((3, 3), (2, 2), (2, 1), 0),
# Scatter to larger dimentions
((16,), (2, 2, 2, 2), (8, 4, 2, 1), 0),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the change here? Should the test cases have a comment describing the property they assume?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added a comment above, the removed code was causing non-deterministic behaviour so I modified the strides

# Scatter to larger dimentions with strides inverted
((16,), (2, 1, 1, 2), (1, 2, 4, 8), 0),
]

for input_shape, output_shape, stride, storage_offset in test_cases:
input_t = make_arg(input_shape)
input_src = make_arg(output_shape)
yield SampleInput(input_t, input_src, output_shape, stride, storage_offset=storage_offset)

def error_inputs_as_strided_scatter(op_info, device, **kwargs):
make_arg = partial(make_tensor, device=device, dtype=torch.float32, requires_grad=False)

# Create a small tensor and try to scatter it out of bounds
input_t = make_arg([4, 4])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a comment describing what error condition is being tested

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

input_src = make_arg([2, 2])
yield ErrorInput(
SampleInput(input_t, input_src, [2, 2], [200, 200], storage_offset=0),
error_regex="itemsize 4 requiring a storage size of 1604 are out of bounds for storage of size 64"
)


def sample_inputs_combinations(op_info, device, dtype, requires_grad, **kwargs):
inputs = (
(0,),
Expand Down Expand Up @@ -10430,18 +10445,15 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
# vmap does not support inplace views
check_inplace_batched_forward_grad=False,
sample_inputs_func=sample_inputs_as_strided_scatter,
error_inputs_func=error_inputs_as_strided_scatter,
skips=(
DecorateInfo(unittest.skip('Works for int64, fails for everything else'), 'TestCommon', 'test_noncontiguous_samples'), # noqa: B950
DecorateInfo(unittest.skip('Fails in most cases, passes on LAZY for some reason'), 'TestCommon', 'test_variant_consistency_eager'), # noqa: B950
DecorateInfo(unittest.skip('Fails on cuda + rocm'), 'TestCommon', 'test_complex_half_reference_testing'),
DecorateInfo(unittest.expectedFailure, 'TestBwdGradients', 'test_fn_grad'),
DecorateInfo(unittest.expectedFailure, 'TestBwdGradients', 'test_fn_gradgrad'),
DecorateInfo(unittest.skip('Passes on complex128 and float64 only'), 'TestFwdGradients', 'test_fn_fwgrad_bwgrad'),
# AssertionError: Tensor-likes are not close! (new_empty_strided.default)
DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"), 'TestDecomp', 'test_comprehensive'),
DecorateInfo(
unittest.skip("Some stride values write multiple values to the same location e.g. (1,1,1,1)"),
'TestCommon', 'test_compare_cpu'),)),
DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"), 'TestDecomp', 'test_comprehensive'),)),
OpInfo('native_layer_norm',
aten_name='native_layer_norm',
ref=reference_native_layer_norm,
Expand Down