Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
288 changes: 234 additions & 54 deletions monai/losses/cldice.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,18 @@

from __future__ import annotations

import warnings
from collections.abc import Callable

import torch
import torch.nn.functional as F
from torch.nn.modules.loss import _Loss

from monai.losses.dice import DiceLoss
from monai.networks import one_hot
from monai.utils import LossReduction
from monai.utils.deprecate_utils import deprecated_arg


def soft_erode(img: torch.Tensor) -> torch.Tensor: # type: ignore
"""
Expand Down Expand Up @@ -92,26 +100,6 @@ def soft_skel(img: torch.Tensor, iter_: int) -> torch.Tensor:
return skel


def soft_dice(y_true: torch.Tensor, y_pred: torch.Tensor, smooth: float = 1.0) -> torch.Tensor:
"""
Function to compute soft dice loss

Adapted from:
https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/cldice.py#L22

Args:
y_true: the shape should be BCH(WD)
y_pred: the shape should be BCH(WD)

Returns:
dice loss
"""
intersection = torch.sum((y_true * y_pred)[:, 1:, ...])
coeff = (2.0 * intersection + smooth) / (torch.sum(y_true[:, 1:, ...]) + torch.sum(y_pred[:, 1:, ...]) + smooth)
soft_dice: torch.Tensor = 1.0 - coeff
return soft_dice


class SoftclDiceLoss(_Loss):
"""
Compute the Soft clDice loss defined in:
Expand All @@ -121,64 +109,256 @@ class SoftclDiceLoss(_Loss):

Adapted from:
https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/cldice.py#L7

The data `input` (BNHW[D] where N is number of classes) is compared with ground truth `target` (BNHW[D]).
Note that axis N of `input` is expected to be logits or probabilities for each class, if passing logits as input,
must set `sigmoid=True` or `softmax=True`, or specifying `other_act`. And the same axis of `target`
can be 1 or N (one-hot format).

"""

def __init__(self, iter_: int = 3, smooth: float = 1.0) -> None:
def __init__(
self,
iter_: int = 3,
smooth_nr: float = 1.0,
smooth_dr: float = 1.0,
include_background: bool = True,
to_onehot_y: bool = False,
sigmoid: bool = False,
softmax: bool = False,
other_act: Callable | None = None,
reduction: LossReduction | str = LossReduction.MEAN,
) -> None:
"""
Args:
iter_: Number of iterations for skeletonization
smooth: Smoothing parameter
iter_: Number of iterations for skeletonization. Must be a non-negative integer.
smooth_nr: a small constant added to the numerator to avoid zero.
smooth_dr: a small constant added to the denominator to avoid nan.
include_background: if False, channel index 0 (background category) is excluded from the calculation.
if the non-background segmentations are small compared to the total image size they can get overwhelmed
by the signal from the background so excluding it in such cases helps convergence.
to_onehot_y: whether to convert the ``target`` into the one-hot format,
using the number of classes inferred from `input` (``input.shape[1]``). Defaults to False.
sigmoid: if True, apply a sigmoid function to the prediction.
softmax: if True, apply a softmax function to the prediction.
other_act: callable function to execute other activation layers, Defaults to ``None``. for example:
``other_act = torch.tanh``.
reduction: {``"none"``, ``"mean"``, ``"sum"``}
Specifies the reduction to apply to the output. Defaults to ``"mean"``.

- ``"none"``: no reduction will be applied.
- ``"mean"``: the sum of the output will be divided by the number of elements in the output.
- ``"sum"``: the output will be summed.

Raises:
TypeError: When ``other_act`` is not an ``Optional[Callable]``.
ValueError: When more than 1 of [``sigmoid=True``, ``softmax=True``, ``other_act is not None``].
Incompatible values.

Comment on lines +153 to +157
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Raises section omits iter_-related exceptions.

Both the TypeError (non-integer iter_) and ValueError (negative iter_) raised at Lines 165–167 are absent from the docstring.

📝 Proposed fix
     Raises:
         TypeError: When ``other_act`` is not an ``Optional[Callable]``.
         ValueError: When more than 1 of [``sigmoid=True``, ``softmax=True``, ``other_act is not None``].
             Incompatible values.
+        TypeError: When ``iter_`` is not an integer.
+        ValueError: When ``iter_`` is negative.
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
Raises:
TypeError: When ``other_act`` is not an ``Optional[Callable]``.
ValueError: When more than 1 of [``sigmoid=True``, ``softmax=True``, ``other_act is not None``].
Incompatible values.
Raises:
TypeError: When ``other_act`` is not an ``Optional[Callable]``.
ValueError: When more than 1 of [``sigmoid=True``, ``softmax=True``, ``other_act is not None``].
Incompatible values.
TypeError: When ``iter_`` is not an integer.
ValueError: When ``iter_`` is negative.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@monai/losses/cldice.py` around lines 153 - 157, The docstring's Raises
section for the cldice loss is missing the exceptions related to the iter_
parameter; update the Raises block in the cldice docstring (the docstring for
the clDice loss class/function that accepts the iter_ parameter) to include a
TypeError when iter_ is not an int and a ValueError when iter_ is negative,
matching the existing phrasing/style used for other exceptions (i.e., list
"TypeError: When ``iter_`` is not an ``int``." and "ValueError: When ``iter_``
is negative.").

"""
super().__init__()
super().__init__(reduction=LossReduction(reduction).value)
if other_act is not None and not callable(other_act):
raise TypeError(f"other_act must be None or callable but is {type(other_act).__name__}.")
if int(sigmoid) + int(softmax) + int(other_act is not None) > 1:
raise ValueError("Incompatible values: more than 1 of [sigmoid=True, softmax=True, other_act is not None].")
if not isinstance(iter_, int):
raise TypeError(f"iter_ must be an integer but got {type(iter_).__name__}.")
if iter_ < 0:
raise ValueError(f"iter_ must be a non-negative integer but got {iter_}.")
self.iter = iter_
self.smooth = smooth
self.smooth_nr = float(smooth_nr)
self.smooth_dr = float(smooth_dr)
Comment on lines +159 to +170
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

smooth_dr=0 still produces NaN; no validation added.

With iter_=0 and an all-zero input, skel_pred is all-zero, so torch.sum(skel_pred, dim=reduce_axis) + self.smooth_dr becomes 0 when smooth_dr=0.0, causing NaN. The defaults of 1.0 guard normal usage, but an explicit smooth_dr=0.0 slips through silently.

🛡️ Proposed fix
+        if smooth_dr <= 0:
+            raise ValueError(f"smooth_dr must be positive but got {smooth_dr}.")
         self.iter = iter_
         self.smooth_nr = float(smooth_nr)
         self.smooth_dr = float(smooth_dr)
🧰 Tools
🪛 Ruff (0.15.1)

[warning] 161-161: Avoid specifying long messages outside the exception class

(TRY003)


[warning] 163-163: Avoid specifying long messages outside the exception class

(TRY003)


[warning] 165-165: Avoid specifying long messages outside the exception class

(TRY003)


[warning] 167-167: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@monai/losses/cldice.py` around lines 159 - 170, Validate the smoothing
denominator to prevent division-by-zero: in the constructor where iter_,
smooth_nr and smooth_dr are handled (the __init__ that sets self.iter and
self.smooth_dr), ensure smooth_dr is a numeric (coerce to float) and strictly
greater than 0 (raise ValueError if <= 0 or not convertible), then assign
self.smooth_dr = float(smooth_dr); this prevents torch.sum(...)+self.smooth_dr
from becoming zero when iter_=0 and inputs are all zeros.

self.include_background = include_background
self.to_onehot_y = to_onehot_y
self.sigmoid = sigmoid
self.softmax = softmax
self.other_act = other_act

@deprecated_arg("y_pred", since="1.5", removed="1.8", new_name="input", msg_suffix="please use `input` instead.")
@deprecated_arg("y_true", since="1.5", removed="1.8", new_name="target", msg_suffix="please use `target` instead.")
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

y_true and y_pred will have to be marked as deprecated arguments here, in case anyone has used them by name.

"""
Args:
input: the shape should be BNH[WD], where N is the number of classes.
target: the shape should be BNH[WD] or B1H[WD], where N is the number of classes.

Raises:
AssertionError: When input and target (after one hot transform if set)
have different shapes.

"""
n_pred_ch = input.shape[1]

if self.sigmoid:
input = torch.sigmoid(input)

if self.softmax:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `softmax=True` ignored.", stacklevel=2)
else:
input = torch.softmax(input, dim=1)

def forward(self, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor:
skel_pred = soft_skel(y_pred, self.iter)
skel_true = soft_skel(y_true, self.iter)
tprec = (torch.sum(torch.multiply(skel_pred, y_true)[:, 1:, ...]) + self.smooth) / (
torch.sum(skel_pred[:, 1:, ...]) + self.smooth
if self.other_act is not None:
input = self.other_act(input)

if self.to_onehot_y:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2)
else:
target = one_hot(target, num_classes=n_pred_ch)

if not self.include_background:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `include_background=False` ignored.", stacklevel=2)
else:
target = target[:, 1:]
input = input[:, 1:]

if target.shape != input.shape:
raise AssertionError(f"ground truth has different shape ({target.shape}) from input ({input.shape})")

skel_pred = soft_skel(input, self.iter)
skel_true = soft_skel(target, self.iter)

# Compute per-batch clDice by reducing over channel and spatial dimensions
# reduce_axis includes all dimensions except batch (dim 0)
reduce_axis: list[int] = list(range(1, len(input.shape)))

tprec = (torch.sum(torch.multiply(skel_pred, target), dim=reduce_axis) + self.smooth_nr) / (
torch.sum(skel_pred, dim=reduce_axis) + self.smooth_dr
)
tsens = (torch.sum(torch.multiply(skel_true, y_pred)[:, 1:, ...]) + self.smooth) / (
torch.sum(skel_true[:, 1:, ...]) + self.smooth
tsens = (torch.sum(torch.multiply(skel_true, input), dim=reduce_axis) + self.smooth_nr) / (
torch.sum(skel_true, dim=reduce_axis) + self.smooth_dr
)
cl_dice: torch.Tensor = 1.0 - 2.0 * (tprec * tsens) / (tprec + tsens)
# Add small epsilon for numerical stability in harmonic mean
cl_dice: torch.Tensor = 1.0 - 2.0 * (tprec * tsens) / (tprec + tsens + 1e-7)

# Apply reduction
if self.reduction == LossReduction.MEAN.value:
cl_dice = torch.mean(cl_dice)
elif self.reduction == LossReduction.SUM.value:
cl_dice = torch.sum(cl_dice)
elif self.reduction == LossReduction.NONE.value:
pass # keep per-batch values
else:
raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].')

return cl_dice


class SoftDiceclDiceLoss(_Loss):
"""
Compute the Soft clDice loss defined in:
Compute both Dice loss and clDice loss, and return the weighted sum of these two losses.
The details of Dice loss is shown in ``monai.losses.DiceLoss``.
The details of clDice loss is shown in ``monai.losses.SoftclDiceLoss``.

Adapted from:
Shit et al. (2021) clDice -- A Novel Topology-Preserving Loss Function
for Tubular Structure Segmentation. (https://arxiv.org/abs/2003.07311)

Adapted from:
https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/cldice.py#L38
"""

def __init__(self, iter_: int = 3, alpha: float = 0.5, smooth: float = 1.0) -> None:
def __init__(
self,
iter_: int = 3,
alpha: float = 0.5,
smooth_nr: float = 1.0,
smooth_dr: float = 1.0,
include_background: bool = True,
to_onehot_y: bool = False,
sigmoid: bool = False,
softmax: bool = False,
other_act: Callable | None = None,
reduction: LossReduction | str = LossReduction.MEAN,
) -> None:
"""
Args:
iter_: Number of iterations for skeletonization
smooth: Smoothing parameter
alpha: Weighing factor for cldice
iter_: Number of iterations for skeletonization, used by clDice. Must be a non-negative integer.
alpha: Weighing factor for cldice component. Total loss = (1 - alpha) * dice + alpha * cldice.
Defaults to 0.5.
smooth_nr: a small constant added to the numerator to avoid zero, used by both Dice and clDice.
smooth_dr: a small constant added to the denominator to avoid nan, used by both Dice and clDice.
include_background: if False, channel index 0 (background category) is excluded from the calculation.
if the non-background segmentations are small compared to the total image size they can get overwhelmed
by the signal from the background so excluding it in such cases helps convergence.
to_onehot_y: whether to convert the ``target`` into the one-hot format,
using the number of classes inferred from `input` (``input.shape[1]``). Defaults to False.
sigmoid: if True, apply a sigmoid function to the prediction.
softmax: if True, apply a softmax function to the prediction.
other_act: callable function to execute other activation layers, Defaults to ``None``. for example:
``other_act = torch.tanh``.
reduction: {``"none"``, ``"mean"``, ``"sum"``}
Specifies the reduction to apply to the output. Defaults to ``"mean"``.

- ``"none"``: no reduction will be applied.
- ``"mean"``: the sum of the output will be divided by the number of elements in the output.
- ``"sum"``: the output will be summed.

Raises:
TypeError: When ``other_act`` is not an ``Optional[Callable]``.
ValueError: When more than 1 of [``sigmoid=True``, ``softmax=True``, ``other_act is not None``].
Incompatible values.

"""
Comment on lines +297 to 302
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Raises section missing ValueError for alpha out of [0, 1].

Line 305 raises it, but the docstring doesn't document it.

📝 Proposed fix
     Raises:
         TypeError: When ``other_act`` is not an ``Optional[Callable]``.
         ValueError: When more than 1 of [``sigmoid=True``, ``softmax=True``, ``other_act is not None``].
             Incompatible values.
+        ValueError: When ``alpha`` is not in ``[0, 1]``.
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
Raises:
TypeError: When ``other_act`` is not an ``Optional[Callable]``.
ValueError: When more than 1 of [``sigmoid=True``, ``softmax=True``, ``other_act is not None``].
Incompatible values.
"""
Raises:
TypeError: When ``other_act`` is not an ``Optional[Callable]``.
ValueError: When more than 1 of [``sigmoid=True``, ``softmax=True``, ``other_act is not None``].
Incompatible values.
ValueError: When ``alpha`` is not in ``[0, 1]``.
"""
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@monai/losses/cldice.py` around lines 297 - 302, Add a ValueError entry to the
existing Raises section of the docstring in monai/losses/cldice.py to document
that the parameter alpha must be within [0, 1]; specifically, add a line like
"ValueError: When ``alpha`` is not in ``[0, 1]``." to the Raises block for the
callable that accepts the alpha parameter (the function/class docstring where
alpha is validated).

super().__init__()
self.iter = iter_
self.smooth = smooth
self.alpha = alpha

def forward(self, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor:
dice = soft_dice(y_true, y_pred, self.smooth)
skel_pred = soft_skel(y_pred, self.iter)
skel_true = soft_skel(y_true, self.iter)
tprec = (torch.sum(torch.multiply(skel_pred, y_true)[:, 1:, ...]) + self.smooth) / (
torch.sum(skel_pred[:, 1:, ...]) + self.smooth
if not 0.0 <= alpha <= 1.0:
raise ValueError(f"alpha must be in [0, 1] but got {alpha}.")
self.dice = DiceLoss(
include_background=include_background,
to_onehot_y=False,
sigmoid=sigmoid,
softmax=softmax,
other_act=other_act,
reduction=reduction,
smooth_nr=smooth_nr,
smooth_dr=smooth_dr,
)
tsens = (torch.sum(torch.multiply(skel_true, y_pred)[:, 1:, ...]) + self.smooth) / (
torch.sum(skel_true[:, 1:, ...]) + self.smooth
self.cldice = SoftclDiceLoss(
iter_=iter_,
smooth_nr=smooth_nr,
smooth_dr=smooth_dr,
include_background=include_background,
to_onehot_y=False,
sigmoid=sigmoid,
softmax=softmax,
other_act=other_act,
reduction=reduction,
)
cl_dice = 1.0 - 2.0 * (tprec * tsens) / (tprec + tsens)
total_loss: torch.Tensor = (1.0 - self.alpha) * dice + self.alpha * cl_dice
self.alpha = alpha
self.to_onehot_y = to_onehot_y

@deprecated_arg("y_pred", since="1.5", removed="1.8", new_name="input", msg_suffix="please use `input` instead.")
@deprecated_arg("y_true", since="1.5", removed="1.8", new_name="target", msg_suffix="please use `target` instead.")
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same with the names here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have addressed all the above, thank you!

"""
Args:
input: the shape should be BNH[WD], where N is the number of classes.
target: the shape should be BNH[WD] or B1H[WD], where N is the number of classes.

Raises:
ValueError: When number of dimensions for input and target are different.
ValueError: When number of channels for target is neither 1 nor the same as input.

"""
if input.dim() != target.dim():
raise ValueError(
f"the number of dimensions for input and target should be the same, got shape {input.shape} and {target.shape}."
)

if target.shape[1] != 1 and target.shape[1] != input.shape[1]:
raise ValueError(
f"number of channels for target is neither 1 nor the same as input, got shape {input.shape} and {target.shape}."
)

if self.to_onehot_y:
n_pred_ch = input.shape[1]
if n_pred_ch == 1:
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2)
else:
target = one_hot(target, num_classes=n_pred_ch)

dice_loss = self.dice(input, target)
cldice_loss = self.cldice(input, target)
total_loss: torch.Tensor = (1.0 - self.alpha) * dice_loss + self.alpha * cldice_loss

return total_loss
Loading
Loading