1212def _test_integration_multiclass (device , output_dict ):
1313
1414 rank = idist .get_rank ()
15- torch .manual_seed (12 )
1615
1716 def _test (metric_device , n_classes , labels = None ):
1817
1918 classification_report = ClassificationReport (device = metric_device , output_dict = output_dict , labels = labels )
2019 n_iters = 80
21- s = 16
22- offset = n_iters * s
23- y_true = torch .randint (0 , n_classes , size = (offset * idist . get_world_size () ,)).to (device )
24- y_preds = torch .rand (offset * idist . get_world_size () , n_classes ).to (device )
20+ batch_size = 16
21+
22+ y_true = torch .randint (0 , n_classes , size = (n_iters * batch_size ,)).to (device )
23+ y_preds = torch .rand (n_iters * batch_size , n_classes ).to (device )
2524
2625 def update (engine , i ):
2726 return (
28- y_preds [i * s + rank * offset : (i + 1 ) * s + rank * offset , :],
29- y_true [i * s + rank * offset : (i + 1 ) * s + rank * offset ],
27+ y_preds [i * batch_size : (i + 1 ) * batch_size , :],
28+ y_true [i * batch_size : (i + 1 ) * batch_size ],
3029 )
3130
3231 engine = Engine (update )
@@ -36,6 +35,9 @@ def update(engine, i):
3635 data = list (range (n_iters ))
3736 engine .run (data = data )
3837
38+ y_preds = idist .all_gather (y_preds )
39+ y_true = idist .all_gather (y_true )
40+
3941 assert "cr" in engine .state .metrics
4042 res = engine .state .metrics ["cr" ]
4143 res2 = classification_report .compute ()
@@ -60,7 +62,8 @@ def update(engine, i):
6062 assert sklearn_result ["macro avg" ]["recall" ] == pytest .approx (res ["macro avg" ]["recall" ])
6163 assert sklearn_result ["macro avg" ]["f1-score" ] == pytest .approx (res ["macro avg" ]["f1-score" ])
6264
63- for _ in range (5 ):
65+ for i in range (5 ):
66+ torch .manual_seed (12 + rank + i )
6467 # check multiple random inputs as random exact occurencies are rare
6568 metric_devices = ["cpu" ]
6669 if device .type != "xla" :
@@ -77,24 +80,22 @@ def update(engine, i):
7780def _test_integration_multilabel (device , output_dict ):
7881
7982 rank = idist .get_rank ()
80- torch .manual_seed (12 )
8183
8284 def _test (metric_device , n_epochs , labels = None ):
8385
8486 classification_report = ClassificationReport (device = metric_device , output_dict = output_dict , is_multilabel = True )
8587
8688 n_iters = 10
87- s = 16
89+ batch_size = 16
8890 n_classes = 7
8991
90- offset = n_iters * s
91- y_true = torch .randint (0 , 2 , size = (offset * idist .get_world_size (), n_classes , 6 , 8 )).to (device )
92- y_preds = torch .randint (0 , 2 , size = (offset * idist .get_world_size (), n_classes , 6 , 8 )).to (device )
92+ y_true = torch .randint (0 , 2 , size = (n_iters * batch_size , n_classes , 6 , 8 )).to (device )
93+ y_preds = torch .randint (0 , 2 , size = (n_iters * batch_size , n_classes , 6 , 8 )).to (device )
9394
9495 def update (engine , i ):
9596 return (
96- y_preds [i * s + rank * offset : (i + 1 ) * s + rank * offset , ...],
97- y_true [i * s + rank * offset : (i + 1 ) * s + rank * offset , ...],
97+ y_preds [i * batch_size : (i + 1 ) * batch_size , ...],
98+ y_true [i * batch_size : (i + 1 ) * batch_size , ...],
9899 )
99100
100101 engine = Engine (update )
@@ -104,6 +105,9 @@ def update(engine, i):
104105 data = list (range (n_iters ))
105106 engine .run (data = data , max_epochs = n_epochs )
106107
108+ y_preds = idist .all_gather (y_preds )
109+ y_true = idist .all_gather (y_true )
110+
107111 assert "cr" in engine .state .metrics
108112 res = engine .state .metrics ["cr" ]
109113 res2 = classification_report .compute ()
@@ -121,6 +125,7 @@ def update(engine, i):
121125 sklearn_result = sklearn_classification_report (np_y_true , np_y_preds , output_dict = True , zero_division = 1 )
122126
123127 for i in range (n_classes ):
128+ torch .manual_seed (12 + rank + i )
124129 label_i = labels [i ] if labels else str (i )
125130 assert sklearn_result [str (i )]["precision" ] == pytest .approx (res [label_i ]["precision" ])
126131 assert sklearn_result [str (i )]["f1-score" ] == pytest .approx (res [label_i ]["f1-score" ])
0 commit comments