Skip to content

[primTorch] Adds contiguous and expand references#79820

Closed
mruberry wants to merge 21 commits intomasterfrom
shape_refs
Closed

[primTorch] Adds contiguous and expand references#79820
mruberry wants to merge 21 commits intomasterfrom
shape_refs

Conversation

@mruberry
Copy link
Collaborator

@mruberry mruberry commented Jun 17, 2022

I also filed while creating this PR.

This PR...

Filed issues

prims

  • Fixes prims.squeeze when called with an unsorted list of dimensions
  • Removes the clone prim

refs

  • adds contiguous
  • adds expand
  • updates clone to call empty_like and copy_to
  • updates empty to accept a memory format
  • updates empty_like to accept a memory_format

utils

  • adds helper functions for working with memory formats and channels last tensors, in particular

tests

  • removes unused clamp sample input functions (mooted by clamp's new reference inputs)
  • extends the reference inputs for clone to include different memory formats
  • creates reference inputs for contiguous
  • xfails operators that depend on clone (including clone) on test_python_ref (see issues)

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Jun 17, 2022

🔗 Helpful links

❌ 2 New Failures

As of commit a5ac174 (more details on the Dr. CI page):

Expand to see more
  • 2/2 failures introduced in this PR

🕵️ 2 new failures recognized by patterns

The following CI failures do not appear to be due to upstream breakages

See GitHub Actions build trunk / win-vs2019-cuda11.6-py3 / test (default, 5, 5, windows.8xlarge.nvidia.gpu) (1/2)

Step: "Setup Windows" (full log | diagnosis details | 🔁 rerun)

2022-07-11T15:19:43.9241465Z ##[error]Process completed with exit code 127.
2022-07-11T15:19:43.8804457Z   PYTORCH_RETRY_TEST_CASES: 1
2022-07-11T15:19:43.8804953Z   PYTORCH_OVERRIDE_FLAKY_SIGNAL: 1
2022-07-11T15:19:43.8805275Z   SHA1: a5ac174a41dce83d542ec146ebb1662a91a5e36b
2022-07-11T15:19:43.8805548Z   TAG: 
2022-07-11T15:19:43.8805742Z   WORKFLOW_ID: 2650220508
2022-07-11T15:19:43.8806431Z   GITHUB_TOKEN: ***
2022-07-11T15:19:43.8806691Z   GHA_WORKFLOW_JOB_ID: 
2022-07-11T15:19:43.8806934Z ##[endgroup]
2022-07-11T15:19:43.9093461Z + python3 -m pip install -r requirements.txt
2022-07-11T15:19:43.9210418Z C:\actions-runner\_work\_temp\5b084cbb-f8db-48fb-9601-a8a544970467.sh: line 2: python3: command not found
2022-07-11T15:19:43.9241465Z ##[error]Process completed with exit code 127.
2022-07-11T15:19:43.9392051Z Prepare all required actions
2022-07-11T15:19:43.9435332Z ##[group]Run ./.github/actions/teardown-win
2022-07-11T15:19:43.9435606Z with:
2022-07-11T15:19:43.9435809Z env:
2022-07-11T15:19:43.9436015Z   GIT_DEFAULT_BRANCH: master
2022-07-11T15:19:43.9436259Z ##[endgroup]
2022-07-11T15:19:43.9574209Z ##[group]Run .github\scripts\wait_for_ssh_to_drain.ps1
2022-07-11T15:19:43.9574593Z �[36;1m.github\scripts\wait_for_ssh_to_drain.ps1�[0m
2022-07-11T15:19:43.9604689Z shell: C:\Windows\System32\WindowsPowerShell\v1.0\powershell.EXE -command ". '{0}'"
2022-07-11T15:19:43.9605048Z env:

See GitHub Actions build trunk / linux-focal-rocm5.1-py3.7 / test (default, 2, 2, linux.rocm.gpu) (2/2)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

2022-07-11T16:29:53.2026965Z RuntimeError: test_nestedtensor failed! Received signal: SIGIOT
2022-07-11T16:29:52.9575177Z   test_to_padded_tensor_output_size_cuda_float16 (__main__.TestNestedTensorDeviceTypeCUDA) ... ok (0.004s)
2022-07-11T16:29:52.9614368Z   test_to_padded_tensor_output_size_cuda_float32 (__main__.TestNestedTensorDeviceTypeCUDA) ... ok (0.004s)
2022-07-11T16:29:52.9648143Z   test_to_padded_tensor_simple_cuda_float16 (__main__.TestNestedTensorDeviceTypeCUDA) ... Memory exception on virtual address 0x7fe2c9fb7000, node id 4 : Page not present
2022-07-11T16:29:52.9649512Z Address does not belong to a known buffer
2022-07-11T16:29:52.9651040Z Memory access fault by GPU node-4 (Agent handle: 0x558cf6282240) on address 0x7fe2c9fb7000. Reason: Page not present or supervisor privilege.
2022-07-11T16:29:53.2015738Z Traceback (most recent call last):
2022-07-11T16:29:53.2016894Z   File "test/run_test.py", line 945, in <module>
2022-07-11T16:29:53.2021157Z     main()
2022-07-11T16:29:53.2022169Z   File "test/run_test.py", line 923, in main
2022-07-11T16:29:53.2025692Z     raise RuntimeError(err_message)
2022-07-11T16:29:53.2026965Z RuntimeError: test_nestedtensor failed! Received signal: SIGIOT
2022-07-11T16:29:55.1723937Z 
2022-07-11T16:29:55.1724750Z real	101m10.992s
2022-07-11T16:29:55.1725593Z user	190m53.059s
2022-07-11T16:29:55.1726481Z sys	24m15.402s
2022-07-11T16:29:55.1862777Z ##[error]Process completed with exit code 1.
2022-07-11T16:29:55.1957248Z ##[group]Run # copy test results back to the mounted workspace, needed sudo, resulting permissions were correct
2022-07-11T16:29:55.1957906Z �[36;1m# copy test results back to the mounted workspace, needed sudo, resulting permissions were correct�[0m
2022-07-11T16:29:55.1958498Z �[36;1mdocker exec -t "c2926f7b9154ef6e501ef18a49cb7c62ad2eb4ee58d86201df2a868a2e5f66f1" sh -c "cd ../pytorch && sudo cp -R test/test-reports ../workspace/test"�[0m
2022-07-11T16:29:55.1995525Z shell: /bin/bash -e {0}
2022-07-11T16:29:55.1995763Z env:

This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@mruberry mruberry added the ciflow/trunk Trigger trunk jobs on your pull request label Jun 23, 2022
)

shape = list(shape)
shape[1], shape[-1] = shape[-1], shape[1]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not correct, you need to permute shape elements like

shape[1], shape[2], shape[3] = shape[2], shape[3], shape[1]

(and how are tests passing? even for (2,2,2,2) they shouldn't. Oh, I guess it's because tests actually don't test strides)


shape = list(shape)
shape[1], shape[-1] = shape[-1], shape[1]
strides = list(make_contiguous_strides_for(shape))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there's this weird inconsistency in eager for 0-element tensors where contiguous leaves the strides of following size-0 dimensions to be whatever the stride of the previous dimension was, whereas for channels-last that stride (and all subsequent) are zeroed:

In [41]: a=torch.empty(2,2,2,0, memory_format=torch.channels_last)

In [42]: a.stride()
Out[42]: (0, 1, 0, 2)

In [43]: a=torch.empty(2,2,2,0)

In [44]: a.stride()
Out[44]: (4, 2, 1, 1)

so using this approach would produce results that are different from eager.
Another eager inconsistency is that empty tensors are always contiguous, but have to follow precise striding requirements to be channels_last contiguous, and we should probably fix that.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great catch -- I redid the implementation to capture this caveat and I'll add some additional testing


# Short-circuits for tensors with zero or one elements, which
# are trivially non-overlapping and "dense"
if a.numel() < 2:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need this special case, or does it follow from is_contiguous=True?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, we can remove this special case

)

shape = list(shape)
shape[1], shape[-1] = shape[-1], shape[1]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here, (and same for strides in both places, modulo inverse permutation)


# Short-circuits for tensors of rank one, which are
# non-overlapping and "dense" if their stride is one
if a.ndim == 1:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this also follows from is_contiguous(a)=True so you'd never get here

strides = []
for idx in (-1, -2, 0):
l = shape[idx]
if l != 0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if and else branches are the same?

Comment on lines +1264 to +1274
multiplier = shape[1]
strides = []
for idx in (-1, -2, 0):
l = shape[idx]
if l != 0:
strides.append(multiplier)
else:
strides.append(multiplier)
# NOTE: intentionally divergence from make_contiguous_strides_for
# This is consistent with eager
multiplier = l * multiplier
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
multiplier = shape[1]
strides = []
for idx in (-1, -2, 0):
l = shape[idx]
if l != 0:
strides.append(multiplier)
else:
strides.append(multiplier)
# NOTE: intentionally divergence from make_contiguous_strides_for
# This is consistent with eager
multiplier = l * multiplier
multiplier = 1
strides = [0]*4
for idx in (1, -1, -2, 0):
# NOTE: intentionally divergence from make_contiguous_strides_for
# This is consistent with eager
strides[idx]=multiplier
multiplier *= shape[idx]

Comment on lines +1276 to +1278
strides.insert(2, 1)
result = tuple(reversed(strides))
return result
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
strides.insert(2, 1)
result = tuple(reversed(strides))
return result
return strides

yield SampleInput(make_arg((6, 6), noncontiguous=True))

# channels last 2D
yield SampleInput(make_arg((2, 2, 2, 2)), kwargs={'memory_format': torch.channels_last})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these are the same as clone?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep -- combining them -- great catch

@mruberry
Copy link
Collaborator Author

@pytorchbot merge -g

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a merge job. Check the current status here

@mruberry
Copy link
Collaborator Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a merge job. Check the current status here

@github-actions
Copy link
Contributor

Hey @mruberry.
You've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.
For changes that are 'topic: not user facing' there is no need for a release notes label.

@mruberry mruberry added the topic: not user facing topic category label Jul 11, 2022
facebook-github-bot pushed a commit that referenced this pull request Jul 12, 2022
Summary:
I also filed  while creating this PR.

This PR...

**Filed issues**

- #79818
- #80154

**prims**

- Fixes prims.squeeze when called with an unsorted list of dimensions
- Removes the clone prim

**refs**
- adds contiguous
- adds expand
- updates clone to call empty_like and copy_to
- updates empty to accept a memory format
- updates empty_like to accept a memory_format

**utils**
- adds helper functions for working with memory formats and channels last tensors, in particular

**tests**

- removes unused clamp sample input functions (mooted by clamp's new reference inputs)
- extends the reference inputs for clone to include different memory formats
- creates reference inputs for contiguous
- xfails operators that depend on clone (including clone) on `test_python_ref` (see issues)

Pull Request resolved: #79820
Approved by: https://github.com/ngimel

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/8740c68c41c563de42b011cef42b3de7690e9446

Reviewed By: DanilBaibak

Differential Revision: D37781899

Pulled By: mruberry

fbshipit-source-id: 7c422f216844ac9c3c95ccc97d320faaafd88b0b
pytorchmergebot pushed a commit that referenced this pull request Oct 17, 2022
Add Reference:
- nll_loss

Depends on:
- expand #79820
- advance indexing

Pull Request resolved: #81128
Approved by: https://github.com/mruberry
@github-actions github-actions bot deleted the shape_refs branch February 18, 2024 01:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants