Skip to content
Open
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 26 additions & 4 deletions monai/losses/unified_focal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,8 @@ def __init__(
gamma: float = 0.5,
delta: float = 0.7,
reduction: LossReduction | str = LossReduction.MEAN,
include_background: bool = True,
use_softmax: bool = False,
):
"""
Args:
Expand All @@ -171,6 +173,9 @@ def __init__(
gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75.
epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7.
weight : weight for each loss function, if it's none it's 0.5. Defaults to None.
include_background : whether to include the background class in loss calculation. Defaults to True.
use_softmax: whether to use softmax to transform the original logits into probabilities.
If True, softmax is used. If False, sigmoid is used. Defaults to False.

Example:
>>> import torch
Expand All @@ -186,8 +191,12 @@ def __init__(
self.gamma = gamma
self.delta = delta
self.weight: float = weight
self.asy_focal_loss = AsymmetricFocalLoss(gamma=self.gamma, delta=self.delta)
self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss(gamma=self.gamma, delta=self.delta)
self.asy_focal_loss = AsymmetricFocalLoss(to_onehot_y=self.to_onehot_y, gamma=self.gamma, delta=self.delta)
self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss(
to_onehot_y=self.to_onehot_y, gamma=self.gamma, delta=self.delta
)
self.include_background = include_background
self.use_softmax = use_softmax

# TODO: Implement this function to support multiple classes segmentation
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
Expand All @@ -196,8 +205,8 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
y_pred : the shape should be BNH[WD], where N is the number of classes.
It only supports binary segmentation.
The input should be the original logits since it will be transformed by
a sigmoid in the forward function.
y_true : the shape should be BNH[WD], where N is the number of classes.
a sigmoid or softmax in the forward function.
y_true : the shape should be BNH[WD] or B1H[WD], where N is the number of classes.
It only supports binary segmentation.

Raises:
Expand Down Expand Up @@ -226,6 +235,19 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
else:
y_true = one_hot(y_true, num_classes=n_pred_ch)

if not self.include_background:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `include_background=False` ignored.")
else:
# if skipping background, removing first channel
y_pred = y_pred[:, 1:]
y_true = y_true[:, 1:]

if self.use_softmax:
y_pred = torch.softmax(y_pred.float(), dim=1)
else:
y_pred = torch.sigmoid(y_pred.float())

asy_focal_loss = self.asy_focal_loss(y_pred, y_true)
asy_focal_tversky_loss = self.asy_focal_tversky_loss(y_pred, y_true)

Expand Down
Loading