Conversation
|
To begin, this PR creates a prototypical C++ API for |
|
This is cool! Just making sure you're aware of https://sourcegraph.com/github.com/pytorch/pytorch/-/blob/torch/csrc/jit/script/module.h - we need to make sure that there's a single nn.Module abstraction in C++ which works with python, script, tracing, etc |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/api/src/module.cpp
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
Nice POC! I found a few worrying things and left comments |
There was a problem hiding this comment.
This is perhaps a naive question, but why provide a clone() method rather than a copy constructor? clone() requires CRTP if you want to return the correct type, and in the current implementation it dictates the storage for the resulting module (unique_ptr).
|
It's for when all you have is an abstract Module*, i.e. a "polymorphic copy constructor". In that case you need to know the static type to invoke the copy constructor. @ebetica was showing me some neuroevolution stuff where you guys needed this
|
|
@jgehring see the virtual copy constructor paradigm: https://www.geeksforgeeks.org/advanced-c-virtual-copy-constructor/ |
|
Ah thanks, of course, I completely missed that. Indeed, we discussed CRTP quite a lot for @ebetica's autogradpp... |
|
I'v squashed my commits a bit to now have:
|
| include_dirs += [ | ||
| cwd, | ||
| os.path.join(cwd, "torch", "csrc"), | ||
| os.path.join(cwd, "torch", "csrc", "api", "include"), |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| if (auto method = methods.find(name)) { | ||
| return method->get(); | ||
| } | ||
| return nullptr; |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| #include "torch/csrc/autograd/variable.h" | ||
| #include <ATen/optional.h> | ||
|
|
||
| #include "torch/csrc/api/include/torch/detail/ordered_dict.h" |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| append(std::forward<Tail>(tail)...); | ||
| } | ||
|
|
||
| std::vector<std::unique_ptr<Module>> modules_; |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
|
||
| /// Adds a new `Module` to the `Sequential` container. | ||
| template <typename M> | ||
| void append(M&& module) { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
CC @smessmer who is our resident C++ expert |
| message(STATUS "ATEN_BUILD_PATH is ${ATEN_BUILD_PATH}") | ||
| endif() | ||
|
|
||
| set(CMAKE_CXX_FLAGS "--std=c++11 -Wall -Wextra -Wno-unused-parameter -Wno-unused-variable ${CMAKE_CXX_FLAGS}") |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
Changes:
|
apaszke
left a comment
There was a problem hiding this comment.
Mostly LGTM. Let's get this in soon, because it's getting dangerously large.
|
|
||
| ## API tests | ||
|
|
||
| SET(TORCH_API_TEST_SRCS ${TORCH_SRC_DIR}/csrc/api/test/test.cpp) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| "${TORCH_SRC_DIR}/../aten/src/ATen/utils/catch/single_include" | ||
| "${COMMON_INCLUDES}") | ||
|
|
||
| INSTALL(TARGETS test_api |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| size_t size() const noexcept; | ||
|
|
||
| protected: | ||
| std::vector<Item> items_; |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
|
||
| void Module::type(at::ScalarType new_type) { | ||
| // parameters().apply([=](Tensor& tensor) { tensor.toType_(new_type); }); | ||
| // buffers().apply([=](Tensor& tensor) { tensor.toType_(new_type); }); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| #include "torch/csrc/autograd/variable.h" | ||
| #include <ATen/optional.h> | ||
|
|
||
| #include "torch/csrc/api/include/torch/detail/ordered_dict.h" |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| if (auto method = methods.find(name)) { | ||
| return method->get(); | ||
| } | ||
| return nullptr; |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
|
||
| py::tuple result(self.get_parameters().size()); | ||
| size_t index = 0; | ||
| for (auto& parameter : self.get_parameters()) { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| for (auto& method : self.get_methods()) { | ||
| method_names.push_back((*method)->name()); | ||
| } | ||
| return method_names; |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
@apaszke I was actually planning on doing all development in this PR, and squashing commits to make it manageable. This means we don't have to make it "public" yet (so people don't use it). And I think if I push it onto a branch, say |
|
I think it's easier to merge this thing in parts, even if it's going to undergo breaking changes. It's already 1,1k lines long, and reviewing PRs that have 2k+ is really hard. It's probably ok to merge this even before it's public. I doubt it's going to be used before we actually start advertising it, and we should make it clear that it's really experimental and will break your code. |
apaszke
left a comment
There was a problem hiding this comment.
Just to make sure this isn't merged before you remove the accidental ONNX submodule change.
|
Will be rebased on top of autogradpp |
This PR will be the development ground of PyTorch's C++ API for a while. To get CI feedback, without pushing to master, I will be posting continuous updates into this PR, one isolated, squashed commit at a time, and request reviews of individual commits.
CC @ebetica @jgehring