Skip to content

Commit 02d4c81

Browse files
puhukvfdev-5
andauthored
Change data generating in classification, fbeta, mae (#2667)
* Change data generating in classification, fbeta, mae Change data generating in `classification_report`, `fbeta`, `mae` * Update with review * Update test_classification_report.py Co-authored-by: vfdev <vfdev.5@gmail.com>
1 parent e64eb4b commit 02d4c81

File tree

3 files changed

+46
-33
lines changed

3 files changed

+46
-33
lines changed

tests/ignite/metrics/test_classification_report.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,21 +12,20 @@
1212
def _test_integration_multiclass(device, output_dict):
1313

1414
rank = idist.get_rank()
15-
torch.manual_seed(12)
1615

1716
def _test(metric_device, n_classes, labels=None):
1817

1918
classification_report = ClassificationReport(device=metric_device, output_dict=output_dict, labels=labels)
2019
n_iters = 80
21-
s = 16
22-
offset = n_iters * s
23-
y_true = torch.randint(0, n_classes, size=(offset * idist.get_world_size(),)).to(device)
24-
y_preds = torch.rand(offset * idist.get_world_size(), n_classes).to(device)
20+
batch_size = 16
21+
22+
y_true = torch.randint(0, n_classes, size=(n_iters * batch_size,)).to(device)
23+
y_preds = torch.rand(n_iters * batch_size, n_classes).to(device)
2524

2625
def update(engine, i):
2726
return (
28-
y_preds[i * s + rank * offset : (i + 1) * s + rank * offset, :],
29-
y_true[i * s + rank * offset : (i + 1) * s + rank * offset],
27+
y_preds[i * batch_size : (i + 1) * batch_size, :],
28+
y_true[i * batch_size : (i + 1) * batch_size],
3029
)
3130

3231
engine = Engine(update)
@@ -36,6 +35,9 @@ def update(engine, i):
3635
data = list(range(n_iters))
3736
engine.run(data=data)
3837

38+
y_preds = idist.all_gather(y_preds)
39+
y_true = idist.all_gather(y_true)
40+
3941
assert "cr" in engine.state.metrics
4042
res = engine.state.metrics["cr"]
4143
res2 = classification_report.compute()
@@ -60,7 +62,8 @@ def update(engine, i):
6062
assert sklearn_result["macro avg"]["recall"] == pytest.approx(res["macro avg"]["recall"])
6163
assert sklearn_result["macro avg"]["f1-score"] == pytest.approx(res["macro avg"]["f1-score"])
6264

63-
for _ in range(5):
65+
for i in range(5):
66+
torch.manual_seed(12 + rank + i)
6467
# check multiple random inputs as random exact occurencies are rare
6568
metric_devices = ["cpu"]
6669
if device.type != "xla":
@@ -77,24 +80,22 @@ def update(engine, i):
7780
def _test_integration_multilabel(device, output_dict):
7881

7982
rank = idist.get_rank()
80-
torch.manual_seed(12)
8183

8284
def _test(metric_device, n_epochs, labels=None):
8385

8486
classification_report = ClassificationReport(device=metric_device, output_dict=output_dict, is_multilabel=True)
8587

8688
n_iters = 10
87-
s = 16
89+
batch_size = 16
8890
n_classes = 7
8991

90-
offset = n_iters * s
91-
y_true = torch.randint(0, 2, size=(offset * idist.get_world_size(), n_classes, 6, 8)).to(device)
92-
y_preds = torch.randint(0, 2, size=(offset * idist.get_world_size(), n_classes, 6, 8)).to(device)
92+
y_true = torch.randint(0, 2, size=(n_iters * batch_size, n_classes, 6, 8)).to(device)
93+
y_preds = torch.randint(0, 2, size=(n_iters * batch_size, n_classes, 6, 8)).to(device)
9394

9495
def update(engine, i):
9596
return (
96-
y_preds[i * s + rank * offset : (i + 1) * s + rank * offset, ...],
97-
y_true[i * s + rank * offset : (i + 1) * s + rank * offset, ...],
97+
y_preds[i * batch_size : (i + 1) * batch_size, ...],
98+
y_true[i * batch_size : (i + 1) * batch_size, ...],
9899
)
99100

100101
engine = Engine(update)
@@ -104,6 +105,9 @@ def update(engine, i):
104105
data = list(range(n_iters))
105106
engine.run(data=data, max_epochs=n_epochs)
106107

108+
y_preds = idist.all_gather(y_preds)
109+
y_true = idist.all_gather(y_true)
110+
107111
assert "cr" in engine.state.metrics
108112
res = engine.state.metrics["cr"]
109113
res2 = classification_report.compute()
@@ -121,6 +125,7 @@ def update(engine, i):
121125
sklearn_result = sklearn_classification_report(np_y_true, np_y_preds, output_dict=True, zero_division=1)
122126

123127
for i in range(n_classes):
128+
torch.manual_seed(12 + rank + i)
124129
label_i = labels[i] if labels else str(i)
125130
assert sklearn_result[str(i)]["precision"] == pytest.approx(res[label_i]["precision"])
126131
assert sklearn_result[str(i)]["f1-score"] == pytest.approx(res[label_i]["f1-score"])

tests/ignite/metrics/test_fbeta.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -91,21 +91,21 @@ def update_fn(engine, batch):
9191
def _test_distrib_integration(device):
9292

9393
rank = idist.get_rank()
94-
torch.manual_seed(12)
9594

9695
def _test(p, r, average, n_epochs, metric_device):
9796
n_iters = 60
98-
s = 16
97+
batch_size = 16
9998
n_classes = 7
10099

101-
offset = n_iters * s
102-
y_true = torch.randint(0, n_classes, size=(offset * idist.get_world_size(),)).to(device)
103-
y_preds = torch.rand(offset * idist.get_world_size(), n_classes).to(device)
100+
torch.manual_seed(12 + rank)
101+
102+
y_true = torch.randint(0, n_classes, size=(n_iters * batch_size,)).to(device)
103+
y_preds = torch.rand(n_iters * batch_size, n_classes).to(device)
104104

105105
def update(engine, i):
106106
return (
107-
y_preds[i * s + rank * offset : (i + 1) * s + rank * offset, :],
108-
y_true[i * s + rank * offset : (i + 1) * s + rank * offset],
107+
y_preds[i * batch_size : (i + 1) * batch_size, :],
108+
y_true[i * batch_size : (i + 1) * batch_size],
109109
)
110110

111111
engine = Engine(update)
@@ -116,6 +116,9 @@ def update(engine, i):
116116
data = list(range(n_iters))
117117
engine.run(data=data, max_epochs=n_epochs)
118118

119+
y_preds = idist.all_gather(y_preds)
120+
y_true = idist.all_gather(y_true)
121+
119122
assert "f2.5" in engine.state.metrics
120123
res = engine.state.metrics["f2.5"]
121124
if isinstance(res, torch.Tensor):

tests/ignite/metrics/test_mean_absolute_error.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -59,20 +59,22 @@ def _test_distrib_integration(device):
5959
from ignite.engine import Engine
6060

6161
rank = idist.get_rank()
62-
n_iters = 80
63-
s = 50
64-
offset = n_iters * s
6562

66-
y_true = torch.arange(0, offset * idist.get_world_size(), dtype=torch.float).to(device)
67-
y_preds = torch.ones(offset * idist.get_world_size(), dtype=torch.float).to(device)
63+
def _test(metric_device):
6864

69-
def update(engine, i):
70-
return (
71-
y_preds[i * s + offset * rank : (i + 1) * s + offset * rank],
72-
y_true[i * s + offset * rank : (i + 1) * s + offset * rank],
73-
)
65+
n_iters = 80
66+
batch_size = 50
67+
torch.manual_seed(12 + rank)
68+
69+
y_true = torch.arange(0, n_iters * batch_size, dtype=torch.float).to(device)
70+
y_preds = torch.ones(n_iters * batch_size, dtype=torch.float).to(device)
71+
72+
def update(engine, i):
73+
return (
74+
y_preds[i * batch_size : (i + 1) * batch_size],
75+
y_true[i * batch_size : (i + 1) * batch_size],
76+
)
7477

75-
def _test(metric_device):
7678
engine = Engine(update)
7779

7880
m = MeanAbsoluteError(device=metric_device)
@@ -81,6 +83,9 @@ def _test(metric_device):
8183
data = list(range(n_iters))
8284
engine.run(data=data, max_epochs=1)
8385

86+
y_preds = idist.all_gather(y_preds)
87+
y_true = idist.all_gather(y_true)
88+
8489
assert "mae" in engine.state.metrics
8590
res = engine.state.metrics["mae"]
8691

0 commit comments

Comments
 (0)