Use c10::variant-based enums for Reduction#27942
Use c10::variant-based enums for Reduction#27942yf225 wants to merge 44 commits intogh/yf225/11/basefrom
Conversation
Use c10::variant-based enums for Reduction gh-metadata: pytorch pytorch 27942 gh/yf225/11/head
Use c10::variant-based enums for Reduction gh-metadata: pytorch pytorch 27942 gh/yf225/11/head
Use c10::variant-based enums for Reduction gh-metadata: pytorch pytorch 27942 gh/yf225/11/head
Use c10::variant-based enums for Reduction gh-metadata: pytorch pytorch 27942 gh/yf225/11/head
Use c10::variant-based enums for Reduction gh-metadata: pytorch pytorch 27942 gh/yf225/11/head
| TORCH_ARG(bool, swap) = false; | ||
| /// Specifies the reduction to apply to the output. Default: Mean | ||
| TORCH_ARG(Reduction::Reduction, reduction) = Reduction::Mean; | ||
| TORCH_ARG(10::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>, reduction) = torch::kMean; |
There was a problem hiding this comment.
Sorry I was looking if a PR of mine was in the list and found this PR which kind of concerns what I was doing :P I think you are missing a c here (should be c10) and in previous declarations too!
There was a problem hiding this comment.
@CarMiranda Thanks a lot for the catch! I just fixed it. After this PR is merged, I will do a sweep to change all torch::nn layers that use Reduction to use the corresponding variant type. :D
Use c10::variant-based enums for Reduction gh-metadata: pytorch pytorch 27942 gh/yf225/11/head
Use c10::variant-based enums for Reduction gh-metadata: pytorch pytorch 27942 gh/yf225/11/head
Use c10::variant-based enums for Reduction gh-metadata: pytorch pytorch 27942 gh/yf225/11/head
Use c10::variant-based enums for Reduction gh-metadata: pytorch pytorch 27942 gh/yf225/11/head
Use c10::variant-based enums for Reduction gh-metadata: pytorch pytorch 27942 gh/yf225/11/head
Use c10::variant-based enums for Reduction gh-metadata: pytorch pytorch 27942 gh/yf225/11/head
Use c10::variant-based enums for Reduction gh-metadata: pytorch pytorch 27942 gh/yf225/11/head
Use c10::variant-based enums for Reduction gh-metadata: pytorch pytorch 27942 gh/yf225/11/head
Use c10::variant-based enums for Reduction gh-metadata: pytorch pytorch 27942 gh/yf225/11/head
Use c10::variant-based enums for Reduction gh-metadata: pytorch pytorch 27942 gh/yf225/11/head
Use c10::variant-based enums for Reduction gh-metadata: pytorch pytorch 27942 gh/yf225/11/head
Use c10::variant-based enums for Reduction gh-metadata: pytorch pytorch 27942 gh/yf225/11/head
Use c10::variant-based enums for Reduction gh-metadata: pytorch pytorch 27942 gh/yf225/11/head
Use c10::variant-based enums for Reduction gh-metadata: pytorch pytorch 27942 gh/yf225/11/head
Use c10::variant-based enums for Reduction gh-metadata: pytorch pytorch 27942 gh/yf225/11/head
Use c10::variant-based enums for Reduction gh-metadata: pytorch pytorch 27942 gh/yf225/11/head
Use c10::variant-based enums for Reduction gh-metadata: pytorch pytorch 27942 gh/yf225/11/head
Use c10::variant-based enums for Reduction gh-metadata: pytorch pytorch 27942 gh/yf225/11/head
Use c10::variant-based enums for Reduction gh-metadata: pytorch pytorch 27942 gh/yf225/11/head
Use c10::variant-based enums for Reduction gh-metadata: pytorch pytorch 27942 gh/yf225/11/head
…div / mse_loss / binary_cross_entropy" Use c10::variant-based enums for Reduction gh-metadata: pytorch pytorch 27942 gh/yf225/11/head
torch/csrc/api/include/torch/enum.h
Outdated
| // ``` | ||
| // Tensor some_functional( | ||
| // const Tensor& input, | ||
| // const SomeOptions& options = {}) { |
There was a problem hiding this comment.
This smells wrong to me. Why don't you just take it by value? It's a very small struct.
There was a problem hiding this comment.
I updated the comment here and will do a sweep to take Options by value in all functionals in a follow-up PR. Although it doesn't fix the problem that TORCH_OPTIONS_CTOR_VARIANT_ARG3/TORCH_OPTIONS_CTOR_VARIANT_ARG4 try to address though :(
| return torch::l1_loss( | ||
| input, | ||
| target, | ||
| c10::visit(enumtype::_reduction_get_enum{}, options.reduction())); |
There was a problem hiding this comment.
Instead of manually typing out c10::visit everywhere, why not come up with a good API for doing this and call that instead? Especially since _reduction_get_enum is underscored...
There was a problem hiding this comment.
Thanks for the suggestion! I added torch::enumtype::reduction_get_enum() API for this purpose :D
|
Can you say more about what "fix F::kl_div / mse_loss / binary_cross_entropy" means? |
…div / mse_loss / binary_cross_entropy" Use c10::variant-based enums for Reduction gh-metadata: pytorch pytorch 27942 gh/yf225/11/head
…div / mse_loss / binary_cross_entropy" Use c10::variant-based enums for Reduction gh-metadata: pytorch pytorch 27942 gh/yf225/11/head
…div / mse_loss / binary_cross_entropy" Use c10::variant-based enums for Reduction gh-metadata: pytorch pytorch 27942 gh/yf225/11/head
Use c10::variant-based enums for Reduction gh-metadata: pytorch pytorch 27942 gh/yf225/11/head
Use c10::variant-based enums for Reduction gh-metadata: pytorch pytorch 27942 gh/yf225/11/head
Use c10::variant-based enums for Reduction gh-metadata: pytorch pytorch 27942 gh/yf225/11/head
Use c10::variant-based enums for Reduction gh-metadata: pytorch pytorch 27942 gh/yf225/11/head
Use c10::variant-based enums for Reduction gh-metadata: pytorch pytorch 27942 gh/yf225/11/head
Use c10::variant-based enums for Reduction gh-metadata: pytorch pytorch 27942 gh/yf225/11/head
Use c10::variant-based enums for Reduction gh-metadata: pytorch pytorch 27942 gh/yf225/11/head
Use c10::variant-based enums for Reduction gh-metadata: pytorch pytorch 27942 gh/yf225/11/head
The original logic in F::kl_div / mse_loss / binary_cross_entropy doesn't match that of Python version. I moved the changes to another PR since it is not strictly related to the Reduction enum changes. |
Use c10::variant-based enums for Reduction gh-metadata: pytorch pytorch 27942 gh/yf225/11/head
Stack from ghstack:
Differential Revision: D18202857