Skip to content

Commit d178233

Browse files
j-bennetjrbourbeau
andauthored
Efficient dataframe.convert_string support for read_parquet (#9979)
Co-authored-by: James Bourbeau <jrbourbeau@gmail.com>
1 parent 970da68 commit d178233

4 files changed

Lines changed: 174 additions & 32 deletions

File tree

dask/dataframe/io/parquet/arrow.py

Lines changed: 78 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,15 @@
77
import pandas as pd
88
import pyarrow as pa
99
import pyarrow.parquet as pq
10+
11+
try:
12+
from pyarrow.parquet import filters_to_expression
13+
except ImportError:
14+
from pyarrow.parquet import _filters_to_expression as filters_to_expression
15+
1016
from packaging.version import parse as parse_version
1117

18+
from dask import config
1219
from dask.base import tokenize
1320
from dask.core import flatten
1421
from dask.dataframe.backends import pyarrow_schema_dispatch
@@ -436,7 +443,9 @@ def read_metadata(
436443
)
437444

438445
# Stage 2: Generate output `meta`
439-
meta = cls._create_dd_meta(dataset_info, use_nullable_dtypes)
446+
meta = cls._create_dd_meta(
447+
dataset_info, use_nullable_dtypes=use_nullable_dtypes
448+
)
440449

441450
# Stage 3: Generate parts and stats
442451
parts, stats, common_kwargs = cls._construct_collection_plan(dataset_info)
@@ -1091,6 +1100,7 @@ def _collect_dataset_info(
10911100
"metadata_task_size": metadata_task_size,
10921101
"kwargs": {
10931102
"dataset": _dataset_kwargs,
1103+
"convert_string": config.get("dataframe.convert_string"),
10941104
**kwargs,
10951105
},
10961106
}
@@ -1123,11 +1133,13 @@ def _create_dd_meta(cls, dataset_info, use_nullable_dtypes=False):
11231133

11241134
# Use _arrow_table_to_pandas to generate meta
11251135
arrow_to_pandas = dataset_info["kwargs"].get("arrow_to_pandas", {}).copy()
1136+
convert_string = dataset_info["kwargs"].get("convert_string", False)
11261137
meta = cls._arrow_table_to_pandas(
11271138
schema.empty_table(),
11281139
categories,
11291140
arrow_to_pandas=arrow_to_pandas,
11301141
use_nullable_dtypes=use_nullable_dtypes,
1142+
convert_string=convert_string,
11311143
)
11321144
index_names = list(meta.index.names)
11331145
column_names = list(meta.columns)
@@ -1317,7 +1329,7 @@ def _construct_collection_plan(cls, dataset_info):
13171329
# Get/transate filters
13181330
ds_filters = None
13191331
if filters is not None:
1320-
ds_filters = pq._filters_to_expression(filters)
1332+
ds_filters = filters_to_expression(filters)
13211333

13221334
# Define subset of `dataset_info` required by _collect_file_parts
13231335
dataset_info_kwargs = {
@@ -1666,7 +1678,7 @@ def _read_table(
16661678
use_threads=False,
16671679
schema=schema,
16681680
columns=cols,
1669-
filter=pq._filters_to_expression(filters) if filters else None,
1681+
filter=filters_to_expression(filters) if filters else None,
16701682
)
16711683
else:
16721684
arrow_table = _read_table_from_path(
@@ -1699,40 +1711,76 @@ def _read_table(
16991711
return arrow_table
17001712

17011713
@classmethod
1702-
def _arrow_table_to_pandas(
1703-
cls, arrow_table: pa.Table, categories, use_nullable_dtypes=False, **kwargs
1704-
) -> pd.DataFrame:
1705-
_kwargs = kwargs.get("arrow_to_pandas", {})
1706-
_kwargs.update({"use_threads": False, "ignore_metadata": False})
1707-
1708-
if use_nullable_dtypes:
1709-
# Determine is `pandas` or `pyarrow`-backed dtypes should be used
1710-
if use_nullable_dtypes == "pandas":
1711-
default_types_mapper = PYARROW_NULLABLE_DTYPE_MAPPING.get
1714+
def _determine_type_mapper(
1715+
cls, *, use_nullable_dtypes=False, convert_string=False, **kwargs
1716+
):
1717+
user_mapper = kwargs.get("arrow_to_pandas", {}).get("types_mapper")
1718+
type_mappers = []
1719+
1720+
def pyarrow_type_mapper(pyarrow_dtype):
1721+
# Special case pyarrow strings to use more feature complete dtype
1722+
# See https://github.com/pandas-dev/pandas/issues/50074
1723+
if pyarrow_dtype == pa.string():
1724+
return pd.StringDtype("pyarrow")
17121725
else:
1713-
# use_nullable_dtypes == "pyarrow"
1726+
return pd.ArrowDtype(pyarrow_dtype)
17141727

1715-
def default_types_mapper(pyarrow_dtype): # type: ignore
1716-
# Special case pyarrow strings to use more feature complete dtype
1717-
# See https://github.com/pandas-dev/pandas/issues/50074
1718-
if pyarrow_dtype == pa.string():
1719-
return pd.StringDtype("pyarrow")
1720-
else:
1721-
return pd.ArrowDtype(pyarrow_dtype)
1728+
# always use the user-defined mapper first, if available
1729+
if user_mapper is not None:
1730+
type_mappers.append(user_mapper)
17221731

1723-
if "types_mapper" in _kwargs:
1724-
# User-provided entries take priority over default_types_mapper
1725-
types_mapper = _kwargs["types_mapper"]
1732+
# next in priority is converting strings
1733+
if convert_string:
1734+
type_mappers.append({pa.string(): pd.StringDtype("pyarrow")}.get)
17261735

1727-
def _types_mapper(pa_type):
1728-
return types_mapper(pa_type) or default_types_mapper(pa_type)
1736+
# and then nullable types
1737+
if use_nullable_dtypes == "pandas":
1738+
type_mappers.append(PYARROW_NULLABLE_DTYPE_MAPPING.get)
1739+
elif use_nullable_dtypes: # "pyarrow" or True
1740+
type_mappers.append(pyarrow_type_mapper)
17291741

1730-
_kwargs["types_mapper"] = _types_mapper
1742+
def default_types_mapper(pyarrow_dtype):
1743+
"""Try all type mappers in order, starting from the user type mapper."""
1744+
for type_converter in type_mappers:
1745+
converted_type = type_converter(pyarrow_dtype)
1746+
if converted_type is not None:
1747+
return converted_type
17311748

1732-
else:
1733-
_kwargs["types_mapper"] = default_types_mapper
1749+
if len(type_mappers) > 0:
1750+
return default_types_mapper
17341751

1735-
return arrow_table.to_pandas(categories=categories, **_kwargs)
1752+
@classmethod
1753+
def _arrow_table_to_pandas(
1754+
cls,
1755+
arrow_table: pa.Table,
1756+
categories,
1757+
use_nullable_dtypes=False,
1758+
convert_string=False,
1759+
**kwargs,
1760+
) -> pd.DataFrame:
1761+
_kwargs = kwargs.get("arrow_to_pandas", {})
1762+
_kwargs.update({"use_threads": False, "ignore_metadata": False})
1763+
1764+
types_mapper = cls._determine_type_mapper(
1765+
use_nullable_dtypes=use_nullable_dtypes,
1766+
convert_string=convert_string,
1767+
**kwargs,
1768+
)
1769+
if types_mapper is not None:
1770+
_kwargs["types_mapper"] = types_mapper
1771+
1772+
res = arrow_table.to_pandas(categories=categories, **_kwargs)
1773+
# TODO: remove this when fixed in pyarrow: https://github.com/apache/arrow/issues/34283
1774+
if (
1775+
convert_string
1776+
and isinstance(res.index, pd.Index)
1777+
and not isinstance(res.index, pd.MultiIndex)
1778+
and pd.api.types.is_string_dtype(res.index.dtype)
1779+
and res.index.dtype
1780+
not in (pd.StringDtype("pyarrow"), pd.ArrowDtype(pa.string()))
1781+
):
1782+
res.index = res.index.astype(pd.StringDtype("pyarrow"))
1783+
return res
17361784

17371785
@classmethod
17381786
def collect_file_metadata(cls, path, fs, file_path):

dask/dataframe/io/parquet/fastparquet.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
except ImportError:
2121
pass
2222

23+
from dask import config
2324
from dask.base import tokenize
2425

2526
#########################
@@ -891,6 +892,11 @@ def read_metadata(
891892
raise ValueError(
892893
"`use_nullable_dtypes` is not supported by the fastparquet engine"
893894
)
895+
if config.get("dataframe.convert_string", False):
896+
warnings.warn(
897+
"`dataframe.convert_string` is not supported by the fastparquet engine",
898+
category=UserWarning,
899+
)
894900

895901
# Stage 1: Collect general dataset information
896902
dataset_info = cls._collect_dataset_info(

dask/dataframe/io/parquet/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,8 @@ def read_partition(
194194
use_nullable_dtypes: boolean
195195
Whether to use pandas nullable dtypes (like "string" or "Int64")
196196
where appropriate when reading parquet files.
197+
convert_string: boolean
198+
Whether to use pyarrow strings when reading parquet files.
197199
**kwargs:
198200
Includes `"kwargs"` values stored within the `parts` output
199201
of `engine.read_metadata`. May also include arguments to be

dask/dataframe/io/tests/test_parquet.py

Lines changed: 88 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3330,11 +3330,20 @@ def clamp_arrow_datetimes(cls, arrow_table: pa.Table) -> pa.Table:
33303330

33313331
@classmethod
33323332
def _arrow_table_to_pandas(
3333-
cls, arrow_table: pa.Table, categories, use_nullable_dtypes=False, **kwargs
3333+
cls,
3334+
arrow_table: pa.Table,
3335+
categories,
3336+
use_nullable_dtypes=False,
3337+
convert_string=False,
3338+
**kwargs,
33343339
) -> pd.DataFrame:
33353340
fixed_arrow_table = cls.clamp_arrow_datetimes(arrow_table)
33363341
return super()._arrow_table_to_pandas(
3337-
fixed_arrow_table, categories, use_nullable_dtypes, **kwargs
3342+
fixed_arrow_table,
3343+
categories,
3344+
use_nullable_dtypes=use_nullable_dtypes,
3345+
convert_string=convert_string,
3346+
**kwargs,
33383347
)
33393348

33403349
# this should not fail, but instead produce timestamps that are in the valid range
@@ -4556,3 +4565,80 @@ def test_select_filtered_column(tmp_path, engine):
45564565
with pytest.warns(UserWarning, match="Sorted columns detected"):
45574566
ddf = dd.read_parquet(path, engine=engine, filters=[("b", "==", "cat")])
45584567
assert_eq(df, ddf)
4568+
4569+
4570+
@PYARROW_MARK
4571+
@pytest.mark.parametrize("convert_string", [True, False])
4572+
@pytest.mark.skipif(not PANDAS_GT_150, reason="requires pd.ArrowDtype")
4573+
def test_read_parquet_convert_string(tmp_path, convert_string, engine):
4574+
df = pd.DataFrame(
4575+
{"A": ["def", "abc", "ghi"], "B": [5, 2, 3], "C": ["x", "y", "z"]}
4576+
).set_index("C")
4577+
4578+
outfile = tmp_path / "out.parquet"
4579+
df.to_parquet(outfile, engine=engine)
4580+
4581+
with dask.config.set({"dataframe.convert_string": convert_string}):
4582+
ddf = dd.read_parquet(outfile, engine="pyarrow")
4583+
4584+
if convert_string:
4585+
expected = df.astype({"A": "string[pyarrow]"})
4586+
expected.index = expected.index.astype("string[pyarrow]")
4587+
else:
4588+
expected = df
4589+
assert_eq(ddf, expected)
4590+
assert len(ddf.dask.layers) == 1
4591+
4592+
4593+
@PYARROW_MARK
4594+
@pytest.mark.skipif(not PANDAS_GT_150, reason="requires pd.ArrowDtype")
4595+
def test_read_parquet_convert_string_nullable_mapper(tmp_path, engine):
4596+
"""Make sure that when convert_string, use_nullable_dtypes and types_mapper are set,
4597+
all three are used."""
4598+
df = pd.DataFrame(
4599+
{
4600+
"A": pd.Series(["def", "abc", "ghi"], dtype="string"),
4601+
"B": pd.Series([5, 2, 3], dtype="Int64"),
4602+
"C": pd.Series([1.1, 6.3, 8.4], dtype="Float32"),
4603+
"I": pd.Series(["x", "y", "z"], dtype="string"),
4604+
}
4605+
).set_index("I")
4606+
4607+
outfile = tmp_path / "out.parquet"
4608+
df.to_parquet(outfile, engine=engine)
4609+
4610+
types_mapper = {
4611+
pa.float32(): pd.Float64Dtype(),
4612+
}
4613+
4614+
with dask.config.set({"dataframe.convert_string": True}):
4615+
ddf = dd.read_parquet(
4616+
tmp_path,
4617+
engine="pyarrow",
4618+
use_nullable_dtypes="pandas",
4619+
arrow_to_pandas={"types_mapper": types_mapper.get},
4620+
)
4621+
4622+
expected = df.astype(
4623+
{
4624+
"A": "string[pyarrow]", # bc dataframe.convert_string=True
4625+
"B": pd.Int64Dtype(), # bc use_nullable_dtypes=Pandas
4626+
"C": pd.Float64Dtype(), # bc user mapper
4627+
}
4628+
)
4629+
expected.index = expected.index.astype("string[pyarrow]")
4630+
4631+
assert_eq(ddf, expected)
4632+
4633+
4634+
@FASTPARQUET_MARK
4635+
def test_read_parquet_convert_string_fastparquet_warns(tmp_path):
4636+
df = pd.DataFrame({"A": ["def", "abc", "ghi"], "B": [5, 2, 3]})
4637+
outfile = tmp_path / "out.parquet"
4638+
df.to_parquet(outfile)
4639+
4640+
with dask.config.set({"dataframe.convert_string": True}):
4641+
with pytest.warns(
4642+
UserWarning, match="`dataframe.convert_string` is not supported"
4643+
):
4644+
dd.read_parquet(outfile, engine="fastparquet")

0 commit comments

Comments
 (0)