Skip to content

Commit 945435b

Browse files
authored
Remove statistics-based set_index logic from read_parquet (#9661)
1 parent b1e468e commit 945435b

3 files changed

Lines changed: 63 additions & 68 deletions

File tree

dask/dataframe/io/parquet/arrow.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -995,7 +995,7 @@ def _create_dd_meta(cls, dataset_info, use_nullable_dtypes=False):
995995
columns = None
996996

997997
# Use pandas metadata to update categories
998-
pandas_metadata = _get_pandas_metadata(schema)
998+
pandas_metadata = _get_pandas_metadata(schema) or {}
999999
if pandas_metadata:
10001000
if categories is None:
10011001
categories = []
@@ -1021,7 +1021,15 @@ def _create_dd_meta(cls, dataset_info, use_nullable_dtypes=False):
10211021

10221022
# Use index specified in the pandas metadata if
10231023
# the index column was not specified by the user
1024-
if index is None and index_names:
1024+
if (
1025+
index is None
1026+
and index_names
1027+
and (
1028+
# Only set to `[None]` if pandas metadata includes an index
1029+
index_names != [None]
1030+
or pandas_metadata.get("index_columns", None)
1031+
)
1032+
):
10251033
index = index_names
10261034

10271035
# Set proper index for meta

dask/dataframe/io/parquet/core.py

Lines changed: 42 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -509,7 +509,7 @@ def read_parquet(
509509
aggregation_depth = parts[0].pop("aggregation_depth", aggregation_depth)
510510

511511
# Parse dataset statistics from metadata (if available)
512-
parts, divisions, index, index_in_columns = process_statistics(
512+
parts, divisions, index = process_statistics(
513513
parts,
514514
statistics,
515515
filters,
@@ -522,18 +522,10 @@ def read_parquet(
522522

523523
# Account for index and columns arguments.
524524
# Modify `meta` dataframe accordingly
525-
meta, index, columns = set_index_columns(
526-
meta, index, columns, index_in_columns, auto_index_allowed
527-
)
525+
meta, index, columns = set_index_columns(meta, index, columns, auto_index_allowed)
528526
if meta.index.name == NONE_LABEL:
529527
meta.index.name = None
530528

531-
# Set the index that was previously treated as a column
532-
if index_in_columns:
533-
meta = meta.set_index(index)
534-
if meta.index.name == NONE_LABEL:
535-
meta.index.name = None
536-
537529
if len(divisions) < 2:
538530
# empty dataframe - just use meta
539531
divisions = (None, None)
@@ -1209,11 +1201,13 @@ def get_engine(engine):
12091201
#####################
12101202

12111203

1212-
def sorted_columns(statistics):
1204+
def sorted_columns(statistics, columns=None):
12131205
"""Find sorted columns given row-group statistics
12141206
1215-
This finds all columns that are sorted, along with appropriate divisions
1216-
values for those columns
1207+
This finds all columns that are sorted, along with the
1208+
appropriate ``divisions`` for those columns. If the (optional)
1209+
``columns`` argument is used, the search will be restricted
1210+
to the specified column set.
12171211
12181212
Returns
12191213
-------
@@ -1224,6 +1218,8 @@ def sorted_columns(statistics):
12241218

12251219
out = []
12261220
for i, c in enumerate(statistics[0]["columns"]):
1221+
if columns and c["name"] not in columns:
1222+
continue
12271223
if not all(
12281224
"min" in s["columns"][i] and "max" in s["columns"][i] for s in statistics
12291225
):
@@ -1347,7 +1343,6 @@ def process_statistics(
13471343
"""Process row-group column statistics in metadata
13481344
Used in read_parquet.
13491345
"""
1350-
index_in_columns = False
13511346
if statistics and len(parts) != len(statistics):
13521347
# It is up to the Engine to guarantee that these
13531348
# lists are the same length (if statistics are defined).
@@ -1362,6 +1357,7 @@ def process_statistics(
13621357
)
13631358
statistics = []
13641359

1360+
divisions = None
13651361
if statistics:
13661362
result = list(
13671363
zip(
@@ -1382,52 +1378,43 @@ def process_statistics(
13821378
parts, statistics, chunksize, split_row_groups, fs, aggregation_depth
13831379
)
13841380

1385-
out = sorted_columns(statistics)
1381+
# Convert str index to list
1382+
index = [index] if isinstance(index, str) else index
13861383

1387-
if index and isinstance(index, str):
1388-
index = [index]
1389-
if index and out:
1390-
# Only one valid column
1391-
out = [o for o in out if o["name"] in index]
1392-
if index is not False and len(out) == 1:
1393-
# Use only sorted column with statistics as the index
1394-
divisions = out[0]["divisions"]
1395-
if index is None:
1396-
index_in_columns = True
1397-
index = [out[0]["name"]]
1398-
elif index != [out[0]["name"]]:
1399-
raise ValueError(f"Specified index is invalid.\nindex: {index}")
1400-
elif index is not False and len(out) > 1:
1401-
if any(o["name"] == NONE_LABEL for o in out):
1402-
# Use sorted column matching NONE_LABEL as the index
1403-
[o] = [o for o in out if o["name"] == NONE_LABEL]
1404-
divisions = o["divisions"]
1405-
if index is None:
1406-
index = [o["name"]]
1407-
index_in_columns = True
1408-
elif index != [o["name"]]:
1409-
raise ValueError(f"Specified index is invalid.\nindex: {index}")
1410-
else:
1411-
# Multiple sorted columns found, cannot autodetect the index
1384+
# TODO: Remove `filters` criteria below after deprecation cycle.
1385+
# We can then remove the `sorted_col_names` logic and warning.
1386+
# See: https://github.com/dask/dask/pull/9661
1387+
process_columns = index if index and len(index) == 1 else None
1388+
if filters:
1389+
process_columns = None
1390+
1391+
# Use statistics to define divisions
1392+
if process_columns or filters:
1393+
sorted_col_names = []
1394+
for sorted_column_info in sorted_columns(
1395+
statistics, columns=process_columns
1396+
):
1397+
if index and sorted_column_info["name"] in index:
1398+
divisions = sorted_column_info["divisions"]
1399+
break
1400+
else:
1401+
# Filtered columns may also be sorted
1402+
sorted_col_names.append(sorted_column_info["name"])
1403+
1404+
if index is None and sorted_col_names:
1405+
assert bool(filters) # Should only get here when filtering
14121406
warnings.warn(
1413-
"Multiple sorted columns found %s, cannot\n "
1414-
"autodetect index. Will continue without an index.\n"
1415-
"To pick an index column, use the index= keyword; to \n"
1416-
"silence this warning use index=False."
1417-
"" % [o["name"] for o in out],
1418-
RuntimeWarning,
1407+
f"Sorted columns detected: {sorted_col_names}\n"
1408+
f"Use the `index` argument to set a sorted column as your "
1409+
f"index to create a DataFrame collection with known `divisions`.",
1410+
UserWarning,
14191411
)
1420-
index = False
1421-
divisions = [None] * (len(parts) + 1)
1422-
else:
1423-
divisions = [None] * (len(parts) + 1)
1424-
else:
1425-
divisions = [None] * (len(parts) + 1)
14261412

1427-
return parts, divisions, index, index_in_columns
1413+
divisions = divisions or (None,) * (len(parts) + 1)
1414+
return parts, divisions, index
14281415

14291416

1430-
def set_index_columns(meta, index, columns, index_in_columns, auto_index_allowed):
1417+
def set_index_columns(meta, index, columns, auto_index_allowed):
14311418
"""Handle index/column arguments, and modify `meta`
14321419
Used in read_parquet.
14331420
"""
@@ -1471,18 +1458,7 @@ def set_index_columns(meta, index, columns, index_in_columns, auto_index_allowed
14711458
"index: {} | column: {}".format(index, columns)
14721459
)
14731460

1474-
# Leaving index as a column in `meta`, because the index
1475-
# will be reset below (in case the index was detected after
1476-
# meta was created)
1477-
if index_in_columns:
1478-
meta = meta[columns + index]
1479-
else:
1480-
meta = meta[columns]
1481-
1482-
else:
1483-
meta = meta[list(columns)]
1484-
1485-
return meta, index, columns
1461+
return meta[list(columns)], index, columns
14861462

14871463

14881464
def aggregate_row_groups(

dask/dataframe/io/tests/test_parquet.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4401,3 +4401,14 @@ def test_retries_on_remote_filesystem(tmpdir):
44014401
layer = hlg_layer(ddf2.dask, "read-parquet")
44024402
assert layer.annotations
44034403
assert layer.annotations["retries"] == 2
4404+
4405+
4406+
def test_select_filtered_column(tmp_path, engine):
4407+
4408+
df = pd.DataFrame({"a": range(10), "b": ["cat"] * 10})
4409+
path = tmp_path / "test_select_filtered_column.parquet"
4410+
df.to_parquet(path, index=False)
4411+
4412+
with pytest.warns(UserWarning, match="Sorted columns detected"):
4413+
ddf = dd.read_parquet(path, engine=engine, filters=[("b", "==", "cat")])
4414+
assert_eq(df, ddf)

0 commit comments

Comments
 (0)