Skip to content

[CPU] Add torch.trace for complex tensors#50380

Closed
anjali411 wants to merge 10 commits intogh/anjali411/79/basefrom
gh/anjali411/79/head
Closed

[CPU] Add torch.trace for complex tensors#50380
anjali411 wants to merge 10 commits intogh/anjali411/79/basefrom
gh/anjali411/79/head

Conversation

@anjali411
Copy link
Copy Markdown
Contributor

@anjali411 anjali411 commented Jan 11, 2021

Stack from ghstack:

Differential Revision: D25949361

anjali411 added a commit that referenced this pull request Jan 11, 2021
ghstack-source-id: 81bd0e4
Pull Request resolved: #50380
@facebook-github-bot
Copy link
Copy Markdown
Contributor

facebook-github-bot commented Jan 11, 2021

💊 CI failures summary and remediations

As of commit 76d05d5 (more details on the Dr. CI page):


  • 2/3 failures possibly* introduced in this PR
    • 2/2 non-CircleCI failure(s)
  • 1/3 broken upstream at merge base 78f3038 on Jan 22 from 1:38pm to 3:34pm

🚧 1 fixed upstream failure:

These were probably caused by upstream breakages that were already fixed.

Please rebase on the viable/strict branch (expand for instructions)

If your commit is older than viable/strict, run these commands:

git fetch https://github.com/pytorch/pytorch viable/strict
git rebase FETCH_HEAD

Check out the recency history of this "viable master" tracking branch.


ci.pytorch.org: 1 failed


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Comment thread aten/src/ATen/native/ReduceOps.cpp Outdated
// all integer types get promoted to kLong
if (result.scalar_type() == at::kLong) {
// all integer types get promoted to kLong
*result.data_ptr<int64_t>() = sum;
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

separated dispatch because this cast causes a compile error for complex otherwise

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like a good opportunity to use a if_constexpr to fix the problem. Test for std::is_integral<scalar_t> . Check if_constexpr docs for how to use it.

Copy link
Copy Markdown
Contributor Author

@anjali411 anjali411 Jan 19, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

still gives a compile error

@anjali411 anjali411 added the module: complex Related to complex number support in PyTorch label Jan 11, 2021
@codecov
Copy link
Copy Markdown

codecov Bot commented Jan 11, 2021

Codecov Report

Merging #50380 (19e8616) into gh/anjali411/79/base (e29082b) will increase coverage by 0.00%.
The diff coverage is 100.00%.

@@                  Coverage Diff                  @@
##           gh/anjali411/79/base   #50380   +/-   ##
=====================================================
  Coverage                 80.71%   80.71%           
=====================================================
  Files                      1904     1904           
  Lines                    206598   206600    +2     
=====================================================
+ Hits                     166750   166753    +3     
+ Misses                    39848    39847    -1     

Comment thread aten/src/ATen/native/ReduceOps.cpp Outdated
using accscalar_t = at::acc_type<scalar_t, false>;
accscalar_t sum = 0;
const auto* t_data = self.data_ptr<scalar_t>();
if (result.scalar_type() == at::kLong) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, is this test right? What about other integral types besides Long?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Surprisingly this is kinda the right test.

if (dtype == at::kLong) {
  ...
}

might be clearer. Add a comment here explaining that all integer types (including bool) will return kLong from get_dtype(), and that this branch handles integer types.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the new branch needed, however? It seems like, as in the previous kernel, the only different is at the assignment to result at the end?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah it is a bit confusing with no comment, so I added one for clarity.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see. OK.

Comment thread test/test_torch.py Outdated
anjali411 added a commit that referenced this pull request Jan 15, 2021
ghstack-source-id: bfd6b0e
Pull Request resolved: #50380
@anjali411 anjali411 requested a review from mruberry January 15, 2021 21:15
@anjali411
Copy link
Copy Markdown
Contributor Author

@mruberry I updated the PR based on your comments. could you take a look?

anjali411 added a commit that referenced this pull request Jan 19, 2021
ghstack-source-id: c8391ae
Pull Request resolved: #50380
@mruberry
Copy link
Copy Markdown
Collaborator

Looks like CUDA supports bool and bfloat16, so those need to be added to the OpInfo's CUDA dtypes.

anjali411 added a commit that referenced this pull request Jan 19, 2021
ghstack-source-id: 7c4dd86
Pull Request resolved: #50380
Copy link
Copy Markdown
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approving to unblock, but please give the if_constexpr trick a try, I think you can avoid the duplication with it.

anjali411 added a commit that referenced this pull request Jan 20, 2021
ghstack-source-id: 558f1f6
Pull Request resolved: #50380
Comment thread aten/src/ATen/native/ReduceOps.cpp Outdated
c10::guts::if_constexpr<std::is_integral<accscalar_t>::value>(
// all integer types get promoted to kLong
[&] (auto _) { *result.data_ptr<int64_t>() = sum; }, // then-case, invalid for non-integral types
[&] (auto _) { *result.data_ptr<scalar_t>() = sum; } // else-case, invalid for integral types
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc. @ezyang the build still fails

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The compiler is eagerly type-checking the else-case even if the condition is true (and the other way round). You need to prevent the compiler from doing so, so that it can only type-check the branch that's actually called. You can do this by wrapping things into the _ you get passed as an argument. See https://github.com/pytorch/pytorch/blob/master/c10/util/C%2B%2B17.h#L257

anjali411 added a commit that referenced this pull request Jan 22, 2021
ghstack-source-id: fc1a160
Pull Request resolved: #50380
@facebook-github-bot
Copy link
Copy Markdown
Contributor

@anjali411 merged this pull request in e544d74.

@facebook-github-bot facebook-github-bot deleted the gh/anjali411/79/head branch January 27, 2021 15:21
laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 24, 2026
Summary: Pull Request resolved: pytorch#50380

Test Plan: Imported from OSS

Reviewed By: ezyang

Differential Revision: D25949361

Pulled By: anjali411

fbshipit-source-id: 9910bc5b532c9bf3add530221d643b2c41c62d01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed Merged module: complex Related to complex number support in PyTorch

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants