-
-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Description
Currently the dask array matmul implemenation is similar to the tensordot implemenation (understandably) but lacks a clever trick that tensordot uses to avoid contractions in blockwise (which can result in large pileups of intermediate results).
Lines 290 to 329 in ff3ea6c
| @derived_from(np) | |
| def matmul(a, b): | |
| a = asanyarray(a) | |
| b = asanyarray(b) | |
| if a.ndim == 0 or b.ndim == 0: | |
| raise ValueError("`matmul` does not support scalars.") | |
| a_is_1d = False | |
| if a.ndim == 1: | |
| a_is_1d = True | |
| a = a[np.newaxis, :] | |
| b_is_1d = False | |
| if b.ndim == 1: | |
| b_is_1d = True | |
| b = b[:, np.newaxis] | |
| if a.ndim < b.ndim: | |
| a = a[(b.ndim - a.ndim) * (np.newaxis,)] | |
| elif a.ndim > b.ndim: | |
| b = b[(a.ndim - b.ndim) * (np.newaxis,)] | |
| out = blockwise( | |
| np.matmul, | |
| tuple(range(1, a.ndim + 1)), | |
| a, | |
| tuple(range(1, a.ndim - 1)) + (a.ndim - 1, 0), | |
| b, | |
| tuple(range(1, a.ndim - 1)) + (0, a.ndim), | |
| dtype=result_type(a, b), | |
| concatenate=True, | |
| ) | |
| if a_is_1d: | |
| out = out[..., 0, :] | |
| if b_is_1d: | |
| out = out[..., 0] | |
| return out |
Tensordot uses only outer products in blockwise (there are no shared indices in the blockwise call) , and then uses a sum call to reduce these. Because sum uses tree reductions, this results in nicer memory usage. Matmul doesn't use this trick.
I briefly tried replacing the blockwise call in matmul with a call to our internal dot function, but various tests failed. It would be good to do one of two things:
- Find a way to back matmul by tensordot, so that the reduction trick is shared with a common codebase
- Replicate the reduction trick in matmul
This was originally reported by @eric-czech and @tomwhite in https://github.com/pystatgen/sgkit/issues/375#issuecomment-731060672