Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
5b4931b
Initial Implementation of GLASS Model
code-dev05 Mar 26, 2025
4789f49
Created the trainer class for glass model
code-dev05 Apr 14, 2025
050fd4c
Added suggested changes
code-dev05 Apr 27, 2025
cdd0984
Modified forward method for model
code-dev05 Apr 27, 2025
381eec6
Fixed backbone loading logic
code-dev05 Apr 30, 2025
9b1c51a
Added type for input shape
code-dev05 May 4, 2025
161005c
Fixed bugs
code-dev05 May 4, 2025
3d78beb
Merge branch 'main' into feature/model/glass
samet-akcay May 7, 2025
617cf49
Changed files as needed
code-dev05 May 13, 2025
f9d3207
Merge remote-tracking branch 'origin/feature/model/glass' into featur…
code-dev05 May 13, 2025
7fea20f
Matched code to the original implementation
code-dev05 Jun 19, 2025
1beedf5
Added support for gpu
code-dev05 Jun 23, 2025
838bc50
Refactored code from lightning model to torch model
code-dev05 Jul 1, 2025
1baa0b7
GPU bug fixed
code-dev05 Jul 2, 2025
f066b3c
used image device in torch model
code-dev05 Jul 2, 2025
6e780b0
fixed bug
code-dev05 Jul 2, 2025
b1be6f5
Added validation step
code-dev05 Jul 11, 2025
20d97dd
Merge branch 'main' into feature/model/glass
samet-akcay Jul 14, 2025
d5affe4
Refactored code for better readability
code-dev05 Jul 28, 2025
f008537
Merge remote-tracking branch 'origin/feature/model/glass' into featur…
code-dev05 Jul 28, 2025
a1097e5
Set automatic optimization to False and made component functions
code-dev05 Jul 31, 2025
7e9d4d4
Resolved third-party-programs.txt conflict
code-dev05 Aug 5, 2025
44dcd60
Added automated download for dtd dataset in Glass Model
code-dev05 Aug 12, 2025
da57095
Removed some input args
code-dev05 Aug 14, 2025
ba5a6dd
Change in default parameters
code-dev05 Aug 14, 2025
714a3c3
Fixed default backbone name
code-dev05 Aug 14, 2025
1a3519c
Changed configure pre_processor method
code-dev05 Aug 14, 2025
9e12285
Merge remote-tracking branch 'up/main' into feature/model/glass
code-dev05 Sep 13, 2025
5466d46
Made some changes to the workflow of GLASS Model
code-dev05 Sep 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/anomalib/models/components/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@

from .base import AnomalibModule, BufferListMixin, DynamicBufferMixin, MemoryBankMixin
from .dimensionality_reduction import PCA, SparseRandomProjection
from .feature_extractors import TimmFeatureExtractor
from .feature_extractors import TimmFeatureExtractor, NetworkFeatureAggregator
from .filters import GaussianBlur2d
from .sampling import KCenterGreedy
from .stats import GaussianKDE, MultiVariateGaussian
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

from .timm import TimmFeatureExtractor
from .utils import dryrun_find_featuremap_dims

from .network_feature_extractor import NetworkFeatureAggregator
__all__ = [
"dryrun_find_featuremap_dims",
"TimmFeatureExtractor",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import torch
from torch import nn
import copy


class NetworkFeatureAggregator(torch.nn.Module):
"""Efficient extraction of network features."""

def __init__(self, backbone, layers_to_extract_from, pre_trained=False):
super(NetworkFeatureAggregator, self).__init__()
"""Extraction of network features.

Runs a network only to the last layer of the list of layers where
network features should be extracted from.

Args:
backbone: torchvision.model
layers_to_extract_from: [list of str]
"""
self.layers_to_extract_from = layers_to_extract_from
self.backbone = backbone
self.pre_trained = pre_trained
if not hasattr(backbone, "hook_handles"):
self.backbone.hook_handles = []
for handle in self.backbone.hook_handles:
handle.remove()
self.outputs = {}

for extract_layer in layers_to_extract_from:
self.register_hook(extract_layer)

self.to(self.device)

def forward(self, images, eval=True):
self.outputs.clear()
if not self.pre_trained and not eval:
self.backbone(images)
else:
with torch.no_grad():
try:
_ = self.backbone(images)
except LastLayerToExtractReachedException:
pass
return self.outputs

def feature_dimensions(self, input_shape):
"""Computes the feature dimensions for all layers given input_shape."""
_input = torch.ones([1] + list(input_shape)).to(self.device)
_output = self(_input)
return [_output[layer].shape[1] for layer in self.layers_to_extract_from]

def register_hook(self, layer_name):
module = self.find_module(self.backbone, layer_name)
if module is not None:
forward_hook = ForwardHook(
self.outputs, layer_name, self.layers_to_extract_from[-1]
)
if isinstance(module, torch.nn.Sequential):
hook = module[-1].register_forward_hook(forward_hook)
else:
hook = module.register_forward_hook(forward_hook)
self.backbone.hook_handles.append(hook)
else:
raise ValueError(f"Module {layer_name} not found in the model")

def find_module(self, model, module_name):
for name, module in model.named_modules():
if name == module_name:
return module
elif "." in module_name:
father, child = module_name.split(".", 1)
if name == father:
return self.find_module(module, child)
return None


class ForwardHook:
def __init__(self, hook_dict, layer_name: str, last_layer_to_extract: str):
self.hook_dict = hook_dict
self.layer_name = layer_name
self.raise_exception_to_break = copy.deepcopy(
layer_name == last_layer_to_extract
)

def __call__(self, module, input, output):
self.hook_dict[self.layer_name] = output
return None


class LastLayerToExtractReachedException(Exception):
pass
1 change: 1 addition & 0 deletions src/anomalib/models/image/glass/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .lightning_model import Glass
50 changes: 50 additions & 0 deletions src/anomalib/models/image/glass/backbones.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import torchvision.models as models
import timm

_BACKBONES = {
"alexnet": "models.alexnet(pretrained=True)",
"resnet18": "models.resnet18(pretrained=True)",
"resnet50": "models.resnet50(pretrained=True)",
"resnet101": "models.resnet101(pretrained=True)",
"resnext101": "models.resnext101_32x8d(pretrained=True)",
"resnet200": 'timm.create_model("resnet200", pretrained=True)',
"resnest50": 'timm.create_model("resnest50d_4s2x40d", pretrained=True)',
"resnetv2_50_bit": 'timm.create_model("resnetv2_50x3_bitm", pretrained=True)',
"resnetv2_50_21k": 'timm.create_model("resnetv2_50x3_bitm_in21k", pretrained=True)',
"resnetv2_101_bit": 'timm.create_model("resnetv2_101x3_bitm", pretrained=True)',
"resnetv2_101_21k": 'timm.create_model("resnetv2_101x3_bitm_in21k", pretrained=True)',
"resnetv2_152_bit": 'timm.create_model("resnetv2_152x4_bitm", pretrained=True)',
"resnetv2_152_21k": 'timm.create_model("resnetv2_152x4_bitm_in21k", pretrained=True)',
"resnetv2_152_384": 'timm.create_model("resnetv2_152x2_bit_teacher_384", pretrained=True)',
"resnetv2_101": 'timm.create_model("resnetv2_101", pretrained=True)',
"vgg11": "models.vgg11(pretrained=True)",
"vgg19": "models.vgg19(pretrained=True)",
"vgg19_bn": "models.vgg19_bn(pretrained=True)",
"wideresnet50": "models.wide_resnet50_2(pretrained=True)",
"wideresnet101": "models.wide_resnet101_2(pretrained=True)",
"mnasnet_100": 'timm.create_model("mnasnet_100", pretrained=True)',
"mnasnet_a1": 'timm.create_model("mnasnet_a1", pretrained=True)',
"mnasnet_b1": 'timm.create_model("mnasnet_b1", pretrained=True)',
"densenet121": 'timm.create_model("densenet121", pretrained=True)',
"densenet201": 'timm.create_model("densenet201", pretrained=True)',
"inception_v4": 'timm.create_model("inception_v4", pretrained=True)',
"vit_small": 'timm.create_model("vit_small_patch16_224", pretrained=True)',
"vit_base": 'timm.create_model("vit_base_patch16_224", pretrained=True)',
"vit_large": 'timm.create_model("vit_large_patch16_224", pretrained=True)',
"vit_r50": 'timm.create_model("vit_large_r50_s32_224", pretrained=True)',
"vit_deit_base": 'timm.create_model("deit_base_patch16_224", pretrained=True)',
"vit_deit_distilled": 'timm.create_model("deit_base_distilled_patch16_224", pretrained=True)',
"vit_swin_base": 'timm.create_model("swin_base_patch4_window7_224", pretrained=True)',
"vit_swin_large": 'timm.create_model("swin_large_patch4_window7_224", pretrained=True)',
"efficientnet_b7": 'timm.create_model("tf_efficientnet_b7", pretrained=True)',
"efficientnet_b5": 'timm.create_model("tf_efficientnet_b5", pretrained=True)',
"efficientnet_b3": 'timm.create_model("tf_efficientnet_b3", pretrained=True)',
"efficientnet_b1": 'timm.create_model("tf_efficientnet_b1", pretrained=True)',
"efficientnetv2_m": 'timm.create_model("tf_efficientnetv2_m", pretrained=True)',
"efficientnetv2_l": 'timm.create_model("tf_efficientnetv2_l", pretrained=True)',
"efficientnet_b3a": 'timm.create_model("efficientnet_b3a", pretrained=True)',
}


def load(name):
return eval(_BACKBONES[name])
161 changes: 161 additions & 0 deletions src/anomalib/models/image/glass/lightning_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
import torch
from torch import nn
from torch import optim

from lightning.pytorch.utilities.types import STEP_OUTPUT

from anomalib.data import Batch
from anomalib.models.components import AnomalibModule
from anomalib.models.components import AnomalibModule
from anomalib.metrics import Evaluator
from anomalib.post_processing import PostProcessor
from anomalib.pre_processing import PreProcessor
from anomalib.visualization import Visualizer

from .loss import FocalLoss
from .torch_model import GlassModel

class Glass(AnomalibModule):
def __init__(
self,
backbone,
input_shape,
pretrain_embed_dim,
target_embed_dim,
patchsize: int = 3,
patchstride: int = 1,
pre_trained: bool = True,
layers: list[str] = ["layer1", "layer2", "layer3"],
pre_proj: int = 1,
dsc_layers: int = 2,
dsc_hidden: int = 1024,
dsc_margin: int = 0.5,
pre_processor: PreProcessor | bool = True,
post_processor: PostProcessor | bool = True,
evaluator: Evaluator | bool = True,
visualizer: Visualizer | bool = True,
mining: int = 1,
noise: float = 0.015,
radius: float = 0.75,
p: float = 0.5,
lr: int = 0.0001,
step: int = 0
):
super().__init__(
pre_processor=pre_processor,
post_processor=post_processor,
evaluator=evaluator,
visualizer=visualizer,
)

self.model = GlassModel(
input_shape=input_shape,
pretrain_embed_dim=pretrain_embed_dim,
target_embed_dim=target_embed_dim,
backbone=backbone,
pre_trained=pre_trained,
patchsize=patchsize,
patchstride=patchstride,
layers=layers,
pre_proj=pre_proj,
dsc_layers=dsc_layers,
dsc_hidden=dsc_hidden,
dsc_margin=dsc_margin
)

self.p = p
self.radius = radius
self.mining = mining
self.noise = noise
self.distribution = 0
self.lr = lr
self.step = step

self.focal_loss = FocalLoss()

def configure_optimizers(self) -> list[optim.Optimizer]:
optimizers = []
if not self.model.pre_trained:
backbone_opt = optim.AdamW(self.model.foward_modules["feature_aggregator"].backbone.parameters(), self.lr)
optimizers.append(backbone_opt)
else:
optimizers.append(None)

if self.model.pre_proj > 0:
proj_opt = optim.AdamW(self.model.pre_projection.parameters(), self.lr, weight_decay=1e-5)
optimizers.append(proj_opt)
else:
optimizers.append(None)

dsc_opt = optim.AdamW(self.model.discriminator.parameters(), lr=self.lr * 2)
optimizers.append(dsc_opt)

return optimizers

def training_step(
self,
batch: Batch,
batch_idx: int
) -> STEP_OUTPUT:
backbone_opt, proj_opt, dsc_opt = self.optimizers()

self.model.forward_modules.eval()
if self.model.pre_proj > 0:
self.pre_projection.train()
self.model.discriminator.train()

dsc_opt.zero_grad()
if proj_opt is not None:
proj_opt.zero_grad()
if backbone_opt is not None:
backbone_opt.zero_grad()

img = batch.image
aug = batch.aug

true_feats, fake_feats = self.model(img, aug)

mask_s_gt = batch.mask_s.reshape(-1, 1)
noise = torch.normal(0, self.noise, true_feats.shape)
gaus_feats = true_feats + noise

for step in range(self.step + 1):
scores = self.model.discriminator(torch.cat([true_feats, gaus_feats]))
true_scores = scores[:len(true_feats)]
gaus_scores = scores[len(true_feats):]
true_loss = nn.BCELoss()(true_scores, torch.zeros_like(true_scores))
gaus_loss = nn.BCELoss()(gaus_scores, torch.ones_like(gaus_scores))
bce_loss = true_loss + gaus_loss

if step == self.step:
break

grad = torch.autograd.grad(gaus_loss, [gaus_feats])[0]
grad_norm = torch.norm(grad, dim=1)
grad_norm = grad_norm.view(-1, 1)
grad_normalized = grad / (grad_norm + 1e-10)

with torch.no_grad():
gaus_feats.add_(0.001 * grad_normalized)

fake_scores = self.model.discriminator(fake_feats)

if self.p > 0:
fake_dist = (fake_scores - mask_s_gt) ** 2
d_hard = torch.quantile(fake_dist, q=self.p)
take_scores_ = fake_scores[fake_dist >= d_hard].unsqueeze(1)
mask_ = mask_s_gt[fake_dist >= d_hard].unsqueeze(1)
else:
fake_scores_ = fake_scores
mask_ = mask_s_gt
output = torch.cat([1 - fake_scores_, fake_scores_], dim=1)
focal_loss = self.focal_loss(output, mask_)

loss = bce_loss + focal_loss
loss.backward()

if proj_opt is not None:
proj_opt.step()
if backbone_opt is not None:
backbone_opt.step()
dsc_opt.step()
Loading