feat: add class_names support to Precision and Recall metrics#3732
feat: add class_names support to Precision and Recall metrics#3732rogueslasher wants to merge 2 commits intopytorch:masterfrom
Conversation
|
ok so like wont this code break There can be more metrics that can benefit from this change. I am not sure though. |
|
@aaishwarymishra would raising errors for those condition a better option or should we add full class_name support to fbeta |
|
I am not sure, adding support for the class names in |
|
Could we please clarify |
If we can compute F1 score per class than we should propagate the dict structure until the output of the F1 otherwise, take |
|
Hey @vfdev-5 , would adding a helper in metric.py to handle dict arithmetic in all operator overloads so the dict propagates through to the Fbeta output be a good approach? Also when average=True and class_names is set, should Fbeta return a scalar float or propagate the dict as is? |
I do not know, what would you suggest? Let's try to make it the most simple possible and intuitively clear. |
Fixes #1466
Description:
Adds an optional
class_namesparameter to_BasePrecisionRecall, allowingcompute()to return a labeleddictinstead of an unnamed tensor whenaverage=Falseoraverage=None. Useful for per-class metric tracking where knowing which score belongs to which class matters for logging and visualization.Check list: