Skip to content

DOC: add RawKernel example using complex-valued arrays#1866

Closed
grlee77 wants to merge 3 commits intocupy:masterfrom
grlee77:rawkernel_docs
Closed

DOC: add RawKernel example using complex-valued arrays#1866
grlee77 wants to merge 3 commits intocupy:masterfrom
grlee77:rawkernel_docs

Conversation

@grlee77
Copy link
Copy Markdown
Member

@grlee77 grlee77 commented Dec 5, 2018

This adds an example to the custom kernels tutorial showing how to define a RawKernel for complex-valued arrays. The appropriate include file and notation took me some reading of the cupy sources and some trial & error to figure out. Perhaps there are also other ways, but this was what I got working.

Notably, it seems one should use the C++-style complex<float> notation from Thrust and not the C-style cuFloatComplex, etc.

The example also demonstrates casting of a constant to the appropriate dtype.

@asi1024
Copy link
Copy Markdown
Member

asi1024 commented Dec 6, 2018

Thank you for PR!

We don't want to recommend this solution (because of #1398 (comment)), but we agree with you that we need to provide a way to handle complex arrays. We'll discuss it and reply soon.

@grlee77
Copy link
Copy Markdown
Member Author

grlee77 commented Dec 6, 2018

Thanks. I understand if you want to implement and recommend a different approach.

Is my proposed approach here likely to continue to work in the future? If so, I can use it for now until/unless something better comes along.

@kmaehashi
Copy link
Copy Markdown
Member

Signatures like complex<float> can be considered stable.
The header filename <cupy/complex.cuh> is considered as a kind of private API, but currently there's no plan to change it.

@okuta okuta added the cat:document Documentation label Dec 24, 2018
@grlee77
Copy link
Copy Markdown
Member Author

grlee77 commented Jan 4, 2019

I am closing this based on the feedback above. Thanks

@grlee77 grlee77 closed this Jan 4, 2019
@kmaehashi kmaehashi added this to the Closed issues and PRs milestone Jan 24, 2019
@leofang
Copy link
Copy Markdown
Member

leofang commented Jan 26, 2019

@kmaehashi @asi1024 can we please consider reopening this PR? It is important to keep this non-intuitive usage documented somewhere, because CUDA's native cuComplex.h is not supported by cupy.Rawkernel (or CuPy's NVRTC wrapper). As a simple example, attempting to run the code below

import cupy as cp
cuda_ker = r'''
#include <cuComplex.h>

extern "C"{
__global__ void add_one(cuDoubleComplex * A, int N) {
   int id = threadIdx.x + blockIdx.x * blockDim.x;
   if (id<N) {
      A[id] = cuCadd(A[id], make_cuDoubleComplex(1.0, 0.0));
   }
}
}
'''

a = cp.random.random(10) + 1j*cp.random.random(10)
b = a.copy()
ker = cp.RawKernel(cuda_ker, 'add_one')
block = (10, 0, 0)
grid = (1, 0 , 0)
args = (b, 10)
ker(grid, block, args)
assert (a+1==b).all()

will see the following error

Traceback (most recent call last):
  File "/GPFS/XF03ID1/home/leofang/test/cupy2/cupy/cuda/compiler.py", line 242, in compile
    nvrtc.compileProgram(self.ptr, options)
  File "cupy/cuda/nvrtc.pyx", line 98, in cupy.cuda.nvrtc.compileProgram
  File "cupy/cuda/nvrtc.pyx", line 108, in cupy.cuda.nvrtc.compileProgram
  File "cupy/cuda/nvrtc.pyx", line 53, in cupy.cuda.nvrtc.check_status
cupy.cuda.nvrtc.NVRTCError: NVRTC_ERROR_COMPILATION (6)

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "test_raw3.py", line 21, in <module>
    ker(grid, block, args)
  File "cupy/core/raw.pyx", line 45, in cupy.core.raw.RawKernel.__call__
  File "cupy/util.pyx", line 48, in cupy.util.memoize.decorator.ret
  File "cupy/core/raw.pyx", line 51, in cupy.core.raw._get_raw_kernel
  File "cupy/core/carray.pxi", line 125, in cupy.core.core.compile_with_cache
  File "cupy/core/carray.pxi", line 164, in cupy.core.core.compile_with_cache
  File "/GPFS/XF03ID1/home/leofang/test/cupy2/cupy/cuda/compiler.py", line 165, in compile_with_cache
    ptx = compile_using_nvrtc(source, options, arch, name + '.cu')
  File "/GPFS/XF03ID1/home/leofang/test/cupy2/cupy/cuda/compiler.py", line 81, in compile_using_nvrtc
    ptx = prog.compile(options)
  File "/GPFS/XF03ID1/home/leofang/test/cupy2/cupy/cuda/compiler.py", line 246, in compile
    raise CompileException(log, self.src, self.name, options)
cupy.cuda.compiler.CompileException: /usr/include/features.h(374): catastrophic error: cannot open source file "sys/cdefs.h"

1 catastrophic error detected in the compilation of "/tmp/tmpavsru0q0/47fdde39a2a40ca9fc79d6ba5643eae3_2.cubin.cu".
Compilation terminated.

This error will remain there even if we change to manipulate any array not of any complex type. In fact, simply adding the #include <cuComplex.h> statement to any valid CUDA code is enough to cause the error for RawKernel, highlighting the importance of OP's request.

Note that if instead of doing ker = cp.RawKernel(cuda_ker, 'add_one') we pre-compile the kernel to a .cubin and load it using the undocumented feature #1657

mod = cp.cuda.function.Module()
mod.load_file("test_raw3.cubin")
ker = mod.get_function('add_one')

the code will run correctly.

@kmaehashi
Copy link
Copy Markdown
Member

@leofang Thanks for the comment!
I discussed again with @okuta and concluded that documenting cupy/complex.cuh is better in this case.

@kmaehashi kmaehashi reopened this Feb 6, 2019
@kmaehashi
Copy link
Copy Markdown
Member

@grlee77 Can I takeover your PR?

@grlee77
Copy link
Copy Markdown
Member Author

grlee77 commented Feb 13, 2019

Yes, feel free to update it.

@leofang
Copy link
Copy Markdown
Member

leofang commented Mar 20, 2019

Apparently the fact that complex numbers are treated specially in CuPy starts to cause issues (ex: #2111). I'd suggest @kmaehashi to merge this PR for the time being, and polish it later since it's not as critical but does present somehow as an obstacle for advanced users.

@kmaehashi kmaehashi self-assigned this Mar 27, 2019
@leofang
Copy link
Copy Markdown
Member

leofang commented Jul 19, 2019

Hi @kmaehashi @asi1024, any chance get this PR merged? Just ping you folks again in case the ball is dropped. Thanks.

@leofang leofang mentioned this pull request Aug 16, 2019
@leofang leofang mentioned this pull request Sep 5, 2019
9 tasks
@grlee77
Copy link
Copy Markdown
Member Author

grlee77 commented Sep 18, 2019

I have rebased to fix the merge conflict and made the change suggested by @leofang

@leofang
Copy link
Copy Markdown
Member

leofang commented Sep 19, 2019

LGTM! Thanks for revisiting this PR, @grlee77.

@kmaehashi @asi1024 any chance you could take a look at this before the weekend? Thanks.

@leofang
Copy link
Copy Markdown
Member

leofang commented Nov 12, 2019

Note: This PR is merged into #2551.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cat:document Documentation

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants