Skip to content

Move where cuda implementation to TensorIterator#32984

Closed
zasdfgbnm wants to merge 18 commits intopytorch:masterfrom
zasdfgbnm:where
Closed

Move where cuda implementation to TensorIterator#32984
zasdfgbnm wants to merge 18 commits intopytorch:masterfrom
zasdfgbnm:where

Conversation

@zasdfgbnm
Copy link
Collaborator

@zasdfgbnm zasdfgbnm commented Feb 4, 2020

where is special because the arguments do not have the same type, which does not satisfy the assumption in modern #32383. I migrate it to TensorIterator so that there is something to test that this case is not broken. Currently, this case fallback to using legacy (not vectorized, not unrolled) code. It should be supported in the future when I cleanup Loops.cuh.

I also move some sharing part of CUDALoops.cuh and ROCmLoops.cuh into Loops.cuh so that to logic for checking whether func_t has the same arg types could be shared.


} // namespace modern

template<typename func_t, int nargs=function_traits<func_t>::arity>
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moved to Loops.cuh


} // namespace modern

template<typename func_t, int nargs=function_traits<func_t>::arity>
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moved to Loops.cuh

#include <ATen/native/cuda/ROCmLoops.cuh>
#endif

namespace at { namespace native {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved from CUDALoops.cuh and ROCmLoops.cuh, this part of the code is identical for CUDA and ROCm.


namespace at { namespace native { namespace modern { namespace detail {

template<typename func_t, int remaining=function_traits<func_t>::arity-1>
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this part is newly added

arg0_t result = legacy::invoke(f, &data.data[1], &strides.data[1], &dtypes.data[1], idx);
c10::cast_and_store<arg0_t>(dtypes[0], out, result);
});
} else if (iter.has_contiguous_first_dim() && modern::detail::has_same_arg_types<func_t>::value) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the only line changed for this copy-pasted chunk of code.

@vadimkantorov
Copy link
Contributor

I hope this enables #9190

@zasdfgbnm
Copy link
Collaborator Author

@vadimkantorov Unfortunately this doesn't...

@zasdfgbnm zasdfgbnm changed the title [WIP] Move where cuda implementation to TensorIterator Move where cuda implementation to TensorIterator Feb 4, 2020
@zasdfgbnm zasdfgbnm requested a review from ngimel February 4, 2020 22:54
@kostmo
Copy link
Member

kostmo commented Feb 4, 2020

💊 CircleCI build failures summary and remediations

As of commit f5e1114:

None of the build failures appear to be your fault.

  • 1/1 broken upstream at merge base 74c8a8f since Feb 11

    Please rebase on the viable/strict branch (expand for instructions)

    If your commit is newer than viable/strict, you can try basing on an older, stable commit:

    git fetch origin viable/strict
    git rebase --onto viable/strict $(git merge-base origin/master HEAD)
    

    If your commit is older than viable/strict:

    git fetch origin viable/strict
    git rebase viable/strict
    

    Check out the recency history of this "viable master" tracking branch.

Detailed failure analysis

One may explore the probable reasons each build failed interactively on the Dr. CI website.

🚧 1 upstream failure recognized by patterns:

These builds matched patterns, but were probably caused by upstream breakages:


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker.

This comment has been revised 16 times.

@jerryzh168 jerryzh168 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Feb 5, 2020
};

// simple compile time test for has_same_arg_types:
using func1_t = int (*)(float, float);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This belongs in tests, not in actual source?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is a compile-time unit test for has_same_arg_types. Maybe I should remove it from Loops.cuh and move to somewhere else?

using traits = function_traits<func_t>;
static constexpr bool value = std::is_same<
typename traits::template arg<remaining>::type,
typename traits::template arg<remaining-1>::type
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

out of curiosity, how does this work with -1?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is specialized as true as in the code below. For nullary function, arity == 0, therefore has_same_arg_types<func_t> will becomes has_same_arg_types<func_t, function_traits<func_t>::arity-1> which is has_same_arg_types<func_t, -1>


namespace at { namespace native {

// `needs_dynamic_casting` compares the types expected by iterator
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ngimel Docs added here

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool, so it'll be good to go once you move out tests from Loops.cuh.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test moved to cuda_vectorized_test.cu

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@ngimel merged this pull request in 367488b.

@zasdfgbnm zasdfgbnm restored the where branch February 12, 2020 04:07
@zasdfgbnm zasdfgbnm reopened this Feb 12, 2020
@zasdfgbnm zasdfgbnm closed this Feb 12, 2020
@zasdfgbnm zasdfgbnm deleted the where branch February 12, 2020 04:12
facebook-github-bot pushed a commit that referenced this pull request Feb 12, 2020
Summary:
Reopen of #32984
Pull Request resolved: #33228

Differential Revision: D19850862

Pulled By: ngimel

fbshipit-source-id: b92446a49b4980188fa4788220a2164650e905c2
ttumiel pushed a commit to ttumiel/pytorch that referenced this pull request Mar 4, 2020
Summary:
`where` is special because the arguments do not have the same type, which does not satisfy the assumption in modern pytorch#32383. I migrate it to TensorIterator so that there is something to test that this case is not broken. Currently, this case fallback to using legacy (not vectorized, not unrolled) code. It should be supported in the future when I cleanup `Loops.cuh`.

I also move some sharing part of `CUDALoops.cuh` and `ROCmLoops.cuh` into `Loops.cuh` so that to logic for checking whether `func_t` has the same arg types could be shared.
Pull Request resolved: pytorch#32984

Differential Revision: D19825127

Pulled By: ngimel

fbshipit-source-id: bbf4682349d96b4480c4d657f3c18a3a67a9bf17
ttumiel pushed a commit to ttumiel/pytorch that referenced this pull request Mar 4, 2020
Summary:
Reopen of pytorch#32984
Pull Request resolved: pytorch#33228

Differential Revision: D19850862

Pulled By: ngimel

fbshipit-source-id: b92446a49b4980188fa4788220a2164650e905c2
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants