Skip to content

High Level Expressions #7933

@mrocklin

Description

@mrocklin

Summary

We should make space for high level query optimization. There are a couple of ways to do this. This issue includes motivation, a description of two approaches, and some thoughts on trade-offs.

Motivation

There are a variety of situations where we might want to rewrite a user's code.

Dataframes

  1. Column projection, as in dd.read_parquet(...)["x"] -> dd.read_parquet(..., columns=["x"])
  2. Predicate pushdown (same as above)
  3. High level expression fusion (what we do today with blockwise)
  4. Pushing length calls down through elementwise calls
  5. Pushing filters earlier in a computation
  6. ...

Arrays

  1. Automatic rechunking at the beginning of a computation based on the end of the computation
  2. Slicing

History

Today there is no real place where we capture a user's intent or their lines of code. We immediately create a task graph for the requested operation, dump it into a mapping, and create a new dask.dataframe.DataFrame or dask.array.Array instance. That instance has no knowledge of what created it, or what created the other input dataframes on which it depends.

This was a good choice early on. It made it easy for us to quickly implement lots of complex operations without thinking about a class hierarchy for them. This choice followed on from the choices of Blaze, where we started with high level expressions, but got a bit stuck because they constrained our thinking (and no one really cares about high level query optimization for a system that they don't use.

However today we have maybe reached a point where our "keep everything low-level and simple" strategy has hit a limit, and now we're curious about how best to bolt on a high level expression system. Doing this smoothly given the current system is hard. I see two ways out.

High Level Graph layers

We do have a record of what operations came before us in the high level graph layers. Currently the API of layers is very generic. They must be a mapping that adheres to the Dask graph spec, and they must be serializable in a certain way. There are some consistent subclasses, like blockwise, which enable high level optimizations which have proven useful.

There isn't really much structure here though, and as a result it's hard to do interesting optimizations. For example it would be nice if we could change a layer at the very bottom of the graph, and then replay all of the operations on that input over again to see how they would change. High level layers today don't have enough shared structure that we know how to do this.

I like High Level Graph Layers because it gives us a space to hijack and add in all sorts of complex machinery, but without affecting the user-facing Dataframe class. We would have to add a lot more structure here though, and we'll always be working around the collection class, which is a drawback.

Collection subclasses

I'm going to focus on an alternative that is a bit more radical. We could also have every user call generate a DataFrame subclass. There would still be a DataFrame instance that took in a generic graph/divisions/meta, but that would be mostly for backwards compatibility. Instead most dataframe operations would produce subclasses that had a well-defined common structure, as well as more custom attributes for their specific operation. Let's look at a couple of examples.

# API calls just create instances.  All logic happens there.
def read_parquet(file, columns, filters):
    return ReadParquet(file, columns, filters)

class ReadParquet(DataFrame):
    args = ["file", "column", "filters"]  # List of arguments to use when reconstructing
    inputs = []  # List of arguments that are DataFrame objects

    def __init__(self, file, columns, filters):
        self.file  = file
        self.columns = columns
        self.filters = filters

        self.divisions, self.meta = # do a bit of work on metadata
        
    def _generate_dask_layer(self) -> dict:
        ...

class ColumnProjection(DataFrame):
    args = ["dataframe", "columns"]
    inputs = ["dataframe"]

    def __init__(self, dataframe, columns):
        self.dataframe = dataframe
        self.columns = columns
        self._meta = self.dataframe._meta[columns]

    def _generate_dask_layer(self) -> dict:
        ...

class Add(DataFrame):
    args = ["left", "right"]

    def __init__(self, left, right):
        self.left = left
        self.right = right
        self.inputs = []
        if is_dask_collection(left):
            self.inputs.append("left")
        if is_dask_collection(right):
            self.inputs.append("right")

        self._meta = ...
        self._divisions = ...

    def _generate_dask_layer(self) -> dict:
        ...

As folks familiar with SymPy will recall, having attributes like args/inputs around makes it possible to re-generate a DataFrame automatically. So if we do something like the following:

df = dd.read_parquet(...)
z = df.x + df.y

Then this turns into an expression tree like the following:

Add(
    ColumnProjection(
        ReadParquet(..., columns=None),
        "x",
    ),
    ColumnProjection(
        ReadParquet(..., columns=None),
        "y",
    ),
)

We can then traverse this tree (which is easy because we have a list of all attributes that are dask collections in the inputs attribute) and apply optimizaitons (which is easy because we can easily reconstruct layers because we have the args attribute).

For example the ColumnProjection class may have an optimization method like the following:

class ColumnProjection(DataFrame):
    ... # continued from above

    def _optimize(self) -> DataFrame:
        if isinstance(self.dataframe, ReadParquet):
            args = {arg: getattr(self.dataframe, arg) for arg in self.dataframe.args}
            args["columns"] = self.columns
            return ReadParquet(**args)._optimize()

        # here is another optimization, just to show variety
        if isinstance(self.datafarme, ColumnProjection):  # like df[["x", "y"]["x"]
            return ColumnProjection(self.dataframe.dataframe, self.columns)._optimize()

        # no known optimizations, optimize all inputs and then reconstruct (this would live in a superclass)
        args = []
        for arg in self.args:
            if arg in self.inputs:
                arg = getattr(self, arg)._optimize()
            else:
                arg = getattr(self, arg)
            args.append(arg)
        
        return type(self)(*args)

This is just one way of doing a traversal, using a method on a class. We can do fancier things. Mostly what I wanted to show here was that because we have args/inputs and class types it's fairly easy to encode optimizations and rewrite things.

What's notable here is that we aren't generating the graph, or even any semblance of the graph ahead of time. At any point where we run code that requires something like _meta or _divisions from an input we stop all opportunities to change the graph under that stage. This is OK for our internal code. We can defer graph generation I think.

I think that the main advantage to this approach is that we can easily reconstruct expressions given newly modified inputs. I think that this is fundamentally what is lacking with our current HLG layer approach.

However, this would also be a significant deviation from how Dask works today. It's likely that this would affect all downstream projects that subclass Dask arrays/dataframes today (RAPIDS, Pint, yt, ...). I think that that's probably ok. Those groups will, I think, understand.

Personal Thoughts

I've historically been against doing collection subclasses (the second approach), but after thinking about this for a while I think I'm now more in favor. It seems like maybe we've arrived at the time when the benefits to doing this outweigh the complexity costs. I think that this is motivated also by the amount of complexity that we're encountering walking down the HLG path.

However, I do think that this will be hard for us to implement in an incremental way. If we want to go down this path we probably need to think about an approach that lets the current DataFrame API persist for backwards compatibility with downstream projects (they wouldn't get any benefits of this system though) and a way where we can move over collection subclasses incrementally.

cc @rjzamora @jcrist

Metadata

Metadata

Assignees

No one assigned

    Labels

    highlevelgraphIssues relating to HighLevelGraphs.needs attentionIt's been a while since this was pushed on. Needs attention from the owner or a maintainer.

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions