Skip to content

Implement NLLLossNd#4035

Merged
soumith merged 3 commits intopytorch:masterfrom
zou3519:nlllossNd
Dec 18, 2017
Merged

Implement NLLLossNd#4035
soumith merged 3 commits intopytorch:masterfrom
zou3519:nlllossNd

Conversation

@zou3519
Copy link
Copy Markdown
Contributor

@zou3519 zou3519 commented Dec 5, 2017

Needed for #3556

I'm not sure this is the best way to implement because the .contiguous() calls might be slow.

One alternative way to implement this is to copy and modify gather. Without any of the extra keyword modifiers, with reduce=False, the following is equivalent to NLLLossNd:

def nlllossNd(input, target):
    target = target.unsqueeze(1)
    out = torch.gather(input, 1, target)
    return out.squeeze(1)

I tried benchmarking this against what I have right now (this diff that uses .contiguous() calls and NLLLoss2d) and using gather is around 2x slower, even for non-contiguous inputs, so I went with this approach.

Test Plan

Unit tests for NLLLossNd with a NLLLossNd reference function

Comment thread test/common_nn.py
'KLDivLoss': kldivloss_reference,
'NLLLoss': nllloss_reference,
'NLLLoss2d': nllloss2d_reference,
'NLLLossNd': nlllossNd_reference,

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@zou3519
Copy link
Copy Markdown
Contributor Author

zou3519 commented Dec 6, 2017

@pytorchbot retest this please

@pietern
Copy link
Copy Markdown
Contributor

pietern commented Dec 7, 2017

There was some CI maintenance happening this morning -- retriggering build.

@pytorchbot retest this please

@soumith soumith merged commit 30e6898 into pytorch:master Dec 18, 2017
@zou3519 zou3519 deleted the nlllossNd branch January 3, 2018 19:58
@soumith soumith added the 0.3.1 label Feb 4, 2018
soumith pushed a commit that referenced this pull request Feb 7, 2018
* Implement NLLLossNd

* Fix tests and typos

* Fix tests
laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 24, 2026
* Implement NLLLossNd

* Fix tests and typos

* Fix tests
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants