From a582d1365c7fae02699965a1f13ff549f95872cb Mon Sep 17 00:00:00 2001 From: Alfie Roddan <228966941+alfieroddanintel@users.noreply.github.com> Date: Tue, 11 Nov 2025 22:08:33 +0900 Subject: [PATCH 01/25] init commit Signed-off-by: Alfie Roddan <228966941+alfieroddanintel@users.noreply.github.com> --- .../models/image/anomaly_dino/__init__.py | 50 +++ .../image/anomaly_dino/lightning_model.py | 288 ++++++++++++++++++ .../models/image/anomaly_dino/torch_model.py | 288 ++++++++++++++++++ 3 files changed, 626 insertions(+) create mode 100644 src/anomalib/models/image/anomaly_dino/__init__.py create mode 100644 src/anomalib/models/image/anomaly_dino/lightning_model.py create mode 100644 src/anomalib/models/image/anomaly_dino/torch_model.py diff --git a/src/anomalib/models/image/anomaly_dino/__init__.py b/src/anomalib/models/image/anomaly_dino/__init__.py new file mode 100644 index 0000000000..32a5345f6d --- /dev/null +++ b/src/anomalib/models/image/anomaly_dino/__init__.py @@ -0,0 +1,50 @@ +# Copyright (C) 2022-2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""AnomalyDINO: Boosting Patch-based Few-shot Anomaly Detection with DINOv2. + +This module implements AnomalyDINO. A memory-bank model for anomaly detection +that utilizes DINOv2-Small as its backbone. At inference time it uses kNN +to search for anomalous patches. The image anomaly score is dependent on the worse +99th percentile of the pixel-wise anomaly score. + +The model has optional masking to remove noisy background components, +also optionally can use greedy coreset-subsampling if needed. + +Example: + >>> from anomalib.data import MVTecAD + >>> from anomalib.models.image.anomaly_dino.lightning_model import AnomalyDINO + >>> from anomalib.engine import Engine + >>> + >>> MVTEC_CATEGORIES = [ + ... "hazelnut", "grid", "carpet", "bottle", "cable", "capsule", "leather", + ... "metal_nut", "pill", "screw", "tile", "toothbrush", "transistor", "wood", "zipper" + ... ] + >>> MASKED_CATEGORIES = ["capsule", "hazelnut", "pill", "screw", "toothbrush"] + >>> + >>> for category in MVTEC_CATEGORIES: + ... mask = category in MASKED_CATEGORIES + ... print(f"--- Running category: {category} | masking={mask} ---") + ... + ... # Initialize data module + ... datamodule = MVTecAD(category=category) + ... + ... # Initialize model + ... model = AnomalyDINO( + ... num_neighbours=1, + ... encoder_name="dinov2_vit_small_14", + ... masking=mask, + ... coreset_subsampling=False, + ... ) + ... + ... # Train and test + ... engine = Engine() + ... engine.fit(model=model, datamodule=datamodule) + ... engine.test(datamodule=datamodule) + >>> + >>> print("All categories processed.") +""" + +from anomalib.models.image.anomaly_dino.lightning_model import AnomalyDINO + +__all__ = ["AnomalyDINO"] diff --git a/src/anomalib/models/image/anomaly_dino/lightning_model.py b/src/anomalib/models/image/anomaly_dino/lightning_model.py new file mode 100644 index 0000000000..8530501f62 --- /dev/null +++ b/src/anomalib/models/image/anomaly_dino/lightning_model.py @@ -0,0 +1,288 @@ +# Copyright (C) 2022-2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""AnomalyDINO: Boosting Patch-based Few-shot Anomaly Detection with DINOv2. + +This module implements AnomalyDINO. A memory-bank model for anomaly detection +that utilizes DINOv2-Small as its backbone. At inference time it uses kNN +to search for anomalous patches. The image anomaly score is dependent on the worse +99th percentile of the pixel-wise anomaly score. + +The model has optional masking to remove noisy background components, +also optionally can use greedy coreset-subsampling if needed. + +Example: + >>> from anomalib.data import MVTecAD + >>> from anomalib.models.image.anomaly_dino.lightning_model import AnomalyDINO + >>> from anomalib.engine import Engine + >>> + >>> MVTEC_CATEGORIES = [ + ... "hazelnut", "grid", "carpet", "bottle", "cable", "capsule", "leather", + ... "metal_nut", "pill", "screw", "tile", "toothbrush", "transistor", "wood", "zipper" + ... ] + >>> MASKED_CATEGORIES = ["capsule", "hazelnut", "pill", "screw", "toothbrush"] + >>> + >>> for category in MVTEC_CATEGORIES: + ... mask = category in MASKED_CATEGORIES + ... print(f"--- Running category: {category} | masking={mask} ---") + ... + ... # Initialize data module + ... datamodule = MVTecAD(category=category) + ... + ... # Initialize model + ... model = AnomalyDINO( + ... num_neighbours=1, + ... encoder_name="dinov2_vit_small_14", + ... masking=mask, + ... coreset_subsampling=False, + ... ) + ... + ... # Train and test + ... engine = Engine() + ... engine.fit(model=model, datamodule=datamodule) + ... engine.test(datamodule=datamodule) + >>> + >>> print("All categories processed.") +""" + +import logging +from typing import Any + +import torch +from lightning.pytorch.utilities.types import STEP_OUTPUT +from torch import nn +from torchvision.transforms.v2 import Compose, InterpolationMode, Normalize, Resize + +from anomalib import LearningType +from anomalib.data import Batch +from anomalib.metrics import Evaluator +from anomalib.models.components import AnomalibModule, MemoryBankMixin +from anomalib.post_processing import PostProcessor +from anomalib.pre_processing import PreProcessor +from anomalib.visualization import Visualizer + +from .torch_model import AnomalyDINOModel + +logger = logging.getLogger(__name__) + + +class AnomalyDINO(MemoryBankMixin, AnomalibModule): + """AnomalyDINO Lightning Module for anomaly detection. + + This class implements the AnomalyDINO algorithm, which leverages self-supervised + DINO (self-distillation with no labels) vision transformer (ViT) encoders for + feature extraction in anomaly detection tasks. Similar to PatchCore, it uses a + memory bank of patch embeddings and performs nearest neighbor search to identify + anomalous regions in test images. + + The model operates in two phases: + 1. **Training**: Extracts and stores patch embeddings from normal training images. + 2. **Inference**: Compares test image patch embeddings with the memory bank + to identify anomalies based on distance metrics. + + Args: + num_neighbours (int, optional): Number of nearest neighbors to use for + anomaly scoring. Defaults to ``1``. + encoder_name (str, optional): Name of the pretrained DINO encoder to use. + Defaults to ``"dinov2_vits14"``. + masking (bool, optional): Whether to apply masking during feature extraction + to simulate occlusions or missing patches. Defaults to ``False``. + coreset_subsampling (bool, optional): Whether to apply coreset subsampling + to reduce the size of the memory bank. Defaults to ``False``. + sampling ratio(float, optional): If coreset subsampling, by what ratio + should we subsample. Defaults to ``0.1`` + pre_processor (PreProcessor | bool, optional): Pre-processor instance or + bool flag to enable default preprocessing. Defaults to ``True``. + post_processor (PostProcessor | bool, optional): Post-processor instance or + bool flag to enable default postprocessing. Defaults to ``True``. + evaluator (Evaluator | bool, optional): Evaluator instance or bool flag for + performance computation. Defaults to ``True``. + visualizer (Visualizer | bool, optional): Visualizer instance or bool flag + to enable visualization. Defaults to ``True``. + + Example: + >>> from anomalib.data import MVTecAD + >>> from anomalib.models.image.anomaly_dino.lightning_model import AnomalyDINO + >>> from anomalib.engine import Engine + >>> + >>> MVTEC_CATEGORIES = [ + ... "hazelnut", "grid", "carpet", "bottle", "cable", "capsule", "leather", + ... "metal_nut", "pill", "screw", "tile", "toothbrush", "transistor", "wood", "zipper" + ... ] + >>> MASKED_CATEGORIES = ["capsule", "hazelnut", "pill", "screw", "toothbrush"] + >>> + >>> for category in MVTEC_CATEGORIES: + ... mask = category in MASKED_CATEGORIES + ... print(f"--- Running category: {category} | masking={mask} ---") + ... + ... # Initialize data module + ... datamodule = MVTecAD(category=category) + ... + ... # Initialize model + ... model = AnomalyDINO( + ... num_neighbours=1, + ... encoder_name="dinov2_vit_small_14", + ... masking=mask, + ... coreset_subsampling=False, + ... ) + ... + ... # Train and test + ... engine = Engine() + ... engine.fit(model=model, datamodule=datamodule) + ... engine.test(datamodule=datamodule) + >>> + >>> print("All categories processed.") + + Notes: + - The model does not require backpropagation or optimization, as it relies + on pretrained transformer embeddings and similarity search. + - Works best when trained exclusively on normal (non-anomalous) samples. + + See Also: + - :class:`anomalib.models.components.AnomalibModule`: + Base class for all anomaly detection models + - :class:`anomalib.models.components.MemoryBankMixin`: + Mixin class for models using memory bank embeddings + """ + + def __init__( + self, + num_neighbours: int = 1, + encoder_name: str = "dinov2_vit_small_14", + masking: bool = False, + coreset_subsampling: bool = False, + sampling_ratio: float = 0.1, + pre_processor: nn.Module | bool = True, + post_processor: nn.Module | bool = True, + evaluator: Evaluator | bool = True, + visualizer: Visualizer | bool = True, + ) -> None: + super().__init__( + pre_processor=pre_processor, + post_processor=post_processor, + evaluator=evaluator, + visualizer=visualizer, + ) + self.model: AnomalyDINOModel = AnomalyDINOModel( + num_neighbours=num_neighbours, + encoder_name=encoder_name, + masking=masking, + coreset_subsampling=coreset_subsampling, + sampling_ratio=sampling_ratio, + ) + + @classmethod + def configure_pre_processor( + cls, + image_size: tuple[int, int] | None = None, + ) -> PreProcessor: + """Configure the default pre-processor for AnomalyDINO. + + Args: + image_size (tuple[int, int] | None, optional): Target size for resizing + input images. Defaults to ``(252, 252)``. + + Returns: + PreProcessor: Configured pre-processor instance. + + Example: + >>> pre_processor = AnomalyDINO.configure_pre_processor( + ... image_size=(252, 252) + ... ) + >>> transformed_image = pre_processor(image) + """ + image_size = image_size or (252, 252) + transform = Compose([ + Resize(image_size, antialias=True, interpolation=InterpolationMode.BICUBIC), + Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ]) + return PreProcessor(transform=transform) + + @staticmethod + def configure_optimizers() -> None: + """Configure optimizers. + + Returns: + None: AnomalyDINO does not require optimization or gradient updates. + """ + return + + def training_step(self, batch: Batch, *args, **kwargs) -> None: + """Extract feature embeddings from training images. + + Args: + batch (Batch): Input batch containing images and metadata. + *args: Additional arguments (unused). + **kwargs: Additional keyword arguments (unused). + + Returns: + torch.Tensor: Dummy loss tensor for Lightning compatibility. + + Note: + The extracted embeddings are stored in the models memory bank for + later use during the coreset sampling or inference phase. + """ + del args, kwargs # These variables are not used. + _ = self.model(batch.image) + return torch.tensor(0.0, requires_grad=True, device=self.device) + + def fit(self) -> None: + """Optional fitting step. + + This method is a placeholder for potential post-training operations + such as coreset subsampling or feature normalization. + + Note: + The current implementation is a no-op, as AnomalyDINO typically + performs inference directly after feature extraction. + """ + self.model.fit() + + def validation_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT: + """Generate anomaly predictions for a validation batch. + + Args: + batch (Batch): Input batch containing images and metadata. + *args: Additional arguments (unused). + **kwargs: Additional keyword arguments (unused). + + Returns: + STEP_OUTPUT: Batch with added predictions including anomaly maps and + scores computed using nearest neighbor search. + """ + del args, kwargs + predictions = self.model(batch.image) + return batch.update(**predictions._asdict()) + + @property + def trainer_arguments(self) -> dict[str, Any]: + """Default PyTorch Lightning trainer arguments for AnomalyDINO. + + Returns: + dict[str, Any]: Trainer configuration with: + - ``gradient_clip_val``: ``0`` (no gradient clipping) + - ``max_epochs``: ``1`` (single pass over training data) + - ``num_sanity_val_steps``: ``0`` (skip validation sanity checks) + - ``devices``: ``1`` (single GPU supported) + """ + return {"gradient_clip_val": 0, "max_epochs": 1, "num_sanity_val_steps": 0, "devices": 1} + + @property + def learning_type(self) -> LearningType: + """Get the learning type for AnomalyDINO. + + Returns: + LearningType: Always ``LearningType.ONE_CLASS`` since the model is + trained only on normal samples. + """ + return LearningType.ONE_CLASS + + @staticmethod + def configure_post_processor() -> PostProcessor: + """Configure the default post-processor. + + Returns: + PostProcessor: Post-processor that converts raw model scores into + interpretable anomaly predictions and maps. + """ + return PostProcessor() diff --git a/src/anomalib/models/image/anomaly_dino/torch_model.py b/src/anomalib/models/image/anomaly_dino/torch_model.py new file mode 100644 index 0000000000..b627472ab6 --- /dev/null +++ b/src/anomalib/models/image/anomaly_dino/torch_model.py @@ -0,0 +1,288 @@ +# Copyright (C) 2022-2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""PyTorch model implementation for AnomalyDINO. + +This module defines the low-level PyTorch implementation of the AnomalyDINO model, +which combines a DINOv2 Vision Transformer encoder with a memory-bank approach +for few-shot anomaly detection. It performs patch-based feature extraction, +optional background masking, and k-nearest neighbor search for anomaly scoring. + +Example: + >>> from anomalib.models.image.anomaly_dino.torch_model import AnomalyDINOModel + >>> model = AnomalyDINOModel( + ... num_neighbours=1, + ... encoder_name="dinov2_vit_small_14", + ... masking=False, + ... coreset_subsampling=False, + ... sampling_ratio=0.1, + ... ) +""" + +import cv2 +import numpy as np +import torch +from sklearn.decomposition import PCA +from torch import nn +from torch.nn import functional as F # noqa: N812 + +from anomalib.data import InferenceBatch +from anomalib.models.components import DynamicBufferMixin, KCenterGreedy +from anomalib.models.image.dinomaly.components import load as load_dinov2_model +from anomalib.models.image.patchcore.anomaly_map import AnomalyMapGenerator + + +class AnomalyDINOModel(DynamicBufferMixin, nn.Module): + """AnomalyDINO base PyTorch model for patch-based anomaly detection. + + This model uses DINOv2 transformers as feature extractors and applies + a memory-bank mechanism for few-shot anomaly detection, similar to PatchCore. + It supports optional background masking and coreset subsampling. + + Args: + num_neighbours (int, optional): Number of nearest neighbors used for + anomaly scoring. Defaults to ``1``. + encoder_name (str, optional): DINOv2 encoder architecture name. + Must start with ``"dinov2"``. Defaults to ``"dinov2_vit_small_14"``. + masking (bool, optional): Whether to apply PCA-based masking to suppress + background features. Defaults to ``False``. + coreset_subsampling (bool, optional): Whether to apply greedy coreset + selection to reduce memory bank size. Defaults to ``False``. + sampling_ratio (float, optional): Fraction of samples retained during + coreset subsampling. Defaults to ``0.1``. + + Example: + >>> model = AnomalyDINOModel(masking=True, coreset_subsampling=True) + >>> x = torch.randn(1, 3, 224, 224) + >>> preds = model(x) + >>> preds.pred_score.shape + torch.Size([1, 1]) + """ + + def __init__( + self, + num_neighbours: int = 1, + encoder_name: str = "dinov2_vit_small_14", + masking: bool = False, + coreset_subsampling: bool = False, + sampling_ratio: float = 0.1, + ) -> None: + super().__init__() + self.num_neighbours = num_neighbours + self.encoder_name = encoder_name + self.masking = masking + self.coreset_subsampling = coreset_subsampling + self.sampling_ratio = sampling_ratio + + # Load DINOv2 backbone + assert encoder_name.startswith("dinov2"), f"Encoder must be dinov2, got {encoder_name}" + self.feature_encoder = load_dinov2_model(self.encoder_name) + self.feature_encoder.eval() + + # Memory bank and embedding storage + self.register_buffer("memory_bank", torch.empty(0)) + self.embedding_store: list[torch.Tensor] = [] + + # Anomaly map generator for visualization and scoring + self.anomaly_map_generator = AnomalyMapGenerator() + + def fit(self) -> None: + """Finalize and optionally subsample the memory bank after training. + + Once all embeddings from normal training images have been collected, + this method consolidates them into the memory bank and optionally + performs coreset-based subsampling. + + Raises: + ValueError: If called before collecting any embeddings. + """ + if len(self.embedding_store) == 0: + err_str = "No embeddings collected. Run model in training mode first." + raise ValueError(err_str) + + # Stack and normalize embeddings + self.memory_bank = torch.vstack(self.embedding_store) + self.embedding_store.clear() + + # Optional coreset selection + if self.coreset_subsampling: + sampler = KCenterGreedy(embedding=self.memory_bank, sampling_ratio=self.sampling_ratio) + self.memory_bank = sampler.sample_coreset() + + def extract_features(self, image_tensor: torch.Tensor) -> torch.Tensor: + """Extract patch-level feature embeddings from the last transformer layer. + + Returns flattened patch tokens excluding CLS and register tokens. + + Args: + image_tensor (torch.Tensor): Input image tensor of shape ``(B, 3, H, W)``. + + Returns: + torch.Tensor: Patch feature embeddings of shape ``(B, N, D)``, + where ``N`` is the number of patches and ``D`` the feature dimension. + """ + with torch.inference_mode(): + tokens = self.feature_encoder.get_intermediate_layers(image_tensor, n=1)[0] + start = self.feature_encoder.num_tokens + self.feature_encoder.num_register_tokens + return tokens[:, start:, :] + + @staticmethod + def compute_background_masks( + batch_features: np.ndarray, + grid_size: tuple[int, int], + threshold: float = 10.0, + kernel_size: int = 3, + border: float = 0.2, + ) -> np.ndarray: + """Compute binary masks to identify foreground patches. + + This method uses PCA on patch embeddings to estimate foreground regions, + followed by morphological operations to clean up the mask. + + Args: + batch_features (np.ndarray): Patch embeddings of shape ``(B, N, D)``. + grid_size (tuple[int, int]): Spatial grid dimensions (H, W). + threshold (float, optional): PCA threshold for foreground separation. + Defaults to ``10.0``. + kernel_size (int, optional): Morphological kernel size. Defaults to ``3``. + border (float, optional): Fraction of image borders excluded from + thresholding. Defaults to ``0.2``. + + Returns: + np.ndarray: Boolean masks of shape ``(B, N)``, where ``True`` indicates + foreground patches. + """ + b, n, _ = batch_features.shape + masks = np.ones((b, n), dtype=bool) + + for i in range(b): + img_features = batch_features[i] + pca = PCA(n_components=1, svd_solver="randomized") + first_pc = pca.fit_transform(img_features.astype(np.float32)) + mask = first_pc > threshold + + mask_2d = mask.reshape(grid_size) + h, w = grid_size + y0, y1 = int(h * border), int(h * (1 - border)) + x0, x1 = int(w * border), int(w * (1 - border)) + center_crop = mask_2d[y0:y1, x0:x1] + + # Flip sign if PCA direction is inverted + if center_crop.sum() <= center_crop.size * 0.35: + mask = (-first_pc) > threshold + mask_2d = mask.reshape(grid_size) + + # Morphological cleanup + kernel = np.ones((kernel_size, kernel_size), np.uint8) + mask_2d = cv2.dilate(mask_2d.astype(np.uint8), kernel).astype(bool) + mask_2d = cv2.morphologyEx(mask_2d.astype(np.uint8), cv2.MORPH_CLOSE, kernel).astype(bool) + + masks[i] = mask_2d.flatten() + + return masks + + @staticmethod + def euclidean_dist(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """Compute pairwise Euclidean distances between all pairs of vectors. + + Efficiently computes the distance matrix between ``x`` and ``y`` without + using ``torch.cdist()``, improving compatibility with ONNX and OpenVINO. + + Args: + x (torch.Tensor): Tensor of shape ``(n, d)``. + y (torch.Tensor): Tensor of shape ``(m, d)``. + + Returns: + torch.Tensor: Distance matrix of shape ``(n, m)``. + + Example: + >>> x = torch.randn(100, 512) + >>> y = torch.randn(50, 512) + >>> distances = AnomalyDINOModel.euclidean_dist(x, y) + >>> distances.shape + torch.Size([100, 50]) + """ + x_norm = x.pow(2).sum(dim=-1, keepdim=True) + y_norm = y.pow(2).sum(dim=-1, keepdim=True) + res = x_norm - 2 * torch.matmul(x, y.transpose(-2, -1)) + y_norm.transpose(-2, -1) + return res.clamp_min_(0).sqrt_() + + @staticmethod + def mean_top1p(distances: torch.Tensor) -> torch.Tensor: + """Compute the mean of the top 1% distances per image. + + Used as a robust aggregation of patch-level anomaly scores into a + single image-level anomaly score. + + Args: + distances (torch.Tensor): Patch-level distances of shape ``(B, N)``. + + Returns: + torch.Tensor: Mean of the top 1% distances per image, shape ``(B, 1)``. + """ + n = distances.shape[-1] + num_top = max(int(n * 0.01), 1) + topk_vals, _ = torch.topk(distances, num_top, dim=1, largest=True) + return topk_vals.mean(dim=1, keepdim=True) + + def forward(self, input_tensor: torch.Tensor) -> torch.Tensor | InferenceBatch: + """Forward pass for both training and inference. + + In training mode: + - Extracts normalized patch features. + - Collects embeddings into the memory bank. + + In inference mode: + - Computes distances between input features and the memory bank. + - Performs kNN-based scoring and anomaly map generation. + + Args: + input_tensor (torch.Tensor): Input batch of shape ``(B, 3, H, W)``. + + Returns: + Union[torch.Tensor, InferenceBatch]: + - In training: dummy scalar tensor (no loss backprop). + - In inference: :class:`anomalib.data.InferenceBatch` containing: + * ``pred_score``: Image-level anomaly score ``(B, 1)`` + * ``anomaly_map``: Pixel-level anomaly heatmap ``(B, 1, H, W)`` + """ + b, _, w, h = input_tensor.shape + cropped_width = w - w % self.feature_encoder.patch_size + cropped_height = h - h % self.feature_encoder.patch_size + grid_size = ( + cropped_height // self.feature_encoder.patch_size, + cropped_width // self.feature_encoder.patch_size, + ) + + device = input_tensor.device + features = self.extract_features(input_tensor) + + if self.masking: + features_np = features.detach().cpu().numpy() + masks_np = self.compute_background_masks(features_np, grid_size) + masks = torch.from_numpy(masks_np).to(device) + else: + masks = torch.ones(features.shape[:2], dtype=torch.bool, device=device) + + features = features[masks] + features = F.normalize(features, p=2, dim=1) + + if self.training: + self.embedding_store.append(features) + return torch.tensor(0.0, device=device, requires_grad=True) + + # Inference + dist_matrix = self.euclidean_dist(features, self.memory_bank) + k = max(1, self.num_neighbours) + topk_vals, _ = torch.topk(dist_matrix, k=k, dim=1, largest=False) + min_dists = topk_vals.mean(dim=1) + + distances_full = torch.zeros((b, grid_size[0] * grid_size[1]), device=device) + batch_idx, patch_idx = torch.nonzero(masks, as_tuple=True) + distances_full[batch_idx, patch_idx] = min_dists + + image_score = self.mean_top1p(distances_full) + anomaly_map = distances_full.view(b, 1, *grid_size) + anomaly_map = self.anomaly_map_generator(anomaly_map, (h, w)) + + return InferenceBatch(pred_score=image_score, anomaly_map=anomaly_map) From 08e722a27a3ba1d1c7679573ad62a0b1496bfde8 Mon Sep 17 00:00:00 2001 From: Alfie Roddan <228966941+alfieroddanintel@users.noreply.github.com> Date: Tue, 11 Nov 2025 22:29:19 +0900 Subject: [PATCH 02/25] add cdist instead of euclidean distance, divide by 2 for cosine and improve comments Signed-off-by: Alfie Roddan <228966941+alfieroddanintel@users.noreply.github.com> --- .../models/image/anomaly_dino/torch_model.py | 22 ++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/src/anomalib/models/image/anomaly_dino/torch_model.py b/src/anomalib/models/image/anomaly_dino/torch_model.py index b627472ab6..6ad7d8efae 100644 --- a/src/anomalib/models/image/anomaly_dino/torch_model.py +++ b/src/anomalib/models/image/anomaly_dino/torch_model.py @@ -272,16 +272,32 @@ def forward(self, input_tensor: torch.Tensor) -> torch.Tensor | InferenceBatch: return torch.tensor(0.0, device=device, requires_grad=True) # Inference - dist_matrix = self.euclidean_dist(features, self.memory_bank) + # L2-normalized distances + # memory_bank : [M, D], features : [Q, D] + + # Compute pairwise distances [Q, M] + dists = torch.cdist(features, self.memory_bank, p=2) + + # Convert L2 to cosine distance + # (since both vectors are normalized, divide by 2) + dists = dists / 2.0 + + # Get top-k nearest neighbors k = max(1, self.num_neighbours) - topk_vals, _ = torch.topk(dist_matrix, k=k, dim=1, largest=False) - min_dists = topk_vals.mean(dim=1) + topk_vals, _ = torch.topk(dists, k=k, dim=1, largest=False) + # Mean over k neighbors if needed + min_dists = topk_vals.mean(dim=1) if k > 1 else topk_vals.squeeze(1) + + # Vectorized reconstruction distances_full = torch.zeros((b, grid_size[0] * grid_size[1]), device=device) batch_idx, patch_idx = torch.nonzero(masks, as_tuple=True) distances_full[batch_idx, patch_idx] = min_dists + # Aggregate image-level anomaly scores image_score = self.mean_top1p(distances_full) + + # Generate final anomaly map anomaly_map = distances_full.view(b, 1, *grid_size) anomaly_map = self.anomaly_map_generator(anomaly_map, (h, w)) From 1a92038013d5fe0e5845bbe9822f373ee6e3f8ef Mon Sep 17 00:00:00 2001 From: Alfie Roddan <228966941+alfieroddanintel@users.noreply.github.com> Date: Tue, 11 Nov 2025 23:28:45 +0900 Subject: [PATCH 03/25] remove redundant euclidean distance Signed-off-by: Alfie Roddan <228966941+alfieroddanintel@users.noreply.github.com> --- .../models/image/anomaly_dino/torch_model.py | 26 ------------------- 1 file changed, 26 deletions(-) diff --git a/src/anomalib/models/image/anomaly_dino/torch_model.py b/src/anomalib/models/image/anomaly_dino/torch_model.py index 6ad7d8efae..b4d2b7c9bf 100644 --- a/src/anomalib/models/image/anomaly_dino/torch_model.py +++ b/src/anomalib/models/image/anomaly_dino/torch_model.py @@ -181,32 +181,6 @@ def compute_background_masks( return masks - @staticmethod - def euclidean_dist(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - """Compute pairwise Euclidean distances between all pairs of vectors. - - Efficiently computes the distance matrix between ``x`` and ``y`` without - using ``torch.cdist()``, improving compatibility with ONNX and OpenVINO. - - Args: - x (torch.Tensor): Tensor of shape ``(n, d)``. - y (torch.Tensor): Tensor of shape ``(m, d)``. - - Returns: - torch.Tensor: Distance matrix of shape ``(n, m)``. - - Example: - >>> x = torch.randn(100, 512) - >>> y = torch.randn(50, 512) - >>> distances = AnomalyDINOModel.euclidean_dist(x, y) - >>> distances.shape - torch.Size([100, 50]) - """ - x_norm = x.pow(2).sum(dim=-1, keepdim=True) - y_norm = y.pow(2).sum(dim=-1, keepdim=True) - res = x_norm - 2 * torch.matmul(x, y.transpose(-2, -1)) + y_norm.transpose(-2, -1) - return res.clamp_min_(0).sqrt_() - @staticmethod def mean_top1p(distances: torch.Tensor) -> torch.Tensor: """Compute the mean of the top 1% distances per image. From 5cfd65cfcaf724218afc1705c25f8f91fe279455 Mon Sep 17 00:00:00 2001 From: Alfie Roddan <228966941+alfieroddanintel@users.noreply.github.com> Date: Tue, 11 Nov 2025 23:41:57 +0900 Subject: [PATCH 04/25] Add AnomalyDINO to model list. Also alphabetically re-order some models Signed-off-by: Alfie Roddan <228966941+alfieroddanintel@users.noreply.github.com> --- src/anomalib/models/__init__.py | 6 ++++-- src/anomalib/models/image/__init__.py | 5 ++++- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/src/anomalib/models/__init__.py b/src/anomalib/models/__init__.py index f7ed76d1f4..1051215d16 100644 --- a/src/anomalib/models/__init__.py +++ b/src/anomalib/models/__init__.py @@ -55,6 +55,7 @@ from anomalib.utils.path import convert_snake_to_pascal_case, convert_to_snake_case, convert_to_title_case from .image import ( + AnomalyDINO, Cfa, Cflow, Csflow, @@ -93,6 +94,8 @@ class UnknownModelError(ModuleNotFoundError): __all__ = [ + "AiVad", + "AnomalyDINO", "Cfa", "Cflow", "Csflow", @@ -104,6 +107,7 @@ class UnknownModelError(ModuleNotFoundError): "EfficientAd", "Fastflow", "Fre", + "Fuvas", "Ganomaly", "Padim", "Patchcore", @@ -114,8 +118,6 @@ class UnknownModelError(ModuleNotFoundError): "UniNet", "VlmAd", "WinClip", - "AiVad", - "Fuvas", ] logger = logging.getLogger(__name__) diff --git a/src/anomalib/models/image/__init__.py b/src/anomalib/models/image/__init__.py index d5adc65ead..cdfc4f44f1 100644 --- a/src/anomalib/models/image/__init__.py +++ b/src/anomalib/models/image/__init__.py @@ -23,6 +23,7 @@ >>> predictions = engine.predict(model=model, datamodule=datamodule) # doctest: +SKIP Available Models: + - :class: `AnomalyDINO`: Boost Memorybank Models with DINOv2 - :class:`Cfa`: Contrastive Feature Aggregation - :class:`Cflow`: Conditional Normalizing Flow - :class:`Csflow`: Conditional Split Flow @@ -44,6 +45,7 @@ - :class:`WinClip`: Zero-/Few-Shot CLIP-based Detection """ +from .anomaly_dino import AnomalyDINO from .cfa import Cfa from .cflow import Cflow from .csflow import Csflow @@ -67,10 +69,12 @@ from .winclip import WinClip __all__ = [ + "AnomalyDINO", "Cfa", "Cflow", "Csflow", "Dfkde", + "Dinomaly", "Dfm", "Draem", "Dsr", @@ -87,5 +91,4 @@ "UniNet", "VlmAd", "WinClip", - "Dinomaly", ] From 6727541c1482121040b8e0f6320a0a884a8f1ceb Mon Sep 17 00:00:00 2001 From: Alfie Roddan <228966941+alfieroddanintel@users.noreply.github.com> Date: Tue, 11 Nov 2025 23:49:05 +0900 Subject: [PATCH 05/25] add precision modifier Signed-off-by: Alfie Roddan <228966941+alfieroddanintel@users.noreply.github.com> --- .../models/image/anomaly_dino/lightning_model.py | 15 ++++++++++++++- .../models/image/anomaly_dino/torch_model.py | 4 ++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/src/anomalib/models/image/anomaly_dino/lightning_model.py b/src/anomalib/models/image/anomaly_dino/lightning_model.py index 8530501f62..aee8afee36 100644 --- a/src/anomalib/models/image/anomaly_dino/lightning_model.py +++ b/src/anomalib/models/image/anomaly_dino/lightning_model.py @@ -53,7 +53,7 @@ from torch import nn from torchvision.transforms.v2 import Compose, InterpolationMode, Normalize, Resize -from anomalib import LearningType +from anomalib import LearningType, PrecisionType from anomalib.data import Batch from anomalib.metrics import Evaluator from anomalib.models.components import AnomalibModule, MemoryBankMixin @@ -91,6 +91,9 @@ class AnomalyDINO(MemoryBankMixin, AnomalibModule): to reduce the size of the memory bank. Defaults to ``False``. sampling ratio(float, optional): If coreset subsampling, by what ratio should we subsample. Defaults to ``0.1`` + precision (str, optional): Precision type for model computations. + Supported values are defined in :class:`PrecisionType`. + Defaults to ``PrecisionType.FLOAT32``. pre_processor (PreProcessor | bool, optional): Pre-processor instance or bool flag to enable default preprocessing. Defaults to ``True``. post_processor (PostProcessor | bool, optional): Post-processor instance or @@ -152,6 +155,7 @@ def __init__( masking: bool = False, coreset_subsampling: bool = False, sampling_ratio: float = 0.1, + precision: str = PrecisionType.FLOAT32, pre_processor: nn.Module | bool = True, post_processor: nn.Module | bool = True, evaluator: Evaluator | bool = True, @@ -171,6 +175,15 @@ def __init__( sampling_ratio=sampling_ratio, ) + if precision == PrecisionType.FLOAT16: + self.model = self.model.half() + elif precision == PrecisionType.FLOAT32: + self.model = self.model.float() + else: + msg = f"""Unsupported precision type: {precision}. + Supported types are: {PrecisionType.FLOAT16}, {PrecisionType.FLOAT32}.""" + raise ValueError(msg) + @classmethod def configure_pre_processor( cls, diff --git a/src/anomalib/models/image/anomaly_dino/torch_model.py b/src/anomalib/models/image/anomaly_dino/torch_model.py index b4d2b7c9bf..50d0af4630 100644 --- a/src/anomalib/models/image/anomaly_dino/torch_model.py +++ b/src/anomalib/models/image/anomaly_dino/torch_model.py @@ -220,6 +220,10 @@ def forward(self, input_tensor: torch.Tensor) -> torch.Tensor | InferenceBatch: * ``pred_score``: Image-level anomaly score ``(B, 1)`` * ``anomaly_map``: Pixel-level anomaly heatmap ``(B, 1, H, W)`` """ + # set precicion + input_tensor = input_tensor.type(self.memory_bank.dtype) + + # work out sizing b, _, w, h = input_tensor.shape cropped_width = w - w % self.feature_encoder.patch_size cropped_height = h - h % self.feature_encoder.patch_size From 822f8a3a29a0215fcb8ad1c2e6cac49262df0ed1 Mon Sep 17 00:00:00 2001 From: Alfie Roddan <228966941+alfieroddanintel@users.noreply.github.com> Date: Wed, 12 Nov 2025 01:04:36 +0900 Subject: [PATCH 06/25] remove fit comments. small typo of shape dimensions Signed-off-by: Alfie Roddan <228966941+alfieroddanintel@users.noreply.github.com> --- .../models/image/anomaly_dino/lightning_model.py | 13 +++++-------- .../models/image/anomaly_dino/torch_model.py | 2 +- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/src/anomalib/models/image/anomaly_dino/lightning_model.py b/src/anomalib/models/image/anomaly_dino/lightning_model.py index aee8afee36..38aff73cae 100644 --- a/src/anomalib/models/image/anomaly_dino/lightning_model.py +++ b/src/anomalib/models/image/anomaly_dino/lightning_model.py @@ -187,13 +187,13 @@ def __init__( @classmethod def configure_pre_processor( cls, - image_size: tuple[int, int] | None = None, + image_size: tuple[int, int] | int | None = None, ) -> PreProcessor: """Configure the default pre-processor for AnomalyDINO. Args: - image_size (tuple[int, int] | None, optional): Target size for resizing - input images. Defaults to ``(252, 252)``. + image_size (tuple[int, int] | int | None, optional): Target size for resizing + input images. Defaults to ``(252, 252)``. Note if int, keeps aspect ratio and resizes shortest side. Returns: PreProcessor: Configured pre-processor instance. @@ -243,11 +243,8 @@ def fit(self) -> None: """Optional fitting step. This method is a placeholder for potential post-training operations - such as coreset subsampling or feature normalization. - - Note: - The current implementation is a no-op, as AnomalyDINO typically - performs inference directly after feature extraction. + such as coreset subsampling or feature normalization. The model + handles fitting (if-needed). """ self.model.fit() diff --git a/src/anomalib/models/image/anomaly_dino/torch_model.py b/src/anomalib/models/image/anomaly_dino/torch_model.py index 50d0af4630..91f2072394 100644 --- a/src/anomalib/models/image/anomaly_dino/torch_model.py +++ b/src/anomalib/models/image/anomaly_dino/torch_model.py @@ -224,7 +224,7 @@ def forward(self, input_tensor: torch.Tensor) -> torch.Tensor | InferenceBatch: input_tensor = input_tensor.type(self.memory_bank.dtype) # work out sizing - b, _, w, h = input_tensor.shape + b, _, h, w = input_tensor.shape cropped_width = w - w % self.feature_encoder.patch_size cropped_height = h - h % self.feature_encoder.patch_size grid_size = ( From 4753d087820430609c7b94874a8c21802985613a Mon Sep 17 00:00:00 2001 From: Alfie Roddan <228966941+alfieroddanintel@users.noreply.github.com> Date: Wed, 12 Nov 2025 01:26:53 +0900 Subject: [PATCH 07/25] update docs for anomaly dino Signed-off-by: Alfie Roddan <228966941+alfieroddanintel@users.noreply.github.com> --- .../guides/reference/models/image/anomaly_dino.md | 13 +++++++++++++ .../markdown/guides/reference/models/image/index.md | 8 ++++++++ examples/configs/README.md | 1 + examples/configs/model/anomaly_dino.yaml | 7 +++++++ 4 files changed, 29 insertions(+) create mode 100644 docs/source/markdown/guides/reference/models/image/anomaly_dino.md create mode 100644 examples/configs/model/anomaly_dino.yaml diff --git a/docs/source/markdown/guides/reference/models/image/anomaly_dino.md b/docs/source/markdown/guides/reference/models/image/anomaly_dino.md new file mode 100644 index 0000000000..700f6c8da0 --- /dev/null +++ b/docs/source/markdown/guides/reference/models/image/anomaly_dino.md @@ -0,0 +1,13 @@ +# AnomalyDINO + +```{eval-rst} +.. automodule:: anomalib.models.image.anomaly_dino.lightning_model + :members: AnomalyDINO + :show-inheritance: +``` + +```{eval-rst} +.. automodule:: anomalib.models.image.anomaly_dino.torch_model + :members: AnomalyDINOModel + :show-inheritance: +``` diff --git a/docs/source/markdown/guides/reference/models/image/index.md b/docs/source/markdown/guides/reference/models/image/index.md index bfb6d5ab91..d637893b79 100644 --- a/docs/source/markdown/guides/reference/models/image/index.md +++ b/docs/source/markdown/guides/reference/models/image/index.md @@ -4,6 +4,13 @@ :margin: 1 1 0 0 :gutter: 1 +:::{grid-item-card} {material-regular}`model_training;1.5em` AnomalyDINO +:link: ./anomaly_dino +:link-type: doc + +Boosting Patch-based Few-shot Anomaly Detection with DINOv2 +::: + :::{grid-item-card} {material-regular}`model_training;1.5em` CFA :link: ./cfa :link-type: doc @@ -142,6 +149,7 @@ WinCLIP: Zero-/Few-Shot Anomaly Classification and Segmentation :caption: Data :hidden: +./anomaly_dino ./cfa ./cflow ./csflow diff --git a/examples/configs/README.md b/examples/configs/README.md index 8f7af15e1b..d0eab157bf 100644 --- a/examples/configs/README.md +++ b/examples/configs/README.md @@ -21,6 +21,7 @@ configs/ │ └── visa.yaml └── model ├── ai_vad.yaml + ├── anomaly_dino.yaml ├── cfa.yaml ├── cflow.yaml ├── csflow.yaml diff --git a/examples/configs/model/anomaly_dino.yaml b/examples/configs/model/anomaly_dino.yaml new file mode 100644 index 0000000000..6d69899984 --- /dev/null +++ b/examples/configs/model/anomaly_dino.yaml @@ -0,0 +1,7 @@ +model: + class_path: anomalib.models.AnomalyDINO + init_args: + num_neighbours: 1 + encoder_name: dinov2_vit_small_14 + masking: False + coreset_subsampling: False From 6cf0d11db5d067f639d0528e383956a00a4d66ae Mon Sep 17 00:00:00 2001 From: Alfie Roddan <228966941+alfieroddanintel@users.noreply.github.com> Date: Wed, 12 Nov 2025 01:39:42 +0900 Subject: [PATCH 08/25] cleanup docstrings Signed-off-by: Alfie Roddan <228966941+alfieroddanintel@users.noreply.github.com> --- .../models/image/anomaly_dino/__init__.py | 13 +++++----- .../image/anomaly_dino/lightning_model.py | 25 +++++++++---------- .../models/image/anomaly_dino/torch_model.py | 6 +++-- 3 files changed, 22 insertions(+), 22 deletions(-) diff --git a/src/anomalib/models/image/anomaly_dino/__init__.py b/src/anomalib/models/image/anomaly_dino/__init__.py index 32a5345f6d..a3ba92fa0f 100644 --- a/src/anomalib/models/image/anomaly_dino/__init__.py +++ b/src/anomalib/models/image/anomaly_dino/__init__.py @@ -1,4 +1,4 @@ -# Copyright (C) 2022-2025 Intel Corporation +# Copyright (C) 2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 """AnomalyDINO: Boosting Patch-based Few-shot Anomaly Detection with DINOv2. @@ -15,20 +15,20 @@ >>> from anomalib.data import MVTecAD >>> from anomalib.models.image.anomaly_dino.lightning_model import AnomalyDINO >>> from anomalib.engine import Engine - >>> + >>> MVTEC_CATEGORIES = [ ... "hazelnut", "grid", "carpet", "bottle", "cable", "capsule", "leather", ... "metal_nut", "pill", "screw", "tile", "toothbrush", "transistor", "wood", "zipper" ... ] >>> MASKED_CATEGORIES = ["capsule", "hazelnut", "pill", "screw", "toothbrush"] - >>> + >>> for category in MVTEC_CATEGORIES: ... mask = category in MASKED_CATEGORIES ... print(f"--- Running category: {category} | masking={mask} ---") - ... + ... # Initialize data module ... datamodule = MVTecAD(category=category) - ... + ... # Initialize model ... model = AnomalyDINO( ... num_neighbours=1, @@ -36,12 +36,11 @@ ... masking=mask, ... coreset_subsampling=False, ... ) - ... + ... # Train and test ... engine = Engine() ... engine.fit(model=model, datamodule=datamodule) ... engine.test(datamodule=datamodule) - >>> >>> print("All categories processed.") """ diff --git a/src/anomalib/models/image/anomaly_dino/lightning_model.py b/src/anomalib/models/image/anomaly_dino/lightning_model.py index 38aff73cae..10913d1c91 100644 --- a/src/anomalib/models/image/anomaly_dino/lightning_model.py +++ b/src/anomalib/models/image/anomaly_dino/lightning_model.py @@ -1,4 +1,4 @@ -# Copyright (C) 2022-2025 Intel Corporation +# Copyright (C) 2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 """AnomalyDINO: Boosting Patch-based Few-shot Anomaly Detection with DINOv2. @@ -15,20 +15,20 @@ >>> from anomalib.data import MVTecAD >>> from anomalib.models.image.anomaly_dino.lightning_model import AnomalyDINO >>> from anomalib.engine import Engine - >>> + >>> MVTEC_CATEGORIES = [ ... "hazelnut", "grid", "carpet", "bottle", "cable", "capsule", "leather", ... "metal_nut", "pill", "screw", "tile", "toothbrush", "transistor", "wood", "zipper" ... ] >>> MASKED_CATEGORIES = ["capsule", "hazelnut", "pill", "screw", "toothbrush"] - >>> + >>> for category in MVTEC_CATEGORIES: ... mask = category in MASKED_CATEGORIES ... print(f"--- Running category: {category} | masking={mask} ---") - ... + ... # Initialize data module ... datamodule = MVTecAD(category=category) - ... + ... # Initialize model ... model = AnomalyDINO( ... num_neighbours=1, @@ -36,12 +36,11 @@ ... masking=mask, ... coreset_subsampling=False, ... ) - ... + ... # Train and test ... engine = Engine() ... engine.fit(model=model, datamodule=datamodule) ... engine.test(datamodule=datamodule) - >>> >>> print("All categories processed.") """ @@ -107,20 +106,20 @@ class AnomalyDINO(MemoryBankMixin, AnomalibModule): >>> from anomalib.data import MVTecAD >>> from anomalib.models.image.anomaly_dino.lightning_model import AnomalyDINO >>> from anomalib.engine import Engine - >>> + >>> MVTEC_CATEGORIES = [ ... "hazelnut", "grid", "carpet", "bottle", "cable", "capsule", "leather", ... "metal_nut", "pill", "screw", "tile", "toothbrush", "transistor", "wood", "zipper" ... ] >>> MASKED_CATEGORIES = ["capsule", "hazelnut", "pill", "screw", "toothbrush"] - >>> + >>> for category in MVTEC_CATEGORIES: ... mask = category in MASKED_CATEGORIES ... print(f"--- Running category: {category} | masking={mask} ---") - ... + ... # Initialize data module ... datamodule = MVTecAD(category=category) - ... + ... # Initialize model ... model = AnomalyDINO( ... num_neighbours=1, @@ -128,12 +127,12 @@ class AnomalyDINO(MemoryBankMixin, AnomalibModule): ... masking=mask, ... coreset_subsampling=False, ... ) - ... + ... # Train and test ... engine = Engine() ... engine.fit(model=model, datamodule=datamodule) ... engine.test(datamodule=datamodule) - >>> + >>> print("All categories processed.") Notes: diff --git a/src/anomalib/models/image/anomaly_dino/torch_model.py b/src/anomalib/models/image/anomaly_dino/torch_model.py index 91f2072394..d23ee29d56 100644 --- a/src/anomalib/models/image/anomaly_dino/torch_model.py +++ b/src/anomalib/models/image/anomaly_dino/torch_model.py @@ -1,4 +1,4 @@ -# Copyright (C) 2022-2025 Intel Corporation +# Copyright (C) 2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 """PyTorch model implementation for AnomalyDINO. @@ -75,7 +75,9 @@ def __init__( self.sampling_ratio = sampling_ratio # Load DINOv2 backbone - assert encoder_name.startswith("dinov2"), f"Encoder must be dinov2, got {encoder_name}" + if not encoder_name.startswith("dinov2"): + err_str = f"Encoder must be dinov2, got {encoder_name}" + raise ValueError(err_str) self.feature_encoder = load_dinov2_model(self.encoder_name) self.feature_encoder.eval() From 107432cc959d2dd74cb660e3ed8e878c1c53240b Mon Sep 17 00:00:00 2001 From: Alfie Roddan <228966941+alfieroddanintel@users.noreply.github.com> Date: Wed, 12 Nov 2025 02:42:04 +0900 Subject: [PATCH 09/25] add unit tests for anomalydino. change distance computation from cdist to matmul, work with half tensors Signed-off-by: Alfie Roddan <228966941+alfieroddanintel@users.noreply.github.com> --- .../models/image/anomaly_dino/torch_model.py | 26 +++-- .../models/image/anomaly_dino/__init__.py | 4 + .../image/anomaly_dino/test_torch_model.py | 101 ++++++++++++++++++ 3 files changed, 124 insertions(+), 7 deletions(-) create mode 100644 tests/unit/models/image/anomaly_dino/__init__.py create mode 100644 tests/unit/models/image/anomaly_dino/test_torch_model.py diff --git a/src/anomalib/models/image/anomaly_dino/torch_model.py b/src/anomalib/models/image/anomaly_dino/torch_model.py index d23ee29d56..93e035d81c 100644 --- a/src/anomalib/models/image/anomaly_dino/torch_model.py +++ b/src/anomalib/models/image/anomaly_dino/torch_model.py @@ -251,16 +251,24 @@ def forward(self, input_tensor: torch.Tensor) -> torch.Tensor | InferenceBatch: self.embedding_store.append(features) return torch.tensor(0.0, device=device, requires_grad=True) + # check bank isn't empty at inference + if self.memory_bank.numel() == 0: + msg = "Memory bank is empty. Run the model in training mode and call `fit()` before inference." + raise RuntimeError(msg) + + # Ensure dtype consistency + if features.dtype != self.memory_bank.dtype: + features = features.to(self.memory_bank.dtype) + # Inference # L2-normalized distances # memory_bank : [M, D], features : [Q, D] - # Compute pairwise distances [Q, M] - dists = torch.cdist(features, self.memory_bank, p=2) - - # Convert L2 to cosine distance - # (since both vectors are normalized, divide by 2) - dists = dists / 2.0 + # Compute cosine distance using matrix multiplication + # both features and memory_bank are already L2-normalized. + # cdist is not for half precision, but matmul is. + similarity = torch.matmul(features, self.memory_bank.T) # [Q, M] + dists = (torch.ones_like(similarity) - similarity).clamp(min=0.0, max=2.0) # cosine distance ∈ [0, 2] # Get top-k nearest neighbors k = max(1, self.num_neighbours) @@ -270,7 +278,11 @@ def forward(self, input_tensor: torch.Tensor) -> torch.Tensor | InferenceBatch: min_dists = topk_vals.mean(dim=1) if k > 1 else topk_vals.squeeze(1) # Vectorized reconstruction - distances_full = torch.zeros((b, grid_size[0] * grid_size[1]), device=device) + distances_full = torch.zeros( + (b, grid_size[0] * grid_size[1]), + device=device, + dtype=min_dists.dtype, + ) batch_idx, patch_idx = torch.nonzero(masks, as_tuple=True) distances_full[batch_idx, patch_idx] = min_dists diff --git a/tests/unit/models/image/anomaly_dino/__init__.py b/tests/unit/models/image/anomaly_dino/__init__.py new file mode 100644 index 0000000000..41e06ba1b8 --- /dev/null +++ b/tests/unit/models/image/anomaly_dino/__init__.py @@ -0,0 +1,4 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for AnomalyDINO.""" diff --git a/tests/unit/models/image/anomaly_dino/test_torch_model.py b/tests/unit/models/image/anomaly_dino/test_torch_model.py new file mode 100644 index 0000000000..87153c02d9 --- /dev/null +++ b/tests/unit/models/image/anomaly_dino/test_torch_model.py @@ -0,0 +1,101 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for the WinCLIP torch model.""" + +import numpy as np +import pytest +import torch +from _pytest.monkeypatch import MonkeyPatch + +from anomalib.models.image.anomaly_dino.torch_model import AnomalyDINOModel + + +class TestAnomalyDINOModel: + """Test the AnomalyDINO torch model.""" + + @staticmethod + def test_initialization_defaults() -> None: + """Test initialization with default arguments.""" + model = AnomalyDINOModel() + assert model.encoder_name.startswith("dinov2") + assert model.memory_bank.numel() == 0 + + @staticmethod + def test_invalid_encoder_name_raises() -> None: + """Test that invalid encoder names raise an error.""" + with pytest.raises(ValueError, match="Encoder must be dinov2"): + _ = AnomalyDINOModel(encoder_name="resnet50") + + @staticmethod + def test_fit_raises_without_embeddings() -> None: + """Test that fit raises when no embeddings have been collected.""" + model = AnomalyDINOModel() + with pytest.raises(ValueError, match="No embeddings collected"): + model.fit() + + @staticmethod + def test_forward_train_adds_embeddings(monkeypatch: MonkeyPatch) -> None: + """Test training mode collects embeddings into store.""" + model = AnomalyDINOModel() + model.train() + + fake_features = torch.randn(2, 8, 128) + monkeypatch.setattr(model, "extract_features", lambda _: fake_features) + + x = torch.randn(2, 3, 224, 224) + output = model(x) + assert torch.is_tensor(output) + assert output.requires_grad + assert len(model.embedding_store) == 1 + assert model.embedding_store[0].ndim == 2 + + @staticmethod + def test_forward_eval_raises_with_empty_memory_bank(monkeypatch: MonkeyPatch) -> None: + """Test that inference raises an error when memory bank is empty.""" + model = AnomalyDINOModel() + model.eval() + + fake_features = torch.randn(1, 16, 64) + monkeypatch.setattr(model, "extract_features", lambda _: fake_features) + model.register_buffer("memory_bank", torch.empty(0, 64)) + + x = torch.randn(1, 3, 224, 224) + with pytest.raises(RuntimeError, match="Memory bank is empty"): + _ = model(x) + + @staticmethod + def test_compute_background_masks_runs() -> None: + """Test that background mask computation produces boolean masks.""" + b, h, w, d = 2, 8, 8, 16 + features = np.random.randn(b, h * w, d).astype(np.float32) # noqa: NPY002 + masks = AnomalyDINOModel.compute_background_masks(features, (h, w)) + assert masks.shape == (b, h * w) + assert masks.dtype == bool + + @staticmethod + def test_mean_top1p_computation() -> None: + """Test that mean_top1p returns expected shape and value.""" + distances = torch.arange(0, 100, dtype=torch.float32).view(1, -1) + result = AnomalyDINOModel.mean_top1p(distances) + assert result.shape == (1, 1) + assert torch.allclose(result, torch.tensor([[99.0]])) + + @staticmethod + def test_forward_half_precision_eval(monkeypatch: MonkeyPatch) -> None: + """Test inference in half precision (float16) using matmul cosine distance.""" + model = AnomalyDINOModel().half() + model.eval() + + fake_features = torch.randn(1, 16, 64, dtype=torch.float16) + monkeypatch.setattr(model, "extract_features", lambda _: fake_features) + monkeypatch.setattr(model.anomaly_map_generator, "__call__", lambda x, __: x) + + model.register_buffer("memory_bank", torch.randn(16, 64, dtype=torch.float16)) + x = torch.randn(1, 3, 224, 224, dtype=torch.float16) + out = model(x) + + assert hasattr(out, "pred_score") + assert out.pred_score.shape == (1, 1) + # outputs should be float16-safe with matmul + assert out.pred_score.dtype == torch.float16 From 4a4423f84b62b28140dbb37e37dca17fdf214c3f Mon Sep 17 00:00:00 2001 From: Alfie Roddan <228966941+alfieroddanintel@users.noreply.github.com> Date: Thu, 13 Nov 2025 20:52:16 +0900 Subject: [PATCH 10/25] add vit/dino implementation (no xformers). implement factory class for generating dinov2. update anomaly_dino to use factory method Signed-off-by: Alfie Roddan <228966941+alfieroddanintel@users.noreply.github.com> --- .../models/components/dinov2/__init__.py | 31 + .../models/components/dinov2/dinov2_loader.py | 239 ++++++++ .../components/dinov2/layers/__init__.py | 32 ++ .../components/dinov2/layers/attention.py | 150 +++++ .../models/components/dinov2/layers/block.py | 243 ++++++++ .../components/dinov2/layers/dino_head.py | 124 ++++ .../components/dinov2/layers/drop_path.py | 62 ++ .../components/dinov2/layers/layer_scale.py | 53 ++ .../models/components/dinov2/layers/mlp.py | 60 ++ .../components/dinov2/layers/patch_embed.py | 110 ++++ .../components/dinov2/layers/swiglu_ffn.py | 146 +++++ .../components/dinov2/vision_transformer.py | 542 ++++++++++++++++++ .../models/image/anomaly_dino/torch_model.py | 8 +- 13 files changed, 1795 insertions(+), 5 deletions(-) create mode 100644 src/anomalib/models/components/dinov2/__init__.py create mode 100644 src/anomalib/models/components/dinov2/dinov2_loader.py create mode 100644 src/anomalib/models/components/dinov2/layers/__init__.py create mode 100644 src/anomalib/models/components/dinov2/layers/attention.py create mode 100644 src/anomalib/models/components/dinov2/layers/block.py create mode 100644 src/anomalib/models/components/dinov2/layers/dino_head.py create mode 100644 src/anomalib/models/components/dinov2/layers/drop_path.py create mode 100644 src/anomalib/models/components/dinov2/layers/layer_scale.py create mode 100644 src/anomalib/models/components/dinov2/layers/mlp.py create mode 100644 src/anomalib/models/components/dinov2/layers/patch_embed.py create mode 100644 src/anomalib/models/components/dinov2/layers/swiglu_ffn.py create mode 100644 src/anomalib/models/components/dinov2/vision_transformer.py diff --git a/src/anomalib/models/components/dinov2/__init__.py b/src/anomalib/models/components/dinov2/__init__.py new file mode 100644 index 0000000000..e9db6cf8f2 --- /dev/null +++ b/src/anomalib/models/components/dinov2/__init__.py @@ -0,0 +1,31 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +"""Anomalib's Vision Transformer implementation. + +References: +https://github.com/facebookresearch/dinov2/blob/main/dinov2/ +""" + +# vision transformer +# loader +from .dinov2_loader import DinoV2Loader +from .vision_transformer import ( + DinoVisionTransformer, + vit_base, + vit_giant2, + vit_large, + vit_small, +) + +__all__ = [ + # vision transformer + "DinoVisionTransformer", + "vit_base", + "vit_giant2", + "vit_large", + "vit_small", + # loader + "DinoV2Loader", +] diff --git a/src/anomalib/models/components/dinov2/dinov2_loader.py b/src/anomalib/models/components/dinov2/dinov2_loader.py new file mode 100644 index 0000000000..6c122bafff --- /dev/null +++ b/src/anomalib/models/components/dinov2/dinov2_loader.py @@ -0,0 +1,239 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Loader for DINOv2 Vision Transformer models. + +This module provides a simple interface for loading pre-trained DINOv2 Vision Transformer models for the +Dinomaly anomaly detection framework. + +Example: + model = DinoV2Loader.from_name("dinov2_vit_base_14") + model = DinoV2Loader.from_name("dinomaly_vit_base_14") +""" + +from __future__ import annotations + +import logging +from pathlib import Path +from typing import ClassVar +from urllib.request import urlretrieve + +import torch + +from anomalib.data.utils import DownloadInfo +from anomalib.data.utils.download import DownloadProgressBar +from anomalib.models.components.dinov2 import vision_transformer as dinov2_models +from anomalib.models.image.dinomaly.components import vision_transformer as dinomaly_models + +logger = logging.getLogger(__name__) + +MODEL_FACTORIES: dict[str, object] = { + "dinov2": dinov2_models, + "dinov2_reg": dinov2_models, + "dinomaly": dinomaly_models, +} + + +class DinoV2Loader: + """Simple loader for DINOv2 Vision Transformer models. + + Supports loading dinov2, dinov2_reg, and dinomaly model variants across small, base, + and large architectures. + """ + + DINOV2_BASE_URL: ClassVar[str] = "https://dl.fbaipublicfiles.com/dinov2" + + MODEL_CONFIGS: ClassVar[dict[str, dict[str, int]]] = { + "small": {"embed_dim": 384, "num_heads": 6}, + "base": {"embed_dim": 768, "num_heads": 12}, + "large": {"embed_dim": 1024, "num_heads": 16}, + } + + def __init__(self, cache_dir: str | Path = "./pre_trained/") -> None: + """Initialize a model loader instance. + + Args: + cache_dir: Directory in which downloaded weights will be stored. + """ + self.cache_dir: Path = Path(cache_dir) + self.cache_dir.mkdir(parents=True, exist_ok=True) + + def load(self, model_name: str) -> torch.nn.Module: + """Load a DINOv2 model by name. + + Args: + model_name: Model identifier such as "dinov2_vit_base_14". + + Returns: + A fully constructed and weight-loaded PyTorch module. + + Raises: + ValueError: If the requested model name is malformed or unsupported. + """ + model_type, architecture, patch_size = self._parse_name(model_name) + model = self._create_model(model_type, architecture, patch_size) + self._load_weights(model, model_type, architecture, patch_size) + + logger.info(f"Loaded model: {model_name}") + return model + + @classmethod + def from_name( + cls, + model_name: str, + cache_dir: str | Path = "./pre_trained/", + ) -> torch.nn.Module: + """Instantiate a loader and return the requested model.""" + loader = cls(cache_dir) + return loader.load(model_name) + + def _parse_name(self, name: str) -> tuple[str, str, int]: + """Parse a model name string into components. + + Args: + name: Full model name string. + + Returns: + Tuple of (model_type, architecture_name, patch_size). + + Raises: + ValueError: If the prefix or architecture is unknown. + """ + parts = name.split("_") + prefix = parts[0] + architecture = parts[-2] + patch_size = int(parts[-1]) + + if prefix == "dinov2reg": + model_type = "dinov2_reg" + elif prefix == "dinov2": + model_type = "dinov2" + elif prefix == "dinomaly": + model_type = "dinomaly" + else: + msg = f"Unknown model type prefix '{prefix}'." + raise ValueError(msg) + + if architecture not in self.MODEL_CONFIGS: + msg = f"Invalid architecture '{architecture}'. Expected one of: {list(self.MODEL_CONFIGS)}" + raise ValueError( + msg, + ) + + return model_type, architecture, patch_size + + @staticmethod + def _create_model( + model_type: str, + architecture: str, + patch_size: int, + ) -> torch.nn.Module: + """Construct a model instance using the configured factory modules. + + Args: + model_type: Model family, e.g., "dinov2", "dinov2_reg", "dinomaly". + architecture: Architecture label ("small", "base", "large"). + patch_size: Patch resolution. + + Returns: + An instantiated PyTorch module. + + Raises: + ValueError: If the relevant constructor cannot be found. + """ + model_kwargs: dict[str, object] = { + "patch_size": patch_size, + "img_size": 518, + "block_chunks": 0, + "init_values": 1e-8, + "interpolate_antialias": False, + "interpolate_offset": 0.1, + } + + if model_type == "dinov2_reg": + model_kwargs["num_register_tokens"] = 4 + + module = MODEL_FACTORIES.get(model_type) + if module is None: + msg = f"Unknown model type '{model_type}'." + raise ValueError(msg) + + ctor = getattr(module, f"vit_{architecture}", None) + if ctor is None: + msg = f"No constructor 'vit_{architecture}' in module {module}." + raise ValueError(msg) + + model: torch.nn.Module = ctor(**model_kwargs) + return model + + def _load_weights( + self, + model: torch.nn.Module, + model_type: str, + architecture: str, + patch_size: int, + ) -> None: + """Load pre-trained weights from disk, downloading them if necessary.""" + weight_path = self._get_weight_path(model_type, architecture, patch_size) + + if not weight_path.exists(): + self._download_weights(model_type, architecture, patch_size) + + # Using weights_only=True for safety mitigation (see Anomalib PR #2729) + state_dict = torch.load(weight_path, map_location="cpu", weights_only=True) # nosec B614 + model.load_state_dict(state_dict, strict=False) + + def _get_weight_path( + self, + model_type: str, + architecture: str, + patch_size: int, + ) -> Path: + """Return the expected local path for downloaded weights.""" + arch_code = architecture[0] + + if model_type == "dinov2_reg": + filename = f"dinov2_vit{arch_code}{patch_size}_reg4_pretrain.pth" + else: + filename = f"dinov2_vit{arch_code}{patch_size}_pretrain.pth" + + return self.cache_dir / filename + + def _download_weights( + self, + model_type: str, + architecture: str, + patch_size: int, + ) -> None: + """Download DINOv2 weight files using Anomalib's standardized utilities.""" + weight_path = self._get_weight_path(model_type, architecture, patch_size) + arch_code = architecture[0] + + model_dir = f"dinov2_vit{arch_code}{patch_size}" + url = f"{self.DINOV2_BASE_URL}/{model_dir}/{weight_path.name}" + + download_info = DownloadInfo( + name=f"DINOv2 {model_type} {architecture} weights", + url=url, + hashsum="", # DINOv2 publishes no official hash + filename=weight_path.name, + ) + + logger.info( + f"Downloading DINOv2 weights: {weight_path.name} to {self.cache_dir}", + ) + + self.cache_dir.mkdir(parents=True, exist_ok=True) + + with DownloadProgressBar( + unit="B", + unit_scale=True, + miniters=1, + desc=download_info.name, + ) as progress_bar: + # nosemgrep: python.lang.security.audit.dynamic-urllib-use-detected.dynamic-urllib-use-detected # noqa: ERA001, E501 + urlretrieve( # noqa: S310 # nosec B310 + url=url, + filename=weight_path, + reporthook=progress_bar.update_to, + ) diff --git a/src/anomalib/models/components/dinov2/layers/__init__.py b/src/anomalib/models/components/dinov2/layers/__init__.py new file mode 100644 index 0000000000..e2c88d7aa2 --- /dev/null +++ b/src/anomalib/models/components/dinov2/layers/__init__.py @@ -0,0 +1,32 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Layers needed to build DINOv2. + +References: +https://github.com/facebookresearch/dinov2/blob/main/dinov2/layers/__init__.py +""" + +from .attention import Attention, MemEffAttention +from .block import Block, CausalAttentionBlock +from .dino_head import DINOHead +from .drop_path import DropPath +from .layer_scale import LayerScale +from .mlp import Mlp +from .patch_embed import PatchEmbed +from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNAligned, SwiGLUFFNFused + +__all__ = [ + "Attention", + "CausalAttentionBlock", + "Block", + "DINOHead", + "DropPath", + "LayerScale", + "MemEffAttention", + "Mlp", + "PatchEmbed", + "SwiGLUFFN", + "SwiGLUFFNAligned", + "SwiGLUFFNFused", +] diff --git a/src/anomalib/models/components/dinov2/layers/attention.py b/src/anomalib/models/components/dinov2/layers/attention.py new file mode 100644 index 0000000000..0f1ae6ef14 --- /dev/null +++ b/src/anomalib/models/components/dinov2/layers/attention.py @@ -0,0 +1,150 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Attention layers for DINOv2 Vision Transformers. + +This module provides: +- A standard multi-head self-attention implementation (`Attention`) +- A memory-efficient xFormers-based version (`MemEffAttention`) when xFormers is available + +These layers are used as core components within DINOv2 and Dinomaly transformer +blocks for feature extraction and masked modeling. +""" + +from __future__ import annotations + +import logging + +import torch +from torch import Tensor, nn +from torch.nn import functional as F # noqa: N812 + +logger = logging.getLogger(__name__) + + +class Attention(nn.Module): + """Standard multi-head self-attention layer. + + Implements a QKV-projection attention block with optional bias, dropout, and + projection layers. This is the default attention mechanism used in DINOv2 + when memory-efficient attention kernels are not available. + + Args: + dim: Embedding dimension. + num_heads: Number of attention heads. + qkv_bias: Whether to include bias in the QKV projections. + proj_bias: Whether to include bias in the output projection. + attn_drop: Dropout probability applied to attention weights. + proj_drop: Dropout probability applied after projection. + """ + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ) -> None: + super().__init__() + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = attn_drop + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + def init_weights( + self, + init_attn_std: float | None = None, + init_proj_std: float | None = None, + factor: float = 1.0, + ) -> None: + """Initialize QKV and projection weights. + + Args: + init_attn_std: Standard deviation for attention weights. + init_proj_std: Standard deviation for projection weights. + factor: Additional scaling factor for projection initialization. + """ + init_attn_std = init_attn_std or (self.dim**-0.5) + init_proj_std = init_proj_std or (init_attn_std * factor) + + nn.init.normal_(self.qkv.weight, std=init_attn_std) + nn.init.normal_(self.proj.weight, std=init_proj_std) + + if self.qkv.bias is not None: + nn.init.zeros_(self.qkv.bias) + if self.proj.bias is not None: + nn.init.zeros_(self.proj.bias) + + def forward(self, x: Tensor, is_causal: bool = False) -> Tensor: + """Apply multi-head self-attention. + + Args: + x: Input sequence of shape ``(B, N, C)``. + is_causal: If True, applies causal masking. + + Returns: + Tensor of shape ``(B, N, C)`` containing attended features. + """ + b, n, c = x.shape + qkv = self.qkv(x).reshape(b, n, 3, self.num_heads, c // self.num_heads) + q, k, v = torch.unbind(qkv, 2) + q, k, v = (t.transpose(1, 2) for t in (q, k, v)) + + x = nn.functional.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + dropout_p=self.attn_drop if self.training else 0.0, + is_causal=is_causal, + ) + + x = x.transpose(1, 2).contiguous().view(b, n, c) + return self.proj_drop(self.proj(x)) + + +class MemEffAttention(Attention): + """Memory-efficient attention from the dinov2 implementation with a small change. + + Reference: + https://github.com/facebookresearch/dinov2/blob/592541c8d842042bb5ab29a49433f73b544522d5/dinov2/eval/segmentation_m2f/models/backbones/vit.py#L159 + + Instead of using xformers's memory_efficient_attention() method, which requires adding a new dependency to anomalib, + this implementation uses the scaled dot product from torch. + """ + + def forward(self, x: Tensor, attn_bias: Tensor | None = None) -> Tensor: + """Compute memory-efficient attention using PyTorch's scaled dot product attention. + + Args: + x: Input tensor of shape (batch_size, seq_len, embed_dim). + attn_bias: Optional attention bias mask. Default: None. + + Returns: + Output tensor of shape (batch_size, seq_len, embed_dim). + """ + batch_size, seq_len, embed_dim = x.shape + qkv = self.qkv(x).reshape(batch_size, seq_len, 3, self.num_heads, embed_dim // self.num_heads) + + q, k, v = qkv.unbind(2) + + # Use PyTorch's native scaled dot product attention for memory efficiency. + # Replaced xformers's memory_efficient_attention() method with pytorch's scaled + # dot product. + x = F.scaled_dot_product_attention( + q.transpose(1, 2), + k.transpose(1, 2), + v.transpose(1, 2), + attn_mask=attn_bias, + ) + x = x.transpose(1, 2).reshape(batch_size, seq_len, embed_dim) + + x = self.proj(x) + return self.proj_drop(x) diff --git a/src/anomalib/models/components/dinov2/layers/block.py b/src/anomalib/models/components/dinov2/layers/block.py new file mode 100644 index 0000000000..1afe3512de --- /dev/null +++ b/src/anomalib/models/components/dinov2/layers/block.py @@ -0,0 +1,243 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Transformer blocks used in DINOv2 Vision Transformers. + +This module implements: +- Standard transformer blocks with attention and MLP (`Block`) +- Causal attention blocks (`CausalAttentionBlock`) + +The implementation is adapted from the original DINO and timm Vision +Transformer code: + +- https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +- https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +import torch +from torch import Tensor, nn + +from .attention import Attention +from .drop_path import DropPath +from .layer_scale import LayerScale +from .mlp import Mlp + +if TYPE_CHECKING: + from collections.abc import Callable + +logger = logging.getLogger("dinov2") + + +class Block(nn.Module): + """Standard transformer block with attention and MLP. + + This block applies layer normalization, multi-head self-attention, optional + layer scaling, stochastic depth, and a feed-forward network. + + Args: + dim: Embedding dimension. + num_heads: Number of attention heads. + mlp_ratio: Expansion ratio for the MLP hidden dimension. + qkv_bias: Whether to use bias in the QKV projections. + proj_bias: Whether to use bias in the attention projection. + ffn_bias: Whether to use bias in the MLP. + drop: Dropout probability applied to projections and MLP. + attn_drop: Dropout probability applied to attention weights. + init_values: Initial value for LayerScale. If ``None``, LayerScale is disabled. + drop_path: Stochastic depth rate. + act_layer: Activation layer factory for the MLP. + norm_layer: Normalization layer factory. + attn_class: Attention layer factory. + ffn_layer: Feed-forward layer factory. + """ + + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + proj_bias: bool = True, + ffn_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + init_values: float | Tensor | None = None, + drop_path: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_class: Callable[..., nn.Module] = Attention, + ffn_layer: Callable[..., nn.Module] = Mlp, + ) -> None: + super().__init__() + + self.norm1: nn.Module = norm_layer(dim) + self.attn: nn.Module = attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=drop, + ) + self.ls1: nn.Module = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path1: nn.Module = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2: nn.Module = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp: nn.Module = ffn_layer( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + bias=ffn_bias, + ) + self.ls2: nn.Module = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2: nn.Module = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.sample_drop_ratio: float = drop_path + + def forward(self, x: Tensor) -> Tensor: + """Apply attention and MLP residual blocks with optional stochastic depth.""" + + def attn_residual_func(inp: Tensor) -> Tensor: + return self.ls1(self.attn(self.norm1(inp))) + + def ffn_residual_func(inp: Tensor) -> Tensor: + return self.ls2(self.mlp(self.norm2(inp))) + + if self.training and self.sample_drop_ratio > 0.1: + x = drop_add_residual_stochastic_depth( + x, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + x = drop_add_residual_stochastic_depth( + x, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + elif self.training and self.sample_drop_ratio > 0.0: + x = x + self.drop_path1(attn_residual_func(x)) + x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 # noqa: FIX001, TD001, TD002, TD003 + else: + x = x + attn_residual_func(x) + x = x + ffn_residual_func(x) + return x + + +class CausalAttentionBlock(nn.Module): + """Transformer block with causal attention. + + This block applies causal self-attention followed by a feed-forward network, + with optional LayerScale and dropout. + + Args: + dim: Embedding dimension. + num_heads: Number of attention heads. + ffn_ratio: Expansion ratio for the feed-forward network. + ls_init_value: Initial value for LayerScale. If ``None``, LayerScale is disabled. + is_causal: Whether to apply causal masking. + act_layer: Activation layer factory for the MLP. + norm_layer: Normalization layer factory. + dropout_prob: Dropout probability applied to attention and MLP. + """ + + def __init__( + self, + dim: int, + num_heads: int, + ffn_ratio: float = 4.0, + ls_init_value: float | None = None, + is_causal: bool = True, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + dropout_prob: float = 0.0, + ) -> None: + super().__init__() + + self.dim: int = dim + self.is_causal: bool = is_causal + self.ls1: nn.Module = LayerScale(dim, init_values=ls_init_value) if ls_init_value else nn.Identity() + self.attention_norm: nn.Module = norm_layer(dim) + self.attention: Attention = Attention( + dim, + num_heads, + attn_drop=dropout_prob, + proj_drop=dropout_prob, + ) + + self.ffn_norm: nn.Module = norm_layer(dim) + ffn_hidden_dim = int(dim * ffn_ratio) + self.feed_forward: Mlp = Mlp( + in_features=dim, + hidden_features=ffn_hidden_dim, + drop=dropout_prob, + act_layer=act_layer, + ) + + self.ls2: nn.Module = LayerScale(dim, init_values=ls_init_value) if ls_init_value else nn.Identity() + + def init_weights( + self, + init_attn_std: float | None = None, + init_proj_std: float | None = None, + init_fc_std: float | None = None, + factor: float = 1.0, + ) -> None: + """Initialize attention and MLP weights.""" + init_attn_std = init_attn_std or (self.dim**-0.5) + init_proj_std = init_proj_std or (init_attn_std * factor) + init_fc_std = init_fc_std or (2 * self.dim) ** -0.5 + + self.attention.init_weights(init_attn_std, init_proj_std) + self.attention_norm.reset_parameters() + nn.init.normal_(self.feed_forward.fc1.weight, std=init_fc_std) + nn.init.normal_(self.feed_forward.fc2.weight, std=init_proj_std) + self.ffn_norm.reset_parameters() + + def forward(self, x: Tensor) -> Tensor: + """Apply causal attention followed by a feed-forward block.""" + x_attn = x + self.ls1(self.attention(self.attention_norm(x), self.is_causal)) + return x_attn + self.ls2(self.feed_forward(self.ffn_norm(x_attn))) + + +def drop_add_residual_stochastic_depth( + x: Tensor, + residual_func: Callable[[Tensor], Tensor], + sample_drop_ratio: float = 0.0, +) -> Tensor: + """Apply stochastic depth to a residual branch on a subset of samples. + + Args: + x: Input tensor of shape ``(B, N, C)``. + residual_func: Function computing the residual on a subset of samples. + sample_drop_ratio: Fraction of samples to drop for residual computation. + + Returns: + Tensor with residual added to a subset of samples. + """ + b, _, _ = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = torch.randperm(b, device=x.device)[:sample_subset_size] + x_subset = x[brange] + + residual = residual_func(x_subset) + + x_flat = x.flatten(1) + residual_flat = residual.flatten(1) + + residual_scale_factor = b / sample_subset_size + + x_plus_residual = torch.index_add( + x_flat, + 0, + brange, + residual_flat.to(dtype=x.dtype), + alpha=residual_scale_factor, + ) + return x_plus_residual.view_as(x) diff --git a/src/anomalib/models/components/dinov2/layers/dino_head.py b/src/anomalib/models/components/dinov2/layers/dino_head.py new file mode 100644 index 0000000000..811195053b --- /dev/null +++ b/src/anomalib/models/components/dinov2/layers/dino_head.py @@ -0,0 +1,124 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""DINO projection head module. + +Reference: +https://github.com/facebookresearch/dinov2/blob/main/dinov2/layers/dino_head.py +""" + +from __future__ import annotations + +import torch +from torch import Tensor, nn +from torch.nn.init import trunc_normal_ +from torch.nn.utils import weight_norm + + +class DINOHead(nn.Module): + """Projection head used in DINO and DINOv2. + + This module applies a multi-layer perceptron (MLP) followed by weight-normalized + output projection, matching the design used in the official DINOv2 models. + + Args: + in_dim: Input embedding dimension. + out_dim: Output projection dimension. + use_bn: Whether to insert BatchNorm1d layers in the MLP. + nlayers: Number of MLP layers. + hidden_dim: Hidden layer size for intermediate MLP layers. + bottleneck_dim: Dimension of the final MLP output before projection. + mlp_bias: Whether to use bias in Linear layers. + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + use_bn: bool = False, + nlayers: int = 3, + hidden_dim: int = 2048, + bottleneck_dim: int = 256, + mlp_bias: bool = True, + ) -> None: + super().__init__() + + nlayers = max(nlayers, 1) + + self.mlp: nn.Module = _build_mlp( + nlayers=nlayers, + in_dim=in_dim, + bottleneck_dim=bottleneck_dim, + hidden_dim=hidden_dim, + use_bn=use_bn, + bias=mlp_bias, + ) + + self.apply(self._init_weights) + + self.last_layer: nn.Module = weight_norm( + nn.Linear(bottleneck_dim, out_dim, bias=False), + ) + self.last_layer.weight_g.data.fill_(1) + + def _init_weights(self, module: nn.Module) -> None: # noqa: PLR6301 + """Initialize Linear layers with truncated normal weights.""" + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + def forward(self, x: Tensor) -> Tensor: + """Run the DINO projection head forward pass.""" + x = self.mlp(x) + + eps: float = 1e-6 if x.dtype == torch.float16 else 1e-12 + x = nn.functional.normalize(x, dim=-1, p=2, eps=eps) + + return self.last_layer(x) + + +def _build_mlp( + nlayers: int, + in_dim: int, + bottleneck_dim: int, + hidden_dim: int | None = None, + use_bn: bool = False, + bias: bool = True, +) -> nn.Module: + """Construct an MLP with optional batch normalization. + + Args: + nlayers: Number of layers in the MLP. + in_dim: Input feature dimension. + bottleneck_dim: Output dimension of the final layer. + hidden_dim: Hidden dimension for intermediate layers. + use_bn: Whether to insert BatchNorm1d layers. + bias: Whether to enable Linear layer bias. + + Returns: + A fully constructed torch.nn.Module representing the MLP. + """ + if nlayers == 1: + return nn.Linear(in_dim, bottleneck_dim, bias=bias) + + assert hidden_dim is not None, "hidden_dim must be provided when nlayers > 1" + + layers: list[nn.Module] = [ + nn.Linear(in_dim, hidden_dim, bias=bias), + ] + + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + + layers.append(nn.GELU()) + + for _ in range(nlayers - 2): + layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias)) + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + + layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias)) + + return nn.Sequential(*layers) diff --git a/src/anomalib/models/components/dinov2/layers/drop_path.py b/src/anomalib/models/components/dinov2/layers/drop_path.py new file mode 100644 index 0000000000..def3884e82 --- /dev/null +++ b/src/anomalib/models/components/dinov2/layers/drop_path.py @@ -0,0 +1,62 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Stochastic depth drop-path implementation used in DINOv2. + +This module provides a functional drop-path operation and a corresponding +nn.Module wrapper. Drop-path (also known as stochastic depth) randomly +drops entire residual branches during training to improve model robustness. +""" + +from __future__ import annotations + +from torch import Tensor, nn + + +def drop_path(x: Tensor, drop_prob: float = 0.0, training: bool = False) -> Tensor: + """Apply stochastic depth to an input tensor. + + Args: + x: Input tensor to process. + drop_prob: Probability of dropping the path. + training: Whether the module is in training mode. + + Returns: + Tensor with dropped paths applied during training, or the original + tensor during evaluation. + + Notes: + Drop-path randomly zeroes the entire residual branch for each sample + in the batch while scaling the remaining samples appropriately. + """ + if drop_prob == 0.0 or not training: + return x + + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) + + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0: + random_tensor.div_(keep_prob) + + return x * random_tensor + + +class DropPath(nn.Module): + """Stochastic depth module for residual blocks. + + Applies drop-path per sample. During training, residual branches are + randomly removed with probability ``drop_prob`` while scaling the + remaining paths. In evaluation mode, the module becomes a no-op. + + Args: + drop_prob: Probability of dropping a path. + """ + + def __init__(self, drop_prob: float | None = None) -> None: + super().__init__() + self.drop_prob = drop_prob if drop_prob is not None else 0.0 + + def forward(self, x: Tensor) -> Tensor: + """Forward pass applying stochastic depth.""" + return drop_path(x, drop_prob=self.drop_prob, training=self.training) diff --git a/src/anomalib/models/components/dinov2/layers/layer_scale.py b/src/anomalib/models/components/dinov2/layers/layer_scale.py new file mode 100644 index 0000000000..a64e57257e --- /dev/null +++ b/src/anomalib/models/components/dinov2/layers/layer_scale.py @@ -0,0 +1,53 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""LayerScale module used in DINOv2. + +LayerScale applies a learnable per-channel scaling parameter (gamma) +to stabilize deep transformer training. It is frequently used in +Vision Transformers with residual connections. +""" + +from __future__ import annotations + +import torch +from torch import Tensor, nn + + +class LayerScale(nn.Module): + """Learnable per-channel scaling factor. + + This module introduces a learnable scale parameter ``gamma`` applied + to the input tensor. It is commonly used in modern transformer + architectures to improve optimization stability. + + Args: + dim: Number of feature channels. + init_values: Initial value for the scale parameter; may be a float + or a tensor of shape ``(dim,)``. + inplace: Whether to apply the scaling operation in-place. + device: Optional torch device for parameter initialization. + dtype: Optional torch dtype for parameter initialization. + """ + + def __init__( + self, + dim: int, + init_values: float | Tensor = 1e-5, + inplace: bool = False, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ) -> None: + super().__init__() + self.inplace = inplace + self.init_values = init_values + self.gamma = nn.Parameter(torch.empty(dim, device=device, dtype=dtype)) + self.reset_parameters() + + def reset_parameters(self) -> None: + """Reset scale parameters to their initialization values.""" + nn.init.constant_(self.gamma, self.init_values) + + def forward(self, x: Tensor) -> Tensor: + """Apply channel-wise scaling to the input tensor.""" + return x.mul_(self.gamma) if self.inplace else x * self.gamma diff --git a/src/anomalib/models/components/dinov2/layers/mlp.py b/src/anomalib/models/components/dinov2/layers/mlp.py new file mode 100644 index 0000000000..9dc80253a5 --- /dev/null +++ b/src/anomalib/models/components/dinov2/layers/mlp.py @@ -0,0 +1,60 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Feed-forward MLP block used in DINOv2 Vision Transformers. + +This module implements the standard 2-layer transformer MLP block with an +activation function and dropout. It is used as the feed-forward component +inside each transformer block. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from torch import Tensor, nn + +if TYPE_CHECKING: + from collections.abc import Callable + + +class Mlp(nn.Module): + """Two-layer feed-forward MLP used inside transformer blocks. + + Args: + in_features: Input feature dimension. + hidden_features: Dimension of the hidden expansion layer. Defaults to + ``in_features`` when ``None``. + out_features: Output feature dimension. Defaults to ``in_features`` when + ``None``. + act_layer: Activation layer constructor. + drop: Dropout probability applied after each layer. + bias: Whether linear layers use bias terms. + """ + + def __init__( + self, + in_features: int, + hidden_features: int | None = None, + out_features: int | None = None, + act_layer: Callable[..., nn.Module] = nn.GELU, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + + out_features = out_features or in_features + hidden_features = hidden_features or in_features + + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.drop = nn.Dropout(drop) + + def forward(self, x: Tensor) -> Tensor: + """Apply the two-layer feed-forward transformation.""" + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + return self.drop(x) diff --git a/src/anomalib/models/components/dinov2/layers/patch_embed.py b/src/anomalib/models/components/dinov2/layers/patch_embed.py new file mode 100644 index 0000000000..1e39b25b8b --- /dev/null +++ b/src/anomalib/models/components/dinov2/layers/patch_embed.py @@ -0,0 +1,110 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Patch embedding module for DINOv2 Vision Transformers. + +This module converts an image into a grid of patch embeddings using a strided +convolution. It supports square or rectangular image sizes and patch sizes, +optional output reshaping, and optional normalization. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from torch import Tensor, nn + +if TYPE_CHECKING: + from collections.abc import Callable + + +def make_2tuple(x: int | tuple[int, int]) -> tuple[int, int]: + """Ensure a value is represented as a 2-tuple. + + Args: + x: Integer or tuple representing height/width. + + Returns: + A tuple ``(h, w)``. + """ + if isinstance(x, tuple): + assert len(x) == 2 + return x + return (x, x) + + +class PatchEmbed(nn.Module): + """Image-to-patch embedding layer. + + Converts a 2D image tensor of shape ``(B, C, H, W)`` into a sequence of + flattened patch embeddings of shape ``(B, N, D)`` where ``N`` is the number + of patches and ``D`` is the embedding dimension. + + Args: + img_size: Input image size (integer or ``(H, W)``). + patch_size: Patch dimensions (integer or ``(H_p, W_p)``). + in_chans: Number of input channels. + embed_dim: Output embedding dimension. + norm_layer: Optional normalization layer constructor. + flatten_embedding: Whether to flatten to ``(B, N, D)`` (True) or return + ``(B, H_p, W_p, D)`` (False). + """ + + def __init__( + self, + img_size: int | tuple[int, int] = 224, + patch_size: int | tuple[int, int] = 16, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: Callable[[int], nn.Module] | None = None, + flatten_embedding: bool = True, + ) -> None: + super().__init__() + + image_hw = make_2tuple(img_size) + patch_hw = make_2tuple(patch_size) + + grid_h = image_hw[0] // patch_hw[0] + grid_w = image_hw[1] // patch_hw[1] + + self.img_size = image_hw + self.patch_size = patch_hw + self.patches_resolution = (grid_h, grid_w) + self.num_patches = grid_h * grid_w + + self.in_chans = in_chans + self.embed_dim = embed_dim + self.flatten_embedding = flatten_embedding + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_hw, stride=patch_hw) + self.norm = norm_layer(embed_dim) if norm_layer is not None else nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + """Embed the input image into patch tokens.""" + _, _, h, w = x.shape + patch_h, patch_w = self.patch_size + + if h % patch_h != 0: + msg = f"Input image height {h} must be divisible by patch height {patch_h}" + raise AssertionError(msg) + if w % patch_w != 0: + msg = f"Input image width {w} must be divisible by patch width {patch_w}" + raise AssertionError(msg) + + x = self.proj(x) # (B, D, H', W') + h_out, w_out = x.shape[2], x.shape[3] + + x = x.flatten(2).transpose(1, 2) # (B, N, D) + x = self.norm(x) + + if not self.flatten_embedding: + x = x.reshape(-1, h_out, w_out, self.embed_dim) # (B, H', W', D) + + return x + + def flops(self) -> float: + """Compute FLOPs for the patch embedding layer.""" + grid_h, grid_w = self.patches_resolution + flops = grid_h * grid_w * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + flops += grid_h * grid_w * self.embed_dim # normalization cost + return float(flops) diff --git a/src/anomalib/models/components/dinov2/layers/swiglu_ffn.py b/src/anomalib/models/components/dinov2/layers/swiglu_ffn.py new file mode 100644 index 0000000000..1866574f74 --- /dev/null +++ b/src/anomalib/models/components/dinov2/layers/swiglu_ffn.py @@ -0,0 +1,146 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""SwiGLU-based feed-forward layers used in DINOv2. + +This module provides multiple variants of SwiGLU feed-forward networks, +including: +- A pure PyTorch implementation (`SwiGLUFFN`) +- A fused xFormers version when available (`SwiGLUFFNFused`) +- An aligned variant for memory efficiency (`SwiGLUFFNAligned`) + +These layers are used as transformer FFN blocks in DINOv2 models. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch.nn.functional as F # noqa: N812 +from torch import Tensor, nn + +if TYPE_CHECKING: + from collections.abc import Callable + + +class SwiGLUFFN(nn.Module): + """Pure PyTorch SwiGLU feed-forward network. + + This network computes: + hidden = silu(W1(x)) * W2(x) + output = W3(hidden) + + Args: + in_features: Input feature dimension. + hidden_features: Hidden layer dimension (defaults to ``in_features``). + out_features: Output feature dimension (defaults to ``in_features``). + act_layer: Unused placeholder to mimic MLP API. + drop: Unused dropout placeholder for API compatibility. + bias: Whether to use bias in linear layers. + """ + + def __init__( + self, + in_features: int, + hidden_features: int | None = None, + out_features: int | None = None, + act_layer: Callable[..., nn.Module] | None = None, # noqa: ARG002 + drop: float = 0.0, # noqa: ARG002 + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + + self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) + self.w3 = nn.Linear(hidden_features, out_features, bias=bias) + + def forward(self, x: Tensor) -> Tensor: + """Apply the SwiGLU feed-forward transformation.""" + x12 = self.w12(x) + x1, x2 = x12.chunk(2, dim=-1) + hidden = F.silu(x1) * x2 + return self.w3(hidden) + + +class SwiGLUFFNFused(SwiGLUFFN): + """Fused SwiGLU FFN using xFormers when available. + + This implementation reduces memory usage by aligning hidden dimensions + and delegating the SwiGLU computation to optimized xFormers kernels. + + Args: + in_features: Input feature dimension. + hidden_features: Hidden layer dimension (defaults to ``in_features``). + out_features: Output feature dimension (defaults to ``in_features``). + act_layer: Unused placeholder for API compatibility. + drop: Unused dropout placeholder. + bias: Whether linear layers use bias. + """ + + def __init__( + self, + in_features: int, + hidden_features: int | None = None, + out_features: int | None = None, + act_layer: Callable[..., nn.Module] | None = None, # noqa: ARG002 + drop: float = 0.0, # noqa: ARG002 + bias: bool = True, + ) -> None: + out_features = out_features or in_features + hidden_features = hidden_features or in_features + + # Align hidden dimension for fused kernels + hidden_aligned = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + + super().__init__( + in_features=in_features, + hidden_features=hidden_aligned, + out_features=out_features, + bias=bias, + ) + + +class SwiGLUFFNAligned(nn.Module): + """SwiGLU FFN with explicit alignment for hardware efficiency. + + Args: + in_features: Input feature dimension. + hidden_features: Hidden layer dimension (defaults to ``in_features``). + out_features: Output feature dimension (defaults to ``in_features``). + act_layer: Activation layer (unused; API compatibility). + drop: Dropout (unused). + bias: Whether linear layers use bias. + align_to: Alignment multiple for hidden dimension. + device: Optional device for parameter initialization. + """ + + def __init__( + self, + in_features: int, + hidden_features: int | None = None, + out_features: int | None = None, + act_layer: Callable[..., nn.Module] = nn.GELU, # noqa: ARG002 + drop: float = 0.0, # noqa: ARG002 + bias: bool = True, + align_to: int = 8, + device=None, # noqa: ANN001 + ) -> None: + super().__init__() + + out_features = out_features or in_features + hidden_features = hidden_features or in_features + + d = int(hidden_features * 2 / 3) + hidden_aligned = d + (-d % align_to) + + self.w1 = nn.Linear(in_features, hidden_aligned, bias=bias, device=device) + self.w2 = nn.Linear(in_features, hidden_aligned, bias=bias, device=device) + self.w3 = nn.Linear(hidden_aligned, out_features, bias=bias, device=device) + + def forward(self, x: Tensor) -> Tensor: + """Apply aligned SwiGLU feed-forward transformation.""" + x1 = self.w1(x) + x2 = self.w2(x) + hidden = F.silu(x1) * x2 + return self.w3(hidden) diff --git a/src/anomalib/models/components/dinov2/vision_transformer.py b/src/anomalib/models/components/dinov2/vision_transformer.py new file mode 100644 index 0000000000..836eadac0a --- /dev/null +++ b/src/anomalib/models/components/dinov2/vision_transformer.py @@ -0,0 +1,542 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Loader for DINOv2 Vision Transformer models. + +This module provides a PyTorch implementation of the DINOv2 Vision Transformer +architecture. It includes utilities for building transformers of different sizes, +handling positional embeddings, preparing tokens with masking, extracting intermediate +layer outputs, and applying initialization schemes compatible with the timm library. + +The module forms the backbone for DINO-based feature extraction used in Dinomaly +and related anomaly detection frameworks. +""" + +from __future__ import annotations + +import math +from functools import partial +from typing import TYPE_CHECKING + +import numpy as np +import torch +from torch import nn +from torch.nn.init import trunc_normal_ + +from anomalib.models.components.dinov2.layers import Block, MemEffAttention, Mlp, PatchEmbed, SwiGLUFFNFused + +if TYPE_CHECKING: + from collections.abc import Callable, Sequence + + +def named_apply( + fn: Callable[..., object], + module: nn.Module, + name: str = "", + depth_first: bool = True, + include_root: bool = False, +) -> nn.Module: + """Recursively apply a function to a module and all its children. + + Args: + fn: Callable applied to each visited module. + module: Module to traverse. + name: Base name for hierarchical module naming. + depth_first: If True, apply the function after visiting children. + include_root: If True, apply function to the root module itself. + + Returns: + The modified module. + """ + if not depth_first and include_root: + fn(module=module, name=name) + for child_name, child_module in module.named_children(): + full_name = f"{name}.{child_name}" if name else child_name + named_apply( + fn=fn, + module=child_module, + name=full_name, + depth_first=depth_first, + include_root=True, + ) + if depth_first and include_root: + fn(module=module, name=name) + return module + + +class BlockChunk(nn.ModuleList): + """Container for sequential execution of transformer blocks. + + This utility groups multiple transformer blocks into a single module list + to improve processing efficiency, particularly in distributed or chunked + execution settings. + """ + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply all blocks in the chunk sequentially. + + Args: + x: Input tensor. + + Returns: + Tensor output after sequential block processing. + """ + for b in self: + x = b(x) + return x + + +class DinoVisionTransformer(nn.Module): + """Vision Transformer backbone used in DINOv2. + + This class implements the complete forward pipeline for DINOv2-style Vision + Transformers, including patch embedding, positional encoding, transformer blocks, + register token handling, intermediate-layer extraction, and output normalization. + + The architecture supports: + - Optional register tokens. + - Chunked transformer blocks for FSDP or memory-efficient training. + - Mask tokens for masked feature modeling. + - Interpolated positional encodings for variable-sized input images. + - Flexible FFN selection (MLP, SwiGLU, fused SwiGLU, identity). + + Args: + img_size: Input image resolution. + patch_size: Patch size for patch embedding. + in_chans: Number of input channels. + embed_dim: Embedding dimensionality. + depth: Number of transformer layers. + num_heads: Number of self-attention heads. + mlp_ratio: Expansion ratio in feed-forward layers. + qkv_bias: Whether to include bias in QKV projections. + ffn_bias: Whether to include bias in FFN layers. + proj_bias: Whether to include bias in projection layers. + drop_path_rate: Stochastic depth rate. + drop_path_uniform: Whether to apply uniform drop-path. + init_values: Initial values for layer-scale (None disables). + embed_layer: Patch embedding layer class. + act_layer: Activation function. + block_fn: Transformer block class. + ffn_layer: Feed-forward layer type or constructor. + block_chunks: Number of chunks to split block sequence into. + num_register_tokens: Number of extra learned register tokens. + interpolate_antialias: Whether to apply antialiasing on position interpolation. + interpolate_offset: Offset to avoid floating point interpolation artifacts. + """ + + def __init__( + self, + img_size: int = 224, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + ffn_bias: bool = True, + proj_bias: bool = True, + drop_path_rate: float = 0.0, + drop_path_uniform: bool = False, + init_values: float | None = None, + embed_layer: type[nn.Module] = PatchEmbed, + act_layer: type[nn.Module] = nn.GELU, + block_fn: Callable[..., nn.Module] = Block, + ffn_layer: str | Callable[..., nn.Module] = "mlp", + block_chunks: int = 1, + num_register_tokens: int = 0, + interpolate_antialias: bool = False, + interpolate_offset: float = 0.1, + ) -> None: + super().__init__() + norm_layer = partial(nn.LayerNorm, eps=1e-6) + + self.num_features: int = embed_dim + self.embed_dim: int = embed_dim + self.num_tokens: int = 1 + self.n_blocks: int = depth + self.num_heads: int = num_heads + self.patch_size: int = patch_size + self.num_register_tokens: int = num_register_tokens + self.interpolate_antialias: bool = interpolate_antialias + self.interpolate_offset: float = interpolate_offset + + self.patch_embed: nn.Module = embed_layer( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + ) + num_patches = self.patch_embed.num_patches + + self.cls_token: nn.Parameter = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed: nn.Parameter = nn.Parameter( + torch.zeros(1, num_patches + self.num_tokens, embed_dim), + ) + self.register_tokens: nn.Parameter | None = ( + nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None + ) + + dpr = [drop_path_rate] * depth if drop_path_uniform else np.linspace(0, drop_path_rate, depth).tolist() + + if ffn_layer == "mlp": + ffn_layer = Mlp + elif ffn_layer in {"swiglu", "swiglufused"}: + ffn_layer = SwiGLUFFNFused + elif ffn_layer == "identity": + + def f(*args: object, **kwargs: object) -> nn.Identity: # noqa: ARG001 + return nn.Identity() + + ffn_layer = f + elif isinstance(ffn_layer, str): + raise NotImplementedError + # else assume callable + + blocks_list: list[nn.Module] = [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + ffn_layer=ffn_layer, + init_values=init_values, + ) + for i in range(depth) + ] + + if block_chunks > 0: + self.chunked_blocks: bool = True + chunksize = depth // block_chunks + chunked_blocks: list[list[nn.Module]] = [ + [nn.Identity()] * i + blocks_list[i : i + chunksize] for i in range(0, depth, chunksize) + ] + self.blocks: nn.ModuleList = nn.ModuleList(BlockChunk(p) for p in chunked_blocks) + else: + self.chunked_blocks = False + self.blocks = nn.ModuleList(blocks_list) + + self.norm: nn.LayerNorm = norm_layer(embed_dim) + self.head: nn.Module = nn.Identity() + self.mask_token: nn.Parameter = nn.Parameter(torch.zeros(1, embed_dim)) + self.init_weights() + + def init_weights(self) -> None: + """Initialize model weights, positional embeddings, and register tokens.""" + trunc_normal_(self.pos_embed, std=0.02) + nn.init.normal_(self.cls_token, std=1e-6) + if self.register_tokens is not None: + nn.init.normal_(self.register_tokens, std=1e-6) + named_apply(init_weights_vit_timm, self) + + def interpolate_pos_encoding( + self, + x: torch.Tensor, + w: int, + h: int, + ) -> torch.Tensor: + """Interpolate positional encodings for inputs whose resolution differs from training time.""" + previous_dtype = x.dtype + npatch = x.shape[1] - 1 + n_pos = self.pos_embed.shape[1] - 1 + if npatch == n_pos and w == h: + return self.pos_embed + + pos_embed = self.pos_embed.float() + class_pos_embed = pos_embed[:, 0] + patch_pos_embed = pos_embed[:, 1:] + dim = x.shape[-1] + + w0 = w // self.patch_size + h0 = h // self.patch_size + m = int(math.sqrt(n_pos)) + assert n_pos == m * m + + kwargs: dict[str, object] = {} + if self.interpolate_offset: + sx = float(w0 + self.interpolate_offset) / m + sy = float(h0 + self.interpolate_offset) / m + kwargs["scale_factor"] = (sx, sy) + else: + kwargs["size"] = (w0, h0) + + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, m, m, dim).permute(0, 3, 1, 2), + mode="bicubic", + antialias=self.interpolate_antialias, + **kwargs, + ) + + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) + + def prepare_tokens_with_masks( + self, + x: torch.Tensor, + masks: torch.Tensor | None = None, + ) -> torch.Tensor: + """Prepare input tokens with optional masking and positional encoding.""" + _, _, w, h = x.shape + x = self.patch_embed(x) + if masks is not None: + x = torch.where( + masks.unsqueeze(-1), + self.mask_token.to(x.dtype).unsqueeze(0), + x, + ) + + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + x = x + self.interpolate_pos_encoding(x, w, h) + + if self.register_tokens is not None: + x = torch.cat( + ( + x[:, :1], + self.register_tokens.expand(x.shape[0], -1, -1), + x[:, 1:], + ), + dim=1, + ) + return x + + def forward_features_list( + self, + x_list: list[torch.Tensor], + masks_list: list[torch.Tensor | None], + ) -> list[dict[str, torch.Tensor | None]]: + """Compute forward features for a list of inputs with corresponding masks.""" + x: list[torch.Tensor] = [ + self.prepare_tokens_with_masks(x_item, masks) for x_item, masks in zip(x_list, masks_list, strict=True) + ] + for blk in self.blocks: + x = blk(x) + + output: list[dict[str, torch.Tensor | None]] = [] + for x_item, masks in zip(x, masks_list, strict=True): + x_norm = self.norm(x_item) + output.append( + { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x_item, + "masks": masks, + }, + ) + return output + + def forward_features( + self, + x: torch.Tensor | list[torch.Tensor], + masks: torch.Tensor | list[torch.Tensor | None] | None = None, + ) -> dict[str, torch.Tensor | None] | list[dict[str, torch.Tensor | None]]: + """Compute forward features for a single batch or list of batches.""" + if isinstance(x, list): + masks_list: list[torch.Tensor | None] + if masks is None: + masks_list = [None] * len(x) + elif isinstance(masks, list): + masks_list = masks + else: + masks_list = [masks for _ in x] + return self.forward_features_list(x, masks_list) + + features = self.prepare_tokens_with_masks( + x, + masks if isinstance(masks, torch.Tensor) else None, + ) + + for blk in self.blocks: + features = blk(features) + + x_norm = self.norm(features) + return { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": features, + "masks": masks if isinstance(masks, torch.Tensor) else None, + } + + def _get_intermediate_layers_not_chunked( + self, + x: torch.Tensor, + n: int | Sequence[int] = 1, + ) -> list[torch.Tensor]: + """Extract intermediate outputs from specific layers when blocks are not chunked.""" + x = self.prepare_tokens_with_masks(x) + output: list[torch.Tensor] = [] + total_block_len = len(self.blocks) + blocks_to_take: range | Sequence[int] + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for i, blk in enumerate(self.blocks): + x = blk(x) + if i in blocks_to_take: + output.append(x) + return output + + def _get_intermediate_layers_chunked( + self, + x: torch.Tensor, + n: int | Sequence[int] = 1, + ) -> list[torch.Tensor]: + """Extract intermediate outputs from specific layers when blocks are chunked.""" + x = self.prepare_tokens_with_masks(x) + output: list[torch.Tensor] = [] + i = 0 + total_block_len = len(self.blocks[-1]) + blocks_to_take: range | Sequence[int] + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for block_chunk in self.blocks: + for blk in block_chunk[i:]: + x = blk(x) + if i in blocks_to_take: + output.append(x) + i += 1 + return output + + def get_intermediate_layers( + self, + x: torch.Tensor, + n: int | Sequence[int] = 1, + reshape: bool = False, + return_class_token: bool = False, + norm: bool = True, + ) -> tuple[torch.Tensor, ...] | tuple[tuple[torch.Tensor, torch.Tensor], ...]: + """Retrieve intermediate layer outputs. + + Args: + x: Input tensor. + n: Number of layers or explicit list of layer indices. + reshape: Whether to reshape patch tokens into feature maps. + return_class_token: Whether to include class tokens in the output. + norm: Whether to apply final normalization. + + Returns: + Tuple of intermediate outputs, optionally paired with class tokens. + """ + outputs = ( + self._get_intermediate_layers_chunked(x, n) + if self.chunked_blocks + else self._get_intermediate_layers_not_chunked(x, n) + ) + + if norm: + outputs = [self.norm(out) for out in outputs] + + class_tokens = [out[:, 0] for out in outputs] + outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs] + + if reshape: + batch_size, _, w, h = x.shape + outputs = [ + out.reshape(batch_size, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous() + for out in outputs + ] + if return_class_token: + return tuple(zip(outputs, class_tokens, strict=True)) + return tuple(outputs) + + def forward( + self, + *args: object, + is_training: bool = False, + **kwargs: object, + ) -> dict[str, torch.Tensor | None] | torch.Tensor: + """Apply the forward pass, returning classification output or full features.""" + ret = self.forward_features(*args, **kwargs) + if is_training: + return ret + if isinstance(ret, list): + msg = "forward() received list output in inference mode" + raise TypeError(msg) + # inference: ret is a dict for non-list input + return self.head(ret["x_norm_clstoken"]) # type: ignore[misc] + + +def init_weights_vit_timm( + module: nn.Module, + name: str = "", # noqa: ARG001 +) -> None: + """Initialize module weights following the timm ViT initialization scheme.""" + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + + +def vit_small( + patch_size: int = 16, + num_register_tokens: int = 0, + **kwargs, +) -> DinoVisionTransformer: + """Construct a small DINO Vision Transformer (ViT-S/16).""" + return DinoVisionTransformer( + patch_size=patch_size, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + + +def vit_base( + patch_size: int = 16, + num_register_tokens: int = 0, + **kwargs, +) -> DinoVisionTransformer: + """Construct a base DINO Vision Transformer (ViT-B/16).""" + return DinoVisionTransformer( + patch_size=patch_size, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + + +def vit_large( + patch_size: int = 16, + num_register_tokens: int = 0, + **kwargs, +) -> DinoVisionTransformer: + """Construct a large DINO Vision Transformer (ViT-L/16).""" + return DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + + +def vit_giant2( + patch_size: int = 16, + num_register_tokens: int = 0, + **kwargs, +) -> DinoVisionTransformer: + """Construct a Giant-2 DINO Vision Transformer variant.""" + return DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1536, + depth=40, + num_heads=24, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) diff --git a/src/anomalib/models/image/anomaly_dino/torch_model.py b/src/anomalib/models/image/anomaly_dino/torch_model.py index 93e035d81c..5cc81826bf 100644 --- a/src/anomalib/models/image/anomaly_dino/torch_model.py +++ b/src/anomalib/models/image/anomaly_dino/torch_model.py @@ -28,7 +28,7 @@ from anomalib.data import InferenceBatch from anomalib.models.components import DynamicBufferMixin, KCenterGreedy -from anomalib.models.image.dinomaly.components import load as load_dinov2_model +from anomalib.models.components.dinov2 import DinoV2Loader from anomalib.models.image.patchcore.anomaly_map import AnomalyMapGenerator @@ -78,7 +78,7 @@ def __init__( if not encoder_name.startswith("dinov2"): err_str = f"Encoder must be dinov2, got {encoder_name}" raise ValueError(err_str) - self.feature_encoder = load_dinov2_model(self.encoder_name) + self.feature_encoder = DinoV2Loader.from_name(self.encoder_name) self.feature_encoder.eval() # Memory bank and embedding storage @@ -124,9 +124,7 @@ def extract_features(self, image_tensor: torch.Tensor) -> torch.Tensor: where ``N`` is the number of patches and ``D`` the feature dimension. """ with torch.inference_mode(): - tokens = self.feature_encoder.get_intermediate_layers(image_tensor, n=1)[0] - start = self.feature_encoder.num_tokens + self.feature_encoder.num_register_tokens - return tokens[:, start:, :] + return self.feature_encoder.get_intermediate_layers(image_tensor, n=1)[0] @staticmethod def compute_background_masks( From d0c9c14c77306569a06a19fa814bed3b518d2c96 Mon Sep 17 00:00:00 2001 From: Alfie Roddan <228966941+alfieroddanintel@users.noreply.github.com> Date: Thu, 13 Nov 2025 21:15:22 +0900 Subject: [PATCH 11/25] change dinov2loader to factory method, remove duplicated components from dinomaly Signed-off-by: Alfie Roddan <228966941+alfieroddanintel@users.noreply.github.com> --- .../models/components/dinov2/dinov2_loader.py | 53 ++--- .../image/dinomaly/components/__init__.py | 8 +- .../dinomaly/components/dinov2_loader.py | 191 ------------------ .../image/dinomaly/components/layers.py | 44 +--- .../dinomaly/components/vision_transformer.py | 5 +- .../models/image/dinomaly/torch_model.py | 10 +- 6 files changed, 35 insertions(+), 276 deletions(-) delete mode 100644 src/anomalib/models/image/dinomaly/components/dinov2_loader.py diff --git a/src/anomalib/models/components/dinov2/dinov2_loader.py b/src/anomalib/models/components/dinov2/dinov2_loader.py index 6c122bafff..1d37458e52 100644 --- a/src/anomalib/models/components/dinov2/dinov2_loader.py +++ b/src/anomalib/models/components/dinov2/dinov2_loader.py @@ -8,7 +8,8 @@ Example: model = DinoV2Loader.from_name("dinov2_vit_base_14") - model = DinoV2Loader.from_name("dinomaly_vit_base_14") + model = DinoV2Loader.from_name("vit_base_14") + model = DinoV2Loader(vit_factory=my_custom_vit_module).load("dinov2reg_vit_base_14") """ from __future__ import annotations @@ -23,14 +24,12 @@ from anomalib.data.utils import DownloadInfo from anomalib.data.utils.download import DownloadProgressBar from anomalib.models.components.dinov2 import vision_transformer as dinov2_models -from anomalib.models.image.dinomaly.components import vision_transformer as dinomaly_models logger = logging.getLogger(__name__) MODEL_FACTORIES: dict[str, object] = { "dinov2": dinov2_models, "dinov2_reg": dinov2_models, - "dinomaly": dinomaly_models, } @@ -49,13 +48,13 @@ class DinoV2Loader: "large": {"embed_dim": 1024, "num_heads": 16}, } - def __init__(self, cache_dir: str | Path = "./pre_trained/") -> None: - """Initialize a model loader instance. - - Args: - cache_dir: Directory in which downloaded weights will be stored. - """ - self.cache_dir: Path = Path(cache_dir) + def __init__( + self, + cache_dir: str | Path = "./pre_trained/", + vit_factory: object | None = None, + ) -> None: + self.cache_dir = Path(cache_dir) + self.vit_factory = vit_factory self.cache_dir.mkdir(parents=True, exist_ok=True) def load(self, model_name: str) -> torch.nn.Module: @@ -71,7 +70,7 @@ def load(self, model_name: str) -> torch.nn.Module: ValueError: If the requested model name is malformed or unsupported. """ model_type, architecture, patch_size = self._parse_name(model_name) - model = self._create_model(model_type, architecture, patch_size) + model = self.create_model(model_type, architecture, patch_size) self._load_weights(model, model_type, architecture, patch_size) logger.info(f"Loaded model: {model_name}") @@ -122,26 +121,21 @@ def _parse_name(self, name: str) -> tuple[str, str, int]: return model_type, architecture, patch_size - @staticmethod - def _create_model( - model_type: str, - architecture: str, - patch_size: int, - ) -> torch.nn.Module: - """Construct a model instance using the configured factory modules. + def create_model(self, model_type: str, architecture: str, patch_size: int) -> torch.nn.Module: + """Create a Vision Transformer model. Args: - model_type: Model family, e.g., "dinov2", "dinov2_reg", "dinomaly". - architecture: Architecture label ("small", "base", "large"). - patch_size: Patch resolution. + model_type: Normalized model family name (e.g., "dinov2", "dinov2_reg"). + architecture: Architecture size (e.g., "small", "base", "large"). + patch_size: ViT patch size. Returns: - An instantiated PyTorch module. + Instantiated Vision Transformer model. Raises: - ValueError: If the relevant constructor cannot be found. + ValueError: If no matching constructor exists. """ - model_kwargs: dict[str, object] = { + model_kwargs = { "patch_size": patch_size, "img_size": 518, "block_chunks": 0, @@ -153,18 +147,15 @@ def _create_model( if model_type == "dinov2_reg": model_kwargs["num_register_tokens"] = 4 - module = MODEL_FACTORIES.get(model_type) - if module is None: - msg = f"Unknown model type '{model_type}'." - raise ValueError(msg) + # If user supplied a custom ViT module, use it + module = self.vit_factory or MODEL_FACTORIES[model_type] ctor = getattr(module, f"vit_{architecture}", None) if ctor is None: - msg = f"No constructor 'vit_{architecture}' in module {module}." + msg = f"No constructor vit_{architecture} in module {module}" raise ValueError(msg) - model: torch.nn.Module = ctor(**model_kwargs) - return model + return ctor(**model_kwargs) def _load_weights( self, diff --git a/src/anomalib/models/image/dinomaly/components/__init__.py b/src/anomalib/models/image/dinomaly/components/__init__.py index 70385c143b..c691865366 100644 --- a/src/anomalib/models/image/dinomaly/components/__init__.py +++ b/src/anomalib/models/image/dinomaly/components/__init__.py @@ -4,12 +4,9 @@ """Components module for Dinomaly model. This module provides all the necessary components for the Dinomaly Vision Transformer -architecture including layers, model loader, utilities, and vision transformer implementations. +architecture including layers, utilities, and vision transformer implementations. """ -# Model loader -from .dinov2_loader import DinoV2Loader, load - # Layer components from .layers import Block, DinomalyMLP, LinearAttention, MemEffAttention @@ -26,9 +23,6 @@ "DinomalyMLP", "LinearAttention", "MemEffAttention", - # Model loader - "DinoV2Loader", - "load", # Utils "StableAdamW", "WarmCosineScheduler", diff --git a/src/anomalib/models/image/dinomaly/components/dinov2_loader.py b/src/anomalib/models/image/dinomaly/components/dinov2_loader.py deleted file mode 100644 index a6f1206a30..0000000000 --- a/src/anomalib/models/image/dinomaly/components/dinov2_loader.py +++ /dev/null @@ -1,191 +0,0 @@ -# Copyright (C) 2025 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 - -"""Loader for DINOv2 Vision Transformer models. - -This module provides a simple interface for loading pre-trained DINOv2 Vision Transformer models for the -Dinomaly anomaly detection framework. -""" - -from __future__ import annotations - -import logging -from pathlib import Path -from typing import ClassVar -from urllib.request import urlretrieve - -import torch - -from anomalib.data.utils import DownloadInfo -from anomalib.data.utils.download import DownloadProgressBar -from anomalib.models.image.dinomaly.components import vision_transformer as dinov2_models - -logger = logging.getLogger(__name__) - - -class DinoV2Loader: - """Simple loader for DINOv2 Vision Transformer models. - - Supports loading dinov2 and dinov2_reg models with small, base, and large architectures. - """ - - DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2" - - # Model configurations - MODEL_CONFIGS: ClassVar[dict[str, dict[str, int]]] = { - "small": {"embed_dim": 384, "num_heads": 6}, - "base": {"embed_dim": 768, "num_heads": 12}, - "large": {"embed_dim": 1024, "num_heads": 16}, - } - - def __init__(self, cache_dir: str | Path = "./pre_trained/") -> None: - """Initialize model loader. - - Args: - cache_dir: Directory to store downloaded model weights. - """ - self.cache_dir = Path(cache_dir) - self.cache_dir.mkdir(parents=True, exist_ok=True) - - def load(self, model_name: str) -> torch.nn.Module: - """Load a DINOv2 model by name. - - Args: - model_name: Name like 'dinov2_vit_base_14' or 'dinov2reg_vit_small_14'. - - Returns: - Loaded PyTorch model ready for inference. - - Raises: - ValueError: If model name is invalid or unsupported. - """ - # Parse model name - model_type, architecture, patch_size = self._parse_name(model_name) - - # Create model - model = self._create_model(model_type, architecture, patch_size) - - # Load weights - self._load_weights(model, model_type, architecture, patch_size) - - logger.info(f"Loaded model: {model_name}") - return model - - def _parse_name(self, name: str) -> tuple[str, str, int]: - """Parse model name into components.""" - parts = name.split("_") - - if len(parts) < 3: - msg = f"Invalid model name format: {name}. Expected format: 'dinov2_vit__'" - raise ValueError(msg) - - # Determine model type and extract architecture/patch_size - if "dinov2reg" in name or "reg" in name: - model_type = "dinov2_reg" - architecture = parts[-2] - patch_size = int(parts[-1]) - else: - model_type = "dinov2" - architecture = parts[-2] - patch_size = int(parts[-1]) - - if architecture not in self.MODEL_CONFIGS: - valid_archs = list(self.MODEL_CONFIGS.keys()) - msg = f"Invalid architecture '{architecture}' in model name '{name}'. Valid architectures: {valid_archs}" - raise ValueError(msg) - - return model_type, architecture, patch_size - - @staticmethod - def _create_model(model_type: str, architecture: str, patch_size: int) -> torch.nn.Module: - """Create model with appropriate configuration.""" - model_kwargs = { - "patch_size": patch_size, - "img_size": 518, - "block_chunks": 0, - "init_values": 1e-8, - "interpolate_antialias": False, - "interpolate_offset": 0.1, - } - - # Add register tokens for reg models - if model_type == "dinov2_reg": - model_kwargs["num_register_tokens"] = 4 - - # Get model constructor function - model_fn = getattr(dinov2_models, f"vit_{architecture}", None) - if model_fn is None: - msg = f"Model function vit_{architecture} not found in dinov2_models" - raise ValueError(msg) - - return model_fn(**model_kwargs) - - def _load_weights(self, model: torch.nn.Module, model_type: str, architecture: str, patch_size: int) -> None: - """Download and load model weights using standardized Anomalib utilities.""" - weight_path = self._get_weight_path(model_type, architecture, patch_size) - - if not weight_path.exists(): - self._download_weights(model_type, architecture, patch_size) - - # Weights_only is set to True - # See mitigation details in https://github.com/open-edge-platform/anomalib/pull/2729 - # nosemgrep: trailofbits.python.pickles-in-pytorch.pickles-in-pytorch - state_dict = torch.load(weight_path, map_location="cpu", weights_only=True) # nosec B614 - model.load_state_dict(state_dict, strict=False) - - def _get_weight_path(self, model_type: str, architecture: str, patch_size: int) -> Path: - """Get local path for model weights.""" - arch_code = architecture[0] # s, b, or l - - if model_type == "dinov2_reg": - filename = f"dinov2_vit{arch_code}{patch_size}_reg4_pretrain.pth" - else: - filename = f"dinov2_vit{arch_code}{patch_size}_pretrain.pth" - - return self.cache_dir / filename - - def _download_weights(self, model_type: str, architecture: str, patch_size: int) -> None: - """Download model weights using standardized Anomalib download utilities.""" - arch_code = architecture[0] - weight_path = self._get_weight_path(model_type, architecture, patch_size) - - # Build download URL - model_dir = f"dinov2_vit{arch_code}{patch_size}" - url = f"{self.DINOV2_BASE_URL}/{model_dir}/{weight_path.name}" - - # Create DownloadInfo for standardized download - download_info = DownloadInfo( - name=f"DINOv2 {model_type} {architecture} weights", - url=url, - hashsum="", # DINOv2 doesn't provide official hashes, but we use empty string for now - filename=weight_path.name, - ) - - logger.info(f"Downloading DINOv2 weights: {weight_path.name} to {self.cache_dir}") - - # Ensure cache directory exists - self.cache_dir.mkdir(parents=True, exist_ok=True) - - # Download with progress bar (following Anomalib patterns) - with DownloadProgressBar(unit="B", unit_scale=True, miniters=1, desc=download_info.name) as progress_bar: - # nosemgrep: python.lang.security.audit.dynamic-urllib-use-detected.dynamic-urllib-use-detected # noqa: ERA001, E501 - urlretrieve( # noqa: S310 # nosec B310 - url=url, - filename=weight_path, - reporthook=progress_bar.update_to, - ) - - -def load(model_name: str) -> torch.nn.Module: - """Convenience function to load a model. - - This can be later extended to be a factory method to load other models. - - Args: - model_name: Name like 'dinov2_vit_base_14' or 'dinov2reg_vit_small_14'. - - Returns: - Loaded PyTorch model. - """ - loader = DinoV2Loader() - return loader.load(model_name) diff --git a/src/anomalib/models/image/dinomaly/components/layers.py b/src/anomalib/models/image/dinomaly/components/layers.py index d8bcd9d311..32bfae914b 100644 --- a/src/anomalib/models/image/dinomaly/components/layers.py +++ b/src/anomalib/models/image/dinomaly/components/layers.py @@ -16,52 +16,12 @@ from typing import Any import torch -from timm.layers.drop import DropPath -from timm.models.vision_transformer import Attention, LayerScale from torch import Tensor, nn from torch.nn import functional as F # noqa: N812 -logger = logging.getLogger("dinov2") - - -class MemEffAttention(Attention): - """Memory-efficient attention from the dinov2 implementation with a small change. - - Reference: - https://github.com/facebookresearch/dinov2/blob/592541c8d842042bb5ab29a49433f73b544522d5/dinov2/eval/segmentation_m2f/models/backbones/vit.py#L159 - - Instead of using xformers's memory_efficient_attention() method, which requires adding a new dependency to anomalib, - this implementation uses the scaled dot product from torch. - """ - - def forward(self, x: Tensor, attn_bias: Tensor | None = None) -> Tensor: - """Compute memory-efficient attention using PyTorch's scaled dot product attention. - - Args: - x: Input tensor of shape (batch_size, seq_len, embed_dim). - attn_bias: Optional attention bias mask. Default: None. +from anomalib.models.components.dinov2.layers import Attention, DropPath, LayerScale, MemEffAttention - Returns: - Output tensor of shape (batch_size, seq_len, embed_dim). - """ - batch_size, seq_len, embed_dim = x.shape - qkv = self.qkv(x).reshape(batch_size, seq_len, 3, self.num_heads, embed_dim // self.num_heads) - - q, k, v = qkv.unbind(2) - - # Use PyTorch's native scaled dot product attention for memory efficiency. - # Replaced xformers's memory_efficient_attention() method with pytorch's scaled - # dot product. - x = F.scaled_dot_product_attention( - q.transpose(1, 2), - k.transpose(1, 2), - v.transpose(1, 2), - attn_mask=attn_bias, - ) - x = x.transpose(1, 2).reshape(batch_size, seq_len, embed_dim) - - x = self.proj(x) - return self.proj_drop(x) +logger = logging.getLogger("dinov2") class LinearAttention(nn.Module): diff --git a/src/anomalib/models/image/dinomaly/components/vision_transformer.py b/src/anomalib/models/image/dinomaly/components/vision_transformer.py index 7704ffd2f3..fc9310a426 100644 --- a/src/anomalib/models/image/dinomaly/components/vision_transformer.py +++ b/src/anomalib/models/image/dinomaly/components/vision_transformer.py @@ -19,12 +19,11 @@ from functools import partial import torch -import torch.utils.checkpoint -from timm.layers.patch_embed import PatchEmbed from torch import nn from torch.nn.init import trunc_normal_ -from anomalib.models.image.dinomaly.components.layers import Block, DinomalyMLP, MemEffAttention +from anomalib.models.components.dinov2.layers import MemEffAttention, PatchEmbed +from anomalib.models.image.dinomaly.components.layers import Block, DinomalyMLP logger = logging.getLogger("dinov2") diff --git a/src/anomalib/models/image/dinomaly/torch_model.py b/src/anomalib/models/image/dinomaly/torch_model.py index 2dd567ed52..a362f64feb 100644 --- a/src/anomalib/models/image/dinomaly/torch_model.py +++ b/src/anomalib/models/image/dinomaly/torch_model.py @@ -22,8 +22,11 @@ from anomalib.data import InferenceBatch from anomalib.models.components import GaussianBlur2d +from anomalib.models.components.dinov2 import DinoV2Loader from anomalib.models.image.dinomaly.components import CosineHardMiningLoss, DinomalyMLP, LinearAttention -from anomalib.models.image.dinomaly.components import load as load_dinov2_model +from anomalib.models.image.dinomaly.components.vision_transformer import ( + DinoVisionTransformer as DinomalyVisionTransformer, +) # Encoder architecture configurations for DINOv2 models. # The target layers are the @@ -117,7 +120,10 @@ def __init__( if fuse_layer_decoder is None: fuse_layer_decoder = DEFAULT_FUSE_LAYERS - encoder = load_dinov2_model(encoder_name) + self.encoder_name = encoder_name + encoder = DinoV2Loader( + vit_factory=DinomalyVisionTransformer, + ).from_name(encoder_name) # Extract architecture configuration based on the model name arch_config = self._get_architecture_config(encoder_name, target_layers) From e82da760c3b4a24f98a5cd7a9eac0a28c2726d0b Mon Sep 17 00:00:00 2001 From: Alfie Roddan <228966941+alfieroddanintel@users.noreply.github.com> Date: Thu, 13 Nov 2025 21:31:18 +0900 Subject: [PATCH 12/25] add from_name back Signed-off-by: Alfie Roddan <228966941+alfieroddanintel@users.noreply.github.com> --- src/anomalib/models/components/dinov2/dinov2_loader.py | 2 +- .../image/dinomaly/components/vision_transformer.py | 3 ++- src/anomalib/models/image/dinomaly/torch_model.py | 8 +++----- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/src/anomalib/models/components/dinov2/dinov2_loader.py b/src/anomalib/models/components/dinov2/dinov2_loader.py index 1d37458e52..05b2b88561 100644 --- a/src/anomalib/models/components/dinov2/dinov2_loader.py +++ b/src/anomalib/models/components/dinov2/dinov2_loader.py @@ -148,7 +148,7 @@ def create_model(self, model_type: str, architecture: str, patch_size: int) -> t model_kwargs["num_register_tokens"] = 4 # If user supplied a custom ViT module, use it - module = self.vit_factory or MODEL_FACTORIES[model_type] + module = self.vit_factory if self.vit_factory is not None else MODEL_FACTORIES[model_type] ctor = getattr(module, f"vit_{architecture}", None) if ctor is None: diff --git a/src/anomalib/models/image/dinomaly/components/vision_transformer.py b/src/anomalib/models/image/dinomaly/components/vision_transformer.py index fc9310a426..d1ab5513cc 100644 --- a/src/anomalib/models/image/dinomaly/components/vision_transformer.py +++ b/src/anomalib/models/image/dinomaly/components/vision_transformer.py @@ -19,10 +19,11 @@ from functools import partial import torch +from timm.layers.patch_embed import PatchEmbed from torch import nn from torch.nn.init import trunc_normal_ -from anomalib.models.components.dinov2.layers import MemEffAttention, PatchEmbed +from anomalib.models.components.dinov2.layers import MemEffAttention from anomalib.models.image.dinomaly.components.layers import Block, DinomalyMLP logger = logging.getLogger("dinov2") diff --git a/src/anomalib/models/image/dinomaly/torch_model.py b/src/anomalib/models/image/dinomaly/torch_model.py index a362f64feb..0c9a5cf62f 100644 --- a/src/anomalib/models/image/dinomaly/torch_model.py +++ b/src/anomalib/models/image/dinomaly/torch_model.py @@ -24,9 +24,7 @@ from anomalib.models.components import GaussianBlur2d from anomalib.models.components.dinov2 import DinoV2Loader from anomalib.models.image.dinomaly.components import CosineHardMiningLoss, DinomalyMLP, LinearAttention -from anomalib.models.image.dinomaly.components.vision_transformer import ( - DinoVisionTransformer as DinomalyVisionTransformer, -) +from anomalib.models.image.dinomaly.components import vision_transformer as dinomaly_vision_transformer # Encoder architecture configurations for DINOv2 models. # The target layers are the @@ -122,8 +120,8 @@ def __init__( self.encoder_name = encoder_name encoder = DinoV2Loader( - vit_factory=DinomalyVisionTransformer, - ).from_name(encoder_name) + vit_factory=dinomaly_vision_transformer, + ).load(encoder_name) # Extract architecture configuration based on the model name arch_config = self._get_architecture_config(encoder_name, target_layers) From 42c2b2c7b5050798c23393332810c9cafde3a24c Mon Sep 17 00:00:00 2001 From: Alfie Roddan <228966941+alfieroddanintel@users.noreply.github.com> Date: Thu, 13 Nov 2025 22:26:36 +0900 Subject: [PATCH 13/25] add tests for vit and dinov2loader Signed-off-by: Alfie Roddan <228966941+alfieroddanintel@users.noreply.github.com> --- .../unit/models/components/dinov2/__init__.py | 4 + .../components/dinov2/test_dinov2loader.py | 157 ++++++++++++++++++ .../unit/models/components/dinov2/test_vit.py | 132 +++++++++++++++ 3 files changed, 293 insertions(+) create mode 100644 tests/unit/models/components/dinov2/__init__.py create mode 100644 tests/unit/models/components/dinov2/test_dinov2loader.py create mode 100644 tests/unit/models/components/dinov2/test_vit.py diff --git a/tests/unit/models/components/dinov2/__init__.py b/tests/unit/models/components/dinov2/__init__.py new file mode 100644 index 0000000000..4c9f9289ac --- /dev/null +++ b/tests/unit/models/components/dinov2/__init__.py @@ -0,0 +1,4 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for dinov2 implementation and loader.""" diff --git a/tests/unit/models/components/dinov2/test_dinov2loader.py b/tests/unit/models/components/dinov2/test_dinov2loader.py new file mode 100644 index 0000000000..2899aa2e8d --- /dev/null +++ b/tests/unit/models/components/dinov2/test_dinov2loader.py @@ -0,0 +1,157 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for DinoV2Loader.""" + +from __future__ import annotations + +import re +from unittest.mock import MagicMock, patch + +import pytest +import torch +from torch import nn + +from anomalib.models.components.dinov2.dinov2_loader import DinoV2Loader + + +@pytest.fixture() +def dummy_model() -> nn.Module: + """Return a simple dummy model used by fake constructors.""" + + class Dummy(nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = nn.Linear(4, 4) + + return Dummy() + + +@pytest.fixture() +def loader() -> DinoV2Loader: + """Return a loader instance with a non-functional cache path.""" + return DinoV2Loader(cache_dir="not_used_in_unit_tests") + + +@pytest.mark.parametrize( + ("name", "expected"), + [ + ("dinov2_vit_base_14", ("dinov2", "base", 14)), + ("dinov2reg_vit_small_16", ("dinov2_reg", "small", 16)), + ("dinomaly_vit_large_14", ("dinomaly", "large", 14)), + ], +) +def test_parse_name_valid( + loader: DinoV2Loader, + name: str, + expected: tuple[str, str, int], +) -> None: + """Validate that supported model names parse correctly.""" + assert loader._parse_name(name) == expected # noqa: SLF001 + + +@pytest.mark.parametrize( + ("name", "expected"), + [ + ("foo_vit_base_14", "foo"), + ("x_vit_small_16", "x"), + ("wrongprefix_vit_large_14", "wrongprefix"), + ], +) +def test_parse_name_invalid_prefix(loader: DinoV2Loader, name: str, expected: str) -> None: + """Ensure invalid model prefixes raise ValueError.""" + msg = f"Unknown model type prefix '{expected}'." + with pytest.raises(ValueError, match=msg): + loader._parse_name(name) # noqa: SLF001 + + +def test_parse_name_invalid_architecture(loader: DinoV2Loader) -> None: + """Ensure unknown architecture names raise ValueError.""" + expected_msg = f"Invalid architecture 'tiny'. Expected one of: {list(loader.MODEL_CONFIGS)}" + with pytest.raises(ValueError, match=re.escape(expected_msg)): + loader._parse_name("dinov2_vit_tiny_14") # noqa: SLF001 + + +def test_create_model_success(loader: DinoV2Loader, dummy_model: nn.Module) -> None: + """Verify model creation succeeds when constructor exists.""" + fake_module = MagicMock() + fake_module.vit_small = MagicMock(return_value=dummy_model) + + loader.vit_factory = fake_module + model = loader.create_model("dinov2", "small", 14) + + fake_module.vit_small.assert_called_once() + assert model is dummy_model + + +def test_create_model_missing_constructor(loader: DinoV2Loader) -> None: + """Verify missing constructors cause ValueError.""" + loader.vit_factory = object() + expected_msg = f"No constructor vit_base in module {loader.vit_factory}" + with pytest.raises(ValueError, match=expected_msg): + loader.create_model("dinov2", "base", 14) + + +def test_get_weight_path_dinov2(loader: DinoV2Loader) -> None: + """Check generated weight filename for default dinov2 models.""" + path = loader._get_weight_path("dinov2", "base", 14) # noqa: SLF001 + assert path.name == "dinov2_vitb14_pretrain.pth" + + +def test_get_weight_path_reg(loader: DinoV2Loader) -> None: + """Check generated weight filename for register-token models.""" + path = loader._get_weight_path("dinov2_reg", "large", 16) # noqa: SLF001 + assert path.name == "dinov2_vitl16_reg4_pretrain.pth" + + +@patch("anomalib.models.components.dinov2.dinov2_loader.torch.load") +@patch("anomalib.models.components.dinov2.dinov2_loader.DinoV2Loader._download_weights") +def test_load_calls_weight_loading( + mock_download: MagicMock, + mock_torch_load: MagicMock, + loader: DinoV2Loader, + dummy_model: nn.Module, +) -> None: + """Confirm load() uses existing weights without downloading.""" + fake_module = MagicMock() + fake_module.vit_base = MagicMock(return_value=dummy_model) + loader.vit_factory = fake_module + + fake_path = MagicMock() + fake_path.exists.return_value = True + loader._get_weight_path = MagicMock(return_value=fake_path) # noqa: SLF001 + + mock_torch_load.return_value = {"layer": torch.zeros(1)} + + loaded = loader.load("dinov2_vit_base_14") + + fake_module.vit_base.assert_called_once() + mock_download.assert_not_called() + mock_torch_load.assert_called_once() + assert loaded is dummy_model + + +@patch("anomalib.models.components.dinov2.dinov2_loader.torch.load") +@patch("anomalib.models.components.dinov2.dinov2_loader.DinoV2Loader._download_weights") +def test_load_triggers_download_when_missing( + mock_download: MagicMock, + mock_torch_load: MagicMock, + loader: DinoV2Loader, + dummy_model: nn.Module, +) -> None: + """Confirm load() triggers weight download when file is missing.""" + fake_module = MagicMock() + fake_module.vit_small = MagicMock(return_value=dummy_model) + loader.vit_factory = fake_module + + fake_path = MagicMock() + fake_path.exists.return_value = False + loader._get_weight_path = MagicMock(return_value=fake_path) # noqa: SLF001 + + mock_torch_load.return_value = {"test": torch.zeros(1)} + + loader.load("dinov2_vit_small_14") + + mock_download.assert_called_once() + mock_torch_load.assert_called_once() + fake_module.vit_small.assert_called_once() diff --git a/tests/unit/models/components/dinov2/test_vit.py b/tests/unit/models/components/dinov2/test_vit.py new file mode 100644 index 0000000000..862457b4ba --- /dev/null +++ b/tests/unit/models/components/dinov2/test_vit.py @@ -0,0 +1,132 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for DINOv2 ViT / Loader.""" + +from __future__ import annotations + +import pytest +import torch +from torch import Tensor + +from anomalib.models.components.dinov2.vision_transformer import ( + DinoVisionTransformer, + vit_base, + vit_large, + vit_small, +) + + +@pytest.fixture() +def tiny_vit() -> DinoVisionTransformer: + """Return a very small ViT model for unit testing.""" + return DinoVisionTransformer( + img_size=32, + patch_size=8, + embed_dim=64, + depth=2, + num_heads=4, + ) + + +@pytest.fixture() +def tiny_input() -> Tensor: + """Return a small dummy input tensor.""" + return torch.randn(2, 3, 32, 32) # (B=2, C=3, H=W=32) + + +def test_model_initializes(tiny_vit: DinoVisionTransformer) -> None: + """Model constructs and exposes expected attributes.""" + m: DinoVisionTransformer = tiny_vit + + assert m.embed_dim == 64 + assert m.patch_size == 8 + assert m.n_blocks == 2 + assert hasattr(m, "patch_embed") + assert hasattr(m, "cls_token") + assert hasattr(m, "pos_embed") + assert hasattr(m, "blocks") + + +def test_patch_embedding_shape( + tiny_vit: DinoVisionTransformer, + tiny_input: Tensor, +) -> None: + """Patch embedding output has correct (B, N, C) shape.""" + patches: Tensor = tiny_vit.patch_embed(tiny_input) + b, n, c = patches.shape + + assert b == 2 + assert n == 16 # 32x32 with patch_size=8 → 4x4 → 16 patches + assert tiny_vit.embed_dim == c + + +def test_prepare_tokens_output_shape( + tiny_vit: DinoVisionTransformer, + tiny_input: Tensor, +) -> None: + """prepare_tokens_with_masks adds CLS and keeps correct embedding dims.""" + tokens: Tensor = tiny_vit.prepare_tokens_with_masks(tiny_input) + + expected_tokens: int = 1 + tiny_vit.patch_embed.num_patches + assert tokens.shape == (2, expected_tokens, tiny_vit.embed_dim) + + +def test_forward_features_training_output_shapes( + tiny_vit: DinoVisionTransformer, + tiny_input: Tensor, +) -> None: + """forward(is_training=True) returns a dict with expected shapes.""" + out: dict[str, Tensor | None] = tiny_vit(tiny_input, is_training=True) # type: ignore[assignment] + + assert isinstance(out, dict) + assert out["x_norm_clstoken"] is not None + assert out["x_norm_patchtokens"] is not None + + cls: Tensor = out["x_norm_clstoken"] # type: ignore[assignment] + patches: Tensor = out["x_norm_patchtokens"] # type: ignore[assignment] + + assert cls.shape == (2, tiny_vit.embed_dim) + assert patches.shape[1] == tiny_vit.patch_embed.num_patches + + +def test_forward_inference_output_shape( + tiny_vit: DinoVisionTransformer, + tiny_input: Tensor, +) -> None: + """Inference mode returns class-token output only.""" + out: Tensor = tiny_vit(tiny_input) # default is is_training=False + + assert isinstance(out, Tensor) + assert out.shape == (2, tiny_vit.embed_dim) + + +def test_get_intermediate_layers_shapes( + tiny_vit: DinoVisionTransformer, + tiny_input: Tensor, +) -> None: + """Intermediate layer extraction returns tensors shaped (B, tokens, C).""" + feats: tuple[Tensor, ...] = tiny_vit.get_intermediate_layers( + tiny_input, + n=1, + ) + + assert len(feats) == 1 + + f: Tensor = feats[0] + assert f.shape[0] == 2 # batch + assert f.shape[2] == tiny_vit.embed_dim + + +@pytest.mark.parametrize( + "factory", + [vit_small, vit_base, vit_large], +) +def test_vit_factories_create_models(factory) -> None: # noqa: ANN001 + """vit_small/base/large should return valid models.""" + model: DinoVisionTransformer = factory() + + assert isinstance(model, DinoVisionTransformer) + assert model.embed_dim > 0 + assert model.n_blocks > 0 + assert model.num_heads > 0 From 0739df4c6278bfd5df55cdcb35897c066eb9f332 Mon Sep 17 00:00:00 2001 From: Alfie Roddan <228966941+alfieroddanintel@users.noreply.github.com> Date: Thu, 13 Nov 2025 22:53:39 +0900 Subject: [PATCH 14/25] update docstrings Signed-off-by: Alfie Roddan <228966941+alfieroddanintel@users.noreply.github.com> --- .../models/components/dinov2/__init__.py | 4 +++ .../models/components/dinov2/dinov2_loader.py | 27 ++++++++++++++----- .../components/dinov2/layers/__init__.py | 16 ++++++++++- 3 files changed, 40 insertions(+), 7 deletions(-) diff --git a/src/anomalib/models/components/dinov2/__init__.py b/src/anomalib/models/components/dinov2/__init__.py index e9db6cf8f2..684bcc9ace 100644 --- a/src/anomalib/models/components/dinov2/__init__.py +++ b/src/anomalib/models/components/dinov2/__init__.py @@ -6,6 +6,10 @@ References: https://github.com/facebookresearch/dinov2/blob/main/dinov2/ + +Classes: + DinoVisionTransformer: DINOv2 implementation. + DinoV2Loader: Loader class to support downloading and loading weights. """ # vision transformer diff --git a/src/anomalib/models/components/dinov2/dinov2_loader.py b/src/anomalib/models/components/dinov2/dinov2_loader.py index 05b2b88561..6823d6919c 100644 --- a/src/anomalib/models/components/dinov2/dinov2_loader.py +++ b/src/anomalib/models/components/dinov2/dinov2_loader.py @@ -1,15 +1,30 @@ # Copyright (C) 2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -"""Loader for DINOv2 Vision Transformer models. +"""Loading pre-trained DINOv2 Vision Transformer models. -This module provides a simple interface for loading pre-trained DINOv2 Vision Transformer models for the -Dinomaly anomaly detection framework. +This module provides the :class:`DinoV2Loader` class for constructing and loading +pre-trained DINOv2 Vision Transformer models used in the Dinomaly anomaly detection +framework. It supports both standard DINOv2 models and register-token variants, and +allows custom Vision Transformer factories to be supplied. Example: - model = DinoV2Loader.from_name("dinov2_vit_base_14") - model = DinoV2Loader.from_name("vit_base_14") - model = DinoV2Loader(vit_factory=my_custom_vit_module).load("dinov2reg_vit_base_14") + >>> from anomalib.models.components.dinov2 import DinoV2Loader + >>> loader = DinoV2Loader() + >>> model = loader.load("dinov2_vit_base_14") + >>> model = loader.load("vit_base_14") + >>> custom_loader = DinoV2Loader(vit_factory=my_custom_vit_module) + >>> model = custom_loader.load("dinov2reg_vit_base_14") + +The DINOv2 loader handles: + +- Parsing model names and validating architecture types +- Constructing the appropriate Vision Transformer model +- Locating or downloading the corresponding pre-trained weights +- Supporting custom ViT implementations via a pluggable factory + +This enables a simple, unified interface for accessing DINOv2-based backbones in +downstream anomaly detection tasks. """ from __future__ import annotations diff --git a/src/anomalib/models/components/dinov2/layers/__init__.py b/src/anomalib/models/components/dinov2/layers/__init__.py index e2c88d7aa2..61af687569 100644 --- a/src/anomalib/models/components/dinov2/layers/__init__.py +++ b/src/anomalib/models/components/dinov2/layers/__init__.py @@ -4,7 +4,21 @@ """Layers needed to build DINOv2. References: -https://github.com/facebookresearch/dinov2/blob/main/dinov2/layers/__init__.py + https://github.com/facebookresearch/dinov2/blob/main/dinov2/layers/__init__.py + +Classes: + Attention: Standard multi-head self-attention layer used in Vision Transformers. + MemEffAttention: Memory-efficient variant of multi-head attention optimized for large inputs. + Block: Transformer block consisting of attention, MLP, residuals, and normalization layers. + CausalAttentionBlock: Transformer block with causal (autoregressive) attention masking. + DINOHead: Projection head used in DINO/DINOv2 for self-supervised feature learning. + DropPath: Implements stochastic depth, randomly dropping residual connections during training. + LayerScale: Applies learnable per-channel scaling to stabilize deep transformer training. + Mlp: Feedforward network used inside Vision Transformer blocks. + PatchEmbed: Converts image patches into token embeddings for Vision Transformer inputs. + SwiGLUFFN: SwiGLU-based feedforward network used in DINOv2 for improved expressiveness. + SwiGLUFFNAligned: Variant of SwiGLUFFN with tensor alignment optimizations. + SwiGLUFFNFused: Fused implementation of SwiGLUFFN for improved computational efficiency. """ from .attention import Attention, MemEffAttention From 826c18cdc5794885778257bcd5b1ca5eb901c514 Mon Sep 17 00:00:00 2001 From: Niclas <152474825+waschsalz@users.noreply.github.com> Date: Thu, 13 Nov 2025 10:11:52 +0100 Subject: [PATCH 15/25] fix(accelerator): Adding name method in XPUAccelerator (#3108) * Update xpu.py regarind PR #3092 Added the name method to fix an issue related to a newly added feature in lightning 2.5.6 Signed-off-by: Niclas <152474825+waschsalz@users.noreply.github.com> * Update xpu.py Signed-off-by: Niclas <152474825+waschsalz@users.noreply.github.com> * Update xpu.py Signed-off-by: Niclas <152474825+waschsalz@users.noreply.github.com> * Update xpu.py with docstring Signed-off-by: Niclas <152474825+waschsalz@users.noreply.github.com> * Update xpu.py with correct docstring Signed-off-by: Niclas <152474825+waschsalz@users.noreply.github.com> * added name method for XPUAccelerator Signed-off-by: waschsalz --------- Signed-off-by: Niclas <152474825+waschsalz@users.noreply.github.com> Signed-off-by: waschsalz Signed-off-by: Alfie Roddan <228966941+alfieroddanintel@users.noreply.github.com> --- src/anomalib/engine/accelerator/xpu.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/anomalib/engine/accelerator/xpu.py b/src/anomalib/engine/accelerator/xpu.py index a4f2dbb309..dc6d56ee82 100644 --- a/src/anomalib/engine/accelerator/xpu.py +++ b/src/anomalib/engine/accelerator/xpu.py @@ -12,7 +12,10 @@ class XPUAccelerator(Accelerator): """Support for a XPU, optimized for large-scale machine learning.""" - accelerator_name = "xpu" + @property + def name(self) -> str: + """Setting the name of the accelerator which is required for accelerators by pytorch-lightning >= 2.5.6.""" + return "xpu" @staticmethod def setup_device(device: torch.device) -> None: @@ -59,7 +62,7 @@ def teardown(self) -> None: AcceleratorRegistry.register( - XPUAccelerator.accelerator_name, + XPUAccelerator().name, XPUAccelerator, description="Accelerator supports XPU devices", ) From 0788a326169b1af7fd1252ed10a6671adbdf9618 Mon Sep 17 00:00:00 2001 From: Alfie Roddan <228966941+alfieroddanintel@users.noreply.github.com> Date: Fri, 14 Nov 2025 03:18:20 +0900 Subject: [PATCH 16/25] change licesning with meta. Tensor is torch.Tensor. remove __future__. Signed-off-by: Alfie Roddan <228966941+alfieroddanintel@users.noreply.github.com> --- .../models/components/dinov2/dinov2_loader.py | 2 -- .../components/dinov2/layers/attention.py | 12 +++---- .../models/components/dinov2/layers/block.py | 36 ++++++++----------- .../components/dinov2/layers/dino_head.py | 9 ++--- .../components/dinov2/layers/drop_path.py | 13 +++---- .../components/dinov2/layers/layer_scale.py | 10 +++--- .../models/components/dinov2/layers/mlp.py | 14 ++++---- .../components/dinov2/layers/patch_embed.py | 14 ++++---- .../components/dinov2/layers/swiglu_ffn.py | 16 ++++----- .../components/dinov2/vision_transformer.py | 9 ++--- .../models/image/dinomaly/torch_model.py | 4 +-- .../unit/models/components/dinov2/test_vit.py | 31 ++++++++-------- 12 files changed, 75 insertions(+), 95 deletions(-) diff --git a/src/anomalib/models/components/dinov2/dinov2_loader.py b/src/anomalib/models/components/dinov2/dinov2_loader.py index 6823d6919c..a4ab8a4c91 100644 --- a/src/anomalib/models/components/dinov2/dinov2_loader.py +++ b/src/anomalib/models/components/dinov2/dinov2_loader.py @@ -27,8 +27,6 @@ downstream anomaly detection tasks. """ -from __future__ import annotations - import logging from pathlib import Path from typing import ClassVar diff --git a/src/anomalib/models/components/dinov2/layers/attention.py b/src/anomalib/models/components/dinov2/layers/attention.py index 0f1ae6ef14..5db5a74e1b 100644 --- a/src/anomalib/models/components/dinov2/layers/attention.py +++ b/src/anomalib/models/components/dinov2/layers/attention.py @@ -1,5 +1,7 @@ # Copyright (C) 2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +# Copyright (C) 2025 Meta Platforms, Inc. and affiliates. +# SPDX-License-Identifier: Apache-2.0 """Attention layers for DINOv2 Vision Transformers. @@ -11,12 +13,10 @@ blocks for feature extraction and masked modeling. """ -from __future__ import annotations - import logging import torch -from torch import Tensor, nn +from torch import nn from torch.nn import functional as F # noqa: N812 logger = logging.getLogger(__name__) @@ -82,7 +82,7 @@ def init_weights( if self.proj.bias is not None: nn.init.zeros_(self.proj.bias) - def forward(self, x: Tensor, is_causal: bool = False) -> Tensor: + def forward(self, x: torch.Tensor, is_causal: bool = False) -> torch.Tensor: """Apply multi-head self-attention. Args: @@ -90,7 +90,7 @@ def forward(self, x: Tensor, is_causal: bool = False) -> Tensor: is_causal: If True, applies causal masking. Returns: - Tensor of shape ``(B, N, C)`` containing attended features. + torch.Tensor of shape ``(B, N, C)`` containing attended features. """ b, n, c = x.shape qkv = self.qkv(x).reshape(b, n, 3, self.num_heads, c // self.num_heads) @@ -120,7 +120,7 @@ class MemEffAttention(Attention): this implementation uses the scaled dot product from torch. """ - def forward(self, x: Tensor, attn_bias: Tensor | None = None) -> Tensor: + def forward(self, x: torch.Tensor, attn_bias: torch.Tensor | None = None) -> torch.Tensor: """Compute memory-efficient attention using PyTorch's scaled dot product attention. Args: diff --git a/src/anomalib/models/components/dinov2/layers/block.py b/src/anomalib/models/components/dinov2/layers/block.py index 1afe3512de..26ac43475a 100644 --- a/src/anomalib/models/components/dinov2/layers/block.py +++ b/src/anomalib/models/components/dinov2/layers/block.py @@ -1,35 +1,27 @@ # Copyright (C) 2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +# Copyright (C) 2025 Meta Platforms, Inc. and affiliates. +# SPDX-License-Identifier: Apache-2.0 + """Transformer blocks used in DINOv2 Vision Transformers. This module implements: - Standard transformer blocks with attention and MLP (`Block`) - Causal attention blocks (`CausalAttentionBlock`) - -The implementation is adapted from the original DINO and timm Vision -Transformer code: - -- https://github.com/facebookresearch/dino/blob/master/vision_transformer.py -- https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py """ -from __future__ import annotations - import logging -from typing import TYPE_CHECKING +from collections.abc import Callable import torch -from torch import Tensor, nn +from torch import nn from .attention import Attention from .drop_path import DropPath from .layer_scale import LayerScale from .mlp import Mlp -if TYPE_CHECKING: - from collections.abc import Callable - logger = logging.getLogger("dinov2") @@ -66,7 +58,7 @@ def __init__( ffn_bias: bool = True, drop: float = 0.0, attn_drop: float = 0.0, - init_values: float | Tensor | None = None, + init_values: float | torch.Tensor | None = None, drop_path: float = 0.0, act_layer: Callable[..., nn.Module] = nn.GELU, norm_layer: Callable[..., nn.Module] = nn.LayerNorm, @@ -101,13 +93,13 @@ def __init__( self.sample_drop_ratio: float = drop_path - def forward(self, x: Tensor) -> Tensor: + def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply attention and MLP residual blocks with optional stochastic depth.""" - def attn_residual_func(inp: Tensor) -> Tensor: + def attn_residual_func(inp: torch.Tensor) -> torch.Tensor: return self.ls1(self.attn(self.norm1(inp))) - def ffn_residual_func(inp: Tensor) -> Tensor: + def ffn_residual_func(inp: torch.Tensor) -> torch.Tensor: return self.ls2(self.mlp(self.norm2(inp))) if self.training and self.sample_drop_ratio > 0.1: @@ -200,17 +192,17 @@ def init_weights( nn.init.normal_(self.feed_forward.fc2.weight, std=init_proj_std) self.ffn_norm.reset_parameters() - def forward(self, x: Tensor) -> Tensor: + def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply causal attention followed by a feed-forward block.""" x_attn = x + self.ls1(self.attention(self.attention_norm(x), self.is_causal)) return x_attn + self.ls2(self.feed_forward(self.ffn_norm(x_attn))) def drop_add_residual_stochastic_depth( - x: Tensor, - residual_func: Callable[[Tensor], Tensor], + x: torch.Tensor, + residual_func: Callable[[torch.Tensor], torch.Tensor], sample_drop_ratio: float = 0.0, -) -> Tensor: +) -> torch.Tensor: """Apply stochastic depth to a residual branch on a subset of samples. Args: @@ -219,7 +211,7 @@ def drop_add_residual_stochastic_depth( sample_drop_ratio: Fraction of samples to drop for residual computation. Returns: - Tensor with residual added to a subset of samples. + torch.Tensor with residual added to a subset of samples. """ b, _, _ = x.shape sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) diff --git a/src/anomalib/models/components/dinov2/layers/dino_head.py b/src/anomalib/models/components/dinov2/layers/dino_head.py index 811195053b..37fb892056 100644 --- a/src/anomalib/models/components/dinov2/layers/dino_head.py +++ b/src/anomalib/models/components/dinov2/layers/dino_head.py @@ -1,5 +1,8 @@ # Copyright (C) 2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +# Copyright (C) 2025 Meta Platforms, Inc. and affiliates. +# SPDX-License-Identifier: Apache-2.0 + """DINO projection head module. @@ -7,10 +10,8 @@ https://github.com/facebookresearch/dinov2/blob/main/dinov2/layers/dino_head.py """ -from __future__ import annotations - import torch -from torch import Tensor, nn +from torch import nn from torch.nn.init import trunc_normal_ from torch.nn.utils import weight_norm @@ -68,7 +69,7 @@ def _init_weights(self, module: nn.Module) -> None: # noqa: PLR6301 if module.bias is not None: nn.init.constant_(module.bias, 0) - def forward(self, x: Tensor) -> Tensor: + def forward(self, x: torch.Tensor) -> torch.Tensor: """Run the DINO projection head forward pass.""" x = self.mlp(x) diff --git a/src/anomalib/models/components/dinov2/layers/drop_path.py b/src/anomalib/models/components/dinov2/layers/drop_path.py index def3884e82..2a10ce53a3 100644 --- a/src/anomalib/models/components/dinov2/layers/drop_path.py +++ b/src/anomalib/models/components/dinov2/layers/drop_path.py @@ -1,5 +1,7 @@ # Copyright (C) 2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +# Copyright (C) 2025 Meta Platforms, Inc. and affiliates. +# SPDX-License-Identifier: Apache-2.0 """Stochastic depth drop-path implementation used in DINOv2. @@ -8,12 +10,11 @@ drops entire residual branches during training to improve model robustness. """ -from __future__ import annotations - -from torch import Tensor, nn +import torch +from torch import nn -def drop_path(x: Tensor, drop_prob: float = 0.0, training: bool = False) -> Tensor: +def drop_path(x: torch.torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor: """Apply stochastic depth to an input tensor. Args: @@ -22,7 +23,7 @@ def drop_path(x: Tensor, drop_prob: float = 0.0, training: bool = False) -> Tens training: Whether the module is in training mode. Returns: - Tensor with dropped paths applied during training, or the original + torch.Tensor with dropped paths applied during training, or the original tensor during evaluation. Notes: @@ -57,6 +58,6 @@ def __init__(self, drop_prob: float | None = None) -> None: super().__init__() self.drop_prob = drop_prob if drop_prob is not None else 0.0 - def forward(self, x: Tensor) -> Tensor: + def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass applying stochastic depth.""" return drop_path(x, drop_prob=self.drop_prob, training=self.training) diff --git a/src/anomalib/models/components/dinov2/layers/layer_scale.py b/src/anomalib/models/components/dinov2/layers/layer_scale.py index a64e57257e..11c48e6fe9 100644 --- a/src/anomalib/models/components/dinov2/layers/layer_scale.py +++ b/src/anomalib/models/components/dinov2/layers/layer_scale.py @@ -1,5 +1,7 @@ # Copyright (C) 2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +# Copyright (C) 2025 Meta Platforms, Inc. and affiliates. +# SPDX-License-Identifier: Apache-2.0 """LayerScale module used in DINOv2. @@ -8,10 +10,8 @@ Vision Transformers with residual connections. """ -from __future__ import annotations - import torch -from torch import Tensor, nn +from torch import nn class LayerScale(nn.Module): @@ -33,7 +33,7 @@ class LayerScale(nn.Module): def __init__( self, dim: int, - init_values: float | Tensor = 1e-5, + init_values: float | torch.Tensor = 1e-5, inplace: bool = False, device: torch.device | None = None, dtype: torch.dtype | None = None, @@ -48,6 +48,6 @@ def reset_parameters(self) -> None: """Reset scale parameters to their initialization values.""" nn.init.constant_(self.gamma, self.init_values) - def forward(self, x: Tensor) -> Tensor: + def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply channel-wise scaling to the input tensor.""" return x.mul_(self.gamma) if self.inplace else x * self.gamma diff --git a/src/anomalib/models/components/dinov2/layers/mlp.py b/src/anomalib/models/components/dinov2/layers/mlp.py index 9dc80253a5..9a46205d70 100644 --- a/src/anomalib/models/components/dinov2/layers/mlp.py +++ b/src/anomalib/models/components/dinov2/layers/mlp.py @@ -1,5 +1,7 @@ # Copyright (C) 2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +# Copyright (C) 2025 Meta Platforms, Inc. and affiliates. +# SPDX-License-Identifier: Apache-2.0 """Feed-forward MLP block used in DINOv2 Vision Transformers. @@ -8,14 +10,10 @@ inside each transformer block. """ -from __future__ import annotations - -from typing import TYPE_CHECKING - -from torch import Tensor, nn +from collections.abc import Callable -if TYPE_CHECKING: - from collections.abc import Callable +import torch +from torch import nn class Mlp(nn.Module): @@ -51,7 +49,7 @@ def __init__( self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) self.drop = nn.Dropout(drop) - def forward(self, x: Tensor) -> Tensor: + def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply the two-layer feed-forward transformation.""" x = self.fc1(x) x = self.act(x) diff --git a/src/anomalib/models/components/dinov2/layers/patch_embed.py b/src/anomalib/models/components/dinov2/layers/patch_embed.py index 1e39b25b8b..bf3e50d9f1 100644 --- a/src/anomalib/models/components/dinov2/layers/patch_embed.py +++ b/src/anomalib/models/components/dinov2/layers/patch_embed.py @@ -1,5 +1,7 @@ # Copyright (C) 2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +# Copyright (C) 2025 Meta Platforms, Inc. and affiliates. +# SPDX-License-Identifier: Apache-2.0 """Patch embedding module for DINOv2 Vision Transformers. @@ -8,14 +10,10 @@ optional output reshaping, and optional normalization. """ -from __future__ import annotations - -from typing import TYPE_CHECKING - -from torch import Tensor, nn +from collections.abc import Callable -if TYPE_CHECKING: - from collections.abc import Callable +import torch +from torch import nn def make_2tuple(x: int | tuple[int, int]) -> tuple[int, int]: @@ -79,7 +77,7 @@ def __init__( self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_hw, stride=patch_hw) self.norm = norm_layer(embed_dim) if norm_layer is not None else nn.Identity() - def forward(self, x: Tensor) -> Tensor: + def forward(self, x: torch.Tensor) -> torch.Tensor: """Embed the input image into patch tokens.""" _, _, h, w = x.shape patch_h, patch_w = self.patch_size diff --git a/src/anomalib/models/components/dinov2/layers/swiglu_ffn.py b/src/anomalib/models/components/dinov2/layers/swiglu_ffn.py index 1866574f74..e56c2ff4c1 100644 --- a/src/anomalib/models/components/dinov2/layers/swiglu_ffn.py +++ b/src/anomalib/models/components/dinov2/layers/swiglu_ffn.py @@ -1,5 +1,7 @@ # Copyright (C) 2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +# Copyright (C) 2025 Meta Platforms, Inc. and affiliates. +# SPDX-License-Identifier: Apache-2.0 """SwiGLU-based feed-forward layers used in DINOv2. @@ -12,15 +14,11 @@ These layers are used as transformer FFN blocks in DINOv2 models. """ -from __future__ import annotations - -from typing import TYPE_CHECKING +from collections.abc import Callable +import torch import torch.nn.functional as F # noqa: N812 -from torch import Tensor, nn - -if TYPE_CHECKING: - from collections.abc import Callable +from torch import nn class SwiGLUFFN(nn.Module): @@ -55,7 +53,7 @@ def __init__( self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) self.w3 = nn.Linear(hidden_features, out_features, bias=bias) - def forward(self, x: Tensor) -> Tensor: + def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply the SwiGLU feed-forward transformation.""" x12 = self.w12(x) x1, x2 = x12.chunk(2, dim=-1) @@ -138,7 +136,7 @@ def __init__( self.w2 = nn.Linear(in_features, hidden_aligned, bias=bias, device=device) self.w3 = nn.Linear(hidden_aligned, out_features, bias=bias, device=device) - def forward(self, x: Tensor) -> Tensor: + def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply aligned SwiGLU feed-forward transformation.""" x1 = self.w1(x) x2 = self.w2(x) diff --git a/src/anomalib/models/components/dinov2/vision_transformer.py b/src/anomalib/models/components/dinov2/vision_transformer.py index 836eadac0a..fa9f5bdded 100644 --- a/src/anomalib/models/components/dinov2/vision_transformer.py +++ b/src/anomalib/models/components/dinov2/vision_transformer.py @@ -1,5 +1,7 @@ # Copyright (C) 2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +# Copyright (C) 2025 Meta Platforms, Inc. and affiliates. +# SPDX-License-Identifier: Apache-2.0 """Loader for DINOv2 Vision Transformer models. @@ -12,11 +14,9 @@ and related anomaly detection frameworks. """ -from __future__ import annotations - import math +from collections.abc import Callable, Sequence from functools import partial -from typing import TYPE_CHECKING import numpy as np import torch @@ -25,9 +25,6 @@ from anomalib.models.components.dinov2.layers import Block, MemEffAttention, Mlp, PatchEmbed, SwiGLUFFNFused -if TYPE_CHECKING: - from collections.abc import Callable, Sequence - def named_apply( fn: Callable[..., object], diff --git a/src/anomalib/models/image/dinomaly/torch_model.py b/src/anomalib/models/image/dinomaly/torch_model.py index 0c9a5cf62f..da47cd808b 100644 --- a/src/anomalib/models/image/dinomaly/torch_model.py +++ b/src/anomalib/models/image/dinomaly/torch_model.py @@ -119,9 +119,7 @@ def __init__( fuse_layer_decoder = DEFAULT_FUSE_LAYERS self.encoder_name = encoder_name - encoder = DinoV2Loader( - vit_factory=dinomaly_vision_transformer, - ).load(encoder_name) + encoder = DinoV2Loader(vit_factory=dinomaly_vision_transformer).load(encoder_name) # Extract architecture configuration based on the model name arch_config = self._get_architecture_config(encoder_name, target_layers) diff --git a/tests/unit/models/components/dinov2/test_vit.py b/tests/unit/models/components/dinov2/test_vit.py index 862457b4ba..730fe6a87e 100644 --- a/tests/unit/models/components/dinov2/test_vit.py +++ b/tests/unit/models/components/dinov2/test_vit.py @@ -7,7 +7,6 @@ import pytest import torch -from torch import Tensor from anomalib.models.components.dinov2.vision_transformer import ( DinoVisionTransformer, @@ -30,7 +29,7 @@ def tiny_vit() -> DinoVisionTransformer: @pytest.fixture() -def tiny_input() -> Tensor: +def tiny_input() -> torch.Tensor: """Return a small dummy input tensor.""" return torch.randn(2, 3, 32, 32) # (B=2, C=3, H=W=32) @@ -50,10 +49,10 @@ def test_model_initializes(tiny_vit: DinoVisionTransformer) -> None: def test_patch_embedding_shape( tiny_vit: DinoVisionTransformer, - tiny_input: Tensor, + tiny_input: torch.Tensor, ) -> None: """Patch embedding output has correct (B, N, C) shape.""" - patches: Tensor = tiny_vit.patch_embed(tiny_input) + patches: torch.Tensor = tiny_vit.patch_embed(tiny_input) b, n, c = patches.shape assert b == 2 @@ -63,10 +62,10 @@ def test_patch_embedding_shape( def test_prepare_tokens_output_shape( tiny_vit: DinoVisionTransformer, - tiny_input: Tensor, + tiny_input: torch.Tensor, ) -> None: """prepare_tokens_with_masks adds CLS and keeps correct embedding dims.""" - tokens: Tensor = tiny_vit.prepare_tokens_with_masks(tiny_input) + tokens: torch.Tensor = tiny_vit.prepare_tokens_with_masks(tiny_input) expected_tokens: int = 1 + tiny_vit.patch_embed.num_patches assert tokens.shape == (2, expected_tokens, tiny_vit.embed_dim) @@ -74,17 +73,17 @@ def test_prepare_tokens_output_shape( def test_forward_features_training_output_shapes( tiny_vit: DinoVisionTransformer, - tiny_input: Tensor, + tiny_input: torch.Tensor, ) -> None: """forward(is_training=True) returns a dict with expected shapes.""" - out: dict[str, Tensor | None] = tiny_vit(tiny_input, is_training=True) # type: ignore[assignment] + out: dict[str, torch.Tensor | None] = tiny_vit(tiny_input, is_training=True) # type: ignore[assignment] assert isinstance(out, dict) assert out["x_norm_clstoken"] is not None assert out["x_norm_patchtokens"] is not None - cls: Tensor = out["x_norm_clstoken"] # type: ignore[assignment] - patches: Tensor = out["x_norm_patchtokens"] # type: ignore[assignment] + cls: torch.Tensor = out["x_norm_clstoken"] # type: ignore[assignment] + patches: torch.Tensor = out["x_norm_patchtokens"] # type: ignore[assignment] assert cls.shape == (2, tiny_vit.embed_dim) assert patches.shape[1] == tiny_vit.patch_embed.num_patches @@ -92,28 +91,28 @@ def test_forward_features_training_output_shapes( def test_forward_inference_output_shape( tiny_vit: DinoVisionTransformer, - tiny_input: Tensor, + tiny_input: torch.Tensor, ) -> None: """Inference mode returns class-token output only.""" - out: Tensor = tiny_vit(tiny_input) # default is is_training=False + out: torch.Tensor = tiny_vit(tiny_input) # default is is_training=False - assert isinstance(out, Tensor) + assert isinstance(out, torch.Tensor) assert out.shape == (2, tiny_vit.embed_dim) def test_get_intermediate_layers_shapes( tiny_vit: DinoVisionTransformer, - tiny_input: Tensor, + tiny_input: torch.Tensor, ) -> None: """Intermediate layer extraction returns tensors shaped (B, tokens, C).""" - feats: tuple[Tensor, ...] = tiny_vit.get_intermediate_layers( + feats: tuple[torch.Tensor, ...] = tiny_vit.get_intermediate_layers( tiny_input, n=1, ) assert len(feats) == 1 - f: Tensor = feats[0] + f: torch.Tensor = feats[0] assert f.shape[0] == 2 # batch assert f.shape[2] == tiny_vit.embed_dim From fcb1189514ff54f03457e185e5ca3012205873a3 Mon Sep 17 00:00:00 2001 From: Samet Akcay Date: Mon, 24 Nov 2025 09:49:18 +0000 Subject: [PATCH 17/25] Update src/anomalib/models/components/dinov2/layers/block.py Signed-off-by: Samet Akcay --- src/anomalib/models/components/dinov2/layers/block.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/anomalib/models/components/dinov2/layers/block.py b/src/anomalib/models/components/dinov2/layers/block.py index 26ac43475a..6d93a9b2ee 100644 --- a/src/anomalib/models/components/dinov2/layers/block.py +++ b/src/anomalib/models/components/dinov2/layers/block.py @@ -1,5 +1,6 @@ # Copyright (C) 2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 + # Copyright (C) 2025 Meta Platforms, Inc. and affiliates. # SPDX-License-Identifier: Apache-2.0 From c48ed3f57d1c61dfcea04d6d6c0120078601f327 Mon Sep 17 00:00:00 2001 From: Samet Akcay Date: Mon, 24 Nov 2025 09:49:31 +0000 Subject: [PATCH 18/25] Update src/anomalib/models/components/dinov2/layers/attention.py Signed-off-by: Samet Akcay --- src/anomalib/models/components/dinov2/layers/attention.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/anomalib/models/components/dinov2/layers/attention.py b/src/anomalib/models/components/dinov2/layers/attention.py index 5db5a74e1b..f48bcbc24c 100644 --- a/src/anomalib/models/components/dinov2/layers/attention.py +++ b/src/anomalib/models/components/dinov2/layers/attention.py @@ -1,5 +1,6 @@ # Copyright (C) 2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 + # Copyright (C) 2025 Meta Platforms, Inc. and affiliates. # SPDX-License-Identifier: Apache-2.0 From c3a6cdfb96b3ed02d332cf90b290cdb91ed919a5 Mon Sep 17 00:00:00 2001 From: Samet Akcay Date: Mon, 24 Nov 2025 09:49:41 +0000 Subject: [PATCH 19/25] Update src/anomalib/models/components/dinov2/layers/dino_head.py Signed-off-by: Samet Akcay --- src/anomalib/models/components/dinov2/layers/dino_head.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/anomalib/models/components/dinov2/layers/dino_head.py b/src/anomalib/models/components/dinov2/layers/dino_head.py index 37fb892056..81789bb94c 100644 --- a/src/anomalib/models/components/dinov2/layers/dino_head.py +++ b/src/anomalib/models/components/dinov2/layers/dino_head.py @@ -1,5 +1,6 @@ # Copyright (C) 2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 + # Copyright (C) 2025 Meta Platforms, Inc. and affiliates. # SPDX-License-Identifier: Apache-2.0 From bffa69ebf02cba12644cf142bc198a3a7ee02c72 Mon Sep 17 00:00:00 2001 From: Samet Akcay Date: Mon, 24 Nov 2025 09:49:54 +0000 Subject: [PATCH 20/25] Update src/anomalib/models/components/dinov2/layers/drop_path.py Signed-off-by: Samet Akcay --- src/anomalib/models/components/dinov2/layers/drop_path.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/anomalib/models/components/dinov2/layers/drop_path.py b/src/anomalib/models/components/dinov2/layers/drop_path.py index 2a10ce53a3..5f21ff8af9 100644 --- a/src/anomalib/models/components/dinov2/layers/drop_path.py +++ b/src/anomalib/models/components/dinov2/layers/drop_path.py @@ -1,5 +1,6 @@ # Copyright (C) 2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 + # Copyright (C) 2025 Meta Platforms, Inc. and affiliates. # SPDX-License-Identifier: Apache-2.0 From 6ff87de98806965ad3cee3f63b6fb56037902acf Mon Sep 17 00:00:00 2001 From: Samet Akcay Date: Mon, 24 Nov 2025 09:50:03 +0000 Subject: [PATCH 21/25] Update src/anomalib/models/components/dinov2/layers/layer_scale.py Signed-off-by: Samet Akcay --- src/anomalib/models/components/dinov2/layers/layer_scale.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/anomalib/models/components/dinov2/layers/layer_scale.py b/src/anomalib/models/components/dinov2/layers/layer_scale.py index 11c48e6fe9..5f0770cf45 100644 --- a/src/anomalib/models/components/dinov2/layers/layer_scale.py +++ b/src/anomalib/models/components/dinov2/layers/layer_scale.py @@ -1,5 +1,6 @@ # Copyright (C) 2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 + # Copyright (C) 2025 Meta Platforms, Inc. and affiliates. # SPDX-License-Identifier: Apache-2.0 From 4c8e61adf149d38a34f853f580b573897422fc5e Mon Sep 17 00:00:00 2001 From: Samet Akcay Date: Mon, 24 Nov 2025 09:50:19 +0000 Subject: [PATCH 22/25] Update src/anomalib/models/components/dinov2/layers/mlp.py Signed-off-by: Samet Akcay --- src/anomalib/models/components/dinov2/layers/mlp.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/anomalib/models/components/dinov2/layers/mlp.py b/src/anomalib/models/components/dinov2/layers/mlp.py index 9a46205d70..fa0c2aed83 100644 --- a/src/anomalib/models/components/dinov2/layers/mlp.py +++ b/src/anomalib/models/components/dinov2/layers/mlp.py @@ -1,5 +1,6 @@ # Copyright (C) 2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 + # Copyright (C) 2025 Meta Platforms, Inc. and affiliates. # SPDX-License-Identifier: Apache-2.0 From 86e15765f825ef7ead52ae8b5d9382c584af5751 Mon Sep 17 00:00:00 2001 From: Samet Akcay Date: Mon, 24 Nov 2025 09:50:29 +0000 Subject: [PATCH 23/25] Update src/anomalib/models/components/dinov2/layers/patch_embed.py Signed-off-by: Samet Akcay --- src/anomalib/models/components/dinov2/layers/patch_embed.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/anomalib/models/components/dinov2/layers/patch_embed.py b/src/anomalib/models/components/dinov2/layers/patch_embed.py index bf3e50d9f1..73ee524064 100644 --- a/src/anomalib/models/components/dinov2/layers/patch_embed.py +++ b/src/anomalib/models/components/dinov2/layers/patch_embed.py @@ -1,5 +1,6 @@ # Copyright (C) 2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 + # Copyright (C) 2025 Meta Platforms, Inc. and affiliates. # SPDX-License-Identifier: Apache-2.0 From 7c13702bb926e308a32a0af0e8dcb7850c1f56a2 Mon Sep 17 00:00:00 2001 From: Samet Akcay Date: Mon, 24 Nov 2025 09:50:39 +0000 Subject: [PATCH 24/25] Update src/anomalib/models/components/dinov2/layers/swiglu_ffn.py Signed-off-by: Samet Akcay --- src/anomalib/models/components/dinov2/layers/swiglu_ffn.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/anomalib/models/components/dinov2/layers/swiglu_ffn.py b/src/anomalib/models/components/dinov2/layers/swiglu_ffn.py index e56c2ff4c1..e7c64a20c9 100644 --- a/src/anomalib/models/components/dinov2/layers/swiglu_ffn.py +++ b/src/anomalib/models/components/dinov2/layers/swiglu_ffn.py @@ -1,5 +1,6 @@ # Copyright (C) 2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 + # Copyright (C) 2025 Meta Platforms, Inc. and affiliates. # SPDX-License-Identifier: Apache-2.0 From 34aad7877bd7c1785c73e34544382c52d1b728c0 Mon Sep 17 00:00:00 2001 From: Samet Akcay Date: Mon, 24 Nov 2025 09:50:49 +0000 Subject: [PATCH 25/25] Update src/anomalib/models/components/dinov2/vision_transformer.py Signed-off-by: Samet Akcay --- src/anomalib/models/components/dinov2/vision_transformer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/anomalib/models/components/dinov2/vision_transformer.py b/src/anomalib/models/components/dinov2/vision_transformer.py index fa9f5bdded..64308316a7 100644 --- a/src/anomalib/models/components/dinov2/vision_transformer.py +++ b/src/anomalib/models/components/dinov2/vision_transformer.py @@ -1,5 +1,6 @@ # Copyright (C) 2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 + # Copyright (C) 2025 Meta Platforms, Inc. and affiliates. # SPDX-License-Identifier: Apache-2.0