Skip to content

Commit 9915a79

Browse files
d-v-bdcherian
andauthored
use integer multiple of shard shape when rechunking in to_zarr (#12106)
Co-authored-by: Deepak Cherian <dcherian@users.noreply.github.com>
1 parent 06e75c7 commit 9915a79

2 files changed

Lines changed: 97 additions & 21 deletions

File tree

dask/array/core.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3889,22 +3889,47 @@ def to_zarr(
38893889
"Cannot store into in memory Zarr Array using "
38903890
"the distributed scheduler."
38913891
)
3892+
zarr_write_chunks = _get_zarr_write_chunks(z)
3893+
dask_write_chunks = normalize_chunks(
3894+
chunks="auto",
3895+
shape=z.shape,
3896+
dtype=z.dtype,
3897+
previous_chunks=zarr_write_chunks,
3898+
)
38923899

3900+
for ax, (dw, zw) in enumerate(
3901+
zip(dask_write_chunks, zarr_write_chunks, strict=True)
3902+
):
3903+
if len(dw) >= 1:
3904+
nominal_dask_chunk_size = dw[0]
3905+
if not nominal_dask_chunk_size % zw == 0:
3906+
safe_chunk_size = np.prod(zarr_write_chunks) * max(
3907+
1, z.dtype.itemsize
3908+
)
3909+
msg = (
3910+
f"The input Dask array will be rechunked along axis {ax} with chunk size "
3911+
f"{nominal_dask_chunk_size}, but a chunk size divisible by {zw} is "
3912+
f"required for Dask to write safely to the Zarr array {z}. "
3913+
"To avoid risk of data loss when writing to this Zarr array, set the "
3914+
'"array.chunk-size" configuration parameter to at least the size in'
3915+
" bytes of a single on-disk "
3916+
f"chunk (or shard) of the Zarr array, which in this case is "
3917+
f"{safe_chunk_size} bytes. "
3918+
f'E.g., dask.config.set({{"array.chunk-size": {safe_chunk_size}}})'
3919+
)
3920+
raise PerformanceWarning(msg)
3921+
break
38933922
if region is None:
38943923
# Get the appropriate write granularity (shard shape if sharding, else chunk shape)
3895-
write_chunks = _get_zarr_write_chunks(z)
3896-
arr = arr.rechunk(write_chunks)
3924+
arr = arr.rechunk(dask_write_chunks)
38973925
regions = None
38983926
else:
38993927
from dask.array.slicing import new_blockdim, normalize_index
39003928

3901-
# For regions, use the appropriate write granularity
3902-
write_chunks = _get_zarr_write_chunks(z)
3903-
old_chunks = normalize_chunks(write_chunks, z.shape)
39043929
index = normalize_index(region, z.shape)
39053930
chunks = tuple(
39063931
tuple(new_blockdim(s, c, r))
3907-
for s, c, r in zip(z.shape, old_chunks, index)
3932+
for s, c, r in zip(z.shape, dask_write_chunks, index)
39083933
)
39093934
arr = arr.rechunk(chunks)
39103935
regions = [region]

dask/array/tests/test_array_core.py

Lines changed: 66 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import pathlib
66
import re
77
import xml.etree.ElementTree
8+
from typing import Literal
89

910
import pytest
1011

@@ -31,6 +32,7 @@
3132
from dask.array.core import (
3233
Array,
3334
BlockView,
35+
PerformanceWarning,
3436
blockdims_from_blockshape,
3537
broadcast_chunks,
3638
broadcast_shapes,
@@ -5122,44 +5124,93 @@ def test_from_array_respects_zarr_shards():
51225124
assert all(c % s == 0 for c, s in zip(dz.chunksize, z.shards))
51235125

51245126

5125-
def test_zarr_chunk_shards_mismatch_warns():
5127+
@pytest.mark.parametrize("region_spec", [None, "all", "half"])
5128+
def test_zarr_to_zarr_shards(region_spec: None | Literal["all", "half"]):
51265129
"""
51275130
Test that calling to_zarr with a dask array with chunks that do not match the
5128-
shard shape of the zarr array automatically rechunks to the shard shape to ensure
5129-
safe writes.
5131+
shard shape of the zarr array automatically rechunks to a multiple of the
5132+
shard shape to ensure safe writes.
5133+
5134+
This test is parametrized over different regions, because the rechunking logic in
5135+
to_zarr contains an branch depending on whether a region parameter was specified.
51305136
"""
51315137
zarr = pytest.importorskip("zarr", minversion="3.0.0")
5132-
import numpy as np
51335138

5134-
shape = (24,)
5135-
dask_chunks = (10,) # Not aligned with shard boundaries
5136-
zarr_chunk_shape = (4,) # Inner chunk shape
5137-
zarr_shard_shape = (12,) # Shard contains 3 chunks of size 4
5139+
shape = (100,)
5140+
dask_chunks = (10,)
5141+
zarr_chunk_shape = (1,)
5142+
zarr_shard_shape = (2,)
51385143

51395144
# Create a dask array with chunks that don't align with shards
51405145
arr = da.arange(shape[0], chunks=dask_chunks)
51415146

5147+
# the region parameter we will pass into to_zarr
5148+
region: tuple[slice, ...] | None
5149+
5150+
# The region of the zarr array we will write into
5151+
sel: tuple[slice, ...]
5152+
5153+
if region_spec is None:
5154+
sel = (slice(None),)
5155+
region = None
5156+
elif region_spec == "all":
5157+
sel = (slice(None),)
5158+
region = sel
5159+
else:
5160+
sel = (slice(shape[0] // 2),)
5161+
region = sel
5162+
# crop the source data
5163+
arr = arr[sel]
5164+
51425165
# Create a sharded zarr array
51435166
# In Zarr v3: chunks = inner chunk shape, shards = shard shape
51445167
z = zarr.create_array(
5145-
store={}, # Use in-memory store
5168+
store={},
51465169
shape=shape,
51475170
chunks=zarr_chunk_shape,
51485171
shards=zarr_shard_shape,
51495172
dtype=arr.dtype,
51505173
)
51515174

5152-
# to_zarr should automatically rechunk to shard boundaries
5153-
result = arr.to_zarr(z, compute=False)
5175+
# to_zarr should automatically rechunk to a multiple of the shard shape
5176+
result = arr.to_zarr(z, region=region, compute=False)
51545177

51555178
# Verify the array was rechunked to the shard shape
5156-
assert result.chunks == (
5157-
(zarr_shard_shape[0], zarr_shard_shape[0]),
5158-
), f"Expected chunks {((zarr_shard_shape[0], zarr_shard_shape[0]),)}, got {result.chunks}"
5179+
assert all(c % s == 0 for c, s in zip(result.chunksize, zarr_shard_shape))
51595180

51605181
# Verify data correctness
51615182
result.compute()
5162-
assert_eq(z[:], np.arange(shape[0]))
5183+
assert_eq(z[sel], arr.compute())
5184+
5185+
5186+
def test_zarr_risky_shards_warns():
5187+
"""
5188+
Test that we see a performance warning when dask chooses a chunk size that will cause data loss
5189+
for zarr arrays.
5190+
"""
5191+
zarr = pytest.importorskip("zarr", minversion="3.0.0")
5192+
5193+
shape = (100,)
5194+
dask_chunks = (10,)
5195+
zarr_chunk_shape = (3,)
5196+
zarr_shard_shape = (6,)
5197+
5198+
arr = da.arange(shape[0], chunks=dask_chunks)
5199+
5200+
z = zarr.create_array(
5201+
store={},
5202+
shape=shape,
5203+
chunks=zarr_chunk_shape,
5204+
shards=zarr_shard_shape,
5205+
dtype=arr.dtype,
5206+
)
5207+
5208+
with dask.config.set({"array.chunk-size": 1}):
5209+
with pytest.raises(
5210+
PerformanceWarning,
5211+
match="The input Dask array will be rechunked along axis",
5212+
):
5213+
arr.to_zarr(z)
51635214

51645215

51655216
def test_zarr_nocompute():

0 commit comments

Comments
 (0)