Skip to content
42 changes: 27 additions & 15 deletions monai/data/test_time_augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from copy import deepcopy
from typing import TYPE_CHECKING, Any

import numpy as np
import torch

from monai.config.type_definitions import NdarrayOrTensor
Expand Down Expand Up @@ -68,7 +67,7 @@ class TestTimeAugmentation:
Args:
transform: transform (or composed) to be applied to each realization. At least one transform must be of type
`RandomizableTrait` (i.e. `Randomizable`, `RandomizableTransform`, or `RandomizableTrait`).
. All random transforms must be of type `InvertibleTransform`.
When `apply_inverse_to_pred` is True, all random transforms must be of type `InvertibleTransform`.
batch_size: number of realizations to infer at once.
num_workers: how many subprocesses to use for data.
inferrer_fn: function to use to perform inference.
Expand All @@ -92,6 +91,11 @@ class TestTimeAugmentation:
will return the full data. Dimensions will be same size as when passing a single image through
`inferrer_fn`, with a dimension appended equal in size to `num_examples` (N), i.e., `[N,C,H,W,[D]]`.
progress: whether to display a progress bar.
apply_inverse_to_pred: whether to apply inverse transformations to the predictions.
If the model's prediction is spatial (e.g. segmentation), this should be `True` to map the predictions
back to the original spatial reference.
If the prediction is non-spatial (e.g. classification label or score), this should be `False` to
aggregate the raw predictions directly. Defaults to `True`.

Example:
.. code-block:: python
Expand Down Expand Up @@ -125,6 +129,7 @@ def __init__(
post_func: Callable = _identity,
return_full_data: bool = False,
progress: bool = True,
apply_inverse_to_pred: bool = True,
) -> None:
self.transform = transform
self.batch_size = batch_size
Expand All @@ -134,6 +139,7 @@ def __init__(
self.image_key = image_key
self.return_full_data = return_full_data
self.progress = progress
self.apply_inverse_to_pred = apply_inverse_to_pred
self._pred_key = CommonKeys.PRED
self.inverter = Invertd(
keys=self._pred_key,
Expand All @@ -152,20 +158,23 @@ def __init__(

def _check_transforms(self):
"""Should be at least 1 random transform, and all random transforms should be invertible."""
ts = [self.transform] if not isinstance(self.transform, Compose) else self.transform.transforms
randoms = np.array([isinstance(t, Randomizable) for t in ts])
invertibles = np.array([isinstance(t, InvertibleTransform) for t in ts])
# check at least 1 random
if sum(randoms) == 0:
transforms = [self.transform] if not isinstance(self.transform, Compose) else self.transform.transforms
warns = []
randoms = []

for idx, t in enumerate(transforms):
if isinstance(t, Randomizable):
randoms.append(t)
if self.apply_inverse_to_pred and not isinstance(t, InvertibleTransform):
warns.append(f"Transform #{idx} (type {type(t).__name__}) is random but not invertible.")

if len(randoms) == 0:
warns.append("TTA usually requires at least one `Randomizable` transform in the given transform sequence.")

if len(warns) > 0:
warnings.warn(
"TTA usually has at least a `Randomizable` transform or `Compose` contains `Randomizable` transforms."
"TTA has encountered issues with the given transforms:\n " + "\n ".join(warns), stacklevel=2
)
# check that whenever randoms is True, invertibles is also true
for r, i in zip(randoms, invertibles):
if r and not i:
warnings.warn(
f"Not all applied random transform(s) are invertible. Problematic transform: {type(r).__name__}"
)

def __call__(
self, data: dict[str, Any], num_examples: int = 10
Expand Down Expand Up @@ -199,7 +208,10 @@ def __call__(
for b in tqdm(dl) if has_tqdm and self.progress else dl:
# do model forward pass
b[self._pred_key] = self.inferrer_fn(b[self.image_key].to(self.device))
outs.extend([self.inverter(PadListDataCollate.inverse(i))[self._pred_key] for i in decollate_batch(b)])
if self.apply_inverse_to_pred:
outs.extend([self.inverter(PadListDataCollate.inverse(i))[self._pred_key] for i in decollate_batch(b)])
else:
outs.extend([i[self._pred_key] for i in decollate_batch(b)])

output: NdarrayOrTensor = stack(outs, 0)

Expand Down
39 changes: 38 additions & 1 deletion tests/integration/test_testtimeaugmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def test_test_time_augmentation(self):
# output might be different size, so pad so that they match
train_loader = DataLoader(train_ds, batch_size=2, collate_fn=pad_list_data_collate)

model = UNet(2, 1, 1, channels=(6, 6), strides=(2, 2)).to(device)
model = UNet(2, 1, 1, channels=(6, 6), strides=(2,)).to(device)
loss_function = DiceLoss(sigmoid=True)
optimizer = torch.optim.Adam(model.parameters(), 1e-3)

Expand Down Expand Up @@ -181,6 +181,43 @@ def test_image_no_label(self):
tta = TestTimeAugmentation(transforms, batch_size=5, num_workers=0, inferrer_fn=lambda x: x, orig_key="image")
tta(self.get_data(1, (20, 20), include_label=False))

def test_non_spatial_output(self):
"""
Test TTA for non-spatial output (e.g., classification scores).
Verifies that setting `apply_inverse_to_pred=False` correctly aggregates
predictions without attempting spatial inversion.
"""
input_size = (20, 20)
data = {"image": np.random.rand(1, *input_size).astype(np.float32)}

transforms = Compose(
[EnsureChannelFirstd("image", channel_dim="no_channel"), RandFlipd("image", prob=1.0, spatial_axis=0)]
)

def mock_classifier(x):
batch_size = x.shape[0]
return torch.tensor([[0.2, 0.8]] * batch_size, dtype=torch.float32, device=x.device)

tt_aug = TestTimeAugmentation(
transform=transforms,
batch_size=2,
num_workers=0,
inferrer_fn=mock_classifier,
device="cpu",
orig_key="image",
apply_inverse_to_pred=False,
return_full_data=False,
)
mode, mean, std, vvc = tt_aug(data, num_examples=4)

self.assertEqual(mean.shape, (2,))
np.testing.assert_allclose(mean, [0.2, 0.8], atol=1e-6)
np.testing.assert_allclose(std, [0.0, 0.0], atol=1e-6)

tt_aug.return_full_data = True
full_output = tt_aug(data, num_examples=4)
self.assertEqual(full_output.shape, (4, 2))


if __name__ == "__main__":
unittest.main()
Loading