Reference implementations for softmax, log_softmax, logsumexp#79423
Reference implementations for softmax, log_softmax, logsumexp#79423IvanYashchuk wants to merge 17 commits intopytorch:masterfrom
Conversation
🔗 Helpful links
✅ No Failures (0 Pending)As of commit 8893024 (more details on the Dr. CI page): Expand to see more💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Please report bugs/suggestions to the (internal) Dr. CI Users group. |
|
What are the errors with aten executor? |
For |
| ) -> TensorLikeType: | ||
| result_dtype = dtype or a.dtype | ||
| computation_dtype = utils.get_computation_dtype(a.dtype) | ||
| a = prims.convert_element_type(a, computation_dtype) |
There was a problem hiding this comment.
Let's make this conditional on the conversion being required -- we actually have a function for this if you'd prefer, but it's in an awkward place (it could be moved to utils if you'd rather use it vs. a custom conditional)
pytorch/torch/_prims/wrappers.py
Line 20 in 38e717d
There was a problem hiding this comment.
Thanks! _maybe_convert_to_dtype is already used in this file, so we can keep it defined in the awkward place for now 🙂
| computation_dtype = utils.get_computation_dtype(a.dtype) | ||
| a = prims.convert_element_type(a, computation_dtype) | ||
| a_max = amax(a, dim, keepdim=True) | ||
| shifted = a - a_max |
There was a problem hiding this comment.
Nit: comment for why the shift occurs would be nice
There was a problem hiding this comment.
Oh, the shift is actually not required here because stabilized logsumexp is used. I will remove it.
| a_max = amax(a, dim, keepdim=True) | ||
| shifted = a - a_max | ||
| shifted_logsumexp = logsumexp(shifted, dim, keepdim=True) | ||
| return prims.convert_element_type( |
There was a problem hiding this comment.
Let's also make this conditional on the conversion being required
| ) | ||
|
|
||
|
|
||
| def _squeeze_multiple(a: TensorLikeType, dims: DimsSequenceType) -> TensorLikeType: |
There was a problem hiding this comment.
Unify this with prims.squeeze -- see
pytorch/torch/_prims/__init__.py
Line 1636 in 38e717d
I think we only need one? But maybe this implementation is better for the prim?
There was a problem hiding this comment.
I didn't notice that prims.squeeze works for multiple specified dimensions. Both implementations are fine, I will not touch prims.squeeze at this time.
| @out_wrapper | ||
| def logsumexp( | ||
| a: TensorLikeType, | ||
| dims: DimsType, |
There was a problem hiding this comment.
There was a problem hiding this comment.
Right, it should be dim to match the torch namespace. I think I was confused because it also accepts and works for several dimensions
dim (int or tuple of python:ints)
import torch
a = torch.ones(3, 3)
torch.logsumexp(a, (0, 1))
# tensor(3.1972)| keepdim: bool = False, | ||
| ) -> TensorLikeType: | ||
| dims = utils.canonicalize_dims(a.ndim, dims) | ||
| # ATen specifies int[1] type dims which expands integers to tuples of length 1 |
| a_max_squeezed = _squeeze_multiple(a_max, dims) if not keepdim else a_max | ||
| result = log(sum(exp(a - a_max), dims, keepdim=keepdim)) + a_max_squeezed | ||
| else: | ||
| result = log(sum(exp(a), dims, keepdim=keepdim)) |
There was a problem hiding this comment.
Add a comment for what this case covers (integer and boolean dtypes)
| dims = (dims,) | ||
| if utils.is_float_dtype(a.dtype) or utils.is_complex_dtype(a.dtype): | ||
| a_max = amax(a, dims, keepdim=True) | ||
| a_max = where(abs(a_max) == float("inf"), 0.0, a_max) |
There was a problem hiding this comment.
really elegant code here
| ) -> TensorLikeType: | ||
| result_dtype = dtype or a.dtype | ||
| computation_dtype = utils.get_computation_dtype(a.dtype) | ||
| a = prims.convert_element_type(a, computation_dtype) |
There was a problem hiding this comment.
Similar comments here as with log_softmax re: conditional conversions
mruberry
left a comment
There was a problem hiding this comment.
This is awesome, @IvanYashchuk! I made some inline comments for your review, nothing major, and the lint job needs to be fixed, but approving this for velocity because I'm sure you'll sort out the review and the jobs.
|
|
Forward AD test fails with new sample input that reduces over multiple dims
|
@pytorchbot merge -g |
|
@pytorchbot successfully started a merge job. Check the current status here |
|
Hey @IvanYashchuk. |
#79423) Summary: This PR adds references for: - `torch.softmax` - `torch.log_softmax` - `torch.logsumexp` Unfortunately, none of them currently pass `test_python_ref_executor` even with `"aten"` executor. Pull Request resolved: #79423 Approved by: https://github.com/mruberry Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/4fc7832d72926d65d773ae4e4ae0ed7fc573f0c7 Reviewed By: malfet Differential Revision: D37156829 fbshipit-source-id: 88a1ed3d42fda30b880d8a2fe48f385ebdb98d22
…h#79423) This PR adds references for: - `torch.softmax` - `torch.log_softmax` - `torch.logsumexp` Unfortunately, none of them currently pass `test_python_ref_executor` even with `"aten"` executor. Pull Request resolved: pytorch#79423 Approved by: https://github.com/mruberry
This PR adds references for:
torch.softmaxtorch.log_softmaxtorch.logsumexpUnfortunately, none of them currently pass
test_python_ref_executoreven with"aten"executor.