Skip to content

Commit 05dd7a5

Browse files
try with helper function duplicated in _cuda.pyx
1 parent 42e883c commit 05dd7a5

7 files changed

Lines changed: 121 additions & 82 deletions

File tree

python/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -796,7 +796,6 @@ install(FILES "${ARROW_PYTHON_BINARY_DIR}/lib_api.h" "${ARROW_PYTHON_BINARY_DIR}
796796

797797
if(PYARROW_BUILD_CUDA)
798798
target_link_libraries(_cuda PRIVATE ${CUDA_LINK_LIBS})
799-
target_link_libraries(lib PRIVATE ${CUDA_LINK_LIBS})
800799
endif()
801800

802801
if(PYARROW_BUILD_FLIGHT)

python/pyarrow/_cuda.pyx

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -965,6 +965,52 @@ def read_record_batch(object buffer, object schema, *,
965965
return pyarrow_wrap_batch(batch)
966966

967967

968+
def _import_device_array(in_ptr, type):
969+
cdef:
970+
void* c_ptr = _as_c_pointer(in_ptr)
971+
void* c_type_ptr
972+
shared_ptr[CArray] c_array
973+
974+
c_type = pyarrow_unwrap_data_type(type)
975+
if c_type == nullptr:
976+
# Not a DataType object, perhaps a raw ArrowSchema pointer
977+
c_type_ptr = _as_c_pointer(type)
978+
with nogil:
979+
c_array = GetResultValue(
980+
ImportDeviceArray(<ArrowDeviceArray*> c_ptr,
981+
<ArrowSchema*> c_type_ptr,
982+
CudaDefaultMemoryMapper)
983+
)
984+
else:
985+
with nogil:
986+
c_array = GetResultValue(
987+
ImportDeviceArray(<ArrowDeviceArray*> c_ptr, c_type,
988+
CudaDefaultMemoryMapper)
989+
)
990+
return pyarrow_wrap_array(c_array)
991+
992+
993+
def _import_device_recordbatch(in_ptr, schema):
994+
cdef:
995+
void* c_ptr = _as_c_pointer(in_ptr)
996+
void* c_schema_ptr
997+
shared_ptr[CRecordBatch] c_batch
998+
999+
c_schema = pyarrow_unwrap_schema(schema)
1000+
if c_schema == nullptr:
1001+
# Not a Schema object, perhaps a raw ArrowSchema pointer
1002+
c_schema_ptr = _as_c_pointer(schema, allow_null=True)
1003+
with nogil:
1004+
c_batch = GetResultValue(ImportDeviceRecordBatch(
1005+
<ArrowDeviceArray*> c_ptr, <ArrowSchema*> c_schema_ptr,
1006+
CudaDefaultMemoryMapper))
1007+
else:
1008+
with nogil:
1009+
c_batch = GetResultValue(ImportDeviceRecordBatch(
1010+
<ArrowDeviceArray*> c_ptr, c_schema, CudaDefaultMemoryMapper))
1011+
return pyarrow_wrap_batch(c_batch)
1012+
1013+
9681014
# Public API
9691015

9701016

python/pyarrow/array.pxi

Lines changed: 45 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,50 @@ import os
2121
import warnings
2222
from cython import sizeof
2323

24-
# from pyarrow.includes.libarrow_cuda cimport DefaultMemoryMapper
25-
# from pyarrow.includes.libarrow cimport DefaultDeviceMapper as DefaultDeviceMemoryMapper
26-
from pyarrow.includes.libarrow_memory cimport CDefaultDeviceMemoryMapper
24+
25+
def _import_device_array(in_ptr, type):
26+
"""
27+
Import Array from a C ArrowDeviceArray struct, given its pointer
28+
and the imported array type.
29+
30+
Parameters
31+
----------
32+
in_ptr: int
33+
The raw pointer to a C ArrowDeviceArray struct.
34+
type: DataType or int
35+
Either a DataType object, or the raw pointer to a C ArrowSchema
36+
struct.
37+
38+
This is a low-level function intended for expert users.
39+
"""
40+
cdef:
41+
void* c_ptr = _as_c_pointer(in_ptr)
42+
void* c_type_ptr
43+
shared_ptr[CArray] c_array
44+
45+
c_type = pyarrow_unwrap_data_type(type)
46+
if c_type == nullptr:
47+
# Not a DataType object, perhaps a raw ArrowSchema pointer
48+
c_type_ptr = _as_c_pointer(type)
49+
with nogil:
50+
c_array = GetResultValue(
51+
ImportDeviceArray(<ArrowDeviceArray*> c_ptr,
52+
<ArrowSchema*> c_type_ptr,
53+
DefaultDeviceMapper)
54+
)
55+
else:
56+
with nogil:
57+
c_array = GetResultValue(
58+
ImportDeviceArray(<ArrowDeviceArray*> c_ptr, c_type,
59+
DefaultDeviceMapper)
60+
)
61+
return pyarrow_wrap_array(c_array)
62+
63+
64+
try:
65+
from pyarrow._cuda import _import_device_array
66+
except ImportError:
67+
pass
2768

2869

2970
cdef _sequence_to_array(object sequence, object mask, object size,
@@ -1826,28 +1867,7 @@ cdef class Array(_PandasConvertible):
18261867
18271868
This is a low-level function intended for expert users.
18281869
"""
1829-
cdef:
1830-
void* c_ptr = _as_c_pointer(in_ptr)
1831-
void* c_type_ptr
1832-
shared_ptr[CArray] c_array
1833-
1834-
c_type = pyarrow_unwrap_data_type(type)
1835-
if c_type == nullptr:
1836-
# Not a DataType object, perhaps a raw ArrowSchema pointer
1837-
c_type_ptr = _as_c_pointer(type)
1838-
with nogil:
1839-
c_array = GetResultValue(
1840-
ImportDeviceArray(<ArrowDeviceArray*> c_ptr,
1841-
<ArrowSchema*> c_type_ptr,
1842-
CDefaultDeviceMemoryMapper)
1843-
)
1844-
else:
1845-
with nogil:
1846-
c_array = GetResultValue(
1847-
ImportDeviceArray(<ArrowDeviceArray*> c_ptr, c_type,
1848-
CDefaultDeviceMemoryMapper)
1849-
)
1850-
return pyarrow_wrap_array(c_array)
1870+
return _import_device_array(in_ptr, type)
18511871

18521872
def __dlpack__(self, stream=None):
18531873
"""Export a primitive array as a DLPack capsule.

python/pyarrow/includes/libarrow_cuda.pxd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ cdef extern from "arrow/gpu/cuda_api.h" namespace "arrow::cuda" nogil:
9393
CResult[shared_ptr[CCudaHostBuffer]] AllocateCudaHostBuffer(
9494
int device_number, const int64_t size)
9595

96-
CResult[shared_ptr[CMemoryManager]] DefaultMemoryMapper(
96+
CResult[shared_ptr[CMemoryManager]] CudaDefaultMemoryMapper" DefaultMemoryMapper"(
9797
ArrowDeviceType device_type, int64_t device_id)
9898

9999
# Cuda prefix is added to avoid picking up arrow::cuda functions

python/pyarrow/includes/libarrow_memory.pxd

Lines changed: 0 additions & 34 deletions
This file was deleted.

python/pyarrow/lib.pxd

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -643,6 +643,8 @@ cdef shared_ptr[const CKeyValueMetadata] pyarrow_unwrap_metadata(
643643
cdef object pyarrow_wrap_metadata(
644644
const shared_ptr[const CKeyValueMetadata]& meta)
645645

646+
cdef void* _as_c_pointer(v, allow_null=*) except *
647+
646648
#
647649
# Public Cython API for 3rd party code
648650
#

python/pyarrow/table.pxi

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,32 @@ from cpython.pycapsule cimport PyCapsule_CheckExact, PyCapsule_GetPointer, PyCap
2020
import warnings
2121
from cython import sizeof
2222

23-
# from pyarrow.includes.libarrow_cuda cimport DefaultMemoryMapper
24-
# from pyarrow.includes.libarrow cimport DefaultDeviceMapper as DefaultDeviceMemoryMapper
25-
from pyarrow.includes.libarrow_memory cimport CDefaultDeviceMemoryMapper
23+
24+
def _import_device_recordbatch(in_ptr, schema):
25+
cdef:
26+
void* c_ptr = _as_c_pointer(in_ptr)
27+
void* c_schema_ptr
28+
shared_ptr[CRecordBatch] c_batch
29+
30+
c_schema = pyarrow_unwrap_schema(schema)
31+
if c_schema == nullptr:
32+
# Not a Schema object, perhaps a raw ArrowSchema pointer
33+
c_schema_ptr = _as_c_pointer(schema, allow_null=True)
34+
with nogil:
35+
c_batch = GetResultValue(ImportDeviceRecordBatch(
36+
<ArrowDeviceArray*> c_ptr, <ArrowSchema*> c_schema_ptr,
37+
DefaultDeviceMapper))
38+
else:
39+
with nogil:
40+
c_batch = GetResultValue(ImportDeviceRecordBatch(
41+
<ArrowDeviceArray*> c_ptr, c_schema, DefaultDeviceMapper))
42+
return pyarrow_wrap_batch(c_batch)
43+
44+
45+
try:
46+
from pyarrow._cuda import _import_device_recordbatch
47+
except ImportError:
48+
pass
2649

2750

2851
cdef class ChunkedArray(_PandasConvertible):
@@ -3590,24 +3613,7 @@ cdef class RecordBatch(_Tabular):
35903613
35913614
This is a low-level function intended for expert users.
35923615
"""
3593-
cdef:
3594-
void* c_ptr = _as_c_pointer(in_ptr)
3595-
void* c_schema_ptr
3596-
shared_ptr[CRecordBatch] c_batch
3597-
3598-
c_schema = pyarrow_unwrap_schema(schema)
3599-
if c_schema == nullptr:
3600-
# Not a Schema object, perhaps a raw ArrowSchema pointer
3601-
c_schema_ptr = _as_c_pointer(schema, allow_null=True)
3602-
with nogil:
3603-
c_batch = GetResultValue(ImportDeviceRecordBatch(
3604-
<ArrowDeviceArray*> c_ptr, <ArrowSchema*> c_schema_ptr,
3605-
CDefaultDeviceMemoryMapper))
3606-
else:
3607-
with nogil:
3608-
c_batch = GetResultValue(ImportDeviceRecordBatch(
3609-
<ArrowDeviceArray*> c_ptr, c_schema, CDefaultDeviceMemoryMapper))
3610-
return pyarrow_wrap_batch(c_batch)
3616+
return _import_device_recordbatch(in_ptr, schema)
36113617

36123618

36133619
def _reconstruct_record_batch(columns, schema):

0 commit comments

Comments
 (0)