Skip to content

Commit 2f85a57

Browse files
committed
Update on "Implement correction argument in torch.masked.{std,var}"
This makes the signature of `torch.masked.std` and `var` more consistent with the global namespace variant and also updates the sample inputs to repurpose the existing `sample_inputs_std_var` inputs which fully exercise the `correction` argument. [ghstack-poisoned]
2 parents d5b9c6b + 901ab87 commit 2f85a57

299 files changed

Lines changed: 5445 additions & 2391 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.github/ci_commit_pins/vision.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
790f1cdcea0359619adfc9ec37b91883748d1854
1+
842e178a488722720b6eb1e9cb508439e8e1ecd9

.github/requirements/conda-env-Linux-X64

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
cffi=1.15.1
2-
cmake=3.22.1
2+
cmake=3.22.*
33
mkl=2022.1.0
44
mkl-include=2022.1.0
55
ninja=1.10.2

.github/requirements/conda-env-macOS-ARM64

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
numpy=1.22.3
22
pyyaml=6.0
33
setuptools=61.2.0
4-
cmake=3.22.1
4+
cmake=3.22.*
55
cffi=1.15.1
66
typing_extensions=4.3.0
77
dataclasses=0.8

.github/requirements/conda-env-macOS-X64

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ mkl-include=2021.2.0
33
numpy=1.18.5
44
pyyaml=5.3
55
setuptools=46.0.0
6-
cmake=3.22.1
6+
cmake=3.22.*
77
cffi=1.15.1
88
typing_extensions=4.3.0
99
dataclasses=0.8

.jenkins/pytorch/test.sh

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -252,11 +252,8 @@ test_inductor_distributed() {
252252

253253
test_inductor() {
254254
python tools/dynamo/verify_dynamo.py
255-
python test/run_test.py --include test_modules test_ops --verbose
255+
python test/run_test.py --include test_modules test_ops test_ops_gradients --verbose
256256
PYTORCH_TEST_WITH_INDUCTOR=0 python test/run_test.py --include inductor/test_torchinductor --include inductor/test_torchinductor_opinfo --verbose
257-
# TODO: investigate "RuntimeError: CUDA driver API confirmed a leak"
258-
# seen intest_ops_gradients.py
259-
# pytest test/test_ops_gradients.py --verbose -k "not _complex and not test_inplace_grad_acos_cuda_float64"
260257
}
261258

262259
test_inductor_huggingface() {

.lintrunner.toml

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,13 +101,22 @@ exclude_patterns = [
101101
'torch/csrc/**',
102102
'torch/_dynamo/**/*.py',
103103
'torch/_inductor/**/*.py',
104+
'torch/_functorch/aot_autograd.py',
105+
'torch/_functorch/benchmark_utils.py',
106+
'torch/_functorch/compile_utils.py',
107+
'torch/_functorch/compilers.py',
108+
'torch/_functorch/eager_transforms.py',
109+
'torch/_functorch/fx_minifier.py',
110+
'torch/_functorch/partitioners.py',
111+
'torch/_functorch/make_functional.py',
112+
'torch/_functorch/top_operators_github_usage.py',
113+
'torch/_functorch/vmap.py',
104114
'torch/distributed/elastic/agent/server/api.py',
105115
'torch/testing/_internal/**',
106116
'torch/distributed/fsdp/fully_sharded_data_parallel.py',
107117
'torch/distributed/distributed_c10d.py',
108118
# TODO(suo): these exclusions were added just to get lint clean on master.
109119
# Follow up to do more target suppressions and remove them.
110-
'torch/distributed/fsdp/flatten_params_wrapper.py',
111120
'torch/ao/quantization/fx/convert.py',
112121
'torch/ao/quantization/_dbr/function_fusion.py',
113122
'test/test_datapipe.py',

aten/src/ATen/CPUGeneratorImpl.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,8 @@ void CPUGeneratorImpl::set_state(const c10::TensorImpl& new_state) {
127127
using detail::CPUGeneratorImplState;
128128
using detail::CPUGeneratorImplStateLegacy;
129129

130-
static_assert(std::is_pod<CPUGeneratorImplStateLegacy>::value, "CPUGeneratorImplStateLegacy is not a PODType");
131-
static_assert(std::is_pod<CPUGeneratorImplState>::value, "CPUGeneratorImplState is not a PODType");
130+
static_assert(std::is_standard_layout<CPUGeneratorImplStateLegacy>::value, "CPUGeneratorImplStateLegacy is not a PODType");
131+
static_assert(std::is_standard_layout<CPUGeneratorImplState>::value, "CPUGeneratorImplState is not a PODType");
132132

133133
static const size_t size_legacy = sizeof(CPUGeneratorImplStateLegacy);
134134
static const size_t size_current = sizeof(CPUGeneratorImplState);
@@ -207,7 +207,7 @@ c10::intrusive_ptr<c10::TensorImpl> CPUGeneratorImpl::get_state() const {
207207
using detail::CPUGeneratorImplState;
208208

209209
static const size_t size = sizeof(CPUGeneratorImplState);
210-
static_assert(std::is_pod<CPUGeneratorImplState>::value, "CPUGeneratorImplState is not a PODType");
210+
static_assert(std::is_standard_layout<CPUGeneratorImplState>::value, "CPUGeneratorImplState is not a PODType");
211211

212212
auto state_tensor = at::detail::empty_cpu({(int64_t)size}, ScalarType::Byte, c10::nullopt, c10::nullopt, c10::nullopt, c10::nullopt);
213213
auto rng_state = state_tensor.data_ptr();

aten/src/ATen/cpu/vec/vec256/vec256.h

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -256,8 +256,7 @@ inline Vectorized<int16_t> flip(const Vectorized<int16_t> & v) {
256256
return _mm256_permute2x128_si256(reversed, reversed, 1);
257257
}
258258

259-
template<>
260-
inline Vectorized<int8_t> flip(const Vectorized<int8_t> & v) {
259+
inline __m256i flip8(const __m256i & v) {
261260
const __m256i mask_int8 = _mm256_set_epi8(
262261
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
263262
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15
@@ -266,6 +265,15 @@ inline Vectorized<int8_t> flip(const Vectorized<int8_t> & v) {
266265
return _mm256_permute2x128_si256(reversed, reversed, 1);
267266
}
268267

268+
template<>
269+
inline Vectorized<int8_t> flip(const Vectorized<int8_t> & v) {
270+
return flip8(v);
271+
}
272+
273+
template<>
274+
inline Vectorized<uint8_t> flip(const Vectorized<uint8_t> & v) {
275+
return flip8(v);
276+
}
269277

270278
#endif // (defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
271279

aten/src/ATen/cpu/vec/vec256/vec256_complex_double.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ template <> class Vectorized<c10::complex<double>> {
185185
return _mm256_div_pd(log(), log10_);
186186
}
187187
Vectorized<c10::complex<double>> log1p() const {
188-
AT_ERROR("not supported for complex numbers");
188+
return map(std::log1p);
189189
}
190190
Vectorized<c10::complex<double>> asin() const {
191191
// asin(x)

aten/src/ATen/cpu/vec/vec256/vec256_complex_float.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ template <> class Vectorized<c10::complex<float>> {
221221
return _mm256_div_ps(log(), log10_);
222222
}
223223
Vectorized<c10::complex<float>> log1p() const {
224-
AT_ERROR("not supported for complex numbers");
224+
return map(std::log1p);
225225
}
226226
Vectorized<c10::complex<float>> asin() const {
227227
// asin(x)

0 commit comments

Comments
 (0)