Skip to content

Reuse tensordot memory tricks in matmul #6874

@mrocklin

Description

@mrocklin

Currently the dask array matmul implemenation is similar to the tensordot implemenation (understandably) but lacks a clever trick that tensordot uses to avoid contractions in blockwise (which can result in large pileups of intermediate results).

dask/dask/array/routines.py

Lines 290 to 329 in ff3ea6c

@derived_from(np)
def matmul(a, b):
a = asanyarray(a)
b = asanyarray(b)
if a.ndim == 0 or b.ndim == 0:
raise ValueError("`matmul` does not support scalars.")
a_is_1d = False
if a.ndim == 1:
a_is_1d = True
a = a[np.newaxis, :]
b_is_1d = False
if b.ndim == 1:
b_is_1d = True
b = b[:, np.newaxis]
if a.ndim < b.ndim:
a = a[(b.ndim - a.ndim) * (np.newaxis,)]
elif a.ndim > b.ndim:
b = b[(a.ndim - b.ndim) * (np.newaxis,)]
out = blockwise(
np.matmul,
tuple(range(1, a.ndim + 1)),
a,
tuple(range(1, a.ndim - 1)) + (a.ndim - 1, 0),
b,
tuple(range(1, a.ndim - 1)) + (0, a.ndim),
dtype=result_type(a, b),
concatenate=True,
)
if a_is_1d:
out = out[..., 0, :]
if b_is_1d:
out = out[..., 0]
return out

Tensordot uses only outer products in blockwise (there are no shared indices in the blockwise call) , and then uses a sum call to reduce these. Because sum uses tree reductions, this results in nicer memory usage. Matmul doesn't use this trick.

I briefly tried replacing the blockwise call in matmul with a call to our internal dot function, but various tests failed. It would be good to do one of two things:

  1. Find a way to back matmul by tensordot, so that the reduction trick is shared with a common codebase
  2. Replicate the reduction trick in matmul

This was originally reported by @eric-czech and @tomwhite in https://github.com/pystatgen/sgkit/issues/375#issuecomment-731060672

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions