Add torch.dot for complex tensors#42745
Add torch.dot for complex tensors#42745anjali411 wants to merge 13 commits intogh/anjali411/49/basefrom
Conversation
[ghstack-poisoned]
💊 CI failures summary and remediationsAs of commit 1ff598c (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group. This comment has been revised 64 times. |
TODO: potentially add a fast path for complex dot [ghstack-poisoned]
TODO: potentially add a fast path for complex dot [ghstack-poisoned]
TODO: potentially add a fast path for complex dot [ghstack-poisoned]
|
|
||
| #if AT_BUILD_WITH_BLAS() | ||
| extern "C" double ddot_(int *n, double *x, int *incx, double *y, int *incy); | ||
| extern "C" void zdotu_(std::complex<double> *res, int *n, std::complex<double> *x, int *incx, std::complex<double> *y, int *incy); |
There was a problem hiding this comment.
https://www.math.utah.edu/software/c-with-fortran.html#function-return-types
"...you should not expect to use Fortran functions that return types such as COMPLEX or COMPLEX*16. Write a SUBROUTINE interface to your Fortran function instead, and then invoke it as a void function from C or C++."
[ghstack-poisoned]
|
|
||
| template <> | ||
| void dot<c10::complex<double>>(CUDABLAS_DOT_ARGTYPES(c10::complex<double>)) { | ||
| TORCH_CUDABLAS_CHECK(cublasZdotu(handle, n, reinterpret_cast<const cuDoubleComplex*>(x), |
There was a problem hiding this comment.
Hmm, I wonder if we shouldn't have some methods on c10::complex for doing pointery conversions like this. It would be nice to not have to be slinging reinterpret cast everywhere. (No action needed for PR)
[ghstack-poisoned]
Differential Revision: [D23056382](https://our.internmc.facebook.com/intern/diff/D23056382) [ghstack-poisoned]
Differential Revision: [D23056382](https://our.internmc.facebook.com/intern/diff/D23056382) [ghstack-poisoned]
Differential Revision: [D23056382](https://our.internmc.facebook.com/intern/diff/D23056382) [ghstack-poisoned]
Differential Revision: [D23056382](https://our.internmc.facebook.com/intern/diff/D23056382) [ghstack-poisoned]
Differential Revision: [D23056382](https://our.internmc.facebook.com/intern/diff/D23056382) [ghstack-poisoned]
|
@anjali411 as discussed offline, the reason CPU results for the ROCm build are failing is because diff --git a/aten/src/ATen/native/BlasKernel.cpp b/aten/src/ATen/native/BlasKernel.cpp
index ef05cb8..1fe8a73 100644
--- a/aten/src/ATen/native/BlasKernel.cpp
+++ b/aten/src/ATen/native/BlasKernel.cpp
@@ -17,9 +17,6 @@ extern "C" void sgemv_(char *trans, int *m, int *n, float *alpha, float *a, int
# define ffloat float
#endif
-extern "C" ffloat sdot_(int *n, float *x, int *incx, float *y, int *incy);
-extern "C" void cdotu_(std::complex<float> *res, int *n, std::complex<float> *x, int *incx, std::complex<float> *y, int *incy);
-extern "C" void zdotu_(std::complex<double> *res, int *n, std::complex<double> *x, int *incx, std::complex<double> *y, int *incy);
#ifdef BLAS_USE_CBLAS_DOT
extern "C" float cblas_sdot(const int n, const float *x, const int incx, const float *y, const int incy);
@@ -40,6 +37,10 @@ static inline void zdotu_(std::complex<double> *res, const int *n, const std::co
cblas_zdotu_sub(*n, x, *incx, y, *incy, res);
}
#endif // THBlas_cblas_dot_
+#else // BLAS_USE_CBLAS_DOT
+extern "C" ffloat sdot_(int *n, float *x, int *incx, float *y, int *incy);
+extern "C" void cdotu_(std::complex<float> *res, int *n, std::complex<float> *x, int *incx, std::complex<float> *y, int *incy);
+extern "C" void zdotu_(std::complex<double> *res, int *n, std::complex<double> *x, int *incx, std::complex<double> *y, int *incy);
#endif // BLAS_USE_CBLAS_DOT
#endif // AT_BUILD_WITH_BLAS |
|
BLAS development began in fortran, and the function calling convention differs between compilers and compiler flags, making it difficult to get correct when calling the fortran function from C. Sometimes the complex value is returned as a hidden positional argument as in your extern declaration, and sometimes it is returned like a regular C function call return. Since the openblas symbols were getting linked, and their extern declaration was incorrect, the function was getting called but there was a mismatch between the expected function signature and how it was declared. |
Differential Revision: [D23056382](https://our.internmc.facebook.com/intern/diff/D23056382) [ghstack-poisoned]
Differential Revision: [D23056382](https://our.internmc.facebook.com/intern/diff/D23056382) [ghstack-poisoned]
|
@anjali411 merged this pull request in aab6660. |
| TH_EXTERNC void cblas_cdotu_sub(const int n, const void *x, const int incx, const void *y, const int incy, void *dotu); | ||
| TH_EXTERNC void cblas_zdotu_sub(const int n, const void *x, const int incx, const void *y, const int incy, void *dotu); | ||
|
|
||
| #ifndef THBlas_cblas_dot_ |
There was a problem hiding this comment.
why are these symbols being defined?
Summary: Pull Request resolved: pytorch#42745 Test Plan: Imported from OSS Reviewed By: izdeby Differential Revision: D23056382 Pulled By: anjali411 fbshipit-source-id: c97f15e057095f78069844dbe0299c14104d2fce
Stack from ghstack:
Differential Revision: D23056382