@@ -244,13 +244,14 @@ def _setup_metrics(self):
244244 else :
245245 self .metrics = self .custom_metrics
246246
247- def calculate_loss (self , output : Dict , y : torch .Tensor , tag : str ) -> torch .Tensor :
247+ def calculate_loss (self , output : Dict , y : torch .Tensor , tag : str , sync_dist : bool = False ) -> torch .Tensor :
248248 """Calculates the loss for the model.
249249
250250 Args:
251251 output (Dict): The output dictionary from the model
252252 y (torch.Tensor): The target tensor
253253 tag (str): The tag to use for logging
254+ sync_dist (bool): enable distributed sync of logs
254255
255256 Returns:
256257 torch.Tensor: The loss value
@@ -270,6 +271,7 @@ def calculate_loss(self, output: Dict, y: torch.Tensor, tag: str) -> torch.Tenso
270271 on_step = False ,
271272 logger = True ,
272273 prog_bar = False ,
274+ sync_dist = sync_dist ,
273275 )
274276 if self .hparams .task == "regression" :
275277 computed_loss = reg_loss
@@ -284,6 +286,7 @@ def calculate_loss(self, output: Dict, y: torch.Tensor, tag: str) -> torch.Tenso
284286 on_step = False ,
285287 logger = True ,
286288 prog_bar = False ,
289+ sync_dist = sync_dist ,
287290 )
288291 else :
289292 # TODO loss fails with batch size of 1?
@@ -301,6 +304,7 @@ def calculate_loss(self, output: Dict, y: torch.Tensor, tag: str) -> torch.Tenso
301304 on_step = False ,
302305 logger = True ,
303306 prog_bar = False ,
307+ sync_dist = sync_dist ,
304308 )
305309 start_index = end_index
306310 self .log (
@@ -311,10 +315,13 @@ def calculate_loss(self, output: Dict, y: torch.Tensor, tag: str) -> torch.Tenso
311315 # on_step=False,
312316 logger = True ,
313317 prog_bar = True ,
318+ sync_dist = sync_dist ,
314319 )
315320 return computed_loss
316321
317- def calculate_metrics (self , y : torch .Tensor , y_hat : torch .Tensor , tag : str ) -> List [torch .Tensor ]:
322+ def calculate_metrics (
323+ self , y : torch .Tensor , y_hat : torch .Tensor , tag : str , sync_dist : bool = False
324+ ) -> List [torch .Tensor ]:
318325 """Calculates the metrics for the model.
319326
320327 Args:
@@ -324,6 +331,8 @@ def calculate_metrics(self, y: torch.Tensor, y_hat: torch.Tensor, tag: str) -> L
324331
325332 tag (str): The tag to use for logging
326333
334+ sync_dist (bool): enable distributed sync of logs
335+
327336 Returns:
328337 List[torch.Tensor]: The list of metric values
329338
@@ -356,6 +365,7 @@ def calculate_metrics(self, y: torch.Tensor, y_hat: torch.Tensor, tag: str) -> L
356365 on_step = False ,
357366 logger = True ,
358367 prog_bar = False ,
368+ sync_dist = sync_dist ,
359369 )
360370 _metrics .append (_metric )
361371 avg_metric = torch .stack (_metrics , dim = 0 ).sum ()
@@ -379,6 +389,7 @@ def calculate_metrics(self, y: torch.Tensor, y_hat: torch.Tensor, tag: str) -> L
379389 on_step = False ,
380390 logger = True ,
381391 prog_bar = False ,
392+ sync_dist = sync_dist ,
382393 )
383394 _metrics .append (_metric )
384395 start_index = end_index
@@ -391,6 +402,7 @@ def calculate_metrics(self, y: torch.Tensor, y_hat: torch.Tensor, tag: str) -> L
391402 on_step = False ,
392403 logger = True ,
393404 prog_bar = True ,
405+ sync_dist = sync_dist ,
394406 )
395407 return metrics
396408
@@ -523,19 +535,19 @@ def validation_step(self, batch, batch_idx):
523535 # fetched from the batch
524536 y = batch ["target" ] if y is None else y
525537 y_hat = output ["logits" ]
526- self .calculate_loss (output , y , tag = "valid" )
527- self .calculate_metrics (y , y_hat , tag = "valid" )
538+ self .calculate_loss (output , y , tag = "valid" , sync_dist = True )
539+ self .calculate_metrics (y , y_hat , tag = "valid" , sync_dist = True )
528540 return y_hat , y
529541
530542 def test_step (self , batch , batch_idx ):
531543 with torch .no_grad ():
532544 output , y = self .forward_pass (batch )
533- # y is not None for SSL task.Rest of the tasks target is
545+ # y is not None for SSL task. Rest of the tasks target is
534546 # fetched from the batch
535547 y = batch ["target" ] if y is None else y
536548 y_hat = output ["logits" ]
537- self .calculate_loss (output , y , tag = "test" )
538- self .calculate_metrics (y , y_hat , tag = "test" )
549+ self .calculate_loss (output , y , tag = "test" , sync_dist = True )
550+ self .calculate_metrics (y , y_hat , tag = "test" , sync_dist = True )
539551 return y_hat , y
540552
541553 def configure_optimizers (self ):
0 commit comments