-
Notifications
You must be signed in to change notification settings - Fork 844
π feat(model): add GLASS model into Anomalib #2629
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
code-dev05
wants to merge
29
commits into
open-edge-platform:feature/model/glass
Choose a base branch
from
code-dev05:feature/model/glass
base: feature/model/glass
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
29 commits
Select commit
Hold shift + click to select a range
5b4931b
Initial Implementation of GLASS Model
code-dev05 4789f49
Created the trainer class for glass model
code-dev05 050fd4c
Added suggested changes
code-dev05 cdd0984
Modified forward method for model
code-dev05 381eec6
Fixed backbone loading logic
code-dev05 9b1c51a
Added type for input shape
code-dev05 161005c
Fixed bugs
code-dev05 3d78beb
Merge branch 'main' into feature/model/glass
samet-akcay 617cf49
Changed files as needed
code-dev05 f9d3207
Merge remote-tracking branch 'origin/feature/model/glass' into featurβ¦
code-dev05 7fea20f
Matched code to the original implementation
code-dev05 1beedf5
Added support for gpu
code-dev05 838bc50
Refactored code from lightning model to torch model
code-dev05 1baa0b7
GPU bug fixed
code-dev05 f066b3c
used image device in torch model
code-dev05 6e780b0
fixed bug
code-dev05 b1be6f5
Added validation step
code-dev05 20d97dd
Merge branch 'main' into feature/model/glass
samet-akcay d5affe4
Refactored code for better readability
code-dev05 f008537
Merge remote-tracking branch 'origin/feature/model/glass' into featurβ¦
code-dev05 a1097e5
Set automatic optimization to False and made component functions
code-dev05 7e9d4d4
Resolved third-party-programs.txt conflict
code-dev05 44dcd60
Added automated download for dtd dataset in Glass Model
code-dev05 da57095
Removed some input args
code-dev05 ba5a6dd
Change in default parameters
code-dev05 714a3c3
Fixed default backbone name
code-dev05 1a3519c
Changed configure pre_processor method
code-dev05 9e12285
Merge remote-tracking branch 'up/main' into feature/model/glass
code-dev05 5466d46
Made some changes to the workflow of GLASS Model
code-dev05 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,26 @@ | ||
| model: | ||
| class_path: anomalib.models.Glass | ||
| init_args: | ||
| input_shape: [288, 288] | ||
| backbone: resnet18 | ||
| pretrain_embed_dim: 1024 | ||
| target_embed_dim: 1024 | ||
| patchsize: 3 | ||
| patchstride: 1 | ||
| pre_trained: true | ||
| pre_projection: 1 | ||
| discriminator_layers: 2 | ||
| discriminator_hidden: 1024 | ||
| discriminator_margin: 0.5 | ||
| learning_rate: 0.0001 | ||
| step: 20 | ||
| svd: 0 | ||
|
|
||
| trainer: | ||
| max_epochs: 640 | ||
| callbacks: | ||
| - class_path: lightning.pytorch.callbacks.EarlyStopping | ||
| init_args: | ||
| patience: 5 | ||
| monitor: pixel_AUROC | ||
| mode: max |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,23 @@ | ||
| # Copyright (C) 2025 Intel Corporation | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| """GLASS - Unsupervised anomaly detection via Gradient Ascent for Industrial Anomaly detection and localization. | ||
|
|
||
| This module implements the GLASS model for unsupervised anomaly detection and localization. GLASS synthesizes both | ||
| global and local anomalies using Gaussian noise guided by gradient ascent to enhance weak defect detection in | ||
| industrial settings. | ||
|
|
||
| The model consists of: | ||
| - A feature extractor and feature adaptor to obtain robust normal representations | ||
| - A Global Anomaly Synthesis (GAS) module that perturbs features using Gaussian noise and gradient ascent with | ||
| truncated projection | ||
| - A Local Anomaly Synthesis (LAS) module that overlays augmented textures onto images using Perlin noise masks | ||
| - A shared discriminator trained with features from normal, global, and local synthetic samples | ||
|
|
||
| Paper: `A Unified Anomaly Synthesis Strategy with Gradient Ascent for Industrial Anomaly Detection and Localization | ||
| <https://arxiv.org/pdf/2407.09359>` | ||
| """ | ||
|
|
||
| from .lightning_model import Glass | ||
|
|
||
| __all__ = ["Glass"] | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,19 @@ | ||
| # Copyright (C) 2025 Intel Corporation | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| """Utility functions for GLASS Model.""" | ||
|
|
||
| from .aggregator import Aggregator | ||
| from .discriminator import Discriminator | ||
| from .patch_maker import PatchMaker | ||
| from .preprocessing import Preprocessing | ||
| from .projection import Projection | ||
| from .rescale_segmentor import RescaleSegmentor | ||
|
|
||
| __all__ = ["Aggregator", | ||
| "Discriminator", | ||
| "PatchMaker", | ||
| "Preprocessing", | ||
| "Projection", | ||
| "RescaleSegmentor", | ||
| ] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,25 @@ | ||
| # Copyright (C) 2025 Intel Corporation | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| """Aggregates and reshapes features to a target dimension.""" | ||
|
|
||
| import torch | ||
| import torch.nn.functional as f | ||
|
|
||
|
|
||
| class Aggregator(torch.nn.Module): | ||
| """Aggregates and reshapes features to a target dimension. | ||
|
|
||
| Input: Multi-dimensional feature tensors | ||
| Output: Reshaped and pooled features of specified target dimension | ||
| """ | ||
|
|
||
| def __init__(self, target_dim: int) -> None: | ||
| super().__init__() | ||
| self.target_dim = target_dim | ||
|
|
||
| def forward(self, features: torch.Tensor) -> torch.Tensor: | ||
| """Returns reshaped and average pooled features.""" | ||
| features = features.reshape(len(features), 1, -1) | ||
| features = f.adaptive_avg_pool1d(features, self.target_dim) | ||
| return features.reshape(len(features), -1) |
52 changes: 52 additions & 0 deletions
52
src/anomalib/models/image/glass/components/discriminator.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,52 @@ | ||
| # Copyright (C) 2025 Intel Corporation | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| """Discriminator network for anomaly detection.""" | ||
|
|
||
| import torch | ||
|
|
||
| from .init_weight import init_weight | ||
|
|
||
|
|
||
| class Discriminator(torch.nn.Module): | ||
| """Discriminator network for anomaly detection. | ||
|
|
||
| Args: | ||
| in_planes: Input feature dimension | ||
| n_layers: Number of layers | ||
| hidden: Hidden layer dimensions | ||
| """ | ||
|
|
||
| def __init__(self, in_planes: int, n_layers: int = 2, hidden: int | None = None) -> None: | ||
| super().__init__() | ||
|
|
||
| hidden_ = in_planes if hidden is None else hidden | ||
| self.body = torch.nn.Sequential() | ||
| for i in range(n_layers - 1): | ||
| in_ = in_planes if i == 0 else hidden_ | ||
| hidden_ = int(hidden_ // 1.5) if hidden is None else hidden | ||
| self.body.add_module( | ||
| f"block{i + 1}", | ||
| torch.nn.Sequential( | ||
| torch.nn.Linear(in_, hidden_), | ||
| torch.nn.BatchNorm1d(hidden_), | ||
| torch.nn.LeakyReLU(0.2), | ||
| ), | ||
| ) | ||
| self.tail = torch.nn.Sequential( | ||
| torch.nn.Linear(hidden_, 1, bias=False), | ||
| torch.nn.Sigmoid(), | ||
| ) | ||
| self.apply(init_weight) | ||
|
|
||
| def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
| """Performs a forward pass through the discriminator network. | ||
|
|
||
| Args: | ||
| x (torch.Tensor): Input tensor of shape (B, in_planes), where B is the batch size. | ||
|
|
||
| Returns: | ||
| torch.Tensor: Output tensor of shape (B, 1) containing probability scores. | ||
| """ | ||
| x = self.body(x) | ||
| return self.tail(x) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,22 @@ | ||
| # Copyright (C) 2025 Intel Corporation | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| """Initializes network weights using Xavier normal initialization.""" | ||
|
|
||
| import torch | ||
| from torch import nn | ||
|
|
||
|
|
||
| def init_weight(m: nn.Module) -> None: | ||
| """Initializes network weights using Xavier normal initialization. | ||
|
|
||
| Applies Xavier initialization for linear layers and normal initialization | ||
| for convolutional and batch normalization layers. | ||
| """ | ||
| if isinstance(m, torch.nn.Linear): | ||
| torch.nn.init.xavier_normal_(m.weight) | ||
| if isinstance(m, torch.nn.BatchNorm2d): | ||
| m.weight.data.normal_(1.0, 0.02) | ||
| m.bias.data.fill_(0) | ||
| elif isinstance(m, torch.nn.Conv2d): | ||
| m.weight.data.normal_(0.0, 0.02) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,88 @@ | ||
| # Copyright (C) 2025 Intel Corporation | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| """Handles patch-based processing of feature maps.""" | ||
|
|
||
| import torch | ||
|
|
||
|
|
||
| class PatchMaker: | ||
| """Handles patch-based processing of feature maps. | ||
|
|
||
| This class provides utilities for converting feature maps into patches, | ||
| reshaping patch scores back to original dimensions, and computing global | ||
| anomaly scores from patch-wise predictions. | ||
|
|
||
| Attributes: | ||
| patchsize (int): Size of each patch (patchsize x patchsize). | ||
| stride (int or None): Stride used for patch extraction. Defaults to patchsize if None. | ||
| """ | ||
|
|
||
| def __init__(self, patchsize: int, stride: int | None = None) -> None: | ||
| self.patchsize = patchsize | ||
| self.stride = stride if stride is not None else patchsize | ||
|
|
||
| def patchify( | ||
| self, | ||
| features: torch.Tensor, | ||
| return_spatial_info: bool = False, | ||
| ) -> tuple[torch.Tensor, list[int]] | torch.Tensor: | ||
| """Converts a batch of feature maps into patches. | ||
|
|
||
| Args: | ||
| features (torch.Tensor): Input feature maps of shape (B, C, H, W). | ||
| return_spatial_info (bool): If True, also returns spatial patch count. Default is False. | ||
|
|
||
| Returns: | ||
| torch.Tensor: Output tensor of shape (B, N, C, patchsize, patchsize), where N is number of patches. | ||
| list[int], optional: Number of patches in (height, width) dimensions, only if return_spatial_info is True. | ||
| """ | ||
| padding = int((self.patchsize - 1) / 2) | ||
| unfolder = torch.nn.Unfold( | ||
| kernel_size=self.patchsize, | ||
| stride=self.stride, | ||
| padding=padding, | ||
| dilation=1, | ||
| ) | ||
| unfolded_features = unfolder(features) | ||
| number_of_total_patches = [] | ||
| for s in features.shape[-2:]: | ||
| n_patches = (s + 2 * padding - 1 * (self.patchsize - 1) - 1) / self.stride + 1 | ||
| number_of_total_patches.append(int(n_patches)) | ||
| unfolded_features = unfolded_features.reshape( | ||
| *features.shape[:2], | ||
| self.patchsize, | ||
| self.patchsize, | ||
| -1, | ||
| ) | ||
| unfolded_features = unfolded_features.permute(0, 4, 1, 2, 3) | ||
|
|
||
| if return_spatial_info: | ||
| return unfolded_features, number_of_total_patches | ||
| return unfolded_features | ||
|
|
||
| @staticmethod | ||
| def unpatch_scores(x: torch.Tensor, batchsize: int) -> torch.Tensor: | ||
| """Reshapes patch scores back into per-batch format. | ||
|
|
||
| Args: | ||
| x (torch.Tensor): Input tensor of shape (B * N, ...). | ||
| batchsize (int): Original batch size. | ||
|
|
||
| Returns: | ||
| torch.Tensor: Reshaped tensor of shape (B, N, ...). | ||
| """ | ||
| return x.reshape(batchsize, -1, *x.shape[1:]) | ||
|
|
||
| @staticmethod | ||
| def compute_score(x: torch.Tensor) -> torch.Tensor: | ||
| """Computes final anomaly scores from patch-wise predictions. | ||
|
|
||
| Args: | ||
| x (torch.Tensor): Patch scores of shape (B, N, 1). | ||
|
|
||
| Returns: | ||
| torch.Tensor: Final anomaly score per image, shape (B,). | ||
| """ | ||
| x = x[:, :, 0] # remove last dimension if singleton | ||
| return torch.max(x, dim=1).values |
66 changes: 66 additions & 0 deletions
66
src/anomalib/models/image/glass/components/preprocessing.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,66 @@ | ||
| # Copyright (C) 2025 Intel Corporation | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| """Maps input features to a fixed dimension using adaptive average pooling.""" | ||
|
|
||
| import torch | ||
| import torch.nn.functional as f | ||
|
|
||
|
|
||
| class MeanMapper(torch.nn.Module): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This looks very similar to the Aggregator class. |
||
| """Maps input features to a fixed dimension using adaptive average pooling. | ||
|
|
||
| Input: Variable-sized feature tensors | ||
| Output: Fixed-size feature representations | ||
| """ | ||
|
|
||
| def __init__(self, preprocessing_dim: int) -> None: | ||
| super().__init__() | ||
| self.preprocessing_dim = preprocessing_dim | ||
|
|
||
| def forward(self, features: torch.Tensor) -> torch.Tensor: | ||
| """Applies adaptive average pooling to reshape features to a fixed size. | ||
|
|
||
| Args: | ||
| features (torch.Tensor): Input tensor of shape (B, *) where * denotes | ||
| any number of remaining dimensions. It is flattened before pooling. | ||
|
|
||
| Returns: | ||
| torch.Tensor: Output tensor of shape (B, D), where D is `preprocessing_dim`. | ||
| """ | ||
| features = features.reshape(len(features), 1, -1) | ||
| return f.adaptive_avg_pool1d(features, self.preprocessing_dim).squeeze(1) | ||
|
|
||
|
|
||
| class Preprocessing(torch.nn.Module): | ||
| """Handles initial feature preprocessing across multiple input dimensions. | ||
|
|
||
| Input: List of features from different backbone layers | ||
| Output: Processed features with consistent dimensionality | ||
| """ | ||
|
|
||
| def __init__(self, input_dims: list[int | tuple[int, int]], output_dim: int) -> None: | ||
| super().__init__() | ||
| self.input_dims = input_dims | ||
| self.output_dim = output_dim | ||
|
|
||
| self.preprocessing_modules = torch.nn.ModuleList() | ||
| for _ in input_dims: | ||
| module = MeanMapper(output_dim) | ||
| self.preprocessing_modules.append(module) | ||
|
|
||
| def forward(self, features: list[torch.Tensor]) -> torch.Tensor: | ||
| """Applies preprocessing modules to a list of input feature tensors. | ||
|
|
||
| Args: | ||
| features (list of torch.Tensor): List of feature maps from different | ||
| layers of the backbone network. Each tensor can have a different shape. | ||
|
|
||
| Returns: | ||
| torch.Tensor: A single tensor with shape (B, N, D), where B is the batch size, | ||
| N is the number of feature maps, and D is the output dimension (`output_dim`). | ||
| """ | ||
| features_ = [] | ||
| for module, feature in zip(self.preprocessing_modules, features, strict=False): | ||
| features_.append(module(feature)) | ||
| return torch.stack(features_, dim=1) | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.