Skip to content

Add a utility function to filter out elements that are not of a certain type#267

Merged
naoyam merged 16 commits intocsarofeen:20_7_6_develfrom
naoyam:ir-node-filter
Aug 6, 2020
Merged

Add a utility function to filter out elements that are not of a certain type#267
naoyam merged 16 commits intocsarofeen:20_7_6_develfrom
naoyam:ir-node-filter

Conversation

@naoyam
Copy link
Copy Markdown
Collaborator

@naoyam naoyam commented Aug 4, 2020

This adds a utility function that I think can reduce repetition. It's not yet completed, but I'd like to have feedback before going further.

We often have code like this:

std::vector<Val*> values;
for (auto x: values) {
  // do something only for TensorView vals
  if (x->getValType() == ValType::TensorView) {
   TensorView* x_tv = x->as<TensorView>();
   ...
  }
}

This PR adds ir_utils::filterVals, which can be used to simplify code patterns like the above:

for (auto x: filterVals<TensorView>(values)) {
  // x is TensorView*
 ... 
}

For this, I added static constexpr ValType type to the Val classes (not yet completed; only added to IterDomain and TensorView for now).

Let me know if you have any feedback. Thanks!

Comment thread torch/csrc/jit/codegen/cuda/ir_utils.h Outdated
Comment thread torch/csrc/jit/codegen/cuda/ir_utils.h Outdated
Container<FilterType*> filterVals(const Container<ElementType*>& container) {
Container<FilterType*> out;
for (auto& s : container) {
if (s->getValType() == FilterType::type) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since you're interested in a better notation, would you be interested in adding another building block to the IR root class? ie. add this to Statement:

 template<class T>
 bool IsA() const {
 return dynamic_cast<const T*>(this) != nullptr;
 }

this is a counterpart to Statement::as that I have on my todo list, and would allow us to say:

if (s->isA<FilterType>()) { ... }

... and also eventually eliminate the need for ValType altogether.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That looks nice, but one thing I didn't want to change is that the type checking is currently always done with getValType(). I think we can replace that with something like IsA, but I wonder whether ValType values and classes have one-to-one relationship. For example, would there be a case where there could be two different classes that are a subclass of Val AND have the same ValType value. I don't think that's the case now, but is this something that could happen in the future?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think switch to a more consistent IsA makes sense to me (uniformed API).
Can't we have specialization of isA so that we cannot necessarily do dynamic cast? that seems to also resolve the restriction on the one-to-one mapping.

 template<class T>
 bool IsA() const {
 return dynamic_cast<const T*>(this) != nullptr;
 }

 template<ValType T>
 bool IsA() const {
 return getValType() == T;
 }

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using IsA<>() with the non-mapping filter and return would still require us to add an extra template parameter in the filterVals, (one for filter criterion used in the IsA, one to cast the return type as). Might bother people 🤷

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@naoyam: a potential mismatch between the tag (ValType) and the actual type is one of the reasons I'm wary of using tags , especially when they duplicate the type hierarchy 1:1 (like we do today). Since we're using a hierarchy to model the taxonomy of nodes, I think it's both easier and safer to use the types themselves. Another benefit of using the types (through a notation like isA<T>), is that you can query for abstract (non-leaf) types too, ex. isA<Expr> or isA<Val>

@jjsjann123 : as I explained above, the main benefit of isA<T> is not the notation itself, but the fact that it avoids tags altogether. So a version which tries to mix tags and types seems confusing at best and I'd recommend against it.

I'm not suggesting we should go as far as eliminating tags in this change. It seems an opportunity for adding isA<T> since it fits well with the proposed utility, but it's not really critical to this PR.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that we could get rid of the tag and just use the runtime type information would be more robust. However, I'd prefer to keep using the tag for type checking at least in this PR as changing that is not the purpose. I'll create a separate issue to consider.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See #268

Comment thread torch/csrc/jit/codegen/cuda/ir_utils.h Outdated
Comment thread torch/csrc/jit/codegen/cuda/ir_utils.h Outdated
Copy link
Copy Markdown
Collaborator

@jjsjann123 jjsjann123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the updated version a lot 👍

class FilterValContainer {
public:
using value_type = FilterType*;
using const_iterator = FilterIterator<FilterType, InputIt>;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick: Wouldn't we also want a non-const iterator?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to modify the original container through the filtered view?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There may be some use cases, but I suspect we don't need that at this time. I left a related comment for a future extension.

Comment thread torch/csrc/jit/codegen/cuda/ir_utils.h Outdated
Container<FilterType*> filterVals(const Container<ElementType*>& container) {
Container<FilterType*> out;
for (auto& s : container) {
if (s->getValType() == FilterType::type) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think switch to a more consistent IsA makes sense to me (uniformed API).
Can't we have specialization of isA so that we cannot necessarily do dynamic cast? that seems to also resolve the restriction on the one-to-one mapping.

 template<class T>
 bool IsA() const {
 return dynamic_cast<const T*>(this) != nullptr;
 }

 template<ValType T>
 bool IsA() const {
 return getValType() == T;
 }

Comment thread torch/csrc/jit/codegen/cuda/ir_utils.h Outdated
Container<FilterType*> filterVals(const Container<ElementType*>& container) {
Container<FilterType*> out;
for (auto& s : container) {
if (s->getValType() == FilterType::type) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using IsA<>() with the non-mapping filter and return would still require us to add an extra template parameter in the filterVals, (one for filter criterion used in the IsA, one to cast the return type as). Might bother people 🤷

Comment thread torch/csrc/jit/codegen/cuda/ir_utils.h Outdated
};

template <typename FilterType, typename InputIt>
class FilterValContainer {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably rename this to say "view" to better reflect what it is (it's not really a container since it doesn't own the values). FilteredView ?

Also, InputIt -> just Iterator? (technically it's not an input iterator but rather a forward iterator)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion. Renamed to FilteredView.

class FilterValContainer {
public:
using value_type = FilterType*;
using const_iterator = FilterIterator<FilterType, InputIt>;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to modify the original container through the filtered view?

Comment thread torch/csrc/jit/codegen/cuda/ir_utils.h Outdated
}

private:
InputIt input_it_;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. const for the data members?

  2. Use more precise names for the iterators: begin_ and end_ ? (in particular, last != end so we should avoid using that name)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Done.

Comment thread torch/csrc/jit/codegen/cuda/ir_utils.h Outdated

template <typename FilterType, typename ContainerType>
auto filterVals(const ContainerType& inputs) {
return filterVals<FilterType>(inputs.begin(), inputs.end());
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: cbegin, cend

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed.

Comment thread torch/csrc/jit/codegen/cuda/ir_utils.h Outdated
return out;
}
#else
template <typename FilterType, typename InputIt>
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

InputIt -> IteratorType or just Iterator ?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed to Iterator.

Comment thread torch/csrc/jit/codegen/cuda/ir_utils.h Outdated
}

private:
InputIt input_it_;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please separate the data members from private functions into separate private sections

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Comment thread torch/csrc/jit/codegen/cuda/ir_utils.h
Comment thread torch/csrc/jit/codegen/cuda/ir_utils.h Outdated

private:
InputIt input_it_;
InputIt last_;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. rename input_it_ -> current_ ?
  2. as I mentioned in another comment, we should rename last_ -> end_

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Comment thread torch/csrc/jit/codegen/cuda/ir_utils.h Outdated
}

bool operator==(const FilterIterator& other) const {
return input_it_ == other.input_it_ && last_ == other.last_;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd actually just TORCH_INTERNAL_ASSERT(last_ == other.last_) - if someone is trying to compare fitered iterators coming from different filtered views it's most likely an error and we should raise it as such

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good idea. Done.

Comment thread torch/csrc/jit/codegen/cuda/ir_utils.h
@naoyam
Copy link
Copy Markdown
Collaborator Author

naoyam commented Aug 5, 2020

I actually changed the type checking to use dynamic_cast instead of getValType() as it eventually does dynamic_cast. I think it makes more sense.

Copy link
Copy Markdown
Collaborator

@jjsjann123 jjsjann123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@naoyam naoyam marked this pull request as ready for review August 5, 2020 20:11
return (*current_)->template as<FilterType>();
}

FilterType* operator->() const {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice, good catch!

Comment thread torch/csrc/jit/codegen/cuda/ir_utils.h Outdated
}

FilterIterator operator++(int) {
auto before_increment = *this;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: const ?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment thread torch/csrc/jit/codegen/cuda/ir_utils.h Outdated
return !(*this == other);
}

private:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please place the private state section last (public methods, private methods, private state)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment thread test/cpp/jit/test_gpu.cpp Outdated
Comment on lines +797 to +817
This doesn't compile
../test/cpp/jit/test_gpu.cpp:800:51: error: no matching function for call to ‘std::vector<torch::jit::fuser::TensorView*>::vector(torch::jit::fuser::ir_utils::FilteredView<torch::jit::fuser::TensorView, __gnu_cxx::__normal_iterator<torch::jit::fuser::Val* const*, std::vector<torch::jit::fuser::Val*> > >::const_iterator, torch::jit::fuser::ir_utils::FilteredView<torch::jit::fuser::TensorView, __gnu_cxx::__normal_iterator<torch::jit::fuser::Val* const*, std::vector<torch::jit::fuser::Val*> > >::const_iterator)’
ir_utils::filterVals<TensorView>(vals).end());
^
In file included from /usr/include/c++/8/vector:64,
from ../c10/util/StringUtil.h:11,
from ../c10/util/Exception.h:5,
from ../c10/core/Device.h:5,
from ../c10/core/Allocator.h:6,
from ../aten/src/ATen/ATen.h:7,
from ../torch/csrc/jit/ir/attributes.h:2,
from ../torch/csrc/jit/ir/ir.h:3,
from ../test/cpp/jit/test_base.h:5,
from ../test/cpp/jit/test_gpu.cpp:3:
/usr/include/c++/8/bits/stl_vector.h:543:2: note: candidate: ‘template<class _InputIterator, class> std::vector<_Tp, _Alloc>::vector(_InputIterator, _InputIterator, const allocator_type&)’
vector(_InputIterator __first, _InputIterator __last,
*/
#if 0
std::vector<TensorView*> tvs(
ir_utils::filterVals<TensorView>(vals).begin(),
ir_utils::filterVals<TensorView>(vals).end());
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tlemo Do you know why this doesn't work?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Look at constructor (5) in https://en.cppreference.com/w/cpp/container/vector/vector
Your const_iterator needs to satisfy LegacyInputIterator... So you would need difference_type and bunch of that...

Comment thread test/cpp/jit/test_gpu.cpp
TORCH_CHECK(ints.size() == 2);
TORCH_CHECK(ints[0] == scalar1);
TORCH_CHECK(ints[1] == scalar2);
return;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: return not needed

Comment thread test/cpp/jit/test_gpu.cpp
auto scalar1 = new Int(0);
auto scalar2 = new Int(1);

std::vector<Val*> vals = {tv0, scalar0, tv1, scalar1, scalar2};
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you please make vals const?

Comment thread test/cpp/jit/test_gpu.cpp
TORCH_CHECK(floats.size() == 1);
TORCH_CHECK(floats[0] == scalar0);

std::vector<Int*> ints(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one more test case please: a filtering which returns an empty list, ex:

for (auto ptr : ir_utils::filterVals<Expr>(vals)) {
 TORCH_CHECK(false, "Not expecting any results");
}

@naoyam naoyam changed the title [WIP] Add a utility function to filter out elements that are not of a certain type Add a utility function to filter out elements that are not of a certain type Aug 6, 2020
@naoyam
Copy link
Copy Markdown
Collaborator Author

naoyam commented Aug 6, 2020

I renamed the filter function to filterByType as it is actually no longer specific to Val. It just checks the runtime type of each element and filters out anything that is of that type.

The example I gave first still looks mostly the same. We replace this:

std::vector<Val*> values;
for (Val* x: values) {
  // do something only for TensorView vals
  if (x->getValType() == ValType::TensorView) {
   TensorView* x_tv = x->as<TensorView>();
   ...
  }
}

with:

for (TensorView* x: filterByType<TensorView>(values)) {
  // x is TensorView*
 ... 
}

@naoyam naoyam merged commit 1850edd into csarofeen:20_7_6_devel Aug 6, 2020
@naoyam
Copy link
Copy Markdown
Collaborator Author

naoyam commented Aug 6, 2020

@tlemo Sorry, I missed your last comments. I'll update the test with a separate PR.

@naoyam naoyam mentioned this pull request Aug 6, 2020
jjsjann123 pushed a commit that referenced this pull request Dec 22, 2022
…ty. (#267) (pytorch#89315)

Summary:
pytorch#89122 introduces internal compatibility issues with torchdeploy. However, GetPythonFramesFunction() never worked with torchdeploy, so this PR simply reverts to the original behavior of skipping the function if torchdeploy is used as a forward fix.

Test Plan:
Running failed tests in T128123281
```
buck2 test @//mode/opt //multipy/runtime:test_deploy -- --exact 'multipy/runtime:test_deploy - TorchpyTest.TaggingRace' --run-disabled

buck2 test mode/dev //multipy/runtime/testdev:test_deploy_from_python -- --exact 'multipy/runtime/testdev:test_deploy_from_python - multipy.runtime.testdev.test_deploy_from_python.TestDeployFromPython: test_deploy_from_python'
```

Differential Revision: D41414263

Pull Request resolved: pytorch#89315
Approved by: https://github.com/kurman
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants