Skip to content

Machine learning models

The qim3d library aims to ease the creation of ML models for volumetric images.

qim3d.ml.models.UNet

Bases: Module

Constructs a 3D U-Net model for volumetric image segmentation.

The U-Net architecture consists of a contracting path (encoder) to capture context and a symmetric expanding path (decoder) that enables precise localization. This implementation wraps the MONAI U-Net and provides simplified presets for model depth and width via the size argument.

Model Presets (Channels per Layer):

  • 'xxsmall': (4, 8) - Ultra-lightweight (2 layers)
  • 'xsmall': (16, 32) - Lightweight (2 layers)
  • 'small': (32, 64, 128) - Fast training (3 layers)
  • 'medium': (64, 128, 256) - Balanced performance (3 layers, default)
  • 'large': (64, 128, 256, 512, 1024) - High capacity (5 layers)
  • 'xlarge': (64, 128, 256, 512, 1024, 2048) - Very high capacity (6 layers)
  • 'xxlarge': (64, 128, 256, 512, 1024, 2048, 4096) - Maximum capacity (7 layers)

Parameters:

Name Type Description Default
size str

The complexity of the model. Must be one of 'xxsmall', 'xsmall', 'small', 'medium', 'large', 'xlarge', or 'xxlarge'. Defaults to 'medium'.

'medium'
dropout float

The dropout rate (0 to 1) applied to hidden layers to prevent overfitting. Defaults to 0.

0
kernel_size int

The size of the convolution kernel. Defaults to 3.

3
up_kernel_size int

The size of the up-convolution kernel. Defaults to 3.

3
activation str

The activation function to use (e.g., 'RELU', 'PReLU', 'Sigmoid'). Defaults to 'PReLU'.

'PReLU'
bias bool

Whether to include bias terms in convolutions. Defaults to True.

True
adn_order str

The ordering of Activation (A), Dropout (D), and Normalization (N) blocks. Defaults to 'NDA'.

'NDA'

Returns:

Name Type Description
model Module

The initialized 3D U-Net model.

Raises:

Type Description
ValueError

If size is not a valid preset string.

Example
import qim3d

# Initialize a small U-Net for quick experiments
model = qim3d.ml.models.UNet(size='small', dropout=0.2)

print(model)
Source code in qim3d/ml/models/_unet.py
class UNet(torch.nn.Module):
    """
    Constructs a 3D U-Net model for volumetric image segmentation.

    The U-Net architecture consists of a contracting path (encoder) to capture context and a symmetric expanding path (decoder) that enables precise localization. This implementation wraps the [MONAI U-Net](https://docs.monai.io/en/stable/networks.html#unet) and provides simplified presets for model depth and width via the `size` argument.

    **Model Presets (Channels per Layer):**

    * **'xxsmall'**: (4, 8) - *Ultra-lightweight (2 layers)*
    * **'xsmall'**: (16, 32) - *Lightweight (2 layers)*
    * **'small'**: (32, 64, 128) - *Fast training (3 layers)*
    * **'medium'**: (64, 128, 256) - *Balanced performance (3 layers, default)*
    * **'large'**: (64, 128, 256, 512, 1024) - *High capacity (5 layers)*
    * **'xlarge'**: (64, 128, 256, 512, 1024, 2048) - *Very high capacity (6 layers)*
    * **'xxlarge'**: (64, 128, 256, 512, 1024, 2048, 4096) - *Maximum capacity (7 layers)*

    Args:
        size (str, optional): The complexity of the model. Must be one of 'xxsmall', 'xsmall', 'small', 'medium', 'large', 'xlarge', or 'xxlarge'. Defaults to 'medium'.
        dropout (float, optional): The dropout rate (0 to 1) applied to hidden layers to prevent overfitting. Defaults to 0.
        kernel_size (int, optional): The size of the convolution kernel. Defaults to 3.
        up_kernel_size (int, optional): The size of the up-convolution kernel. Defaults to 3.
        activation (str, optional): The activation function to use (e.g., 'RELU', 'PReLU', 'Sigmoid'). Defaults to 'PReLU'.
        bias (bool, optional): Whether to include bias terms in convolutions. Defaults to `True`.
        adn_order (str, optional): The ordering of Activation (A), Dropout (D), and Normalization (N) blocks. Defaults to 'NDA'.

    Returns:
        model (torch.nn.Module): The initialized 3D U-Net model.

    Raises:
        ValueError: If `size` is not a valid preset string.

    Example:
        ```python
        import qim3d

        # Initialize a small U-Net for quick experiments
        model = qim3d.ml.models.UNet(size='small', dropout=0.2)

        print(model)
        ```
    """

    def __init__(
        self,
        size: str = 'medium',
        dropout: float = 0,
        kernel_size: int = 3,
        up_kernel_size: int = 3,
        activation: str = 'PReLU',
        bias: bool = True,
        adn_order: str = 'NDA',
    ):
        super().__init__()

        self.size = size
        self.dropout = dropout
        self.kernel_size = kernel_size
        self.up_kernel_size = up_kernel_size
        self.activation = activation
        self.bias = bias
        self.adn_order = adn_order

        self.model = self._model_choice()

    def _model_choice(self) -> torch.nn.Module:
        monai = optional_import('monai', extra='deep-learning')

        size_options = {
            'xxsmall': (4, 8),  # 2 layers
            'xsmall': (16, 32),  # 2 layers
            'small': (32, 64, 128),  # 3 layers
            'medium': (64, 128, 256),  # 3 layers
            'large': (64, 128, 256, 512, 1024),  # 5 layers
            'xlarge': (64, 128, 256, 512, 1024, 2048),  # 6 layers
            'xxlarge': (64, 128, 256, 512, 1024, 2048, 4096),  # 7 layers
        }

        if self.size in size_options:
            self.channels = size_options[self.size]
        else:
            message = (
                f"Unknown size '{self.size}'. Choose from {list(size_options.keys())}"
            )
            raise ValueError(message)

        model = monai.networks.nets.UNet(
            spatial_dims=3,
            in_channels=1,
            out_channels=1,
            channels=self.channels,
            strides=(2,) * (len(self.channels) - 1),
            num_res_units=2,
            kernel_size=self.kernel_size,
            up_kernel_size=self.up_kernel_size,
            act=self.activation,
            dropout=self.dropout,
            bias=self.bias,
            adn_ordering=self.adn_order,
        )
        return model

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.model(x)
        return x

qim3d.ml.Augmentation

Configures data augmentation pipelines for 3D deep learning using MONAI.

This class simplifies the creation of augmentation strategies for training, validation, and testing. It allows you to select preset levels of intensity ('light', 'moderate', 'heavy') and define how input volumes are resized or cropped to match the model's input requirements.

Parameters:

Name Type Description Default
resize str

The method used to conform input images to a specific size. Defaults to 'crop'. * 'crop': Extracts a central crop of the desired size. * 'reshape': Resizes (interpolates) the image to the desired size. * 'padding': Pads the image with zeros to reach the desired size.

'crop'
transform_train str | None

The intensity of augmentation applied to the training set. Options: 'light', 'moderate', 'heavy', or None. Defaults to 'moderate'.

'moderate'
transform_validation str | None

The intensity of augmentation applied to the validation set. Defaults to None (no augmentation).

None
transform_test str | None

The intensity of augmentation applied to the test set. Defaults to None.

None

Raises:

Type Description
ValueError

If resize is not one of 'crop', 'reshape', or 'padding'.

Example
import qim3d

# Create an augmentation strategy that crops images and applies moderate 
# transformations during training.
augmentation = qim3d.ml.Augmentation(resize='crop', transform_train='moderate')
Source code in qim3d/ml/_augmentations.py
class Augmentation:
    """
    Configures data augmentation pipelines for 3D deep learning using MONAI.

    This class simplifies the creation of augmentation strategies for training, validation, and testing. It allows you to select preset levels of intensity ('light', 'moderate', 'heavy') and define how input volumes are resized or cropped to match the model's input requirements.

    Args:
        resize (str, optional): The method used to conform input images to a specific size. Defaults to 'crop'.
            * **'crop'**: Extracts a central crop of the desired size.
            * **'reshape'**: Resizes (interpolates) the image to the desired size.
            * **'padding'**: Pads the image with zeros to reach the desired size.
        transform_train (str | None, optional): The intensity of augmentation applied to the training set. Options: 'light', 'moderate', 'heavy', or `None`. Defaults to 'moderate'.
        transform_validation (str | None, optional): The intensity of augmentation applied to the validation set. Defaults to `None` (no augmentation).
        transform_test (str | None, optional): The intensity of augmentation applied to the test set. Defaults to `None`.

    Raises:
        ValueError: If `resize` is not one of 'crop', 'reshape', or 'padding'.

    Example:
        ```python
        import qim3d

        # Create an augmentation strategy that crops images and applies moderate 
        # transformations during training.
        augmentation = qim3d.ml.Augmentation(resize='crop', transform_train='moderate')
        ```
    """

    def __init__(
        self,
        resize: str = 'crop',
        transform_train: str | None = 'moderate',
        transform_validation: str | None = None,
        transform_test: str | None = None,
    ):
        if resize not in ['crop', 'reshape', 'padding']:
            msg = f"Invalid resize type: {resize}. Use either 'crop', 'resize' or 'padding'."
            raise ValueError(msg)

        self.resize = resize
        self.transform_train = transform_train
        self.transform_validation = transform_validation
        self.transform_test = transform_test

    def augment(
        self, img_shape: tuple, level: str | None = None
    ) -> monai.transforms.Compose:
        """
        Builds a MONAI composition of transforms based on the specified intensity level.

        This method constructs the actual pipeline of operations (e.g., rotations, flips, smoothing) that will be applied to the data.

        **Augmentation Levels:**

        * **None**: No augmentation. Only baseline formatting (ToTensor) is applied.
        * **'light'**: Random 90-degree rotations.
        * **'moderate'**: Rotations, flips, slight Gaussian smoothing, and minor affine transformations (scaling/translation).
        * **'heavy'**: Aggressive rotations, flips, stronger smoothing, and significant affine transformations including shearing.

        Args:
            img_shape (tuple): The target dimensions of the volume as `(Depth, Height, Width)`.
            level (str | None, optional): The specific augmentation level to generate. Must be one of `None`, 'light', 'moderate', or 'heavy'. Defaults to `None`.

        Returns:
            Compose (monai.transforms.Compose): A MONAI `Compose` object containing the sequence of transforms.

        Raises:
            ValueError: If `img_shape` is not 3D or if `level` is invalid.
        """
        from monai.transforms import (
            CenterSpatialCropd,
            Compose,
            RandAffined,
            RandFlipd,
            RandGaussianSmoothd,
            RandRotate90d,
            Resized,
            SpatialPadd,
            ToTensor,
        )

        # Check if image is 3D
        if len(img_shape) == 3:
            im_d, im_h, im_w = img_shape

        else:
            msg = f'Invalid image shape: {img_shape}. Must be 3D.'
            raise ValueError(msg)

        # Check if one of standard augmentation levels
        if level not in [None, 'light', 'moderate', 'heavy']:
            msg = f"Invalid transformation level: {level}. Please choose one of the following levels: None, 'light', 'moderate', 'heavy'."
            raise ValueError(msg)

        # Baseline augmentations
        # TODO: Figure out how to properly do normalization in 3D (normalization should be done channel-wise)
        baseline_aug = [ToTensor()]  # , NormalizeIntensityd(keys=["image"])]

        # Resize augmentations
        if self.resize == 'crop':
            resize_aug = [
                CenterSpatialCropd(keys=['image', 'label'], roi_size=(im_d, im_h, im_w))
            ]

        elif self.resize == 'reshape':
            resize_aug = [
                Resized(keys=['image', 'label'], spatial_size=(im_d, im_h, im_w))
            ]

        elif self.resize == 'padding':
            resize_aug = [
                SpatialPadd(keys=['image', 'label'], spatial_size=(im_d, im_h, im_w))
            ]

        # Level of augmentation
        if level is None:
            # No augmentation for the validation and test sets
            level_aug = []
            resize_aug = []

        elif level == 'light':
            # TODO: Do rotations along other axes?
            level_aug = [
                RandRotate90d(keys=['image', 'label'], prob=1, spatial_axes=(0, 1))
            ]

        elif level == 'moderate':
            level_aug = [
                RandRotate90d(keys=['image', 'label'], prob=1, spatial_axes=(0, 1)),
                RandFlipd(keys=['image', 'label'], prob=0.3, spatial_axis=0),
                RandFlipd(keys=['image', 'label'], prob=0.3, spatial_axis=1),
                RandGaussianSmoothd(keys=['image'], sigma_x=(0.7, 0.7), prob=0.1),
                RandAffined(
                    keys=['image', 'label'],
                    prob=0.5,
                    translate_range=(0.1, 0.1),
                    scale_range=(0.9, 1.1),
                ),
            ]

        elif level == 'heavy':
            level_aug = [
                RandRotate90d(keys=['image', 'label'], prob=1, spatial_axes=(0, 1)),
                RandFlipd(keys=['image', 'label'], prob=0.7, spatial_axis=0),
                RandFlipd(keys=['image', 'label'], prob=0.7, spatial_axis=1),
                RandGaussianSmoothd(keys=['image'], sigma_x=(1.2, 1.2), prob=0.3),
                RandAffined(
                    keys=['image', 'label'],
                    prob=0.5,
                    translate_range=(0.2, 0.2),
                    scale_range=(0.8, 1.4),
                    shear_range=(-15, 15),
                ),
            ]

        return Compose(baseline_aug + resize_aug + level_aug)

qim3d.ml.Augmentation.augment

augment(img_shape, level=None)

Builds a MONAI composition of transforms based on the specified intensity level.

This method constructs the actual pipeline of operations (e.g., rotations, flips, smoothing) that will be applied to the data.

Augmentation Levels:

  • None: No augmentation. Only baseline formatting (ToTensor) is applied.
  • 'light': Random 90-degree rotations.
  • 'moderate': Rotations, flips, slight Gaussian smoothing, and minor affine transformations (scaling/translation).
  • 'heavy': Aggressive rotations, flips, stronger smoothing, and significant affine transformations including shearing.

Parameters:

Name Type Description Default
img_shape tuple

The target dimensions of the volume as (Depth, Height, Width).

required
level str | None

The specific augmentation level to generate. Must be one of None, 'light', 'moderate', or 'heavy'. Defaults to None.

None

Returns:

Name Type Description
Compose Compose

A MONAI Compose object containing the sequence of transforms.

Raises:

Type Description
ValueError

If img_shape is not 3D or if level is invalid.

Source code in qim3d/ml/_augmentations.py
def augment(
    self, img_shape: tuple, level: str | None = None
) -> monai.transforms.Compose:
    """
    Builds a MONAI composition of transforms based on the specified intensity level.

    This method constructs the actual pipeline of operations (e.g., rotations, flips, smoothing) that will be applied to the data.

    **Augmentation Levels:**

    * **None**: No augmentation. Only baseline formatting (ToTensor) is applied.
    * **'light'**: Random 90-degree rotations.
    * **'moderate'**: Rotations, flips, slight Gaussian smoothing, and minor affine transformations (scaling/translation).
    * **'heavy'**: Aggressive rotations, flips, stronger smoothing, and significant affine transformations including shearing.

    Args:
        img_shape (tuple): The target dimensions of the volume as `(Depth, Height, Width)`.
        level (str | None, optional): The specific augmentation level to generate. Must be one of `None`, 'light', 'moderate', or 'heavy'. Defaults to `None`.

    Returns:
        Compose (monai.transforms.Compose): A MONAI `Compose` object containing the sequence of transforms.

    Raises:
        ValueError: If `img_shape` is not 3D or if `level` is invalid.
    """
    from monai.transforms import (
        CenterSpatialCropd,
        Compose,
        RandAffined,
        RandFlipd,
        RandGaussianSmoothd,
        RandRotate90d,
        Resized,
        SpatialPadd,
        ToTensor,
    )

    # Check if image is 3D
    if len(img_shape) == 3:
        im_d, im_h, im_w = img_shape

    else:
        msg = f'Invalid image shape: {img_shape}. Must be 3D.'
        raise ValueError(msg)

    # Check if one of standard augmentation levels
    if level not in [None, 'light', 'moderate', 'heavy']:
        msg = f"Invalid transformation level: {level}. Please choose one of the following levels: None, 'light', 'moderate', 'heavy'."
        raise ValueError(msg)

    # Baseline augmentations
    # TODO: Figure out how to properly do normalization in 3D (normalization should be done channel-wise)
    baseline_aug = [ToTensor()]  # , NormalizeIntensityd(keys=["image"])]

    # Resize augmentations
    if self.resize == 'crop':
        resize_aug = [
            CenterSpatialCropd(keys=['image', 'label'], roi_size=(im_d, im_h, im_w))
        ]

    elif self.resize == 'reshape':
        resize_aug = [
            Resized(keys=['image', 'label'], spatial_size=(im_d, im_h, im_w))
        ]

    elif self.resize == 'padding':
        resize_aug = [
            SpatialPadd(keys=['image', 'label'], spatial_size=(im_d, im_h, im_w))
        ]

    # Level of augmentation
    if level is None:
        # No augmentation for the validation and test sets
        level_aug = []
        resize_aug = []

    elif level == 'light':
        # TODO: Do rotations along other axes?
        level_aug = [
            RandRotate90d(keys=['image', 'label'], prob=1, spatial_axes=(0, 1))
        ]

    elif level == 'moderate':
        level_aug = [
            RandRotate90d(keys=['image', 'label'], prob=1, spatial_axes=(0, 1)),
            RandFlipd(keys=['image', 'label'], prob=0.3, spatial_axis=0),
            RandFlipd(keys=['image', 'label'], prob=0.3, spatial_axis=1),
            RandGaussianSmoothd(keys=['image'], sigma_x=(0.7, 0.7), prob=0.1),
            RandAffined(
                keys=['image', 'label'],
                prob=0.5,
                translate_range=(0.1, 0.1),
                scale_range=(0.9, 1.1),
            ),
        ]

    elif level == 'heavy':
        level_aug = [
            RandRotate90d(keys=['image', 'label'], prob=1, spatial_axes=(0, 1)),
            RandFlipd(keys=['image', 'label'], prob=0.7, spatial_axis=0),
            RandFlipd(keys=['image', 'label'], prob=0.7, spatial_axis=1),
            RandGaussianSmoothd(keys=['image'], sigma_x=(1.2, 1.2), prob=0.3),
            RandAffined(
                keys=['image', 'label'],
                prob=0.5,
                translate_range=(0.2, 0.2),
                scale_range=(0.8, 1.4),
                shear_range=(-15, 15),
            ),
        ]

    return Compose(baseline_aug + resize_aug + level_aug)

qim3d.ml.Hyperparameters

Configuration wrapper for training parameters (Optimizer, Loss, Epochs).

This class centralizes the setup of the training loop. It automatically instantiates the PyTorch optimizer and loss function based on string arguments, ensuring valid combinations and settings.

Parameters:

Name Type Description Default
model Module

The PyTorch model to be trained (required to register parameters with the optimizer).

required
n_epochs int

The number of complete passes through the training dataset. Defaults to 10.

10
learning_rate float

The step size for the optimizer. Defaults to 1e-3.

0.001
optimizer str

The optimization algorithm. Options: 'Adam', 'SGD', 'RMSprop'. Defaults to 'Adam'.

'Adam'
momentum float

The momentum factor (only used for 'SGD' and 'RMSprop'). Accelerates gradient vectors in the right directions, leading to faster converging. Defaults to 0.

0
weight_decay float

L2 penalty applied to the weights to prevent overfitting. Defaults to 0.

0
loss_function str

The objective function to minimize. Options: * 'BCE': Binary Cross Entropy (with Logits). * 'Dice': Dice Loss (good for class imbalance). * 'Focal': Focal Loss (focuses on hard examples). * 'DiceCE': Weighted sum of Dice and Cross Entropy. Defaults to 'Focal'.

'Focal'

Returns:

Name Type Description
hyperparameters dict

A dictionary containing the initialized objects, accessible via the () operator: * 'optimizer': The torch.optim object. * 'criterion': The loss function module. * 'n_epochs': The integer number of epochs.

Raises:

Type Description
ValueError

If loss_function or optimizer are not among the supported options.

Example
import qim3d

# 1. Initialize model
model = qim3d.ml.models.UNet(size='small')

# 2. Define training configuration
hyperparameters = qim3d.ml.Hyperparameters(
    model=model,
    n_epochs=10,
    learning_rate=5e-3,
    optimizer='Adam',
    loss_function='DiceCE'
)

# 3. Retrieve initialized objects for the training loop
params_dict = hyperparameters()

optimizer = params_dict['optimizer']
criterion = params_dict['criterion']
print(f"Ready to train for {params_dict['n_epochs']} epochs with {optimizer.__class__.__name__}")
Source code in qim3d/ml/models/_unet.py
class Hyperparameters:
    """
    Configuration wrapper for training parameters (Optimizer, Loss, Epochs).

    This class centralizes the setup of the training loop. It automatically instantiates the PyTorch optimizer and loss function based on string arguments, ensuring valid combinations and settings.

    Args:
        model (torch.nn.Module): The PyTorch model to be trained (required to register parameters with the optimizer).
        n_epochs (int, optional): The number of complete passes through the training dataset. Defaults to 10.
        learning_rate (float, optional): The step size for the optimizer. Defaults to 1e-3.
        optimizer (str, optional): The optimization algorithm. Options: 'Adam', 'SGD', 'RMSprop'. Defaults to 'Adam'.
        momentum (float, optional): The momentum factor (only used for 'SGD' and 'RMSprop'). Accelerates gradient vectors in the right directions, leading to faster converging. Defaults to 0.
        weight_decay (float, optional): L2 penalty applied to the weights to prevent overfitting. Defaults to 0.
        loss_function (str, optional): The objective function to minimize. Options:
            * 'BCE': Binary Cross Entropy (with Logits).
            * 'Dice': Dice Loss (good for class imbalance).
            * 'Focal': Focal Loss (focuses on hard examples).
            * 'DiceCE': Weighted sum of Dice and Cross Entropy.
            Defaults to 'Focal'.

    Returns:
        hyperparameters (dict):
            A dictionary containing the initialized objects, accessible via the `()` operator:
            * 'optimizer': The torch.optim object.
            * 'criterion': The loss function module.
            * 'n_epochs': The integer number of epochs.

    Raises:
        ValueError: If `loss_function` or `optimizer` are not among the supported options.

    Example:
        ```python
        import qim3d

        # 1. Initialize model
        model = qim3d.ml.models.UNet(size='small')

        # 2. Define training configuration
        hyperparameters = qim3d.ml.Hyperparameters(
            model=model,
            n_epochs=10,
            learning_rate=5e-3,
            optimizer='Adam',
            loss_function='DiceCE'
        )

        # 3. Retrieve initialized objects for the training loop
        params_dict = hyperparameters()

        optimizer = params_dict['optimizer']
        criterion = params_dict['criterion']
        print(f"Ready to train for {params_dict['n_epochs']} epochs with {optimizer.__class__.__name__}")
        ```
    """

    def __init__(
        self,
        model: torch.nn.Module,
        n_epochs: int = 10,
        learning_rate: float = 1e-3,
        optimizer: str = 'Adam',
        momentum: float = 0,
        weight_decay: float = 0,
        loss_function: str = 'Focal',
    ):
        # TODO: Implement custom loss_functions? Then add a check to see if loss works for segmentation.
        if loss_function not in ['BCE', 'Dice', 'Focal', 'DiceCE']:
            msg = f'Invalid loss function: {loss_function}. Loss criterion must be one of the following: "BCE", "Dice", "Focal", "DiceCE".'
            raise ValueError(msg)

        # TODO: Implement custom optimizer? And add check to see if valid.
        if optimizer not in ['Adam', 'SGD', 'RMSprop']:
            msg = f'Invalid optimizer: {optimizer}. Optimizer must be one of the following: "Adam", "SGD", "RMSprop".'
            raise ValueError(msg)

        if (momentum != 0) and optimizer == 'Adam':
            log.info(
                "Momentum isn't an input in the 'Adam' optimizer. "
                "Change optimizer to 'SGD' or 'RMSprop' to use momentum."
            )

        self.model = model
        self.n_epochs = n_epochs
        self.learning_rate = learning_rate
        self.optimizer = optimizer
        self.momentum = momentum
        self.weight_decay = weight_decay
        self.loss_function = loss_function

    def __call__(self):
        return self.model_params(
            self.model,
            self.n_epochs,
            self.optimizer,
            self.learning_rate,
            self.weight_decay,
            self.momentum,
            self.loss_function,
        )

    def model_params(
        self,
        model: torch.nn.Module,
        n_epochs: int,
        optimizer: str,
        learning_rate: float,
        weight_decay: float,
        momentum: float,
        loss_function: str,
    ) -> dict:
        optim = self._optimizer(model, optimizer, learning_rate, weight_decay, momentum)
        criterion = self._loss_functions(loss_function)

        hyper_dict = {
            'optimizer': optim,
            'criterion': criterion,
            'n_epochs': n_epochs,
        }
        return hyper_dict

    # Selecting the optimizer
    def _optimizer(
        self,
        model: torch.nn.Module,
        optimizer: str,
        learning_rate: float,
        weight_decay: float,
        momentum: float,
    ) -> torch.optim.Optimizer:
        torch = optional_import('torch', extra='deep-learning')

        if optimizer == 'Adam':
            optim = torch.optim.Adam(
                model.parameters(), lr=learning_rate, weight_decay=weight_decay
            )
        elif optimizer == 'SGD':
            optim = torch.optim.SGD(
                model.parameters(),
                lr=learning_rate,
                momentum=momentum,
                weight_decay=weight_decay,
            )
        elif optimizer == 'RMSprop':
            optim = torch.optim.RMSprop(
                model.parameters(),
                lr=learning_rate,
                weight_decay=weight_decay,
                momentum=momentum,
            )
        return optim

    # Selecting the loss function
    def _loss_functions(self, loss_function: str) -> torch.nn:
        monai = optional_import('monai', extra='deep-learning')
        torch = optional_import('torch', extra='deep-learning')

        if loss_function == 'BCE':
            criterion = torch.nn.BCEWithLogitsLoss(reduction='mean')
        elif loss_function == 'Dice':
            criterion = monai.losses.DiceLoss(sigmoid=True, reduction='mean')
        elif loss_function == 'Focal':
            criterion = monai.losses.FocalLoss(reduction='mean')
        elif loss_function == 'DiceCE':
            criterion = monai.losses.DiceCELoss(sigmoid=True, reduction='mean')
        return criterion

qim3d.ml.prepare_datasets

prepare_datasets(path, val_fraction, model, augmentation)

Loads, splits, and applies augmentations to the dataset.

This function automates the creation of PyTorch datasets for training, validation, and testing. It handles:

  1. Loading: Reads data from the specified directory structure (expects train/images, train/labels, test/images, test/labels).
  2. Splitting: Divides the training data into a training set and a validation set based on val_fraction.
  3. Augmentation: Applies the transformation pipelines defined in the augmentation object to each split.

Parameters:

Name Type Description Default
path str

The root directory of the dataset.

required
val_fraction float

The proportion of the training data to reserve for validation (0.0 to 1.0).

required
model Module

The PyTorch model. (Used to determine the expected input dimensions/channels for resizing checks).

required
augmentation Augmentation

An instance of qim3d.ml.Augmentation containing the transformation rules for each split.

required

Returns:

Type Description
(train_set, val_set, test_set)

A tuple of torch.utils.data.Subset objects ready for use with a DataLoader.

Raises:

Type Description
ValueError

If val_fraction is not a float between 0 and 1.

Example
import qim3d

# Define parameters
base_path = "dataset"
model = qim3d.ml.models.UNet(size='small')

# Configure augmentation
aug_pipeline = qim3d.ml.Augmentation(resize='crop', transform_train='light')

# Generate datasets
train_ds, val_ds, test_ds = qim3d.ml.prepare_datasets(
    path=base_path, 
    val_fraction=0.2, 
    model=model, 
    augmentation=aug_pipeline
)

print(f"Training samples: {len(train_ds)}")
print(f"Validation samples: {len(val_ds)}")
Source code in qim3d/ml/_data.py
def prepare_datasets(
    path: str,
    val_fraction: float,
    model: torch.nn.Module,
    augmentation: Augmentation,
) -> tuple[torch.utils.data.Subset, torch.utils.data.Subset, torch.utils.data.Subset]:
    """
    Loads, splits, and applies augmentations to the dataset.

    This function automates the creation of PyTorch datasets for training, validation, and testing. It handles:

    1.  **Loading**: Reads data from the specified directory structure (expects `train/images`, `train/labels`, `test/images`, `test/labels`).
    2.  **Splitting**: Divides the training data into a training set and a validation set based on `val_fraction`.
    3.  **Augmentation**: Applies the transformation pipelines defined in the `augmentation` object to each split.

    Args:
        path (str): The root directory of the dataset.
        val_fraction (float): The proportion of the training data to reserve for validation (0.0 to 1.0).
        model (torch.nn.Module): The PyTorch model. (Used to determine the expected input dimensions/channels for resizing checks).
        augmentation (Augmentation): An instance of `qim3d.ml.Augmentation` containing the transformation rules for each split.

    Returns:
        (train_set, val_set, test_set): A tuple of `torch.utils.data.Subset` objects ready for use with a DataLoader.

    Raises:
        ValueError: If `val_fraction` is not a float between 0 and 1.

    Example:
        ```python
        import qim3d

        # Define parameters
        base_path = "dataset"
        model = qim3d.ml.models.UNet(size='small')

        # Configure augmentation
        aug_pipeline = qim3d.ml.Augmentation(resize='crop', transform_train='light')

        # Generate datasets
        train_ds, val_ds, test_ds = qim3d.ml.prepare_datasets(
            path=base_path, 
            val_fraction=0.2, 
            model=model, 
            augmentation=aug_pipeline
        )

        print(f"Training samples: {len(train_ds)}")
        print(f"Validation samples: {len(val_ds)}")
        ```
    """

    if not isinstance(val_fraction, float) or not (0 <= val_fraction < 1):
        msg = 'The validation fraction must be a float between 0 and 1.'
        raise ValueError(msg)

    resize = augmentation.resize
    n_channels = len(model.channels)

    # Get the first image to check the shape
    im_path = Path(path) / 'train'
    first_img = sorted((im_path / 'images').iterdir())[0]

    # Load 3D volume
    image = qim3d.io.load(first_img)
    orig_shape = image.shape

    final_shape = check_resize(orig_shape, resize, n_channels)

    train_set = Dataset(
        root_path=path,
        transform=augmentation.augment(final_shape, level=augmentation.transform_train),
    )
    val_set = Dataset(
        root_path=path,
        transform=augmentation.augment(
            final_shape, level=augmentation.transform_validation
        ),
    )
    test_set = Dataset(
        root_path=path,
        split='test',
        transform=augmentation.augment(final_shape, level=augmentation.transform_test),
    )

    split_idx = int(np.floor(val_fraction * len(train_set)))
    indices = torch.randperm(len(train_set))

    train_set = torch.utils.data.Subset(train_set, indices[split_idx:])
    val_set = torch.utils.data.Subset(val_set, indices[:split_idx])

    return train_set, val_set, test_set

qim3d.ml.prepare_dataloaders

prepare_dataloaders(
    train_set,
    val_set,
    test_set,
    batch_size,
    shuffle_train=True,
    num_workers=8,
    pin_memory=False,
)

Wraps the datasets into PyTorch DataLoaders for efficient batch processing.

DataLoaders handle the complexity of batching, shuffling, and parallel data loading. This function ensures that your model receives data in the correct format and speed during training.

Note: It is recommended to keep shuffle_train=True to improve model generalization.

Parameters:

Name Type Description Default
train_set Dataset

The training dataset.

required
val_set Dataset

The validation dataset.

required
test_set Dataset

The testing dataset.

required
batch_size int

The number of samples to process in a single step (batch).

required
shuffle_train bool

If True, the training data is reshuffled at every epoch. This prevents the model from learning the order of the data. Defaults to True.

True
num_workers int

The number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process. Increasing this speeds up data loading but consumes more RAM. Defaults to 8.

8
pin_memory bool

If True, the data loader will copy Tensors into CUDA pinned memory before returning them. This can speed up data transfer to the GPU. Defaults to False.

False

Returns:

Type Description
(train_loader, val_loader, test_loader)

A tuple containing the three torch.utils.data.DataLoader objects.

Example
import qim3d

# ... (assume datasets are already created as train_set, val_set, test_set) ...

# Create DataLoaders
train_loader, val_loader, test_loader = qim3d.ml.prepare_dataloaders(
    train_set=train_set,
    val_set=val_set,
    test_set=test_set,
    batch_size=4,
    num_workers=4
)

# Iterate through the training loader
for batch in train_loader:
    inputs, labels = batch['image'], batch['label']
    # training step...
Source code in qim3d/ml/_data.py
def prepare_dataloaders(
    train_set: torch.utils.data,
    val_set: torch.utils.data,
    test_set: torch.utils.data,
    batch_size: int,
    shuffle_train: bool = True,
    num_workers: int = 8,
    pin_memory: bool = False,
) -> tuple[
    torch.utils.data.DataLoader,
    torch.utils.data.DataLoader,
    torch.utils.data.DataLoader,
]:
    """
    Wraps the datasets into PyTorch DataLoaders for efficient batch processing.

    DataLoaders handle the complexity of batching, shuffling, and parallel data loading. This function ensures that your model receives data in the correct format and speed during training.

    **Note:** It is recommended to keep `shuffle_train=True` to improve model generalization.

    Args:
        train_set (torch.utils.data.Dataset): The training dataset.
        val_set (torch.utils.data.Dataset): The validation dataset.
        test_set (torch.utils.data.Dataset): The testing dataset.
        batch_size (int): The number of samples to process in a single step (batch).
        shuffle_train (bool, optional): If `True`, the training data is reshuffled at every epoch. This prevents the model from learning the order of the data. Defaults to `True`.
        num_workers (int, optional): The number of subprocesses to use for data loading. `0` means that the data will be loaded in the main process. Increasing this speeds up data loading but consumes more RAM. Defaults to 8.
        pin_memory (bool, optional): If `True`, the data loader will copy Tensors into CUDA pinned memory before returning them. This can speed up data transfer to the GPU. Defaults to `False`.

    Returns:
        (train_loader, val_loader, test_loader): A tuple containing the three `torch.utils.data.DataLoader` objects.

    Example:
        ```python
        import qim3d

        # ... (assume datasets are already created as train_set, val_set, test_set) ...

        # Create DataLoaders
        train_loader, val_loader, test_loader = qim3d.ml.prepare_dataloaders(
            train_set=train_set,
            val_set=val_set,
            test_set=test_set,
            batch_size=4,
            num_workers=4
        )

        # Iterate through the training loader
        for batch in train_loader:
            inputs, labels = batch['image'], batch['label']
            # training step...
        ```
    """
    from torch.utils.data import DataLoader

    train_loader = DataLoader(
        dataset=train_set,
        batch_size=batch_size,
        shuffle=shuffle_train,
        num_workers=num_workers,
        pin_memory=pin_memory,
    )
    val_loader = DataLoader(
        dataset=val_set,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
    )
    test_loader = DataLoader(
        dataset=test_set,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
    )

    return train_loader, val_loader, test_loader

qim3d.ml.model_summary

model_summary(model, dataloader)

Generates a detailed summary of the model's architecture and parameter count.

This function provides a comprehensive overview of the model, including the output shape of each layer, the number of trainable parameters, and the estimated memory usage. It automatically infers the input dimensions by sampling a single batch from the provided DataLoader.

Parameters:

Name Type Description Default
model Module

The PyTorch model to analyze.

required
dataloader DataLoader

A DataLoader used to retrieve a sample batch for input shape inference.

required

Returns:

Name Type Description
model_s ModelStatistics

An object containing the model statistics. When printed, it displays a formatted table of layers and parameters.

Example
import qim3d

# Define model and data components
model = qim3d.ml.models.UNet(size='small')

# ... (assume train_loader is already prepared) ...

# Print model summary
summary = qim3d.ml.model_summary(model, train_loader)
print(summary)
Source code in qim3d/ml/_ml_utils.py
def model_summary(
    model: torch.nn.Module, dataloader: torch.utils.data.DataLoader
) -> torchinfo.ModelStatistics:
    """
    Generates a detailed summary of the model's architecture and parameter count.

    This function provides a comprehensive overview of the model, including the output shape of each layer, the number of trainable parameters, and the estimated memory usage. It automatically infers the input dimensions by sampling a single batch from the provided DataLoader.

    Args:
        model (torch.nn.Module): The PyTorch model to analyze.
        dataloader (torch.utils.data.DataLoader): A DataLoader used to retrieve a sample batch for input shape inference.

    Returns:
        model_s (torchinfo.ModelStatistics):
            An object containing the model statistics. When printed, it displays a formatted table of layers and parameters.

    Example:
        ```python
        import qim3d

        # Define model and data components
        model = qim3d.ml.models.UNet(size='small')

        # ... (assume train_loader is already prepared) ...

        # Print model summary
        summary = qim3d.ml.model_summary(model, train_loader)
        print(summary)
        ```
    """
    images, _ = next(iter(dataloader))
    batch_size = tuple(images.shape)
    model_s = torchinfo.summary(model, batch_size, depth=torch.inf)

    return model_s

qim3d.ml.train_model

train_model(
    model,
    hyperparameters,
    train_loader,
    val_loader,
    checkpoint_directory=None,
    eval_every=1,
    print_every=1,
    plot=True,
    return_loss=False,
)

Executes the training loop for a PyTorch model.

This function manages the iterative process of training. It handles:

  1. Training Steps: Iterating through the training data, computing gradients (backpropagation), and updating model weights.
  2. Validation: Periodically evaluating the model on unseen data to monitor for overfitting.
  3. Logging: Printing loss values to track convergence.
  4. Checkpointing: Saving the final model weights to disk.
  5. Visualization: Plotting training and validation loss curves.

The function automatically detects if a GPU (CUDA) is available and moves the model and data to the appropriate device.

Parameters:

Name Type Description Default
model Module

The PyTorch model to train.

required
hyperparameters Hyperparameters

A qim3d.ml.Hyperparameters object containing the optimizer, loss function, and epoch count.

required
train_loader DataLoader

The DataLoader for the training set.

required
val_loader DataLoader

The DataLoader for the validation set.

required
checkpoint_directory str

The directory where the final model weights (.pth file) will be saved. If None, the model is not saved to disk. Defaults to None.

None
eval_every int

The number of epochs between validation runs. Defaults to 1 (validate every epoch).

1
print_every int

The number of epochs between log updates. Defaults to 1 (log every epoch).

1
plot bool

If True, displays a plot of the loss history after training finishes. Defaults to True.

True
return_loss bool

If True, returns the history of loss values. Defaults to False.

False

Returns:

Type Description
train_loss, val_loss) (tuple[dict, dict] | None

Only returned if return_loss is True. * train_loss: A dictionary containing 'loss' (per epoch) and 'batch_loss' (per iteration). * val_loss: A dictionary containing 'loss' and 'batch_loss' for the validation set.

Example
import qim3d

# 1. Setup components
base_path = "dataset"
model = qim3d.ml.models.UNet(size='xxsmall')
hyperparameters = qim3d.ml.Hyperparameters(model, n_epochs=5)
augmentation = qim3d.ml.Augmentation(resize='crop', transform_train='light')

# 2. Prepare Data
train_set, val_set, test_set = qim3d.ml.prepare_datasets(
    path=base_path,
    val_fraction=0.5,
    model=model,
    augmentation=augmentation
)

train_loader, val_loader, test_loader = qim3d.ml.prepare_dataloaders(
    train_set, val_set, test_set, batch_size=1
)

# 3. Train
qim3d.ml.train_model(
    model=model,
    hyperparameters=hyperparameters,
    train_loader=train_loader,
    val_loader=val_loader,
    checkpoint_directory=base_path,
    plot=True
)
Source code in qim3d/ml/_ml_utils.py
def train_model(
    model: torch.nn.Module,
    hyperparameters: Hyperparameters,
    train_loader: torch.utils.data.DataLoader,
    val_loader: torch.utils.data.DataLoader,
    checkpoint_directory: str = None,
    eval_every: int = 1,
    print_every: int = 1,
    plot: bool = True,
    return_loss: bool = False,
) -> tuple[tuple[float], tuple[float]]:
    """
    Executes the training loop for a PyTorch model.

    This function manages the iterative process of training. It handles:

    1.  **Training Steps**: Iterating through the training data, computing gradients (backpropagation), and updating model weights.
    2.  **Validation**: Periodically evaluating the model on unseen data to monitor for overfitting.
    3.  **Logging**: Printing loss values to track convergence.
    4.  **Checkpointing**: Saving the final model weights to disk.
    5.  **Visualization**: Plotting training and validation loss curves.

    The function automatically detects if a GPU (CUDA) is available and moves the model and data to the appropriate device.

    Args:
        model (torch.nn.Module): The PyTorch model to train.
        hyperparameters (Hyperparameters): A `qim3d.ml.Hyperparameters` object containing the optimizer, loss function, and epoch count.
        train_loader (torch.utils.data.DataLoader): The DataLoader for the training set.
        val_loader (torch.utils.data.DataLoader): The DataLoader for the validation set.
        checkpoint_directory (str, optional): The directory where the final model weights (`.pth` file) will be saved. If `None`, the model is not saved to disk. Defaults to `None`.
        eval_every (int, optional): The number of epochs between validation runs. Defaults to 1 (validate every epoch).
        print_every (int, optional): The number of epochs between log updates. Defaults to 1 (log every epoch).
        plot (bool, optional): If `True`, displays a plot of the loss history after training finishes. Defaults to `True`.
        return_loss (bool, optional): If `True`, returns the history of loss values. Defaults to `False`.

    Returns:
        (train_loss, val_loss) (tuple[dict, dict] | None):
            Only returned if `return_loss` is `True`.
            * **train_loss**: A dictionary containing 'loss' (per epoch) and 'batch_loss' (per iteration).
            * **val_loss**: A dictionary containing 'loss' and 'batch_loss' for the validation set.

    Example:
        ```python
        import qim3d

        # 1. Setup components
        base_path = "dataset"
        model = qim3d.ml.models.UNet(size='xxsmall')
        hyperparameters = qim3d.ml.Hyperparameters(model, n_epochs=5)
        augmentation = qim3d.ml.Augmentation(resize='crop', transform_train='light')

        # 2. Prepare Data
        train_set, val_set, test_set = qim3d.ml.prepare_datasets(
            path=base_path,
            val_fraction=0.5,
            model=model,
            augmentation=augmentation
        )

        train_loader, val_loader, test_loader = qim3d.ml.prepare_dataloaders(
            train_set, val_set, test_set, batch_size=1
        )

        # 3. Train
        qim3d.ml.train_model(
            model=model,
            hyperparameters=hyperparameters,
            train_loader=train_loader,
            val_loader=val_loader,
            checkpoint_directory=base_path,
            plot=True
        )
        ```
    """
    # Get hyperparameters
    params_dict = hyperparameters()

    n_epochs = params_dict['n_epochs']
    optimizer = params_dict['optimizer']
    criterion = params_dict['criterion']

    # Choosing best device available
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    model.to(device)

    # Avoid logging twice
    log.propagate = False

    # Set up dictionaries to store training and validation losses
    train_loss = {'loss': [], 'batch_loss': []}
    val_loss = {'loss': [], 'batch_loss': []}

    with logging_redirect_tqdm():
        for epoch in tqdm(range(n_epochs), desc='Training epochs', unit='epoch'):
            epoch_loss = 0
            step = 0

            model.train()

            for data in train_loader:
                inputs, targets = data
                inputs = inputs.to(device)
                targets = targets.to(device)

                optimizer.zero_grad()
                outputs = model(inputs)

                loss = criterion(outputs, targets)

                # Backpropagation
                loss.backward()
                optimizer.step()

                epoch_loss += loss.detach().item()
                step += 1

                # Log and store batch training loss
                train_loss['batch_loss'].append(loss.detach().item())

            # Log and store average training loss per epoch
            epoch_loss = epoch_loss / step
            train_loss['loss'].append(epoch_loss)

            if epoch % eval_every == 0:
                eval_loss = 0
                step = 0

                model.eval()

                for data in val_loader:
                    inputs, targets = data
                    inputs = inputs.to(device)
                    targets = targets.to(device)

                    with torch.no_grad():
                        outputs = model(inputs)
                        loss = criterion(outputs, targets)

                    eval_loss += loss.item()
                    step += 1

                    # Log and store batch validation loss
                    val_loss['batch_loss'].append(loss.item())

                # Log and store average validation loss
                eval_loss = eval_loss / step
                val_loss['loss'].append(eval_loss)

                if epoch % print_every == 0:
                    log.info(
                        f"Epoch {epoch: 3}, train loss: {train_loss['loss'][epoch]:.4f}, "
                        f"val loss: {val_loss['loss'][epoch]:.4f}"
                    )

    if checkpoint_directory:
        checkpoint_filename = f'model_{n_epochs}epochs.pth'
        checkpoint_path = os.path.join(checkpoint_directory, checkpoint_filename)

        # Save model checkpoint to .pth file
        torch.save(model.state_dict(), checkpoint_path)
        log.info(f'Model checkpoint saved at: {checkpoint_path}')

    if plot:
        plot_metrics(train_loss, val_loss, labels=['Train', 'Valid.'], show=True)

    if return_loss:
        return train_loss, val_loss

qim3d.ml.load_checkpoint

load_checkpoint(model, checkpoint_path)

Restores a model's state (weights and biases) from a saved checkpoint file.

This function loads a dictionary of learned parameters from a .pth file and applies them to the provided model architecture. This is essential for:

  • Inference: Using a pre-trained model to make predictions on new data.
  • Resuming Training: Continuing the training process from a specific point.
  • Transfer Learning: Fine-tuning a pre-trained model on a new task.

Important: The architecture of the model object must match the architecture used when the checkpoint was saved. If the shapes of the layers do not align, a runtime error will occur.

Parameters:

Name Type Description Default
model Module

The initialized PyTorch model architecture (e.g., a UNet instance).

required
checkpoint_path str

The file path to the .pth checkpoint.

required

Returns:

Name Type Description
model Module

The model with its weights updated from the file.

Example
import qim3d

# 1. Define the architecture (must match the saved model)
model = qim3d.ml.models.UNet(size='small')

# 2. Path to the saved weights
checkpoint_path = "dataset/model_5epochs.pth"

# 3. Load the weights
model = qim3d.ml.load_checkpoint(model, checkpoint_path)

# The model is now ready for inference
print("Checkpoint loaded successfully.")
Source code in qim3d/ml/_ml_utils.py
def load_checkpoint(model: torch.nn.Module, checkpoint_path: str) -> torch.nn.Module:
    """
    Restores a model's state (weights and biases) from a saved checkpoint file.

    This function loads a dictionary of learned parameters from a `.pth` file and applies them to the provided model architecture. This is essential for:

    * **Inference**: Using a pre-trained model to make predictions on new data.
    * **Resuming Training**: Continuing the training process from a specific point.
    * **Transfer Learning**: Fine-tuning a pre-trained model on a new task.

    **Important:** The architecture of the `model` object must match the architecture used when the checkpoint was saved. If the shapes of the layers do not align, a runtime error will occur.

    Args:
        model (torch.nn.Module): The initialized PyTorch model architecture (e.g., a `UNet` instance).
        checkpoint_path (str): The file path to the `.pth` checkpoint.

    Returns:
        model (torch.nn.Module):
            The model with its weights updated from the file.

    Example:
        ```python
        import qim3d

        # 1. Define the architecture (must match the saved model)
        model = qim3d.ml.models.UNet(size='small')

        # 2. Path to the saved weights
        checkpoint_path = "dataset/model_5epochs.pth"

        # 3. Load the weights
        model = qim3d.ml.load_checkpoint(model, checkpoint_path)

        # The model is now ready for inference
        print("Checkpoint loaded successfully.")
        ```
    """
    model.load_state_dict(torch.load(checkpoint_path))
    log.info(f'Model checkpoint loaded from: {checkpoint_path}')

    return model

qim3d.ml.test_model

test_model(model, test_set, threshold=0.5)

Runs inference on a test dataset to generate segmentation predictions.

This function iterates through the provided test_set, applies the trained model, and post-processes the output. It automatically handles:

  1. Device Management: Moves data to GPU if available.
  2. Batching: Adds necessary batch dimensions for the model input.
  3. Activation: Applies a Sigmoid function to convert raw model outputs (logits) into probabilities.
  4. Thresholding: Converts probabilities into binary masks using the specified threshold.
  5. Format Conversion: Returns inputs, targets, and predictions as NumPy arrays for easy analysis or visualization.

Parameters:

Name Type Description Default
model Module

The trained PyTorch model.

required
test_set Dataset

The dataset containing (image, label) pairs to evaluate.

required
threshold float

The probability threshold for binary classification. Pixels with a probability higher than this value are classified as foreground (1). Defaults to 0.5.

0.5

Returns:

Name Type Description
results list[tuple[ndarray, ndarray, ndarray]]

A list of tuples, where each tuple corresponds to one sample in the test set and contains: * volume: The original input image. * target: The ground truth label. * pred: The predicted binary segmentation mask.

Raises:

Type Description
ValueError

If the items yielded by test_set are not PyTorch tensors.

Example
import qim3d
import matplotlib.pyplot as plt

# ... (Assume model and test_set are already prepared) ...

# Run inference
results = qim3d.ml.test_model(model=model, test_set=test_set)

# Visualize the first result
vol, target, pred = results[0]

# Display the middle slice of the prediction
mid_slice = pred.shape[0] // 2
plt.imshow(pred[mid_slice], cmap='gray')
plt.title("Predicted Segmentation")
plt.show()
Source code in qim3d/ml/_ml_utils.py
def test_model(
    model: torch.nn.Module,
    test_set: torch.utils.data.Dataset,
    threshold: float = 0.5,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Runs inference on a test dataset to generate segmentation predictions.

    This function iterates through the provided `test_set`, applies the trained `model`, and post-processes the output. It automatically handles:

    1.  **Device Management**: Moves data to GPU if available.
    2.  **Batching**: Adds necessary batch dimensions for the model input.
    3.  **Activation**: Applies a Sigmoid function to convert raw model outputs (logits) into probabilities.
    4.  **Thresholding**: Converts probabilities into binary masks using the specified `threshold`.
    5.  **Format Conversion**: Returns inputs, targets, and predictions as NumPy arrays for easy analysis or visualization.

    Args:
        model (torch.nn.Module): The trained PyTorch model.
        test_set (torch.utils.data.Dataset): The dataset containing (image, label) pairs to evaluate.
        threshold (float, optional): The probability threshold for binary classification. Pixels with a probability higher than this value are classified as foreground (1). Defaults to 0.5.

    Returns:
        results (list[tuple[np.ndarray, np.ndarray, np.ndarray]]):
            A list of tuples, where each tuple corresponds to one sample in the test set and contains:
            * **volume**: The original input image.
            * **target**: The ground truth label.
            * **pred**: The predicted binary segmentation mask.

    Raises:
        ValueError: If the items yielded by `test_set` are not PyTorch tensors.

    Example:
        ```python
        import qim3d
        import matplotlib.pyplot as plt

        # ... (Assume model and test_set are already prepared) ...

        # Run inference
        results = qim3d.ml.test_model(model=model, test_set=test_set)

        # Visualize the first result
        vol, target, pred = results[0]

        # Display the middle slice of the prediction
        mid_slice = pred.shape[0] // 2
        plt.imshow(pred[mid_slice], cmap='gray')
        plt.title("Predicted Segmentation")
        plt.show()
        ```
    """
    # Set model to evaluation mode
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model.to(device)
    model.eval()

    # List to store results
    results = []

    for volume, target in test_set:
        if not isinstance(volume, torch.Tensor) or not isinstance(target, torch.Tensor):
            msg = 'Data items must consist of tensors'
            raise ValueError(msg)

        # Add batch and channel dimensions
        volume = volume.unsqueeze(0).to(device)  # Shape: [1, 1, D, H, W]
        target = target.unsqueeze(0).to(device)  # Shape: [1, 1, D, H, W]

        with torch.no_grad():
            # Get model predictions (logits)
            output = model(volume)

            # Convert logits to probabilities [0, 1]
            pred = torch.sigmoid(output)

            # Convert to binary mask by thresholding the probabilities
            pred = (pred > threshold).float()

            # Remove batch and channel dimensions
            volume = volume.squeeze().cpu().numpy()
            target = target.squeeze().cpu().numpy()
            pred = pred.squeeze().cpu().numpy()

        # TODO: Compute DICE score between target and prediction?

        # Append results to list
        results.append((volume, target, pred))

    return results