C++ API parity: at::Tensor::requires_grad_#26332
C++ API parity: at::Tensor::requires_grad_#26332pbelevich wants to merge 22 commits intogh/pbelevich/8/basefrom
Conversation
yf225
left a comment
There was a problem hiding this comment.
Overall looks awesome! I left a minor comment.
|
This will break things in |
|
@eellison Do you mind elaborating more on the use cases that this will break? The |
All of the functions in |
Differential Revision: [D17427575](https://our.internmc.facebook.com/intern/diff/D17427575)
Could you remove the register_prim_ops implementation ? Those are for registering ops that are not bound to the torch c++ library. There is no need to have it in C++ and in register_prim_ops, since the c++ ops are exposed to the JIT already. |
| return 0; | ||
| }, | ||
| aliasAnalysisConservative()), | ||
| Operator( |
There was a problem hiding this comment.
This looks good and should fix the issue in JIT
|
@pbelevich edited my comment |
|
@eellison , this one wouldn't be exposed correctly. note the "conservative" annotation in the prim_ops registration. |
Do you know that the register_prim_ops schema is being matched to before the native_functions one? |
|
@eellison that seems to be how it is used in the file, I don't know if bugs have crept in -- the c10 dispatch code is extremely hard to understand. I believe, if there is an issue, there will be a runtime error |
From looking at the code I would suspect this should raise, since it's the same schema with a different options. If doesn't than i would think that's a bug. cc @smessmer |
@smessmer why is it not a bug that https://bddppq.github.io/codebrowser/pytorch/pytorch/aten/src/ATen/core/dispatch/Dispatcher.cpp.html#63 didn't fire? |
smessmer
left a comment
There was a problem hiding this comment.
oh wait I'm a bit confused here. The entry in native_functions.yaml should already create a jit op for this in register_aten_ops.cpp. Why is the one in register_prim_ops.cpp needed?
@eellison This doesn't crash because register_prim_ops.cpp isn't the c10 operator library, that's a shortcut directly to jit which should only be used if absolutely needed.
smessmer
left a comment
There was a problem hiding this comment.
withdrawing my concerns with recent changes
Differential Revision: [D17427575](https://our.internmc.facebook.com/intern/diff/D17427575) [ghstack-poisoned]
Differential Revision: [D17427575](https://our.internmc.facebook.com/intern/diff/D17427575) [ghstack-poisoned]
|
@pbelevich merged this pull request in 46f96d1. |
Summary: Pull Request resolved: pytorch/pytorch#26332 Test Plan: Imported from OSS Differential Revision: D17427575 Pulled By: pbelevich fbshipit-source-id: 5500169a4fa0ef9cc2a7272e13b6e2d89df09260
Summary: Pull Request resolved: pytorch#26332 Test Plan: Imported from OSS Differential Revision: D17427575 Pulled By: pbelevich fbshipit-source-id: 5500169a4fa0ef9cc2a7272e13b6e2d89df09260
| }, | ||
| aliasAnalysisConservative()), | ||
| Operator( | ||
| "aten::requires_grad_(Tensor(a!) self, bool _requires_grad=True) -> Tensor(a!)", |
There was a problem hiding this comment.
@pbelevich hey any reason we have this schema with _requires_grad instead of requires_grad? This is creating discrepancy between the jit api and the python api as it take requires_grad..
Stack from ghstack:
Differential Revision: D17427575