Skip to content

Commit 1bb21a0

Browse files
committed
Update on "Doc update for complex numbers"
[ghstack-poisoned]
2 parents 373dd30 + ba316a7 commit 1bb21a0

43 files changed

Lines changed: 1120 additions & 159 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.gitmodules

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[submodule "third_party/pybind11"]
22
ignore = dirty
33
path = third_party/pybind11
4-
url = https://github.com/seemethere/pybind11.git
4+
url = https://github.com/pybind/pybind11.git
55
[submodule "third_party/cub"]
66
ignore = dirty
77
path = third_party/cub

aten/src/ATen/core/Vitals.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
#include <ATen/core/Vitals.h>
2+
#include <cstdlib>
3+
4+
namespace at {
5+
namespace vitals {
6+
7+
TorchVitalAttr& TorchVital::create(const std::string& attr) {
8+
if (!torchVitalEnabled()) {
9+
static TorchVitalAttr disabled;
10+
return disabled;
11+
}
12+
auto iter = attrs.find(attr);
13+
if (iter == attrs.end()) {
14+
auto r = attrs.emplace(std::make_pair(attr, TorchVitalAttr()));
15+
return r.first->second;
16+
}
17+
return iter->second;
18+
}
19+
20+
bool torchVitalEnabled() {
21+
// If this is a performance hit, make `enabled` variable static
22+
// and return `const bool&` instead
23+
bool enabled = []() {
24+
auto e = getenv("TORCH_VITAL");
25+
if (e != nullptr) {
26+
return strlen(e) > 0;
27+
}
28+
return false;
29+
}();
30+
return enabled;
31+
}
32+
33+
} // namespace at
34+
} // namespace vitals

aten/src/ATen/core/Vitals.h

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
#pragma once
2+
#include <cstring>
3+
#include <iostream>
4+
#include <sstream>
5+
#include <unordered_map>
6+
7+
namespace at {
8+
namespace vitals {
9+
10+
bool torchVitalEnabled();
11+
12+
struct TorchVitalAttr {
13+
// always initialized to empty
14+
std::string value = "";
15+
template <typename T>
16+
TorchVitalAttr& operator<<(const T& t) {
17+
if (torchVitalEnabled()) {
18+
std::stringstream ss;
19+
ss << t;
20+
value += ss.str();
21+
}
22+
return *this;
23+
}
24+
};
25+
26+
struct TorchVital {
27+
std::string name;
28+
std::unordered_map<std::string, TorchVitalAttr> attrs;
29+
30+
explicit TorchVital(std::string n) : name(std::move(n)) {}
31+
TorchVital() = delete;
32+
33+
TorchVitalAttr& create(const std::string& attr);
34+
35+
~TorchVital() {
36+
for (const auto& m : attrs) {
37+
std::cout << "[TORCH_VITAL] " << name << "." << m.first << "\t\t "
38+
<< m.second.value << "\n";
39+
}
40+
}
41+
};
42+
43+
} // namespace at
44+
} // namespace vitals
45+
46+
#define TORCH_VITAL_DECLARE(name) extern TorchVital TorchVital_##name;
47+
48+
#define TORCH_VITAL_DEFINE(name) TorchVital TorchVital_##name(#name);
49+
50+
#define TORCH_VITAL(name, attr) TorchVital_##name.create(#attr)

aten/src/ATen/cuda/CUDABlas.cpp

Lines changed: 21 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,6 @@ void bgemm<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)) {
327327

328328
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
329329
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
330-
TORCH_CHECK(prop->major >= 8, "BFloat16 bgemm in CUDA requires Ampere or later GPU");
331330
TORCH_CUDABLAS_CHECK(cublasGemmStridedBatchedExFix(handle,
332331
opa, opb, (int)m, (int)n, (int)k,
333332
(void*)&falpha, a, CUDA_R_16BF, (int)lda, stridea,
@@ -343,7 +342,7 @@ void bgemm<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)) {
343342
(int) num_batches, rocblas_datatype_f32_r, rocblas_gemm_algo_standard,
344343
0, 0, NULL, NULL));
345344
#else
346-
TORCH_CHECK(false, "BFloat16 bgemm in CUDA requires Ampere or later GPU");
345+
TORCH_CHECK(false, "CUDA BFloat16 bgemm requires CUDA 11 or later");
347346
#endif // defined(CUDA_VERSION) && CUDA_VERSION >= 11000
348347
}
349348
#endif // __HIP_PLATFORM_HCC__
@@ -550,37 +549,26 @@ void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
550549
float fbeta = beta;
551550
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
552551
GEMM_CHECK_ARGVALUES(at::BFloat16);
553-
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
554-
if (prop->major >= 8) {
555-
// On CUDA versions prior to 11, users are required to set the math mode to CUBLAS_TENSOR_OP_MATH
556-
// manually to be able to use tensor cores for FP16. On CUDA 11, this is no longer required.
557-
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
558-
TORCH_CUDABLAS_CHECK(cublasGemmEx(
559-
handle,
560-
opa,
561-
opb,
562-
m,
563-
n,
564-
k,
565-
&falpha,
566-
a,
567-
CUDA_R_16BF,
568-
lda,
569-
b,
570-
CUDA_R_16BF,
571-
ldb,
572-
&fbeta,
573-
c,
574-
CUDA_R_16BF,
575-
ldc,
576-
CUDA_R_32F,
577-
CUBLAS_GEMM_DFALT_TENSOR_OP));
578-
// On CUDA versions prior to 11, users are required to set the math mode to CUBLAS_TENSOR_OP_MATH
579-
// manually to be able to use tensor cores for FP16. On CUDA 11, this is no longer required.
580-
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
581-
} else {
582-
TORCH_CHECK(false, "BFloat16 gemm in CUDA requires Ampere or later GPU");
583-
}
552+
TORCH_CUDABLAS_CHECK(cublasGemmEx(
553+
handle,
554+
opa,
555+
opb,
556+
m,
557+
n,
558+
k,
559+
&falpha,
560+
a,
561+
CUDA_R_16BF,
562+
lda,
563+
b,
564+
CUDA_R_16BF,
565+
ldb,
566+
&fbeta,
567+
c,
568+
CUDA_R_16BF,
569+
ldc,
570+
CUDA_R_32F,
571+
CUBLAS_GEMM_DFALT_TENSOR_OP));
584572
}
585573
#endif
586574

aten/src/ATen/native/LinearAlgebra.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ std::tuple<Tensor, Tensor> slogdet(const Tensor& self) {
126126
}
127127

128128
Tensor linalg_pinv(const Tensor& input, const Tensor& rcond, bool hermitian) {
129+
NoTF32Guard disable_tf32;
129130
ScalarType t = input.scalar_type();
130131
TORCH_CHECK((t == ScalarType::Double || t == ScalarType::Float || t == ScalarType::ComplexFloat || t == ScalarType::ComplexDouble)
131132
&& input.dim() >= 2,

aten/src/ATen/test/vitals.cpp

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
#include <gtest/gtest.h>
2+
3+
#include <ATen/ATen.h>
4+
#include <ATen/core/Vitals.h>
5+
#include <cstdlib>
6+
7+
using namespace at::vitals;
8+
9+
TEST(Vitals, Basic) {
10+
std::stringstream buffer;
11+
12+
std::streambuf* sbuf = std::cout.rdbuf();
13+
std::cout.rdbuf(buffer.rdbuf());
14+
{
15+
#ifdef _WIN32
16+
_putenv("TORCH_VITAL=1");
17+
#else
18+
setenv("TORCH_VITAL", "1", 1);
19+
#endif
20+
TORCH_VITAL_DEFINE(Testing);
21+
TORCH_VITAL(Testing, Attribute0) << 1;
22+
TORCH_VITAL(Testing, Attribute1) << "1";
23+
TORCH_VITAL(Testing, Attribute2) << 1.0f;
24+
TORCH_VITAL(Testing, Attribute3) << 1.0;
25+
auto t = at::ones({1, 1});
26+
TORCH_VITAL(Testing, Attribute4) << t;
27+
}
28+
std::cout.rdbuf(sbuf);
29+
30+
auto s = buffer.str();
31+
ASSERT_TRUE(s.find("Testing.Attribute0\t\t 1") != std::string::npos);
32+
ASSERT_TRUE(s.find("Testing.Attribute1\t\t 1") != std::string::npos);
33+
ASSERT_TRUE(s.find("Testing.Attribute2\t\t 1") != std::string::npos);
34+
ASSERT_TRUE(s.find("Testing.Attribute3\t\t 1") != std::string::npos);
35+
ASSERT_TRUE(s.find("Testing.Attribute4\t\t 1") != std::string::npos);
36+
}
37+
38+
TEST(Vitals, MultiString) {
39+
std::stringstream buffer;
40+
41+
std::streambuf* sbuf = std::cout.rdbuf();
42+
std::cout.rdbuf(buffer.rdbuf());
43+
{
44+
#ifdef _WIN32
45+
_putenv("TORCH_VITAL=1");
46+
#else
47+
setenv("TORCH_VITAL", "1", 1);
48+
#endif
49+
TORCH_VITAL_DEFINE(Testing);
50+
TORCH_VITAL(Testing, Attribute0) << 1 << " of " << 2;
51+
TORCH_VITAL(Testing, Attribute1) << 1;
52+
TORCH_VITAL(Testing, Attribute1) << " of ";
53+
TORCH_VITAL(Testing, Attribute1) << 2;
54+
}
55+
std::cout.rdbuf(sbuf);
56+
57+
auto s = buffer.str();
58+
ASSERT_TRUE(s.find("Testing.Attribute0\t\t 1 of 2") != std::string::npos);
59+
ASSERT_TRUE(s.find("Testing.Attribute1\t\t 1 of 2") != std::string::npos);
60+
}
61+
62+
TEST(Vitals, OnAndOff) {
63+
for (auto i = 0; i < 2; ++i) {
64+
std::stringstream buffer;
65+
66+
std::streambuf* sbuf = std::cout.rdbuf();
67+
std::cout.rdbuf(buffer.rdbuf());
68+
{
69+
#ifdef _WIN32
70+
if (i) {
71+
_putenv("TORCH_VITAL=1");
72+
} else {
73+
_putenv("TORCH_VITAL=0");
74+
}
75+
#else
76+
setenv("TORCH_VITAL", i ? "1" : "", 1);
77+
#endif
78+
TORCH_VITAL_DEFINE(Testing);
79+
TORCH_VITAL(Testing, Attribute0) << 1;
80+
}
81+
std::cout.rdbuf(sbuf);
82+
83+
auto s = buffer.str();
84+
auto f = s.find("Testing.Attribute0\t\t 1");
85+
if (i) {
86+
ASSERT_TRUE(f != std::string::npos);
87+
} else {
88+
ASSERT_TRUE(f == std::string::npos);
89+
}
90+
}
91+
}

0 commit comments

Comments
 (0)