Skip to content

Return namedtuples from torch.* function with multiple return arguments for C++ operators#15429

Closed
zasdfgbnm wants to merge 49 commits intopytorch:masterfrom
zasdfgbnm:namedtuple
Closed

Return namedtuples from torch.* function with multiple return arguments for C++ operators#15429
zasdfgbnm wants to merge 49 commits intopytorch:masterfrom
zasdfgbnm:namedtuple

Conversation

@zasdfgbnm
Copy link
Collaborator

@zasdfgbnm zasdfgbnm commented Dec 20, 2018

Partially fixes: #394

Implementation detail:

Codegen is modified to generate codes that looks like below:

static PyObject * THPVariable_svd(PyObject* self_, PyObject* args, PyObject* kwargs)
{
  HANDLE_TH_ERRORS
  static PythonArgParser parser({
    "svd(Tensor input, bool some=True, bool compute_uv=True, *, TensorList[3] out=None)",
  }, /*traceable=*/true);

  ParsedArgs<6> parsed_args;
  auto r = parser.parse(args, kwargs, parsed_args);
  static PyStructSequence_Field fields0[] = {
    {"U", ""}, {"S", ""}, {"V", ""}, {nullptr}
  };
  static PyStructSequence_Desc desc0 = {
    "torch.return_types.svd_out", nullptr,
    fields0, 3
  };
  static PyTypeObject type0;
  static bool namedtuple_type_initialized0 = false;
  if (!namedtuple_type_initialized0) {
    PyStructSequence_InitType(&type0, &desc0);
    namedtuple_type_initialized0 = true;
  }
  static PyStructSequence_Field fields1[] = {
    {"U", ""}, {"S", ""}, {"V", ""}, {nullptr}
  };
  static PyStructSequence_Desc desc1 = {
    "torch.return_types.svd", nullptr,
    fields1, 3
  };
  static PyTypeObject type1;
  static bool namedtuple_type_initialized1 = false;
  if (!namedtuple_type_initialized1) {
    PyStructSequence_InitType(&type1, &desc1);
    namedtuple_type_initialized1 = true;
  }
  if (r.idx == 0) {
    if (r.isNone(3)) {
      return wrap(&type1, dispatch_svd(r.tensor(0), r.toBool(1), r.toBool(2)));
    } else {
      auto results = r.tensorlist_n<3>(3);
      return wrap(&type0, dispatch_svd(r.tensor(0), r.toBool(1), r.toBool(2), results[0], results[1], results[2]));
    }
  }
  Py_RETURN_NONE;
  END_HANDLE_TH_ERRORS
}

Types are defined as static member of THPVariable_${op_name} functions, and initialized at the first time the function is called.

When parsing function prototypes in native_functions.yaml, the parser will set the specified name as field_name when see things like -> (Tensor t1, ...). These field names will be the field names of namedtuple. The class of namedtuples will be named torch.return_types.${op_name}.

In some python 2, PyStructSequence is not a subtype of tuple, so we have to create some functions to check if an object is a tuple or namedtuple for compatibility issue.

Operators in native_functions.yaml are changed such that only max and svd are generated as namedtuple. Tests are added for these two operators to see if the return value works as expected. Docs for these two ops are also updated to explicitly mention the return value is a namedtuple. More ops will be added in later PRs.

There is some issue with Windows build of linker unable to resolve PyStructSequence_UnnamedField, and some workaround is added to deal with this case.

@zasdfgbnm zasdfgbnm changed the title [WIP] Return namedtuples from torch.* function with multiple return arguments [WIP] Return namedtuples from torch.* function with multiple return arguments for C++ operators Dec 20, 2018
Copy link
Contributor

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, I didn’t know about PyStructSequence! My only concern is that we’re creating those types at every invocation instead of caching them. Can you please run some benchmarks to see how the overhead has changed?

@ezyang
Copy link
Contributor

ezyang commented Dec 20, 2018

I think test failures are real.

One consequence of this change is that return names are part of the public API and no longer can be changed. The names you added look fine, but let's be cognizant of this.

@zasdfgbnm
Copy link
Collaborator Author

@ezyang Yes, it looks real, I will look into that.

For the "API" consideration, yes, we need to be very careful. But, is it a good idea to take the following strategy?

If in the doc, it explicitly says "return a namedtuple XXXXX", we consider this as public API, otherwise, we just consider their name as internal usage and might change without further notification.

Since there are many changes to many ops, this strategy allows us to make changes incrementally:

  • In this PR, we update native_functions.yaml, but don't touch docs. This means, we are just adding infrastructure to support namedtuple return, but no API change yet.
  • In future PRs, we change docs and add tests to these ops's namedtuple API to unit test, a few ops per PR.

This way would make the writing and reviewing PR easier.

@ezyang
Copy link
Contributor

ezyang commented Dec 21, 2018

Unfortunately, that's not how BC works :) People will depend on something like this whether you want them to or not. (to be clear, I'm not saying that we shouldn't do this; just that we should commit to the names we make available.)

@zasdfgbnm
Copy link
Collaborator Author

@ezyang Yes, I'm thinking the same thing, let me quickly disable that piece of code.

@zasdfgbnm
Copy link
Collaborator Author

@ezyang Please see 8ed0fd2

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ezyang is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.


- func: adaptive_max_pool3d_out(Tensor output, Tensor indices, Tensor self, IntList[3] output_size) -> (Tensor output, Tensor indices)
# Return: (Tensor output, Tensor indices)
- func: adaptive_max_pool3d_out(Tensor output, Tensor indices, Tensor self, IntList[3] output_size) -> (Tensor, Tensor)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why did some of these returns lose their names?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We hadn't finalized whether or not these were the names we wanted for these returns, so they were commented out for now, to be readded when we agreed on them.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A comment to that effect would be nice :).

@gchanan
Copy link
Contributor

gchanan commented Jan 22, 2019

were the benchmarks that @apaszke requested above done? I didn't see them when I looked briefly, but may have missed them.

@ezyang
Copy link
Contributor

ezyang commented Jan 22, 2019

@gchanan In the latest revision of the PR, the types are cached, so I figured that mooted Adam's benchmark request.

@gchanan
Copy link
Contributor

gchanan commented Jan 22, 2019

My guess is it probably doesn't matter, but it's good to do a sanity check.

zdevito pushed a commit to zdevito/ATen that referenced this pull request Jan 22, 2019
…ts for C++ operators (#15429)

Summary:
Partially fixes: pytorch/pytorch#394

Implementation detail:

Codegen is modified to generate codes that looks like below:
```C++
static PyObject * THPVariable_svd(PyObject* self_, PyObject* args, PyObject* kwargs)
{
  HANDLE_TH_ERRORS
  static PythonArgParser parser({
    "svd(Tensor input, bool some=True, bool compute_uv=True, *, TensorList[3] out=None)",
  }, /*traceable=*/true);

  ParsedArgs<6> parsed_args;
  auto r = parser.parse(args, kwargs, parsed_args);
  static PyStructSequence_Field fields0[] = {
    {"U", ""}, {"S", ""}, {"V", ""}, {nullptr}
  };
  static PyStructSequence_Desc desc0 = {
    "torch.return_types.svd_out", nullptr,
    fields0, 3
  };
  static PyTypeObject type0;
  static bool namedtuple_type_initialized0 = false;
  if (!namedtuple_type_initialized0) {
    PyStructSequence_InitType(&type0, &desc0);
    namedtuple_type_initialized0 = true;
  }
  static PyStructSequence_Field fields1[] = {
    {"U", ""}, {"S", ""}, {"V", ""}, {nullptr}
  };
  static PyStructSequence_Desc desc1 = {
    "torch.return_types.svd", nullptr,
    fields1, 3
  };
  static PyTypeObject type1;
  static bool namedtuple_type_initialized1 = false;
  if (!namedtuple_type_initialized1) {
    PyStructSequence_InitType(&type1, &desc1);
    namedtuple_type_initialized1 = true;
  }
  if (r.idx == 0) {
    if (r.isNone(3)) {
      return wrap(&type1, dispatch_svd(r.tensor(0), r.toBool(1), r.toBool(2)));
    } else {
      auto results = r.tensorlist_n<3>(3);
      return wrap(&type0, dispatch_svd(r.tensor(0), r.toBool(1), r.toBool(2), results[0], results[1], results[2]));
    }
  }
  Py_RETURN_NONE;
  END_HANDLE_TH_ERRORS
}
```
Types are defined as static member of `THPVariable_${op_name}` functions, and initialized at the first time the function is called.

When parsing function prototypes in `native_functions.yaml`, the parser will set the specified name as `field_name` when see things like `-> (Tensor t1, ...)`. These field names will be the field names of namedtuple. The class of namedtuples will be named `torch.return_types.${op_name}`.

In some python 2, `PyStructSequence` is not a subtype of tuple, so we have to create some functions to check if an object is a tuple or namedtuple for compatibility issue.

Operators in `native_functions.yaml` are changed such that only `max` and `svd` are generated as namedtuple. Tests are added for these two operators to see if the return value works as expected. Docs for these two ops are also updated to explicitly mention the return value is a namedtuple. More ops will be added in later PRs.

There is some issue with Windows build of linker unable to resolve `PyStructSequence_UnnamedField`, and some workaround is added to deal with this case.
Pull Request resolved: pytorch/pytorch#15429

Differential Revision: D13709678

Pulled By: ezyang

fbshipit-source-id: 23a511c9436977098afc49374e9a748b6e30bccf
@zasdfgbnm
Copy link
Collaborator Author

@ezyang do you want me to do the benchmarking as @gchanan says? or you will do it?

@ezyang
Copy link
Contributor

ezyang commented Jan 22, 2019

If you could do it that would be very helpful.

@zasdfgbnm
Copy link
Collaborator Author

@ezyang I will do it, maybe this evening.

@zasdfgbnm
Copy link
Collaborator Author

@gchanan @ezyang The overhead introduced by this PR is negligible.

The commit before this: 1e19fd9
Result:
image

After:
after

@zasdfgbnm zasdfgbnm deleted the namedtuple branch January 22, 2019 21:26
@ezyang
Copy link
Contributor

ezyang commented Jan 22, 2019

Thank you, much appreciated!

@gchanan
Copy link
Contributor

gchanan commented Jan 23, 2019

Thanks!


inline bool isTuple(pybind11::handle input) {
std::string m = pybind11::str(input.get_type().attr("__module__"));
return pybind11::isinstance<pybind11::tuple>(input) || m == "torch.return_types";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Uh, this is really expensive!!! In most cases those will be simply tuples, so we should just use PyTuple_Check (which is heavily optimized and much faster than regular isinstance calls), and only then try to fall back to checking the module like this!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More importantly, this PR doesn't refactor all places that use PyTuple_Check so we current have a very limited coverage, and will start rejecting things as tuples. What versions of Python fail the PyTuple_Check predicate with struct sequences?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@apaszke All I know is it fails on some 2.7, but not all 2.7. It seems that its hard to find a pattern which version fails, I even doubt that even for the same version 2.7.9, some build success while other fails. I suggest to look at the build redish of e3d03bf.
image

Copy link
Collaborator Author

@zasdfgbnm zasdfgbnm Feb 1, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@apaszke This is a very subtle compatibility issue, I don't think we can replace all PyTuple_Check with six::isTuple. For example when working on this, I actually tried in 8dfcc4c to make python_arg_parser.cpp support namedtuple to allow something like ret1 = torch.max(a, dim=0, out=my_namedtuple), but his gives segfault. So finally I decided to reverted this change in 304cd44 and simply not supporting this feature in these python2.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's very hard to reproduce this error because it only fails on some build of python2, I can not reproduce it on the python2 on my machine, so all I can do when debugging was to add some print to the code, see the printed message on CI, and decide what to change....

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it looks like all of Python 2 CI is red. Build jobs are green, test jobs are red. And all Caffe2 jobs are obviously green as well. So it seems to be a problem universal in Python 2.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What worries me very much is why do we consider those types tuples in there but not elsewhere?! It seems like a very very arbitrary distinction and it will likely break in very weird ways.

Copy link
Collaborator Author

@zasdfgbnm zasdfgbnm Feb 1, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@apaszke Just did some test:

Python 2.7.15 |Anaconda, Inc.| (default, Dec 14 2018, 19:04:19) 
[GCC 7.3.0] on linux2
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> a = torch.randn(5, 5)
>>> ret = a.max(dim=0)
>>> ret
torch.return_types.max(values=tensor([1.7617, 0.9970, 0.5510, 1.0230, 1.4386]), indices=tensor([1, 3, 4, 0, 3]))
>>> isinstance(ret, tuple)
False
>>> from collections import namedtuple
>>> T = namedtuple('T', ['a', 'b'])
>>> ab = T(1, 2)
>>> isinstance(ab, tuple)
True
>>> tuple(ret)
(tensor([1.7617, 0.9970, 0.5510, 1.0230, 1.4386]), tensor([1, 3, 4, 0, 3]))
>>> import inspect
>>> inspect.getsource(namedtuple)
'def namedtuple(typename, field_names, verbose=False, rename=False):\n    """Returns a new subclass of tuple with named fields.\n\n    >>> Point = namedtuple(\'Point\', [\'x\', \'y\'])\n    >>> Point.__doc__                   # docstring for the new class\n    \'Point(x, y)\'\n    >>> p = Point(11, y=22)             # instantiate with positional args or keywords\n    >>> p[0] + p[1]                     # indexable like a plain tuple\n    33\n    >>> x, y = p                        # unpack like a regular tuple\n    >>> x, y\n    (11, 22)\n    >>> p.x + p.y                       # fields also accessible by name\n    33\n    >>> d = p._asdict()                 # convert to a dictionary\n    >>> d[\'x\']\n    11\n    >>> Point(**d)                      # convert from a dictionary\n    Point(x=11, y=22)\n    >>> p._replace(x=100)               # _replace() is like str.replace() but targets named fields\n    Point(x=100, y=22)\n\n    """\n\n    # Validate the field names.  At the user\'s option, either generate an error\n    # message or automatically replace the field name with a valid name.\n    if isinstance(field_names, basestring):\n        field_names = field_names.replace(\',\', \' \').split()\n    field_names = map(str, field_names)\n    typename = str(typename)\n    if rename:\n        seen = set()\n        for index, name in enumerate(field_names):\n            if (not all(c.isalnum() or c==\'_\' for c in name)\n                or _iskeyword(name)\n                or not name\n                or name[0].isdigit()\n                or name.startswith(\'_\')\n                or name in seen):\n                field_names[index] = \'_%d\' % index\n            seen.add(name)\n    for name in [typename] + field_names:\n        if type(name) != str:\n            raise TypeError(\'Type names and field names must be strings\')\n        if not all(c.isalnum() or c==\'_\' for c in name):\n            raise ValueError(\'Type names and field names can only contain \'\n                             \'alphanumeric characters and underscores: %r\' % name)\n        if _iskeyword(name):\n            raise ValueError(\'Type names and field names cannot be a \'\n                             \'keyword: %r\' % name)\n        if name[0].isdigit():\n            raise ValueError(\'Type names and field names cannot start with \'\n                             \'a number: %r\' % name)\n    seen = set()\n    for name in field_names:\n        if name.startswith(\'_\') and not rename:\n            raise ValueError(\'Field names cannot start with an underscore: \'\n                             \'%r\' % name)\n        if name in seen:\n            raise ValueError(\'Encountered duplicate field name: %r\' % name)\n        seen.add(name)\n\n    # Fill-in the class template\n    class_definition = _class_template.format(\n        typename = typename,\n        field_names = tuple(field_names),\n        num_fields = len(field_names),\n        arg_list = repr(tuple(field_names)).replace("\'", "")[1:-1],\n        repr_fmt = \', \'.join(_repr_template.format(name=name)\n                             for name in field_names),\n        field_defs = \'\\n\'.join(_field_template.format(index=index, name=name)\n                               for index, name in enumerate(field_names))\n    )\n    if verbose:\n        print class_definition\n\n    # Execute the template string in a temporary namespace and support\n    # tracing utilities by setting a value for frame.f_globals[\'__name__\']\n    namespace = dict(_itemgetter=_itemgetter, __name__=\'namedtuple_%s\' % typename,\n                     OrderedDict=OrderedDict, _property=property, _tuple=tuple)\n    try:\n        exec class_definition in namespace\n    except SyntaxError as e:\n        raise SyntaxError(e.message + \':\\n\' + class_definition)\n    result = namespace[typename]\n\n    # For pickling to work, the __module__ variable needs to be set to the frame\n    # where the named tuple is created.  Bypass this step in environments where\n    # sys._getframe is not defined (Jython for example) or sys._getframe is not\n    # defined for arguments greater than 0 (IronPython).\n    try:\n        result.__module__ = _sys._getframe(1).f_globals.get(\'__name__\', \'__main__\')\n    except (AttributeError, ValueError):\n        pass\n\n    return result\n'

I think the problem is, namedtuple created on python side uses the code in https://github.com/python/cpython/blob/2.7/Lib/collections.py, which put tuple as base class in
https://github.com/python/cpython/blob/2.7/Lib/collections.py#L252
while from C API uses the code of PyStructSequence, these two APIs are designed to provide identical behavior to the user, but unfortunately cpython 2.7 didn't do this correctly for example in
https://github.com/python/cpython/blob/2.7/Objects/structseq.c#L466
tp_base was not set to PyTuple_Type while in python 3 it does: https://github.com/python/cpython/blob/master/Objects/structseq.c#L376

This means, the currently check code in pytorch would only reject namedtuples created from C, i.e. those returned by operators, but would correctly handle user created namedtuples.

I believe this is a bug of CPython, and maybe we should file an issue with them?

Copy link
Collaborator Author

@zasdfgbnm zasdfgbnm Feb 1, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See also #16659, the test on python 2.7 succeed on namedtuples created on python.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@apaszke The problem of rejecting PyStructSequence as tuple on python 2.7 was not introduced by this PR, but already there since the beginning. So if the user create their own C binding of python using PyStructSequence, they would get problem of getting rejected when they pass it as arguments where tuple is required. So the only thing that this PR make it worse is,
it makes much much easier for user to create a PyStructSequence, because some pytorch ops natively return PyStructSequence object.

In my opinion, the biggest problem that this PR introduce is:

ret = a.max(dim=0)
if isinstance(ret, tuple):
  do something

The above user code would definitely break on python 2.7, even if we have our own six::isTuple of six.is_tuple, the user would not use it, this is a bc-breaking change...

But since the problem is in CPython, do you think, maybe we should disable this feature completely on python 2.7? Otherwise there would be lots of dirty workaround in pytorch for this issue.

facebook-github-bot pushed a commit that referenced this pull request Feb 16, 2019
…amedtuple return API (#16186)

Summary:
This partially fixes #394 and depend on #15429. I suggest to review this only after #15429 get landed, otherwise the diff might be large to review.

The test only allows explicitly whitelisted operators to have named return.

Differential Revision: D14070735

Pulled By: ezyang

fbshipit-source-id: ace2a672998b4e4a8094f52cbda5aa1cea6e3b42
zdevito pushed a commit to zdevito/ATen that referenced this pull request Feb 16, 2019
…amedtuple return API (#16186)

Summary:
This partially fixes pytorch/pytorch#394 and depend on pytorch/pytorch#15429. I suggest to review this only after pytorch/pytorch#15429 get landed, otherwise the diff might be large to review.

The test only allows explicitly whitelisted operators to have named return.

Differential Revision: D14070735

Pulled By: ezyang

fbshipit-source-id: ace2a672998b4e4a8094f52cbda5aa1cea6e3b42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module: bc-breaking Related to a BC-breaking change open source

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Return namedtuples from torch.* function with multiple return arguments

5 participants