Skip to content

Commit 6369cdb

Browse files
Add arrow schema extraction dispatch (#9169)
1 parent 9bf7dd7 commit 6369cdb

4 files changed

Lines changed: 34 additions & 1 deletion

File tree

dask/dataframe/backends.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
make_meta_dispatch,
3131
make_meta_obj,
3232
meta_nonempty,
33+
pyarrow_schema_dispatch,
3334
tolist_dispatch,
3435
union_categoricals_dispatch,
3536
)
@@ -81,6 +82,13 @@ def _(x, index=None):
8182
pass
8283

8384

85+
@pyarrow_schema_dispatch.register((pd.DataFrame,))
86+
def get_pyarrow_schema_pandas(obj):
87+
import pyarrow as pa
88+
89+
return pa.Schema.from_pandas(obj)
90+
91+
8492
@meta_nonempty.register(pd.DatetimeTZDtype)
8593
@make_meta_dispatch.register(pd.DatetimeTZDtype)
8694
def make_meta_pandas_datetime_tz(x, index=None):

dask/dataframe/dispatch.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
is_categorical_dtype_dispatch = Dispatch("is_categorical_dtype")
2323
union_categoricals_dispatch = Dispatch("union_categoricals")
2424
grouper_dispatch = Dispatch("grouper")
25+
pyarrow_schema_dispatch = Dispatch("pyarrow_schema_dispatch")
2526

2627

2728
def concat(

dask/dataframe/io/parquet/arrow.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from dask.base import tokenize
1313
from dask.core import flatten
14+
from dask.dataframe.backends import pyarrow_schema_dispatch
1415
from dask.dataframe.io.parquet.utils import (
1516
Engine,
1617
_get_aggregation_depth,
@@ -512,7 +513,7 @@ def initialize_write(
512513
):
513514
if schema == "infer" or isinstance(schema, dict):
514515
# Start with schema from _meta_nonempty
515-
inferred_schema = pa.Schema.from_pandas(
516+
inferred_schema = pyarrow_schema_dispatch(
516517
df._meta_nonempty.set_index(index_cols)
517518
if index_cols
518519
else df._meta_nonempty

dask/dataframe/io/tests/test_parquet.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4177,3 +4177,26 @@ def test_deprecate_gather_statistics(tmp_path, engine):
41774177
gather_statistics=True,
41784178
)
41794179
assert_eq(out, df)
4180+
4181+
4182+
@pytest.mark.gpu
4183+
def test_gpu_write_parquet_simple(tmpdir):
4184+
fn = str(tmpdir)
4185+
cudf = pytest.importorskip("cudf")
4186+
dask_cudf = pytest.importorskip("dask_cudf")
4187+
from dask.dataframe.dispatch import pyarrow_schema_dispatch
4188+
4189+
@pyarrow_schema_dispatch.register((cudf.DataFrame,))
4190+
def get_pyarrow_schema_cudf(obj):
4191+
return obj.to_arrow().schema
4192+
4193+
df = cudf.DataFrame(
4194+
{
4195+
"a": ["abc", "def"],
4196+
"b": ["a", "z"],
4197+
}
4198+
)
4199+
ddf = dask_cudf.from_cudf(df, 3)
4200+
ddf.to_parquet(fn)
4201+
got = dask_cudf.read_parquet(fn)
4202+
assert_eq(df, got)

0 commit comments

Comments
 (0)