[WIP] Support topk classification metrics [ci skip]#3822
[WIP] Support topk classification metrics [ci skip]#3822
Conversation
|
Hello @rohitgr7! Thanks for updating this PR.
Comment last updated at 2020-11-21 23:06:17 UTC |
|
This looks better and cleaner. I guess natively supporting |
Codecov Report
@@ Coverage Diff @@
## master #3822 +/- ##
========================================
- Coverage 92% 77% -15%
========================================
Files 113 118 +5
Lines 8278 11313 +3035
========================================
+ Hits 7632 8766 +1134
- Misses 646 2547 +1901 |
| topk: Optional[int] = None, | ||
| ) -> torch.Tensor: | ||
| if (target.ndim > 1) and (topk is not None): | ||
| raise ValueError( |
There was a problem hiding this comment.
Can you add a test case for this?
There was a problem hiding this comment.
yeah sure, it's in TODO actually.
| topk = 1 | ||
|
|
||
| if topk > pred.size(1): | ||
| raise ValueError( |
There was a problem hiding this comment.
also a test_case for this :)
| pred: predicted labels or probabilities | ||
| target: ground truth labels | ||
| num_classes: number of classes | ||
| topk: number of most likely outcomes considered to find the correct label |
There was a problem hiding this comment.
| topk: number of most likely outcomes considered to find the correct label | |
| topk: number of most likely outcomes considered to find the correct label. Defaults to the usual top1 accuracy |
| pred: predicted labels or probabilities | ||
| target: ground-truth labels | ||
| num_classes: number of classes | ||
| topk: number of most likely outcomes considered to find the correct label |
There was a problem hiding this comment.
| topk: number of most likely outcomes considered to find the correct label | |
| topk: number of most likely outcomes considered to find the correct label. Defaults to the usual top1. |
| pred: predicted labels or probabilities | ||
| target: ground-truth labels | ||
| num_classes: number of classes | ||
| topk: number of most likely outcomes considered to find the correct label |
There was a problem hiding this comment.
| topk: number of most likely outcomes considered to find the correct label | |
| topk: number of most likely outcomes considered to find the correct label. Defaults to the usual top1. |
| pred: predicted labels or probabilities | ||
| target: ground-truth labels | ||
| num_classes: number of classes | ||
| topk: number of most likely outcomes considered to find the correct label |
There was a problem hiding this comment.
| topk: number of most likely outcomes considered to find the correct label | |
| topk: number of most likely outcomes considered to find the correct label. Defaults to the usual top1. |
|
This pull request is now in conflict... :( |
1 similar comment
|
This pull request is now in conflict... :( |
|
Here is a top k accuracy implementation using the new class AccuracyTopK(pl.metrics.Metric):
def __init__(self, top_k=1, dist_sync_on_step=False):
super().__init__(dist_sync_on_step=dist_sync_on_step)
self.k = top_k
self.add_state("correct", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("total", default=torch.tensor(0.0), dist_reduce_fx="sum")
def update(self, logits, y):
_, pred = logits.topk(self.k, dim=1)
pred = pred.t()
corr = pred.eq(y.view(1, -1).expand_as(pred))
self.correct += corr[:self.k].sum()
self.total += y.numel()
def compute(self):
return self.correct.float() / self.total |
|
This is correct. I used the same formula earlier. Can we simply keep |
|
@oke-aditya I think I prefer that as well, especially since |
|
This pull request has been automatically marked as stale because it has not had recent activity. It will be closed in 7 days if no further activity occurs. If you need further help see our docs: https://pytorch-lightning.readthedocs.io/en/latest/CONTRIBUTING.html#pull-request or ask the assistance of a core contributor here or on Slack. Thank you for your contributions. |
|
I guess we can replace |
|
@oke-aditya thanks. I'll complete this one in a few days. The metrics package has been changed and I haven't checked it out yet. |
|
Great. 😀 |
eb8687d to
0fa9697
Compare
|
This pull request has been automatically marked as stale because it has not had recent activity. It will be closed in 7 days if no further activity occurs. If you need further help see our docs: https://pytorch-lightning.readthedocs.io/en/latest/CONTRIBUTING.html#pull-request or ask the assistance of a core contributor here or on Slack. Thank you for your contributions. |
3061742 to
0fa9697
Compare
|
Looks that AccuracyTopK is not available in the current version? Will it be in the next version? |
|
@rohitgr7 @justusschock how is it going here, ready to review/merge? 🐰 |
|
closing this, will be added in #4837 and follow up PRs. |
What does this PR do?
Add support for topk in classification metrics. Let's see how it goes with tests.
TODO:
Before submitting
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
Did you have fun?
Make sure you had fun coding 🙃