Skip to content

BUG Fixes regression for nllloss gradcheck#64203

Closed
thomasjpfan wants to merge 4 commits intopytorch:masterfrom
thomasjpfan:nll_loss_regression_backward
Closed

BUG Fixes regression for nllloss gradcheck#64203
thomasjpfan wants to merge 4 commits intopytorch:masterfrom
thomasjpfan:nll_loss_regression_backward

Conversation

@thomasjpfan
Copy link
Contributor

@thomasjpfan thomasjpfan commented Aug 30, 2021

Fixes #64163

This PR includes the fix and the opinfo from #63854 for non-regression testing.

cc @albanD @mruberry @jbschlosser

@thomasjpfan thomasjpfan added module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Aug 30, 2021
@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Aug 30, 2021

🔗 Helpful links

💊 CI failures summary and remediations

As of commit e549bc7 (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

Comment on lines +369 to +370
using accscalar_t = at::acc_type<scalar_t, /*is_cuda*/true>;
nll_loss_forward_reduce_cuda_kernel_2d<scalar_t, accscalar_t, index_t>
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was what was causing the issue with gradcheck.

Copy link
Contributor

@jbschlosser jbschlosser left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for the fix :)

int n_classes,
int64_t ignore_index) {
CUDA_KERNEL_ASSERT(threadIdx.x == 0 && threadIdx.y == 0 & threadIdx.z == 0);
CUDA_KERNEL_ASSERT(threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oops good catch

auto weight_ = weight.defined() ? weight.contiguous() : weight;

if (reduction == Reduction::None & n_dims == 2) {
if (reduction == Reduction::None && n_dims == 2) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here too :0 Richard says we should have a linter rule for catching this sort of thing- seems like a good idea

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The compiler does warn about these I think. Maybe we want to make these warnings errors.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably compile pytorch with -Werror for good practice. The downside is pytorch already emits a lot of warnings so fixing pytorch to get it into a good shape for this will take a while

@pmeier pmeier requested a review from zou3519 August 30, 2021 18:36
@gchanan gchanan added this to the 1.10.0 milestone Aug 30, 2021
@facebook-github-bot
Copy link
Contributor

@jbschlosser has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@facebook-github-bot
Copy link
Contributor

@jbschlosser has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@jbschlosser merged this pull request in a7ae73a.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed Merged module: nn Related to torch.nn open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Gradient is incorrect for torch.nn.functional.nll_loss for CUDA

7 participants