Skip to content

Commit 1df9932

Browse files
Some improvement in Precision metric docs (#2946)
* Apply improvements * Fix a typo * Revert Makefile change * Update docs/Makefile --------- Co-authored-by: vfdev <vfdev.5@gmail.com>
1 parent f61364f commit 1df9932

File tree

2 files changed

+72
-39
lines changed

2 files changed

+72
-39
lines changed

ignite/metrics/metric.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def __init__(
219219
@abstractmethod
220220
def reset(self) -> None:
221221
"""
222-
Resets the metric to it's initial state.
222+
Resets the metric to its initial state.
223223
224224
By default, this is called at the start of each epoch.
225225
"""
@@ -240,7 +240,7 @@ def update(self, output: Any) -> None:
240240
@abstractmethod
241241
def compute(self) -> Any:
242242
"""
243-
Computes the metric based on it's accumulated state.
243+
Computes the metric based on its accumulated state.
244244
245245
By default, this is called at the end of each epoch.
246246
@@ -273,7 +273,7 @@ def iteration_completed(self, engine: Engine) -> None:
273273
274274
Note:
275275
``engine.state.output`` is used to compute metric values.
276-
The majority of implemented metrics accepts the following formats for ``engine.state.output``:
276+
The majority of implemented metrics accept the following formats for ``engine.state.output``:
277277
``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``. ``y_pred`` and ``y`` can be torch tensors or
278278
list of tensors/numbers if applicable.
279279

ignite/metrics/precision.py

Lines changed: 69 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,8 @@ def _prepare_output(self, output: Sequence[torch.Tensor]) -> Sequence[torch.Tens
6161
num_classes = 2 if self._type == "binary" else y_pred.size(1)
6262
if self._type == "multiclass" and y.max() + 1 > num_classes:
6363
raise ValueError(
64-
f"y_pred contains less classes than y. Number of predicted classes is {num_classes}"
65-
f" and element in y has invalid class = {y.max().item() + 1}."
64+
f"y_pred contains fewer classes than y. Number of classes in the prediction is {num_classes}"
65+
f" and an element in y has invalid class = {y.max().item() + 1}."
6666
)
6767
y = y.view(-1)
6868
if self._type == "binary" and self._average is False:
@@ -86,30 +86,32 @@ def _prepare_output(self, output: Sequence[torch.Tensor]) -> Sequence[torch.Tens
8686

8787
@reinit__is_reduced
8888
def reset(self) -> None:
89-
# `numerator`, `denominator` and `weight` are three variables chosen to be abstract
90-
# representatives of the ones that are measured for cases with different `average` parameters.
91-
# `weight` is only used when `average='weighted'`. Actual value of these three variables is
92-
# as follows.
93-
#
94-
# average='samples':
95-
# numerator (torch.Tensor): sum of metric value for samples
96-
# denominator (int): number of samples
97-
#
98-
# average='weighted':
99-
# numerator (torch.Tensor): number of true positives per class/label
100-
# denominator (torch.Tensor): number of predicted(for precision) or actual(for recall)
101-
# positives per class/label
102-
# weight (torch.Tensor): number of actual positives per class
103-
#
104-
# average='micro':
105-
# numerator (torch.Tensor): sum of number of true positives for classes/labels
106-
# denominator (torch.Tensor): sum of number of predicted(for precision) or actual(for recall) positives
107-
# for classes/labels
108-
#
109-
# average='macro' or boolean or None:
110-
# numerator (torch.Tensor): number of true positives per class/label
111-
# denominator (torch.Tensor): number of predicted(for precision) or actual(for recall)
112-
# positives per class/label
89+
"""
90+
`numerator`, `denominator` and `weight` are three variables chosen to be abstract
91+
representatives of the ones that are measured for cases with different `average` parameters.
92+
`weight` is only used when `average='weighted'`. Actual value of these three variables is
93+
as follows.
94+
95+
average='samples':
96+
numerator (torch.Tensor): sum of metric value for samples
97+
denominator (int): number of samples
98+
99+
average='weighted':
100+
numerator (torch.Tensor): number of true positives per class/label
101+
denominator (torch.Tensor): number of predicted(for precision) or actual(for recall) positives per
102+
class/label.
103+
weight (torch.Tensor): number of actual positives per class
104+
105+
average='micro':
106+
numerator (torch.Tensor): sum of number of true positives for classes/labels
107+
denominator (torch.Tensor): sum of number of predicted(for precision) or actual(for recall) positives for
108+
classes/labels.
109+
110+
average='macro' or boolean or None:
111+
numerator (torch.Tensor): number of true positives per class/label
112+
denominator (torch.Tensor): number of predicted(for precision) or actual(for recall) positives per
113+
class/label.
114+
"""
113115

114116
self._numerator: Union[int, torch.Tensor] = 0
115117
self._denominator: Union[int, torch.Tensor] = 0
@@ -120,16 +122,20 @@ def reset(self) -> None:
120122

121123
@sync_all_reduce("_numerator", "_denominator")
122124
def compute(self) -> Union[torch.Tensor, float]:
123-
# Return value of the metric for `average` options `'weighted'` and `'macro'` is computed as follows.
124-
#
125-
# .. math:: \text{Precision/Recall} = \frac{ numerator }{ denominator } \cdot weight
126-
#
127-
# wherein `weight` is the internal variable `weight` for `'weighted'` option and :math:`1/C`
128-
# for the `macro` one. :math:`C` is the number of classes/labels.
129-
#
130-
# Return value of the metric for `average` options `'micro'`, `'samples'`, `False` and None is as follows.
131-
#
132-
# .. math:: \text{Precision/Recall} = \frac{ numerator }{ denominator }
125+
r"""
126+
Return value of the metric for `average` options `'weighted'` and `'macro'` is computed as follows.
127+
128+
.. math::
129+
\text{Precision/Recall} = \frac{ numerator }{ denominator } \cdot weight
130+
131+
wherein `weight` is the internal variable `_weight` for `'weighted'` option and :math:`1/C`
132+
for the `macro` one. :math:`C` is the number of classes/labels.
133+
134+
Return value of the metric for `average` options `'micro'`, `'samples'`, `False` and None is as follows.
135+
136+
.. math::
137+
\text{Precision/Recall} = \frac{ numerator }{ denominator }
138+
"""
133139

134140
if not self._updated:
135141
raise NotComputableError(
@@ -367,6 +373,33 @@ def thresholded_output_transform(output):
367373

368374
@reinit__is_reduced
369375
def update(self, output: Sequence[torch.Tensor]) -> None:
376+
r"""
377+
Update the metric state using prediction and target.
378+
379+
Args:
380+
output: a binary tuple of tensors (y_pred, y) whose shapes follow the table below. N stands for the batch
381+
dimension, `...` for possible additional dimensions and C for class dimension.
382+
383+
.. list-table::
384+
:widths: 20 10 10 10
385+
:header-rows: 1
386+
387+
* - Output member\\Data type
388+
- Binary
389+
- Multiclass
390+
- Multilabel
391+
* - y_pred
392+
- (N, ...)
393+
- (N, C, ...)
394+
- (N, C, ...)
395+
* - y
396+
- (N, ...)
397+
- (N, ...)
398+
- (N, C, ...)
399+
400+
For binary and multilabel data, both y and y_pred should consist of 0's and 1's, but for multiclass
401+
data, y_pred and y should consist of probabilities and integers respectively.
402+
"""
370403
self._check_shape(output)
371404
self._check_type(output)
372405
y_pred, y, correct = self._prepare_output(output)

0 commit comments

Comments
 (0)