Skip to content

Commit d912829

Browse files
committed
Update on "Add tan_cuda for complex dtypes"
Differential Revision: [D21572209](https://our.internmc.facebook.com/intern/diff/D21572209) [ghstack-poisoned]
2 parents 3efbdcf + 90ac830 commit d912829

16 files changed

Lines changed: 165 additions & 66 deletions

File tree

.circleci/config.yml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2883,6 +2883,14 @@ workflows:
28832883
build_environment: "pytorch-linux-xenial-py3.6-gcc5.4-ge_config_profiling-test"
28842884
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc5.4:9a3986fa-7ce7-4a36-a001-3c9bef9892e2"
28852885
resource_class: large
2886+
- pytorch_linux_test:
2887+
name: pytorch_linux_xenial_py3_6_gcc5_4_ge_config_simple_test
2888+
requires:
2889+
- setup
2890+
- pytorch_linux_xenial_py3_6_gcc5_4_build
2891+
build_environment: "pytorch-linux-xenial-py3.6-gcc5.4-ge_config_simple-test"
2892+
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc5.4:9a3986fa-7ce7-4a36-a001-3c9bef9892e2"
2893+
resource_class: large
28862894
- pytorch_linux_test:
28872895
name: pytorch_linux_xenial_cuda10_2_cudnn7_py3_ge_config_legacy_test
28882896
requires:

.circleci/verbatim-sources/workflows-pytorch-ge-config-tests.yml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,14 @@
1414
build_environment: "pytorch-linux-xenial-py3.6-gcc5.4-ge_config_profiling-test"
1515
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc5.4:9a3986fa-7ce7-4a36-a001-3c9bef9892e2"
1616
resource_class: large
17+
- pytorch_linux_test:
18+
name: pytorch_linux_xenial_py3_6_gcc5_4_ge_config_simple_test
19+
requires:
20+
- setup
21+
- pytorch_linux_xenial_py3_6_gcc5_4_build
22+
build_environment: "pytorch-linux-xenial-py3.6-gcc5.4-ge_config_simple-test"
23+
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc5.4:9a3986fa-7ce7-4a36-a001-3c9bef9892e2"
24+
resource_class: large
1725
- pytorch_linux_test:
1826
name: pytorch_linux_xenial_cuda10_2_cudnn7_py3_ge_config_legacy_test
1927
requires:

BUILD.bazel

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
load("@rules_proto//proto:defs.bzl", "proto_library")
22
load("@rules_cc//cc:defs.bzl", "cc_binary", "cc_library", "cc_proto_library", "cc_test")
33
load("//third_party:substitution.bzl", "template_rule")
4-
load("//:tools/build_variables.bzl", "torch_cpp_srcs", "libtorch_core_sources", "libtorch_distributed_sources", "libtorch_extra_sources")
4+
load("//:tools/build_variables.bzl", "torch_cpp_srcs", "libtorch_core_sources", "libtorch_distributed_sources", "libtorch_extra_sources", "jit_core_sources")
55
load("//tools/rules:cu.bzl", "cu_library")
66
load("//tools/config:defs.bzl", "if_cuda")
77
load("//:aten.bzl", "intern_build_aten_ops")
@@ -1979,13 +1979,7 @@ cc_library(
19791979
"torch/csrc/cuda/python_nccl.cpp",
19801980
"torch/csrc/cuda/nccl.cpp",
19811981
],
1982-
)) + libtorch_core_sources + libtorch_distributed_sources + torch_cpp_srcs + libtorch_extra_sources + [
1983-
"torch/csrc/jit/frontend/error_report.cpp",
1984-
"torch/csrc/jit/frontend/lexer.cpp",
1985-
"torch/csrc/jit/frontend/function_schema_parser.cpp",
1986-
"torch/csrc/jit/frontend/strtod.cpp",
1987-
"torch/csrc/jit/frontend/source_range.cpp",
1988-
"torch/csrc/jit/frontend/schema_type_parser.cpp",
1982+
)) + libtorch_core_sources + libtorch_distributed_sources + torch_cpp_srcs + libtorch_extra_sources + jit_core_sources + [
19891983
":generated_code",
19901984
],
19911985
copts = TORCH_COPTS + if_cuda(["-DUSE_CUDA=1"]),

aten/src/ATen/CMakeLists.txt

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,13 @@ if(USE_ROCM)
3434
endif()
3535

3636
# NB: If you edit these globs, you'll have to update setup.py package_data as well
37+
file(GLOB_RECURSE ATen_CORE_HEADERS "core/*.h")
38+
file(GLOB_RECURSE ATen_CORE_SRCS "core/*.cpp")
39+
file(GLOB_RECURSE ATen_CORE_TEST_SRCS "core/*_test.cpp")
40+
EXCLUDE(ATen_CORE_SRCS "${ATen_CORE_SRCS}" ${ATen_CORE_TEST_SRCS})
41+
3742
file(GLOB base_h "*.h" "detail/*.h" "cpu/*.h" "cpu/vec256/*.h" "quantized/*.h")
3843
file(GLOB base_cpp "*.cpp" "detail/*.cpp" "cpu/*.cpp")
39-
add_subdirectory(core)
4044
file(GLOB cuda_h "cuda/*.h" "cuda/detail/*.h" "cuda/*.cuh" "cuda/detail/*.cuh")
4145
file(GLOB cuda_cpp "cuda/*.cpp" "cuda/detail/*.cpp")
4246
file(GLOB cuda_nvrtc_stub_h "cuda/nvrtc_stub/*.h")
@@ -89,6 +93,26 @@ file(GLOB native_quantized_hip_cpp "native/quantized/hip/*.cpp")
8993
# XNNPACK
9094
file(GLOB native_xnnpack "native/xnnpack/*.cpp")
9195

96+
# Add files needed from jit folders
97+
list(APPEND ATen_CORE_HEADERS
98+
${Caffe2_SOURCE_DIR}/torch/csrc/jit/frontend/source_range.h
99+
${Caffe2_SOURCE_DIR}/torch/csrc/jit/frontend/function_schema_parser.h
100+
${Caffe2_SOURCE_DIR}/torch/csrc/jit/frontend/lexer.h
101+
${Caffe2_SOURCE_DIR}/torch/csrc/jit/frontend/strtod.h
102+
${Caffe2_SOURCE_DIR}/torch/csrc/jit/frontend/parse_string_literal.h
103+
${Caffe2_SOURCE_DIR}/torch/csrc/jit/frontend/schema_type_parser.h
104+
${Caffe2_SOURCE_DIR}/torch/csrc/jit/frontend/error_report.h
105+
${Caffe2_SOURCE_DIR}/torch/csrc/jit/frontend/tree.h
106+
)
107+
list(APPEND ATen_CORE_SRCS
108+
${Caffe2_SOURCE_DIR}/torch/csrc/jit/frontend/error_report.cpp
109+
${Caffe2_SOURCE_DIR}/torch/csrc/jit/frontend/function_schema_parser.cpp
110+
${Caffe2_SOURCE_DIR}/torch/csrc/jit/frontend/lexer.cpp
111+
${Caffe2_SOURCE_DIR}/torch/csrc/jit/frontend/strtod.cpp
112+
${Caffe2_SOURCE_DIR}/torch/csrc/jit/frontend/schema_type_parser.cpp
113+
${Caffe2_SOURCE_DIR}/torch/csrc/jit/frontend/source_range.cpp
114+
)
115+
92116
add_subdirectory(quantized)
93117
set(all_cpu_cpp ${base_cpp} ${ATen_CORE_SRCS} ${native_cpp} ${native_sparse_cpp} ${native_quantized_cpp} ${native_mkl_cpp} ${native_mkldnn_cpp} ${native_xnnpack} ${generated_cpp} ${core_generated_cpp} ${ATen_CPU_SRCS} ${ATen_QUANTIZED_SRCS} ${cpu_kernel_cpp})
94118
if(AT_MKL_ENABLED)

aten/src/ATen/core/CMakeLists.txt

Lines changed: 0 additions & 36 deletions
This file was deleted.

docs/source/torch.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ Tensors
2020
is_storage
2121
is_complex
2222
is_floating_point
23+
is_nonzero
2324
set_default_dtype
2425
get_default_dtype
2526
set_default_tensor_type

test/distributed/test_distributed.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ def init(cls):
203203
os.unlink(os.path.join(barrier_dir, f_name))
204204

205205
@classmethod
206-
def sync(cls, wait_for=None, timeout=5):
206+
def sync(cls, wait_for=None, timeout=10):
207207
if wait_for is None:
208208
wait_for = dist.get_world_size()
209209
cls.barrier_id += 1
@@ -455,7 +455,6 @@ def _test_group_override_backend(self, initializer):
455455
@require_backends_available({"gloo", "nccl"})
456456
@require_world_size(3)
457457
@skip_if_lt_x_gpu(2)
458-
@skip_if_rocm
459458
def test_backend_group(self):
460459
self._test_group_override_backend(self._init_group_test)
461460

test/test_dataloader.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1312,9 +1312,23 @@ def test_duplicating_data_with_drop_last(self):
13121312

13131313
self.assertEqual(scanned_data.size(), scanned_data.unique().size())
13141314

1315+
def _test_sampler(self, **kwargs):
1316+
indices = range(2, 12) # using a regular iterable
1317+
dl = DataLoader(self.dataset, sampler=indices, batch_size=2, **kwargs)
1318+
self.assertEqual(len(dl), 5)
1319+
for i, (input, _target) in enumerate(dl):
1320+
self.assertEqual(len(input), 2)
1321+
self.assertEqual(input, self.data[i * 2 + 2:i * 2 + 4])
1322+
1323+
def test_sampler(self):
1324+
self._test_sampler()
1325+
self._test_sampler(num_workers=4)
1326+
if not NO_MULTIPROCESSING_SPAWN and torch.multiprocessing._supports_context:
1327+
self._test_batch_sampler(num_workers=4, multiprocessing_context='spawn')
1328+
13151329
def _test_batch_sampler(self, **kwargs):
13161330
# [(0, 1), (2, 3, 4), (5, 6), (7, 8, 9), ...]
1317-
batches = []
1331+
batches = [] # using a regular iterable
13181332
for i in range(0, 20, 5):
13191333
batches.append(tuple(range(i, i + 2)))
13201334
batches.append(tuple(range(i + 2, i + 5)))

test/test_jit_simple.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
if __name__ == '__main__':
66
run_tests()
7-
import test_jit_py3
8-
suite = unittest.findTestCases(test_jit_py3)
9-
unittest.TextTestRunner().run(suite)
7+
if not PY2:
8+
import test_jit_py3
9+
suite = unittest.findTestCases(test_jit_py3)
10+
unittest.TextTestRunner().run(suite)

tools/build_variables.bzl

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,34 @@ libtorch_generated_sources = [
3939
"torch/csrc/autograd/VariableTypeManual.cpp",
4040
]
4141

42+
# copied from https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/core/CMakeLists.txt
43+
jit_core_headers = [
44+
"torch/csrc/utils/memory.h",
45+
"torch/csrc/WindowsTorchApiMacro.h",
46+
"torch/csrc/jit/frontend/source_range.h",
47+
"torch/csrc/jit/serialization/source_range_serialization.h",
48+
"torch/csrc/jit/frontend/lexer.h",
49+
"torch/csrc/jit/frontend/strtod.h",
50+
"torch/csrc/jit/frontend/parser_constants.h",
51+
"torch/csrc/jit/frontend/function_schema_parser.h",
52+
"torch/csrc/jit/frontend/parse_string_literal.h",
53+
"torch/csrc/jit/frontend/schema_type_parser.h",
54+
"torch/csrc/jit/frontend/error_report.h",
55+
"torch/csrc/jit/frontend/tree.h",
56+
"torch/custom_class.h",
57+
"torch/custom_class_detail.h",
58+
"torch/library.h",
59+
]
60+
61+
jit_core_sources = [
62+
"torch/csrc/jit/frontend/error_report.cpp",
63+
"torch/csrc/jit/frontend/function_schema_parser.cpp",
64+
"torch/csrc/jit/frontend/lexer.cpp",
65+
"torch/csrc/jit/frontend/schema_type_parser.cpp",
66+
"torch/csrc/jit/frontend/strtod.cpp",
67+
"torch/csrc/jit/frontend/source_range.cpp",
68+
]
69+
4270
# copied from https://github.com/pytorch/pytorch/blob/master/tools/cpp_build/torch/CMakeLists.txt
4371
libtorch_core_sources = [
4472
"torch/csrc/autograd/anomaly_mode.cpp",
@@ -238,7 +266,6 @@ libtorch_core_jit_sources = [
238266

239267
libtorch_cmake_sources = libtorch_core_sources + libtorch_core_jit_sources
240268

241-
242269
libtorch_extra_sources = libtorch_core_jit_sources + [
243270
"torch/csrc/autograd/VariableTypeManual.cpp",
244271
"torch/csrc/jit/api/module_save.cpp",

0 commit comments

Comments
 (0)