Skip to content

[jit] Initial torchbind prototype#21098

Closed
Chillee wants to merge 79 commits intopytorch:masterfrom
Chillee:torchbind
Closed

[jit] Initial torchbind prototype#21098
Chillee wants to merge 79 commits intopytorch:masterfrom
Chillee:torchbind

Conversation

@Chillee
Copy link
Collaborator

@Chillee Chillee commented May 29, 2019

I have some test code in there as well, along with a script "test_libtorch" to run it. You'll need to modify test_libtorch to point to where you have pytorch built. I currently require that pybind11 is included as a subdirectory of the test, but added it to the .gitignore to make this reviewable.

Currently, something like this works:

struct Foo {
  int x, y;
  Foo(): x(2), y(5){}
  Foo(int x_, int y_) : x(x_), y(y_) {}
  void display() {
    cout<<"x: "<<x<<' '<<"y: "<<y<<endl;
  }
  int64_t add(int64_t z) {
    return (x+y)*z;
  }
};
static auto test = torch::jit::class_<Foo>("Foo")
                    .def(torch::jit::init<int64_t, int64_t>())
                    .def("display", &Foo::display)
                    .def("add", &Foo::add)
                    .def("combine", &Foo::combine);

with

@torch.jit.script
def f(x):
    val = torch._C.Foo(5, 3)
    val.display()
    print(val.add(3))

results in

x: 5 y: 3
24

Current issues:

  • The python class created by torchscript doesn't interactly properly with the surrounding code.
@torch.jit.script
def f(x):
    val = torch._C.Foo(5, 3)
    return val
  • Doesn't properly take in non-pointer classes. Can't define this function signature in cpp (We don't want to support this I believe).
  void combine(Foo x) {
  • Has some issues with memory for blobs when constructing multiple objects (fix constant propagation pass to not treat capsules as the same object).
@torch.jit.script
def f(x):
    val = torch._C.Foo(5, 3)
    val2 = torch._C.Foo(100, 0)
    val.display()
    print(val.add(3))
  • Can't define multiple constructors (need to define overload string. Currently not possible since we don't support overloaded methods).
  • init is a little bit different syntax than pybind. .init<...>() instead of .def(py::init<>())
  • I couldn't figure out how to add some files into the build so they'd be copied to the include/ directories, so I symlinked them manually.
  • Currently, the conversion from Python into Torchscript doesn't work.
  • Torchbind also currently requires Python/Pybind dependency. Fixing this would probably involve some kind of macro to bind into Python when possible.
  • We pass back into Python by value, currently. There's no way of passing by reference.
  • Currently can only register one method with the same type signature. This is because we create a static auto opRegistry, and the function is templated on the type signature.

Somewhat blocked on #21177. We currently use some structures that will be refactored by his PR (namely return_type_to_ivalue and ivalue_to_arg_type.

@pytorchbot pytorchbot added the oncall: jit Add this issue/PR to JIT oncall triage queue label May 29, 2019
@Chillee Chillee requested review from jamesr66a and suo May 29, 2019 22:35
@pytorchbot pytorchbot added the module: build Build system issues label May 29, 2019
Copy link
Collaborator

@jamesr66a jamesr66a left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool! I think you're on the right track. Can you add some documentation about what this currently breaks in the PR description?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image

@pytorchbot pytorchbot added module: cpp Related to C++ API module: docs Related to our documentation, both in docs/ and docblocks module: infra Relates to CI infrastructure module: pybind Related to our Python bindings / interactions with other Python libraries labels Jun 3, 2019
@pytorchbot pytorchbot added module: custom-operators custom operators, custom ops, custom-operators, custom-ops module: internals Related to internal abstractions in c10 and ATen labels Jun 4, 2019
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Chillee is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Chillee is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

zdevito pushed a commit to zdevito/ATen that referenced this pull request Aug 3, 2019
Summary:
I have some test code in there as well, along with a script "test_libtorch" to run it. You'll need to modify `test_libtorch` to point to where you have `pytorch` built. I currently require that `pybind11` is included as a subdirectory of the test, but added it to the `.gitignore` to make this reviewable.

Currently, something like this works:
```cpp
struct Foo {
  int x, y;
  Foo(): x(2), y(5){}
  Foo(int x_, int y_) : x(x_), y(y_) {}
  void display() {
    cout<<"x: "<<x<<' '<<"y: "<<y<<endl;
  }
  int64_t add(int64_t z) {
    return (x+y)*z;
  }
};
static auto test = torch::jit::class_<Foo>("Foo")
                    .def(torch::jit::init<int64_t, int64_t>())
                    .def("display", &Foo::display)
                    .def("add", &Foo::add)
                    .def("combine", &Foo::combine);

```
with
```py
torch.jit.script
def f(x):
    val = torch._C.Foo(5, 3)
    val.display()
    print(val.add(3))
```
results in
```
x: 5 y: 3
24
```

Current issues:
- [x] The python class created by torchscript doesn't interactly properly with the surrounding code.
```
torch.jit.script
def f(x):
    val = torch._C.Foo(5, 3)
    return val
```
- [x] Doesn't properly take in non-pointer classes. Can't define this function signature in cpp (We don't want to support this I believe).
```cpp
  void combine(Foo x) {
```

- [x] Has some issues with memory for blobs when constructing multiple objects (fix constant propagation pass to not treat capsules as the same object).
```py
torch.jit.script
def f(x):
    val = torch._C.Foo(5, 3)
    val2 = torch._C.Foo(100, 0)
    val.display()
    print(val.add(3))
```
- [ ] Can't define multiple constructors (need to define overload string. Currently not possible since we don't support overloaded methods).
- [x] `init` is a little bit different syntax than `pybind`. `.init<...>()` instead of `.def(py::init<>())`
- [x] I couldn't figure out how to add some files into the build so they'd be copied to the `include/` directories, so I symlinked them manually.
- [ ] Currently, the conversion from Python into Torchscript doesn't work.
- [ ] Torchbind also currently requires Python/Pybind dependency. Fixing this would probably involve some kind of macro to bind into Python when possible.
- [ ] We pass back into Python by value, currently. There's no way of passing by reference.
- [x] Currently can only register one method with the same type signature. This is because we create a `static auto opRegistry`, and the function is templated on the type signature.

Somewhat blocked on pytorch/pytorch#21177. We currently use some structures that will be refactored by his PR (namely `return_type_to_ivalue` and `ivalue_to_arg_type`.
Pull Request resolved: pytorch/pytorch#21098

Differential Revision: D16634872

Pulled By: Chillee

fbshipit-source-id: 1408bb89ea649c27d560df59e2cf9920467fe1de
@facebook-github-bot
Copy link
Contributor

@Chillee merged this pull request in f81db8a.

@galv
Copy link
Collaborator

galv commented Aug 8, 2019

@Chillee Is there a high level explanation of why pybind11 is inappropriate for pytorch's jit module? I'm just curious.

@suo
Copy link
Member

suo commented Aug 8, 2019

@galv: we would like to be able to bind types into Python and the JIT at the same time, and ensure they have the same underlying implementation. pybind11 is great (and we try to create a similar interface here) but of course does not know about the JIT

@galv
Copy link
Collaborator

galv commented Aug 8, 2019 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

caffe2 Merged module: build Build system issues module: ci Related to continuous integration module: cpp Related to C++ API module: custom-operators custom operators, custom ops, custom-operators, custom-ops module: docs Related to our documentation, both in docs/ and docblocks module: infra Relates to CI infrastructure module: internals Related to internal abstractions in c10 and ATen module: pybind Related to our Python bindings / interactions with other Python libraries oncall: jit Add this issue/PR to JIT oncall triage queue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants