add cross int/float ops to Sym types#87722
add cross int/float ops to Sym types#87722albanD wants to merge 3 commits intogh/albanD/117/basefrom
Conversation
[ghstack-poisoned]
| return self._expr | ||
|
|
||
| def wrap(self, num): | ||
| if not isinstance(num, int): |
There was a problem hiding this comment.
This is a new invariant added here to catch bugs where we would get SymIntNode(PySymFloat(..)). Is that ok?
There was a problem hiding this comment.
This seems good to me - although if we're gonna add the asserts here, what do you think of additionally adding them to the PythonSymIntNodeImpl and PythonSymFloatNodeImpl constructors? e.g. that the PythonSymIntNodeImpl class should only ever hold a PyObject corresponding to an int64_t, or a sym int.
There was a problem hiding this comment.
Forcing SymFooNode to only contain PySymFoo sounds good.
For the content of the PySymFoo, I'm not sure (as discussed below) as we need to at least support foo and bool.
Leaving that for a follow up.
| if SYM_FUNCTION_MODE: | ||
| return _handle_sym_dispatch(sym_int, (self,), {}) | ||
| # TODO: same considerations as PySymInt.__sym_float__ above | ||
| return PySymInt(self.expr, self.shape_env) |
There was a problem hiding this comment.
This means we can have a PySymInt that wraps a sympy float. Is that ok?
The __sym_float__ above is already creating PySymFloat containing sympy integers so I assume yes?
There was a problem hiding this comment.
yeah I think this is wrong, we need to truncate in sympy. In my patch, I'm just not finishing the sym int impl.
| } | ||
|
|
||
| float_magic_methods = {"add", "sub", "mul", "truediv", "ceil", "floor", "eq", "gt", "lt", "le", "ge", "pow"} | ||
| float_magic_methods = {"neg", "add", "sub", "mul", "truediv", "ceil", "floor", "eq", "gt", "lt", "le", "ge", "pow"} |
There was a problem hiding this comment.
neg was added here.
| float_magic_methods = {"add", "sub", "mul", "truediv", "ceil", "floor", "eq", "gt", "lt", "le", "ge", "pow"} | ||
| float_magic_methods = {"neg", "add", "sub", "mul", "truediv", "ceil", "floor", "eq", "gt", "lt", "le", "ge", "pow"} | ||
|
|
||
| always_float_magic_methods = {'truediv'} |
There was a problem hiding this comment.
This is just a refactor to be able to re-use these in the testing.
| # TODO: relational operators actually technically return a | ||
| # PySymBool, this is a type error | ||
| return py_type(out, self.shape_env) | ||
| if py_type is PySymFloat or isinstance(other, PySymFloat): |
There was a problem hiding this comment.
Both our PySymInt and PySymFloat might contain sympy bool! Our custom bool will make sure this goes smothly
There was a problem hiding this comment.
hmm, what does this line actually have to do with sympy bool? I thought that this line is just the "type promotion" line - where for most math ops, mixed int/float compute should give you a float.
| if COLLECT_EXPECT: | ||
| atexit.register(print_seen) | ||
|
|
||
| expected_failure_sym_magic_methods = { |
There was a problem hiding this comment.
All of these should be fairly easy to fix but this PR was getting too big already.
| if (torch::is_symfloat_node(b)) { | ||
| return b.cast<c10::SymFloatNode>(); | ||
| } else { | ||
| return a->sym_float()->wrap(b.cast<double>()); |
There was a problem hiding this comment.
Major change compared to what is in the branch: We have a a->sym_float() here instead of just a. If we don't do that, we end up with a SymFloatNode(PySymInt(double)) which is... not good. Added extra assert in in PySym* to reduce the chance of these things happening.
| .def( | ||
| "__neg__", | ||
| [](c10::SymIntNode a) -> c10::SymIntNode { return a->neg(); }) | ||
| .def( |
There was a problem hiding this comment.
The __min__ and __max__ are not actually a thing in python so just removing these. Happy to add them back if anyone has some custom logic for using these (but I couldn't find any).
| "__neg__", | ||
| [](c10::SymFloatNode a) -> c10::SymFloatNode { return a->neg(); }) | ||
| // We must define `__bool__` on the subclass as the default is `return True` | ||
| .def("__bool__", [](c10::SymFloatNode a) { return a->bool_(); }) |
There was a problem hiding this comment.
This is required otherwise bool(SymFloatNode()) is always True
| // We must define `__bool__` on the subclass as the default is `return True` | ||
| .def("__bool__", [](c10::SymFloatNode a) { return a->bool_(); }) | ||
| .def("__str__", [](c10::SymFloatNode a) { return a->str(); }) | ||
| .def("guard_float", [](c10::SymFloatNode a) { |
| # | ||
| # 1. Run `PYTORCH_COLLECT_EXPECT=1 python test/test_dynamic_shapes.py -k TestSymNumberMagicMethods`. | ||
| # 2. Given the printed xfail list, add them to the set expected_failure_sym_magic_methods. | ||
| COLLECT_EXPECT = os.getenv('PYTORCH_COLLECT_EXPECT', '0') == '1' |
| always_float_magic_methods = {'truediv'} | ||
| always_int_magic_methods = {'floordiv', 'ceil', 'floor'} | ||
|
|
||
| magic_methods_on_builtins = {"min", "max"} |
There was a problem hiding this comment.
should we just kill these, since they desugar into < and >? Or we can always save it for a later PR.
There was a problem hiding this comment.
This add spurious attributes indeed. But at least they get tested because they're in this list. So even if we stop adding attributes for these below, we should keep them in this list for testing to run.
| out = func(expr, other_expr) | ||
| out = sympy.expand(out) | ||
| if method in ["truediv"]: | ||
| if method in always_float_magic_methods: |
There was a problem hiding this comment.
oh an extra assert in PythonSymInt/FloatNodeImpl constructors might be a good idea because of this - if someone updates this list, but forgets to update the corresponding output type of the methods on the PythonSymIntNodeImpl class, an assert would be nice.
There was a problem hiding this comment.
In the constructor, it is another issue as today we do allow SymIntNode(PySymInt()) to contain anything: int, float or bool (same for the float version).
We could restrict this to int+bool and float+bool maybe but that might be breaking a bunch of stuff.
[ghstack-poisoned]
| }; | ||
| virtual SymFloatNode eq(const SymIntNode& other) { | ||
| TORCH_CHECK(false, "NYI"); | ||
| }; |
| TORCH_CHECK(false, "NYI"); | ||
| }; | ||
| virtual SymIntNode ge(const SymIntNode& other) { | ||
| virtual SymFloatNode le(const SymFloatNode& other) { |
There was a problem hiding this comment.
This looks wrong, less than or equals should return a bool (or failing that, an int)
There was a problem hiding this comment.
The current situation here is that both symfloat AND symint can contain bools. And both of them support bool() as expected.
I can move everything to say that only symint can contain a bool (and forbid calling bool(symfloat).
I don't have any strong opinion either way.
@bdhirsh what's your vote?
There was a problem hiding this comment.
SymBoolNode feels like it would be nicer to have in the long term. Today, we have:
SymIntNode holds a pyobject, which we guarantee is always either a PySymInt or int64_t (well, until nested tensor starts using them). The PySymIntholds asympy.expr`, which... is always either an integer expression or a boolean expression
SymFloatNode holds a pyobject that's guaranteed to be either a PySymFloat or a double, and PySymFloat holds a sympy.expr which is always either a float expression or a boolean expression.
I guess for me, given that we already have a looser definition around what PySymInt can hold, it doesn't feel too bad for PySymFloat to act the same way? Unless we want to go all-in on the one-to-one correspondence and add SymBool.
I'm not strongly opinionated about this though.
|
I'm not a fan of bool getting shoved in int/float. If you want, I can bash out SymBool and type erasure on Node. I think it would be pretty easy. |
[ghstack-poisoned]
This refactor was prompted by challenges handling mixed int/float operations in C++. A previous version of this patch added overloads for each permutation of int/float and was unwieldy #87722 This PR takes a different approach. The general outline of the patch is to combine the C++ types SymIntNode and SymFloatNode into a single type, SymNode. This is type erased; we no longer know statically at C++ if we have an int/float and have to test it with the is_int()/is_float() virtual methods. This has a number of knock on effects. - We no longer have C++ classes to bind to Python. Instead, we take an entirely new approach to our Python API, where we have a SymInt/SymFloat class defined entirely in Python, which hold a SymNode (which corresponds to the C++ SymNode). However, SymNode is not pybind11-bound; instead, it lives as-is in Python, and is wrapped into C++ SymNode using PythonSymNode when it goes into C++. This implies a userland rename. In principle, it is also possible for the canonical implementation of SymNode to be written in C++, and then bound to Python with pybind11 (we have this code, although it is commented out.) However, I did not implement this as we currently have no C++ implementations of SymNode. Because we do return SymInt/SymFloat from C++ bindings, the C++ binding code needs to know how to find these classes. Currently, this is done just by manually importing torch and getting the attributes. - Because SymInt/SymFloat are easy Python wrappers, __sym_dispatch__ now takes SymInt/SymFloat, rather than SymNode, bringing it in line with how __torch_dispatch__ works. Some miscellaneous improvements: - SymInt now has a constructor that takes SymNode. Note that this constructor is ambiguous if you pass in a subclass of SymNode, so an explicit downcast is necessary. This means toSymFloat/toSymInt are no more. This is a mild optimization as it means rvalue reference works automatically. - We uniformly use the caster for c10::SymInt/SymFloat, rather than going the long way via the SymIntNode/SymFloatNode. - Removed some unnecessary toSymInt/toSymFloat calls in normalize_* functions, pretty sure this doesn't do anything. - guard_int is now a free function, since to guard on an int you cannot assume the method exists. A function can handle both int and SymInt inputs. - We clean up the magic method definition code for SymInt/SymFloat/SymNode. ONLY the user classes (SymInt/SymFloat) get magic methods; SymNode gets plain methods; this is to help avoid confusion between the two types. Signed-off-by: Edward Z. Yang <ezyang@fb.com> [ghstack-poisoned]
This refactor was prompted by challenges handling mixed int/float operations in C++. A previous version of this patch added overloads for each permutation of int/float and was unwieldy #87722 This PR takes a different approach. The general outline of the patch is to combine the C++ types SymIntNode and SymFloatNode into a single type, SymNode. This is type erased; we no longer know statically at C++ if we have an int/float and have to test it with the is_int()/is_float() virtual methods. This has a number of knock on effects. - We no longer have C++ classes to bind to Python. Instead, we take an entirely new approach to our Python API, where we have a SymInt/SymFloat class defined entirely in Python, which hold a SymNode (which corresponds to the C++ SymNode). However, SymNode is not pybind11-bound; instead, it lives as-is in Python, and is wrapped into C++ SymNode using PythonSymNode when it goes into C++. This implies a userland rename. In principle, it is also possible for the canonical implementation of SymNode to be written in C++, and then bound to Python with pybind11 (we have this code, although it is commented out.) However, I did not implement this as we currently have no C++ implementations of SymNode. Because we do return SymInt/SymFloat from C++ bindings, the C++ binding code needs to know how to find these classes. Currently, this is done just by manually importing torch and getting the attributes. - Because SymInt/SymFloat are easy Python wrappers, __sym_dispatch__ now takes SymInt/SymFloat, rather than SymNode, bringing it in line with how __torch_dispatch__ works. Some miscellaneous improvements: - SymInt now has a constructor that takes SymNode. Note that this constructor is ambiguous if you pass in a subclass of SymNode, so an explicit downcast is necessary. This means toSymFloat/toSymInt are no more. This is a mild optimization as it means rvalue reference works automatically. - We uniformly use the caster for c10::SymInt/SymFloat, rather than going the long way via the SymIntNode/SymFloatNode. - Removed some unnecessary toSymInt/toSymFloat calls in normalize_* functions, pretty sure this doesn't do anything. - guard_int is now a free function, since to guard on an int you cannot assume the method exists. A function can handle both int and SymInt inputs. - We clean up the magic method definition code for SymInt/SymFloat/SymNode. ONLY the user classes (SymInt/SymFloat) get magic methods; SymNode gets plain methods; this is to help avoid confusion between the two types. Signed-off-by: Edward Z. Yang <ezyangfb.com> cc jansel mlazos soumith voznesenskym yanboliang penguinwu anijain2305 [ghstack-poisoned]
…Node" This refactor was prompted by challenges handling mixed int/float operations in C++. A previous version of this patch added overloads for each permutation of int/float and was unwieldy #87722 This PR takes a different approach. The general outline of the patch is to combine the C++ types SymIntNode and SymFloatNode into a single type, SymNode. This is type erased; we no longer know statically at C++ if we have an int/float and have to test it with the is_int()/is_float() virtual methods. This has a number of knock on effects. - We no longer have C++ classes to bind to Python. Instead, we take an entirely new approach to our Python API, where we have a SymInt/SymFloat class defined entirely in Python, which hold a SymNode (which corresponds to the C++ SymNode). However, SymNode is not pybind11-bound; instead, it lives as-is in Python, and is wrapped into C++ SymNode using PythonSymNode when it goes into C++. This implies a userland rename. In principle, it is also possible for the canonical implementation of SymNode to be written in C++, and then bound to Python with pybind11 (we have this code, although it is commented out.) However, I did not implement this as we currently have no C++ implementations of SymNode. Because we do return SymInt/SymFloat from C++ bindings, the C++ binding code needs to know how to find these classes. Currently, this is done just by manually importing torch and getting the attributes. - Because SymInt/SymFloat are easy Python wrappers, __sym_dispatch__ now takes SymInt/SymFloat, rather than SymNode, bringing it in line with how __torch_dispatch__ works. Some miscellaneous improvements: - SymInt now has a constructor that takes SymNode. Note that this constructor is ambiguous if you pass in a subclass of SymNode, so an explicit downcast is necessary. This means toSymFloat/toSymInt are no more. This is a mild optimization as it means rvalue reference works automatically. - We uniformly use the caster for c10::SymInt/SymFloat, rather than going the long way via the SymIntNode/SymFloatNode. - Removed some unnecessary toSymInt/toSymFloat calls in normalize_* functions, pretty sure this doesn't do anything. - guard_int is now a free function, since to guard on an int you cannot assume the method exists. A function can handle both int and SymInt inputs. - We clean up the magic method definition code for SymInt/SymFloat/SymNode. ONLY the user classes (SymInt/SymFloat) get magic methods; SymNode gets plain methods; this is to help avoid confusion between the two types. Signed-off-by: Edward Z. Yang <ezyangfb.com> cc jansel mlazos soumith voznesenskym yanboliang penguinwu anijain2305 [ghstack-poisoned]
This refactor was prompted by challenges handling mixed int/float operations in C++. A previous version of this patch added overloads for each permutation of int/float and was unwieldy #87722 This PR takes a different approach. The general outline of the patch is to combine the C++ types SymIntNode and SymFloatNode into a single type, SymNode. This is type erased; we no longer know statically at C++ if we have an int/float and have to test it with the is_int()/is_float() virtual methods. This has a number of knock on effects. - We no longer have C++ classes to bind to Python. Instead, we take an entirely new approach to our Python API, where we have a SymInt/SymFloat class defined entirely in Python, which hold a SymNode (which corresponds to the C++ SymNode). However, SymNode is not pybind11-bound; instead, it lives as-is in Python, and is wrapped into C++ SymNode using PythonSymNode when it goes into C++. This implies a userland rename. In principle, it is also possible for the canonical implementation of SymNode to be written in C++, and then bound to Python with pybind11 (we have this code, although it is commented out.) However, I did not implement this as we currently have no C++ implementations of SymNode. Because we do return SymInt/SymFloat from C++ bindings, the C++ binding code needs to know how to find these classes. Currently, this is done just by manually importing torch and getting the attributes. - Because SymInt/SymFloat are easy Python wrappers, __sym_dispatch__ now takes SymInt/SymFloat, rather than SymNode, bringing it in line with how __torch_dispatch__ works. Some miscellaneous improvements: - SymInt now has a constructor that takes SymNode. Note that this constructor is ambiguous if you pass in a subclass of SymNode, so an explicit downcast is necessary. This means toSymFloat/toSymInt are no more. This is a mild optimization as it means rvalue reference works automatically. - We uniformly use the caster for c10::SymInt/SymFloat, rather than going the long way via the SymIntNode/SymFloatNode. - Removed some unnecessary toSymInt/toSymFloat calls in normalize_* functions, pretty sure this doesn't do anything. - guard_int is now a free function, since to guard on an int you cannot assume the method exists. A function can handle both int and SymInt inputs. - We clean up the magic method definition code for SymInt/SymFloat/SymNode. ONLY the user classes (SymInt/SymFloat) get magic methods; SymNode gets plain methods; this is to help avoid confusion between the two types. Signed-off-by: Edward Z. Yang <ezyangfb.com> cc jansel mlazos soumith voznesenskym yanboliang penguinwu anijain2305 [ghstack-poisoned]
…Node" This refactor was prompted by challenges handling mixed int/float operations in C++. A previous version of this patch added overloads for each permutation of int/float and was unwieldy #87722 This PR takes a different approach. The general outline of the patch is to combine the C++ types SymIntNode and SymFloatNode into a single type, SymNode. This is type erased; we no longer know statically at C++ if we have an int/float and have to test it with the is_int()/is_float() virtual methods. This has a number of knock on effects. - We no longer have C++ classes to bind to Python. Instead, we take an entirely new approach to our Python API, where we have a SymInt/SymFloat class defined entirely in Python, which hold a SymNode (which corresponds to the C++ SymNode). However, SymNode is not pybind11-bound; instead, it lives as-is in Python, and is wrapped into C++ SymNode using PythonSymNode when it goes into C++. This implies a userland rename. In principle, it is also possible for the canonical implementation of SymNode to be written in C++, and then bound to Python with pybind11 (we have this code, although it is commented out.) However, I did not implement this as we currently have no C++ implementations of SymNode. Because we do return SymInt/SymFloat from C++ bindings, the C++ binding code needs to know how to find these classes. Currently, this is done just by manually importing torch and getting the attributes. - Because SymInt/SymFloat are easy Python wrappers, __sym_dispatch__ now takes SymInt/SymFloat, rather than SymNode, bringing it in line with how __torch_dispatch__ works. Some miscellaneous improvements: - SymInt now has a constructor that takes SymNode. Note that this constructor is ambiguous if you pass in a subclass of SymNode, so an explicit downcast is necessary. This means toSymFloat/toSymInt are no more. This is a mild optimization as it means rvalue reference works automatically. - We uniformly use the caster for c10::SymInt/SymFloat, rather than going the long way via the SymIntNode/SymFloatNode. - Removed some unnecessary toSymInt/toSymFloat calls in normalize_* functions, pretty sure this doesn't do anything. - guard_int is now a free function, since to guard on an int you cannot assume the method exists. A function can handle both int and SymInt inputs. - We clean up the magic method definition code for SymInt/SymFloat/SymNode. ONLY the user classes (SymInt/SymFloat) get magic methods; SymNode gets plain methods; this is to help avoid confusion between the two types. Signed-off-by: Edward Z. Yang <ezyangfb.com> cc jansel mlazos soumith voznesenskym yanboliang penguinwu anijain2305 [ghstack-poisoned]
This refactor was prompted by challenges handling mixed int/float operations in C++. A previous version of this patch added overloads for each permutation of int/float and was unwieldy #87722 This PR takes a different approach. The general outline of the patch is to combine the C++ types SymIntNode and SymFloatNode into a single type, SymNode. This is type erased; we no longer know statically at C++ if we have an int/float and have to test it with the is_int()/is_float() virtual methods. This has a number of knock on effects. - We no longer have C++ classes to bind to Python. Instead, we take an entirely new approach to our Python API, where we have a SymInt/SymFloat class defined entirely in Python, which hold a SymNode (which corresponds to the C++ SymNode). However, SymNode is not pybind11-bound; instead, it lives as-is in Python, and is wrapped into C++ SymNode using PythonSymNode when it goes into C++. This implies a userland rename. In principle, it is also possible for the canonical implementation of SymNode to be written in C++, and then bound to Python with pybind11 (we have this code, although it is commented out.) However, I did not implement this as we currently have no C++ implementations of SymNode. Because we do return SymInt/SymFloat from C++ bindings, the C++ binding code needs to know how to find these classes. Currently, this is done just by manually importing torch and getting the attributes. - Because SymInt/SymFloat are easy Python wrappers, __sym_dispatch__ now takes SymInt/SymFloat, rather than SymNode, bringing it in line with how __torch_dispatch__ works. Some miscellaneous improvements: - SymInt now has a constructor that takes SymNode. Note that this constructor is ambiguous if you pass in a subclass of SymNode, so an explicit downcast is necessary. This means toSymFloat/toSymInt are no more. This is a mild optimization as it means rvalue reference works automatically. - We uniformly use the caster for c10::SymInt/SymFloat, rather than going the long way via the SymIntNode/SymFloatNode. - Removed some unnecessary toSymInt/toSymFloat calls in normalize_* functions, pretty sure this doesn't do anything. - guard_int is now a free function, since to guard on an int you cannot assume the method exists. A function can handle both int and SymInt inputs. - We clean up the magic method definition code for SymInt/SymFloat/SymNode. ONLY the user classes (SymInt/SymFloat) get magic methods; SymNode gets plain methods; this is to help avoid confusion between the two types. Signed-off-by: Edward Z. Yang <ezyangfb.com> cc jansel mlazos soumith voznesenskym yanboliang penguinwu anijain2305 [ghstack-poisoned]
This refactor was prompted by challenges handling mixed int/float operations in C++. A previous version of this patch added overloads for each permutation of int/float and was unwieldy #87722 This PR takes a different approach. The general outline of the patch is to combine the C++ types SymIntNode and SymFloatNode into a single type, SymNode. This is type erased; we no longer know statically at C++ if we have an int/float and have to test it with the is_int()/is_float() virtual methods. This has a number of knock on effects. - We no longer have C++ classes to bind to Python. Instead, we take an entirely new approach to our Python API, where we have a SymInt/SymFloat class defined entirely in Python, which hold a SymNode (which corresponds to the C++ SymNode). However, SymNode is not pybind11-bound; instead, it lives as-is in Python, and is wrapped into C++ SymNode using PythonSymNode when it goes into C++. This implies a userland rename. In principle, it is also possible for the canonical implementation of SymNode to be written in C++, and then bound to Python with pybind11 (we have this code, although it is commented out.) However, I did not implement this as we currently have no C++ implementations of SymNode. Because we do return SymInt/SymFloat from C++ bindings, the C++ binding code needs to know how to find these classes. Currently, this is done just by manually importing torch and getting the attributes. - Because SymInt/SymFloat are easy Python wrappers, __sym_dispatch__ now takes SymInt/SymFloat, rather than SymNode, bringing it in line with how __torch_dispatch__ works. Some miscellaneous improvements: - SymInt now has a constructor that takes SymNode. Note that this constructor is ambiguous if you pass in a subclass of SymNode, so an explicit downcast is necessary. This means toSymFloat/toSymInt are no more. This is a mild optimization as it means rvalue reference works automatically. - We uniformly use the caster for c10::SymInt/SymFloat, rather than going the long way via the SymIntNode/SymFloatNode. - Removed some unnecessary toSymInt/toSymFloat calls in normalize_* functions, pretty sure this doesn't do anything. - guard_int is now a free function, since to guard on an int you cannot assume the method exists. A function can handle both int and SymInt inputs. - We clean up the magic method definition code for SymInt/SymFloat/SymNode. ONLY the user classes (SymInt/SymFloat) get magic methods; SymNode gets plain methods; this is to help avoid confusion between the two types. Signed-off-by: Edward Z. Yang <ezyangfb.com> ghstack-source-id: 4800f32 Pull Request resolved: #87817
…Node" This refactor was prompted by challenges handling mixed int/float operations in C++. A previous version of this patch added overloads for each permutation of int/float and was unwieldy #87722 This PR takes a different approach. The general outline of the patch is to combine the C++ types SymIntNode and SymFloatNode into a single type, SymNode. This is type erased; we no longer know statically at C++ if we have an int/float and have to test it with the is_int()/is_float() virtual methods. This has a number of knock on effects. - We no longer have C++ classes to bind to Python. Instead, we take an entirely new approach to our Python API, where we have a SymInt/SymFloat class defined entirely in Python, which hold a SymNode (which corresponds to the C++ SymNode). However, SymNode is not pybind11-bound; instead, it lives as-is in Python, and is wrapped into C++ SymNode using PythonSymNode when it goes into C++. This implies a userland rename. In principle, it is also possible for the canonical implementation of SymNode to be written in C++, and then bound to Python with pybind11 (we have this code, although it is commented out.) However, I did not implement this as we currently have no C++ implementations of SymNode. Because we do return SymInt/SymFloat from C++ bindings, the C++ binding code needs to know how to find these classes. Currently, this is done just by manually importing torch and getting the attributes. - Because SymInt/SymFloat are easy Python wrappers, __sym_dispatch__ now takes SymInt/SymFloat, rather than SymNode, bringing it in line with how __torch_dispatch__ works. Some miscellaneous improvements: - SymInt now has a constructor that takes SymNode. Note that this constructor is ambiguous if you pass in a subclass of SymNode, so an explicit downcast is necessary. This means toSymFloat/toSymInt are no more. This is a mild optimization as it means rvalue reference works automatically. - We uniformly use the caster for c10::SymInt/SymFloat, rather than going the long way via the SymIntNode/SymFloatNode. - Removed some unnecessary toSymInt/toSymFloat calls in normalize_* functions, pretty sure this doesn't do anything. - guard_int is now a free function, since to guard on an int you cannot assume the method exists. A function can handle both int and SymInt inputs. - We clean up the magic method definition code for SymInt/SymFloat/SymNode. ONLY the user classes (SymInt/SymFloat) get magic methods; SymNode gets plain methods; this is to help avoid confusion between the two types. Signed-off-by: Edward Z. Yang <ezyangfb.com> cc jansel mlazos soumith voznesenskym yanboliang penguinwu anijain2305 [ghstack-poisoned]
This refactor was prompted by challenges handling mixed int/float operations in C++. A previous version of this patch added overloads for each permutation of int/float and was unwieldy #87722 This PR takes a different approach. The general outline of the patch is to combine the C++ types SymIntNode and SymFloatNode into a single type, SymNode. This is type erased; we no longer know statically at C++ if we have an int/float and have to test it with the is_int()/is_float() virtual methods. This has a number of knock on effects. - We no longer have C++ classes to bind to Python. Instead, we take an entirely new approach to our Python API, where we have a SymInt/SymFloat class defined entirely in Python, which hold a SymNode (which corresponds to the C++ SymNode). However, SymNode is not pybind11-bound; instead, it lives as-is in Python, and is wrapped into C++ SymNode using PythonSymNode when it goes into C++. This implies a userland rename. In principle, it is also possible for the canonical implementation of SymNode to be written in C++, and then bound to Python with pybind11 (we have this code, although it is commented out.) However, I did not implement this as we currently have no C++ implementations of SymNode. Because we do return SymInt/SymFloat from C++ bindings, the C++ binding code needs to know how to find these classes. Currently, this is done just by manually importing torch and getting the attributes. - Because SymInt/SymFloat are easy Python wrappers, __sym_dispatch__ now takes SymInt/SymFloat, rather than SymNode, bringing it in line with how __torch_dispatch__ works. Some miscellaneous improvements: - SymInt now has a constructor that takes SymNode. Note that this constructor is ambiguous if you pass in a subclass of SymNode, so an explicit downcast is necessary. This means toSymFloat/toSymInt are no more. This is a mild optimization as it means rvalue reference works automatically. - We uniformly use the caster for c10::SymInt/SymFloat, rather than going the long way via the SymIntNode/SymFloatNode. - Removed some unnecessary toSymInt/toSymFloat calls in normalize_* functions, pretty sure this doesn't do anything. - guard_int is now a free function, since to guard on an int you cannot assume the method exists. A function can handle both int and SymInt inputs. - We clean up the magic method definition code for SymInt/SymFloat/SymNode. ONLY the user classes (SymInt/SymFloat) get magic methods; SymNode gets plain methods; this is to help avoid confusion between the two types. Signed-off-by: Edward Z. Yang <ezyangfb.com> cc jansel mlazos soumith voznesenskym yanboliang penguinwu anijain2305 [ghstack-poisoned]
…Node" This refactor was prompted by challenges handling mixed int/float operations in C++. A previous version of this patch added overloads for each permutation of int/float and was unwieldy #87722 This PR takes a different approach. The general outline of the patch is to combine the C++ types SymIntNode and SymFloatNode into a single type, SymNode. This is type erased; we no longer know statically at C++ if we have an int/float and have to test it with the is_int()/is_float() virtual methods. This has a number of knock on effects. - We no longer have C++ classes to bind to Python. Instead, we take an entirely new approach to our Python API, where we have a SymInt/SymFloat class defined entirely in Python, which hold a SymNode (which corresponds to the C++ SymNode). However, SymNode is not pybind11-bound; instead, it lives as-is in Python, and is wrapped into C++ SymNode using PythonSymNode when it goes into C++. This implies a userland rename. In principle, it is also possible for the canonical implementation of SymNode to be written in C++, and then bound to Python with pybind11 (we have this code, although it is commented out.) However, I did not implement this as we currently have no C++ implementations of SymNode. Because we do return SymInt/SymFloat from C++ bindings, the C++ binding code needs to know how to find these classes. Currently, this is done just by manually importing torch and getting the attributes. - Because SymInt/SymFloat are easy Python wrappers, __sym_dispatch__ now takes SymInt/SymFloat, rather than SymNode, bringing it in line with how __torch_dispatch__ works. Some miscellaneous improvements: - SymInt now has a constructor that takes SymNode. Note that this constructor is ambiguous if you pass in a subclass of SymNode, so an explicit downcast is necessary. This means toSymFloat/toSymInt are no more. This is a mild optimization as it means rvalue reference works automatically. - We uniformly use the caster for c10::SymInt/SymFloat, rather than going the long way via the SymIntNode/SymFloatNode. - Removed some unnecessary toSymInt/toSymFloat calls in normalize_* functions, pretty sure this doesn't do anything. - guard_int is now a free function, since to guard on an int you cannot assume the method exists. A function can handle both int and SymInt inputs. - We clean up the magic method definition code for SymInt/SymFloat/SymNode. ONLY the user classes (SymInt/SymFloat) get magic methods; SymNode gets plain methods; this is to help avoid confusion between the two types. Signed-off-by: Edward Z. Yang <ezyangfb.com> cc jansel mlazos soumith voznesenskym yanboliang penguinwu anijain2305 [ghstack-poisoned]
This refactor was prompted by challenges handling mixed int/float operations in C++. A previous version of this patch added overloads for each permutation of int/float and was unwieldy #87722 This PR takes a different approach. The general outline of the patch is to combine the C++ types SymIntNode and SymFloatNode into a single type, SymNode. This is type erased; we no longer know statically at C++ if we have an int/float and have to test it with the is_int()/is_float() virtual methods. This has a number of knock on effects. - We no longer have C++ classes to bind to Python. Instead, we take an entirely new approach to our Python API, where we have a SymInt/SymFloat class defined entirely in Python, which hold a SymNode (which corresponds to the C++ SymNode). However, SymNode is not pybind11-bound; instead, it lives as-is in Python, and is wrapped into C++ SymNode using PythonSymNode when it goes into C++. This implies a userland rename. In principle, it is also possible for the canonical implementation of SymNode to be written in C++, and then bound to Python with pybind11 (we have this code, although it is commented out.) However, I did not implement this as we currently have no C++ implementations of SymNode. Because we do return SymInt/SymFloat from C++ bindings, the C++ binding code needs to know how to find these classes. Currently, this is done just by manually importing torch and getting the attributes. - Because SymInt/SymFloat are easy Python wrappers, __sym_dispatch__ now takes SymInt/SymFloat, rather than SymNode, bringing it in line with how __torch_dispatch__ works. Some miscellaneous improvements: - SymInt now has a constructor that takes SymNode. Note that this constructor is ambiguous if you pass in a subclass of SymNode, so an explicit downcast is necessary. This means toSymFloat/toSymInt are no more. This is a mild optimization as it means rvalue reference works automatically. - We uniformly use the caster for c10::SymInt/SymFloat, rather than going the long way via the SymIntNode/SymFloatNode. - Removed some unnecessary toSymInt/toSymFloat calls in normalize_* functions, pretty sure this doesn't do anything. - guard_int is now a free function, since to guard on an int you cannot assume the method exists. A function can handle both int and SymInt inputs. - We clean up the magic method definition code for SymInt/SymFloat/SymNode. ONLY the user classes (SymInt/SymFloat) get magic methods; SymNode gets plain methods; this is to help avoid confusion between the two types. Signed-off-by: Edward Z. Yang <ezyangfb.com> cc jansel mlazos soumith voznesenskym yanboliang penguinwu anijain2305 [ghstack-poisoned]
This refactor was prompted by challenges handling mixed int/float operations in C++. A previous version of this patch added overloads for each permutation of int/float and was unwieldy #87722 This PR takes a different approach. The general outline of the patch is to combine the C++ types SymIntNode and SymFloatNode into a single type, SymNode. This is type erased; we no longer know statically at C++ if we have an int/float and have to test it with the is_int()/is_float() virtual methods. This has a number of knock on effects. - We no longer have C++ classes to bind to Python. Instead, we take an entirely new approach to our Python API, where we have a SymInt/SymFloat class defined entirely in Python, which hold a SymNode (which corresponds to the C++ SymNode). However, SymNode is not pybind11-bound; instead, it lives as-is in Python, and is wrapped into C++ SymNode using PythonSymNode when it goes into C++. This implies a userland rename. In principle, it is also possible for the canonical implementation of SymNode to be written in C++, and then bound to Python with pybind11 (we have this code, although it is commented out.) However, I did not implement this as we currently have no C++ implementations of SymNode. Because we do return SymInt/SymFloat from C++ bindings, the C++ binding code needs to know how to find these classes. Currently, this is done just by manually importing torch and getting the attributes. - Because SymInt/SymFloat are easy Python wrappers, __sym_dispatch__ now takes SymInt/SymFloat, rather than SymNode, bringing it in line with how __torch_dispatch__ works. Some miscellaneous improvements: - SymInt now has a constructor that takes SymNode. Note that this constructor is ambiguous if you pass in a subclass of SymNode, so an explicit downcast is necessary. This means toSymFloat/toSymInt are no more. This is a mild optimization as it means rvalue reference works automatically. - We uniformly use the caster for c10::SymInt/SymFloat, rather than going the long way via the SymIntNode/SymFloatNode. - Removed some unnecessary toSymInt/toSymFloat calls in normalize_* functions, pretty sure this doesn't do anything. - guard_int is now a free function, since to guard on an int you cannot assume the method exists. A function can handle both int and SymInt inputs. - We clean up the magic method definition code for SymInt/SymFloat/SymNode. ONLY the user classes (SymInt/SymFloat) get magic methods; SymNode gets plain methods; this is to help avoid confusion between the two types. Signed-off-by: Edward Z. Yang <ezyangfb.com> ghstack-source-id: 7e4e282 Pull Request resolved: #87817
This refactor was prompted by challenges handling mixed int/float operations in C++. A previous version of this patch added overloads for each permutation of int/float and was unwieldy #87722 This PR takes a different approach. The general outline of the patch is to combine the C++ types SymIntNode and SymFloatNode into a single type, SymNode. This is type erased; we no longer know statically at C++ if we have an int/float and have to test it with the is_int()/is_float() virtual methods. This has a number of knock on effects. - We no longer have C++ classes to bind to Python. Instead, we take an entirely new approach to our Python API, where we have a SymInt/SymFloat class defined entirely in Python, which hold a SymNode (which corresponds to the C++ SymNode). However, SymNode is not pybind11-bound; instead, it lives as-is in Python, and is wrapped into C++ SymNode using PythonSymNode when it goes into C++. This implies a userland rename. In principle, it is also possible for the canonical implementation of SymNode to be written in C++, and then bound to Python with pybind11 (we have this code, although it is commented out.) However, I did not implement this as we currently have no C++ implementations of SymNode. Because we do return SymInt/SymFloat from C++ bindings, the C++ binding code needs to know how to find these classes. Currently, this is done just by manually importing torch and getting the attributes. - Because SymInt/SymFloat are easy Python wrappers, __sym_dispatch__ now takes SymInt/SymFloat, rather than SymNode, bringing it in line with how __torch_dispatch__ works. Some miscellaneous improvements: - SymInt now has a constructor that takes SymNode. Note that this constructor is ambiguous if you pass in a subclass of SymNode, so an explicit downcast is necessary. This means toSymFloat/toSymInt are no more. This is a mild optimization as it means rvalue reference works automatically. - We uniformly use the caster for c10::SymInt/SymFloat, rather than going the long way via the SymIntNode/SymFloatNode. - Removed some unnecessary toSymInt/toSymFloat calls in normalize_* functions, pretty sure this doesn't do anything. - guard_int is now a free function, since to guard on an int you cannot assume the method exists. A function can handle both int and SymInt inputs. - We clean up the magic method definition code for SymInt/SymFloat/SymNode. ONLY the user classes (SymInt/SymFloat) get magic methods; SymNode gets plain methods; this is to help avoid confusion between the two types. Signed-off-by: Edward Z. Yang <ezyang@fb.com> cc @jansel @mlazos @soumith @voznesenskym @yanboliang @penguinwu @anijain2305 Pull Request resolved: #87817 Approved by: https://github.com/albanD, https://github.com/anjali411
This refactor was prompted by challenges handling mixed int/float operations in C++. A previous version of this patch added overloads for each permutation of int/float and was unwieldy pytorch#87722 This PR takes a different approach. The general outline of the patch is to combine the C++ types SymIntNode and SymFloatNode into a single type, SymNode. This is type erased; we no longer know statically at C++ if we have an int/float and have to test it with the is_int()/is_float() virtual methods. This has a number of knock on effects. - We no longer have C++ classes to bind to Python. Instead, we take an entirely new approach to our Python API, where we have a SymInt/SymFloat class defined entirely in Python, which hold a SymNode (which corresponds to the C++ SymNode). However, SymNode is not pybind11-bound; instead, it lives as-is in Python, and is wrapped into C++ SymNode using PythonSymNode when it goes into C++. This implies a userland rename. In principle, it is also possible for the canonical implementation of SymNode to be written in C++, and then bound to Python with pybind11 (we have this code, although it is commented out.) However, I did not implement this as we currently have no C++ implementations of SymNode. Because we do return SymInt/SymFloat from C++ bindings, the C++ binding code needs to know how to find these classes. Currently, this is done just by manually importing torch and getting the attributes. - Because SymInt/SymFloat are easy Python wrappers, __sym_dispatch__ now takes SymInt/SymFloat, rather than SymNode, bringing it in line with how __torch_dispatch__ works. Some miscellaneous improvements: - SymInt now has a constructor that takes SymNode. Note that this constructor is ambiguous if you pass in a subclass of SymNode, so an explicit downcast is necessary. This means toSymFloat/toSymInt are no more. This is a mild optimization as it means rvalue reference works automatically. - We uniformly use the caster for c10::SymInt/SymFloat, rather than going the long way via the SymIntNode/SymFloatNode. - Removed some unnecessary toSymInt/toSymFloat calls in normalize_* functions, pretty sure this doesn't do anything. - guard_int is now a free function, since to guard on an int you cannot assume the method exists. A function can handle both int and SymInt inputs. - We clean up the magic method definition code for SymInt/SymFloat/SymNode. ONLY the user classes (SymInt/SymFloat) get magic methods; SymNode gets plain methods; this is to help avoid confusion between the two types. Signed-off-by: Edward Z. Yang <ezyang@fb.com> cc @jansel @mlazos @soumith @voznesenskym @yanboliang @penguinwu @anijain2305 Pull Request resolved: pytorch#87817 Approved by: https://github.com/albanD, https://github.com/anjali411
This refactor was prompted by challenges handling mixed int/float operations in C++. A previous version of this patch added overloads for each permutation of int/float and was unwieldy pytorch#87722 This PR takes a different approach. The general outline of the patch is to combine the C++ types SymIntNode and SymFloatNode into a single type, SymNode. This is type erased; we no longer know statically at C++ if we have an int/float and have to test it with the is_int()/is_float() virtual methods. This has a number of knock on effects. - We no longer have C++ classes to bind to Python. Instead, we take an entirely new approach to our Python API, where we have a SymInt/SymFloat class defined entirely in Python, which hold a SymNode (which corresponds to the C++ SymNode). However, SymNode is not pybind11-bound; instead, it lives as-is in Python, and is wrapped into C++ SymNode using PythonSymNode when it goes into C++. This implies a userland rename. In principle, it is also possible for the canonical implementation of SymNode to be written in C++, and then bound to Python with pybind11 (we have this code, although it is commented out.) However, I did not implement this as we currently have no C++ implementations of SymNode. Because we do return SymInt/SymFloat from C++ bindings, the C++ binding code needs to know how to find these classes. Currently, this is done just by manually importing torch and getting the attributes. - Because SymInt/SymFloat are easy Python wrappers, __sym_dispatch__ now takes SymInt/SymFloat, rather than SymNode, bringing it in line with how __torch_dispatch__ works. Some miscellaneous improvements: - SymInt now has a constructor that takes SymNode. Note that this constructor is ambiguous if you pass in a subclass of SymNode, so an explicit downcast is necessary. This means toSymFloat/toSymInt are no more. This is a mild optimization as it means rvalue reference works automatically. - We uniformly use the caster for c10::SymInt/SymFloat, rather than going the long way via the SymIntNode/SymFloatNode. - Removed some unnecessary toSymInt/toSymFloat calls in normalize_* functions, pretty sure this doesn't do anything. - guard_int is now a free function, since to guard on an int you cannot assume the method exists. A function can handle both int and SymInt inputs. - We clean up the magic method definition code for SymInt/SymFloat/SymNode. ONLY the user classes (SymInt/SymFloat) get magic methods; SymNode gets plain methods; this is to help avoid confusion between the two types. Signed-off-by: Edward Z. Yang <ezyang@fb.com> cc @jansel @mlazos @soumith @voznesenskym @yanboliang @penguinwu @anijain2305 Pull Request resolved: pytorch#87817 Approved by: https://github.com/albanD, https://github.com/anjali411
Stack from ghstack (oldest at bottom):