Skip to content

Commit 33c69df

Browse files
authored
Merge branch 'master' into ngimel/set_device_revert
2 parents 3f913d1 + baa0679 commit 33c69df

195 files changed

Lines changed: 4214 additions & 1903 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.github/ci_commit_pins/vision.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
5b07d6c9c6c14cf88fc545415d63021456874744
1+
29757104250dd088386fef1ec3d70ed0b0c1be8a

.lintrunner.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -623,6 +623,7 @@ include_patterns = [
623623
exclude_patterns = [
624624
'aten/src/ATen/test/**',
625625
'c10/cuda/CUDAFunctions.h',
626+
'c10/cuda/CUDACachingAllocator.cpp',
626627
]
627628
command = [
628629
'python3',

aten/src/ATen/cpu/vec/vec256/vec256_complex_double.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -446,11 +446,15 @@ Vectorized<c10::complex<double>> inline operator^(const Vectorized<c10::complex<
446446
}
447447

448448
inline Vectorized<c10::complex<double>> Vectorized<c10::complex<double>>::eq(const Vectorized<c10::complex<double>>& other) const {
449-
return (*this == other) & Vectorized<c10::complex<double>>(_mm256_set1_pd(1.0));
449+
auto eq = (*this == other); // compares real and imag individually
450+
// If both real numbers and imag numbers are equal, then the complex numbers are equal
451+
return (eq.real() & eq.imag()) & Vectorized<c10::complex<double>>(_mm256_set1_pd(1.0));
450452
}
451453

452454
inline Vectorized<c10::complex<double>> Vectorized<c10::complex<double>>::ne(const Vectorized<c10::complex<double>>& other) const {
453-
return (*this != other) & Vectorized<c10::complex<double>>(_mm256_set1_pd(1.0));
455+
auto ne = (*this != other); // compares real and imag individually
456+
// If either real numbers or imag numbers are not equal, then the complex numbers are not equal
457+
return (ne.real() | ne.imag()) & Vectorized<c10::complex<double>>(_mm256_set1_pd(1.0));
454458
}
455459

456460
#endif

aten/src/ATen/cpu/vec/vec256/vec256_complex_float.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -483,12 +483,16 @@ Vectorized<c10::complex<float>> inline operator^(const Vectorized<c10::complex<f
483483

484484
inline Vectorized<c10::complex<float>> Vectorized<c10::complex<float>>::eq(
485485
const Vectorized<c10::complex<float>>& other) const {
486-
return (*this == other) & Vectorized<c10::complex<float>>(_mm256_set1_ps(1.0f));
486+
auto eq = (*this == other); // compares real and imag individually
487+
// If both real numbers and imag numbers are equal, then the complex numbers are equal
488+
return (eq.real() & eq.imag()) & Vectorized<c10::complex<float>>(_mm256_set1_ps(1.0f));
487489
}
488490

489491
inline Vectorized<c10::complex<float>> Vectorized<c10::complex<float>>::ne(
490492
const Vectorized<c10::complex<float>>& other) const {
491-
return (*this != other) & Vectorized<c10::complex<float>>(_mm256_set1_ps(1.0f));
493+
auto ne = (*this != other); // compares real and imag individually
494+
// If either real numbers or imag numbers are not equal, then the complex numbers are not equal
495+
return (ne.real() | ne.imag()) & Vectorized<c10::complex<float>>(_mm256_set1_ps(1.0f));
492496
}
493497

494498
#endif

aten/src/ATen/cpu/vec/vec256/vsx/vec256_complex_double_vsx.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -515,12 +515,14 @@ class Vectorized<ComplexDbl> {
515515
}
516516

517517
Vectorized<ComplexDbl> eq(const Vectorized<ComplexDbl>& other) const {
518-
auto ret = (*this == other);
519-
return ret & vd_one;
518+
auto eq = (*this == other); // compares real and imag individually
519+
// If both real numbers and imag numbers are equal, then the complex numbers are equal
520+
return (eq.real() & eq.imag()) & vd_one;
520521
}
521522
Vectorized<ComplexDbl> ne(const Vectorized<ComplexDbl>& other) const {
522-
auto ret = (*this != other);
523-
return ret & vd_one;
523+
auto ne = (*this != other); // compares real and imag individually
524+
// If either real numbers or imag numbers are not equal, then the complex numbers are not equal
525+
return (ne.real() | ne.imag()) & vd_one;
524526
}
525527

526528
Vectorized<ComplexDbl> lt(const Vectorized<ComplexDbl>& other) const {

aten/src/ATen/cpu/vec/vec256/vsx/vec256_complex_float_vsx.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -540,12 +540,14 @@ class Vectorized<ComplexFlt> {
540540
}
541541

542542
Vectorized<ComplexFlt> eq(const Vectorized<ComplexFlt>& other) const {
543-
auto ret = (*this == other);
544-
return ret & one;
543+
auto eq = (*this == other); // compares real and imag individually
544+
// If both real numbers and imag numbers are equal, then the complex numbers are equal
545+
return (eq.real() & eq.imag()) & one;
545546
}
546547
Vectorized<ComplexFlt> ne(const Vectorized<ComplexFlt>& other) const {
547-
auto ret = (*this != other);
548-
return ret & one;
548+
auto ne = (*this != other); // compares real and imag individually
549+
// If either real numbers or imag numbers are not equal, then the complex numbers are not equal
550+
return (ne.real() | ne.imag()) & one;
549551
}
550552

551553
Vectorized<ComplexFlt> sgn() const {

aten/src/ATen/cpu/vec/vec256/zarch/vec256_zarch.h

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2212,10 +2212,18 @@ struct Vectorized<T, std::enable_if_t<is_zarch_implemented_complex<T>()>> {
22122212
}
22132213

22142214
Vectorized<T> C10_ALWAYS_INLINE eq(const Vectorized<T>& other) const {
2215-
return Vectorized<T>{_vec.eq(other._vec)};
2215+
auto eq = _vec.eq(other._vec); // compares real and imag individually
2216+
// If both real numbers and imag numbers are equal, then the complex numbers are equal
2217+
auto real = eq & vinner_type(real_mask<underline_type>());
2218+
auto imag = (eq & vinner_type(image_mask<underline_type>())).swapped();
2219+
return Vectorized<T>{real & imag};
22162220
}
22172221
Vectorized<T> C10_ALWAYS_INLINE ne(const Vectorized<T>& other) const {
2218-
return Vectorized<T>{_vec.ne(other._vec)};
2222+
auto ne = _vec.ne(other._vec); // compares real and imag individually
2223+
// If either real numbers or imag numbers are not equal, then the complex numbers are not equal
2224+
auto real = ne & vinner_type(real_mask<underline_type>());
2225+
auto imag = (ne & vinner_type(image_mask<underline_type>())).swapped();
2226+
return Vectorized<T>{real | imag};
22192227
}
22202228

22212229
Vectorized<T> real() const {

aten/src/ATen/cpu/vec/vec512/vec512_complex_double.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -528,11 +528,15 @@ Vectorized<c10::complex<double>> inline operator^(const Vectorized<c10::complex<
528528
}
529529

530530
inline Vectorized<c10::complex<double>> Vectorized<c10::complex<double>>::eq(const Vectorized<c10::complex<double>>& other) const {
531-
return (*this == other) & Vectorized<c10::complex<double>>(_mm512_set1_pd(1.0));
531+
auto eq = (*this == other); // compares real and imag individually
532+
// If both real numbers and imag numbers are equal, then the complex numbers are equal
533+
return (eq.real() & eq.imag()) & Vectorized<c10::complex<double>>(_mm512_set1_pd(1.0));
532534
}
533535

534536
inline Vectorized<c10::complex<double>> Vectorized<c10::complex<double>>::ne(const Vectorized<c10::complex<double>>& other) const {
535-
return (*this != other) & Vectorized<c10::complex<double>>(_mm512_set1_pd(1.0));
537+
auto ne = (*this != other); // compares real and imag individually
538+
// If either real numbers or imag numbers are not equal, then the complex numbers are not equal
539+
return (ne.real() | ne.imag()) & Vectorized<c10::complex<double>>(_mm512_set1_pd(1.0));
536540
}
537541

538542
#endif

aten/src/ATen/cpu/vec/vec512/vec512_complex_float.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1032,12 +1032,16 @@ Vectorized<c10::complex<float>> inline operator^(const Vectorized<c10::complex<f
10321032

10331033
inline Vectorized<c10::complex<float>> Vectorized<c10::complex<float>>::eq(
10341034
const Vectorized<c10::complex<float>>& other) const {
1035-
return (*this == other) & Vectorized<c10::complex<float>>(_mm512_set1_ps(1.0f));
1035+
auto eq = (*this == other); // compares real and imag individually
1036+
// If both real numbers and imag numbers are equal, then the complex numbers are equal
1037+
return (eq.real() & eq.imag()) & Vectorized<c10::complex<float>>(_mm512_set1_ps(1.0f));
10361038
}
10371039

10381040
inline Vectorized<c10::complex<float>> Vectorized<c10::complex<float>>::ne(
10391041
const Vectorized<c10::complex<float>>& other) const {
1040-
return (*this != other) & Vectorized<c10::complex<float>>(_mm512_set1_ps(1.0f));
1042+
auto ne = (*this != other); // compares real and imag individually
1043+
// If either real numbers or imag numbers are not equal, then the complex numbers are not equal
1044+
return (ne.real() | ne.imag()) & Vectorized<c10::complex<float>>(_mm512_set1_ps(1.0f));
10411045
}
10421046

10431047
#endif

aten/src/ATen/cuda/CUDAGraph.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ void CUDAGraph::capture_begin(MempoolId_t pool/*=0*/) {
7171
offset_extragraph_ = at::empty({1}, options);
7272

7373
seed_extragraph_.fill_(int64_t(gen->current_seed()));
74-
gen->capture_prologue(seed_extragraph_.data_ptr<int64_t>(), offset_extragraph_.data_ptr<int64_t>());
74+
gen->capture_prologue(seed_extragraph_.data_ptr<int64_t>(), offset_extragraph_.mutable_data_ptr<int64_t>());
7575

7676
auto stream = at::cuda::getCurrentCUDAStream();
7777

0 commit comments

Comments
 (0)