|
5 | 5 | import pathlib |
6 | 6 | import re |
7 | 7 | import xml.etree.ElementTree |
| 8 | +from typing import Literal |
8 | 9 |
|
9 | 10 | import pytest |
10 | 11 |
|
|
31 | 32 | from dask.array.core import ( |
32 | 33 | Array, |
33 | 34 | BlockView, |
| 35 | + PerformanceWarning, |
34 | 36 | blockdims_from_blockshape, |
35 | 37 | broadcast_chunks, |
36 | 38 | broadcast_shapes, |
@@ -5122,44 +5124,93 @@ def test_from_array_respects_zarr_shards(): |
5122 | 5124 | assert all(c % s == 0 for c, s in zip(dz.chunksize, z.shards)) |
5123 | 5125 |
|
5124 | 5126 |
|
5125 | | -def test_zarr_chunk_shards_mismatch_warns(): |
| 5127 | +@pytest.mark.parametrize("region_spec", [None, "all", "half"]) |
| 5128 | +def test_zarr_to_zarr_shards(region_spec: None | Literal["all", "half"]): |
5126 | 5129 | """ |
5127 | 5130 | Test that calling to_zarr with a dask array with chunks that do not match the |
5128 | | - shard shape of the zarr array automatically rechunks to the shard shape to ensure |
5129 | | - safe writes. |
| 5131 | + shard shape of the zarr array automatically rechunks to a multiple of the |
| 5132 | + shard shape to ensure safe writes. |
| 5133 | +
|
| 5134 | + This test is parametrized over different regions, because the rechunking logic in |
| 5135 | + to_zarr contains an branch depending on whether a region parameter was specified. |
5130 | 5136 | """ |
5131 | 5137 | zarr = pytest.importorskip("zarr", minversion="3.0.0") |
5132 | | - import numpy as np |
5133 | 5138 |
|
5134 | | - shape = (24,) |
5135 | | - dask_chunks = (10,) # Not aligned with shard boundaries |
5136 | | - zarr_chunk_shape = (4,) # Inner chunk shape |
5137 | | - zarr_shard_shape = (12,) # Shard contains 3 chunks of size 4 |
| 5139 | + shape = (100,) |
| 5140 | + dask_chunks = (10,) |
| 5141 | + zarr_chunk_shape = (1,) |
| 5142 | + zarr_shard_shape = (2,) |
5138 | 5143 |
|
5139 | 5144 | # Create a dask array with chunks that don't align with shards |
5140 | 5145 | arr = da.arange(shape[0], chunks=dask_chunks) |
5141 | 5146 |
|
| 5147 | + # the region parameter we will pass into to_zarr |
| 5148 | + region: tuple[slice, ...] | None |
| 5149 | + |
| 5150 | + # The region of the zarr array we will write into |
| 5151 | + sel: tuple[slice, ...] |
| 5152 | + |
| 5153 | + if region_spec is None: |
| 5154 | + sel = (slice(None),) |
| 5155 | + region = None |
| 5156 | + elif region_spec == "all": |
| 5157 | + sel = (slice(None),) |
| 5158 | + region = sel |
| 5159 | + else: |
| 5160 | + sel = (slice(shape[0] // 2),) |
| 5161 | + region = sel |
| 5162 | + # crop the source data |
| 5163 | + arr = arr[sel] |
| 5164 | + |
5142 | 5165 | # Create a sharded zarr array |
5143 | 5166 | # In Zarr v3: chunks = inner chunk shape, shards = shard shape |
5144 | 5167 | z = zarr.create_array( |
5145 | | - store={}, # Use in-memory store |
| 5168 | + store={}, |
5146 | 5169 | shape=shape, |
5147 | 5170 | chunks=zarr_chunk_shape, |
5148 | 5171 | shards=zarr_shard_shape, |
5149 | 5172 | dtype=arr.dtype, |
5150 | 5173 | ) |
5151 | 5174 |
|
5152 | | - # to_zarr should automatically rechunk to shard boundaries |
5153 | | - result = arr.to_zarr(z, compute=False) |
| 5175 | + # to_zarr should automatically rechunk to a multiple of the shard shape |
| 5176 | + result = arr.to_zarr(z, region=region, compute=False) |
5154 | 5177 |
|
5155 | 5178 | # Verify the array was rechunked to the shard shape |
5156 | | - assert result.chunks == ( |
5157 | | - (zarr_shard_shape[0], zarr_shard_shape[0]), |
5158 | | - ), f"Expected chunks {((zarr_shard_shape[0], zarr_shard_shape[0]),)}, got {result.chunks}" |
| 5179 | + assert all(c % s == 0 for c, s in zip(result.chunksize, zarr_shard_shape)) |
5159 | 5180 |
|
5160 | 5181 | # Verify data correctness |
5161 | 5182 | result.compute() |
5162 | | - assert_eq(z[:], np.arange(shape[0])) |
| 5183 | + assert_eq(z[sel], arr.compute()) |
| 5184 | + |
| 5185 | + |
| 5186 | +def test_zarr_risky_shards_warns(): |
| 5187 | + """ |
| 5188 | + Test that we see a performance warning when dask chooses a chunk size that will cause data loss |
| 5189 | + for zarr arrays. |
| 5190 | + """ |
| 5191 | + zarr = pytest.importorskip("zarr", minversion="3.0.0") |
| 5192 | + |
| 5193 | + shape = (100,) |
| 5194 | + dask_chunks = (10,) |
| 5195 | + zarr_chunk_shape = (3,) |
| 5196 | + zarr_shard_shape = (6,) |
| 5197 | + |
| 5198 | + arr = da.arange(shape[0], chunks=dask_chunks) |
| 5199 | + |
| 5200 | + z = zarr.create_array( |
| 5201 | + store={}, |
| 5202 | + shape=shape, |
| 5203 | + chunks=zarr_chunk_shape, |
| 5204 | + shards=zarr_shard_shape, |
| 5205 | + dtype=arr.dtype, |
| 5206 | + ) |
| 5207 | + |
| 5208 | + with dask.config.set({"array.chunk-size": 1}): |
| 5209 | + with pytest.raises( |
| 5210 | + PerformanceWarning, |
| 5211 | + match="The input Dask array will be rechunked along axis", |
| 5212 | + ): |
| 5213 | + arr.to_zarr(z) |
5163 | 5214 |
|
5164 | 5215 |
|
5165 | 5216 | def test_zarr_nocompute(): |
|
0 commit comments