[C++ API] Initialization functions#9295
Conversation
aten/src/ATen/native/Init.cpp
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/api/src/nn/init.cpp
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
@pytorchbot retest this please |
| Tensor sparse_(Tensor tensor, double sparsity, double std = 0.01); | ||
| Tensor uniform_(Tensor tensor, double low = 0, double high = 1); | ||
| Tensor xavier_normal_(Tensor tensor, double gain = 1.0); | ||
| Tensor xavier_uniform_(Tensor tensor, double gain = 1.0); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
facebook-github-bot
left a comment
There was a problem hiding this comment.
@goldsborough has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
@pytorchbot retest this please |
facebook-github-bot
left a comment
There was a problem hiding this comment.
@goldsborough is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary: To allow our C++ customers to use our initialization methods as well, this PR moves some of the code from `torch.nn.init` to ATen, calls it from Python, and adds equivalent code to the C++ frontend. Notes: 1. Happy to hear thoughts on whether it's ok to have e.g. `torch.nn.init.dirac_` *and* `torch.dirac_` (the former has a `no_grad` guard). We have this for `ones_` and stuff too, so I don't mind it. 2. I left the exception checking in Python because they throw `ValueError`s while ATen errors show as `RuntimeError`s. I imagine this would break users' error handling if someone were to have a `try`-`except` handler for `ValueError` (or maybe it's a far fetch) EDIT: After discussions with zdevito, the PR now simply duplicates the code in C++ exclusively for the C++ API, and we leave the Python code as-is (to make it easier for people to read/modify). ebetica ezyang apaszke Pull Request resolved: pytorch#9295 Differential Revision: D8813793 Pulled By: goldsborough fbshipit-source-id: 4b969f3f75952c1be4e837e19e23b8098e5fbd4b
To allow our C++ customers to use our initialization methods as well, this PR moves some of the code from
torch.nn.initto ATen, calls it from Python, and adds equivalent code to the C++ frontend.Notes:
torch.nn.init.dirac_andtorch.dirac_(the former has ano_gradguard). We have this forones_and stuff too, so I don't mind it.ValueErrors while ATen errors show asRuntimeErrors. I imagine this would break users' error handling if someone were to have atry-excepthandler forValueError(or maybe it's a far fetch)EDIT: After discussions with @zdevito, the PR now simply duplicates the code in C++ exclusively for the C++ API, and we leave the Python code as-is (to make it easier for people to read/modify).
@ebetica @ezyang @apaszke