Skip to content

Add at::one_hot#15208

Closed
zasdfgbnm wants to merge 25 commits intopytorch:masterfrom
zasdfgbnm:onehot
Closed

Add at::one_hot#15208
zasdfgbnm wants to merge 25 commits intopytorch:masterfrom
zasdfgbnm:onehot

Conversation

@zasdfgbnm
Copy link
Collaborator

Closes: #15060

cc: @zou3519

I happen to need this feature, so I just implement it...

This PR contains removal of some unnecessary lambda in test_torch.py. These unnecessary lambdas are discovered by automated tools and I just go ahead and fix some of them them. If you don't like this change, I can revert it.

@zasdfgbnm zasdfgbnm changed the title Add at::to_one_hot Add at::one_hot Dec 14, 2018
Copy link
Contributor

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! Minor comments:

  • the default behavior of num_classes <= 0 is a little strange, can we use optional<int64_t>? I'm not 100% sure our code generation supports this
  • Please add some test cases for some corner cases
  • minor documentation nits

@zasdfgbnm
Copy link
Collaborator Author

@zou3519 How does this looks like now? The optional approach does not to work with JIT, but the feature should be coming soon: #15154

Copy link
Contributor

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Three more minor things and then this should be good to go!

  • Let's change the default to one_hot(self, num_classes=-1), where -1 means "pytorch, please infer the number of classes from the data"
  • check that the input to one_hot is a integral-type tensor
  • add a CUDA test in test/test_cuda.py

@zasdfgbnm
Copy link
Collaborator Author

zasdfgbnm commented Dec 14, 2018

@zou3519 Sorry, I didn't see your suggestion about the -1 before I finish modifying JIT to support optional<int64_t>... A new PR is comming (#15235), how do we deal with it?

I will change according to your other suggestions

@zasdfgbnm
Copy link
Collaborator Author

@zou3519 How about we keep the -1 as a legal input, and land this PR, then we work on #15235?

I think #15235 would also be useful to finish #12064

@zou3519
Copy link
Contributor

zou3519 commented Dec 14, 2018

@zasdfgbnm let's do those pull requests one step at a time. This one, introducing torch.to_hot, is the first one. I'm not sure how long it'll take to get the second one (use optional<int64_t> with the codegen) working, so let's code up this one torch.to_hot under the assumption that it is standalone.

Given that, it would really be great if we had the num_classes=-1 as a default in this PR because it makes more sense than num_classes=0 from a PyTorch API perspective

@zou3519
Copy link
Contributor

zou3519 commented Dec 14, 2018

Oh I see, your new PR contains this one @zasdfgbnm. I think we should either (1) finish this one since it is close to being good or (2) abandon this one in favor of your new PR so there are less things that move around. It's up to you which option you'd like to pursue

@zasdfgbnm
Copy link
Collaborator Author

@zou3519 I think landing this one alone (with num_classes=-1) is a good idea, because this is much easier in terms of both implementation and review.

namespace at { namespace native {

Tensor one_hot(const Tensor &self, int64_t num_classes) {
AT_ASSERTM(self.dtype() == kLong, "one_hot is only applicable to index tensor.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AT_ERROR is better, iirc AT_ASSERTM tells the user to submit a bug to pytorch

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zou3519 Do you mean AT_CHECK? I don't think AT_ERROR takes any condition as arg, it simply report there is an error unconditionally.

}

// non-empty tensor
AT_ASSERTM(self.min().item().toLong() >= 0, "Class values must be non-negative.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use AT_ERROR instead

if (num_classes == -1) {
num_classes = self.max().item().toLong() + 1;
} else {
AT_ASSERTM(num_classes > self.max().item().toLong(), "Class values must be smaller than num_classes.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use AT_ERROR instead. It might be better for user usability if it were something like

AT_ERROR(num_classes > self.max().item().toLong(), "Class values (", self.max().item().toLong(), ") must be smaller than num_classes (", num_classes, ")");

Copy link
Contributor

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm, with a few minor nits:

  • Let's use AT_ERROR for the error message instead of AT_ASSERTM: AT_ASSERTM is for developers, AT_ERROR is for users
  • Please fix the documentation as noted

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zou3519 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@zou3519
Copy link
Contributor

zou3519 commented Dec 20, 2018

@zasdfgbnm test failures look real

@zasdfgbnm
Copy link
Collaborator Author

@zou3519 Replacing AT_ERROR with AT_CHECK should fix the error. Let's see the outcome of CI.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zou3519 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@zou3519
Copy link
Contributor

zou3519 commented Dec 20, 2018

Right, AT_ERROR unconditionally errors but AT_CHECK checks a condition before throwing an error; that is my bad. Thanks for catching that!

@ezyang
Copy link
Contributor

ezyang commented Dec 20, 2018

Test failures are fake

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ezyang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

facebook-github-bot pushed a commit that referenced this pull request Dec 20, 2018
Summary: Closes: #15060

Differential Revision: D13528014

Pulled By: ezyang

fbshipit-source-id: 5a18689a4c5638d92f9390c91517f741e5396293
@vishwakftw
Copy link
Contributor

This PR can be closed now, I suppose.

@zasdfgbnm
Copy link
Collaborator Author

Why @facebook-github-bot didn't close this?

@zasdfgbnm zasdfgbnm closed this Dec 21, 2018
@zasdfgbnm zasdfgbnm deleted the onehot branch December 21, 2018 00:02
@zasdfgbnm zasdfgbnm mentioned this pull request Dec 21, 2018
facebook-github-bot pushed a commit that referenced this pull request Jan 21, 2019
Summary:
This PR does three things:

~~Allow `int64_t?` in function schema,  which provide an elegant way of implementing null-able int arguments, as discussed in #15208 (review)

~~Originally implemented in #15235

~~Example:~~

```yaml
- func: myop(Tensor self, int64_t? dim=None) -> Tensor
  variants: function
```

~~cc: zou3519~~

Edit: implemented in #15234

Previously tried in #12064. There was a problem that C++ does not have kwarg support, which makes it confusing to know whether `unique(t, 1)` actually means `unique(t, dim=1)` or `unique(t, sorted=1)`.

Now I think I have a better idea on how to implement this: there are two ATen operators: `unique` and `unique_dim`. `unique` has the same signature as in python, and exported to both python and C++. `unique_dim` has signature `unique_dim(tensor, dim, sorted=False, return_inverse=False)`, and only exported to C++, which could be used more naturally for a C++ user.

Differential Revision: D13540278

Pulled By: wanchaol

fbshipit-source-id: 3768c76a90b0881f565a1f890459ebccbdfe6ecd
zdevito pushed a commit to zdevito/ATen that referenced this pull request Jan 21, 2019
Summary:
This PR does three things:

~~Allow `int64_t?` in function schema,  which provide an elegant way of implementing null-able int arguments, as discussed in pytorch/pytorch#15208 (review)

~~Originally implemented in pytorch/pytorch#15235

~~Example:~~

```yaml
- func: myop(Tensor self, int64_t? dim=None) -> Tensor
  variants: function
```

~~cc: zou3519~~

Edit: implemented in pytorch/pytorch#15234

Previously tried in pytorch/pytorch#12064. There was a problem that C++ does not have kwarg support, which makes it confusing to know whether `unique(t, 1)` actually means `unique(t, dim=1)` or `unique(t, sorted=1)`.

Now I think I have a better idea on how to implement this: there are two ATen operators: `unique` and `unique_dim`. `unique` has the same signature as in python, and exported to both python and C++. `unique_dim` has signature `unique_dim(tensor, dim, sorted=False, return_inverse=False)`, and only exported to C++, which could be used more naturally for a C++ user.

Differential Revision: D13540278

Pulled By: wanchaol

fbshipit-source-id: 3768c76a90b0881f565a1f890459ebccbdfe6ecd
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add to_one_hot function

9 participants