Machine learning models
The qim3d library aims to ease the creation of ML models for volumetric images.
qim3d.ml.models.UNet
Bases: Module
Constructs a 3D U-Net model for volumetric image segmentation.
The U-Net architecture consists of a contracting path (encoder) to capture context and a symmetric expanding path (decoder) that enables precise localization. This implementation wraps the MONAI U-Net and provides simplified presets for model depth and width via the size argument.
Model Presets (Channels per Layer):
- 'xxsmall': (4, 8) - Ultra-lightweight (2 layers)
- 'xsmall': (16, 32) - Lightweight (2 layers)
- 'small': (32, 64, 128) - Fast training (3 layers)
- 'medium': (64, 128, 256) - Balanced performance (3 layers, default)
- 'large': (64, 128, 256, 512, 1024) - High capacity (5 layers)
- 'xlarge': (64, 128, 256, 512, 1024, 2048) - Very high capacity (6 layers)
- 'xxlarge': (64, 128, 256, 512, 1024, 2048, 4096) - Maximum capacity (7 layers)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
size
|
str
|
The complexity of the model. Must be one of 'xxsmall', 'xsmall', 'small', 'medium', 'large', 'xlarge', or 'xxlarge'. Defaults to 'medium'. |
'medium'
|
dropout
|
float
|
The dropout rate (0 to 1) applied to hidden layers to prevent overfitting. Defaults to 0. |
0
|
kernel_size
|
int
|
The size of the convolution kernel. Defaults to 3. |
3
|
up_kernel_size
|
int
|
The size of the up-convolution kernel. Defaults to 3. |
3
|
activation
|
str
|
The activation function to use (e.g., 'RELU', 'PReLU', 'Sigmoid'). Defaults to 'PReLU'. |
'PReLU'
|
bias
|
bool
|
Whether to include bias terms in convolutions. Defaults to |
True
|
adn_order
|
str
|
The ordering of Activation (A), Dropout (D), and Normalization (N) blocks. Defaults to 'NDA'. |
'NDA'
|
Returns:
| Name | Type | Description |
|---|---|---|
model |
Module
|
The initialized 3D U-Net model. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If |
Example
Source code in qim3d/ml/models/_unet.py
9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 | |
qim3d.ml.Augmentation
Configures data augmentation pipelines for 3D deep learning using MONAI.
This class simplifies the creation of augmentation strategies for training, validation, and testing. It allows you to select preset levels of intensity ('light', 'moderate', 'heavy') and define how input volumes are resized or cropped to match the model's input requirements.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
resize
|
str
|
The method used to conform input images to a specific size. Defaults to 'crop'. * 'crop': Extracts a central crop of the desired size. * 'reshape': Resizes (interpolates) the image to the desired size. * 'padding': Pads the image with zeros to reach the desired size. |
'crop'
|
transform_train
|
str | None
|
The intensity of augmentation applied to the training set. Options: 'light', 'moderate', 'heavy', or |
'moderate'
|
transform_validation
|
str | None
|
The intensity of augmentation applied to the validation set. Defaults to |
None
|
transform_test
|
str | None
|
The intensity of augmentation applied to the test set. Defaults to |
None
|
Raises:
| Type | Description |
|---|---|
ValueError
|
If |
Example
Source code in qim3d/ml/_augmentations.py
8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 | |
qim3d.ml.Augmentation.augment
Builds a MONAI composition of transforms based on the specified intensity level.
This method constructs the actual pipeline of operations (e.g., rotations, flips, smoothing) that will be applied to the data.
Augmentation Levels:
- None: No augmentation. Only baseline formatting (ToTensor) is applied.
- 'light': Random 90-degree rotations.
- 'moderate': Rotations, flips, slight Gaussian smoothing, and minor affine transformations (scaling/translation).
- 'heavy': Aggressive rotations, flips, stronger smoothing, and significant affine transformations including shearing.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
img_shape
|
tuple
|
The target dimensions of the volume as |
required |
level
|
str | None
|
The specific augmentation level to generate. Must be one of |
None
|
Returns:
| Name | Type | Description |
|---|---|---|
Compose |
Compose
|
A MONAI |
Raises:
| Type | Description |
|---|---|
ValueError
|
If |
Source code in qim3d/ml/_augmentations.py
52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 | |
qim3d.ml.Hyperparameters
Configuration wrapper for training parameters (Optimizer, Loss, Epochs).
This class centralizes the setup of the training loop. It automatically instantiates the PyTorch optimizer and loss function based on string arguments, ensuring valid combinations and settings.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Module
|
The PyTorch model to be trained (required to register parameters with the optimizer). |
required |
n_epochs
|
int
|
The number of complete passes through the training dataset. Defaults to 10. |
10
|
learning_rate
|
float
|
The step size for the optimizer. Defaults to 1e-3. |
0.001
|
optimizer
|
str
|
The optimization algorithm. Options: 'Adam', 'SGD', 'RMSprop'. Defaults to 'Adam'. |
'Adam'
|
momentum
|
float
|
The momentum factor (only used for 'SGD' and 'RMSprop'). Accelerates gradient vectors in the right directions, leading to faster converging. Defaults to 0. |
0
|
weight_decay
|
float
|
L2 penalty applied to the weights to prevent overfitting. Defaults to 0. |
0
|
loss_function
|
str
|
The objective function to minimize. Options: * 'BCE': Binary Cross Entropy (with Logits). * 'Dice': Dice Loss (good for class imbalance). * 'Focal': Focal Loss (focuses on hard examples). * 'DiceCE': Weighted sum of Dice and Cross Entropy. Defaults to 'Focal'. |
'Focal'
|
Returns:
| Name | Type | Description |
|---|---|---|
hyperparameters |
dict
|
A dictionary containing the initialized objects, accessible via the |
Raises:
| Type | Description |
|---|---|
ValueError
|
If |
Example
import qim3d
# 1. Initialize model
model = qim3d.ml.models.UNet(size='small')
# 2. Define training configuration
hyperparameters = qim3d.ml.Hyperparameters(
model=model,
n_epochs=10,
learning_rate=5e-3,
optimizer='Adam',
loss_function='DiceCE'
)
# 3. Retrieve initialized objects for the training loop
params_dict = hyperparameters()
optimizer = params_dict['optimizer']
criterion = params_dict['criterion']
print(f"Ready to train for {params_dict['n_epochs']} epochs with {optimizer.__class__.__name__}")
Source code in qim3d/ml/models/_unet.py
115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 | |
qim3d.ml.prepare_datasets
Loads, splits, and applies augmentations to the dataset.
This function automates the creation of PyTorch datasets for training, validation, and testing. It handles:
- Loading: Reads data from the specified directory structure (expects
train/images,train/labels,test/images,test/labels). - Splitting: Divides the training data into a training set and a validation set based on
val_fraction. - Augmentation: Applies the transformation pipelines defined in the
augmentationobject to each split.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path
|
str
|
The root directory of the dataset. |
required |
val_fraction
|
float
|
The proportion of the training data to reserve for validation (0.0 to 1.0). |
required |
model
|
Module
|
The PyTorch model. (Used to determine the expected input dimensions/channels for resizing checks). |
required |
augmentation
|
Augmentation
|
An instance of |
required |
Returns:
| Type | Description |
|---|---|
(train_set, val_set, test_set)
|
A tuple of |
Raises:
| Type | Description |
|---|---|
ValueError
|
If |
Example
import qim3d
# Define parameters
base_path = "dataset"
model = qim3d.ml.models.UNet(size='small')
# Configure augmentation
aug_pipeline = qim3d.ml.Augmentation(resize='crop', transform_train='light')
# Generate datasets
train_ds, val_ds, test_ds = qim3d.ml.prepare_datasets(
path=base_path,
val_fraction=0.2,
model=model,
augmentation=aug_pipeline
)
print(f"Training samples: {len(train_ds)}")
print(f"Validation samples: {len(val_ds)}")
Source code in qim3d/ml/_data.py
170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 | |
qim3d.ml.prepare_dataloaders
prepare_dataloaders(
train_set,
val_set,
test_set,
batch_size,
shuffle_train=True,
num_workers=8,
pin_memory=False,
)
Wraps the datasets into PyTorch DataLoaders for efficient batch processing.
DataLoaders handle the complexity of batching, shuffling, and parallel data loading. This function ensures that your model receives data in the correct format and speed during training.
Note: It is recommended to keep shuffle_train=True to improve model generalization.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
train_set
|
Dataset
|
The training dataset. |
required |
val_set
|
Dataset
|
The validation dataset. |
required |
test_set
|
Dataset
|
The testing dataset. |
required |
batch_size
|
int
|
The number of samples to process in a single step (batch). |
required |
shuffle_train
|
bool
|
If |
True
|
num_workers
|
int
|
The number of subprocesses to use for data loading. |
8
|
pin_memory
|
bool
|
If |
False
|
Returns:
| Type | Description |
|---|---|
(train_loader, val_loader, test_loader)
|
A tuple containing the three |
Example
import qim3d
# ... (assume datasets are already created as train_set, val_set, test_set) ...
# Create DataLoaders
train_loader, val_loader, test_loader = qim3d.ml.prepare_dataloaders(
train_set=train_set,
val_set=val_set,
test_set=test_set,
batch_size=4,
num_workers=4
)
# Iterate through the training loader
for batch in train_loader:
inputs, labels = batch['image'], batch['label']
# training step...
Source code in qim3d/ml/_data.py
263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 | |
qim3d.ml.model_summary
Generates a detailed summary of the model's architecture and parameter count.
This function provides a comprehensive overview of the model, including the output shape of each layer, the number of trainable parameters, and the estimated memory usage. It automatically infers the input dimensions by sampling a single batch from the provided DataLoader.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Module
|
The PyTorch model to analyze. |
required |
dataloader
|
DataLoader
|
A DataLoader used to retrieve a sample batch for input shape inference. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
model_s |
ModelStatistics
|
An object containing the model statistics. When printed, it displays a formatted table of layers and parameters. |
Example
Source code in qim3d/ml/_ml_utils.py
qim3d.ml.train_model
train_model(
model,
hyperparameters,
train_loader,
val_loader,
checkpoint_directory=None,
eval_every=1,
print_every=1,
plot=True,
return_loss=False,
)
Executes the training loop for a PyTorch model.
This function manages the iterative process of training. It handles:
- Training Steps: Iterating through the training data, computing gradients (backpropagation), and updating model weights.
- Validation: Periodically evaluating the model on unseen data to monitor for overfitting.
- Logging: Printing loss values to track convergence.
- Checkpointing: Saving the final model weights to disk.
- Visualization: Plotting training and validation loss curves.
The function automatically detects if a GPU (CUDA) is available and moves the model and data to the appropriate device.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Module
|
The PyTorch model to train. |
required |
hyperparameters
|
Hyperparameters
|
A |
required |
train_loader
|
DataLoader
|
The DataLoader for the training set. |
required |
val_loader
|
DataLoader
|
The DataLoader for the validation set. |
required |
checkpoint_directory
|
str
|
The directory where the final model weights ( |
None
|
eval_every
|
int
|
The number of epochs between validation runs. Defaults to 1 (validate every epoch). |
1
|
print_every
|
int
|
The number of epochs between log updates. Defaults to 1 (log every epoch). |
1
|
plot
|
bool
|
If |
True
|
return_loss
|
bool
|
If |
False
|
Returns:
| Type | Description |
|---|---|
train_loss, val_loss) (tuple[dict, dict] | None
|
Only returned if |
Example
import qim3d
# 1. Setup components
base_path = "dataset"
model = qim3d.ml.models.UNet(size='xxsmall')
hyperparameters = qim3d.ml.Hyperparameters(model, n_epochs=5)
augmentation = qim3d.ml.Augmentation(resize='crop', transform_train='light')
# 2. Prepare Data
train_set, val_set, test_set = qim3d.ml.prepare_datasets(
path=base_path,
val_fraction=0.5,
model=model,
augmentation=augmentation
)
train_loader, val_loader, test_loader = qim3d.ml.prepare_dataloaders(
train_set, val_set, test_set, batch_size=1
)
# 3. Train
qim3d.ml.train_model(
model=model,
hyperparameters=hyperparameters,
train_loader=train_loader,
val_loader=val_loader,
checkpoint_directory=base_path,
plot=True
)
Source code in qim3d/ml/_ml_utils.py
18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 | |
qim3d.ml.load_checkpoint
Restores a model's state (weights and biases) from a saved checkpoint file.
This function loads a dictionary of learned parameters from a .pth file and applies them to the provided model architecture. This is essential for:
- Inference: Using a pre-trained model to make predictions on new data.
- Resuming Training: Continuing the training process from a specific point.
- Transfer Learning: Fine-tuning a pre-trained model on a new task.
Important: The architecture of the model object must match the architecture used when the checkpoint was saved. If the shapes of the layers do not align, a runtime error will occur.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Module
|
The initialized PyTorch model architecture (e.g., a |
required |
checkpoint_path
|
str
|
The file path to the |
required |
Returns:
| Name | Type | Description |
|---|---|---|
model |
Module
|
The model with its weights updated from the file. |
Example
import qim3d
# 1. Define the architecture (must match the saved model)
model = qim3d.ml.models.UNet(size='small')
# 2. Path to the saved weights
checkpoint_path = "dataset/model_5epochs.pth"
# 3. Load the weights
model = qim3d.ml.load_checkpoint(model, checkpoint_path)
# The model is now ready for inference
print("Checkpoint loaded successfully.")
Source code in qim3d/ml/_ml_utils.py
qim3d.ml.test_model
Runs inference on a test dataset to generate segmentation predictions.
This function iterates through the provided test_set, applies the trained model, and post-processes the output. It automatically handles:
- Device Management: Moves data to GPU if available.
- Batching: Adds necessary batch dimensions for the model input.
- Activation: Applies a Sigmoid function to convert raw model outputs (logits) into probabilities.
- Thresholding: Converts probabilities into binary masks using the specified
threshold. - Format Conversion: Returns inputs, targets, and predictions as NumPy arrays for easy analysis or visualization.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Module
|
The trained PyTorch model. |
required |
test_set
|
Dataset
|
The dataset containing (image, label) pairs to evaluate. |
required |
threshold
|
float
|
The probability threshold for binary classification. Pixels with a probability higher than this value are classified as foreground (1). Defaults to 0.5. |
0.5
|
Returns:
| Name | Type | Description |
|---|---|---|
results |
list[tuple[ndarray, ndarray, ndarray]]
|
A list of tuples, where each tuple corresponds to one sample in the test set and contains: * volume: The original input image. * target: The ground truth label. * pred: The predicted binary segmentation mask. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If the items yielded by |
Example
import qim3d
import matplotlib.pyplot as plt
# ... (Assume model and test_set are already prepared) ...
# Run inference
results = qim3d.ml.test_model(model=model, test_set=test_set)
# Visualize the first result
vol, target, pred = results[0]
# Display the middle slice of the prediction
mid_slice = pred.shape[0] // 2
plt.imshow(pred[mid_slice], cmap='gray')
plt.title("Predicted Segmentation")
plt.show()
Source code in qim3d/ml/_ml_utils.py
268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 | |