Skip to content

Commit b1e468e

Browse files
authored
Add support for use_nullable_dtypes to dd.read_parquet (#9617)
1 parent f309f9f commit b1e468e

7 files changed

Lines changed: 244 additions & 13 deletions

File tree

dask/dataframe/io/parquet/arrow.py

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from dask.base import tokenize
1313
from dask.core import flatten
14+
from dask.dataframe._compat import PANDAS_GT_120
1415
from dask.dataframe.backends import pyarrow_schema_dispatch
1516
from dask.dataframe.io.parquet.utils import (
1617
Engine,
@@ -37,6 +38,23 @@
3738
partitioning_supported = _pa_version >= parse_version("5.0.0")
3839
del _pa_version
3940

41+
PYARROW_NULLABLE_DTYPE_MAPPING = {
42+
pa.int8(): pd.Int8Dtype(),
43+
pa.int16(): pd.Int16Dtype(),
44+
pa.int32(): pd.Int32Dtype(),
45+
pa.int64(): pd.Int64Dtype(),
46+
pa.uint8(): pd.UInt8Dtype(),
47+
pa.uint16(): pd.UInt16Dtype(),
48+
pa.uint32(): pd.UInt32Dtype(),
49+
pa.uint64(): pd.UInt64Dtype(),
50+
pa.bool_(): pd.BooleanDtype(),
51+
pa.string(): pd.StringDtype(),
52+
}
53+
54+
if PANDAS_GT_120:
55+
PYARROW_NULLABLE_DTYPE_MAPPING[pa.float32()] = pd.Float32Dtype()
56+
PYARROW_NULLABLE_DTYPE_MAPPING[pa.float64()] = pd.Float64Dtype()
57+
4058
#
4159
# Helper Utilities
4260
#
@@ -321,6 +339,7 @@ def read_metadata(
321339
paths,
322340
categories=None,
323341
index=None,
342+
use_nullable_dtypes=False,
324343
gather_statistics=None,
325344
filters=None,
326345
split_row_groups=False,
@@ -350,7 +369,7 @@ def read_metadata(
350369
)
351370

352371
# Stage 2: Generate output `meta`
353-
meta = cls._create_dd_meta(dataset_info)
372+
meta = cls._create_dd_meta(dataset_info, use_nullable_dtypes)
354373

355374
# Stage 3: Generate parts and stats
356375
parts, stats, common_kwargs = cls._construct_collection_plan(dataset_info)
@@ -375,6 +394,7 @@ def read_partition(
375394
pieces,
376395
columns,
377396
index,
397+
use_nullable_dtypes=False,
378398
categories=(),
379399
partitions=(),
380400
filters=None,
@@ -445,7 +465,9 @@ def read_partition(
445465
arrow_table = pa.concat_tables(tables)
446466

447467
# Convert to pandas
448-
df = cls._arrow_table_to_pandas(arrow_table, categories, **kwargs)
468+
df = cls._arrow_table_to_pandas(
469+
arrow_table, categories, use_nullable_dtypes=use_nullable_dtypes, **kwargs
470+
)
449471

450472
# For pyarrow.dataset api, need to convert partition columns
451473
# to categorigal manually for integer types.
@@ -958,7 +980,7 @@ def _collect_dataset_info(
958980
}
959981

960982
@classmethod
961-
def _create_dd_meta(cls, dataset_info):
983+
def _create_dd_meta(cls, dataset_info, use_nullable_dtypes=False):
962984
"""Use parquet schema and hive-partition information
963985
(stored in dataset_info) to construct DataFrame metadata.
964986
"""
@@ -989,6 +1011,7 @@ def _create_dd_meta(cls, dataset_info):
9891011
schema.empty_table(),
9901012
categories,
9911013
arrow_to_pandas=arrow_to_pandas,
1014+
use_nullable_dtypes=use_nullable_dtypes,
9921015
)
9931016
index_names = list(meta.index.names)
9941017
column_names = list(meta.columns)
@@ -1543,11 +1566,26 @@ def _read_table(
15431566

15441567
@classmethod
15451568
def _arrow_table_to_pandas(
1546-
cls, arrow_table: pa.Table, categories, **kwargs
1569+
cls, arrow_table: pa.Table, categories, use_nullable_dtypes=False, **kwargs
15471570
) -> pd.DataFrame:
15481571
_kwargs = kwargs.get("arrow_to_pandas", {})
15491572
_kwargs.update({"use_threads": False, "ignore_metadata": False})
15501573

1574+
if use_nullable_dtypes:
1575+
if "types_mapper" in _kwargs:
1576+
# User-provided entries take priority over PYARROW_NULLABLE_DTYPE_MAPPING
1577+
types_mapper = _kwargs["types_mapper"]
1578+
1579+
def _types_mapper(pa_type):
1580+
return types_mapper(pa_type) or PYARROW_NULLABLE_DTYPE_MAPPING.get(
1581+
pa_type
1582+
)
1583+
1584+
_kwargs["types_mapper"] = _types_mapper
1585+
1586+
else:
1587+
_kwargs["types_mapper"] = PYARROW_NULLABLE_DTYPE_MAPPING.get
1588+
15511589
return arrow_table.to_pandas(categories=categories, **_kwargs)
15521590

15531591
@classmethod

dask/dataframe/io/parquet/core.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def __init__(
4545
meta,
4646
columns,
4747
index,
48+
use_nullable_dtypes,
4849
kwargs,
4950
common_kwargs,
5051
):
@@ -53,6 +54,7 @@ def __init__(
5354
self.meta = meta
5455
self._columns = columns
5556
self.index = index
57+
self.use_nullable_dtypes = use_nullable_dtypes
5658

5759
# `kwargs` = user-defined kwargs to be passed
5860
# identically for all partitions.
@@ -78,6 +80,7 @@ def project_columns(self, columns):
7880
self.meta,
7981
columns,
8082
self.index,
83+
self.use_nullable_dtypes,
8184
None, # Already merged into common_kwargs
8285
self.common_kwargs,
8386
)
@@ -101,6 +104,7 @@ def __call__(self, part):
101104
],
102105
self.columns,
103106
self.index,
107+
self.use_nullable_dtypes,
104108
self.common_kwargs,
105109
)
106110

@@ -181,6 +185,7 @@ def read_parquet(
181185
index=None,
182186
storage_options=None,
183187
engine="auto",
188+
use_nullable_dtypes=False,
184189
calculate_divisions=None,
185190
ignore_metadata_file=False,
186191
metadata_task_size=None,
@@ -433,6 +438,7 @@ def read_parquet(
433438
"index": index,
434439
"storage_options": storage_options,
435440
"engine": engine,
441+
"use_nullable_dtypes": use_nullable_dtypes,
436442
"calculate_divisions": calculate_divisions,
437443
"ignore_metadata_file": ignore_metadata_file,
438444
"metadata_task_size": metadata_task_size,
@@ -475,6 +481,7 @@ def read_parquet(
475481
paths,
476482
categories=categories,
477483
index=index,
484+
use_nullable_dtypes=use_nullable_dtypes,
478485
gather_statistics=calculate_divisions,
479486
filters=filters,
480487
split_row_groups=split_row_groups,
@@ -540,6 +547,7 @@ def read_parquet(
540547
meta,
541548
columns,
542549
index,
550+
use_nullable_dtypes,
543551
{}, # All kwargs should now be in `common_kwargs`
544552
common_kwargs,
545553
)
@@ -578,7 +586,9 @@ def check_multi_support(engine):
578586
return hasattr(engine, "multi_support") and engine.multi_support()
579587

580588

581-
def read_parquet_part(fs, engine, meta, part, columns, index, kwargs):
589+
def read_parquet_part(
590+
fs, engine, meta, part, columns, index, use_nullable_dtypes, kwargs
591+
):
582592
"""Read a part of a parquet dataset
583593
584594
This function is used by `read_parquet`."""
@@ -587,22 +597,39 @@ def read_parquet_part(fs, engine, meta, part, columns, index, kwargs):
587597
# Part kwargs expected
588598
func = engine.read_partition
589599
dfs = [
590-
func(fs, rg, columns.copy(), index, **toolz.merge(kwargs, kw))
600+
func(
601+
fs,
602+
rg,
603+
columns.copy(),
604+
index,
605+
use_nullable_dtypes=use_nullable_dtypes,
606+
**toolz.merge(kwargs, kw),
607+
)
591608
for (rg, kw) in part
592609
]
593610
df = concat(dfs, axis=0) if len(dfs) > 1 else dfs[0]
594611
else:
595612
# No part specific kwargs, let engine read
596613
# list of parts at once
597614
df = engine.read_partition(
598-
fs, [p[0] for p in part], columns.copy(), index, **kwargs
615+
fs,
616+
[p[0] for p in part],
617+
columns.copy(),
618+
index,
619+
use_nullable_dtypes=use_nullable_dtypes,
620+
**kwargs,
599621
)
600622
else:
601623
# NOTE: `kwargs` are the same for all parts, while `part_kwargs` may
602624
# be different for each part.
603625
rg, part_kwargs = part
604626
df = engine.read_partition(
605-
fs, rg, columns, index, **toolz.merge(kwargs, part_kwargs)
627+
fs,
628+
rg,
629+
columns,
630+
index,
631+
use_nullable_dtypes=use_nullable_dtypes,
632+
**toolz.merge(kwargs, part_kwargs),
606633
)
607634

608635
if meta.columns.name:

dask/dataframe/io/parquet/fastparquet.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -821,6 +821,7 @@ def read_metadata(
821821
paths,
822822
categories=None,
823823
index=None,
824+
use_nullable_dtypes=False,
824825
gather_statistics=None,
825826
filters=None,
826827
split_row_groups=False,
@@ -831,6 +832,10 @@ def read_metadata(
831832
parquet_file_extension=None,
832833
**kwargs,
833834
):
835+
if use_nullable_dtypes:
836+
raise ValueError(
837+
"`use_nullable_dtypes` is not supported by the fastparquet engine"
838+
)
834839

835840
# Stage 1: Collect general dataset information
836841
dataset_info = cls._collect_dataset_info(
@@ -890,6 +895,7 @@ def read_partition(
890895
pieces,
891896
columns,
892897
index,
898+
use_nullable_dtypes=False,
893899
categories=(),
894900
root_cats=None,
895901
root_file_scheme=None,

dask/dataframe/io/parquet/utils.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ def read_metadata(
1717
paths,
1818
categories=None,
1919
index=None,
20+
use_nullable_dtypes=False,
2021
gather_statistics=None,
2122
filters=None,
2223
**kwargs,
@@ -37,6 +38,9 @@ def read_metadata(
3738
The column name(s) to be used as the index.
3839
If set to ``None``, pandas metadata (if available) can be used
3940
to reset the value in this function
41+
use_nullable_dtypes: boolean
42+
Whether to use pandas nullable dtypes (like "string" or "Int64")
43+
where appropriate when reading parquet files.
4044
gather_statistics: bool
4145
Whether or not to gather statistics to calculate divisions
4246
for the output DataFrame collection.
@@ -73,7 +77,9 @@ def read_metadata(
7377
raise NotImplementedError()
7478

7579
@classmethod
76-
def read_partition(cls, fs, piece, columns, index, **kwargs):
80+
def read_partition(
81+
cls, fs, piece, columns, index, use_nullable_dtypes=False, **kwargs
82+
):
7783
"""Read a single piece of a Parquet dataset into a Pandas DataFrame
7884
7985
This function is called many times in individual tasks
@@ -88,6 +94,9 @@ def read_partition(cls, fs, piece, columns, index, **kwargs):
8894
List of column names to pull out of that row group
8995
index: str, List[str], or False
9096
The index name(s).
97+
use_nullable_dtypes: boolean
98+
Whether to use pandas nullable dtypes (like "string" or "Int64")
99+
where appropriate when reading parquet files.
91100
**kwargs:
92101
Includes `"kwargs"` values stored within the `parts` output
93102
of `engine.read_metadata`. May also include arguments to be

0 commit comments

Comments
 (0)