import os
import torch
import SimpleITK as sitk
import numpy as np
from typing import Tuple, Optional, Union
from pathlib import Path
from .image import MedicalImage, Mask
from .logger import logger
def _check_geometry_match(image_sitk: sitk.Image, mask_sitk: sitk.Image, tolerance: float = 1e-4) -> bool:
"""
Validates if two SimpleITK images share the exact analytical physical space constraints.
"""
if image_sitk.GetDimension() != mask_sitk.GetDimension():
return False
s1, s2 = np.array(image_sitk.GetSpacing()), np.array(mask_sitk.GetSpacing())
if np.max(np.abs(s1 - s2)) > tolerance:
return False
o1, o2 = np.array(image_sitk.GetOrigin()), np.array(mask_sitk.GetOrigin())
if np.max(np.abs(o1 - o2)) > tolerance:
return False
d1, d2 = np.array(image_sitk.GetDirection()), np.array(mask_sitk.GetDirection())
if np.max(np.abs(d1 - d2)) > tolerance:
return False
if image_sitk.GetSize() != mask_sitk.GetSize():
return False
return True
def resample_to_isotropic(image: sitk.Image, mask: sitk.Image, target_spacing: Tuple[float, float, float]) -> Tuple[sitk.Image, sitk.Image]:
"""
Resamples an image and mask natively to a forced isotropic target spacing.
Uses generic B-Spline for continuous intensities and Nearest-Neighbor strictly for categorical masks.
"""
logger.info(f"Resampling continuous structures strictly bound at target spacing: {target_spacing}")
orig_spacing = image.GetSpacing()
orig_size = image.GetSize()
# Calculate dimensional bounds mapping continuously
new_size = [
int(round(orig_size[0] * (orig_spacing[0] / target_spacing[0]))),
int(round(orig_size[1] * (orig_spacing[1] / target_spacing[1]))),
int(round(orig_size[2] * (orig_spacing[2] / target_spacing[2])))
]
# Standard Continuous Image interpolation
resampler_img = sitk.ResampleImageFilter()
resampler_img.SetOutputSpacing(target_spacing)
resampler_img.SetSize(new_size)
resampler_img.SetOutputDirection(image.GetDirection())
resampler_img.SetOutputOrigin(image.GetOrigin())
resampler_img.SetTransform(sitk.Transform())
resampler_img.SetDefaultPixelValue(image.GetPixelIDValue())
resampler_img.SetInterpolator(sitk.sitkBSpline)
img_res = resampler_img.Execute(image)
# Discrete Categorical Mask interpolation
resampler_mask = sitk.ResampleImageFilter()
resampler_mask.SetOutputSpacing(target_spacing)
resampler_mask.SetSize(new_size)
resampler_mask.SetOutputDirection(mask.GetDirection())
resampler_mask.SetOutputOrigin(mask.GetOrigin())
resampler_mask.SetTransform(sitk.Transform())
resampler_mask.SetDefaultPixelValue(0)
resampler_mask.SetInterpolator(sitk.sitkNearestNeighbor)
mask_res = resampler_mask.Execute(mask)
return img_res, mask_res
def crop_to_bbox(image: sitk.Image, mask: sitk.Image, label: int = 1, pad: int = 0) -> Tuple[sitk.Image, sitk.Image]:
"""
Crops the images tightly adhering around the label constraint bound to dramatically accelerate GPU throughput.
"""
label_shape_filter = sitk.LabelShapeStatisticsImageFilter()
label_shape_filter.Execute(mask)
if not label_shape_filter.HasLabel(label):
logger.warning(f"Label {label} does not exist strictly within mask scope.")
return image, mask
bbox = list(label_shape_filter.GetBoundingBox(label))
# Bbox format in SITK (3D): (startX, startY, startZ, sizeX, sizeY, sizeZ)
for i in range(3):
start = max(0, bbox[i] - pad)
end = min(mask.GetSize()[i], bbox[i] + bbox[i+3] + pad)
bbox[i] = start
bbox[i+3] = end - start
roi_filter = sitk.RegionOfInterestImageFilter()
roi_filter.SetSize(bbox[3:])
roi_filter.SetIndex(bbox[:3])
img_crop = roi_filter.Execute(image)
mask_crop = roi_filter.Execute(mask)
return img_crop, mask_crop
def _sitk_to_tensor(sitk_img: sitk.Image) -> Tuple[torch.Tensor, Tuple[float, float, float]]:
"""
Translates a SimpleITK object explicitly into contiguous PyTorch bindings correctly swapping Z, Y, X layout.
"""
data = sitk.GetArrayFromImage(sitk_img).astype(np.float32)
# SITK array is natively returned in (z, y, x) matching expected tensor bindings exactly.
spacing = sitk_img.GetSpacing()
# SITK spacing is (x, y, z), we map into scientific notation (z, y, x)
spacing_zyx = (float(spacing[2]), float(spacing[1]), float(spacing[0]))
return torch.from_numpy(data), spacing_zyx
def _read_sitk_image(path: Union[str, Path]) -> sitk.Image:
"""
Safely reads DICOM directories or single NIfTI/NRRD volume files smoothly mimicking legacy I/O bindings.
"""
path_str = str(path)
if os.path.isdir(path_str):
reader = sitk.ImageSeriesReader()
dicom_names = reader.GetGDCMSeriesFileNames(path_str)
if not dicom_names:
raise ValueError(f"Directory {path_str} does not contain valid DICOM series.")
reader.SetFileNames(dicom_names)
return reader.Execute()
else:
return sitk.ReadImage(path_str)
[docs]
def load_and_align(image_path: Union[str, Path], mask_path: Union[str, Path], resample_spacing: Optional[Tuple[float, float, float]] = None, crop: bool = True) -> Tuple[MedicalImage, Mask]:
"""
Core entrypoint matching the PyRadiomics `pyradiomics.imageoperations` logic exactly.
"""
logger.info(f"Loading Image: {image_path}")
image_sitk = _read_sitk_image(image_path)
logger.info(f"Loading Mask: {mask_path}")
mask_sitk = _read_sitk_image(mask_path)
# 1. Geometry Handshake
if not _check_geometry_match(image_sitk, mask_sitk):
logger.warning("Geometry validation failed standard tolerance! PyRadiomics usually throws exceptions here. Forcing nearest-neighbor overlay to mimic analytical bounding constraints.")
resampler = sitk.ResampleImageFilter()
resampler.SetReferenceImage(image_sitk)
resampler.SetInterpolator(sitk.sitkNearestNeighbor)
resampler.SetDefaultPixelValue(0)
mask_sitk = resampler.Execute(mask_sitk)
# 2. Resampling Hooks
if resample_spacing is not None:
# Re-ordering from (z, y, x) target back to SITK parameter format (x, y, z)
sitk_spacing = (resample_spacing[2], resample_spacing[1], resample_spacing[0])
image_sitk, mask_sitk = resample_to_isotropic(image_sitk, mask_sitk, sitk_spacing)
# 3. Dynamic Memory Constraints Cropping Layer
if crop:
image_sitk, mask_sitk = crop_to_bbox(image_sitk, mask_sitk)
# 4. Final Proxy Bridging Construction
img_t, img_s = _sitk_to_tensor(image_sitk)
mask_t, mask_s = _sitk_to_tensor(mask_sitk)
img_obj = MedicalImage(tensor=img_t, spacing=img_s)
mask_obj = Mask(tensor=mask_t, spacing=mask_s)
return img_obj, mask_obj