Skip to content

Remove gpu_kernel_with_index#33370

Closed
zasdfgbnm wants to merge 26 commits intopytorch:masterfrom
zasdfgbnm:range
Closed

Remove gpu_kernel_with_index#33370
zasdfgbnm wants to merge 26 commits intopytorch:masterfrom
zasdfgbnm:range

Conversation

@zasdfgbnm
Copy link
Collaborator

@zasdfgbnm zasdfgbnm commented Feb 15, 2020

Although gpu_kernel_with_index might look like a quite general helper function at first look, it actually isn't.

The problem is not only 32bit indexing, but something more fundamental: TensorIterator reorder dims and shapes, so if you have non-contiguous tensor such as torch.empty(5, 5).t() , the index won't be correct. Since the whole point of TensorIterator is to manipulate shapes/strides to speedup loops, it is fundamentally impossible to get the correct linear index without tons of efforts.

Currently, the range factories are not failing on an out=non_contiguous_tensor is because it is so lucky that has_internal_overlap is stupid enough to return everything not contiguous as TOO_HARD.

Since gpu_kernel_with_index is not general, we should move it from Loops.cuh to RangeFactories.cu. And since the kernel is so simple to implement, it makes no sense to use TensorIterator which goes through tons of unnecessary checks like compute_dtypes.

torch.range is not tested for 64bit-indexing, and I will file a new PR to remove it (it was supposed to be removed at 0.5).

Benchmark:
The device is GTX-1650, I don't have a good GPU at home.

Code:

import torch
print(torch.__version__)

for i in range(100):
    torch.randn(1000, device='cuda')
torch.cuda.synchronize()

for i in range(15, 29):
    %timeit torch.arange(2 ** i, device='cuda'); torch.cuda.synchronize()

Before:

1.5.0a0+c37a9b8
11.9 µs ± 412 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
12.7 µs ± 309 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
19.6 µs ± 209 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
28.9 µs ± 923 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
48.4 µs ± 1.64 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
85.7 µs ± 1.46 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
162 µs ± 1.09 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
312 µs ± 9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
618 µs ± 15.7 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
1.22 ms ± 9.91 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
2.45 ms ± 97.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
4.9 ms ± 155 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
10.1 ms ± 378 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

After:

1.5.0a0+7960d19
11 µs ± 29.6 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
12.4 µs ± 550 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
18.4 µs ± 230 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
27.6 µs ± 10.9 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
46.2 µs ± 18.6 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
83.3 µs ± 5.61 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
158 µs ± 373 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
307 µs ± 1.44 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
603 µs ± 112 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
1.2 ms ± 1.05 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
2.4 ms ± 23.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
4.77 ms ± 25.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
9.51 ms ± 933 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)

@dr-ci
Copy link

dr-ci bot commented Feb 15, 2020

💊 CircleCI build failures summary and remediations

As of commit ec09ca7:

  • 1/2 failures introduced in this PR
  • 1/2 recognized as flaky ❄️
    • Re-run these jobs?

Detailed failure analysis

One may explore the probable reasons each build failed interactively on the Dr. CI website.

🕵️ 1 new failure recognized by patterns

The following build failures do not appear to be due to upstream breakage:

See CircleCI build pytorch_windows_vs2019_py36_cuda10.1_test2 (1/1)

Step: "Test" (full log | pattern match details)

RuntimeError: test_jit_fuser failed!
 
---------------------------------------------------------------------- 
Ran 46 tests in 11.552s 
 
FAILED (errors=4, skipped=10) 
Traceback (most recent call last): 
  File "run_test.py", line 486, in <module> 
    main() 
  File "run_test.py", line 479, in main 
    raise RuntimeError(message) 
RuntimeError: test_jit_fuser failed! 
 
(base) circleci@PACKER-5E29F737 C:\Users\circleci\project\test>if ERRORLEVEL 1 exit /b 1  
+ cleanup
+ retcode=1
+ set +x

❄️ 1 failure recognized as flaky

The following build failures have been detected as flaky and may not be your fault:

See CircleCI build caffe2_onnx_py2_gcc5_ubuntu16_04_test (1/1)

Step: "Test" (full log | pattern match details) ❄️

Feb 16 04:49:52 test/onnx/test_utility_funs.py::TestUtilityFuns::test_constant_fold_slice_negative_index /var/lib/jenkins/workspace/scripts/onnx/test.sh: line 57: 29209 Segmentation fault (core dumped) pytest "${args[@]}" --ignore "$top_dir/test/onnx/test_pytorch_onnx_onnxruntime.py" --ignore "$top_dir/test/onnx/test_custom_ops.py" --ignore "$top_dir/test/onnx/test_models_onnxruntime.py" "${test_paths[@]}"
Feb 16 04:49:51 test/onnx/test_utility_funs.py::TestUtilityFuns_opset10::test_is_in_onnx_export PASSED [ 99%] 
Feb 16 04:49:51 test/onnx/test_utility_funs.py::TestUtilityFuns_opset10::test_strip_doc_string PASSED [ 99%] 
Feb 16 04:49:51 test/onnx/test_utility_funs.py::TestUtilityFuns_opset10::test_validate_dynamic_axes_invalid_input_output_name PASSED [ 99%] 
Feb 16 04:49:51 test/onnx/test_utility_funs.py::TestUtilityFuns::test_constant_fold_concat PASSED [ 99%] 
Feb 16 04:49:51 test/onnx/test_utility_funs.py::TestUtilityFuns::test_constant_fold_div PASSED [ 99%] 
Feb 16 04:49:51 test/onnx/test_utility_funs.py::TestUtilityFuns::test_constant_fold_lstm PASSED [ 99%] 
Feb 16 04:49:51 test/onnx/test_utility_funs.py::TestUtilityFuns::test_constant_fold_mul PASSED [ 99%] 
Feb 16 04:49:51 test/onnx/test_utility_funs.py::TestUtilityFuns::test_constant_fold_reshape SKIPPED [ 99%] 
Feb 16 04:49:51 test/onnx/test_utility_funs.py::TestUtilityFuns::test_constant_fold_slice PASSED [ 99%] 
Feb 16 04:49:51 test/onnx/test_utility_funs.py::TestUtilityFuns::test_constant_fold_slice_index_exceeds_dim PASSED [ 99%] 
Feb 16 04:49:52 test/onnx/test_utility_funs.py::TestUtilityFuns::test_constant_fold_slice_negative_index /var/lib/jenkins/workspace/scripts/onnx/test.sh: line 57: 29209 Segmentation fault      (core dumped) pytest "${args[@]}" --ignore "$top_dir/test/onnx/test_pytorch_onnx_onnxruntime.py" --ignore "$top_dir/test/onnx/test_custom_ops.py" --ignore "$top_dir/test/onnx/test_models_onnxruntime.py" "${test_paths[@]}" 

This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker.

This comment has been revised 64 times.

@zasdfgbnm zasdfgbnm changed the title [WIP] Remove gpu_kernel_with_index Remove gpu_kernel_with_index Feb 15, 2020
@zasdfgbnm zasdfgbnm requested a review from ngimel February 15, 2020 09:25
@ngimel
Copy link
Collaborator

ngimel commented Feb 15, 2020

Test failure in test_print and test_pickle might be related?

@zasdfgbnm
Copy link
Collaborator Author

@ngimel I don't know. Let me rebase and see.

@zasdfgbnm
Copy link
Collaborator Author

@pytorchbot rebase this please

@ngimel
Copy link
Collaborator

ngimel commented Feb 15, 2020

This generally looks good, but if you want to be awesome, can you implement @largeTensorTest wrapper in common_device_type.py? It seems to be usefule, TEST_LARGE_TENSOR is used in a lot of places.

@zasdfgbnm
Copy link
Collaborator Author

@ngimel Sure

@zasdfgbnm
Copy link
Collaborator Author

@ngimel The failures are related. They are about n dimensional zero sized tensors.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@ngimel
Copy link
Collaborator

ngimel commented Feb 15, 2020

test_nn now segfaults.

@zasdfgbnm
Copy link
Collaborator Author

@ngimel

CUDA out of memory. Tried to allocate 6.00 GiB (GPU 0; 14.85 GiB total capacity; 8.00 GiB already allocated; 1.96 GiB free; 12.13 GiB reserved in total by PyTorch)

It is strange that this test is not skipped on CI. It was correctly skipped locally on my machine. Could it be possible that the total RAM of GPU is 32GB, but only 16GB is allowed by PyTorch to use?

@zasdfgbnm
Copy link
Collaborator Author

@ngimel I see what's wrong: I forget a not in the condition and was testing on a wrong PyTorch build locally.....

@ngimel
Copy link
Collaborator

ngimel commented Feb 16, 2020

Which job are you looking at? Generally in Set up CI environment section you can see output of nvidia-smi. E.g. here https://app.circleci.com/jobs/github/pytorch/pytorch/4510128 nvidia-smi reports M60 with 8 GB memory, and the test fails with the correct report (tried to allocate 4 GB, 7.44 or something available), so it's not being skipped correctly?
non-gpu tests fail with segfault, not with oom.

@zasdfgbnm
Copy link
Collaborator Author

@ngimel Yes, you are right, it is not skipped correctly. I was looking at a build on a wrong directory, in which it was successfully skipped, that's why I was saying

It is strange that this test is not skipped on CI. It was correctly skipped locally on my machine.

The SIGSEGV fault could be because some test was skipped both on CPU and CUDA (due to the skipIf), but now it is only skipped on CUDA. But the machine on the CI does not have that much of host memory to run it on CPU. Let me check on a 256GB RAM computer.

@zasdfgbnm
Copy link
Collaborator Author

@ngimel Ready

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@ngimel merged this pull request in 55fa133.

@zasdfgbnm zasdfgbnm deleted the range branch February 18, 2020 04:44
ttumiel pushed a commit to ttumiel/pytorch that referenced this pull request Mar 4, 2020
Summary:
Although `gpu_kernel_with_index` might look like a quite general helper function at first look, it actually isn't.

The problem is not only 32bit indexing, but something more fundamental: `TensorIterator` reorder dims and shapes, so if you have non-contiguous tensor such as `torch.empty(5, 5).t()` , the index won't be correct. Since the whole point of `TensorIterator` is to manipulate shapes/strides to speedup loops, it is fundamentally impossible to get the correct linear index without tons of efforts.

Currently, the range factories are not failing on an `out=non_contiguous_tensor`  is because it is so lucky that  `has_internal_overlap` is stupid enough to return everything not contiguous as `TOO_HARD`.

Since `gpu_kernel_with_index` is not general, we should move it from `Loops.cuh` to `RangeFactories.cu`. And since the kernel is so simple to implement, it makes no sense to use `TensorIterator` which goes through tons of unnecessary checks like `compute_dtypes`.

`torch.range` is not tested for 64bit-indexing, and I will file a new PR to remove it (it was supposed to be removed at 0.5).

Benchmark:
The device is GTX-1650, I don't have a good GPU at home.

Code:
```python
import torch
print(torch.__version__)

for i in range(100):
    torch.randn(1000, device='cuda')
torch.cuda.synchronize()

for i in range(15, 29):
    %timeit torch.arange(2 ** i, device='cuda'); torch.cuda.synchronize()
```

Before:
```
1.5.0a0+c37a9b8
11.9 µs ± 412 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
12.7 µs ± 309 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
19.6 µs ± 209 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
28.9 µs ± 923 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
48.4 µs ± 1.64 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
85.7 µs ± 1.46 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
162 µs ± 1.09 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
312 µs ± 9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
618 µs ± 15.7 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
1.22 ms ± 9.91 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
2.45 ms ± 97.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
4.9 ms ± 155 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
10.1 ms ± 378 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
```

After:
```
1.5.0a0+7960d19
11 µs ± 29.6 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
12.4 µs ± 550 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
18.4 µs ± 230 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
27.6 µs ± 10.9 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
46.2 µs ± 18.6 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
83.3 µs ± 5.61 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
158 µs ± 373 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
307 µs ± 1.44 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
603 µs ± 112 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
1.2 ms ± 1.05 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
2.4 ms ± 23.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
4.77 ms ± 25.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
9.51 ms ± 933 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)
```
Pull Request resolved: pytorch#33370

Differential Revision: D19925990

Pulled By: ngimel

fbshipit-source-id: f4a732fe14a5582b35a56618941120d62e82fdce
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants