Conversation
zou3519
left a comment
There was a problem hiding this comment.
Looks great! Minor comments:
- the default behavior of num_classes <= 0 is a little strange, can we use optional<int64_t>? I'm not 100% sure our code generation supports this
- Please add some test cases for some corner cases
- minor documentation nits
zou3519
left a comment
There was a problem hiding this comment.
Three more minor things and then this should be good to go!
- Let's change the default to one_hot(self, num_classes=-1), where
-1means "pytorch, please infer the number of classes from the data" - check that the input to one_hot is a integral-type tensor
- add a CUDA test in test/test_cuda.py
|
@zasdfgbnm let's do those pull requests one step at a time. This one, introducing Given that, it would really be great if we had the |
|
Oh I see, your new PR contains this one @zasdfgbnm. I think we should either (1) finish this one since it is close to being good or (2) abandon this one in favor of your new PR so there are less things that move around. It's up to you which option you'd like to pursue |
|
@zou3519 I think landing this one alone (with |
aten/src/ATen/native/Onehot.cpp
Outdated
| namespace at { namespace native { | ||
|
|
||
| Tensor one_hot(const Tensor &self, int64_t num_classes) { | ||
| AT_ASSERTM(self.dtype() == kLong, "one_hot is only applicable to index tensor."); |
There was a problem hiding this comment.
AT_ERROR is better, iirc AT_ASSERTM tells the user to submit a bug to pytorch
There was a problem hiding this comment.
@zou3519 Do you mean AT_CHECK? I don't think AT_ERROR takes any condition as arg, it simply report there is an error unconditionally.
aten/src/ATen/native/Onehot.cpp
Outdated
| } | ||
|
|
||
| // non-empty tensor | ||
| AT_ASSERTM(self.min().item().toLong() >= 0, "Class values must be non-negative."); |
aten/src/ATen/native/Onehot.cpp
Outdated
| if (num_classes == -1) { | ||
| num_classes = self.max().item().toLong() + 1; | ||
| } else { | ||
| AT_ASSERTM(num_classes > self.max().item().toLong(), "Class values must be smaller than num_classes."); |
There was a problem hiding this comment.
Use AT_ERROR instead. It might be better for user usability if it were something like
AT_ERROR(num_classes > self.max().item().toLong(), "Class values (", self.max().item().toLong(), ") must be smaller than num_classes (", num_classes, ")");
zou3519
left a comment
There was a problem hiding this comment.
lgtm, with a few minor nits:
- Let's use AT_ERROR for the error message instead of AT_ASSERTM: AT_ASSERTM is for developers, AT_ERROR is for users
- Please fix the documentation as noted
facebook-github-bot
left a comment
There was a problem hiding this comment.
@zou3519 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
@zasdfgbnm test failures look real |
|
@zou3519 Replacing |
facebook-github-bot
left a comment
There was a problem hiding this comment.
@zou3519 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
Right, AT_ERROR unconditionally errors but AT_CHECK checks a condition before throwing an error; that is my bad. Thanks for catching that! |
|
Test failures are fake |
facebook-github-bot
left a comment
There was a problem hiding this comment.
@ezyang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary: Closes: #15060 Differential Revision: D13528014 Pulled By: ezyang fbshipit-source-id: 5a18689a4c5638d92f9390c91517f741e5396293
|
This PR can be closed now, I suppose. |
|
Why @facebook-github-bot didn't close this? |
Summary: This PR does three things: ~~Allow `int64_t?` in function schema, which provide an elegant way of implementing null-able int arguments, as discussed in #15208 (review) ~~Originally implemented in #15235 ~~Example:~~ ```yaml - func: myop(Tensor self, int64_t? dim=None) -> Tensor variants: function ``` ~~cc: zou3519~~ Edit: implemented in #15234 Previously tried in #12064. There was a problem that C++ does not have kwarg support, which makes it confusing to know whether `unique(t, 1)` actually means `unique(t, dim=1)` or `unique(t, sorted=1)`. Now I think I have a better idea on how to implement this: there are two ATen operators: `unique` and `unique_dim`. `unique` has the same signature as in python, and exported to both python and C++. `unique_dim` has signature `unique_dim(tensor, dim, sorted=False, return_inverse=False)`, and only exported to C++, which could be used more naturally for a C++ user. Differential Revision: D13540278 Pulled By: wanchaol fbshipit-source-id: 3768c76a90b0881f565a1f890459ebccbdfe6ecd
Summary: This PR does three things: ~~Allow `int64_t?` in function schema, which provide an elegant way of implementing null-able int arguments, as discussed in pytorch/pytorch#15208 (review) ~~Originally implemented in pytorch/pytorch#15235 ~~Example:~~ ```yaml - func: myop(Tensor self, int64_t? dim=None) -> Tensor variants: function ``` ~~cc: zou3519~~ Edit: implemented in pytorch/pytorch#15234 Previously tried in pytorch/pytorch#12064. There was a problem that C++ does not have kwarg support, which makes it confusing to know whether `unique(t, 1)` actually means `unique(t, dim=1)` or `unique(t, sorted=1)`. Now I think I have a better idea on how to implement this: there are two ATen operators: `unique` and `unique_dim`. `unique` has the same signature as in python, and exported to both python and C++. `unique_dim` has signature `unique_dim(tensor, dim, sorted=False, return_inverse=False)`, and only exported to C++, which could be used more naturally for a C++ user. Differential Revision: D13540278 Pulled By: wanchaol fbshipit-source-id: 3768c76a90b0881f565a1f890459ebccbdfe6ecd
Closes: #15060
cc: @zou3519
I happen to need this feature, so I just implement it...
This PR contains removal of some unnecessary lambda in
test_torch.py. These unnecessary lambdas are discovered by automated tools and I just go ahead and fix some of them them. If you don't like this change, I can revert it.