Source code for fastrad.dense_extractor

import torch
from typing import Dict, Tuple, Union
from .settings import FeatureSettings
from .image import MedicalImage, Mask
from .extractor import FeatureExtractor, _FEATURE_MAP
from .logger import logger

[docs] class DenseFeatureExtractor(FeatureExtractor): """ Subclass of FeatureExtractor that natively outputs dense, voxel-wise feature maps using sliding 3D window memory-strided patch views. """
[docs] def extract_dense(self, image: MedicalImage, mask: Mask, kernel_size: Union[int, Tuple[int, int, int]], stride: Union[int, Tuple[int, int, int]] = 1) -> Dict[str, torch.Tensor]: """ Executes dense feature extraction on the given Image and Mask. Args: image: Baseline medical volume. mask: Binary ROIs. Only windows containing positive mask voxels are computed. kernel_size: 3D window dimensions (z, y, x). stride: Step size for window extraction. Returns: Dict mapping feature names to dense torch.Tensor feature maps. Output shape is (Dz_out, Dy_out, Dx_out), matching the sliding window grid. Voxels with no valid input mask elements are set to NaN. """ if isinstance(kernel_size, int): kernel_size = (kernel_size, kernel_size, kernel_size) if isinstance(stride, int): stride = (stride, stride, stride) # Move tensors to the requested device img_tensor = image.tensor.to(self.device) mask_tensor = mask.tensor.to(self.device) if img_tensor.shape != mask_tensor.shape: raise ValueError(f"Image and mask shape mismatch: {img_tensor.shape} != {mask_tensor.shape}") # Update spacing implicitly for downstream routines self.settings.spacing = image.spacing kz, ky, kx = kernel_size sz, sy, sx = stride Dz, Dy, Dx = img_tensor.shape out_z = (Dz - kz) // sz + 1 out_y = (Dy - ky) // sy + 1 out_x = (Dx - kx) // sx + 1 if out_z <= 0 or out_y <= 0 or out_x <= 0: raise ValueError("Kernel size larger than tensor dimensions or invalid stride/kernel combination.") # PyTorch unfold extracts memory-strided patch views without copying memory img_patches = img_tensor.unfold(0, kz, sz).unfold(1, ky, sy).unfold(2, kx, sx) mask_patches = mask_tensor.unfold(0, kz, sz).unfold(1, ky, sy).unfold(2, kx, sx) dense_features = {} # Fallback to CPU for specific routines if OOM happens cpu_fallback_needed = set() # Compute features sequentially over patches. # Future architecture changes can vectorize logic completely. for zi in range(out_z): for yi in range(out_y): for xi in range(out_x): p_img = img_patches[zi, yi, xi] p_mask = mask_patches[zi, yi, xi] # Only compute where the mask subset has valid voxels. if p_mask.sum().item() == 0: continue for feature_class in self.settings.feature_classes: if feature_class not in _FEATURE_MAP: raise ValueError(f"Unknown feature class: {feature_class}") compute_fn = _FEATURE_MAP[feature_class] try: # Standard pathway inside try if feature_class in cpu_fallback_needed: f_vals = compute_fn(p_img.cpu(), p_mask.cpu(), self.settings) else: f_vals = compute_fn(p_img, p_mask, self.settings) except torch.cuda.OutOfMemoryError: if self.device == "cuda": logger.warning( f"CUDA OutOfMemoryError caught for {feature_class} in patch {zi},{yi},{xi}. " f"Falling back to CPU computation." ) torch.cuda.empty_cache() cpu_fallback_needed.add(feature_class) f_vals = compute_fn(p_img.cpu(), p_mask.cpu(), self.settings) else: raise except Exception as e: logger.error(f"Failed patch extraction at {zi},{yi},{xi} for {feature_class}: {e}") continue if f_vals: for k, v in f_vals.items(): if k not in dense_features: dense_features[k] = torch.full((out_z, out_y, out_x), float('nan'), device=self.device) dense_features[k][zi, yi, xi] = v return dense_features