Skip to content

Plot and render

Visualization of volumetric data.

qim3d.viz.export_rotation

export_rotation(
    path,
    vol,
    degrees=360,
    num_frames=180,
    fps=30,
    image_size=(256, 256),
    color_map='magma',
    camera_height=2.0,
    camera_distance='auto',
    camera_focus='center',
    show=False,
)

Export a rotation animation of volume.

Parameters:

Name Type Description Default
path str

The path to save the output. The path should end with .gif, .avi, .mp4 or .webm. If no file extension is specified, .gif is automatically added.

required
vol ndarray

Volume to create .gif of.

required
degrees int

The amount of degrees for the volume to rotate. Defaults to 360.

360
num_frames int

The amount of frames to generate. Defaults to 180.

180
fps int

The amount of frames per second in the resulting animation. This determines the speed of the rotation of the volume. Defaults to 30.

30
image_size tuple of ints or None

Pixel size (width, height) of each frame. If None, the plotter's default size is used. Defaults to (256, 256).

(256, 256)
color_map str

Determines color map of volume. Defaults to 'magma'.

'magma'
camera_height float

Determines the height of the camera rotating around the volume. The float value represents a multiple of the height of the z-axis. Defaults to 2.0.

2.0
camera_distance int or string

Determines the distance of the camera from the center point. If 'auto' is used, it will be auto calculated. Otherwise a float value representing voxel distance is expected. Defaults to 'auto'.

'auto'
camera_focus list or str

Determines the voxel that the camera rotates around. Using 'center' will default to the center of the volume. Otherwise a list of three integers is expected. Defaults to 'center'.

'center'
show bool

If True, the resulting animation will be shown in the Jupyter notebook. Defaults to False.

False

Returns:

Type Description
None

None

Raises:

Type Description
TypeError

If the camera focus argument is incorrectly used.

TypeError

If the camera_distance argument is incorrectly used.

ValueError

If the path contains an unrecognized file extension.

Example

Creation of .gif file with default parameters of a generated volume.

import qim3d
vol = qim3d.generate.volume()

qim3d.viz.export_rotation('test.gif', vol, show=True)
export_rotation_defaults

Example

Creation of a .webm file with specified parameters of a generated volume in the shape of a tube.

import qim3d

vol = qim3d.generate.volume(shape='tube')

qim3d.viz.export_rotation('test.webm', vol,
                          degrees = 360,
                          num_frames = 120,
                          fps = 30,
                          image_size = (512,512),
                          camera_height = 3.0,
                          camera_distance = 'auto',
                          camera_focus = 'center',
                          show = True)
export_rotation_video

Source code in qim3d/viz/_data_exploration.py
def export_rotation(
    path: str,
    vol: np.ndarray,
    degrees: int = 360,
    num_frames: int = 180,
    fps: int = 30,
    image_size: tuple[int, int] | None = (256, 256),
    color_map: str = 'magma',
    camera_height: float = 2.0,
    camera_distance: float | str = 'auto',
    camera_focus: list | str = 'center',
    show: bool = False,
) -> None:
    """
    Export a rotation animation of volume.

    Args:
        path (str): The path to save the output. The path should end with .gif, .avi, .mp4 or .webm. If no file extension is specified, .gif is automatically added.
        vol (np.ndarray): Volume to create .gif of.
        degrees (int, optional): The amount of degrees for the volume to rotate. Defaults to 360.
        num_frames (int, optional): The amount of frames to generate. Defaults to 180.
        fps (int, optional): The amount of frames per second in the resulting animation. This determines the speed of the rotation of the volume. Defaults to 30.
        image_size (tuple of ints or None, optional): Pixel size (width, height) of each frame. If None, the plotter's default size is used. Defaults to (256, 256).
        color_map (str, optional): Determines color map of volume. Defaults to 'magma'.
        camera_height (float, optional): Determines the height of the camera rotating around the volume. The float value represents a multiple of the height of the z-axis. Defaults to 2.0.
        camera_distance (int or string, optional): Determines the distance of the camera from the center point. If 'auto' is used, it will be auto calculated. Otherwise a float value representing voxel distance is expected. Defaults to 'auto'.
        camera_focus (list or str, optional): Determines the voxel that the camera rotates around. Using 'center' will default to the center of the volume. Otherwise a list of three integers is expected. Defaults to 'center'.
        show (bool, optional): If True, the resulting animation will be shown in the Jupyter notebook. Defaults to False.

    Returns:
        None


    Raises:
        TypeError: If the camera focus argument is incorrectly used.
        TypeError: If the camera_distance argument is incorrectly used.
        ValueError: If the path contains an unrecognized file extension.

    Example:
        Creation of .gif file with default parameters of a generated volume.
        ```python
        import qim3d
        vol = qim3d.generate.volume()

        qim3d.viz.export_rotation('test.gif', vol, show=True)
        ```
        ![export_rotation_defaults](../../assets/screenshots/export_rotation_defaults.gif)

    Example:
        Creation of a .webm file with specified parameters of a generated volume in the shape of a tube.
        ```python
        import qim3d

        vol = qim3d.generate.volume(shape='tube')

        qim3d.viz.export_rotation('test.webm', vol,
                                  degrees = 360,
                                  num_frames = 120,
                                  fps = 30,
                                  image_size = (512,512),
                                  camera_height = 3.0,
                                  camera_distance = 'auto',
                                  camera_focus = 'center',
                                  show = True)
        ```
        ![export_rotation_video](../../assets/screenshots/export_rotation_video.gif)

    """
    if not (
        camera_focus == 'center'
        or (
            isinstance(camera_focus, list | np.ndarray)
            and not isinstance(camera_focus, str)
            and len(camera_focus) == 3
        )
    ):
        msg = f'Value "{camera_focus}" for camera focus is invalid. Use "center" or a list of three values.'
        raise TypeError(msg)
    if not (isinstance(camera_distance, float) or camera_distance == 'auto'):
        msg = f'Value "{camera_distance}" for camera distance is invalid. Use "auto" or a float value.'
        raise TypeError(msg)

    if Path(path).suffix == '':
        print(f'Input path: "{path}" does not have a filetype. Defaulting to .gif.')
        path += '.gif'

    # Handle img in (xyz) instead of (zyx) (due to rendering issues with the up-vector, ensure that z=y, such that we now have (x,z,y))
    vol = np.transpose(vol, (2, 0, 1))

    # Create a uniform grid
    grid = pv.ImageData()
    grid.dimensions = np.array(vol.shape) + 1  # PyVista dims are +1 from volume shape
    grid.spacing = (1, 1, 1)
    grid.origin = (0, 0, 0)
    grid.cell_data['values'] = vol.flatten(order='F')  # Fortran order

    # Initialize plotter
    plotter = pv.Plotter(off_screen=True)
    plotter.add_volume(grid, opacity='linear', cmap=color_map)
    plotter.remove_scalar_bar()  # Remove colorbar

    frames = []
    camera_height = vol.shape[1] * camera_height

    if camera_distance == 'auto':
        bounds = np.array(plotter.bounds)  # (xmin, xmax, ymin, ymax, zmin, zmax)
        diag = np.linalg.norm(
            [bounds[1] - bounds[0], bounds[3] - bounds[2], bounds[5] - bounds[4]]
        )
        camera_distance = diag * 2.0

    if camera_focus == 'center':
        _, center, _ = plotter.camera_position
    else:
        center = camera_focus

    center = np.array(center)

    angle_per_frame = degrees / num_frames
    radians_per_frame = np.radians(angle_per_frame)

    # Set up orbit radius and fixed up
    radius = camera_distance
    fixed_up = [0, 1, 0]
    for i in tqdm(range(num_frames), desc='Rendering'):
        theta = radians_per_frame * i
        x = radius * np.sin(theta)
        z = radius * np.cos(theta)
        y = camera_height  # fixed height

        eye = center + np.array([x, y, z])
        plotter.camera_position = [eye.tolist(), center.tolist(), fixed_up]

        plotter.render()
        img = plotter.screenshot(return_img=True, window_size=image_size)
        frames.append(img)

    if path[-4:] == '.gif':
        imageio.mimsave(path, frames, fps=fps, loop=0)

    elif path[-4:] == '.avi' or path[-4:] == '.mp4':
        writer = imageio.get_writer(path, fps=fps)
        for frame in frames:
            writer.append_data(frame)
        writer.close()

    elif path[-5:] == '.webm':
        writer = imageio.get_writer(
            path, fps=fps, codec='vp9', ffmpeg_params=['-crf', '32']
        )
        for frame in frames:
            writer.append_data(frame)
        writer.close()

    else:
        msg = 'Invalid file extension. Please use .gif, .avi, .mp4 or .webm'
        raise ValueError(msg)

    path = _get_save_path(path)
    log.info('File saved to ' + str(path.resolve()))

    if show:
        if path.suffix == '.gif':
            display(Image(filename=path))
        elif path.suffix in ['.avi', '.mp4', '.webm']:
            display(Video(filename=path, html_attributes='controls autoplay loop'))

qim3d.viz.circles

circles(blobs, vol, alpha=0.5, color='#ff9900', **kwargs)

Plots the blobs found on a slice of the volume.

This function takes in a 3D volume and a list of blobs (detected features) and plots the blobs on a specified slice of the volume. If no slice is specified, it defaults to the middle slice of the volume.

Parameters:

Name Type Description Default
blobs ndarray

An array-like object of blobs, where each blob is represented as a 4-tuple (p, r, c, radius). Usually the result of qim3d.processing.blob_detection(vol)

required
vol ndarray

The 3D volume on which to plot the blobs.

required
alpha float

The transparency of the blobs. Defaults to 0.5.

0.5
color str

The color of the blobs. Defaults to "#ff9900".

'#ff9900'
**kwargs Any

Arbitrary keyword arguments for the slices function.

{}

Returns:

Name Type Description
slicer_obj interactive

An interactive widget for visualizing the blobs.

Example

import qim3d
import qim3d.detection

# Get data
vol = qim3d.examples.cement_128x128x128

# Detect blobs, and get binary mask
blobs, _ = qim3d.detection.blobs(
    vol,
    min_sigma=1,
    max_sigma=8,
    threshold=0.001,
    overlap=0.1,
    background="bright"
    )

# Visualize detected blobs with circles method
qim3d.viz.circles(blobs, vol, alpha=0.8, color='blue')
blob detection

Source code in qim3d/viz/_detection.py
def circles(
    blobs: tuple[float, float, float, float],
    vol: np.ndarray,
    alpha: float = 0.5,
    color: str = '#ff9900',
    **kwargs,
) -> widgets.interactive:
    """
    Plots the blobs found on a slice of the volume.

    This function takes in a 3D volume and a list of blobs (detected features)
    and plots the blobs on a specified slice of the volume. If no slice is specified,
    it defaults to the middle slice of the volume.

    Args:
        blobs (np.ndarray): An array-like object of blobs, where each blob is represented
            as a 4-tuple (p, r, c, radius). Usually the result of `qim3d.processing.blob_detection(vol)`
        vol (np.ndarray): The 3D volume on which to plot the blobs.
        alpha (float, optional): The transparency of the blobs. Defaults to 0.5.
        color (str, optional): The color of the blobs. Defaults to "#ff9900".
        **kwargs (Any): Arbitrary keyword arguments for the `slices` function.

    Returns:
        slicer_obj (ipywidgets.interactive): An interactive widget for visualizing the blobs.

    Example:
        ```python
        import qim3d
        import qim3d.detection

        # Get data
        vol = qim3d.examples.cement_128x128x128

        # Detect blobs, and get binary mask
        blobs, _ = qim3d.detection.blobs(
            vol,
            min_sigma=1,
            max_sigma=8,
            threshold=0.001,
            overlap=0.1,
            background="bright"
            )

        # Visualize detected blobs with circles method
        qim3d.viz.circles(blobs, vol, alpha=0.8, color='blue')
        ```
        ![blob detection](../../assets/screenshots/blob_detection.gif)

    """

    def _slicer(z_slice):
        clear_output(wait=True)
        fig = qim3d.viz.slices_grid(
            vol[z_slice : z_slice + 1],
            num_slices=1,
            color_map='gray',
            display_figure=False,
            display_positions=False,
            **kwargs,
        )
        # Add circles from deteced blobs
        for detected in blobs:
            z, y, x, s = detected
            if abs(z - z_slice) < s:  # The blob is in the slice
                # Adjust the radius based on the distance from the center of the sphere
                distance_from_center = abs(z - z_slice)
                angle = (
                    np.pi / 2 * (distance_from_center / s)
                )  # Angle varies from 0 at the center to pi/2 at the edge
                adjusted_radius = s * np.cos(angle)  # Radius follows a cosine curve

                if adjusted_radius > 0.5:
                    c = plt.Circle(
                        (x, y),
                        adjusted_radius,
                        color=color,
                        linewidth=0,
                        fill=True,
                        alpha=alpha,
                    )
                    fig.get_axes()[0].add_patch(c)

        display(fig)
        return fig

    position_slider = widgets.IntSlider(
        value=vol.shape[0] // 2,
        min=0,
        max=vol.shape[0] - 1,
        description='Slice',
        continuous_update=True,
    )
    slicer_obj = widgets.interactive(_slicer, z_slice=position_slider)
    slicer_obj.layout = widgets.Layout(align_items='flex-start')

    return slicer_obj

qim3d.viz.local_thickness

local_thickness(
    image,
    image_lt,
    max_projection=False,
    axis=0,
    slice_idx=None,
    show=False,
    figsize=(15, 5),
)

Visualizes the local thickness of a 2D or 3D image.

Parameters:

Name Type Description Default
image ndarray

2D or 3D NumPy array representing the image/volume.

required
image_lt ndarray

2D or 3D NumPy array representing the local thickness of the input image/volume.

required
max_projection bool

If True, displays the maximum projection of the local thickness. Only used for 3D images. Defaults to False.

False
axis int

The axis along which to visualize the local thickness. Unused for 2D images. Defaults to 0.

0
slice_idx int or float

The initial slice to be visualized. The slice index can afterwards be changed. If value is an integer, it will be the index of the slice to be visualized. If value is a float between 0 and 1, it will be multiplied by the number of slices and rounded to the nearest integer. If None, the middle slice will be used for 3D images. Unused for 2D images. Defaults to None.

None
show bool

If True, displays the plot (i.e. calls plt.show()). Defaults to False.

False
figsize tuple

The size of the figure. Defaults to (15, 5).

(15, 5)

Raises:

Type Description
ValueError

If the slice index is not an integer or a float between 0 and 1.

Returns:

Name Type Description
local_thickness interactive or Figure

If the input is 3D, returns an interactive widget. Otherwise, returns a matplotlib figure.

Example

import qim3d

fly = qim3d.examples.fly_150x256x256
lt_fly = qim3d.processing.local_thickness(fly)
qim3d.viz.local_thickness(fly, lt_fly, axis=0)
local thickness 3d

Source code in qim3d/viz/_local_thickness.py
def local_thickness(
    image: np.ndarray,
    image_lt: np.ndarray,
    max_projection: bool = False,
    axis: int = 0,
    slice_idx: Optional[Union[int, float]] = None,
    show: bool = False,
    figsize: Tuple[int, int] = (15, 5),
) -> Union[plt.Figure, widgets.interactive]:
    """
    Visualizes the local thickness of a 2D or 3D image.

    Args:
        image (np.ndarray): 2D or 3D NumPy array representing the image/volume.
        image_lt (np.ndarray): 2D or 3D NumPy array representing the local thickness of the input
            image/volume.
        max_projection (bool, optional): If True, displays the maximum projection of the local
            thickness. Only used for 3D images. Defaults to False.
        axis (int, optional): The axis along which to visualize the local thickness.
            Unused for 2D images.
            Defaults to 0.
        slice_idx (int or float, optional): The initial slice to be visualized. The slice index
            can afterwards be changed. If value is an integer, it will be the index of the slice
            to be visualized. If value is a float between 0 and 1, it will be multiplied by the
            number of slices and rounded to the nearest integer. If None, the middle slice will
            be used for 3D images. Unused for 2D images. Defaults to None.
        show (bool, optional): If True, displays the plot (i.e. calls plt.show()). Defaults to False.
        figsize (tuple, optional): The size of the figure. Defaults to (15, 5).

    Raises:
        ValueError: If the slice index is not an integer or a float between 0 and 1.

    Returns:
        local_thickness (widgets.interactive or plt.Figure): If the input is 3D, returns an interactive widget. Otherwise, returns a matplotlib figure.

    Example:
        ```python
        import qim3d

        fly = qim3d.examples.fly_150x256x256
        lt_fly = qim3d.processing.local_thickness(fly)
        qim3d.viz.local_thickness(fly, lt_fly, axis=0)
        ```
        ![local thickness 3d](../../assets/screenshots/local_thickness_3d.gif)


    """

    def _local_thickness(image, image_lt, show, figsize, axis=None, slice_idx=None):
        if slice_idx is not None:
            image = image.take(slice_idx, axis=axis)
            image_lt = image_lt.take(slice_idx, axis=axis)

        fig, axs = plt.subplots(1, 3, figsize=figsize, layout='constrained')

        axs[0].imshow(image, cmap='gray')
        axs[0].set_title('Original image')
        axs[0].axis('off')

        axs[1].imshow(image_lt, cmap='viridis')
        axs[1].set_title('Local thickness')
        axs[1].axis('off')

        plt.colorbar(
            axs[1].imshow(image_lt, cmap='viridis'), ax=axs[1], orientation='vertical'
        )

        axs[2].hist(image_lt[image_lt > 0].ravel(), bins=32, edgecolor='black')
        axs[2].set_title('Local thickness histogram')
        axs[2].set_xlabel('Local thickness')
        axs[2].set_ylabel('Count')

        if show:
            plt.show()

        plt.close()

        return fig

    # Get the middle slice if the input is 3D
    if len(image.shape) == 3:
        if max_projection:
            if slice_idx is not None:
                log.warning(
                    'slice_idx is not used for max_projection. It will be ignored.'
                )
            image = image.max(axis=axis)
            image_lt = image_lt.max(axis=axis)
            return _local_thickness(image, image_lt, show, figsize)
        else:
            if slice_idx is None:
                slice_idx = image.shape[axis] // 2
            elif isinstance(slice_idx, float):
                if slice_idx < 0 or slice_idx > 1:
                    raise ValueError(
                        'Values of slice_idx of float type must be between 0 and 1.'
                    )
                slice_idx = int(slice_idx * image.shape[0]) - 1
            slide_idx_slider = widgets.IntSlider(
                min=0,
                max=image.shape[axis] - 1,
                step=1,
                value=slice_idx,
                description='Slice index',
                layout=widgets.Layout(width='450px'),
            )
            widget_obj = widgets.interactive(
                _local_thickness,
                image=widgets.fixed(image),
                image_lt=widgets.fixed(image_lt),
                show=widgets.fixed(True),
                figsize=widgets.fixed(figsize),
                axis=widgets.fixed(axis),
                slice_idx=slide_idx_slider,
            )
            widget_obj.layout = widgets.Layout(align_items='center')
            if show:
                display(widget_obj)
            return widget_obj
    else:
        if max_projection:
            log.warning(
                'max_projection is only used for 3D images. It will be ignored.'
            )
        if slice_idx is not None:
            log.warning('slice_idx is only used for 3D images. It will be ignored.')
        return _local_thickness(image, image_lt, show, figsize)

qim3d.viz.vectors

vectors(
    volume,
    vec,
    axis=0,
    volume_cmap='grey',
    vmin=None,
    vmax=None,
    slice_idx=None,
    grid_size=10,
    interactive=True,
    figsize=(10, 5),
    show=False,
)

Visualizes the orientation of the structures in a 3D volume using the eigenvectors of the structure tensor.

Parameters:

Name Type Description Default
volume ndarray

The 3D volume to be sliced.

required
vec ndarray

The eigenvectors of the structure tensor.

required
axis int

The axis along which to visualize the orientation. Defaults to 0.

0
volume_cmap str

Defines colormap for display of the volume

'grey'
vmin float

Together with vmax define the data range the colormap covers. By default colormap covers the full range. Defaults to None.

None
vmax float

Together with vmin define the data range the colormap covers. By default colormap covers the full range. Defaults to None

None
slice_idx int or float or None

The initial slice to be visualized. The slice index can afterwards be changed. If value is an integer, it will be the index of the slice to be visualized. If value is a float between 0 and 1, it will be multiplied by the number of slices and rounded to the nearest integer. If None, the middle slice will be used. Defaults to None.

None
grid_size int

The size of the grid. Defaults to 10.

10
interactive bool

If True, returns an interactive widget. Defaults to True.

True
figsize tuple

The size of the figure. Defaults to (15, 5).

(10, 5)
show bool

If True, displays the plot (i.e. calls plt.show()). Defaults to False.

False

Raises:

Type Description
ValueError

If the axis to slice along is not 0, 1, or 2.

ValueError

If the slice index is not an integer or a float between 0 and 1.

Returns:

Name Type Description
fig interactive or Figure

If interactive is True, returns an interactive widget. Otherwise, returns a matplotlib figure.

Note

The orientation of the vectors is visualized using an HSV color map, where the saturation corresponds to the vector component of the slicing direction (i.e. z-component when choosing visualization along axis = 0). Hence, if an orientation in the volume is orthogonal to the slicing direction, the corresponding color of the visualization will be gray.

Example

import qim3d

vol = qim3d.examples.NT_128x128x128
val, vec = qim3d.processing.structure_tensor(vol)

# Visualize the structure tensor
qim3d.viz.vectors(vol, vec, axis = 2, interactive = True)
structure tensor

Source code in qim3d/viz/_structure_tensor.py
def vectors(
    volume: np.ndarray,
    vec: np.ndarray,
    axis: int = 0,
    volume_cmap: str = 'grey',
    vmin: float | None = None,
    vmax: float | None = None,
    slice_idx: Union[int, float] | None = None,
    grid_size: int = 10,
    interactive: bool = True,
    figsize: Tuple[int, int] = (10, 5),
    show: bool = False,
) -> Union[plt.Figure, widgets.interactive]:
    """
    Visualizes the orientation of the structures in a 3D volume using the eigenvectors of the structure tensor.

    Args:
        volume (np.ndarray): The 3D volume to be sliced.
        vec (np.ndarray): The eigenvectors of the structure tensor.
        axis (int, optional): The axis along which to visualize the orientation. Defaults to 0.
        volume_cmap (str, optional): Defines colormap for display of the volume
        vmin (float, optional): Together with vmax define the data range the colormap covers. By default colormap covers the full range. Defaults to None.
        vmax (float, optional): Together with vmin define the data range the colormap covers. By default colormap covers the full range. Defaults to None
        slice_idx (int or float or None, optional): The initial slice to be visualized. The slice index
            can afterwards be changed. If value is an integer, it will be the index of the slice
            to be visualized. If value is a float between 0 and 1, it will be multiplied by the
            number of slices and rounded to the nearest integer. If None, the middle slice will
            be used. Defaults to None.
        grid_size (int, optional): The size of the grid. Defaults to 10.
        interactive (bool, optional): If True, returns an interactive widget. Defaults to True.
        figsize (tuple, optional): The size of the figure. Defaults to (15, 5).
        show (bool, optional): If True, displays the plot (i.e. calls plt.show()). Defaults to False.

    Raises:
        ValueError: If the axis to slice along is not 0, 1, or 2.
        ValueError: If the slice index is not an integer or a float between 0 and 1.

    Returns:
        fig (widgets.interactive or plt.Figure): If `interactive` is True, returns an interactive widget. Otherwise, returns a matplotlib figure.

    Note:
        The orientation of the vectors is visualized using an HSV color map, where the saturation corresponds to the vector component
        of the slicing direction (i.e. z-component when choosing visualization along `axis = 0`). Hence, if an orientation in the volume
        is orthogonal to the slicing direction, the corresponding color of the visualization will be gray.

    Example:
        ```python
        import qim3d

        vol = qim3d.examples.NT_128x128x128
        val, vec = qim3d.processing.structure_tensor(vol)

        # Visualize the structure tensor
        qim3d.viz.vectors(vol, vec, axis = 2, interactive = True)
        ```
        ![structure tensor](../../assets/screenshots/structure_tensor_visualization.gif)

    """

    # Ensure volume is a float
    if volume.dtype != np.float32 and volume.dtype != np.float64:
        volume = volume.astype(np.float32)

    # Normalize the volume if needed (i.e. if values are in [0, 255])
    if volume.max() > 1.0:
        volume = volume / 255.0

    # Define grid size limits
    min_grid_size = max(1, volume.shape[axis] // 50)
    max_grid_size = max(1, volume.shape[axis] // 10)
    if max_grid_size <= min_grid_size:
        max_grid_size = min_grid_size * 5

    if not grid_size:
        grid_size = (min_grid_size + max_grid_size) // 2

    # Testing
    if grid_size < min_grid_size or grid_size > max_grid_size:
        # Adjust grid size as little as possible to be within the limits
        grid_size = min(max(min_grid_size, grid_size), max_grid_size)
        log.warning(f'Adjusting grid size to {grid_size} as it is out of bounds.')

    def _structure_tensor(volume, vec, axis, slice_idx, grid_size, figsize, show):
        # Choose the appropriate slice based on the specified dimension
        if axis == 0:
            data_slice = volume[slice_idx, :, :]
            vectors_slice_x = vec[0, slice_idx, :, :]
            vectors_slice_y = vec[1, slice_idx, :, :]
            vectors_slice_z = vec[2, slice_idx, :, :]

        elif axis == 1:
            data_slice = volume[:, slice_idx, :]
            vectors_slice_x = vec[0, :, slice_idx, :]
            vectors_slice_y = vec[2, :, slice_idx, :]
            vectors_slice_z = vec[1, :, slice_idx, :]

        elif axis == 2:
            data_slice = volume[:, :, slice_idx]
            vectors_slice_x = vec[1, :, :, slice_idx]
            vectors_slice_y = vec[2, :, :, slice_idx]
            vectors_slice_z = vec[0, :, :, slice_idx]

        else:
            raise ValueError('Invalid dimension. Use 0 for Z, 1 for Y, or 2 for X.')

        # Create three subplots
        fig, ax = plt.subplots(1, 3, figsize=figsize, layout='constrained')

        blend_hue_saturation = (
            lambda hue, sat: hue * (1 - sat) + 0.5 * sat
        )  # Function for blending hue and saturation
        blend_slice_colors = lambda slice, colors: 0.5 * (
            slice + colors
        )  # Function for blending image slice with orientation colors

        # ----- Subplot 1: Image slice with orientation vectors ----- #
        # Create meshgrid with the correct dimensions
        xmesh, ymesh = np.mgrid[0 : data_slice.shape[0], 0 : data_slice.shape[1]]

        # Create a slice object for selecting the grid points
        g = slice(grid_size // 2, None, grid_size)

        # Angles from 0 to pi
        angles_quiver = np.mod(
            np.arctan2(vectors_slice_y[g, g], vectors_slice_x[g, g]), np.pi
        )

        # Calculate z-component (saturation)
        saturation_quiver = (vectors_slice_z[g, g] ** 2)[:, :, np.newaxis]

        # Calculate hue
        hue_quiver = plt.cm.hsv(angles_quiver / np.pi)

        # Blend hue and saturation
        rgba_quiver = blend_hue_saturation(hue_quiver, saturation_quiver)
        rgba_quiver = np.clip(
            rgba_quiver, 0, 1
        )  # Ensure rgba values are values within [0, 1]
        rgba_quiver_flat = rgba_quiver.reshape(
            (rgba_quiver.shape[0] * rgba_quiver.shape[1], 4)
        )  # Flatten array for quiver plot

        # Plot vectors
        ax[0].quiver(
            ymesh[g, g],
            xmesh[g, g],
            vectors_slice_x[g, g],
            vectors_slice_y[g, g],
            color=rgba_quiver_flat,
            angles='xy',
        )
        ax[0].quiver(
            ymesh[g, g],
            xmesh[g, g],
            -vectors_slice_x[g, g],
            -vectors_slice_y[g, g],
            color=rgba_quiver_flat,
            angles='xy',
        )

        ax[0].imshow(data_slice, cmap=volume_cmap, vmin=vmin, vmax=vmax)
        ax[0].set_title(
            f'Orientation vectors (slice {slice_idx})'
            if not interactive
            else 'Orientation vectors'
        )
        ax[0].set_axis_off()

        # ----- Subplot 2: Orientation histogram ----- #
        nbins = 36

        # Angles from 0 to pi
        angles = np.mod(np.arctan2(vectors_slice_y, vectors_slice_x), np.pi)

        # Orientation histogram over angles
        distribution, bin_edges = np.histogram(angles, bins=nbins, range=(0.0, np.pi))

        # Half circle (180 deg)
        bin_centers = (np.arange(nbins) + 0.5) * np.pi / nbins

        # Calculate z-component (saturation) for each bin
        bins = np.digitize(angles.ravel(), bin_edges)
        saturation_bin = np.array(
            [
                (
                    np.mean((vectors_slice_z**2).ravel()[bins == i])
                    if np.sum(bins == i) > 0
                    else 0
                )
                for i in range(1, len(bin_edges))
            ]
        )

        # Calculate hue for each bin
        hue_bin = plt.cm.hsv(bin_centers / np.pi)

        # Blend hue and saturation
        rgba_bin = hue_bin.copy()
        rgba_bin[:, :3] = blend_hue_saturation(
            hue_bin[:, :3], saturation_bin[:, np.newaxis]
        )

        ax[1].bar(bin_centers, distribution, width=np.pi / nbins, color=rgba_bin)
        ax[1].set_xlabel('Angle [radians]')
        ax[1].set_xlim([0, np.pi])
        ax[1].set_aspect(np.pi / ax[1].get_ylim()[1])
        ax[1].set_xticks([0, np.pi / 2, np.pi])
        ax[1].set_xticklabels(['0', '$\\frac{\\pi}{2}$', '$\\pi$'])
        ax[1].set_yticks([])
        ax[1].set_ylabel('Frequency')
        ax[1].set_title('Histogram over orientation angles')

        # ----- Subplot 3: Image slice colored according to orientation ----- #
        # Calculate z-component (saturation)
        saturation = (vectors_slice_z**2)[:, :, np.newaxis]

        # Calculate hue
        hue = plt.cm.hsv(angles / np.pi)

        # Blend hue and saturation
        rgba = blend_hue_saturation(hue, saturation)

        # Grayscale image slice blended with orientation colors
        data_slice_orientation_colored = (
            blend_slice_colors(plt.cm.gray(data_slice), rgba) * 255
        ).astype('uint8')

        ax[2].imshow(data_slice_orientation_colored)
        ax[2].set_title(
            f'Colored orientations (slice {slice_idx})'
            if not interactive
            else 'Colored orientations'
        )
        ax[2].set_axis_off()

        if show:
            plt.show()

        plt.close()

        return fig

    if vec.ndim == 5:
        vec = vec[0, ...]
        log.warning(
            'Eigenvector array is full. Only the eigenvectors corresponding to the first eigenvalue will be used.'
        )

    if slice_idx is None:
        slice_idx = volume.shape[axis] // 2

    elif isinstance(slice_idx, float):
        if slice_idx < 0 or slice_idx > 1:
            raise ValueError(
                'Values of slice_idx of float type must be between 0 and 1.'
            )
        slice_idx = int(slice_idx * volume.shape[0]) - 1

    if interactive:
        slide_idx_slider = widgets.IntSlider(
            min=0,
            max=volume.shape[axis] - 1,
            step=1,
            value=slice_idx,
            description='Slice index',
            layout=widgets.Layout(width='450px'),
        )

        grid_size_slider = widgets.IntSlider(
            min=min_grid_size,
            max=max_grid_size,
            step=1,
            value=grid_size,
            description='Grid size',
            layout=widgets.Layout(width='450px'),
        )

        widget_obj = widgets.interactive(
            _structure_tensor,
            volume=widgets.fixed(volume),
            vec=widgets.fixed(vec),
            axis=widgets.fixed(axis),
            slice_idx=slide_idx_slider,
            grid_size=grid_size_slider,
            figsize=widgets.fixed(figsize),
            show=widgets.fixed(True),
        )
        # Arrange sliders horizontally
        sliders_box = widgets.HBox([slide_idx_slider, grid_size_slider])
        widget_obj = widgets.VBox([sliders_box, widget_obj.children[-1]])
        widget_obj.layout.align_items = 'center'

        if show:
            display(widget_obj)

        return widget_obj

    else:
        return _structure_tensor(volume, vec, axis, slice_idx, grid_size, figsize, show)