Skip to content

Support cuComplex.h in cupy.RawKernel and cupy.RawModule#2551

Merged
kmaehashi merged 15 commits intocupy:masterfrom
leofang:cuComplex
Nov 12, 2019
Merged

Support cuComplex.h in cupy.RawKernel and cupy.RawModule#2551
kmaehashi merged 15 commits intocupy:masterfrom
leofang:cuComplex

Conversation

@leofang
Copy link
Copy Markdown
Member

@leofang leofang commented Oct 18, 2019

Closes #1866. Closes #2111.

Previously the support of complex numbers in cupy.RawKernel (and thus cupy.RawModule) was poorly documented, see #1866 (comment) and #2111 for example. This PR aims to solve this problem once and for all. This PR supersedes #1866.

The support of the APIs from cuComplex.h is enabled by prepending a set of C macros that replace those APIs by their Thrust counterparts. Thus, users can just copy and paste their cuComplex-based codes to cupy.RawKernel or cupy.RawModule, and make them work without any modification. Note that the complex mathematical functions are supported without any additional macros.

For backward compatibility, the newly added option enable_cuComplex is set to False by default.

attn: @grlee77

@leofang
Copy link
Copy Markdown
Member Author

leofang commented Oct 18, 2019

Note that the commits in #1866 are taken over to here.

@leofang leofang mentioned this pull request Oct 18, 2019
9 tasks
Comment on lines +141 to +142
#define cuCrealf real
#define cuCimagf imag
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is enabled by #2520.

Comment on lines +156 to +157
#define cuCreal real
#define cuCimag imag
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto.

@hvy
Copy link
Copy Markdown
Member

hvy commented Oct 21, 2019

Sorry for the drop by comment but thanks for the PR @leofang. @kmaehashi could you have a look at this PR? If you're busy, then @asi1024 said he might be able to take over.

@leofang
Copy link
Copy Markdown
Member Author

leofang commented Oct 21, 2019

@hvy No need to apologize! Thanks for your followup and kind reply.


# First, we comment out the line that includes cuComplex.h
for i, line in enumerate(source_lines):
if '#include' in line and 'cuComplex.h' in line:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, this approach sounds fragile and may surprise users. Do we really need to support cuComplex.h? How about providing a bridge interface in the separate header file (in that case users need to replace cuComplex.h to sth like cupy/cuComplex_bridge.h)?

Copy link
Copy Markdown
Member Author

@leofang leofang Oct 24, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for raising your concern, @kmaehashi.

Do we really need to support cuComplex.h?

Short answer: Yes. Please do.

Long answer: #2111 is just a long list of the cases I have encountered that request this. Please do not surprise our users. Many scientific projects need complex numbers, but some existing CUDA projects either do not want to depend on Thrust, or it's too large to do a clean overhaul to transit to Thrust's complex types. So, when people port code to CuPy, they are often frustrated. I've helped many people resolve this, and I think this PR solves the issue once and for all. (So people would stop bothering their local CuPy expert, i.e., me...)

If you are really, really against such a support, can you please at least merge #1866? It's hanging there for quite a while, and it's really frustrating. In either case, we need to make it clear to the users. No surprises.

How about providing a bridge interface in the separate header file (in that case users need to replace cuComplex.h to sth like cupy/cuComplex_bridge.h)?

I do not think this is less fragile than the current Pythonic approach. First, this requires code change (although it's only one line), which is against the design principle I had in mind when pushing for RawModule. (Users should just import the source code and let CuPy take care of the rest to make it work.)

Second, in the end the macros in this PR go into cupy/cuComplex_bridge.h that you suggested, but the net effect, after the C preprocessing, is identical. So I don't see why we wanna bother users to make a change?

If you are still concerned and not fully convinced, may I propose to mark the new enable_cuComplex flag as experimental and subject to change? Let me reiterate that the changes in this PR are not enabled by default, so normal use cases would just continue to work (as you will see from Jenkins). It would not hurt to consider this PR.

Thanks again, @kmaehashi.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kmaehashi just check in. Any comment? Thanks.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, now I understand your motivation to use cuComplex code directly on CuPy RawModule/RawKernel.

  • Regarding the option name, how about translate_cucomplex instead of enable_cuComplex? (1) It is difficult for users to imagine that the code is automatically rewritten when enable is set. (2) pep8 recommends lowercase for variable names.
  • Could you make the option a keyword-only argument so that we can easily consider reorder/renaming later?
  • Please use more stricter regex that matches the whole line to detect a header file. You need to consider cases like: # include <cuComplex.h> or // #include <cuComplex.h>
  • Could you extract macro contents as a separate file (cupy/cuComplex_bridge.h)?
    Then replace #include <cuComplex.h> with #include <cupy/cuComplex_bridge.h> // translate_cucomplex to minimize the amount of modification? (this also keeps the original line numbers so it's better for cuda-gdb)

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you very much for your understanding and detailed comments, @kmaehashi. I'm working on it.

Two quick replies:

Could you make the option a keyword-only argument so that we can easily consider reorder/renaming later?

If I understand you correctly, I don't think for compile_with_cache this is possible unfortunately. cdef and cpdef functions do not support **kwargs in the end of the function signature (will get an Expected ')', found '**' error).

this also keeps the original line numbers so it's better for cuda-gdb

This is a very good suggestion which I would never come up myself! I never used cuda-gdb after all...

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand you correctly, I don't think for compile_with_cache this is possible unfortunately. cdef and cpdef functions do not support kwargs in the end of the function signature (will get an Expected ')', found '' error).

Yes, you're correct. I think it's OK for compile_with_cache as it is not public (documented) interface.

@leofang
Copy link
Copy Markdown
Member Author

leofang commented Nov 8, 2019

Done. PTAL.

@leofang
Copy link
Copy Markdown
Member Author

leofang commented Nov 11, 2019

The requested changes are done except for the regex one. Tested locally.

Copy link
Copy Markdown
Member

@kmaehashi kmaehashi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some more nitpicks

leofang and others added 2 commits November 11, 2019 03:05
Co-authored-by: Kenichi Maehashi <webmaster@kenichimaehashi.com>
Co-Authored-By: Kenichi Maehashi <webmaster@kenichimaehashi.com>
@leofang
Copy link
Copy Markdown
Member Author

leofang commented Nov 11, 2019

@kmaehashi Most requested changes are done except for the **kwargs comment.

Also, I took the liberty of removing six-related code in cupy/core/raw.pyx. Since PY2 support is dropped, I hope this is acceptable. If not, I'll bring them back.

Thanks.

@kmaehashi
Copy link
Copy Markdown
Member

pfnCI, test this please.
(Running tests as this PR is almost complete.)

@pfn-ci-bot
Copy link
Copy Markdown
Collaborator

Successfully created a job for commit 9c02d56:

@chainer-ci
Copy link
Copy Markdown
Member

Jenkins CI test (for commit 9c02d56, target branch master) failed with status FAILURE.

@kmaehashi kmaehashi added cat:enhancement Improvements to existing features takeover Pull-requests taken over from other contributor labels Nov 11, 2019
@leofang
Copy link
Copy Markdown
Member Author

leofang commented Nov 12, 2019

Everything seems to work. The only CI failure seems unrelated. The six support is restored, and the keyword-only argument is used. @kmaehashi PTAL, thanks.

Copy link
Copy Markdown
Member

@kmaehashi kmaehashi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@kmaehashi
Copy link
Copy Markdown
Member

pfnCI, test this please.

@pfn-ci-bot
Copy link
Copy Markdown
Collaborator

Successfully created a job for commit 0ff24d7:

@chainer-ci
Copy link
Copy Markdown
Member

Jenkins CI test (for commit 0ff24d7, target branch master) failed with status FAILURE.

@kmaehashi
Copy link
Copy Markdown
Member

Test failures are not related to this PR.

@kmaehashi kmaehashi merged commit af7e41e into cupy:master Nov 12, 2019
@kmaehashi kmaehashi added this to the v7.0.0 milestone Nov 12, 2019
@leofang leofang deleted the cuComplex branch November 12, 2019 12:10
@leofang
Copy link
Copy Markdown
Member Author

leofang commented Nov 12, 2019

Thanks a lot @kmaehashi for your help and @grlee77 for the update in the docs!

grlee77 added a commit to grlee77/cupy that referenced this pull request Nov 14, 2019
* upstream/master: (28 commits)
  apply review (need six; use keyword-only args)
  Apply suggestions from code review
  use regex
  apply review
  apply review from cupy#2551 (comment)
  Try using a shape of `()` for `tobytes` testing
  Update tests/cupy_tests/core_tests/test_ndarray_conversion.py
  Implement `tobytes` for CuPy arrays
  apply review
  Type `dumps` return value as `bytes`
  silence flake8
  fix cupy#2610
  Update CODE_OF_CONDUCT.md
  silence flake8
  ensure axis is None before device_reduce
  temporarily removed WIP for device segmented reduce
  device reduce for argmin and argmax works
  this compiles but is not yet tested
  propagate metadata to RawKernels
  ensure operation order is preserved
  ...
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cat:enhancement Improvements to existing features takeover Pull-requests taken over from other contributor

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Use complex number in cupy.RawKernel

6 participants