Classification metrics overhaul: input formatting standardization (1/n)#4837
Classification metrics overhaul: input formatting standardization (1/n)#4837Borda merged 53 commits intoLightning-AI:masterfrom tadejsv:cls_metrics_input_formatting
Conversation
|
Thanks for splitting off the PR! Reviewing now |
teddykoker
left a comment
There was a problem hiding this comment.
Overall looks like good changes, just a few small things to fix.
Codecov Report
@@ Coverage Diff @@
## master #4837 +/- ##
=======================================
Coverage 93% 93%
=======================================
Files 129 130 +1
Lines 9397 9527 +130
=======================================
+ Hits 8713 8843 +130
Misses 684 684 |
SkafteNicki
left a comment
There was a problem hiding this comment.
overall very good :)
Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>
|
Is there anything else that needs to be done before this PR can be merged? |
|
@tadejsv, thanks for the further description of the |
|
@tadejsv mind resolve conflicts :] probably after #4549 |
Borda
left a comment
There was a problem hiding this comment.
LGTM, just would be nice to have tests also for all the helper functions raising some kind of exception...
…orch-lightning into cls_metrics_input_formatting
|
Alright, merge conflicts resolved, ready for final review. @SkafteNicki please double check that docs are ok (git diff not useful there). |
SkafteNicki
left a comment
There was a problem hiding this comment.
LGTM, docs looks fine :]
awaelchli
left a comment
There was a problem hiding this comment.
High-level review, looks good, nice docs
This PR is a spin-off from #4835. It should be merged before any other spin offs, as it provides a base for all of them
What does this PR do?
General (fundamental) changes
I have created a new
_input_format_classificationfunction (inmetrics/classification/utils). The job of this function is to a) validate, and b) transform the inputs into a common format. This common format is a binary label indicator array: either(N, C), or(N, C, X)(only for multi-dimensional multi-class inputs).I believe that having such a "central" function is crucial, as it gets rid of code duplication (which was present in PL metrics before), and enables metric developers to focus on developing the metrics themselves, and not on standardizing and validating inputs.
The validation performed on the inputs basically makes sure that they fall into one of the possible input type cases, that the values are consistent with both the type of the inputs and the additional parameters set (e.g. that there is no label higher than
num_classesin target). The docstrings (and the new "Input types" section in the documentation) give all the details about how the standardization and validation are performed.Here I'll list the parameters of this function (many of which are also present on some metrics), and why I decided to use them:
threshold: The probability threshold for binarizing binary and multi-label inputs.num_classes: number of classes. Used to either decide theCdimension of inputs, or, if this is already implicitly given, to ensure consistency between inputs and number of classes the user specified when creating the metric (thus ignoring either having to chech this manually inupdatefor each metric, or raising error when updating the state, which may not be very clearto the user).
top_k: for (multi-dimensional) multi-class, if predictions are given as probabilities, selects the top k highest probabilities per sample. It's a generalization of the usual procedure, withk=1. This will be used by theAccuracymetric in subsequent PRs.is_multiclass: used for transforming binary or multi-label input to 2-class multi-class and 2-class multi-dimensional multi-class, respectively. And vice versa.Why? This is similar to
multilabelargument that was (is?) present on some metrics. I believe this is a better name for it, as it also deals with transforming to/from binary. But why is it needed? There are cases where it is not clear what the inputs are: for example, say that both preds and target are of the form [0,1,0,1]. This actually appears to be multi-class (could be the case that is simply happened in this batch that there were only 0s and 1s), so an explicit instruction is needed to tell the metrics that this is in fact binary. On the other hand, sometimes we would like to treat binary inputs as two class inputs - this is the case used in confusion matrix.I also experiemented with using
num_classesto determine this. Besides this being a very confusing approach, requiring several paragraphs to explain clearly, it also does not resolve all ambiguities (is settingnum_classes=1with 2 class probability predictions a request to treat the data as binary, or an inconsitency of inputs that should raise an error?). So I thinkis_multiclassis the best approach here.Documentation
Instead of metrics being organized into "Class Metrics" and "Functional Metrics", they are now organized by topics (Classification, Regression, ...), and within topics split into class and functional, if necessary. This allows to add special topic-related sections - in this case I have added a section on what type of inputs are used for classification metric - a section that metrics can link to, in order to not repeat the same thing 100 times, and to keep docstrings short and to the point.
A second half of the Input types section with examples from StatScores metric will be added in the metric's PR.