Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
311 changes: 230 additions & 81 deletions dask/blockwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,20 +183,6 @@ class Blockwise(Layer):
single input to the block function
new_axes: Dict
New index dimensions that may have been created, and their extent
io_subgraph: Tuple[str, Dict or Mapping]
If the blockwise operation corresponds to the generation of a new
collection (i.e. `indices` includes keys for a collection that has
yet to be constructed), this argument must be used to specify the
`(key-name, subgraph)` pair for the to-be-created collection. Note
that `key-name` must match the name used in both `indices` and in
the keys of `subgraph`.

NOTE: The `subgraph` must comprise of exactly N keys (where N is the
number of chunks/partitions in the new collection), and each value
must correspond to a "callable" task. The first element (a callable
function) must be the same for all N tasks. This "uniformity" is
required for the abstract `SubgraphCallable` representation used
within Blockwise.
output_blocks: Set[Tuple]
Specify a specific set of required output blocks. Since the graph
will only contain the necessary tasks to generate these outputs,
Expand All @@ -207,6 +193,7 @@ class Blockwise(Layer):

See Also
--------
dask.blockwise.BlockwiseIO
dask.blockwise.blockwise
dask.array.blockwise
"""
Expand All @@ -220,40 +207,13 @@ def __init__(
numblocks,
concatenate=None,
new_axes=None,
io_subgraph=None,
output_blocks=None,
annotations=None,
):
super().__init__(annotations=annotations)
self.output = output
self.output_indices = tuple(output_indices)
self.io_subgraph = io_subgraph[1] if io_subgraph else None
self.io_name = io_subgraph[0] if io_subgraph else None
self.output_blocks = output_blocks
if not dsk:
# If there is no `dsk` input, there must be an IO subgraph.
if io_subgraph is None:
raise ValueError("io_subgraph required if dsk is not supplied.")

# Extract actual IO function for SubgraphCallable construction.
# Wrap func in `PackedFunctionCall`, since it will receive
# all arguments as a sigle (packed) tuple at run time.
if self.io_subgraph:
# We assume a 1-to-1 mapping between keys (i.e. tasks) and
# chunks/partitions in `io_subgraph`, and assume the first
# (callable) element is the same for all tasks.
any_key = next(iter(self.io_subgraph))
io_func = self.io_subgraph.get(any_key)[0]
else:
io_func = None
ninds = 1 if isinstance(output_indices, str) else len(output_indices)
dsk = {
output: (
PackedFunctionCall(io_func),
*[blockwise_token(i) for i in range(ninds)],
)
}

self.dsk = dsk
self.indices = tuple(
(name, tuple(ind) if ind is not None else ind) for name, ind in indices
Expand All @@ -271,6 +231,12 @@ def __init__(
self.concatenate = concatenate
self.new_axes = new_axes or {}

# No IO subgraph allowed in `Blockwise`.
# Use `BlockwiseIO` to include IO. Note that the `io_name`
# attribute is only included in `Blockwise` to help
# simplify `_optimize_blockwise` logic.
self.io_name = None

@property
def dims(self):
"""Returns a dictionary mapping between each index specified in
Expand Down Expand Up @@ -304,18 +270,6 @@ def _dict(self):
dims=self.dims,
)

if self.io_subgraph:
# This is an IO layer.
for k in dsk:
io_key = (self.io_name,) + tuple([k[i] for i in range(1, len(k))])
if io_key in dsk[k]:
# Inject IO-function arguments into the blockwise graph
# as a single (packed) tuple.
io_item = self.io_subgraph.get(io_key)
io_item = list(io_item[1:]) if len(io_item) > 1 else []
new_task = [io_item if v == io_key else v for v in dsk[k]]
dsk[k] = tuple(new_task)

self._cached_dict = {"dsk": dsk}
return self._cached_dict["dsk"]

Expand Down Expand Up @@ -395,9 +349,6 @@ def __dask_distributed_pack__(self, client):
"numblocks": self.numblocks,
"concatenate": self.concatenate,
"new_axes": self.new_axes,
"io_subgraph": (self.io_name, self.io_subgraph)
if self.io_name
else (None, None),
"annotations": self.pack_annotations(),
"output_blocks": self.output_blocks,
"dims": self.dims,
Expand All @@ -421,21 +372,8 @@ def __dask_distributed_unpack__(cls, state, dsk, dependencies, annotations):
deserializing=True,
func_future_args=state["func_future_args"],
)
io_name, io_subgraph = state["io_subgraph"]
global_dependencies = list(state["global_dependencies"])

if io_subgraph:
# This is an IO layer.
for k in raw:
io_key = (io_name,) + tuple([k[i] for i in range(1, len(k))])
if io_key in raw[k]:
# Inject IO-function arguments into the blockwise graph
# as a single (packed) tuple.
io_item = io_subgraph.get(io_key)
io_item = list(io_item[1:]) if len(io_item) > 1 else []
new_task = [io_item if v == io_key else v for v in raw[k]]
raw[k] = tuple(new_task)

if state["annotations"]:
annotations.update(cls.expand_annotations(state["annotations"], raw.keys()))

Expand Down Expand Up @@ -502,8 +440,8 @@ def _cull(self, output_blocks):
self.numblocks,
concatenate=self.concatenate,
new_axes=self.new_axes,
io_subgraph=(self.io_name, self.io_subgraph) if self.io_name else None,
output_blocks=output_blocks,
annotations=self.annotations,
)

def cull(self, keys, all_hlg_keys):
Expand All @@ -525,6 +463,203 @@ def cull(self, keys, all_hlg_keys):
return self, culled_deps


class BlockwiseIO(Blockwise):
"""Blockwise layer with IO

This a specialized Blockwise layer containing the IO "subgraph"
required to construct a brand new collection.

Parameters
----------
io_name: str
The name of the new collection to be created within this layer.
Note that this must match the name used in both `indices` and in
the keys of the `io_subgraph` input.
io_subgraph: Dict or Mapping
The graph needed to construct a new collection within this layer.
The `subgraph` must comprise of exactly N keys (where N is the
number of chunks/partitions in the new collection), and each value
must correspond to a "callable" task. The first element (a callable
function) must be the same for all N tasks. This "uniformity" is
required for the abstract `SubgraphCallable` representation used
within Blockwise.
output: str
The name of the output collection. Used in keynames
output_indices: tuple
The output indices, like ``('i', 'j', 'k')`` used to determine the
structure of the block computations
dsk: dict
A small graph to apply per-output-block. May include keys from the
input indices.
indices: Tuple[str, Tuple[str, str]]
An ordered mapping from input key name, like ``'x'``
to input indices, like ``('i', 'j')``
Or includes literals, which have ``None`` for an index value
numblocks: Dict[key, Sequence[int]]
Number of blocks along each dimension for each input
concatenate: boolean
Whether or not to pass contracted dimensions as a list of inputs or a
single input to the block function
new_axes: Dict
New index dimensions that may have been created, and their extent
output_blocks: Set[Tuple]
Specify a specific set of required output blocks. Since the graph
will only contain the necessary tasks to generate these outputs,
this kwarg can be used to "cull" the abstract layer (without needing
to materialize the low-level graph).
annotations: dict (optional)
Layer annotations

See Also
--------
dask.blockwise.BlockwiseIO
"""

def __init__(
self,
io_name,
io_subgraph,
output,
output_indices,
dsk,
indices,
numblocks,
concatenate=None,
new_axes=None,
output_blocks=None,
annotations=None,
):
# Handle the case that this is a "pure" IO layer (dsk is None).
# Note that `dsk` should only be defined after fusion.
if not dsk:
# Extract actual IO function for SubgraphCallable construction.
# Wrap func in `PackedFunctionCall`, since it will receive
# all arguments as a sigle (packed) tuple at run time.
if io_subgraph:
# We assume a 1-to-1 mapping between keys (i.e. tasks) and
# chunks/partitions in `io_subgraph`, and assume the first
# (callable) element is the same for all tasks.
any_key = next(iter(io_subgraph))
io_func = io_subgraph.get(any_key)[0]
else:
io_func = None
ninds = 1 if isinstance(output_indices, str) else len(output_indices)
dsk = {
output: (
PackedFunctionCall(io_func),
*[blockwise_token(i) for i in range(ninds)],
)
}

# Super-class initializer
super().__init__(
output,
output_indices,
dsk,
indices,
numblocks,
concatenate=None,
new_axes=None,
output_blocks=None,
annotations=annotations,
)

# BlockwiseIO requires `io_name` and `io_subgraph` inputs
self.io_name = io_name
self.io_subgraph = io_subgraph

@property
def _dict(self):
if hasattr(self, "_cached_dict"):
return self._cached_dict["dsk"]
else:
keys = tuple(map(blockwise_token, range(len(self.indices))))
dsk, _ = fuse(self.dsk, [self.output])
func = SubgraphCallable(dsk, self.output, keys)

dsk = make_blockwise_graph(
func,
self.output,
self.output_indices,
*list(toolz.concat(self.indices)),
new_axes=self.new_axes,
numblocks=self.numblocks,
concatenate=self.concatenate,
output_blocks=self.output_blocks,
dims=self.dims,
)

# Handle IO Subgraph
for k in dsk:
io_key = (self.io_name,) + tuple([k[i] for i in range(1, len(k))])
if io_key in dsk[k]:
# Inject IO-function arguments into the blockwise graph
# as a single (packed) tuple.
io_item = self.io_subgraph.get(io_key)
io_item = list(io_item[1:]) if len(io_item) > 1 else []
new_task = [io_item if v == io_key else v for v in dsk[k]]
dsk[k] = tuple(new_task)

self._cached_dict = {"dsk": dsk}
return self._cached_dict["dsk"]

def __dask_distributed_pack__(self, client):
ret = super().__dask_distributed_pack__(client)
ret["io_info"] = (self.io_name, self.io_subgraph)
return ret

@classmethod
def __dask_distributed_unpack__(cls, state, dsk, dependencies):
raw, raw_deps = make_blockwise_graph(
state["func"],
state["output"],
state["output_indices"],
*state["indices"],
new_axes=state["new_axes"],
numblocks=state["numblocks"],
concatenate=state["concatenate"],
output_blocks=state["output_blocks"],
dims=state["dims"],
return_key_deps=True,
deserializing=True,
func_future_args=state["func_future_args"],
)
io_name, io_subgraph = state["io_info"]
global_dependencies = list(state["global_dependencies"])

if io_subgraph:
# This is an IO layer.
for k in raw:
io_key = (io_name,) + tuple([k[i] for i in range(1, len(k))])
if io_key in raw[k]:
# Inject IO-function arguments into the blockwise graph
# as a single (packed) tuple.
io_item = io_subgraph.get(io_key)
io_item = list(io_item[1:]) if len(io_item) > 1 else []
new_task = [io_item if v == io_key else v for v in raw[k]]
raw[k] = tuple(new_task)

raw = {stringify(k): stringify_collection_keys(v) for k, v in raw.items()}
dsk.update(raw)

for k, v in raw_deps.items():
dependencies[stringify(k)] = [stringify(d) for d in v] + global_dependencies

def _cull(self, output_blocks):
return BlockwiseIO(
self.io_name,
self.io_subgraph,
self.output,
self.output_indices,
self.dsk,
self.indices,
self.numblocks,
concatenate=self.concatenate,
new_axes=self.new_axes,
output_blocks=output_blocks,
)


def _get_coord_mapping(
dims,
output,
Expand Down Expand Up @@ -1161,17 +1296,31 @@ def rewrite_blockwise(inputs):
numblocks = toolz.merge([inp.numblocks for inp in inputs.values()])
numblocks = {k: v for k, v in numblocks.items() if v is None or k in indices_check}

out = Blockwise(
root,
inputs[root].output_indices,
dsk,
new_indices,
numblocks=numblocks,
new_axes=new_axes,
concatenate=concatenate,
io_subgraph=io_info,
annotations=inputs[root].annotations,
)
if io_info:
# Fused layer includes IO
out = BlockwiseIO(
*io_info, # (io_name, io_subgraph)
root,
inputs[root].output_indices,
dsk,
new_indices,
numblocks=numblocks,
new_axes=new_axes,
concatenate=concatenate,
annotations=inputs[root].annotations,
)
else:
# Fused layer does NOT include IO
out = Blockwise(
root,
inputs[root].output_indices,
dsk,
new_indices,
numblocks=numblocks,
new_axes=new_axes,
concatenate=concatenate,
annotations=inputs[root].annotations,
)

return out

Expand Down
Loading