Skip to content

dask.array.jit or dask.array.vectorize #1946

@mrocklin

Description

@mrocklin

When working with dask-glm I find myself interacting with functions like the following (where x is a dask.array):

def l2(x, t):
    return 1 / (1 + lamda * t) * x

def l1(x, t):
    return (absolute(x) > lamda * t) * (x - sign(x) * lamda * t)

These are costly in a few ways:

  1. They have decent overhead, because they repeatedly regenerate relatively large graphs
  2. On computation, even if we fuse, we create many intermediate copies of numpy arrays

So there are two part solutions that we could combine here:

  1. For any given dtype/shape/chunks signature, we could precompute a dask graph. When the same dtype/shape/chunks signature comes in we would stitch the new keys in at the right place, change around some tokenized values, and ship the result out without calling all of the dask.array code.
  2. We could numba.jit fused tasks

Using numba would actually be pretty valuable in some cases in dask-glm. This could be an optimization at the task graph level. I suspect that if we get good at recognizing recurring patterns and cache well that we could make this fast-ish. (add, _, (mul, _, _)) -> numba.jit(lambda x, y, z: x + y * z). We might also be able to back out patterns based on keys (not sure if this is safe)

cc @jcrist @eriknw @sklam @seibert @shoyer

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions