Skip to content

Commit 83b0d58

Browse files
fixup import from pyarrow._cuda
1 parent 05dd7a5 commit 83b0d58

File tree

5 files changed

+29
-32
lines changed

5 files changed

+29
-32
lines changed

python/pyarrow/_cuda.pyx

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -965,7 +965,8 @@ def read_record_batch(object buffer, object schema, *,
965965
return pyarrow_wrap_batch(batch)
966966

967967

968-
def _import_device_array(in_ptr, type):
968+
def _import_device_array_cuda(in_ptr, type):
969+
# equivalent to the definition in array.pxi but using CudaDefaultMemoryMapper
969970
cdef:
970971
void* c_ptr = _as_c_pointer(in_ptr)
971972
void* c_type_ptr
@@ -978,19 +979,20 @@ def _import_device_array(in_ptr, type):
978979
with nogil:
979980
c_array = GetResultValue(
980981
ImportDeviceArray(<ArrowDeviceArray*> c_ptr,
981-
<ArrowSchema*> c_type_ptr,
982-
CudaDefaultMemoryMapper)
982+
<ArrowSchema*> c_type_ptr,
983+
CudaDefaultMemoryMapper)
983984
)
984985
else:
985986
with nogil:
986987
c_array = GetResultValue(
987988
ImportDeviceArray(<ArrowDeviceArray*> c_ptr, c_type,
988-
CudaDefaultMemoryMapper)
989+
CudaDefaultMemoryMapper)
989990
)
990991
return pyarrow_wrap_array(c_array)
991992

992993

993-
def _import_device_recordbatch(in_ptr, schema):
994+
def _import_device_recordbatch_cuda(in_ptr, schema):
995+
# equivalent to the definition in table.pxi but using CudaDefaultMemoryMapper
994996
cdef:
995997
void* c_ptr = _as_c_pointer(in_ptr)
996998
void* c_schema_ptr
@@ -1003,11 +1005,13 @@ def _import_device_recordbatch(in_ptr, schema):
10031005
with nogil:
10041006
c_batch = GetResultValue(ImportDeviceRecordBatch(
10051007
<ArrowDeviceArray*> c_ptr, <ArrowSchema*> c_schema_ptr,
1006-
CudaDefaultMemoryMapper))
1008+
CudaDefaultMemoryMapper)
1009+
)
10071010
else:
10081011
with nogil:
10091012
c_batch = GetResultValue(ImportDeviceRecordBatch(
1010-
<ArrowDeviceArray*> c_ptr, c_schema, CudaDefaultMemoryMapper))
1013+
<ArrowDeviceArray*> c_ptr, c_schema, CudaDefaultMemoryMapper)
1014+
)
10111015
return pyarrow_wrap_batch(c_batch)
10121016

10131017

python/pyarrow/array.pxi

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,21 +22,7 @@ import warnings
2222
from cython import sizeof
2323

2424

25-
def _import_device_array(in_ptr, type):
26-
"""
27-
Import Array from a C ArrowDeviceArray struct, given its pointer
28-
and the imported array type.
29-
30-
Parameters
31-
----------
32-
in_ptr: int
33-
The raw pointer to a C ArrowDeviceArray struct.
34-
type: DataType or int
35-
Either a DataType object, or the raw pointer to a C ArrowSchema
36-
struct.
37-
38-
This is a low-level function intended for expert users.
39-
"""
25+
def _import_device_array_cpu(in_ptr, type):
4026
cdef:
4127
void* c_ptr = _as_c_pointer(in_ptr)
4228
void* c_type_ptr
@@ -62,9 +48,11 @@ def _import_device_array(in_ptr, type):
6248

6349

6450
try:
65-
from pyarrow._cuda import _import_device_array
51+
from pyarrow._cuda import _import_device_array_cuda
52+
53+
_import_device_array = _import_device_array_cuda
6654
except ImportError:
67-
pass
55+
_import_device_array = _import_device_array_cpu
6856

6957

7058
cdef _sequence_to_array(object sequence, object mask, object size,

python/pyarrow/includes/libarrow_cuda.pxd

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,9 @@ cdef extern from "arrow/gpu/cuda_api.h" namespace "arrow::cuda" nogil:
9393
CResult[shared_ptr[CCudaHostBuffer]] AllocateCudaHostBuffer(
9494
int device_number, const int64_t size)
9595

96-
CResult[shared_ptr[CMemoryManager]] CudaDefaultMemoryMapper" DefaultMemoryMapper"(
97-
ArrowDeviceType device_type, int64_t device_id)
96+
CResult[shared_ptr[CMemoryManager]] \
97+
CudaDefaultMemoryMapper" arrow::cuda::DefaultMemoryMapper"(
98+
ArrowDeviceType device_type, int64_t device_id)
9899

99100
# Cuda prefix is added to avoid picking up arrow::cuda functions
100101
# from arrow namespace.

python/pyarrow/lib.pyx

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ import os
2626
import sys
2727

2828
from cython.operator cimport dereference as deref
29+
from cython cimport binding
30+
2931
from pyarrow.includes.libarrow cimport *
3032
from pyarrow.includes.libarrow_python cimport *
3133
from pyarrow.includes.common cimport PyObject_to_object
@@ -162,6 +164,9 @@ include "pandas-shim.pxi"
162164
# Memory pools and allocation
163165
include "memory.pxi"
164166

167+
# File IO
168+
include "io.pxi"
169+
165170
# DataType, Field, Schema
166171
include "types.pxi"
167172

@@ -183,9 +188,6 @@ include "tensor.pxi"
183188
# DLPack
184189
include "_dlpack.pxi"
185190

186-
# File IO
187-
include "io.pxi"
188-
189191
# IPC / Messaging
190192
include "ipc.pxi"
191193

python/pyarrow/table.pxi

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import warnings
2121
from cython import sizeof
2222

2323

24-
def _import_device_recordbatch(in_ptr, schema):
24+
def _import_device_recordbatch_cpu(in_ptr, schema):
2525
cdef:
2626
void* c_ptr = _as_c_pointer(in_ptr)
2727
void* c_schema_ptr
@@ -43,9 +43,11 @@ def _import_device_recordbatch(in_ptr, schema):
4343

4444

4545
try:
46-
from pyarrow._cuda import _import_device_recordbatch
46+
from pyarrow._cuda import _import_device_recordbatch_cuda
47+
48+
_import_device_recordbatch = _import_device_recordbatch_cuda
4749
except ImportError:
48-
pass
50+
_import_device_recordbatch = _import_device_recordbatch_cpu
4951

5052

5153
cdef class ChunkedArray(_PandasConvertible):

0 commit comments

Comments
 (0)