Skip to content

Support complex number list in JIT#51145

Closed
anjali411 wants to merge 9 commits intogh/anjali411/91/basefrom
gh/anjali411/91/head
Closed

Support complex number list in JIT#51145
anjali411 wants to merge 9 commits intogh/anjali411/91/basefrom
gh/anjali411/91/head

Conversation

@anjali411
Copy link
Copy Markdown
Contributor

@anjali411 anjali411 commented Jan 26, 2021

Stack from ghstack:

Differential Revision: D26154025

@facebook-github-bot facebook-github-bot added cla signed oncall: jit Add this issue/PR to JIT oncall triage queue labels Jan 26, 2021
anjali411 added a commit that referenced this pull request Jan 26, 2021
ghstack-source-id: 99d6bb5
Pull Request resolved: #51145
@facebook-github-bot
Copy link
Copy Markdown
Contributor

facebook-github-bot commented Jan 26, 2021

💊 CI failures summary and remediations

As of commit 8a9ae4c (more details on the Dr. CI page):


  • 1/2 failures introduced in this PR
  • 1/2 broken upstream at merge base 1836070 on Jan 26 from 12:49pm to 5:48pm

🕵️ 1 new failure recognized by patterns

The following CI failures do not appear to be due to upstream breakages:

See CircleCI build pytorch_python_doc_build (1/1)

Step: "Doc Build and Push" (full log | diagnosis details | 🔁 rerun)

Feb 01 04:40:51 Makefile:38: recipe for target 'html' failed
Feb 01 04:40:50 
Feb 01 04:40:50 copying static files... ... done
Feb 01 04:40:50 copying extra files... done
Feb 01 04:40:51 dumping search index in English (code: en)... done
Feb 01 04:40:51 dumping object inventory... done
Feb 01 04:40:51 build finished with problems, 1 warning.
Feb 01 04:40:51 /var/lib/jenkins/workspace/docs/src/pytorch-sphinx-theme/pytorch_sphinx_theme/search.html:21: RemovedInSphinx30Warning: To modify script_files in the theme is deprecated. Please insert a <script> tag directly in your theme instead.
Feb 01 04:40:51   <p class="last">
Feb 01 04:40:51 /var/lib/jenkins/workspace/docs/src/pytorch-sphinx-theme/pytorch_sphinx_theme/search.html:24: RemovedInSphinx30Warning: To modify script_files in the theme is deprecated. Please insert a <script> tag directly in your theme instead.
Feb 01 04:40:51   </p>
Feb 01 04:40:51 Makefile:38: recipe for target 'html' failed
Feb 01 04:40:51 make: *** [html] Error 1
Feb 01 04:40:51 ++ code=2
Feb 01 04:40:51 ++ '[' 2 -ne 0 ']'
Feb 01 04:40:51 ++ set +x
Feb 01 04:40:51 =========================
Feb 01 04:40:51 /var/lib/jenkins/workspace/docs/source/notes/broadcasting.rst:6: WARNING: 'any' reference target not found: numpy.doc.broadcasting
Feb 01 04:40:51 =========================
Feb 01 04:40:51 Docs build failed. If the failure is not clear, scan back in the log
Feb 01 04:40:51 for any WARNINGS or for the line build finished with problems
Feb 01 04:40:51 (tried to echo the WARNINGS above the ==== line)

1 job timed out:

  • pytorch_linux_xenial_py3_clang5_asan_test1

🚧 1 fixed upstream failure:

These were probably caused by upstream breakages that were already fixed.

Please rebase on the viable/strict branch (expand for instructions)

If your commit is older than viable/strict, run these commands:

git fetch https://github.com/pytorch/pytorch viable/strict
git rebase FETCH_HEAD

Check out the recency history of this "viable master" tracking branch.


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Copy link
Copy Markdown

@SplitInfinity SplitInfinity left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if this is ready for review (I didn't check if you requested review already, oops) but there are a few distracting formatting changes, and there should be Python tests that use lists of complex numbers.

const FunctionSchema& forward_schema = getMethod("forward").getSchema();
std::string input_types = getSchemaInputTypesString(forward_schema);
const std::vector<Argument>& forward_args = forward_schema.arguments();
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There seem to be a lot of formatting changes below this point. Can you undo them?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done!

anjali411 added a commit that referenced this pull request Jan 27, 2021
ghstack-source-id: b7c2bb8
Pull Request resolved: #51145
anjali411 added a commit that referenced this pull request Jan 27, 2021
ghstack-source-id: 23a086f
Pull Request resolved: #51145
anjali411 added a commit that referenced this pull request Jan 28, 2021
ghstack-source-id: 46f5840
Pull Request resolved: #51145
@anjali411
Copy link
Copy Markdown
Contributor Author

@SplitInfinity ready for review. Added the python test, but can't find a way to update that file without causing those formatting changes

anjali411 added a commit that referenced this pull request Jan 29, 2021
ghstack-source-id: 6c71350
Pull Request resolved: #51145
@anjali411 anjali411 added the module: complex Related to complex number support in PyTorch label Jan 29, 2021
Copy link
Copy Markdown

@SplitInfinity SplitInfinity left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

Comment thread aten/src/ATen/test/ivalue_test.cpp Outdated
ASSERT_EQ(foo12.toComplexDouble(), c10::complex<double>(3,4));
ASSERT_EQ(foo1.use_count(), 2);
ASSERT_TRUE(baz1.toComplexDoubleVector() == std::vector<c10::complex<double>>({(3, 4), (3, -4), (5, 0)}));
IValue the_complex_list(
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems a bit weird to call this a list when it is a tuple.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

true updated the name!

Comment thread test/jit/test_complex.py
Comment on lines +17 to +21
def test_complexlist(self):
def fn(a: List[complex], idx: int):
return a[idx]

input = [1j, 2, 3 + 4j, -5, -7j]
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about serialization of lists? Is that coming later?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will add it in a follow-up PR

Comment thread torch/csrc/jit/ir/constants.cpp Outdated
Comment on lines +91 to +92
bool fast_path_list =
val.isBoolList() || val.isIntList() || val.isDoubleList();
bool fast_path_list = val.isBoolList() || val.isIntList() ||
val.isDoubleList() || val.isComplexDoubleList();
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I recall, complex literals are not supported in TorchScript yet, so this is not being tested.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed!

anjali411 added a commit that referenced this pull request Jan 29, 2021
ghstack-source-id: 5a13a5b
Pull Request resolved: #51145
Copy link
Copy Markdown

@SplitInfinity SplitInfinity left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice

anjali411 added a commit that referenced this pull request Feb 1, 2021
ghstack-source-id: c575e07
Pull Request resolved: #51145
@facebook-github-bot
Copy link
Copy Markdown
Contributor

@anjali411 merged this pull request in 508bab4.

@facebook-github-bot facebook-github-bot deleted the gh/anjali411/91/head branch February 4, 2021 15:22
laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 24, 2026
Summary: Pull Request resolved: pytorch#51145

Test Plan: Imported from OSS

Reviewed By: SplitInfinity

Differential Revision: D26154025

Pulled By: anjali411

fbshipit-source-id: 74645f9b6467757ddb9d75846e778222109848f0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed Merged module: complex Related to complex number support in PyTorch oncall: jit Add this issue/PR to JIT oncall triage queue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants