Skip to content

Commit a93e481

Browse files
authored
Doctests for RMSE and MeanPairwiseDistance (#2307)
* Add doctests for RMSE and MeanPairwiseDistance * Add doctests for RMSE and MeanPairwiseDistance * Make metric name consistent
1 parent fb6ba0a commit a93e481

File tree

3 files changed

+60
-0
lines changed

3 files changed

+60
-0
lines changed

docs/source/conf.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,14 @@ def run(self):
341341
from ignite.utils import *
342342
343343
manual_seed(666)
344+
345+
# create default evaluator for doctests
346+
347+
def process_function(engine, batch):
348+
y_pred, y = batch
349+
return y_pred, y
350+
351+
default_evaluator = Engine(process_function)
344352
"""
345353

346354

ignite/metrics/mean_pairwise_distance.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,32 @@ class MeanPairwiseDistance(Metric):
2626
device: specifies which device updates are accumulated on. Setting the
2727
metric's device to be the same as your ``update`` arguments ensures the ``update`` method is
2828
non-blocking. By default, CPU.
29+
30+
Examples:
31+
To use with ``Engine`` and ``process_function``, simply attach the metric instance to the engine.
32+
The output of the engine's ``process_function`` needs to be in the format of
33+
``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y, ...}``. If not, ``output_tranform`` can be added
34+
to the metric to transform the output into the form expected by the metric.
35+
36+
``y_pred`` and ``y`` should have the same shape.
37+
38+
.. testcode::
39+
40+
metric = MeanPairwiseDistance(p=4)
41+
metric.attach(default_evaluator, 'mpd')
42+
preds = torch.Tensor([
43+
[1, 2, 4, 1],
44+
[2, 3, 1, 5],
45+
[1, 3, 5, 1],
46+
[1, 5, 1 ,11]
47+
])
48+
target = preds * 0.75
49+
state = default_evaluator.run([[preds, target]])
50+
print(state.metrics['mpd'])
51+
52+
.. testoutput::
53+
54+
1.5955...
2955
"""
3056

3157
def __init__(

ignite/metrics/root_mean_squared_error.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,32 @@ class RootMeanSquaredError(MeanSquaredError):
2626
device: specifies which device updates are accumulated on. Setting the
2727
metric's device to be the same as your ``update`` arguments ensures the ``update`` method is
2828
non-blocking. By default, CPU.
29+
30+
Examples:
31+
To use with ``Engine`` and ``process_function``, simply attach the metric instance to the engine.
32+
The output of the engine's ``process_function`` needs to be in the format of
33+
``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y, ...}``. If not, ``output_tranform`` can be added
34+
to the metric to transform the output into the form expected by the metric.
35+
36+
``y_pred`` and ``y`` should have the same shape.
37+
38+
.. testcode::
39+
40+
metric = RootMeanSquaredError()
41+
metric.attach(default_evaluator, 'rmse')
42+
preds = torch.Tensor([
43+
[1, 2, 4, 1],
44+
[2, 3, 1, 5],
45+
[1, 3, 5, 1],
46+
[1, 5, 1 ,11]
47+
])
48+
target = preds * 0.75
49+
state = default_evaluator.run([[preds, target]])
50+
print(state.metrics['rmse'])
51+
52+
.. testoutput::
53+
54+
1.956559480312316
2955
"""
3056

3157
def compute(self) -> Union[torch.Tensor, float]:

0 commit comments

Comments
 (0)