-
-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Closed
Description
When working with dask-glm I find myself interacting with functions like the following (where x is a dask.array):
def l2(x, t):
return 1 / (1 + lamda * t) * x
def l1(x, t):
return (absolute(x) > lamda * t) * (x - sign(x) * lamda * t)These are costly in a few ways:
- They have decent overhead, because they repeatedly regenerate relatively large graphs
- On computation, even if we fuse, we create many intermediate copies of numpy arrays
So there are two part solutions that we could combine here:
- For any given dtype/shape/chunks signature, we could precompute a dask graph. When the same dtype/shape/chunks signature comes in we would stitch the new keys in at the right place, change around some tokenized values, and ship the result out without calling all of the dask.array code.
- We could numba.jit fused tasks
Using numba would actually be pretty valuable in some cases in dask-glm. This could be an optimization at the task graph level. I suspect that if we get good at recognizing recurring patterns and cache well that we could make this fast-ish. (add, _, (mul, _, _)) -> numba.jit(lambda x, y, z: x + y * z). We might also be able to back out patterns based on keys (not sure if this is safe)
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels