Skip to content

Commit a407a6b

Browse files
GH-40698: [C++] Create registry for Devices to map DeviceType to MemoryManager in C Device Data import (#40699)
### Rationale for this change Follow-up on #39980 (comment) Right now, the user of `ImportDeviceArray` or `ImportDeviceRecordBatch` needs to provide a `DeviceMemoryMapper` mapping the device type and id to a MemoryManager. We provide a default implementation of that mapper that just knows about the default CPU memory manager (and there is another implementation in `arrow::cuda`, but you need to explicitly pass that to the import function) To make this easier, this PR adds a registry such that default device mappers can be added separately. ### What changes are included in this PR? This PR adds two new public functions to register device types (`RegisterDeviceMemoryManager`) and retrieve the mapper from the registry (`GetDeviceMemoryManager`). Further, it provides a `RegisterCUDADevice` to optionally register the CUDA devices (by default only CPU device is registered). ### Are these changes tested? ### Are there any user-facing changes? * GitHub Issue: #40698 Lead-authored-by: Joris Van den Bossche <jorisvandenbossche@gmail.com> Co-authored-by: Antoine Pitrou <pitrou@free.fr> Signed-off-by: Joris Van den Bossche <jorisvandenbossche@gmail.com>
1 parent aae2557 commit a407a6b

8 files changed

Lines changed: 139 additions & 26 deletions

File tree

cpp/src/arrow/buffer_test.cc

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1023,4 +1023,17 @@ TEST(TestBufferConcatenation, EmptyBuffer) {
10231023
AssertMyBufferEqual(*result, contents);
10241024
}
10251025

1026+
TEST(TestDeviceRegistry, Basics) {
1027+
// Test the error cases for the device registry
1028+
1029+
// CPU is already registered
1030+
ASSERT_RAISES(KeyError,
1031+
RegisterDeviceMapper(DeviceAllocationType::kCPU, [](int64_t device_id) {
1032+
return default_cpu_memory_manager();
1033+
}));
1034+
1035+
// VPI is not registered
1036+
ASSERT_RAISES(KeyError, GetDeviceMapper(DeviceAllocationType::kVPI));
1037+
}
1038+
10261039
} // namespace arrow

cpp/src/arrow/c/bridge.cc

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1967,12 +1967,11 @@ Result<std::shared_ptr<RecordBatch>> ImportRecordBatch(struct ArrowArray* array,
19671967
return ImportRecordBatch(array, *maybe_schema);
19681968
}
19691969

1970-
Result<std::shared_ptr<MemoryManager>> DefaultDeviceMapper(ArrowDeviceType device_type,
1971-
int64_t device_id) {
1972-
if (device_type != ARROW_DEVICE_CPU) {
1973-
return Status::NotImplemented("Only importing data on CPU is supported");
1974-
}
1975-
return default_cpu_memory_manager();
1970+
Result<std::shared_ptr<MemoryManager>> DefaultDeviceMemoryMapper(
1971+
ArrowDeviceType device_type, int64_t device_id) {
1972+
ARROW_ASSIGN_OR_RAISE(auto mapper,
1973+
GetDeviceMapper(static_cast<DeviceAllocationType>(device_type)));
1974+
return mapper(device_id);
19761975
}
19771976

19781977
Result<std::shared_ptr<Array>> ImportDeviceArray(struct ArrowDeviceArray* array,

cpp/src/arrow/c/bridge.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -219,8 +219,8 @@ using DeviceMemoryMapper =
219219
std::function<Result<std::shared_ptr<MemoryManager>>(ArrowDeviceType, int64_t)>;
220220

221221
ARROW_EXPORT
222-
Result<std::shared_ptr<MemoryManager>> DefaultDeviceMapper(ArrowDeviceType device_type,
223-
int64_t device_id);
222+
Result<std::shared_ptr<MemoryManager>> DefaultDeviceMemoryMapper(
223+
ArrowDeviceType device_type, int64_t device_id);
224224

225225
/// \brief EXPERIMENTAL: Import C++ device array from the C data interface.
226226
///
@@ -236,7 +236,7 @@ Result<std::shared_ptr<MemoryManager>> DefaultDeviceMapper(ArrowDeviceType devic
236236
ARROW_EXPORT
237237
Result<std::shared_ptr<Array>> ImportDeviceArray(
238238
struct ArrowDeviceArray* array, std::shared_ptr<DataType> type,
239-
const DeviceMemoryMapper& mapper = DefaultDeviceMapper);
239+
const DeviceMemoryMapper& mapper = DefaultDeviceMemoryMapper);
240240

241241
/// \brief EXPERIMENTAL: Import C++ device array and its type from the C data interface.
242242
///
@@ -253,7 +253,7 @@ Result<std::shared_ptr<Array>> ImportDeviceArray(
253253
ARROW_EXPORT
254254
Result<std::shared_ptr<Array>> ImportDeviceArray(
255255
struct ArrowDeviceArray* array, struct ArrowSchema* type,
256-
const DeviceMemoryMapper& mapper = DefaultDeviceMapper);
256+
const DeviceMemoryMapper& mapper = DefaultDeviceMemoryMapper);
257257

258258
/// \brief EXPERIMENTAL: Import C++ record batch with buffers on a device from the C data
259259
/// interface.
@@ -271,7 +271,7 @@ Result<std::shared_ptr<Array>> ImportDeviceArray(
271271
ARROW_EXPORT
272272
Result<std::shared_ptr<RecordBatch>> ImportDeviceRecordBatch(
273273
struct ArrowDeviceArray* array, std::shared_ptr<Schema> schema,
274-
const DeviceMemoryMapper& mapper = DefaultDeviceMapper);
274+
const DeviceMemoryMapper& mapper = DefaultDeviceMemoryMapper);
275275

276276
/// \brief EXPERIMENTAL: Import C++ record batch with buffers on a device and its schema
277277
/// from the C data interface.
@@ -291,7 +291,7 @@ Result<std::shared_ptr<RecordBatch>> ImportDeviceRecordBatch(
291291
ARROW_EXPORT
292292
Result<std::shared_ptr<RecordBatch>> ImportDeviceRecordBatch(
293293
struct ArrowDeviceArray* array, struct ArrowSchema* schema,
294-
const DeviceMemoryMapper& mapper = DefaultDeviceMapper);
294+
const DeviceMemoryMapper& mapper = DefaultDeviceMemoryMapper);
295295

296296
/// @}
297297

cpp/src/arrow/device.cc

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
#include "arrow/device.h"
1919

2020
#include <cstring>
21+
#include <mutex>
22+
#include <unordered_map>
2123
#include <utility>
2224

2325
#include "arrow/array.h"
@@ -268,4 +270,65 @@ std::shared_ptr<MemoryManager> CPUDevice::default_memory_manager() {
268270
return default_cpu_memory_manager();
269271
}
270272

273+
namespace {
274+
275+
class DeviceMapperRegistryImpl {
276+
public:
277+
DeviceMapperRegistryImpl() {}
278+
279+
Status RegisterDevice(DeviceAllocationType device_type, DeviceMapper memory_mapper) {
280+
std::lock_guard<std::mutex> lock(lock_);
281+
auto [_, inserted] = registry_.try_emplace(device_type, std::move(memory_mapper));
282+
if (!inserted) {
283+
return Status::KeyError("Device type ", static_cast<int>(device_type),
284+
" is already registered");
285+
}
286+
return Status::OK();
287+
}
288+
289+
Result<DeviceMapper> GetMapper(DeviceAllocationType device_type) {
290+
std::lock_guard<std::mutex> lock(lock_);
291+
auto it = registry_.find(device_type);
292+
if (it == registry_.end()) {
293+
return Status::KeyError("Device type ", static_cast<int>(device_type),
294+
"is not registered");
295+
}
296+
return it->second;
297+
}
298+
299+
private:
300+
std::mutex lock_;
301+
std::unordered_map<DeviceAllocationType, DeviceMapper> registry_;
302+
};
303+
304+
Result<std::shared_ptr<MemoryManager>> DefaultCPUDeviceMapper(int64_t device_id) {
305+
return default_cpu_memory_manager();
306+
}
307+
308+
static std::unique_ptr<DeviceMapperRegistryImpl> CreateDeviceRegistry() {
309+
auto registry = std::make_unique<DeviceMapperRegistryImpl>();
310+
311+
// Always register the CPU device
312+
DCHECK_OK(registry->RegisterDevice(DeviceAllocationType::kCPU, DefaultCPUDeviceMapper));
313+
314+
return registry;
315+
}
316+
317+
DeviceMapperRegistryImpl* GetDeviceRegistry() {
318+
static auto g_registry = CreateDeviceRegistry();
319+
return g_registry.get();
320+
}
321+
322+
} // namespace
323+
324+
Status RegisterDeviceMapper(DeviceAllocationType device_type, DeviceMapper mapper) {
325+
auto registry = GetDeviceRegistry();
326+
return registry->RegisterDevice(device_type, std::move(mapper));
327+
}
328+
329+
Result<DeviceMapper> GetDeviceMapper(DeviceAllocationType device_type) {
330+
auto registry = GetDeviceRegistry();
331+
return registry->GetMapper(device_type);
332+
}
333+
271334
} // namespace arrow

cpp/src/arrow/device.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,4 +363,32 @@ class ARROW_EXPORT CPUMemoryManager : public MemoryManager {
363363
ARROW_EXPORT
364364
std::shared_ptr<MemoryManager> default_cpu_memory_manager();
365365

366+
using DeviceMapper =
367+
std::function<Result<std::shared_ptr<MemoryManager>>(int64_t device_id)>;
368+
369+
/// \brief Register a function to retrieve a MemoryManager for a Device type
370+
///
371+
/// This registers the device type globally. A specific device type can only
372+
/// be registered once. This method is thread-safe.
373+
///
374+
/// Currently, this registry is only used for importing data through the C Device
375+
/// Data Interface (for the default Device to MemoryManager mapper in
376+
/// arrow::ImportDeviceArray/ImportDeviceRecordBatch).
377+
///
378+
/// \param[in] device_type the device type for which to register a MemoryManager
379+
/// \param[in] mapper function that takes a device id and returns the appropriate
380+
/// MemoryManager for the registered device type and given device id
381+
/// \return Status
382+
ARROW_EXPORT
383+
Status RegisterDeviceMapper(DeviceAllocationType device_type, DeviceMapper mapper);
384+
385+
/// \brief Get the registered function to retrieve a MemoryManager for the
386+
/// given Device type
387+
///
388+
/// \param[in] device_type the device type
389+
/// \return function that takes a device id and returns the appropriate
390+
/// MemoryManager for the registered device type and given device id
391+
ARROW_EXPORT
392+
Result<DeviceMapper> GetDeviceMapper(DeviceAllocationType device_type);
393+
366394
} // namespace arrow

cpp/src/arrow/gpu/cuda_memory.cc

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include <cuda.h>
2828

2929
#include "arrow/buffer.h"
30+
#include "arrow/device.h"
3031
#include "arrow/io/memory.h"
3132
#include "arrow/memory_pool.h"
3233
#include "arrow/status.h"
@@ -501,5 +502,23 @@ Result<std::shared_ptr<MemoryManager>> DefaultMemoryMapper(ArrowDeviceType devic
501502
}
502503
}
503504

505+
namespace {
506+
507+
Result<std::shared_ptr<MemoryManager>> DefaultCUDADeviceMapper(int64_t device_id) {
508+
ARROW_ASSIGN_OR_RAISE(auto device, arrow::cuda::CudaDevice::Make(device_id));
509+
return device->default_memory_manager();
510+
}
511+
512+
bool RegisterCUDADeviceInternal() {
513+
DCHECK_OK(RegisterDeviceMapper(DeviceAllocationType::kCUDA, DefaultCUDADeviceMapper));
514+
// TODO add the CUDA_HOST and CUDA_MANAGED allocation types when they are supported in
515+
// the CudaDevice
516+
return true;
517+
}
518+
519+
static auto cuda_registered = RegisterCUDADeviceInternal();
520+
521+
} // namespace
522+
504523
} // namespace cuda
505524
} // namespace arrow

cpp/src/arrow/gpu/cuda_memory.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,9 @@ Result<uintptr_t> GetDeviceAddress(const uint8_t* cpu_data,
260260
ARROW_EXPORT
261261
Result<uint8_t*> GetHostAddress(uintptr_t device_ptr);
262262

263-
ARROW_EXPORT
263+
ARROW_DEPRECATED(
264+
"Deprecated in 16.0.0. The CUDA device is registered by default, and you can use "
265+
"arrow::DefaultDeviceMapper instead.")
264266
Result<std::shared_ptr<MemoryManager>> DefaultMemoryMapper(ArrowDeviceType device_type,
265267
int64_t device_id);
266268

cpp/src/arrow/gpu/cuda_test.cc

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -716,17 +716,6 @@ class TestCudaDeviceArrayRoundtrip : public ::testing::Test {
716716
public:
717717
using ArrayFactory = std::function<Result<std::shared_ptr<Array>>()>;
718718

719-
static Result<std::shared_ptr<MemoryManager>> DeviceMapper(ArrowDeviceType type,
720-
int64_t id) {
721-
if (type != ARROW_DEVICE_CUDA) {
722-
return Status::NotImplemented("should only be CUDA device");
723-
}
724-
725-
ARROW_ASSIGN_OR_RAISE(auto manager, cuda::CudaDeviceManager::Instance());
726-
ARROW_ASSIGN_OR_RAISE(auto device, manager->GetDevice(id));
727-
return device->default_memory_manager();
728-
}
729-
730719
static ArrayFactory JSONArrayFactory(std::shared_ptr<DataType> type, const char* json) {
731720
return [=]() { return ArrayFromJSON(type, json); };
732721
}
@@ -759,7 +748,7 @@ class TestCudaDeviceArrayRoundtrip : public ::testing::Test {
759748

760749
std::shared_ptr<Array> device_array_roundtripped;
761750
ASSERT_OK_AND_ASSIGN(device_array_roundtripped,
762-
ImportDeviceArray(&c_array, &c_schema, DeviceMapper));
751+
ImportDeviceArray(&c_array, &c_schema));
763752
ASSERT_TRUE(ArrowSchemaIsReleased(&c_schema));
764753
ASSERT_TRUE(ArrowArrayIsReleased(&c_array.array));
765754

@@ -779,7 +768,7 @@ class TestCudaDeviceArrayRoundtrip : public ::testing::Test {
779768
ASSERT_OK(ExportDeviceArray(*device_array, sync, &c_array, &c_schema));
780769
device_array_roundtripped.reset();
781770
ASSERT_OK_AND_ASSIGN(device_array_roundtripped,
782-
ImportDeviceArray(&c_array, &c_schema, DeviceMapper));
771+
ImportDeviceArray(&c_array, &c_schema));
783772
ASSERT_TRUE(ArrowSchemaIsReleased(&c_schema));
784773
ASSERT_TRUE(ArrowArrayIsReleased(&c_array.array));
785774

0 commit comments

Comments
 (0)