Add a utility function to filter out elements that are not of a certain type#267
Add a utility function to filter out elements that are not of a certain type#267naoyam merged 16 commits intocsarofeen:20_7_6_develfrom
Conversation
| Container<FilterType*> filterVals(const Container<ElementType*>& container) { | ||
| Container<FilterType*> out; | ||
| for (auto& s : container) { | ||
| if (s->getValType() == FilterType::type) { |
There was a problem hiding this comment.
since you're interested in a better notation, would you be interested in adding another building block to the IR root class? ie. add this to Statement:
template<class T>
bool IsA() const {
return dynamic_cast<const T*>(this) != nullptr;
}this is a counterpart to Statement::as that I have on my todo list, and would allow us to say:
if (s->isA<FilterType>()) { ... }
... and also eventually eliminate the need for ValType altogether.
There was a problem hiding this comment.
That looks nice, but one thing I didn't want to change is that the type checking is currently always done with getValType(). I think we can replace that with something like IsA, but I wonder whether ValType values and classes have one-to-one relationship. For example, would there be a case where there could be two different classes that are a subclass of Val AND have the same ValType value. I don't think that's the case now, but is this something that could happen in the future?
There was a problem hiding this comment.
I think switch to a more consistent IsA makes sense to me (uniformed API).
Can't we have specialization of isA so that we cannot necessarily do dynamic cast? that seems to also resolve the restriction on the one-to-one mapping.
template<class T>
bool IsA() const {
return dynamic_cast<const T*>(this) != nullptr;
}
template<ValType T>
bool IsA() const {
return getValType() == T;
}
There was a problem hiding this comment.
Using IsA<>() with the non-mapping filter and return would still require us to add an extra template parameter in the filterVals, (one for filter criterion used in the IsA, one to cast the return type as). Might bother people 🤷
There was a problem hiding this comment.
@naoyam: a potential mismatch between the tag (ValType) and the actual type is one of the reasons I'm wary of using tags , especially when they duplicate the type hierarchy 1:1 (like we do today). Since we're using a hierarchy to model the taxonomy of nodes, I think it's both easier and safer to use the types themselves. Another benefit of using the types (through a notation like isA<T>), is that you can query for abstract (non-leaf) types too, ex. isA<Expr> or isA<Val>
@jjsjann123 : as I explained above, the main benefit of isA<T> is not the notation itself, but the fact that it avoids tags altogether. So a version which tries to mix tags and types seems confusing at best and I'd recommend against it.
I'm not suggesting we should go as far as eliminating tags in this change. It seems an opportunity for adding isA<T> since it fits well with the proposed utility, but it's not really critical to this PR.
There was a problem hiding this comment.
I agree that we could get rid of the tag and just use the runtime type information would be more robust. However, I'd prefer to keep using the tag for type checking at least in this PR as changing that is not the purpose. I'll create a separate issue to consider.
jjsjann123
left a comment
There was a problem hiding this comment.
I like the updated version a lot 👍
| class FilterValContainer { | ||
| public: | ||
| using value_type = FilterType*; | ||
| using const_iterator = FilterIterator<FilterType, InputIt>; |
There was a problem hiding this comment.
nitpick: Wouldn't we also want a non-const iterator?
There was a problem hiding this comment.
Do we want to modify the original container through the filtered view?
There was a problem hiding this comment.
There may be some use cases, but I suspect we don't need that at this time. I left a related comment for a future extension.
| Container<FilterType*> filterVals(const Container<ElementType*>& container) { | ||
| Container<FilterType*> out; | ||
| for (auto& s : container) { | ||
| if (s->getValType() == FilterType::type) { |
There was a problem hiding this comment.
I think switch to a more consistent IsA makes sense to me (uniformed API).
Can't we have specialization of isA so that we cannot necessarily do dynamic cast? that seems to also resolve the restriction on the one-to-one mapping.
template<class T>
bool IsA() const {
return dynamic_cast<const T*>(this) != nullptr;
}
template<ValType T>
bool IsA() const {
return getValType() == T;
}
| Container<FilterType*> filterVals(const Container<ElementType*>& container) { | ||
| Container<FilterType*> out; | ||
| for (auto& s : container) { | ||
| if (s->getValType() == FilterType::type) { |
There was a problem hiding this comment.
Using IsA<>() with the non-mapping filter and return would still require us to add an extra template parameter in the filterVals, (one for filter criterion used in the IsA, one to cast the return type as). Might bother people 🤷
| }; | ||
|
|
||
| template <typename FilterType, typename InputIt> | ||
| class FilterValContainer { |
There was a problem hiding this comment.
We should probably rename this to say "view" to better reflect what it is (it's not really a container since it doesn't own the values). FilteredView ?
Also, InputIt -> just Iterator? (technically it's not an input iterator but rather a forward iterator)
There was a problem hiding this comment.
Thanks for the suggestion. Renamed to FilteredView.
| class FilterValContainer { | ||
| public: | ||
| using value_type = FilterType*; | ||
| using const_iterator = FilterIterator<FilterType, InputIt>; |
There was a problem hiding this comment.
Do we want to modify the original container through the filtered view?
| } | ||
|
|
||
| private: | ||
| InputIt input_it_; |
There was a problem hiding this comment.
-
constfor the data members? -
Use more precise names for the iterators:
begin_andend_? (in particular, last != end so we should avoid using that name)
|
|
||
| template <typename FilterType, typename ContainerType> | ||
| auto filterVals(const ContainerType& inputs) { | ||
| return filterVals<FilterType>(inputs.begin(), inputs.end()); |
| return out; | ||
| } | ||
| #else | ||
| template <typename FilterType, typename InputIt> |
There was a problem hiding this comment.
InputIt -> IteratorType or just Iterator ?
There was a problem hiding this comment.
Renamed to Iterator.
| } | ||
|
|
||
| private: | ||
| InputIt input_it_; |
There was a problem hiding this comment.
please separate the data members from private functions into separate private sections
|
|
||
| private: | ||
| InputIt input_it_; | ||
| InputIt last_; |
There was a problem hiding this comment.
- rename input_it_ -> current_ ?
- as I mentioned in another comment, we should rename last_ -> end_
| } | ||
|
|
||
| bool operator==(const FilterIterator& other) const { | ||
| return input_it_ == other.input_it_ && last_ == other.last_; |
There was a problem hiding this comment.
I'd actually just TORCH_INTERNAL_ASSERT(last_ == other.last_) - if someone is trying to compare fitered iterators coming from different filtered views it's most likely an error and we should raise it as such
There was a problem hiding this comment.
That's a good idea. Done.
We eventually do casting so dynamic_cast-based type checking seems more appropriate
|
I actually changed the type checking to use |
| return (*current_)->template as<FilterType>(); | ||
| } | ||
|
|
||
| FilterType* operator->() const { |
| } | ||
|
|
||
| FilterIterator operator++(int) { | ||
| auto before_increment = *this; |
| return !(*this == other); | ||
| } | ||
|
|
||
| private: |
There was a problem hiding this comment.
please place the private state section last (public methods, private methods, private state)
| This doesn't compile | ||
| ../test/cpp/jit/test_gpu.cpp:800:51: error: no matching function for call to ‘std::vector<torch::jit::fuser::TensorView*>::vector(torch::jit::fuser::ir_utils::FilteredView<torch::jit::fuser::TensorView, __gnu_cxx::__normal_iterator<torch::jit::fuser::Val* const*, std::vector<torch::jit::fuser::Val*> > >::const_iterator, torch::jit::fuser::ir_utils::FilteredView<torch::jit::fuser::TensorView, __gnu_cxx::__normal_iterator<torch::jit::fuser::Val* const*, std::vector<torch::jit::fuser::Val*> > >::const_iterator)’ | ||
| ir_utils::filterVals<TensorView>(vals).end()); | ||
| ^ | ||
| In file included from /usr/include/c++/8/vector:64, | ||
| from ../c10/util/StringUtil.h:11, | ||
| from ../c10/util/Exception.h:5, | ||
| from ../c10/core/Device.h:5, | ||
| from ../c10/core/Allocator.h:6, | ||
| from ../aten/src/ATen/ATen.h:7, | ||
| from ../torch/csrc/jit/ir/attributes.h:2, | ||
| from ../torch/csrc/jit/ir/ir.h:3, | ||
| from ../test/cpp/jit/test_base.h:5, | ||
| from ../test/cpp/jit/test_gpu.cpp:3: | ||
| /usr/include/c++/8/bits/stl_vector.h:543:2: note: candidate: ‘template<class _InputIterator, class> std::vector<_Tp, _Alloc>::vector(_InputIterator, _InputIterator, const allocator_type&)’ | ||
| vector(_InputIterator __first, _InputIterator __last, | ||
| */ | ||
| #if 0 | ||
| std::vector<TensorView*> tvs( | ||
| ir_utils::filterVals<TensorView>(vals).begin(), | ||
| ir_utils::filterVals<TensorView>(vals).end()); |
There was a problem hiding this comment.
@tlemo Do you know why this doesn't work?
There was a problem hiding this comment.
Look at constructor (5) in https://en.cppreference.com/w/cpp/container/vector/vector
Your const_iterator needs to satisfy LegacyInputIterator... So you would need difference_type and bunch of that...
| TORCH_CHECK(ints.size() == 2); | ||
| TORCH_CHECK(ints[0] == scalar1); | ||
| TORCH_CHECK(ints[1] == scalar2); | ||
| return; |
| auto scalar1 = new Int(0); | ||
| auto scalar2 = new Int(1); | ||
|
|
||
| std::vector<Val*> vals = {tv0, scalar0, tv1, scalar1, scalar2}; |
There was a problem hiding this comment.
can you please make vals const?
| TORCH_CHECK(floats.size() == 1); | ||
| TORCH_CHECK(floats[0] == scalar0); | ||
|
|
||
| std::vector<Int*> ints( |
There was a problem hiding this comment.
one more test case please: a filtering which returns an empty list, ex:
for (auto ptr : ir_utils::filterVals<Expr>(vals)) {
TORCH_CHECK(false, "Not expecting any results");
}
|
I renamed the filter function to The example I gave first still looks mostly the same. We replace this: with: |
|
@tlemo Sorry, I missed your last comments. I'll update the test with a separate PR. |
…ty. (#267) (pytorch#89315) Summary: pytorch#89122 introduces internal compatibility issues with torchdeploy. However, GetPythonFramesFunction() never worked with torchdeploy, so this PR simply reverts to the original behavior of skipping the function if torchdeploy is used as a forward fix. Test Plan: Running failed tests in T128123281 ``` buck2 test @//mode/opt //multipy/runtime:test_deploy -- --exact 'multipy/runtime:test_deploy - TorchpyTest.TaggingRace' --run-disabled buck2 test mode/dev //multipy/runtime/testdev:test_deploy_from_python -- --exact 'multipy/runtime/testdev:test_deploy_from_python - multipy.runtime.testdev.test_deploy_from_python.TestDeployFromPython: test_deploy_from_python' ``` Differential Revision: D41414263 Pull Request resolved: pytorch#89315 Approved by: https://github.com/kurman
This adds a utility function that I think can reduce repetition. It's not yet completed, but I'd like to have feedback before going further.
We often have code like this:
This PR adds
ir_utils::filterVals, which can be used to simplify code patterns like the above:For this, I added
static constexpr ValType typeto theValclasses (not yet completed; only added toIterDomainandTensorViewfor now).Let me know if you have any feedback. Thanks!