@@ -134,6 +134,121 @@ you can also pass in an :doc:`datamodules <../data/datamodule>` that have overri
134134 # test (pass in datamodule)
135135 trainer.test(datamodule = dm)
136136
137+
138+ Test with Multiple DataLoaders
139+ ==============================
140+
141+ When you need to evaluate your model on multiple test datasets simultaneously (e.g., different domains, conditions, or
142+ evaluation scenarios), PyTorch Lightning supports multiple test dataloaders out of the box.
143+
144+ To use multiple test dataloaders, simply return a list of dataloaders from your ``test_dataloader() `` method:
145+
146+ .. code-block :: python
147+
148+ class LitModel (L .LightningModule ):
149+ def test_dataloader (self ):
150+ return [
151+ DataLoader(clean_test_dataset, batch_size = 32 ),
152+ DataLoader(noisy_test_dataset, batch_size = 32 ),
153+ DataLoader(adversarial_test_dataset, batch_size = 32 ),
154+ ]
155+
156+ When using multiple test dataloaders, your ``test_step `` method **must ** include a ``dataloader_idx `` parameter:
157+
158+ .. code-block :: python
159+
160+ def test_step (self , batch , batch_idx , dataloader_idx : int = 0 ):
161+ x, y = batch
162+ y_hat = self (x)
163+ loss = F.cross_entropy(y_hat, y)
164+
165+ # Use dataloader_idx to handle different test scenarios
166+ return {' test_loss' : loss}
167+
168+ Logging Metrics Per Dataloader
169+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
170+
171+ Lightning provides automatic support for logging metrics per dataloader:
172+
173+ .. code-block :: python
174+
175+ def test_step (self , batch , batch_idx , dataloader_idx : int = 0 ):
176+ x, y = batch
177+ y_hat = self (x)
178+ loss = F.cross_entropy(y_hat, y)
179+ acc = (y_hat.argmax(dim = 1 ) == y).float().mean()
180+
181+ # Lightning automatically adds "/dataloader_idx_X" suffix
182+ self .log(' test_loss' , loss, add_dataloader_idx = True )
183+ self .log(' test_acc' , acc, add_dataloader_idx = True )
184+
185+ return loss
186+
187+ This will create metrics like ``test_loss/dataloader_idx_0 ``, ``test_loss/dataloader_idx_1 ``, etc.
188+
189+ For more meaningful metric names, you can use custom naming where you need to make sure that individual names are
190+ unique across dataloaders.
191+
192+ .. code-block :: python
193+
194+ def test_step (self , batch , batch_idx , dataloader_idx : int = 0 ):
195+ # Define meaningful names for each dataloader
196+ dataloader_names = {0 : " clean" , 1 : " noisy" , 2 : " adversarial" }
197+ dataset_name = dataloader_names.get(dataloader_idx, f " dataset_ { dataloader_idx} " )
198+
199+ # Log with custom names
200+ self .log(f ' test_loss_ { dataset_name} ' , loss, add_dataloader_idx = False )
201+ self .log(f ' test_acc_ { dataset_name} ' , acc, add_dataloader_idx = False )
202+
203+ Processing Entire Datasets Per Dataloader
204+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
205+
206+ To perform calculations on the entire test dataset for each dataloader (e.g., computing overall metrics, creating
207+ visualizations), accumulate results during ``test_step `` and process them in ``on_test_epoch_end ``:
208+
209+ .. code-block :: python
210+
211+ class LitModel (L .LightningModule ):
212+ def __init__ (self ):
213+ super ().__init__ ()
214+ # Store outputs per dataloader
215+ self .test_outputs = {}
216+
217+ def test_step (self , batch , batch_idx , dataloader_idx : int = 0 ):
218+ x, y = batch
219+ y_hat = self (x)
220+ loss = F.cross_entropy(y_hat, y)
221+
222+ # Initialize and store results
223+ if dataloader_idx not in self .test_outputs:
224+ self .test_outputs[dataloader_idx] = {' predictions' : [], ' targets' : []}
225+ self .test_outputs[dataloader_idx][' predictions' ].append(y_hat)
226+ self .test_outputs[dataloader_idx][' targets' ].append(y)
227+ return loss
228+
229+ def on_test_epoch_end (self ):
230+ for dataloader_idx, outputs in self .test_outputs.items():
231+ # Concatenate all predictions and targets for this dataloader
232+ all_predictions = torch.cat(outputs[' predictions' ], dim = 0 )
233+ all_targets = torch.cat(outputs[' targets' ], dim = 0 )
234+
235+ # Calculate metrics on the entire dataset, log and create visualizations
236+ overall_accuracy = (all_predictions.argmax(dim = 1 ) == all_targets).float().mean()
237+ self .log(f ' test_overall_acc_dataloader_ { dataloader_idx} ' , overall_accuracy)
238+ self ._save_results(all_predictions, all_targets, dataloader_idx)
239+
240+ self .test_outputs.clear()
241+
242+ .. note ::
243+ When using multiple test dataloaders, ``trainer.test() `` returns a list of results, one for each dataloader:
244+
245+ .. code-block :: python
246+
247+ results = trainer.test(model)
248+ print (f " Results from { len (results)} test dataloaders: " )
249+ for i, result in enumerate (results):
250+ print (f " Dataloader { i} : { result} " )
251+
137252----------
138253
139254**********
0 commit comments