Conversation
🔗 Helpful links
✅ No Failures (0 Pending)As of commit 50810d1 (more details on the Dr. CI page): Expand to see more💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Please report bugs/suggestions to the (internal) Dr. CI Users group. |
e43d362 to
8bdcc93
Compare
| virtual bool bool_() { TORCH_CHECK(false, "NYI"); }; | ||
| virtual int64_t int_() { TORCH_CHECK(false, "NYI"); } | ||
| virtual std::string str() { TORCH_CHECK(false, "NYI"); }; | ||
| virtual std::ostream& operator<<(std::ostream& os) { |
There was a problem hiding this comment.
does this still need to be virtual
There was a problem hiding this comment.
I think it does the LTC impl won't have a python object to redispatch or fallback to
There was a problem hiding this comment.
but they'll implement str instead?
There was a problem hiding this comment.
OH, just realized what you meant ... operator<< does NOT to be virtual me thinks :)
| virtual SymbolicIntNode* gt(SymbolicIntNode* other) { TORCH_CHECK(false, "NYI"); }; | ||
| virtual SymbolicIntNode* lt(SymbolicIntNode* other) { TORCH_CHECK(false, "NYI"); }; | ||
| virtual SymbolicIntNode* wrap(int64_t num) { TORCH_CHECK(false, "NYI"); }; | ||
| virtual bool bool_() { TORCH_CHECK(false, "NYI"); }; |
There was a problem hiding this comment.
Hmm, not sure why we need this on top of the int conversion. Is the problem that you also want a SymbolicBoolNode as well?
There was a problem hiding this comment.
@horace overrode bool in his PoC, so I added it as well. I can provide a default implementation static_cast<boo>(this->int_()) or just let implementers implement via int_.
|
|
||
| // if we are going to use sym sizes, we should be setting sym strides at the same time, | ||
| // otherwise it's very easy to misuse this API | ||
| virtual void set_sym_sizes_and_strides(c10::SymIntArrayRef sizes, c10::SymIntArrayRef strides); |
There was a problem hiding this comment.
it's not clear to me why this needs to be virtual
There was a problem hiding this comment.
agreed, @suo mentioned he would help to formalize this API. I'll remove virtual in the meantime.
| void set_custom_device(bool custom_device) { | ||
| custom_device_ = custom_device; | ||
| } | ||
| protected: |
| // | ||
| // Can override: strides(), is_contiguous(), sizes(), dim(), numel() | ||
| CustomSizes = 2, | ||
| CustomSymSizes = 3, |
There was a problem hiding this comment.
Skeptical about this. I'll read through the uses first
| six | ||
| types-dataclasses | ||
| typing_extensions | ||
| dataclasses; python_version<"3.7" |
There was a problem hiding this comment.
we're py3.7 and up only, so this really shouldn't be needed
There was a problem hiding this comment.
ah this is a bad merge.. will fix..
|
|
||
| PyObject * THPSize_NewFromSymSizes(const at::Tensor& self_) | ||
| { | ||
| HANDLE_TH_ERRORS |
There was a problem hiding this comment.
you shouldn't need this macro here; this function is not directly bound to python
| return intlistWithDefault(i, signature.params[i].default_intlist); | ||
| } | ||
|
|
||
| TORCH_API bool is_symint_node(py::handle obj); |
There was a problem hiding this comment.
Why not just include the header?
There was a problem hiding this comment.
yeah let me inline it back.
| const auto size1 = signature.params[i].size; | ||
| if (size1 > 0 && THPUtils_checkLong(args[i])) { | ||
| return std::vector<c10::SymInt>(size1, c10::SymInt(THPUtils_unpackIndex(args[i]))); | ||
| } |
There was a problem hiding this comment.
You need to replicate this logic for a solitary symint arg as well
There was a problem hiding this comment.
yes, a good catch, thank you!
| // we need to clear SymIntTable until we have python | ||
| // otherwise python classes are already deregistered | ||
|
|
||
| //c10::getSymIntTable().clear(); |
|
|
||
| void SymIntTable::clear() { | ||
| std::lock_guard<std::mutex> lock(mutex_); | ||
| nodes_.clear(); |
|
|
||
| void TensorImpl::set_sym_sizes_and_strides(c10::SymIntArrayRef sizes, c10::SymIntArrayRef strides) { | ||
| has_symbolic_sizes_strides_ = true; | ||
| sizes_strides_policy_ = static_cast<uint8_t>(SizesStridesPolicy::CustomSizes); |
There was a problem hiding this comment.
This seems to me like CustomSymSizes isn't actually being used!
There was a problem hiding this comment.
We are setting the CustomSizes policy for python tensors (i.e. made via make_wrapper_class) so calls to sizes() would throw for those. Unfortunately, it means that sym_sizes() also throws. We actually would like to just run the default implementation in this case hence CustomSymSizes which is indeed overridden by LTC. I'm open to how we can make this cleaner.
There was a problem hiding this comment.
I think the easiest thing is to just call into python if python key is set and you have a custom sizes policy.
| } | ||
| if (r.toBool(11)) { | ||
| data.unsafeGetTensorImpl()->set_custom_device(true); | ||
| } |
There was a problem hiding this comment.
oh crap, a bad merge :(
| // NB: pin_memory doesn't actually do anything | ||
| // TODO: strides variant? | ||
| static PythonArgParser parser({ | ||
| "_make_wrapper_subclass(PyObject* cls, SymIntArrayRef size, SymIntArrayRef strides, int64_t? storage_offset=None, *, MemoryFormat? memory_format=None, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool pin_memory=False, bool requires_grad=False)", |
There was a problem hiding this comment.
How about keeping only one _make_wrapper_subclass and just having a second overload for PythonArgParser? Having it as an overload should also help reduce duplication in this variant of the function.
There was a problem hiding this comment.
mkkk... it's already a pretty branch and long function
There was a problem hiding this comment.
It being long is a good reason not to copy paste right ;)
| pyobj_ = std::make_shared<c10::SafePyObject>(pyobj.release().ptr(), getPyInterpreter()); | ||
| }; | ||
|
|
||
| virtual SymbolicIntNode* wrap(int64_t num) { |
There was a problem hiding this comment.
remind me again why we are doing raw pointer memory management here
|
|
||
| virtual bool bool_() { | ||
| py::gil_scoped_acquire acquire; | ||
| return py::str(getPyObj().attr("__bool__")()).is(py::str(Py_True)); |
There was a problem hiding this comment.
why are you doing a string comparison to test what the bool result is?
There was a problem hiding this comment.
lemme try simplifying it a bit. This what SO recommends to do, but I do agree it's convoluted. There's no py::cast to bool but maybe we don't have to do py::str on both sides.
There was a problem hiding this comment.
When taking questionable advice from stack overflow I highly recommend leaving a link to the URL of the question
There was a problem hiding this comment.
cleaned up this part. Now it should make more sense.
| virtual SymbolicIntNode* dispatch_common_(const char* fname, SymbolicIntNode* other) { | ||
| auto pother = dynamic_cast<PythonSymbolicIntNode*>(other); | ||
| TORCH_CHECK(pother); | ||
| auto magic_fname = std::string("__") + fname + std::string("__"); |
There was a problem hiding this comment.
I'd much rather you had taken the magic_fname as argument lol. With a macro you could paste the __ together with a string constant without having to do a string concat every function call (which is wasteful)
There was a problem hiding this comment.
haha sorry, I was going to do but didn't do it before your review.
| .def_static("isinstance", [](py::object obj, bool convert) -> bool { | ||
| return pybind11::detail::type_caster<std::shared_ptr<c10::SymbolicIntNode>>().load(obj, convert); | ||
| //return false; | ||
| }) |
There was a problem hiding this comment.
dead code. sorry..
| if (torch::is_symint_node(b)) { | ||
| return std::shared_ptr<c10::SymbolicIntNode>(a->add(b.cast<c10::SymbolicIntNode*>())); | ||
| } else { | ||
| return std::shared_ptr<c10::SymbolicIntNode> (a->add(a->wrap(b.cast<int64_t>()))); |
There was a problem hiding this comment.
This here feels like a helper function would help a bit here. But it's also not entirely clear you want to wrap integers into symbolic int nodes (that denote plain integers); it seems like it would be more user friendly if these showed up at dispatch site as plain integers. It might be a bit easier here to make the add method accept an IValue instead of a SymbolicIntNode, so you can pass in either an int or symbolic int without needing to unconditionally accept a SymbolicIntNode.
There was a problem hiding this comment.
it looks way nicer rn.
I don't want to introduce a dependency on IValue into SymbolicIntNode :( . It seems more complex architecturally and possibly less user friendly since both LTC and AOTAutograd will need to parse IValues explicitly. Both LTC and AOTAutograd already wrap ints into sympy.Integer or prim::Constant.
b539c7d to
4e3fa42
Compare
|
@pytorchbot merge this |
|
❌ 🤖 pytorchbot command failed: Try |
|
@pytorchbot merge |
|
@pytorchbot successfully started a merge job. Check the current status here |
|
Hey @Krovatkin. |
| static_cast<uint8_t>(SizesStridesPolicy::CustomSizes))) { | ||
| return sym_sizes_custom(); | ||
| } | ||
| virtual c10::SymIntArrayRef sym_sizes() const { |
There was a problem hiding this comment.
Why do we want this to be virtual now, instead of doing the policy thing that we do with all of our other de-virtualized methods?
(I need to fix it up for functionalization, which will be pretty easy - just curious on the reasoning)
There was a problem hiding this comment.
And do we even need a sym_sizes_custom() anymore if sym_sizes() is virtual?
There was a problem hiding this comment.
@bdhirsh python tensor subclasses and LTC want to do different things when policy is set to CustomSizes and there's no easy way to implement both via _custom so we had to make sym_sizes() virtual for now
|
We suspect that this PR broke TorchVision's tests. Information available here: pytorch/vision#6166 (comment) |
|
@pytorchbot revert -m "broke torchvision tests" -c weird |
|
@pytorchbot successfully started a revert job. Check the current status here |
This reverts commit d332724. Reverted #78135 on behalf of https://github.com/ezyang due to broke torchvision tests
This reverts commit b8db0a0. [ghstack-poisoned]
This reverts commit b8db0a0. [ghstack-poisoned]
This reverts commit b8db0a0. ghstack-source-id: 602ffd6 Pull Request resolved: pytorch#79608
This PR adds support for `SymInt`s in python. Namely, * `THPVariable_size` now returns `sym_sizes()` * python arg parser is modified to parse PyObjects into ints and `SymbolicIntNode`s * pybind11 bindings for `SymbolicIntNode` are added, so size expressions can be traced * a large number of tests added to demonstrate how to implement python symints. Pull Request resolved: pytorch#78135 Approved by: https://github.com/ezyang
This reverts commit e2fdcf8. Reverted pytorch#78135 on behalf of https://github.com/ezyang due to broke torchvision tests
This PR adds support for
SymInts in python. Namely,THPVariable_sizenow returnssym_sizes()SymbolicIntNodesSymbolicIntNodeare added, so size expressions can be traced