Skip to content

Commit d2c9e39

Browse files
authored
Serialize all pyarrow extension arrays efficiently (#9740)
1 parent 7a0e873 commit d2c9e39

3 files changed

Lines changed: 152 additions & 227 deletions

File tree

dask/dataframe/_compat.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,19 @@
44

55
import numpy as np
66
import pandas as pd
7-
from packaging.version import parse as parse_version
8-
9-
PANDAS_VERSION = parse_version(pd.__version__)
10-
PANDAS_GT_104 = PANDAS_VERSION >= parse_version("1.0.4")
11-
PANDAS_GT_110 = PANDAS_VERSION >= parse_version("1.1.0")
12-
PANDAS_GT_120 = PANDAS_VERSION >= parse_version("1.2.0")
13-
PANDAS_GT_121 = PANDAS_VERSION >= parse_version("1.2.1")
14-
PANDAS_GT_130 = PANDAS_VERSION >= parse_version("1.3.0")
15-
PANDAS_GT_131 = PANDAS_VERSION >= parse_version("1.3.1")
16-
PANDAS_GT_133 = PANDAS_VERSION >= parse_version("1.3.3")
17-
PANDAS_GT_140 = PANDAS_VERSION >= parse_version("1.4.0")
18-
PANDAS_GT_150 = PANDAS_VERSION >= parse_version("1.5.0")
7+
from packaging.version import Version
8+
9+
PANDAS_VERSION = Version(pd.__version__)
10+
PANDAS_GT_104 = PANDAS_VERSION >= Version("1.0.4")
11+
PANDAS_GT_110 = PANDAS_VERSION >= Version("1.1.0")
12+
PANDAS_GT_120 = PANDAS_VERSION >= Version("1.2.0")
13+
PANDAS_GT_121 = PANDAS_VERSION >= Version("1.2.1")
14+
PANDAS_GT_130 = PANDAS_VERSION >= Version("1.3.0")
15+
PANDAS_GT_131 = PANDAS_VERSION >= Version("1.3.1")
16+
PANDAS_GT_133 = PANDAS_VERSION >= Version("1.3.3")
17+
PANDAS_GT_140 = PANDAS_VERSION >= Version("1.4.0")
18+
PANDAS_GT_150 = PANDAS_VERSION >= Version("1.5.0")
19+
PANDAS_GT_200 = PANDAS_VERSION.major >= 2
1920

2021
import pandas.testing as tm
2122

dask/dataframe/_pyarrow_compat.py

Lines changed: 24 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -1,143 +1,47 @@
11
import copyreg
2-
import math
32

4-
import numpy as np
53
import pandas as pd
64

75
try:
86
import pyarrow as pa
97
except ImportError:
108
pa = None
119

10+
from dask.dataframe._compat import PANDAS_GT_130, PANDAS_GT_150, PANDAS_GT_200
11+
1212
# Pickling of pyarrow arrays is effectively broken - pickling a slice of an
1313
# array ends up pickling the entire backing array.
1414
#
1515
# See https://issues.apache.org/jira/browse/ARROW-10739
1616
#
1717
# This comes up when using pandas `string[pyarrow]` dtypes, which are backed by
1818
# a `pyarrow.StringArray`. To fix this, we register a *global* override for
19-
# pickling `pandas.core.arrays.ArrowStringArray` types. We do this at the
20-
# pandas level rather than the pyarrow level for efficiency reasons (a pandas
21-
# ArrowStringArray may contain many small pyarrow StringArray objects).
22-
#
23-
# This pickling implementation manually mucks with the backing buffers in a
24-
# fairly efficient way:
25-
#
26-
# - The data buffer is never copied
27-
# - The offsets buffer is only copied if the array is sliced with a start index
28-
# (x[start:])
29-
# - The mask buffer is never copied
30-
#
31-
# This implementation works with pickle protocol 5, allowing support for true
32-
# zero-copy sends.
19+
# pickling `ArrowStringArray` or `ArrowExtensionArray` types (where available).
20+
# We do this at the pandas level rather than the pyarrow level for efficiency reasons
21+
# (a pandas ArrowStringArray may contain many small pyarrow StringArray objects).
3322
#
34-
# XXX: Once pyarrow (or pandas) has fixed this bug, we should skip registering
35-
# with copyreg for versions that lack this issue.
36-
37-
38-
def pyarrow_stringarray_to_parts(array):
39-
"""Decompose a ``pyarrow.StringArray`` into a tuple of components.
40-
41-
The resulting tuple can be passed to
42-
``pyarrow_stringarray_from_parts(*components)`` to reconstruct the
43-
``pyarrow.StringArray``.
44-
"""
45-
# Access the backing buffers.
46-
#
47-
# - mask: None, or a bitmask of length ceil(nitems / 8). 0 bits mark NULL
48-
# elements, only present if NULL data is present, commonly None.
49-
# - offsets: A uint32 array of nitems + 1 items marking the start/stop
50-
# indices for the individual elements in `data`
51-
# - data: All the utf8 string data concatenated together
52-
#
53-
# The structure of these buffers comes from the arrow format, documented at
54-
# https://arrow.apache.org/docs/format/Columnar.html#physical-memory-layout.
55-
# In particular, this is a `StringArray` (4 byte offsets), rather than a
56-
# `LargeStringArray` (8 byte offsets).
57-
assert pa.types.is_string(array.type)
58-
59-
mask, offsets, data = array.buffers()
60-
nitems = len(array)
61-
62-
if not array.offset:
63-
# No leading offset, only need to slice any unnecessary data from the
64-
# backing buffers
65-
offsets = offsets[: 4 * (nitems + 1)]
66-
data_stop = int.from_bytes(offsets[-4:], "little")
67-
data = data[:data_stop]
68-
if mask is None:
69-
return nitems, offsets, data
70-
else:
71-
mask = mask[: math.ceil(nitems / 8)]
72-
return nitems, offsets, data, mask
73-
74-
# There is a leading offset. This complicates things a bit.
75-
offsets_start = array.offset * 4
76-
offsets_stop = offsets_start + (nitems + 1) * 4
77-
data_start = int.from_bytes(offsets[offsets_start : offsets_start + 4], "little")
78-
data_stop = int.from_bytes(offsets[offsets_stop - 4 : offsets_stop], "little")
79-
data = data[data_start:data_stop]
80-
81-
if mask is None:
82-
npad = 0
83-
else:
84-
# Since the mask is a bitmask, it can only represent even units of 8
85-
# elements. To avoid shifting any bits, we pad the array with up to 7
86-
# elements so the mask array can always be serialized zero copy.
87-
npad = array.offset % 8
88-
mask_start = array.offset // 8
89-
mask_stop = math.ceil((array.offset + nitems) / 8)
90-
mask = mask[mask_start:mask_stop]
91-
92-
# Subtract the offset of the starting element from every used offset in the
93-
# offsets array, ensuring the first element in the serialized `offsets`
94-
# array is always 0.
95-
offsets_array = np.frombuffer(offsets, dtype="i4")
96-
offsets_array = (
97-
offsets_array[array.offset : array.offset + nitems + 1]
98-
- offsets_array[array.offset]
99-
)
100-
# Pad the new offsets by `npad` offsets of 0 (see the `mask` comment above). We wrap
101-
# this in a `pyarrow.py_buffer`, since this type transparently supports pickle 5,
102-
# avoiding an extra copy inside the pickler.
103-
offsets = pa.py_buffer(
104-
b"\x00" * (4 * npad) + offsets_array.data if npad else offsets_array.data
105-
)
106-
107-
if mask is None:
108-
return nitems, offsets, data
109-
else:
110-
return nitems, offsets, data, mask, npad
111-
112-
113-
def pyarrow_stringarray_from_parts(nitems, data_offsets, data, mask=None, offset=0):
114-
"""Reconstruct a ``pyarrow.StringArray`` from the parts returned by
115-
``pyarrow_stringarray_to_parts``."""
116-
return pa.StringArray.from_buffers(nitems, data_offsets, data, mask, offset=offset)
23+
# The implementation here is based on https://github.com/pandas-dev/pandas/pull/49078
24+
# which is included in pandas=2+. We can remove all this once Dask's minimum
25+
# supported pandas version is at least 2.0.0.
11726

11827

119-
def rebuild_arrowstringarray(*chunk_parts):
120-
"""Rebuild a ``pandas.core.arrays.ArrowStringArray``"""
121-
array = pa.chunked_array(
122-
[pyarrow_stringarray_from_parts(*parts) for parts in chunk_parts],
123-
type=pa.string(),
124-
)
125-
return pd.arrays.ArrowStringArray(array)
28+
def rebuild_arrowextensionarray(type_, chunks):
29+
array = pa.chunked_array(chunks)
30+
return type_(array)
12631

12732

128-
def reduce_arrowstringarray(x):
129-
"""A pickle override for ``pandas.core.arrays.ArrowStringArray`` that avoids
130-
serializing unnecessary data, while also avoiding/minimizing data copies"""
131-
# Decompose each chunk in the backing ChunkedArray into their individual
132-
# components for serialization. We filter out 0-length chunks, since they
133-
# add no meaningful value to the chunked array.
134-
chunks = tuple(
135-
pyarrow_stringarray_to_parts(chunk)
136-
for chunk in x._data.chunks
137-
if len(chunk) > 0
138-
)
139-
return (rebuild_arrowstringarray, chunks)
33+
def reduce_arrowextensionarray(x):
34+
return (rebuild_arrowextensionarray, (type(x), x._data.combine_chunks()))
14035

14136

142-
if hasattr(pd.arrays, "ArrowStringArray") and pa is not None:
143-
copyreg.dispatch_table[pd.arrays.ArrowStringArray] = reduce_arrowstringarray
37+
# `pandas=2` includes efficient serialization of `pyarrow`-backed extension arrays.
38+
# See https://github.com/pandas-dev/pandas/pull/49078 for details.
39+
# We only need to backport efficient serialization for `pandas<2`.
40+
if pa is not None and not PANDAS_GT_200:
41+
if PANDAS_GT_150:
42+
# Applies to all `pyarrow`-backed extension arrays (e.g. `string[pyarrow]`, `int64[pyarrow]`)
43+
for type_ in [pd.arrays.ArrowExtensionArray, pd.arrays.ArrowStringArray]:
44+
copyreg.dispatch_table[type_] = reduce_arrowextensionarray
45+
elif PANDAS_GT_130:
46+
# Only `string[pyarrow]` is implemented, so just patch that
47+
copyreg.dispatch_table[pd.arrays.ArrowStringArray] = reduce_arrowextensionarray

0 commit comments

Comments
 (0)