Rewrite matmul as blockwise without concatenated contraction#7000
Rewrite matmul as blockwise without concatenated contraction#7000TomAugspurger merged 1 commit intodask:masterfrom
Conversation
eb0ebff to
eb6d82b
Compare
|
Nice @ravwojdyla! A couple quick questions:
|
chunks of inputs are passed into Line 288 in eb6d82b When last dimension of A or last 2 two dimensions of B are chunked we need to sum partial results (the sum axis depends on the chunking scheme), otherwise we have full results.
Afaiu this is handled by the
For 2D case, there isn't much difference, but this also handles matrices with more than 2 dimensions. |
|
Hey @mrocklin / @TomAugspurger do you think you could help us direct a review for this one since it's important for unblocking some of our scalability testing (cf. sgkit#390)? FYI this work supersedes what we proposed in #6924 -- I believe @tomwhite was going to close it. |
|
I'm not sure what bandwidth he has available, but cc @gforsyth just in case he has a chance to take a look |
gforsyth
left a comment
There was a problem hiding this comment.
Hey @ravwojdyla -- thanks for putting this in and for the very thorough notebook running through the performance changes.
The changes you've made look good to me and I've tried to break this locally and haven't been able to. 🎉
In re: the slight performance regression vs the original implementation when the arrays fit cleanly in memory, the chunking used in those examples (for the in-memory ones) defaults to (1000, 250) which puts it at something like 2mb per chunk. Bumping that up (as you do in the following example) shows that your implementation is much more performant overall.
My only request here is that we add a few more tests along the lines of @eric-czech's questions -- all the current tests seem to have symmetric chunk sizes and it would be good to have a few where the chunk sizes are a bit more disparate, e.g.
X = da.random.random(size=(3, 3, 50, 100), chunks=(1, 3, 10, 25))
Y = da.random.random(size=(3, 3, 100, 50), chunks=(1, 3, 20, 5))
I'd like to see that and maybe two other random disparate chunk size examples in the tests just to cover our bases and then this is good to go in.
|
Hi @gforsyth, thanks for the review!
Certainly, will add. |
eb6d82b to
ada36fa
Compare
|
@gforsyth tests added: dask/dask/array/tests/test_routines.py Lines 252 to 259 in ada36fa |
gforsyth
left a comment
There was a problem hiding this comment.
Thanks for putting this in @ravwojdyla -- this looks great! @jrbourbeau this is ready to go in.
|
Thanks all! |
|
Thanks @gforsyth and @TomAugspurger! Will go ahead and create an issue to validate the new implementation works and scales well on cupy and sparse. The (unit) tests included in dask work (at least the sparse ones, the cupy is not tested as part of the CI, right?), but I would also like to validate the performance for those two array types (before release), are there any other that we should look into? |
| # Since we have performed the contraction via matmul | ||
| # but blockwise expects all dimensions back, we need | ||
| # to add one dummy dimension back | ||
| return chunk[..., np.newaxis] |
There was a problem hiding this comment.
Your PR is an impressive piece of code @ravwojdyla !
I know this PR is closed. But I'd appreciate it if you would answer one last question from me.
This is the only line that puzzles me — while I understand why the number of the dimensions of the output chunk should match what blockwise expects, I can't figure out why the new axis is inserted in the last position instead of the second to last position. Perhaps it doesn't matter in the end, but I'd like to know why.
I mean, why not the following:
# to add one dummy dimension back
return chunk[..., np.newaxis, :]which more closely matches the expected output of the blockwise call?
Thanks!
There was a problem hiding this comment.
👋 @ParticularMiner my initial reaction would be that it probably doesn't matter (tho might require some changes it in the downstream logic from blockwise). That extra dummy dimension is "squeezed" out later afair, and as you have pointed out - it's there to make blockwise happy at the metadata level. That said, it's been a while, and it's just my initial reaction.
There was a problem hiding this comment.
Thank you very much for your reply @ravwojdyla despite this conversation being almost a year old!
Indeed, it turns out through testing that it doesn’t matter, which puzzled me.
Also the downstream logic after blockwise operates on the assumption that the contraction-axis is the second-to-last axis of the output of blockwise.
There was a problem hiding this comment.
The following is not vital (performance-wise). But just to make you aware:
While mirroring your code elsewhere, it turned out that by using
# to add one dummy dimension back
return chunk[..., np.newaxis, :]instead of
# to add one dummy dimension back
return chunk[..., np.newaxis]one can avoid the extra conditional branches in the downstream logic after blockwise by replacing them with a single sum. That is, this line:
out = out.sum(axis=-2)can replace the following lines:
# When we perform reduction, we need to worry about the last 2 dimensions
# which hold the matrices, some care is required to handle chunking in
# that space.
contraction_dimension_is_chunked = (
max(min(a.chunks[-1], b.chunks[-2])) < a.shape[-1]
)
b_last_dim_max_chunk = max(b.chunks[-1])
if contraction_dimension_is_chunked or b_last_dim_max_chunk < b.shape[-1]:
if b_last_dim_max_chunk > 1:
# This is the case when both contraction and last dimension axes
# are chunked
out = out.reshape(out.shape[:-1] + (1, -1))
out = out.sum(axis=-3)
out = out.reshape(out.shape[:-2] + (b.shape[-1],))
else:
# Contraction axis is chunked
out = out.sum(axis=-2)
else:
# Neither contraction nor last dimension axes are chunked, we
# remove the dummy dimension without reduction
out = out.reshape(out.shape[:-2] + (b.shape[-1],))I haven't tested this yet for dask itself. But it works in my particular application. If you feel this is worth exploring for dask, then just drop me a line, and I'll create a PR with extensive testing. But I doubt it will significantly improve current performance. Rather it will merely create more readable code. Which is often useful. 😉
There was a problem hiding this comment.
@ParticularMiner that's great! Given that those reshape ops should be largely metadata operation, I agree that I won't expect performance impact, but I do love the simpler code. Definitely +1 to PR, iff you test that it works well for diff chunking schemes (as mentioned in the comment above).
There was a problem hiding this comment.
Many thanks for your interest!
I've made the suggested changes on my local git and the relevant existing unit-tests (in test_matmul()) have all passed.
The only thing left now is to run your benchmarking jupyter notebook above. (IMO, the results there will be similar. Still it is always safer to check.) But to do that, I need access to a computing cluster, which I currently do not have. Any pointers on where I can access one (preferably for free)? 😉
There was a problem hiding this comment.
Sorry. I misunderstood. It seems you ran your notebook on a standard laptop, not a multi-node cluster. I can probably do that too.
There was a problem hiding this comment.
Hi @ravwojdyla
You’ll find my suggested PR at #8423.
Currently the CI tests are broken for other reasons. So it is not clear that the PR checks out.
Re: #6874
black dask/flake8 daskPerformance stats in a notebook.
There are more optimizations I have in mind, but first I would like to get some feedback about the overall direction. This PR implements matmul as blockwise without concatenated contraction which reduces the memory footprint and allows for better control over execution time and memory usage. Please see the notebook above to see the impact on performance.