Source code for fastrad.filters

import torch
import torch.nn.functional as F
from typing import Dict, List
import math
import numpy as np

from .image import MedicalImage
from .logger import logger

def _get_gaussian_kernel_3d(sigma: float, size: int = 0) -> torch.Tensor:
    """
    Constructs a discrete 3D Gaussian discrete matrix for smoothing.
    """
    if size == 0:
        # Standard PyRadiomics kernel bounding rule (radius = ceil(3 * sigma))
        radius = int(math.ceil(3.0 * sigma))
        size = 2 * radius + 1
        
    grid = torch.arange(size, dtype=torch.float32)
    grid = grid - (size - 1) / 2.0
    
    variance = sigma ** 2
    # 1D kernel
    gaussian_1d = torch.exp(-grid ** 2 / (2 * variance))
    # Outer product for 3D
    gaussian_2d = torch.einsum('i,j->ij', gaussian_1d, gaussian_1d)
    gaussian_3d = torch.einsum('ij,k->ijk', gaussian_2d, gaussian_1d)
    
    gaussian_3d = gaussian_3d / torch.sum(gaussian_3d)
    return gaussian_3d

def _get_LoG_kernel_3d(sigma: float, size: int = 0) -> torch.Tensor:
    """
    Generates an analytical 3D Laplacian of Gaussian (LoG) spatial kernel matrices.
    """
    if size == 0:
        radius = int(math.ceil(3.0 * sigma))
        size = 2 * radius + 1
        
    grid = torch.arange(size, dtype=torch.float32)
    grid = grid - (size - 1) / 2.0
    
    # Generate 3D grid
    z, y, x = torch.meshgrid(grid, grid, grid, indexing='ij')
    squared_dist = x**2 + y**2 + z**2
    variance = sigma ** 2
    
    # Analytical LoG expression
    scaling = -1.0 / (math.pi * variance ** 2)
    norm_term = 1.0 - (squared_dist / (2 * variance))
    exponential = torch.exp(-squared_dist / (2 * variance))
    
    log_3d = scaling * norm_term * exponential
    
    # Ensure zero sum
    log_3d = log_3d - torch.mean(log_3d)
    return log_3d

def get_LoG_image(image: MedicalImage, sigmas: List[float]) -> Dict[str, MedicalImage]:
    """
    Applies Laplacian of Gaussian (LoG) spatial filtration.
    Mimics `pyradiomics.imageoperations.getLoGImage`.
    Yields dictionary targeting specific string namespaces e.g., 'log-sigma-1-0-mm-3D'.
    """
    logger.info(f"Applying analytical Laplacian of Gaussian kernels for sigmas: {sigmas}")
    
    # Requires batched layout: (1, 1, Z, Y, X)
    tensor = image.tensor.unsqueeze(0).unsqueeze(0)
    filtered_images = {}
    
    for sigma in sigmas:
        kernel = _get_LoG_kernel_3d(sigma)
        kernel = kernel.unsqueeze(0).unsqueeze(0).to(tensor.device)
        
        padding = kernel.shape[-1] // 2
        filtered_tensor = F.conv3d(tensor, kernel, padding=padding)
        filtered_tensor = filtered_tensor.squeeze(0).squeeze(0)
        
        name = f"log-sigma-{str(sigma).replace('.', '-炎')}-mm-3D"
        # PyRadiomics has a specific naming convention: 'log-sigma-1-0-mm-3D'
        name = name.replace("-炎", "-")
        filtered_images[name] = MedicalImage(tensor=filtered_tensor, spacing=image.spacing)
        
    return filtered_images


def get_Square_image(image: MedicalImage) -> Dict[str, MedicalImage]:
    tensor = torch.square(image.tensor)
    return {"square": MedicalImage(tensor=tensor, spacing=image.spacing)}

def get_SquareRoot_image(image: MedicalImage) -> Dict[str, MedicalImage]:
    tensor = torch.sqrt(torch.abs(image.tensor))
    return {"squareroot": MedicalImage(tensor=tensor, spacing=image.spacing)}

def get_Logarithm_image(image: MedicalImage) -> Dict[str, MedicalImage]:
    tensor = torch.log(torch.abs(image.tensor) + 1e-6)
    return {"logarithm": MedicalImage(tensor=tensor, spacing=image.spacing)}

def get_Exponential_image(image: MedicalImage) -> Dict[str, MedicalImage]:
    tensor = torch.exp(image.tensor)
    return {"exponential": MedicalImage(tensor=tensor, spacing=image.spacing)}


[docs] def apply_builtin_filters(image: MedicalImage, filter_types: Dict[str, Dict]) -> Dict[str, MedicalImage]: """ Master router handling multiple simultaneous math mappings to mirror legacy automated scaling. Argument format mimics standard PyRadiomics filter definition dictionaries: e.g. {"Original": {}, "LoG": {"sigma": [1.0, 2.0]}, "Square": {}} """ mapped_images = {} for key, params in filter_types.items(): if key.lower() == "original": mapped_images["original"] = image elif key.lower() == "log": sigmas = params.get("sigma", [1.0]) if not isinstance(sigmas, list): sigmas = [sigmas] out = get_LoG_image(image, sigmas) mapped_images.update(out) elif key.lower() == "square": mapped_images.update(get_Square_image(image)) elif key.lower() == "squareroot": mapped_images.update(get_SquareRoot_image(image)) elif key.lower() == "logarithm": mapped_images.update(get_Logarithm_image(image)) elif key.lower() == "exponential": mapped_images.update(get_Exponential_image(image)) else: logger.warning(f"Filter type {key} is currently unsupported natively via analytical scaling.") return mapped_images