Skip to content

Add overloads of std:: math functions for c10::complex [resubmit]#37468

Closed
zasdfgbnm wants to merge 11 commits intomasterfrom
ci-all/complex-math
Closed

Add overloads of std:: math functions for c10::complex [resubmit]#37468
zasdfgbnm wants to merge 11 commits intomasterfrom
ci-all/complex-math

Conversation

@zasdfgbnm
Copy link
Copy Markdown
Collaborator

This reverts commit d167a7f.

@dr-ci
Copy link
Copy Markdown

dr-ci Bot commented Apr 29, 2020

💊 Build failures summary and remediations

As of commit 9637a1c (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker.

See how this bot performed.

This comment has been revised 31 times.

@zasdfgbnm zasdfgbnm added the module: complex Related to complex number support in PyTorch label Apr 29, 2020
@zasdfgbnm zasdfgbnm marked this pull request as ready for review April 29, 2020 03:09
@zasdfgbnm
Copy link
Copy Markdown
Collaborator Author

@anjali411 This should be ready, the fix is:

#if CUDA_VERSION < 10000
#define CUDA92_BUG(x) thrust::complex<T>(x.real(), x.imag())
#else
#define CUDA92_BUG(x) x
#endif

template<typename T>
C10_HOST_DEVICE c10::complex<T> exp(c10::complex<T> x) {
#if defined(__CUDACC__) || defined(__HIPCC__)
  return static_cast<c10::complex<T>>(thrust::exp(static_cast<thrust::complex<T>>(CUDA92_BUG(x))));
#else
  return static_cast<c10::complex<T>>(std::exp(static_cast<std::complex<T>>(x)));
#endif
}

@zasdfgbnm zasdfgbnm requested a review from anjali411 April 29, 2020 06:29
Copy link
Copy Markdown
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@anjali411 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@zasdfgbnm
Copy link
Copy Markdown
Collaborator Author

@anjali411 Conflicts resolved.

Copy link
Copy Markdown
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@anjali411 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Copy Markdown
Contributor

@anjali411 merged this pull request in c5624e8.

@zasdfgbnm zasdfgbnm deleted the ci-all/complex-math branch April 30, 2020 18:30
laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 24, 2026
…torch#37468)

Summary:
This reverts commit 080a1ae.
Pull Request resolved: pytorch#37468

Differential Revision: D21305110

Pulled By: anjali411

fbshipit-source-id: d1bdc9d9feac00331fc2b2b905d49f80bef680f9
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged module: complex Related to complex number support in PyTorch open source

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants