OpInfo for nn.functional.softmax#62077
Conversation
🔗 Helpful links
💊 CI failures summary and remediationsAs of commit bd9a1d3 (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions to the (internal) Dr. CI Users group. |
|
This looks pretty good @krshrimali; I have one suggestion (inline) for how to tweak testing the "dtype" kwarg. |
There was a problem hiding this comment.
A few updates:
- This PR now uses sample inputs function of
log_softmaxforsoftmax. - Added alias for
softmax. - Code clean-up for some OpInfos, when params are passed with default values (which isn't needed).
- Skip removal for
test_jit_alias_remappingtest oflog_softmax.
cc: @mruberry @zou3519 (sorry for the ping over the weekend, please review whenever you find time)
|
Update: I'm taking a look at the XLA error, if it's something non-trivial - I'll probably remove the scalar input and add an issue related to this error. Also, this is only reproducible on XLA (tested on Google Colab) with scalar tensors: # Works on CPU
>>> torch.log_softmax(torch.rand((), device='cpu'), dim=0)
tensor(0.)
# Fails on XLA
>>> torch.log_softmax(torch.rand((), device='xla'), dim=0)
ERROR (please see the error below since it's too verbose)Error: Error on XLA
RuntimeError Traceback (most recent call last)
<ipython-input-4-470607feefbc> in <module>()
----> 1 torch.log_softmax(torch.rand((), device='xla'), dim=0)
RuntimeError: torch_xla/csrc/helpers.cpp:97 : Check failed: min_shape_dim <= dim && dim <= max_shape_dim
*** Begin stack trace ***
tensorflow::CurrentStackTrace()
torch_xla::XlaHelpers::GetCanonicalDimensionIndex(long, long)
torch_xla::XLATensor::log_softmax(torch_xla::XLATensor const&, long, c10::optional<c10::ScalarType>)
torch_xla::AtenXlaType::_log_softmax(at::Tensor const&, long, bool)
c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoRuntimeFunctor_<at::Tensor (*)(at::Tensor const&, long, bool), at::Tensor, c10::guts::typelist::typelist<at::Tensor const&, long, bool> >, at::Tensor (at::Tensor const&, long, bool)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, long, bool)
at::Tensor c10::Dispatcher::redispatch<at::Tensor, at::Tensor const&, long, bool>(c10::TypedOperatorHandle<at::Tensor (at::Tensor const&, long, bool)> const&, c10::DispatchKeySet, at::Tensor const&, long, bool) const
at::redispatch::_log_softmax(c10::DispatchKeySet, at::Tensor const&, long, bool)
at::_log_softmax(at::Tensor const&, long, bool)
at::native::log_softmax(at::Tensor const&, long, c10::optional<c10::ScalarType>)
at::Tensor::log_softmax(long, c10::optional<c10::ScalarType>) const
_PyMethodDef_RawFastCallKeywords
_PyCFunction_FastCallKeywords
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
PyEval_EvalCode
_PyMethodDef_RawFastCallKeywords
_PyCFunction_FastCallKeywords
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyFunction_FastCallKeywords
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyFunction_FastCallKeywords
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyObject_Call_Prepend
PyObject_Call
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyFunction_FastCallKeywords
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyFunction_FastCallKeywords
_PyEval_EvalFrameDefault
_PyFunction_FastCallKeywords
_PyEval_EvalFrameDefault
_PyFunction_FastCallKeywords
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyFunction_FastCallDict
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyFunction_FastCallDict
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyFunction_FastCallKeywords
_PyEval_EvalFrameDefault
_PyFunction_FastCallKeywords
_PyEval_EvalFrameDefault
_PyObject_Call_Prepend
PyObject_Call
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyFunction_FastCallKeywords
_PyEval_EvalFrameDefault
_PyObject_Call_Prepend
_PyObject_FastCallKeywords
_PyMethodDef_RawFastCallDict
PyCFunction_Call
_PyEval_EvalFrameDefault
_PyFunction_FastCallKeywords
_PyEval_EvalFrameDefault
_PyFunction_FastCallKeywords
_PyEval_EvalFrameDefault
_PyFunction_FastCallKeywords
_PyEval_EvalFrameDefault
_PyFunction_FastCallKeywords
_PyEval_EvalFrameDefault
_PyFunction_FastCallKeywords
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyFunction_FastCallKeywords
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
PyEval_EvalCode
_PyMethodDef_RawFastCallKeywords
_PyCFunction_FastCallKeywords
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyFunction_FastCallKeywords
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyFunction_FastCallDict
_Py_UnixMain
__libc_start_main
_start
*** End stack trace ***
Value out of range (expected to be in range of [0, -1], but got 0)Also, the error doesn't look message correct: since Will update once I've finished building XLA locally to test the macros used. Thanks! |
|
Updates: PyTorch on 0d tensors with This PR should be ready for review, hopefully, tests should pass. Thanks! |
pmeier
left a comment
There was a problem hiding this comment.
Two nits inline, otherwise LGTM! Thanks @krshrimali.
Co-authored-by: Philip Meier <github.pmeier@posteo.de>
| def generator(): | ||
| for shape, args, kwargs in cases: | ||
| yield SampleInput(make_arg(shape), args=args, kwargs=kwargs) | ||
| cases = ( |
There was a problem hiding this comment.
This is a pathological case, but could we add ((), (0,)) to the list?
There was a problem hiding this comment.
Thanks for the comment @zou3519! I had added this before but the test fails on the XLA device: #62077 (comment). PyTorch on 0d tensors with dim=0 doesn't throw an error but on XLA, it does.
There was a problem hiding this comment.
Gotcha, sorry for not seeing that! The action items you proposed (file an issue, leave that case out of the OpInfo) sgtm
There was a problem hiding this comment.
Thanks, @zou3519 ! Post discussion with @mruberry, we thought that having a separate case when the device type isn't XLA so that we don't skip this input for CPU/CUDA devices. I've also filed an issue here: pytorch/xla#3061.
zou3519
left a comment
There was a problem hiding this comment.
This looks pretty good. I added some comments about some more cases for completeness, after that we should be good to go
…hrimali/pytorch into opinfo/nn/functional/softmax
…ntries now" This reverts commit 8ac9c63.
|
Lint is failing: https://github.com/pytorch/pytorch/pull/62077/checks?check_run_id=3191463113 |
Thanks, @zou3519 for the pointer, I've fixed it now. Hopefully, the tests should pass :) |
|
@zou3519 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
| for shape, args, kwargs in cases: | ||
| yield SampleInput(make_arg(shape), args=args, kwargs=kwargs) | ||
| # PyTorch on XLA throws an error when passed with dim argument for 0d tensor. | ||
| # See https://github.com/pytorch/xla/issues/3061 for more details. |
| return list(generator()) | ||
| return [ | ||
| SampleInput(make_arg(shape), args=dim, kwargs=dict(dtype=torch.float64) if with_dtype else None) | ||
| for shape, dim in cases |
There was a problem hiding this comment.
Style nit: put the for loop first for readability (doesn't have to be changed in this PR)
| def sample_inputs_log_softmax(op_info, device, dtype, requires_grad, with_dtype=False, **kwargs): | ||
| # Used for both log_softmax and softmax | ||
| def sample_inputs_softmax_variant(op_info, device, dtype, requires_grad, with_dtype=False, **kwargs): | ||
| make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) |
There was a problem hiding this comment.
"with_dtype" should be kwarg-only
| dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), | ||
| supports_forward_ad=True, | ||
| sample_inputs_func=sample_inputs_max_min_binary,), | ||
| # `softmax` supports different dtypes based on whether `dtype` argument, |
mruberry
left a comment
There was a problem hiding this comment.
Really nice skip removal. Overall looks good. I made a few comments inline to consider in future PRs, no changes needed for this one
This PR:
softmaxandnn.functional.softmax(alias).test_jit_alias_remappingtest oflog_softmax.Please see pytorch/functorch#78 and #54261.
cc: @mruberry @zou3519 @pmeier