Sparse softmax support (CPU)#36305
Conversation
1842a1c to
cb88366
Compare
We had a bunch of discussions about whether or not it is appropriate to treat unspecified sparse tensor entries as things other than zero; specifically, whether or not we should be allowed to overload the old function name in this situation (some discussion about this at #1369 ). Did you folks at Quansight come to a decision about what should be done here? |
|
@ezyang we have not discussed the overloading vs name-prefixing options yet but we have discussed the idea of introducing the so-called fill value to sparse tensors that would set the default value for unspecified sparse tensor entries. The fill value feature would be useful also for the sparse softmax case that would allow users to specify the user-defined behavior:
In addition, the fill value feature has the potential to eliminate the overloading vs name-prefixing question, and in general, it would lower the bridge between sparse and dense tensor in a consistent way. The fill value has been implemented in https://sparse.pydata.org/ and from #9674 (comment) I see you have already discussed this option as well. For this PR, we have two choices:
In the long term, with the perspective of having the fill value feature available, option 1. makes sense to me. So, I am not certain atm which approach would be preferable to users. Thoughts? |
|
If y'all are prepared to take on the maintenance burden of a fill value, I think I would also agree that this would be best, and then we don't have to introduce sparse versions of all of our functions. Without fill, (1) will technically require you to do BC-breaking changes in the future, as you will turn an operation that previously implicitly converted zero-filled tensors to neginf-filled tensors, into an error. It will save you work in the future if you do (2) first. |
|
OK, it makes sense. Here's a plan:
|
|
SGTM. |
9144b22 to
7661c03
Compare
|
@pearu can you get someone from Quansight to review first? |
|
All the loops could be replaced with the TensorIterator loops, I think... TensorIterator loops over all dimensions not involved in reduction, and the kernel loop iterates over the reduction dimension if you perform the reduction and scaling all at once. Otherwise you could iterate over each element and then rescale it appropriately while knowing which value to rescale it with. TensorIterator also brings multithreading for free. |
5d652ec to
f7f2101
Compare
… Implement support for dtype argument.
…alues. Rename iscalar_t to index_t. Update docs. Clean up.
85a58cf to
7d095a3
Compare
| if TYPE_CHECKING: | ||
| from torch import dtype as DType | ||
| else: | ||
| DType = int |
There was a problem hiding this comment.
Yes, we're going need to clue in TorchScript about this, unfortunately >:(
|
Here are some benchmark results for
Some notes:
|
|
Failure is legit: |
facebook-github-bot
left a comment
There was a problem hiding this comment.
@ezyang is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
Interesting to read the conversations in this PR again in light of maskedtensor cc @george-qi |
Summary: This PR implements softmax support for sparse tensors. The sparse softmax is related to dense softmax when the values of unspecified sparse tensor entries are taken to be `-inf` that will have the effect of "zero entries ignored". This relation is used for testing the correctness of results here. Resolves pytorch#23651 for CPU. - [x] sparse softmax - [x] CPU C++ implementation - [x] unittests - [x] update softmax documentation - [x] autograd support - [x] sparse log_softmax - [x] CPU C++ implementation - [x] unittests - [x] update log_softmax documentation - [x] autograd support Pull Request resolved: pytorch#36305 Differential Revision: D21566540 Pulled By: ezyang fbshipit-source-id: a632ea69c38622f960721482e442efeb8d0a54fc
This PR implements softmax support for sparse tensors.
The sparse softmax is related to dense softmax when the values of unspecified sparse tensor entries are taken to be
-infthat will have the effect of "zero entries ignored". This relation is used for testing the correctness of results here.Resolves #23651 for CPU.