[primTorch] Implement NLL loss reference#81128
[primTorch] Implement NLL loss reference#81128rdspring1 wants to merge 29 commits intopytorch:masterfrom
Conversation
🔗 Helpful links
❌ 6 New FailuresAs of commit 370bc60 (more details on the Dr. CI page): Expand to see more
🕵️ 6 new failures recognized by patternsThe following CI failures do not appear to be due to upstream breakages
|
|
#79820 was merged. Is advanced indexing still a blocker? What exactly doesn't work? |
|
https://github.com/pytorch/pytorch/pull/81128/files#diff-93c7b95139f636278cc494028e322a2c3c3c9ba1e83b2adb35d54ccabed5b47aR431 |
|
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/81128
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 3cd82ab: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| else: | ||
| result = _nll_loss_nd(input, target, weight, reduction, ignore_index) | ||
| return torch.reshape(result, out_size) | ||
| else: |
There was a problem hiding this comment.
Add a comment describing what this else branch is for.
From a code organization and readability standpoing these branches seem a little odd. Maybe we can explain them better?
In particular -- can input be zero or one dimension? If so, how do we interpret that? The documentation for suggests that input should have at least two dimensions. And why are inputs with three or four dimensions special?
Finally, prefer putting shorter branches which short-circuit first. That typically lets code have fewer indentation levels:
# shortcircuits if foo because...
if foo:
return x
# implicit else branch doesn't have to be indented
...
There was a problem hiding this comment.
I refactored _nll_loss_nd to handle 1-3 dimensions. If there are more than 3 dimensions, the k-dimension is flattened to create a 3D tensor. The Aten implementation used a 4D case for image inputs.
# The _nll_loss_nd helper function handles the most common cases.
# ndim == 1 (Single Example)
# => Batch Size: 1, Input: (C), Target: ()
# ndim == 2 (k = 1)
# => Batch Size: N, Input: (N, C), Target: (N)
# ndim == 3 (k > 1)
# => Batch Size: N, Input: (N, C, K), Target: (N, K)
# ndim > 3
# => reshape the input and target to the 3-D case
There was a problem hiding this comment.
Is the 4D case interesting to model here?
|
/easycla As part of the transition to the PyTorch Foundation, this project now requires contributions be covered under the new CLA. See #85559 for additional details. This comment will trigger a new check of this PR. If you are already covered, you will simply see a new "EasyCLA" check that passes. If you are not covered, a bot will leave a new comment with a link to sign. |
| utils.check( | ||
| isinstance(target, FakeTensor) or bool(class_check.item()), | ||
| lambda: "A target class is out-of-bounds and not the ignore index.", | ||
| ) |
There was a problem hiding this comment.
Comment this out for now until we have a debug mode for data-dependent checks.
mruberry
left a comment
There was a problem hiding this comment.
Cool! Let's just update the data-dependent check per @IvanYashchuk's comment
|
|
||
|
|
||
| @register_decomposition(torch.ops.aten.nll_loss) | ||
| def nll_loss( |
There was a problem hiding this comment.
try wrapping with type promotion decorator
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
|
Hey @rdspring1. |
Add Reference:
Depends on: