Skip to content

Implement logaddexp#38384

Closed
cloudhan wants to merge 18 commits intopytorch:masterfrom
cloudhan:impl-logaddexp
Closed

Implement logaddexp#38384
cloudhan wants to merge 18 commits intopytorch:masterfrom
cloudhan:impl-logaddexp

Conversation

@cloudhan
Copy link
Copy Markdown
Contributor

@cloudhan cloudhan commented May 13, 2020

Resolve #38377
Related #38349

This op should be disambiguated with logsumexp which do a reduction on a tensor over a specific axis.

@dr-ci
Copy link
Copy Markdown

dr-ci Bot commented May 13, 2020

💊 CI failures summary and remediations

As of commit 97ab0ca (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 53 times.

@ezyang ezyang changed the title Implement logaddexp [WIP] Implement logaddexp May 13, 2020
@ezyang
Copy link
Copy Markdown
Contributor

ezyang commented May 13, 2020

Remove WIP from the title when you're ready for review

@mruberry mruberry added the module: numpy Related to numpy support, and also numpy compatibility of our operators label May 15, 2020
@mruberry mruberry self-requested a review May 15, 2020 08:37
@cloudhan
Copy link
Copy Markdown
Contributor Author

@mruberry anything left?

@mruberry
Copy link
Copy Markdown
Collaborator

@mruberry anything left?

Just mark it as no longer WIP/draft when you're ready for review.

@cloudhan cloudhan changed the title [WIP] Implement logaddexp Implement logaddexp May 16, 2020
@cloudhan cloudhan marked this pull request as ready for review May 16, 2020 01:57
Comment thread aten/src/ATen/native/BinaryOps.cpp Outdated
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need a /*check_mem_overlap=*/=true here? What if result is self or other?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there will be no problem if result is self or other, but problem might occur when view of tensor overlap with other view, so check_mem_overlap is definitely needed here.

Comment thread aten/src/ATen/native/BinaryOps.cpp Outdated
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same check_mem_overlap question as above.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if a == b == inf or -inf?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same question when a == b == +/- inf

Comment thread tools/autograd/derivatives.yaml Outdated
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@albanD Take a look, would you?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The formula looks good.
For testing, you want to add an entry here to make sure the gradients will be properly checked.

Comment thread test/test_torch.py Outdated
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@albanD What's our plan for testing the derivatives of new functions like this?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am also curious.

Comment thread test/test_torch.py Outdated
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't import NumPy, instead assert TEST_NUMPY

Comment thread test/test_torch.py Outdated
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should move these tests into TestTorchDeviceType so you can run them on both the CPU and GPU. Right now you're just testing the CPU. Then see the helper function "_np_compare," which you can use to simplify your tests, and the @dtypes decorator so you can test multiple dtypes. Since you're only enabling logaddexp on float types, you should test at least torch.long (assert throws RuntimeError), torch.float32, and torch.complex64 (assert throws RuntimeError). That way if someone implements logaddexp for complex64, for example, they'll know to update this test.

_np_compare works by taking values to test at. Your value generation is OK but you should add some more "interesting" values like -math.pi, 0, and math.pi, as well as extremal values like nan, -inf, and inf.

Comment thread torch/_torch_docs.py Outdated
Comment thread torch/_torch_docs.py Outdated
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment is a little misleading. Maybe something like: "the tensor whose exponential is added to the exponential of input before the log is taken"?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

your suggestion is too verbose, may be I should just delete it and let the generator generate the second input tensor for it?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure.

Comment thread torch/_torch_docs.py Outdated
Comment thread torch/_torch_docs.py Outdated
Comment thread torch/_torch_docs.py Outdated
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same language/doc changes here as with the above.

Copy link
Copy Markdown
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR is looking really good. I requested changes to the tests and docs and have a couple additional questions. I also want to check with @albanD if/how we're planning to validate the gradients of these functions.

Once we get that fixed up I think this PR will be good to go!

@cloudhan cloudhan requested review from albanD and mruberry May 20, 2020 09:18
Comment thread tools/autograd/derivatives.yaml Outdated
Copy link
Copy Markdown
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM
I'll let @mruberry do a final pass.

@cloudhan
Copy link
Copy Markdown
Contributor Author

@mruberry

Copy link
Copy Markdown
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work, @cloudhan!

If you're still interested in contribution functions to PyTorch I encourage you to pick one of the more challenging functions in #38349, like count_nonzero or divmod.

Copy link
Copy Markdown
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Copy Markdown
Contributor

@mruberry merged this pull request in 05f097b.

@muthuArivoli muthuArivoli mentioned this pull request Jul 30, 2020
@cloudhan cloudhan deleted the impl-logaddexp branch September 26, 2021 05:29
laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 24, 2026
Summary:
Resolve pytorch#38377
Related pytorch#38349

This op should be disambiguated with `logsumexp` which do a reduction on a tensor over a specific axis.
Pull Request resolved: pytorch#38384

Differential Revision: D21737336

Pulled By: mruberry

fbshipit-source-id: 7864d04ca304c0fb2937bb083583e3e3d6ef205d
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged module: numpy Related to numpy support, and also numpy compatibility of our operators open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Implementing logaddexp, logaddexp2

7 participants