Currently the dask array matmul implemenation is similar to the tensordot implemenation (understandably) but lacks a clever trick that tensordot uses to avoid contractions in blockwise (which can result in large pileups of intermediate results).
|
@derived_from(np) |
|
def matmul(a, b): |
|
a = asanyarray(a) |
|
b = asanyarray(b) |
|
|
|
if a.ndim == 0 or b.ndim == 0: |
|
raise ValueError("`matmul` does not support scalars.") |
|
|
|
a_is_1d = False |
|
if a.ndim == 1: |
|
a_is_1d = True |
|
a = a[np.newaxis, :] |
|
|
|
b_is_1d = False |
|
if b.ndim == 1: |
|
b_is_1d = True |
|
b = b[:, np.newaxis] |
|
|
|
if a.ndim < b.ndim: |
|
a = a[(b.ndim - a.ndim) * (np.newaxis,)] |
|
elif a.ndim > b.ndim: |
|
b = b[(a.ndim - b.ndim) * (np.newaxis,)] |
|
|
|
out = blockwise( |
|
np.matmul, |
|
tuple(range(1, a.ndim + 1)), |
|
a, |
|
tuple(range(1, a.ndim - 1)) + (a.ndim - 1, 0), |
|
b, |
|
tuple(range(1, a.ndim - 1)) + (0, a.ndim), |
|
dtype=result_type(a, b), |
|
concatenate=True, |
|
) |
|
|
|
if a_is_1d: |
|
out = out[..., 0, :] |
|
if b_is_1d: |
|
out = out[..., 0] |
|
|
|
return out |
Tensordot uses only outer products in blockwise (there are no shared indices in the blockwise call) , and then uses a sum call to reduce these. Because sum uses tree reductions, this results in nicer memory usage. Matmul doesn't use this trick.
I briefly tried replacing the blockwise call in matmul with a call to our internal dot function, but various tests failed. It would be good to do one of two things:
- Find a way to back matmul by tensordot, so that the reduction trick is shared with a common codebase
- Replicate the reduction trick in matmul
This was originally reported by @eric-czech and @tomwhite in https://github.com/pystatgen/sgkit/issues/375#issuecomment-731060672
Currently the dask array matmul implemenation is similar to the tensordot implemenation (understandably) but lacks a clever trick that tensordot uses to avoid contractions in blockwise (which can result in large pileups of intermediate results).
dask/dask/array/routines.py
Lines 290 to 329 in ff3ea6c
Tensordot uses only outer products in blockwise (there are no shared indices in the blockwise call) , and then uses a
sumcall to reduce these. Because sum uses tree reductions, this results in nicer memory usage. Matmul doesn't use this trick.I briefly tried replacing the blockwise call in
matmulwith a call to our internaldotfunction, but various tests failed. It would be good to do one of two things:This was originally reported by @eric-czech and @tomwhite in https://github.com/pystatgen/sgkit/issues/375#issuecomment-731060672