zhiqing0205
commited on
Commit
·
74acc06
1
Parent(s):
4a80644
Add basic Python scripts and documentation
Browse files- LogSAD技术详解.md +621 -0
- README.md +102 -0
- compute_coreset.py +121 -0
- environment.yml +116 -0
- evaluation.py +257 -0
- imagenet_template.py +82 -0
- model_ensemble.py +1034 -0
- model_ensemble_few_shot.py +935 -0
- prompt_ensemble.py +121 -0
- requirements.txt +77 -0
LogSAD技术详解.md
ADDED
|
@@ -0,0 +1,621 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# LogSAD:基于视觉和语言基础模型的无训练异常检测方法详解
|
| 2 |
+
|
| 3 |
+
## 项目概述
|
| 4 |
+
|
| 5 |
+
LogSAD(Towards Training-free Anomaly Detection with Vision and Language Foundation Models)是一个发表在CVPR 2025的无需训练的异常检测方法。该方法通过结合多个预训练的视觉和语言基础模型,实现了对MVTec LOCO数据集的逻辑异常和结构异常检测。
|
| 6 |
+
|
| 7 |
+
## 整体架构与流程
|
| 8 |
+
|
| 9 |
+
### 核心理念
|
| 10 |
+
LogSAD的核心思想是利用预训练模型的强大表示能力,通过多模态特征融合和逻辑推理来检测异常,无需对特定数据集进行训练。
|
| 11 |
+
|
| 12 |
+
### 系统架构
|
| 13 |
+
```
|
| 14 |
+
输入图像 (448x448)
|
| 15 |
+
↓
|
| 16 |
+
┌─────────────────────────────────────────────────┐
|
| 17 |
+
│ 多模态特征提取层 │
|
| 18 |
+
│ ├─ CLIP ViT-L-14 (图像+文本特征) │
|
| 19 |
+
│ ├─ DINOv2 ViT-L-14 (图像特征) │
|
| 20 |
+
│ └─ SAM ViT-H (实例分割) │
|
| 21 |
+
└─────────────────────────────────────────────────┘
|
| 22 |
+
↓
|
| 23 |
+
┌─────────────────────────────────────────────────┐
|
| 24 |
+
│ 特征处理与融合层 │
|
| 25 |
+
│ ├─ K-means聚类分割 │
|
| 26 |
+
│ ├─ 文本引导的语义分割 │
|
| 27 |
+
│ └─ 多尺度特征融合 │
|
| 28 |
+
└─────────────────────────────────────────────────┘
|
| 29 |
+
↓
|
| 30 |
+
┌─────────────────────────────────────────────────┐
|
| 31 |
+
│ 异常检测层 │
|
| 32 |
+
│ ├─ 结构异常检测 (PatchCore) │
|
| 33 |
+
│ ├─ 逻辑异常检测 (直方图匹配) │
|
| 34 |
+
│ └─ 实例匹配检测 (Hungarian算法) │
|
| 35 |
+
└─────────────────────────────────────────────────┘
|
| 36 |
+
↓
|
| 37 |
+
最终异常分数
|
| 38 |
+
```
|
| 39 |
+
|
| 40 |
+
## 预训练模型详解
|
| 41 |
+
|
| 42 |
+
### 1. CLIP ViT-L-14 模型
|
| 43 |
+
**作用**:视觉-语言理解的核心
|
| 44 |
+
- **模型**:`hf-hub:laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K`
|
| 45 |
+
- **输入尺寸**:448×448
|
| 46 |
+
- **特征提取层**:[6, 12, 18, 24]
|
| 47 |
+
- **特征维度**:1024维
|
| 48 |
+
- **输出特征尺寸**:32×32 → 64×64(插值)
|
| 49 |
+
|
| 50 |
+
**具体实现**:
|
| 51 |
+
```python
|
| 52 |
+
# model_ensemble.py:96-97
|
| 53 |
+
self.model_clip, _, _ = open_clip.create_model_and_transforms('hf-hub:laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K')
|
| 54 |
+
self.feature_list = [6, 12, 18, 24]
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
**协作机制**:
|
| 58 |
+
- 提供图像的语义特征表示
|
| 59 |
+
- 通过文本提示编码不同物体的语义信息
|
| 60 |
+
- 用于语义分割和异常分类
|
| 61 |
+
|
| 62 |
+
### 2. DINOv2 ViT-L-14 模型
|
| 63 |
+
**作用**:提供更丰富的视觉特征
|
| 64 |
+
- **模型**:`dinov2_vitl14`
|
| 65 |
+
- **特征提取层**:[6, 12, 18, 24]
|
| 66 |
+
- **特征维度**:1024维
|
| 67 |
+
- **输出特征尺寸**:32×32 → 64×64(插值)
|
| 68 |
+
|
| 69 |
+
**具体实现**:
|
| 70 |
+
```python
|
| 71 |
+
# model_ensemble.py:181-186
|
| 72 |
+
from dinov2.dinov2.hub.backbones import dinov2_vitl14
|
| 73 |
+
self.model_dinov2 = dinov2_vitl14()
|
| 74 |
+
self.feature_list_dinov2 = [6, 12, 18, 24]
|
| 75 |
+
```
|
| 76 |
+
|
| 77 |
+
**协作机制**:
|
| 78 |
+
- 为某些类别(splicing_connectors, breakfast_box, juice_bottle)提供更强的视觉特征
|
| 79 |
+
- 与CLIP特征互补,提高检测精度
|
| 80 |
+
|
| 81 |
+
### 3. SAM (Segment Anything Model)
|
| 82 |
+
**作用**:实例分割
|
| 83 |
+
- **模型**:ViT-H版本
|
| 84 |
+
- **检查点**:`./checkpoint/sam_vit_h_4b8939.pth`
|
| 85 |
+
- **功能**:自动生成物体mask
|
| 86 |
+
|
| 87 |
+
**具体实现**:
|
| 88 |
+
```python
|
| 89 |
+
# model_ensemble.py:102-103
|
| 90 |
+
self.model_sam = sam_model_registry["vit_h"](checkpoint = "./checkpoint/sam_vit_h_4b8939.pth")
|
| 91 |
+
self.mask_generator = SamAutomaticMaskGenerator(model = self.model_sam)
|
| 92 |
+
```
|
| 93 |
+
|
| 94 |
+
**协作机制**:
|
| 95 |
+
- 提供精确的物体边界
|
| 96 |
+
- 用于实例级别的异常检测
|
| 97 |
+
- 与语义分割结果融合
|
| 98 |
+
|
| 99 |
+
## 数据处理与尺寸变换详解
|
| 100 |
+
|
| 101 |
+
### 图像预处理流程
|
| 102 |
+
|
| 103 |
+
1. **输入尺寸标准化**:
|
| 104 |
+
```python
|
| 105 |
+
# evaluation.py:184
|
| 106 |
+
datamodule = MVTecLoco(root=dataset_path, eval_batch_size=1, image_size=(448, 448), category=category)
|
| 107 |
+
```
|
| 108 |
+
|
| 109 |
+
2. **归一化处理**:
|
| 110 |
+
```python
|
| 111 |
+
# model_ensemble.py:88-92
|
| 112 |
+
self.transform = v2.Compose([
|
| 113 |
+
v2.Normalize(mean=(0.48145466, 0.4578275, 0.40821073),
|
| 114 |
+
std=(0.26862954, 0.26130258, 0.27577711)),
|
| 115 |
+
])
|
| 116 |
+
```
|
| 117 |
+
|
| 118 |
+
3. **特征图尺寸变换**:
|
| 119 |
+
```python
|
| 120 |
+
# model_ensemble.py:155-156
|
| 121 |
+
self.feat_size = 64 # 目标特征图大小
|
| 122 |
+
self.ori_feat_size = 32 # 原始特征图大小
|
| 123 |
+
```
|
| 124 |
+
|
| 125 |
+
### 详细的Resize流程
|
| 126 |
+
|
| 127 |
+
**CLIP特征处理**:
|
| 128 |
+
```python
|
| 129 |
+
# model_ensemble.py:245-255
|
| 130 |
+
# 1. ��32x32插值到64x64
|
| 131 |
+
patch_tokens_clip = patch_tokens_clip.view(1, self.ori_feat_size, self.ori_feat_size, -1).permute(0, 3, 1, 2)
|
| 132 |
+
patch_tokens_clip = F.interpolate(patch_tokens_clip, size=(self.feat_size, self.feat_size),
|
| 133 |
+
mode=self.inter_mode, align_corners=self.align_corners)
|
| 134 |
+
patch_tokens_clip = patch_tokens_clip.permute(0, 2, 3, 1).view(-1, self.vision_width * len(self.feature_list))
|
| 135 |
+
```
|
| 136 |
+
|
| 137 |
+
**DINOv2特征处理**:
|
| 138 |
+
```python
|
| 139 |
+
# model_ensemble.py:253-263
|
| 140 |
+
# 相同的插值流程
|
| 141 |
+
patch_tokens_dinov2 = F.interpolate(patch_tokens_dinov2, size=(self.feat_size, self.feat_size),
|
| 142 |
+
mode=self.inter_mode, align_corners=self.align_corners)
|
| 143 |
+
```
|
| 144 |
+
|
| 145 |
+
**插值参数**:
|
| 146 |
+
- **插值模式**:双线性插值(`bilinear`)
|
| 147 |
+
- **对齐角点**:`align_corners=True`
|
| 148 |
+
- **抗锯齿**:`antialias=True`
|
| 149 |
+
|
| 150 |
+
## SAM多Mask处理机制
|
| 151 |
+
|
| 152 |
+
### SAM生成多个Mask的处理
|
| 153 |
+
|
| 154 |
+
**Mask生成**:
|
| 155 |
+
```python
|
| 156 |
+
# model_ensemble.py:394
|
| 157 |
+
masks = self.mask_generator.generate(raw_image)
|
| 158 |
+
sorted_masks = sorted(masks, key=(lambda x: x['area']), reverse=True)
|
| 159 |
+
```
|
| 160 |
+
|
| 161 |
+
**Mask融合策略**:
|
| 162 |
+
```python
|
| 163 |
+
# model_ensemble.py:347-367
|
| 164 |
+
def merge_segmentations(a, b, background_class):
|
| 165 |
+
"""将SAM mask与语义分割结果融合"""
|
| 166 |
+
# 通过投票机制确定每个SAM区域的语义标签
|
| 167 |
+
for label_a in unique_labels_a:
|
| 168 |
+
mask_a = (a == label_a)
|
| 169 |
+
labels_b = b[mask_a]
|
| 170 |
+
if labels_b.size > 0:
|
| 171 |
+
count_b = np.bincount(labels_b, minlength=unique_labels_b.max() + 1)
|
| 172 |
+
label_map[label_a] = np.argmax(count_b) # 多数投票
|
| 173 |
+
```
|
| 174 |
+
|
| 175 |
+
**多Mask协作流程**:
|
| 176 |
+
1. SAM生成所有可能的实例mask
|
| 177 |
+
2. K-means聚类生成语义分割mask
|
| 178 |
+
3. 文本引导生成patch级别的语义mask
|
| 179 |
+
4. 通过投票机制融合不同来源的mask
|
| 180 |
+
5. 过滤小区域噪声(阈值:32像素)
|
| 181 |
+
|
| 182 |
+
## Ground Truth多Mask处理机制
|
| 183 |
+
|
| 184 |
+
### MVTec LOCO数据集的Mask组织结构
|
| 185 |
+
|
| 186 |
+
**文件结构**:
|
| 187 |
+
```
|
| 188 |
+
dataset/
|
| 189 |
+
├── test/category/image_filename.png # 测试图像
|
| 190 |
+
├── ground_truth/category/image_filename/ # 对应的GT mask目录
|
| 191 |
+
│ ├── 000.png # 第一个异常区域mask
|
| 192 |
+
│ ├── 001.png # 第二个异常区域mask
|
| 193 |
+
│ ├── 002.png # 第三个异常区域mask
|
| 194 |
+
│ └── ... # 更多异常区域mask
|
| 195 |
+
```
|
| 196 |
+
|
| 197 |
+
**数据加载时的多Mask聚合**:
|
| 198 |
+
```python
|
| 199 |
+
# anomalib/data/image/mvtec_loco.py:142-148
|
| 200 |
+
mask_samples = (
|
| 201 |
+
mask_samples.groupby(["path", "split", "label", "image_folder"])["image_path"]
|
| 202 |
+
.agg(list) # 将同一图像的多个mask路径聚合成列表
|
| 203 |
+
.reset_index()
|
| 204 |
+
.rename(columns={"image_path": "mask_path"})
|
| 205 |
+
)
|
| 206 |
+
```
|
| 207 |
+
|
| 208 |
+
### 多Mask融合策略
|
| 209 |
+
|
| 210 |
+
**步骤1:Mask路径处理**:
|
| 211 |
+
```python
|
| 212 |
+
# anomalib/data/image/mvtec_loco.py:279-280
|
| 213 |
+
if isinstance(mask_path, str):
|
| 214 |
+
mask_path = [mask_path] # 确保mask_path是列表格式
|
| 215 |
+
```
|
| 216 |
+
|
| 217 |
+
**步骤2:语义Mask堆叠**:
|
| 218 |
+
```python
|
| 219 |
+
# anomalib/data/image/mvtec_loco.py:281-285
|
| 220 |
+
semantic_mask = (
|
| 221 |
+
Mask(torch.zeros(image.shape[-2:])).to(torch.uint8) # 正常图像:零mask
|
| 222 |
+
if label_index == LabelName.NORMAL
|
| 223 |
+
else Mask(torch.stack([self._read_mask(path) for path in mask_path])) # 异常图像:堆叠所有mask
|
| 224 |
+
)
|
| 225 |
+
```
|
| 226 |
+
|
| 227 |
+
**步骤3:二值Mask生成**:
|
| 228 |
+
```python
|
| 229 |
+
# anomalib/data/image/mvtec_loco.py:287
|
| 230 |
+
binary_mask = Mask(semantic_mask.view(-1, *semantic_mask.shape[-2:]).int().any(dim=0).to(torch.uint8))
|
| 231 |
+
```
|
| 232 |
+
|
| 233 |
+
### 关键融合机制解析
|
| 234 |
+
|
| 235 |
+
**维度变换**:
|
| 236 |
+
- 输入:多个mask,每个形状为 (H, W)
|
| 237 |
+
- 堆叠后:(N, H, W),其中N为mask数量
|
| 238 |
+
- `view(-1, H, W)`:重塑为 (N, H, W)
|
| 239 |
+
- `any(dim=0)`:沿第一维度求或运算,得到 (H, W)
|
| 240 |
+
|
| 241 |
+
**融合逻辑**:
|
| 242 |
+
```python
|
| 243 |
+
# 伪代码示例
|
| 244 |
+
mask1 = [[0, 1, 0], mask2 = [[0, 0, 1],
|
| 245 |
+
[1, 0, 1], [0, 1, 0],
|
| 246 |
+
[0, 1, 0]] [1, 0, 0]]
|
| 247 |
+
|
| 248 |
+
# 堆叠:shape (2, 3, 3)
|
| 249 |
+
stacked = torch.stack([mask1, mask2])
|
| 250 |
+
|
| 251 |
+
# any操作:逐像素求或
|
| 252 |
+
result = [[0, 1, 1], # max(0,0), max(1,0), max(0,1)
|
| 253 |
+
[1, 1, 1], # max(1,0), max(0,1), max(1,0)
|
| 254 |
+
[1, 1, 0]] # max(0,1), max(1,0), max(0,0)
|
| 255 |
+
```
|
| 256 |
+
|
| 257 |
+
### 数据加载完整流程
|
| 258 |
+
|
| 259 |
+
**MVTec LOCO数据项结构**:
|
| 260 |
+
```python
|
| 261 |
+
# 正常样本
|
| 262 |
+
item = {
|
| 263 |
+
"image_path": "/path/to/normal_image.png",
|
| 264 |
+
"label": 0,
|
| 265 |
+
"image": torch.Tensor(...),
|
| 266 |
+
"mask": torch.zeros(H, W), # 零mask
|
| 267 |
+
"mask_path": [], # 空列表
|
| 268 |
+
"semantic_mask": torch.zeros(H, W) # 零mask
|
| 269 |
+
}
|
| 270 |
+
|
| 271 |
+
# 异常样本
|
| 272 |
+
item = {
|
| 273 |
+
"image_path": "/path/to/abnormal_image.png",
|
| 274 |
+
"label": 1,
|
| 275 |
+
"image": torch.Tensor(...),
|
| 276 |
+
"mask": torch.Tensor(...), # 融合后的二值mask
|
| 277 |
+
"mask_path": [ # 多个mask路径列表
|
| 278 |
+
"/path/to/ground_truth/image/000.png",
|
| 279 |
+
"/path/to/ground_truth/image/001.png",
|
| 280 |
+
"/path/to/ground_truth/image/002.png"
|
| 281 |
+
],
|
| 282 |
+
"semantic_mask": torch.Tensor(...) # 原始多mask堆叠,shape (N, H, W)
|
| 283 |
+
}
|
| 284 |
+
```
|
| 285 |
+
|
| 286 |
+
### 评估时的Mask使用
|
| 287 |
+
|
| 288 |
+
**重要特性**:LogSAD在推理过程中**不使用**ground truth mask,完全基于输入图像进行异常检测。Ground truth mask仅用于:
|
| 289 |
+
|
| 290 |
+
1. **性能评估**:计算AUROC、F1等指标
|
| 291 |
+
2. **可视化对比**:与预测结果对比
|
| 292 |
+
3. **指标计算**:像素级和语义级异常检测性能
|
| 293 |
+
|
| 294 |
+
**验证机制**:
|
| 295 |
+
```python
|
| 296 |
+
# anomalib/data/image/mvtec_loco.py:158-174
|
| 297 |
+
# 验证mask文件与图像文件的对应关系
|
| 298 |
+
image_stems = samples.loc[samples.label_index == LabelName.ABNORMAL]["image_path"].apply(lambda x: Path(x).stem)
|
| 299 |
+
mask_parent_stems = samples.loc[samples.label_index == LabelName.ABNORMAL]["mask_path"].apply(
|
| 300 |
+
lambda x: {Path(mask_path).parent.stem for mask_path in x},
|
| 301 |
+
)
|
| 302 |
+
# 确保 image: '005.png' 对应 mask: '005/000.png', '005/001.png' 等
|
| 303 |
+
```
|
| 304 |
+
|
| 305 |
+
### 多Mask场景的实际应用
|
| 306 |
+
|
| 307 |
+
**典型场景**:
|
| 308 |
+
1. **Splicing Connectors**:连接器、电缆、夹具可能分别标注
|
| 309 |
+
2. **Juice Bottle**:液体、标签、瓶身缺陷可能分别标注
|
| 310 |
+
3. **Breakfast Box**:不同食物的缺失可能分别标注
|
| 311 |
+
4. **Screw Bag**:不同螺丝、螺母、垫圈的异常分别标注
|
| 312 |
+
|
| 313 |
+
**处理优势**:
|
| 314 |
+
- 保留了详细的异常区域信息
|
| 315 |
+
- 支持多类型异常的联合评估
|
| 316 |
+
- 便于细粒度的性能分析
|
| 317 |
+
- 兼容传统二值异常检测评估
|
| 318 |
+
|
| 319 |
+
## 关键特判逻辑详解
|
| 320 |
+
|
| 321 |
+
代码中存在**5个主要特判分支**,分别对应不同的数据集类别:
|
| 322 |
+
|
| 323 |
+
### 1. Pushpins类别特判
|
| 324 |
+
|
| 325 |
+
**位置**:`model_ensemble.py:432-479`
|
| 326 |
+
|
| 327 |
+
**逻辑**:
|
| 328 |
+
```python
|
| 329 |
+
if self.class_name == 'pushpins':
|
| 330 |
+
# 1. 物体计数检测
|
| 331 |
+
pushpins_count = num_labels - 1
|
| 332 |
+
if self.few_shot_inited and pushpins_count != self.pushpins_count:
|
| 333 |
+
self.anomaly_flag = True
|
| 334 |
+
|
| 335 |
+
# 2. Patch直方图匹配
|
| 336 |
+
clip_patch_hist = np.bincount(patch_mask.reshape(-1), minlength=self.patch_query_obj.shape[0])
|
| 337 |
+
patch_hist_similarity = (clip_patch_hist @ self.patch_token_hist.T)
|
| 338 |
+
score = 1 - patch_hist_similarity.max()
|
| 339 |
+
```
|
| 340 |
+
|
| 341 |
+
**检测异常类型**:
|
| 342 |
+
- 推钉数量异常(标准数量:15个)
|
| 343 |
+
- 颜色分布异常
|
| 344 |
+
|
| 345 |
+
### 2. Splicing Connectors类别特判
|
| 346 |
+
|
| 347 |
+
**位置**:`model_ensemble.py:481-615`
|
| 348 |
+
|
| 349 |
+
**复杂逻辑**:
|
| 350 |
+
```python
|
| 351 |
+
elif self.class_name == 'splicing_connectors':
|
| 352 |
+
# 1. 连接组件检测
|
| 353 |
+
if count != 1:
|
| 354 |
+
self.anomaly_flag = True
|
| 355 |
+
|
| 356 |
+
# 2. 电缆颜色与夹具数量匹配检测
|
| 357 |
+
foreground_pixel_count = np.sum(erode_binary) / self.splicing_connectors_count[idx_color]
|
| 358 |
+
ratio = foreground_pixel_count / self.foreground_pixel_hist_splicing_connectors
|
| 359 |
+
if ratio > 1.2 or ratio < 0.8:
|
| 360 |
+
self.anomaly_flag = True
|
| 361 |
+
|
| 362 |
+
# 3. 左右对称性检测
|
| 363 |
+
ratio = np.sum(left_count) / (np.sum(right_count) + 1e-5)
|
| 364 |
+
if ratio > 1.2 or ratio < 0.8:
|
| 365 |
+
self.anomaly_flag = True
|
| 366 |
+
|
| 367 |
+
# 4. 距离检测
|
| 368 |
+
distance = np.sqrt((x1/w - x2/w)**2 + (y1/h - y2/h)**2)
|
| 369 |
+
ratio = distance / self.splicing_connectors_distance
|
| 370 |
+
if ratio < 0.6 or ratio > 1.4:
|
| 371 |
+
self.anomaly_flag = True
|
| 372 |
+
```
|
| 373 |
+
|
| 374 |
+
**检测异常类型**:
|
| 375 |
+
- 电缆断裂或缺失
|
| 376 |
+
- 颜色与夹具数量不匹配(黄色2夹、蓝色3夹、红色5夹)
|
| 377 |
+
- 左右夹具不对称
|
| 378 |
+
- 电缆长度异常
|
| 379 |
+
|
| 380 |
+
### 3. Screw Bag类别特判
|
| 381 |
+
|
| 382 |
+
**位置**:`model_ensemble.py:617-670`
|
| 383 |
+
|
| 384 |
+
**逻辑**:
|
| 385 |
+
```python
|
| 386 |
+
elif self.class_name == 'screw_bag':
|
| 387 |
+
# 前景像素统计异常检测
|
| 388 |
+
foreground_pixel_count = np.sum(np.bincount(kmeans_mask.reshape(-1))[:len(self.foreground_label_idx[self.class_name])])
|
| 389 |
+
ratio = foreground_pixel_count / self.foreground_pixel_hist_screw_bag
|
| 390 |
+
if ratio < 0.94 or ratio > 1.06:
|
| 391 |
+
self.anomaly_flag = True
|
| 392 |
+
```
|
| 393 |
+
|
| 394 |
+
**检测异常类型**:
|
| 395 |
+
- 螺丝、螺母、垫圈数量异常
|
| 396 |
+
- 前景像素比例异常(阈值:±6%)
|
| 397 |
+
|
| 398 |
+
### 4. Juice Bottle类别特判
|
| 399 |
+
|
| 400 |
+
**位置**:`model_ensemble.py:715-771`
|
| 401 |
+
|
| 402 |
+
**逻辑**:
|
| 403 |
+
```python
|
| 404 |
+
elif self.class_name == 'juice_bottle':
|
| 405 |
+
# 液体与水果匹配检测
|
| 406 |
+
liquid_idx = (liquid_feature @ query_liquid.T).argmax(-1).squeeze(0).item()
|
| 407 |
+
fruit_idx = (fruit_feature @ query_fruit.T).argmax(-1).squeeze(0).item()
|
| 408 |
+
if liquid_idx != fruit_idx:
|
| 409 |
+
self.anomaly_flag = True
|
| 410 |
+
```
|
| 411 |
+
|
| 412 |
+
**检测异常类型**:
|
| 413 |
+
- 液体颜色与标签水果不匹配
|
| 414 |
+
- 标签错位
|
| 415 |
+
|
| 416 |
+
### 5. Breakfast Box类别特判
|
| 417 |
+
|
| 418 |
+
**位置**:`model_ensemble.py:672-713`
|
| 419 |
+
|
| 420 |
+
**逻辑**:
|
| 421 |
+
```python
|
| 422 |
+
elif self.class_name == 'breakfast_box':
|
| 423 |
+
# 主要依靠patch直方图匹配
|
| 424 |
+
sam_patch_hist = np.bincount(patch_merge_sam.reshape(-1), minlength=self.patch_query_obj.shape[0])
|
| 425 |
+
patch_hist_similarity = (sam_patch_hist @ self.patch_token_hist.T)
|
| 426 |
+
score = 1 - patch_hist_similarity.max()
|
| 427 |
+
```
|
| 428 |
+
|
| 429 |
+
**检测异常类型**:
|
| 430 |
+
- 食物分布异常
|
| 431 |
+
- 缺失或多余物品
|
| 432 |
+
|
| 433 |
+
## Few-shot与Full-data模式区别
|
| 434 |
+
|
| 435 |
+
### 数据处理差异
|
| 436 |
+
|
| 437 |
+
**Few-shot模式**(`model_ensemble_few_shot.py`):
|
| 438 |
+
```python
|
| 439 |
+
# 直接使用所有few-shot样本
|
| 440 |
+
FEW_SHOT_SAMPLES = [0, 1, 2, 3] # 固定4个样本
|
| 441 |
+
self.k_shot = few_shot_samples.size(0)
|
| 442 |
+
```
|
| 443 |
+
|
| 444 |
+
**Full-data模式**(`model_ensemble.py`):
|
| 445 |
+
```python
|
| 446 |
+
# 使用完整训练集构建coreset
|
| 447 |
+
FEW_SHOT_SAMPLES = range(len(datamodule.train_data)) # 所有训练样本
|
| 448 |
+
self.k_shot = 4 if self.total_size > 4 else self.total_size
|
| 449 |
+
```
|
| 450 |
+
|
| 451 |
+
### Coreset子采样机制
|
| 452 |
+
|
| 453 |
+
**Few-shot模式**:无coreset,直接使用原始特征
|
| 454 |
+
```python
|
| 455 |
+
# model_ensemble_few_shot.py:852
|
| 456 |
+
self.mem_patch_feature_clip_coreset = patch_tokens_clip
|
| 457 |
+
self.mem_patch_feature_dinov2_coreset = patch_tokens_dinov2
|
| 458 |
+
```
|
| 459 |
+
|
| 460 |
+
**Full-data模式**:使用K-Center Greedy算法进行coreset子采样
|
| 461 |
+
```python
|
| 462 |
+
# model_ensemble.py:892-896
|
| 463 |
+
clip_sampler = KCenterGreedy(embedding=mem_patch_feature_clip_coreset, sampling_ratio=0.25)
|
| 464 |
+
mem_patch_feature_clip_coreset = clip_sampler.sample_coreset()
|
| 465 |
+
|
| 466 |
+
dinov2_sampler = KCenterGreedy(embedding=mem_patch_feature_dinov2_coreset, sampling_ratio=0.25)
|
| 467 |
+
mem_patch_feature_dinov2_coreset = dinov2_sampler.sample_coreset()
|
| 468 |
+
```
|
| 469 |
+
|
| 470 |
+
### 统计信息差异
|
| 471 |
+
|
| 472 |
+
**Few-shot模式**:
|
| 473 |
+
```python
|
| 474 |
+
# model_ensemble_few_shot.py:185
|
| 475 |
+
self.stats = pickle.load(open("memory_bank/statistic_scores_model_ensemble_few_shot_val.pkl", "rb"))
|
| 476 |
+
```
|
| 477 |
+
|
| 478 |
+
**Full-data模式**:
|
| 479 |
+
```python
|
| 480 |
+
# model_ensemble.py:188
|
| 481 |
+
self.stats = pickle.load(open("memory_bank/statistic_scores_model_ensemble_val.pkl", "rb"))
|
| 482 |
+
```
|
| 483 |
+
|
| 484 |
+
### 计算流程差异
|
| 485 |
+
|
| 486 |
+
**Few-shot模式流程**:
|
| 487 |
+
1. 直接计算4个样本的特征
|
| 488 |
+
2. 无需coreset计算
|
| 489 |
+
3. 直接进行异常检测
|
| 490 |
+
|
| 491 |
+
**Full-data模式流程**:
|
| 492 |
+
1. 计算所有训练样本特征(`compute_coreset.py`)
|
| 493 |
+
2. 使用K-Center Greedy算法选择代表性特征
|
| 494 |
+
3. 保存coreset到`memory_bank/`目录
|
| 495 |
+
4. 加载预计算的coreset进行异常检测
|
| 496 |
+
|
| 497 |
+
## 实现细节与优化
|
| 498 |
+
|
| 499 |
+
### 内存优化策略
|
| 500 |
+
|
| 501 |
+
**批处理机制**:
|
| 502 |
+
```python
|
| 503 |
+
# model_ensemble.py:926-928
|
| 504 |
+
for i in range(self.total_size//self.k_shot):
|
| 505 |
+
self.process(class_name, few_shot_samples[self.k_shot*i : min(self.k_shot*(i+1), self.total_size)],
|
| 506 |
+
few_shot_paths[self.k_shot*i : min(self.k_shot*(i+1), self.total_size)])
|
| 507 |
+
```
|
| 508 |
+
|
| 509 |
+
**特征缓存**:
|
| 510 |
+
- 预计算的coreset特征保存在`memory_bank/`目录
|
| 511 |
+
- 统计信息预计算并缓存
|
| 512 |
+
|
| 513 |
+
### 多模态特征融合
|
| 514 |
+
|
| 515 |
+
**特征层选择策略**:
|
| 516 |
+
- **聚类特征**:使用CLIP的第0、1层(`cluster_feature_id = [0, 1]`)
|
| 517 |
+
- **检测特征**:使用第6、12、18、24层的完整特征
|
| 518 |
+
|
| 519 |
+
**不同类别的模型选择**:
|
| 520 |
+
```python
|
| 521 |
+
# model_ensemble.py:290-310
|
| 522 |
+
if self.class_name in ['pushpins', 'screw_bag']:
|
| 523 |
+
# 使用CLIP特征进行PatchCore检测
|
| 524 |
+
len_feature_list = len(self.feature_list)
|
| 525 |
+
for patch_feature, mem_patch_feature in zip(patch_tokens_clip.chunk(len_feature_list, dim=-1),
|
| 526 |
+
mem_patch_feature_clip_coreset.chunk(len_feature_list, dim=-1)):
|
| 527 |
+
|
| 528 |
+
if self.class_name in ['splicing_connectors', 'breakfast_box', 'juice_bottle']:
|
| 529 |
+
# 使用DINOv2特征进行PatchCore检测
|
| 530 |
+
len_feature_list = len(self.feature_list_dinov2)
|
| 531 |
+
for patch_feature, mem_patch_feature in zip(patch_tokens_dinov2.chunk(len_feature_list, dim=-1),
|
| 532 |
+
mem_patch_feature_dinov2_coreset.chunk(len_feature_list, dim=-1)):
|
| 533 |
+
```
|
| 534 |
+
|
| 535 |
+
## 文本提示工程
|
| 536 |
+
|
| 537 |
+
### 语义查询词典
|
| 538 |
+
|
| 539 |
+
**物体级别查询**:
|
| 540 |
+
```python
|
| 541 |
+
# model_ensemble.py:123-136
|
| 542 |
+
self.query_words_dict = {
|
| 543 |
+
"breakfast_box": ['orange', "nectarine", "cereals", "banana chips", 'almonds', 'white box', 'black background'],
|
| 544 |
+
"juice_bottle": ['bottle', ['black background', 'background']],
|
| 545 |
+
"pushpins": [['pushpin', 'pin'], ['plastic box', 'black background']],
|
| 546 |
+
"screw_bag": [['screw'], 'plastic bag', 'background'],
|
| 547 |
+
"splicing_connectors": [['splicing connector', 'splice connector',], ['cable', 'wire'], ['grid']],
|
| 548 |
+
}
|
| 549 |
+
```
|
| 550 |
+
|
| 551 |
+
**Patch级别查询**:
|
| 552 |
+
```python
|
| 553 |
+
# model_ensemble.py:138-145
|
| 554 |
+
self.patch_query_words_dict = {
|
| 555 |
+
"juice_bottle": [['glass'], ['liquid in bottle'], ['fruit'], ['label', 'tag'], ['black background', 'background']],
|
| 556 |
+
"screw_bag": [['hex screw', 'hexagon bolt'], ['hex nut', 'hexagon nut'], ['ring washer', 'ring gasket'], ['plastic bag', 'background']],
|
| 557 |
+
# ...
|
| 558 |
+
}
|
| 559 |
+
```
|
| 560 |
+
|
| 561 |
+
### 文本编码策略
|
| 562 |
+
|
| 563 |
+
**多模板编码**:
|
| 564 |
+
```python
|
| 565 |
+
# prompt_ensemble.py:98-120
|
| 566 |
+
def encode_obj_text(model, query_words, tokenizer, device):
|
| 567 |
+
for qw in query_words:
|
| 568 |
+
if type(qw) == list:
|
| 569 |
+
for qw2 in qw:
|
| 570 |
+
token_input.extend([temp(qw2) for temp in openai_imagenet_template])
|
| 571 |
+
else:
|
| 572 |
+
token_input = [temp(qw) for temp in openai_imagenet_template]
|
| 573 |
+
```
|
| 574 |
+
|
| 575 |
+
使用82个不同的ImageNet模板进行文本增强,提高文本特征的鲁棒性。
|
| 576 |
+
|
| 577 |
+
## 性能评估
|
| 578 |
+
|
| 579 |
+
### 评估指标
|
| 580 |
+
|
| 581 |
+
**图像级别指标**:
|
| 582 |
+
- F1-Max(Image)
|
| 583 |
+
- AUROC(Image)
|
| 584 |
+
|
| 585 |
+
**异常类型指标**:
|
| 586 |
+
- F1-Max(Logical):逻辑异常
|
| 587 |
+
- AUROC(Logical):逻辑异常
|
| 588 |
+
- F1-Max(Structural):结构异常
|
| 589 |
+
- AUROC(Structural):结构异常
|
| 590 |
+
|
| 591 |
+
### 评估流程
|
| 592 |
+
|
| 593 |
+
**数据分离**:
|
| 594 |
+
```python
|
| 595 |
+
# evaluation.py:222-227
|
| 596 |
+
if 'logical' not in image_path[0]:
|
| 597 |
+
image_metric_structure.update(output["pred_score"].cpu(), data["label"])
|
| 598 |
+
if 'structural' not in image_path[0]:
|
| 599 |
+
image_metric_logical.update(output["pred_score"].cpu(), data["label"])
|
| 600 |
+
```
|
| 601 |
+
|
| 602 |
+
**分数融合**:
|
| 603 |
+
```python
|
| 604 |
+
# model_ensemble.py:227-231
|
| 605 |
+
standard_structural_score = (structural_score - self.stats[self.class_name]["structural_scores"]["mean"]) / self.stats[self.class_name]["structural_scores"]["unbiased_std"]
|
| 606 |
+
standard_instance_hungarian_match_score = (instance_hungarian_match_score - self.stats[self.class_name]["instance_hungarian_match_scores"]["mean"]) / self.stats[self.class_name]["instance_hungarian_match_scores"]["unbiased_std"]
|
| 607 |
+
|
| 608 |
+
pred_score = max(standard_instance_hungarian_match_score, standard_structural_score)
|
| 609 |
+
pred_score = sigmoid(pred_score)
|
| 610 |
+
```
|
| 611 |
+
|
| 612 |
+
## 总结
|
| 613 |
+
|
| 614 |
+
LogSAD通过巧妙结合多个预训练模型的优势,实现了无需训练的异常检测:
|
| 615 |
+
|
| 616 |
+
1. **多模态协作**:CLIP提供语义理解、DINOv2提供视觉特征、SAM提供精确分割
|
| 617 |
+
2. **逻辑推理**:通过领域知识编码的特判逻辑检测复杂的逻辑异常
|
| 618 |
+
3. **特征融合**:多尺度特征提取和融合提高检测精度
|
| 619 |
+
4. **高效优化**:Coreset子采样和特征缓存机制保证实用性
|
| 620 |
+
|
| 621 |
+
该方法在MVTec LOCO数据集上取得了优异的性能,展示了预训练模型在异常检测任务中的巨大潜力。
|
README.md
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Towards Training-free Anomaly Detection with Vision and Language Foundation Models (CVPR 2025)
|
| 2 |
+
|
| 3 |
+
<div>
|
| 4 |
+
<a href="https://arxiv.org/abs/2503.18325"><img src="https://img.shields.io/static/v1?label=Arxiv&message=LogSAD&color=red&logo=arxiv"></a>  
|
| 5 |
+
</div>
|
| 6 |
+
|
| 7 |
+
## System Requirements
|
| 8 |
+
|
| 9 |
+
**Hardware Requirements:**
|
| 10 |
+
- **GPU Memory:** 32GB VRAM (for running complete experiments)
|
| 11 |
+
- **Storage:** 70GB free disk space (for models, datasets, and results)
|
| 12 |
+
- **CUDA:** Compatible GPU with CUDA 12.1 support
|
| 13 |
+
|
| 14 |
+
**Software Requirements:**
|
| 15 |
+
- Python 3.10
|
| 16 |
+
- Conda (recommended for environment management)
|
| 17 |
+
- CUDA 12.1 runtime
|
| 18 |
+
|
| 19 |
+
> **Note:** The memory and storage requirements are for running the full experimental pipeline on all categories with visualization enabled. Smaller experiments on individual categories may require less resources.
|
| 20 |
+
|
| 21 |
+
## Installation
|
| 22 |
+
|
| 23 |
+
### Automated Setup (Recommended)
|
| 24 |
+
|
| 25 |
+
Run the setup script to automatically configure the complete environment:
|
| 26 |
+
|
| 27 |
+
```bash
|
| 28 |
+
bash scripts/setup_environment.sh
|
| 29 |
+
```
|
| 30 |
+
|
| 31 |
+
This script will:
|
| 32 |
+
- Create a conda environment named `logsad` with Python 3.10
|
| 33 |
+
- Install PyTorch with CUDA 12.1 support
|
| 34 |
+
- Install all required dependencies from `requirements.txt`
|
| 35 |
+
- Configure numpy compatibility
|
| 36 |
+
|
| 37 |
+
### Manual Setup
|
| 38 |
+
|
| 39 |
+
If you prefer manual setup, download the checkpoint for [ViT-H SAM model](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth) and put it in the checkpoint folder.
|
| 40 |
+
|
| 41 |
+
After installation, activate the environment:
|
| 42 |
+
```bash
|
| 43 |
+
conda activate logsad
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
## Instructions for MVTEC LOCO dataset
|
| 48 |
+
|
| 49 |
+
### Quick Start (Recommended)
|
| 50 |
+
|
| 51 |
+
Run evaluation for all categories using the provided shell scripts:
|
| 52 |
+
|
| 53 |
+
**Few-shot Protocol:**
|
| 54 |
+
```bash
|
| 55 |
+
bash scripts/run_few_shot.sh
|
| 56 |
+
```
|
| 57 |
+
|
| 58 |
+
**Full-data Protocol:**
|
| 59 |
+
```bash
|
| 60 |
+
bash scripts/run_full_data.sh
|
| 61 |
+
```
|
| 62 |
+
|
| 63 |
+
### Manual Execution
|
| 64 |
+
|
| 65 |
+
#### Few-shot Protocol
|
| 66 |
+
Run the script for few-shot protocal:
|
| 67 |
+
|
| 68 |
+
```
|
| 69 |
+
python evaluation.py --module_path model_ensemble_few_shot --category CATEGORY --dataset_path DATASET_PATH
|
| 70 |
+
```
|
| 71 |
+
|
| 72 |
+
#### Full-data Protocol
|
| 73 |
+
Run the script to compute coreset for full-data scenarios:
|
| 74 |
+
|
| 75 |
+
```
|
| 76 |
+
python compute_coreset.py --module_path model_ensemble --category CATEGORY --dataset_path DATASET_PATH
|
| 77 |
+
```
|
| 78 |
+
|
| 79 |
+
Run the script for full-data protocol:
|
| 80 |
+
|
| 81 |
+
```
|
| 82 |
+
python evaluation.py --module_path model_ensemble --category CATEGORY --dataset_path DATASET_PATH
|
| 83 |
+
```
|
| 84 |
+
|
| 85 |
+
**Available categories:** breakfast_box, juice_bottle, pushpins, screw_bag, splicing_connectors
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
## Acknowledgement
|
| 89 |
+
We are grateful for the following awesome projects when implementing LogSAD:
|
| 90 |
+
* [SAM](https://github.com/facebookresearch/segment-anything), [OpenCLIP](https://github.com/mlfoundations/open_clip), [DINOv2](https://github.com/facebookresearch/dinov2) and [NACLIP](https://github.com/sinahmr/NACLIP).
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
## Citation
|
| 94 |
+
If you find our paper is helpful in your research or applications, generously cite with
|
| 95 |
+
```
|
| 96 |
+
@inproceedings{zhang2025logsad,
|
| 97 |
+
title={Towards Training-free Anomaly Detection with Vision and Language Foundation Models},
|
| 98 |
+
author={Jinjin Zhang, Guodong Wang, Yizhou Jin, Di Huang},
|
| 99 |
+
year={2025},
|
| 100 |
+
booktitle={IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
|
| 101 |
+
}
|
| 102 |
+
```
|
compute_coreset.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Sample evaluation script for track 2."""
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
# Set cache directories to use checkpoint folder for model downloads
|
| 6 |
+
os.environ['TORCH_HOME'] = './checkpoint'
|
| 7 |
+
os.environ['HF_HOME'] = './checkpoint/huggingface'
|
| 8 |
+
os.environ['TRANSFORMERS_CACHE'] = './checkpoint/huggingface/transformers'
|
| 9 |
+
os.environ['HF_HUB_CACHE'] = './checkpoint/huggingface/hub'
|
| 10 |
+
|
| 11 |
+
# Create checkpoint subdirectories if they don't exist
|
| 12 |
+
os.makedirs('./checkpoint/huggingface/transformers', exist_ok=True)
|
| 13 |
+
os.makedirs('./checkpoint/huggingface/hub', exist_ok=True)
|
| 14 |
+
|
| 15 |
+
import argparse
|
| 16 |
+
import importlib
|
| 17 |
+
import importlib.util
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
import logging
|
| 21 |
+
from torch import nn
|
| 22 |
+
|
| 23 |
+
# NOTE: The following MVTecLoco import is not available in anomalib v1.0.1.
|
| 24 |
+
# It will be available in v1.1.0 which will be released on April 29th, 2024.
|
| 25 |
+
# If you are using an earlier version of anomalib, you could install anomalib
|
| 26 |
+
# from the anomalib source code from the following branch:
|
| 27 |
+
# https://github.com/openvinotoolkit/anomalib/tree/feature/mvtec-loco
|
| 28 |
+
from anomalib.data import MVTecLoco
|
| 29 |
+
from anomalib.metrics.f1_max import F1Max
|
| 30 |
+
from anomalib.metrics.auroc import AUROC
|
| 31 |
+
from tabulate import tabulate
|
| 32 |
+
import numpy as np
|
| 33 |
+
|
| 34 |
+
# FEW_SHOT_SAMPLES = [0, 1, 2, 3]
|
| 35 |
+
|
| 36 |
+
def parse_args() -> argparse.Namespace:
|
| 37 |
+
"""Parse command line arguments.
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
argparse.Namespace: Parsed arguments.
|
| 41 |
+
"""
|
| 42 |
+
parser = argparse.ArgumentParser()
|
| 43 |
+
parser.add_argument("--module_path", type=str, required=True)
|
| 44 |
+
parser.add_argument("--class_name", default='MyModel', type=str, required=False)
|
| 45 |
+
parser.add_argument("--weights_path", type=str, required=False)
|
| 46 |
+
parser.add_argument("--dataset_path", default='/home/bhu/Project/datasets/mvtec_loco_anomaly_detection/', type=str, required=False)
|
| 47 |
+
parser.add_argument("--category", type=str, required=True)
|
| 48 |
+
parser.add_argument("--viz", action='store_true', default=False)
|
| 49 |
+
return parser.parse_args()
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def load_model(module_path: str, class_name: str, weights_path: str) -> nn.Module:
|
| 53 |
+
"""Load model.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
module_path (str): Path to the module containing the model class.
|
| 57 |
+
class_name (str): Name of the model class.
|
| 58 |
+
weights_path (str): Path to the model weights.
|
| 59 |
+
|
| 60 |
+
Returns:
|
| 61 |
+
nn.Module: Loaded model.
|
| 62 |
+
"""
|
| 63 |
+
# get model class
|
| 64 |
+
model_class = getattr(importlib.import_module(module_path), class_name)
|
| 65 |
+
# instantiate model
|
| 66 |
+
model = model_class()
|
| 67 |
+
# load weights
|
| 68 |
+
if weights_path:
|
| 69 |
+
model.load_state_dict(torch.load(weights_path))
|
| 70 |
+
return model
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def run(module_path: str, class_name: str, weights_path: str, dataset_path: str, category: str, viz: bool) -> None:
|
| 74 |
+
"""Run the evaluation script.
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
module_path (str): Path to the module containing the model class.
|
| 78 |
+
class_name (str): Name of the model class.
|
| 79 |
+
weights_path (str): Path to the model weights.
|
| 80 |
+
dataset_path (str): Path to the dataset.
|
| 81 |
+
category (str): Category of the dataset.
|
| 82 |
+
"""
|
| 83 |
+
# Disable verbose logging from all libraries
|
| 84 |
+
logging.getLogger().setLevel(logging.ERROR)
|
| 85 |
+
logging.getLogger('anomalib').setLevel(logging.ERROR)
|
| 86 |
+
logging.getLogger('lightning').setLevel(logging.ERROR)
|
| 87 |
+
logging.getLogger('pytorch_lightning').setLevel(logging.ERROR)
|
| 88 |
+
|
| 89 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
| 90 |
+
|
| 91 |
+
# Instantiate model class here
|
| 92 |
+
# Load the model here from checkpoint.
|
| 93 |
+
model = load_model(module_path, class_name, weights_path)
|
| 94 |
+
model.to(device)
|
| 95 |
+
|
| 96 |
+
# Create the dataset
|
| 97 |
+
datamodule = MVTecLoco(root=dataset_path, eval_batch_size=1, image_size=(448, 448), category=category)
|
| 98 |
+
datamodule.setup()
|
| 99 |
+
|
| 100 |
+
model.set_viz(viz)
|
| 101 |
+
model.set_save_coreset_features(True)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
FEW_SHOT_SAMPLES = range(len(datamodule.train_data)) # traverse all dataset to build coreset
|
| 105 |
+
|
| 106 |
+
# pass few-shot images and dataset category to model
|
| 107 |
+
setup_data = {
|
| 108 |
+
"few_shot_samples": torch.stack([datamodule.train_data[idx]["image"] for idx in FEW_SHOT_SAMPLES]).to(device),
|
| 109 |
+
"few_shot_samples_path": [datamodule.train_data[idx]["image_path"] for idx in FEW_SHOT_SAMPLES],
|
| 110 |
+
"dataset_category": category,
|
| 111 |
+
}
|
| 112 |
+
model.setup(setup_data)
|
| 113 |
+
|
| 114 |
+
print(f"✓ Coreset computation completed for {category}")
|
| 115 |
+
print(f" Memory bank features saved to memory_bank/ directory")
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
if __name__ == "__main__":
|
| 120 |
+
args = parse_args()
|
| 121 |
+
run(args.module_path, args.class_name, args.weights_path, args.dataset_path, args.category, args.viz)
|
environment.yml
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: logsad
|
| 2 |
+
channels:
|
| 3 |
+
- https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge
|
| 4 |
+
- https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/pro/
|
| 5 |
+
- https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch/
|
| 6 |
+
- https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge/
|
| 7 |
+
- https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/
|
| 8 |
+
- https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/
|
| 9 |
+
- defaults
|
| 10 |
+
dependencies:
|
| 11 |
+
- _libgcc_mutex=0.1=conda_forge
|
| 12 |
+
- _openmp_mutex=4.5=2_gnu
|
| 13 |
+
- bzip2=1.0.8=h4bc722e_7
|
| 14 |
+
- ca-certificates=2025.8.3=hbd8a1cb_0
|
| 15 |
+
- ld_impl_linux-64=2.44=h1423503_1
|
| 16 |
+
- libexpat=2.7.1=hecca717_0
|
| 17 |
+
- libffi=3.4.6=h2dba641_1
|
| 18 |
+
- libgcc=15.1.0=h767d61c_4
|
| 19 |
+
- libgcc-ng=15.1.0=h69a702a_4
|
| 20 |
+
- libgomp=15.1.0=h767d61c_4
|
| 21 |
+
- liblzma=5.8.1=hb9d3cd8_2
|
| 22 |
+
- libnsl=2.0.1=hb9d3cd8_1
|
| 23 |
+
- libsqlite=3.50.4=h0c1763c_0
|
| 24 |
+
- libuuid=2.38.1=h0b41bf4_0
|
| 25 |
+
- libxcrypt=4.4.36=hd590300_1
|
| 26 |
+
- libzlib=1.3.1=hb9d3cd8_2
|
| 27 |
+
- ncurses=6.5=h2d0b736_3
|
| 28 |
+
- openssl=3.5.2=h26f9b46_0
|
| 29 |
+
- pip=25.2=pyh8b19718_0
|
| 30 |
+
- python=3.10.18=hd6af730_0_cpython
|
| 31 |
+
- readline=8.2=h8c095d6_2
|
| 32 |
+
- setuptools=80.9.0=pyhff2d567_0
|
| 33 |
+
- tk=8.6.13=noxft_hd72426e_102
|
| 34 |
+
- wheel=0.45.1=pyhd8ed1ab_1
|
| 35 |
+
- pip:
|
| 36 |
+
- aiohappyeyeballs==2.6.1
|
| 37 |
+
- aiohttp==3.12.11
|
| 38 |
+
- aiosignal==1.3.2
|
| 39 |
+
- antlr4-python3-runtime==4.9.3
|
| 40 |
+
- async-timeout==5.0.1
|
| 41 |
+
- attrs==25.3.0
|
| 42 |
+
- certifi==2025.4.26
|
| 43 |
+
- charset-normalizer==3.4.2
|
| 44 |
+
- contourpy==1.3.2
|
| 45 |
+
- cycler==0.12.1
|
| 46 |
+
- einops==0.6.1
|
| 47 |
+
- faiss-cpu==1.8.0
|
| 48 |
+
- filelock==3.18.0
|
| 49 |
+
- fonttools==4.58.2
|
| 50 |
+
- freia==0.2
|
| 51 |
+
- frozenlist==1.6.2
|
| 52 |
+
- fsspec==2024.12.0
|
| 53 |
+
- ftfy==6.3.1
|
| 54 |
+
- hf-xet==1.1.3
|
| 55 |
+
- huggingface-hub==0.32.4
|
| 56 |
+
- idna==3.10
|
| 57 |
+
- imageio==2.37.0
|
| 58 |
+
- imgaug==0.4.0
|
| 59 |
+
- jinja2==3.1.6
|
| 60 |
+
- joblib==1.5.1
|
| 61 |
+
- jsonargparse==4.29.0
|
| 62 |
+
- kiwisolver==1.4.8
|
| 63 |
+
- kmeans-pytorch==0.3
|
| 64 |
+
- kornia==0.7.0
|
| 65 |
+
- lazy-loader==0.4
|
| 66 |
+
- lightning==2.2.5
|
| 67 |
+
- lightning-utilities==0.14.3
|
| 68 |
+
- markdown-it-py==3.0.0
|
| 69 |
+
- markupsafe==3.0.2
|
| 70 |
+
- matplotlib==3.10.3
|
| 71 |
+
- mdurl==0.1.2
|
| 72 |
+
- mpmath==1.3.0
|
| 73 |
+
- multidict==6.4.4
|
| 74 |
+
- networkx==3.4.2
|
| 75 |
+
- numpy==1.23.1
|
| 76 |
+
- omegaconf==2.3.0
|
| 77 |
+
- open-clip-torch==2.24.0
|
| 78 |
+
- opencv-python==4.8.1.78
|
| 79 |
+
- packaging==24.2
|
| 80 |
+
- pandas==2.0.3
|
| 81 |
+
- pillow==11.2.1
|
| 82 |
+
- propcache==0.3.1
|
| 83 |
+
- protobuf==6.31.1
|
| 84 |
+
- pygments==2.19.1
|
| 85 |
+
- pyparsing==3.2.3
|
| 86 |
+
- python-dateutil==2.9.0.post0
|
| 87 |
+
- pytorch-lightning==2.5.1.post0
|
| 88 |
+
- pytz==2025.2
|
| 89 |
+
- pyyaml==6.0.2
|
| 90 |
+
- regex==2024.11.6
|
| 91 |
+
- requests==2.32.3
|
| 92 |
+
- rich==13.7.1
|
| 93 |
+
- safetensors==0.5.3
|
| 94 |
+
- scikit-image==0.25.2
|
| 95 |
+
- scikit-learn==1.2.2
|
| 96 |
+
- scipy==1.15.3
|
| 97 |
+
- segment-anything==1.0
|
| 98 |
+
- sentencepiece==0.2.0
|
| 99 |
+
- shapely==2.1.1
|
| 100 |
+
- six==1.17.0
|
| 101 |
+
- sympy==1.14.0
|
| 102 |
+
- tabulate==0.9.0
|
| 103 |
+
- threadpoolctl==3.6.0
|
| 104 |
+
- tifffile==2025.5.10
|
| 105 |
+
- timm==1.0.15
|
| 106 |
+
- torch==2.1.2+cu121
|
| 107 |
+
- torchmetrics==1.7.2
|
| 108 |
+
- torchvision==0.16.2+cu121
|
| 109 |
+
- tqdm==4.67.1
|
| 110 |
+
- triton==2.1.0
|
| 111 |
+
- typing-extensions==4.14.0
|
| 112 |
+
- tzdata==2025.2
|
| 113 |
+
- urllib3==2.4.0
|
| 114 |
+
- wcwidth==0.2.13
|
| 115 |
+
- yarl==1.20.0
|
| 116 |
+
prefix: /opt/conda/envs/logsad
|
evaluation.py
ADDED
|
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Sample evaluation script for track 2."""
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
# Set cache directories to use checkpoint folder for model downloads
|
| 8 |
+
os.environ['TORCH_HOME'] = './checkpoint'
|
| 9 |
+
os.environ['HF_HOME'] = './checkpoint/huggingface'
|
| 10 |
+
os.environ['TRANSFORMERS_CACHE'] = './checkpoint/huggingface/transformers'
|
| 11 |
+
os.environ['HF_HUB_CACHE'] = './checkpoint/huggingface/hub'
|
| 12 |
+
|
| 13 |
+
# Create checkpoint subdirectories if they don't exist
|
| 14 |
+
os.makedirs('./checkpoint/huggingface/transformers', exist_ok=True)
|
| 15 |
+
os.makedirs('./checkpoint/huggingface/hub', exist_ok=True)
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
import importlib
|
| 19 |
+
import importlib.util
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
import logging
|
| 23 |
+
from torch import nn
|
| 24 |
+
|
| 25 |
+
# NOTE: The following MVTecLoco import is not available in anomalib v1.0.1.
|
| 26 |
+
# It will be available in v1.1.0 which will be released on April 29th, 2024.
|
| 27 |
+
# If you are using an earlier version of anomalib, you could install anomalib
|
| 28 |
+
# from the anomalib source code from the following branch:
|
| 29 |
+
# https://github.com/openvinotoolkit/anomalib/tree/feature/mvtec-loco
|
| 30 |
+
from anomalib.data import MVTecLoco
|
| 31 |
+
from anomalib.metrics.f1_max import F1Max
|
| 32 |
+
from anomalib.metrics.auroc import AUROC
|
| 33 |
+
from tabulate import tabulate
|
| 34 |
+
import numpy as np
|
| 35 |
+
|
| 36 |
+
FEW_SHOT_SAMPLES = [0, 1, 2, 3]
|
| 37 |
+
|
| 38 |
+
def write_results_to_markdown(category, results_data, module_path):
|
| 39 |
+
"""Write evaluation results to markdown file.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
category (str): Dataset category name
|
| 43 |
+
results_data (dict): Dictionary containing all metrics
|
| 44 |
+
module_path (str): Model module path (for protocol identification)
|
| 45 |
+
"""
|
| 46 |
+
# Determine protocol type from module path
|
| 47 |
+
protocol = "Few-shot" if "few_shot" in module_path else "Full-data"
|
| 48 |
+
|
| 49 |
+
# Create results directory
|
| 50 |
+
results_dir = Path("results")
|
| 51 |
+
results_dir.mkdir(exist_ok=True)
|
| 52 |
+
|
| 53 |
+
# Combined results file with simple protocol name
|
| 54 |
+
protocol_suffix = "few_shot" if "few_shot" in module_path else "full_data"
|
| 55 |
+
combined_file = results_dir / f"{protocol_suffix}_results.md"
|
| 56 |
+
|
| 57 |
+
# Read existing results if file exists
|
| 58 |
+
existing_results = {}
|
| 59 |
+
if combined_file.exists():
|
| 60 |
+
with open(combined_file, 'r') as f:
|
| 61 |
+
content = f.read()
|
| 62 |
+
# Parse existing results (basic parsing)
|
| 63 |
+
lines = content.split('\n')
|
| 64 |
+
for line in lines:
|
| 65 |
+
if '|' in line and line.count('|') >= 8:
|
| 66 |
+
parts = [p.strip() for p in line.split('|')]
|
| 67 |
+
if len(parts) >= 8 and parts[1] != 'Category' and parts[1] != '-----':
|
| 68 |
+
existing_results[parts[1]] = {
|
| 69 |
+
'k_shots': parts[2],
|
| 70 |
+
'f1_image': parts[3],
|
| 71 |
+
'auroc_image': parts[4],
|
| 72 |
+
'f1_logical': parts[5],
|
| 73 |
+
'auroc_logical': parts[6],
|
| 74 |
+
'f1_structural': parts[7],
|
| 75 |
+
'auroc_structural': parts[8]
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
# Add current results
|
| 79 |
+
existing_results[category] = {
|
| 80 |
+
'k_shots': str(len(FEW_SHOT_SAMPLES)),
|
| 81 |
+
'f1_image': f"{results_data['f1_image']:.2f}",
|
| 82 |
+
'auroc_image': f"{results_data['auroc_image']:.2f}",
|
| 83 |
+
'f1_logical': f"{results_data['f1_logical']:.2f}",
|
| 84 |
+
'auroc_logical': f"{results_data['auroc_logical']:.2f}",
|
| 85 |
+
'f1_structural': f"{results_data['f1_structural']:.2f}",
|
| 86 |
+
'auroc_structural': f"{results_data['auroc_structural']:.2f}"
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
# Write combined results
|
| 90 |
+
with open(combined_file, 'w') as f:
|
| 91 |
+
f.write(f"# All Categories - {protocol} Protocol Results\n\n")
|
| 92 |
+
f.write(f"**Last Updated:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")
|
| 93 |
+
f.write(f"**Protocol:** {protocol}\n")
|
| 94 |
+
f.write(f"**Available Categories:** {', '.join(sorted(existing_results.keys()))}\n\n")
|
| 95 |
+
|
| 96 |
+
f.write("## Summary Table\n\n")
|
| 97 |
+
f.write("| Category | K-shots | F1-Max (Image) | AUROC (Image) | F1-Max (Logical) | AUROC (Logical) | F1-Max (Structural) | AUROC (Structural) |\n")
|
| 98 |
+
f.write("|----------|---------|----------------|---------------|------------------|-----------------|---------------------|-------------------|\n")
|
| 99 |
+
|
| 100 |
+
# Sort categories alphabetically
|
| 101 |
+
for cat in sorted(existing_results.keys()):
|
| 102 |
+
data = existing_results[cat]
|
| 103 |
+
f.write(f"| {cat} | {data['k_shots']} | {data['f1_image']} | {data['auroc_image']} | {data['f1_logical']} | {data['auroc_logical']} | {data['f1_structural']} | {data['auroc_structural']} |\n")
|
| 104 |
+
|
| 105 |
+
print(f"\n✓ Results saved to:")
|
| 106 |
+
print(f" - Combined: {combined_file}")
|
| 107 |
+
|
| 108 |
+
def parse_args() -> argparse.Namespace:
|
| 109 |
+
"""Parse command line arguments.
|
| 110 |
+
|
| 111 |
+
Returns:
|
| 112 |
+
argparse.Namespace: Parsed arguments.
|
| 113 |
+
"""
|
| 114 |
+
parser = argparse.ArgumentParser()
|
| 115 |
+
parser.add_argument("--module_path", type=str, required=True)
|
| 116 |
+
parser.add_argument("--class_name", default='MyModel', type=str, required=False)
|
| 117 |
+
parser.add_argument("--weights_path", type=str, required=False)
|
| 118 |
+
parser.add_argument("--dataset_path", default='/home/bhu/Project/datasets/mvtec_loco_anomaly_detection/', type=str, required=False)
|
| 119 |
+
parser.add_argument("--category", type=str, required=True)
|
| 120 |
+
parser.add_argument("--viz", action='store_true', default=False)
|
| 121 |
+
return parser.parse_args()
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def load_model(module_path: str, class_name: str, weights_path: str) -> nn.Module:
|
| 125 |
+
"""Load model.
|
| 126 |
+
|
| 127 |
+
Args:
|
| 128 |
+
module_path (str): Path to the module containing the model class.
|
| 129 |
+
class_name (str): Name of the model class.
|
| 130 |
+
weights_path (str): Path to the model weights.
|
| 131 |
+
|
| 132 |
+
Returns:
|
| 133 |
+
nn.Module: Loaded model.
|
| 134 |
+
"""
|
| 135 |
+
# get model class
|
| 136 |
+
model_class = getattr(importlib.import_module(module_path), class_name)
|
| 137 |
+
# instantiate model
|
| 138 |
+
model = model_class()
|
| 139 |
+
# load weights
|
| 140 |
+
if weights_path:
|
| 141 |
+
model.load_state_dict(torch.load(weights_path))
|
| 142 |
+
return model
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def run(module_path: str, class_name: str, weights_path: str, dataset_path: str, category: str, viz: bool) -> None:
|
| 146 |
+
"""Run the evaluation script.
|
| 147 |
+
|
| 148 |
+
Args:
|
| 149 |
+
module_path (str): Path to the module containing the model class.
|
| 150 |
+
class_name (str): Name of the model class.
|
| 151 |
+
weights_path (str): Path to the model weights.
|
| 152 |
+
dataset_path (str): Path to the dataset.
|
| 153 |
+
category (str): Category of the dataset.
|
| 154 |
+
"""
|
| 155 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
| 156 |
+
|
| 157 |
+
# Instantiate model class here
|
| 158 |
+
# Load the model here from checkpoint.
|
| 159 |
+
model = load_model(module_path, class_name, weights_path)
|
| 160 |
+
model.to(device)
|
| 161 |
+
|
| 162 |
+
#
|
| 163 |
+
# Create the dataset
|
| 164 |
+
datamodule = MVTecLoco(root=dataset_path, eval_batch_size=1, image_size=(448, 448), category=category)
|
| 165 |
+
datamodule.setup()
|
| 166 |
+
|
| 167 |
+
model.set_viz(viz)
|
| 168 |
+
|
| 169 |
+
#
|
| 170 |
+
# Create the metrics
|
| 171 |
+
image_metric = F1Max()
|
| 172 |
+
pixel_metric = F1Max()
|
| 173 |
+
|
| 174 |
+
image_metric_logical = F1Max()
|
| 175 |
+
image_metric_structure = F1Max()
|
| 176 |
+
|
| 177 |
+
image_metric_auroc = AUROC()
|
| 178 |
+
pixel_metric_auroc = AUROC()
|
| 179 |
+
|
| 180 |
+
image_metric_auroc_logical = AUROC()
|
| 181 |
+
image_metric_auroc_structure = AUROC()
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
#
|
| 185 |
+
# pass few-shot images and dataset category to model
|
| 186 |
+
setup_data = {
|
| 187 |
+
"few_shot_samples": torch.stack([datamodule.train_data[idx]["image"] for idx in FEW_SHOT_SAMPLES]).to(device),
|
| 188 |
+
"few_shot_samples_path": [datamodule.train_data[idx]["image_path"] for idx in FEW_SHOT_SAMPLES],
|
| 189 |
+
"dataset_category": category,
|
| 190 |
+
}
|
| 191 |
+
model.setup(setup_data)
|
| 192 |
+
|
| 193 |
+
# Loop over the test set and compute the metrics
|
| 194 |
+
for data in datamodule.test_dataloader():
|
| 195 |
+
with torch.no_grad():
|
| 196 |
+
image_path = data['image_path']
|
| 197 |
+
output = model(data["image"].to(device), data['image_path'])
|
| 198 |
+
|
| 199 |
+
image_metric.update(output["pred_score"].cpu(), data["label"])
|
| 200 |
+
image_metric_auroc.update(output["pred_score"].cpu(), data["label"])
|
| 201 |
+
|
| 202 |
+
if 'logical' not in image_path[0]:
|
| 203 |
+
image_metric_structure.update(output["pred_score"].cpu(), data["label"])
|
| 204 |
+
image_metric_auroc_structure.update(output["pred_score"].cpu(), data["label"])
|
| 205 |
+
if 'structural' not in image_path[0]:
|
| 206 |
+
image_metric_logical.update(output["pred_score"].cpu(), data["label"])
|
| 207 |
+
image_metric_auroc_logical.update(output["pred_score"].cpu(), data["label"])
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
# Disable verbose logging from all libraries
|
| 212 |
+
logging.getLogger().setLevel(logging.ERROR)
|
| 213 |
+
logging.getLogger('anomalib').setLevel(logging.ERROR)
|
| 214 |
+
logging.getLogger('lightning').setLevel(logging.ERROR)
|
| 215 |
+
logging.getLogger('pytorch_lightning').setLevel(logging.ERROR)
|
| 216 |
+
|
| 217 |
+
# Set up our own logger for results only
|
| 218 |
+
logger = logging.getLogger('evaluation')
|
| 219 |
+
logger.handlers.clear()
|
| 220 |
+
logger.setLevel(logging.INFO)
|
| 221 |
+
formatter = logging.Formatter('%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s', datefmt='%y-%m-%d %H:%M:%S')
|
| 222 |
+
console_handler = logging.StreamHandler()
|
| 223 |
+
console_handler.setFormatter(formatter)
|
| 224 |
+
logger.addHandler(console_handler)
|
| 225 |
+
|
| 226 |
+
table_ls = [[category,
|
| 227 |
+
str(len(FEW_SHOT_SAMPLES)),
|
| 228 |
+
str(np.round(image_metric.compute().item() * 100, decimals=2)),
|
| 229 |
+
str(np.round(image_metric_auroc.compute().item() * 100, decimals=2)),
|
| 230 |
+
# str(np.round(pixel_metric.compute().item() * 100, decimals=2)),
|
| 231 |
+
# str(np.round(pixel_metric_auroc.compute().item() * 100, decimals=2)),
|
| 232 |
+
str(np.round(image_metric_logical.compute().item() * 100, decimals=2)),
|
| 233 |
+
str(np.round(image_metric_auroc_logical.compute().item() * 100, decimals=2)),
|
| 234 |
+
str(np.round(image_metric_structure.compute().item() * 100, decimals=2)),
|
| 235 |
+
str(np.round(image_metric_auroc_structure.compute().item() * 100, decimals=2)),
|
| 236 |
+
]]
|
| 237 |
+
|
| 238 |
+
results = tabulate(table_ls, headers=['category', 'K-shots', 'F1-Max(image)', 'AUROC(image)', 'F1-Max (logical)', 'AUROC (logical)', 'F1-Max (structural)', 'AUROC (structural)'], tablefmt="pipe")
|
| 239 |
+
|
| 240 |
+
logger.info("\n%s", results)
|
| 241 |
+
|
| 242 |
+
# Save results to markdown
|
| 243 |
+
results_data = {
|
| 244 |
+
'f1_image': np.round(image_metric.compute().item() * 100, decimals=2),
|
| 245 |
+
'auroc_image': np.round(image_metric_auroc.compute().item() * 100, decimals=2),
|
| 246 |
+
'f1_logical': np.round(image_metric_logical.compute().item() * 100, decimals=2),
|
| 247 |
+
'auroc_logical': np.round(image_metric_auroc_logical.compute().item() * 100, decimals=2),
|
| 248 |
+
'f1_structural': np.round(image_metric_structure.compute().item() * 100, decimals=2),
|
| 249 |
+
'auroc_structural': np.round(image_metric_auroc_structure.compute().item() * 100, decimals=2)
|
| 250 |
+
}
|
| 251 |
+
write_results_to_markdown(category, results_data, module_path)
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
if __name__ == "__main__":
|
| 256 |
+
args = parse_args()
|
| 257 |
+
run(args.module_path, args.class_name, args.weights_path, args.dataset_path, args.category, args.viz)
|
imagenet_template.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
openai_imagenet_template = [
|
| 2 |
+
lambda c: f'a bad photo of a {c}.',
|
| 3 |
+
lambda c: f'a photo of many {c}.',
|
| 4 |
+
lambda c: f'a sculpture of a {c}.',
|
| 5 |
+
lambda c: f'a photo of the hard to see {c}.',
|
| 6 |
+
lambda c: f'a low resolution photo of the {c}.',
|
| 7 |
+
lambda c: f'a rendering of a {c}.',
|
| 8 |
+
lambda c: f'graffiti of a {c}.',
|
| 9 |
+
lambda c: f'a bad photo of the {c}.',
|
| 10 |
+
lambda c: f'a cropped photo of the {c}.',
|
| 11 |
+
lambda c: f'a tattoo of a {c}.',
|
| 12 |
+
lambda c: f'the embroidered {c}.',
|
| 13 |
+
lambda c: f'a photo of a hard to see {c}.',
|
| 14 |
+
lambda c: f'a bright photo of a {c}.',
|
| 15 |
+
lambda c: f'a photo of a clean {c}.',
|
| 16 |
+
lambda c: f'a photo of a dirty {c}.',
|
| 17 |
+
lambda c: f'a dark photo of the {c}.',
|
| 18 |
+
lambda c: f'a drawing of a {c}.',
|
| 19 |
+
lambda c: f'a photo of my {c}.',
|
| 20 |
+
lambda c: f'the plastic {c}.',
|
| 21 |
+
lambda c: f'a photo of the cool {c}.',
|
| 22 |
+
lambda c: f'a close-up photo of a {c}.',
|
| 23 |
+
lambda c: f'a black and white photo of the {c}.',
|
| 24 |
+
lambda c: f'a painting of the {c}.',
|
| 25 |
+
lambda c: f'a painting of a {c}.',
|
| 26 |
+
lambda c: f'a pixelated photo of the {c}.',
|
| 27 |
+
lambda c: f'a sculpture of the {c}.',
|
| 28 |
+
lambda c: f'a bright photo of the {c}.',
|
| 29 |
+
lambda c: f'a cropped photo of a {c}.',
|
| 30 |
+
lambda c: f'a plastic {c}.',
|
| 31 |
+
lambda c: f'a photo of the dirty {c}.',
|
| 32 |
+
lambda c: f'a jpeg corrupted photo of a {c}.',
|
| 33 |
+
lambda c: f'a blurry photo of the {c}.',
|
| 34 |
+
lambda c: f'a photo of the {c}.',
|
| 35 |
+
lambda c: f'a good photo of the {c}.',
|
| 36 |
+
lambda c: f'a rendering of the {c}.',
|
| 37 |
+
lambda c: f'a {c} in a video game.',
|
| 38 |
+
lambda c: f'a photo of one {c}.',
|
| 39 |
+
lambda c: f'a doodle of a {c}.',
|
| 40 |
+
lambda c: f'a close-up photo of the {c}.',
|
| 41 |
+
lambda c: f'a photo of a {c}.',
|
| 42 |
+
lambda c: f'the origami {c}.',
|
| 43 |
+
lambda c: f'the {c} in a video game.',
|
| 44 |
+
lambda c: f'a sketch of a {c}.',
|
| 45 |
+
lambda c: f'a doodle of the {c}.',
|
| 46 |
+
lambda c: f'a origami {c}.',
|
| 47 |
+
lambda c: f'a low resolution photo of a {c}.',
|
| 48 |
+
lambda c: f'the toy {c}.',
|
| 49 |
+
lambda c: f'a rendition of the {c}.',
|
| 50 |
+
lambda c: f'a photo of the clean {c}.',
|
| 51 |
+
lambda c: f'a photo of a large {c}.',
|
| 52 |
+
lambda c: f'a rendition of a {c}.',
|
| 53 |
+
lambda c: f'a photo of a nice {c}.',
|
| 54 |
+
lambda c: f'a photo of a weird {c}.',
|
| 55 |
+
lambda c: f'a blurry photo of a {c}.',
|
| 56 |
+
lambda c: f'a cartoon {c}.',
|
| 57 |
+
lambda c: f'art of a {c}.',
|
| 58 |
+
lambda c: f'a sketch of the {c}.',
|
| 59 |
+
lambda c: f'a embroidered {c}.',
|
| 60 |
+
lambda c: f'a pixelated photo of a {c}.',
|
| 61 |
+
lambda c: f'itap of the {c}.',
|
| 62 |
+
lambda c: f'a jpeg corrupted photo of the {c}.',
|
| 63 |
+
lambda c: f'a good photo of a {c}.',
|
| 64 |
+
lambda c: f'a plushie {c}.',
|
| 65 |
+
lambda c: f'a photo of the nice {c}.',
|
| 66 |
+
lambda c: f'a photo of the small {c}.',
|
| 67 |
+
lambda c: f'a photo of the weird {c}.',
|
| 68 |
+
lambda c: f'the cartoon {c}.',
|
| 69 |
+
lambda c: f'art of the {c}.',
|
| 70 |
+
lambda c: f'a drawing of the {c}.',
|
| 71 |
+
lambda c: f'a photo of the large {c}.',
|
| 72 |
+
lambda c: f'a black and white photo of a {c}.',
|
| 73 |
+
lambda c: f'the plushie {c}.',
|
| 74 |
+
lambda c: f'a dark photo of a {c}.',
|
| 75 |
+
lambda c: f'itap of a {c}.',
|
| 76 |
+
lambda c: f'graffiti of the {c}.',
|
| 77 |
+
lambda c: f'a toy {c}.',
|
| 78 |
+
lambda c: f'itap of my {c}.',
|
| 79 |
+
lambda c: f'a photo of a cool {c}.',
|
| 80 |
+
lambda c: f'a photo of a small {c}.',
|
| 81 |
+
lambda c: f'a tattoo of the {c}.',
|
| 82 |
+
]
|
model_ensemble.py
ADDED
|
@@ -0,0 +1,1034 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
# Set cache directories to use checkpoint folder for model downloads
|
| 4 |
+
os.environ['TORCH_HOME'] = './checkpoint'
|
| 5 |
+
os.environ['HF_HOME'] = './checkpoint/huggingface'
|
| 6 |
+
os.environ['TRANSFORMERS_CACHE'] = './checkpoint/huggingface/transformers'
|
| 7 |
+
os.environ['HF_HUB_CACHE'] = './checkpoint/huggingface/hub'
|
| 8 |
+
|
| 9 |
+
# Create checkpoint subdirectories if they don't exist
|
| 10 |
+
os.makedirs('./checkpoint/huggingface/transformers', exist_ok=True)
|
| 11 |
+
os.makedirs('./checkpoint/huggingface/hub', exist_ok=True)
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
from torch import nn
|
| 15 |
+
from torchvision.transforms import v2
|
| 16 |
+
from torchvision.transforms.v2.functional import resize
|
| 17 |
+
import cv2
|
| 18 |
+
import json
|
| 19 |
+
import torch
|
| 20 |
+
import random
|
| 21 |
+
import logging
|
| 22 |
+
import argparse
|
| 23 |
+
import numpy as np
|
| 24 |
+
from PIL import Image
|
| 25 |
+
from skimage import measure
|
| 26 |
+
from tabulate import tabulate
|
| 27 |
+
from torchvision.ops.focal_loss import sigmoid_focal_loss
|
| 28 |
+
import torch.nn.functional as F
|
| 29 |
+
import torchvision.transforms as transforms
|
| 30 |
+
import torchvision.transforms.functional as TF
|
| 31 |
+
from sklearn.metrics import auc, roc_auc_score, average_precision_score, f1_score, precision_recall_curve, pairwise
|
| 32 |
+
from sklearn.mixture import GaussianMixture
|
| 33 |
+
import faiss
|
| 34 |
+
import open_clip_local as open_clip
|
| 35 |
+
|
| 36 |
+
from torch.utils.data.dataset import ConcatDataset
|
| 37 |
+
from scipy.optimize import linear_sum_assignment
|
| 38 |
+
from sklearn.random_projection import SparseRandomProjection
|
| 39 |
+
import cv2
|
| 40 |
+
from torchvision.transforms import InterpolationMode
|
| 41 |
+
from PIL import Image
|
| 42 |
+
import string
|
| 43 |
+
|
| 44 |
+
from prompt_ensemble import encode_text_with_prompt_ensemble, encode_normal_text, encode_abnormal_text, encode_general_text, encode_obj_text
|
| 45 |
+
from kmeans_pytorch import kmeans, kmeans_predict
|
| 46 |
+
from scipy.optimize import linear_sum_assignment
|
| 47 |
+
from scipy.stats import norm
|
| 48 |
+
|
| 49 |
+
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
|
| 50 |
+
from matplotlib import pyplot as plt
|
| 51 |
+
|
| 52 |
+
import pickle
|
| 53 |
+
from scipy.stats import norm
|
| 54 |
+
|
| 55 |
+
from open_clip_local.pos_embed import get_2d_sincos_pos_embed
|
| 56 |
+
|
| 57 |
+
from anomalib.models.components import KCenterGreedy
|
| 58 |
+
|
| 59 |
+
def to_np_img(m):
|
| 60 |
+
m = m.permute(1, 2, 0).cpu().numpy()
|
| 61 |
+
mean = np.array([[[0.48145466, 0.4578275, 0.40821073]]])
|
| 62 |
+
std = np.array([[[0.26862954, 0.26130258, 0.27577711]]])
|
| 63 |
+
m = m * std + mean
|
| 64 |
+
return np.clip((m * 255.), 0, 255).astype(np.uint8)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def setup_seed(seed):
|
| 68 |
+
torch.manual_seed(seed)
|
| 69 |
+
torch.cuda.manual_seed_all(seed)
|
| 70 |
+
np.random.seed(seed)
|
| 71 |
+
random.seed(seed)
|
| 72 |
+
torch.backends.cudnn.deterministic = True
|
| 73 |
+
torch.backends.cudnn.benchmark = False
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class MyModel(nn.Module):
|
| 77 |
+
"""Example model class for track 2.
|
| 78 |
+
|
| 79 |
+
This class applies few-shot anomaly detection using the WinClip model from Anomalib.
|
| 80 |
+
"""
|
| 81 |
+
|
| 82 |
+
def __init__(self) -> None:
|
| 83 |
+
super().__init__()
|
| 84 |
+
|
| 85 |
+
setup_seed(42)
|
| 86 |
+
# NOTE: Create your transformation pipeline (if needed).
|
| 87 |
+
self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
| 88 |
+
self.transform = v2.Compose(
|
| 89 |
+
[
|
| 90 |
+
v2.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)),
|
| 91 |
+
],
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
# NOTE: Create your model.
|
| 95 |
+
|
| 96 |
+
self.model_clip, _, _ = open_clip.create_model_and_transforms('hf-hub:laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K')
|
| 97 |
+
self.tokenizer = open_clip.get_tokenizer('hf-hub:laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K')
|
| 98 |
+
self.feature_list = [6, 12, 18, 24]
|
| 99 |
+
self.embed_dim = 768
|
| 100 |
+
self.vision_width = 1024
|
| 101 |
+
|
| 102 |
+
self.model_sam = sam_model_registry["vit_h"](checkpoint = "./checkpoint/sam_vit_h_4b8939.pth").to(self.device)
|
| 103 |
+
self.mask_generator = SamAutomaticMaskGenerator(model = self.model_sam)
|
| 104 |
+
|
| 105 |
+
self.memory_size = 2048
|
| 106 |
+
self.n_neighbors = 2
|
| 107 |
+
|
| 108 |
+
self.model_clip.eval()
|
| 109 |
+
self.test_args = None
|
| 110 |
+
self.align_corners = True # False
|
| 111 |
+
self.antialias = True # False
|
| 112 |
+
self.inter_mode = 'bilinear' # bilinear/bicubic
|
| 113 |
+
|
| 114 |
+
self.cluster_feature_id = [0, 1]
|
| 115 |
+
|
| 116 |
+
self.cluster_num_dict = {
|
| 117 |
+
"breakfast_box": 3, # unused
|
| 118 |
+
"juice_bottle": 8, # unused
|
| 119 |
+
"splicing_connectors": 10, # unused
|
| 120 |
+
"pushpins": 10,
|
| 121 |
+
"screw_bag": 10,
|
| 122 |
+
}
|
| 123 |
+
self.query_words_dict = {
|
| 124 |
+
"breakfast_box": ['orange', "nectarine", "cereals", "banana chips", 'almonds', 'white box', 'black background'],
|
| 125 |
+
"juice_bottle": ['bottle', ['black background', 'background']],
|
| 126 |
+
"pushpins": [['pushpin', 'pin'], ['plastic box', 'black background']],
|
| 127 |
+
"screw_bag": [['screw'], 'plastic bag', 'background'],
|
| 128 |
+
"splicing_connectors": [['splicing connector', 'splice connector',], ['cable', 'wire'], ['grid']],
|
| 129 |
+
}
|
| 130 |
+
self.foreground_label_idx = { # for query_words_dict
|
| 131 |
+
"breakfast_box": [0, 1, 2, 3, 4, 5],
|
| 132 |
+
"juice_bottle": [0],
|
| 133 |
+
"pushpins": [0],
|
| 134 |
+
"screw_bag": [0],
|
| 135 |
+
"splicing_connectors":[0, 1]
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
self.patch_query_words_dict = {
|
| 139 |
+
"breakfast_box": ['orange', "nectarine", "cereals", "banana chips", 'almonds', 'white box', 'black background'],
|
| 140 |
+
"juice_bottle": [['glass'], ['liquid in bottle'], ['fruit'], ['label', 'tag'], ['black background', 'background']],
|
| 141 |
+
"pushpins": [['pushpin', 'pin'], ['plastic box', 'black background']],
|
| 142 |
+
"screw_bag": [['hex screw', 'hexagon bolt'], ['hex nut', 'hexagon nut'], ['ring washer', 'ring gasket'], ['plastic bag', 'background']], # 79.71
|
| 143 |
+
"splicing_connectors": [['splicing connector', 'splice connector',], ['cable', 'wire'], ['grid']],
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
self.query_threshold_dict = {
|
| 148 |
+
"breakfast_box": [0., 0., 0., 0., 0., 0., 0.], # unused
|
| 149 |
+
"juice_bottle": [0., 0., 0.], # unused
|
| 150 |
+
"splicing_connectors": [0.15, 0.15, 0.15, 0., 0.], # unused
|
| 151 |
+
"pushpins": [0.2, 0., 0., 0.],
|
| 152 |
+
"screw_bag": [0., 0., 0.,],
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
self.feat_size = 64
|
| 156 |
+
self.ori_feat_size = 32
|
| 157 |
+
|
| 158 |
+
self.visualization = False #False # True #False
|
| 159 |
+
|
| 160 |
+
self.pushpins_count = 15
|
| 161 |
+
|
| 162 |
+
self.splicing_connectors_count = [2, 3, 5] # coresponding to yellow, blue, and red
|
| 163 |
+
self.splicing_connectors_distance = 0
|
| 164 |
+
self.splicing_connectors_cable_color_query_words_dict = [['yellow cable', 'yellow wire'], ['blue cable', 'blue wire'], ['red cable', 'red wire']]
|
| 165 |
+
|
| 166 |
+
self.juice_bottle_liquid_query_words_dict = [['red liquid', 'cherry juice'], ['yellow liquid', 'orange juice'], ['milky liquid']]
|
| 167 |
+
self.juice_bottle_fruit_query_words_dict = ['cherry', ['tangerine', 'orange'], 'banana']
|
| 168 |
+
|
| 169 |
+
# query words
|
| 170 |
+
self.foreground_pixel_hist = 0
|
| 171 |
+
self.foreground_pixel_hist_screw_bag = 366.0 # 4-shot statistics
|
| 172 |
+
self.foreground_pixel_hist_splicing_connectors = 4249.666666666667 # 4-shot statistics
|
| 173 |
+
# patch query words
|
| 174 |
+
self.patch_token_hist = []
|
| 175 |
+
|
| 176 |
+
self.few_shot_inited = False
|
| 177 |
+
|
| 178 |
+
self.save_coreset_features = False
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
from dinov2.dinov2.hub.backbones import dinov2_vitl14
|
| 182 |
+
self.model_dinov2 = dinov2_vitl14()
|
| 183 |
+
self.model_dinov2.to(self.device)
|
| 184 |
+
self.model_dinov2.eval()
|
| 185 |
+
self.feature_list_dinov2 = [6, 12, 18, 24]
|
| 186 |
+
self.vision_width_dinov2 = 1024
|
| 187 |
+
|
| 188 |
+
self.stats = pickle.load(open("memory_bank/statistic_scores_model_ensemble_val.pkl", "rb"))
|
| 189 |
+
|
| 190 |
+
self.mem_instance_masks = None
|
| 191 |
+
|
| 192 |
+
self.anomaly_flag = False
|
| 193 |
+
self.validation = False #True #False
|
| 194 |
+
|
| 195 |
+
def set_save_coreset_features(self, save_coreset_features):
|
| 196 |
+
self.save_coreset_features = save_coreset_features
|
| 197 |
+
|
| 198 |
+
def set_viz(self, viz):
|
| 199 |
+
self.visualization = viz
|
| 200 |
+
|
| 201 |
+
def set_val(self, val):
|
| 202 |
+
self.validation = val
|
| 203 |
+
|
| 204 |
+
def forward(self, batch: torch.Tensor, batch_path: list) -> dict[str, torch.Tensor]:
|
| 205 |
+
"""Transform the input batch and pass it through the model.
|
| 206 |
+
|
| 207 |
+
This model returns a dictionary with the following keys
|
| 208 |
+
- ``anomaly_map`` - Anomaly map.
|
| 209 |
+
- ``pred_score`` - Predicted anomaly score.
|
| 210 |
+
"""
|
| 211 |
+
self.anomaly_flag = False
|
| 212 |
+
batch = self.transform(batch).to(self.device)
|
| 213 |
+
results = self.forward_one_sample(batch, self.mem_patch_feature_clip_coreset, self.mem_patch_feature_dinov2_coreset, batch_path[0])
|
| 214 |
+
|
| 215 |
+
hist_score = results['hist_score']
|
| 216 |
+
structural_score = results['structural_score']
|
| 217 |
+
instance_hungarian_match_score = results['instance_hungarian_match_score']
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
if self.validation:
|
| 221 |
+
return {"hist_score": torch.tensor(hist_score), "structural_score": torch.tensor(structural_score), "instance_hungarian_match_score": torch.tensor(instance_hungarian_match_score)}
|
| 222 |
+
|
| 223 |
+
def sigmoid(z):
|
| 224 |
+
return 1/(1 + np.exp(-z))
|
| 225 |
+
|
| 226 |
+
# standardization
|
| 227 |
+
standard_structural_score = (structural_score - self.stats[self.class_name]["structural_scores"]["mean"]) / self.stats[self.class_name]["structural_scores"]["unbiased_std"]
|
| 228 |
+
standard_instance_hungarian_match_score = (instance_hungarian_match_score - self.stats[self.class_name]["instance_hungarian_match_scores"]["mean"]) / self.stats[self.class_name]["instance_hungarian_match_scores"]["unbiased_std"]
|
| 229 |
+
|
| 230 |
+
pred_score = max(standard_instance_hungarian_match_score, standard_structural_score)
|
| 231 |
+
pred_score = sigmoid(pred_score)
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
if self.anomaly_flag:
|
| 235 |
+
pred_score = 1.
|
| 236 |
+
self.anomaly_flag = False
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
return {"pred_score": torch.tensor(pred_score), "hist_score": torch.tensor(hist_score), "structural_score": torch.tensor(structural_score), "instance_hungarian_match_score": torch.tensor(instance_hungarian_match_score)}
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def forward_one_sample(self, batch: torch.Tensor, mem_patch_feature_clip_coreset: torch.Tensor, mem_patch_feature_dinov2_coreset: torch.Tensor, path: str):
|
| 243 |
+
|
| 244 |
+
with torch.no_grad():
|
| 245 |
+
image_features, patch_tokens, proj_patch_tokens = self.model_clip.encode_image(batch, self.feature_list)
|
| 246 |
+
# image_features /= image_features.norm(dim=-1, keepdim=True)
|
| 247 |
+
patch_tokens = [p[:, 1:, :] for p in patch_tokens]
|
| 248 |
+
patch_tokens = [p.reshape(p.shape[0]*p.shape[1], p.shape[2]) for p in patch_tokens]
|
| 249 |
+
|
| 250 |
+
patch_tokens_clip = torch.cat(patch_tokens, dim=-1) # (1, 1024, 1024x4)
|
| 251 |
+
# patch_tokens_clip = torch.cat(patch_tokens[2:], dim=-1) # (1, 1024, 1024x2)
|
| 252 |
+
patch_tokens_clip = patch_tokens_clip.view(1, self.ori_feat_size, self.ori_feat_size, -1).permute(0, 3, 1, 2)
|
| 253 |
+
patch_tokens_clip = F.interpolate(patch_tokens_clip, size=(self.feat_size, self.feat_size), mode=self.inter_mode, align_corners=self.align_corners)
|
| 254 |
+
patch_tokens_clip = patch_tokens_clip.permute(0, 2, 3, 1).view(-1, self.vision_width * len(self.feature_list))
|
| 255 |
+
patch_tokens_clip = F.normalize(patch_tokens_clip, p=2, dim=-1) # (1x64x64, 1024x4)
|
| 256 |
+
|
| 257 |
+
with torch.no_grad():
|
| 258 |
+
patch_tokens_dinov2 = self.model_dinov2.forward_features(batch, out_layer_list=self.feature_list)
|
| 259 |
+
patch_tokens_dinov2 = torch.cat(patch_tokens_dinov2, dim=-1) # (1, 1024, 1024x4)
|
| 260 |
+
patch_tokens_dinov2 = patch_tokens_dinov2.view(1, self.ori_feat_size, self.ori_feat_size, -1).permute(0, 3, 1, 2)
|
| 261 |
+
patch_tokens_dinov2 = F.interpolate(patch_tokens_dinov2, size=(self.feat_size, self.feat_size), mode=self.inter_mode, align_corners=self.align_corners)
|
| 262 |
+
patch_tokens_dinov2 = patch_tokens_dinov2.permute(0, 2, 3, 1).view(-1, self.vision_width_dinov2 * len(self.feature_list_dinov2))
|
| 263 |
+
patch_tokens_dinov2 = F.normalize(patch_tokens_dinov2, p=2, dim=-1) # (1x64x64, 1024x4)
|
| 264 |
+
|
| 265 |
+
'''adding for kmeans seg '''
|
| 266 |
+
if self.feat_size != self.ori_feat_size:
|
| 267 |
+
proj_patch_tokens = proj_patch_tokens.view(1, self.ori_feat_size, self.ori_feat_size, -1).permute(0, 3, 1, 2)
|
| 268 |
+
proj_patch_tokens = F.interpolate(proj_patch_tokens, size=(self.feat_size, self.feat_size), mode=self.inter_mode, align_corners=self.align_corners)
|
| 269 |
+
proj_patch_tokens = proj_patch_tokens.permute(0, 2, 3, 1).view(self.feat_size * self.feat_size, self.embed_dim)
|
| 270 |
+
proj_patch_tokens = F.normalize(proj_patch_tokens, p=2, dim=-1)
|
| 271 |
+
|
| 272 |
+
mid_features = None
|
| 273 |
+
for layer in self.cluster_feature_id:
|
| 274 |
+
temp_feat = patch_tokens[layer]
|
| 275 |
+
mid_features = temp_feat if mid_features is None else torch.cat((mid_features, temp_feat), -1)
|
| 276 |
+
|
| 277 |
+
if self.feat_size != self.ori_feat_size:
|
| 278 |
+
mid_features = mid_features.view(1, self.ori_feat_size, self.ori_feat_size, -1).permute(0, 3, 1, 2)
|
| 279 |
+
mid_features = F.interpolate(mid_features, size=(self.feat_size, self.feat_size), mode=self.inter_mode, align_corners=self.align_corners)
|
| 280 |
+
mid_features = mid_features.permute(0, 2, 3, 1).view(-1, self.vision_width * len(self.cluster_feature_id))
|
| 281 |
+
mid_features = F.normalize(mid_features, p=2, dim=-1)
|
| 282 |
+
|
| 283 |
+
results = self.histogram(batch, mid_features, proj_patch_tokens, self.class_name, os.path.dirname(path).split('/')[-1] + "_" + os.path.basename(path).split('.')[0])
|
| 284 |
+
|
| 285 |
+
hist_score = results['score']
|
| 286 |
+
|
| 287 |
+
'''calculate patchcore'''
|
| 288 |
+
anomaly_maps_patchcore = []
|
| 289 |
+
|
| 290 |
+
if self.class_name in ['pushpins', 'screw_bag']: # clip feature for patchcore
|
| 291 |
+
len_feature_list = len(self.feature_list)
|
| 292 |
+
for patch_feature, mem_patch_feature in zip(patch_tokens_clip.chunk(len_feature_list, dim=-1), mem_patch_feature_clip_coreset.chunk(len_feature_list, dim=-1)):
|
| 293 |
+
patch_feature = F.normalize(patch_feature, dim=-1)
|
| 294 |
+
mem_patch_feature = F.normalize(mem_patch_feature, dim=-1)
|
| 295 |
+
normal_map_patchcore = (patch_feature @ mem_patch_feature.T)
|
| 296 |
+
normal_map_patchcore = (normal_map_patchcore.max(1)[0]).cpu().numpy() # 1: normal 0: abnormal
|
| 297 |
+
anomaly_map_patchcore = 1 - normal_map_patchcore
|
| 298 |
+
|
| 299 |
+
anomaly_maps_patchcore.append(anomaly_map_patchcore)
|
| 300 |
+
|
| 301 |
+
if self.class_name in ['splicing_connectors', 'breakfast_box', 'juice_bottle']: # dinov2 feature for patchcore
|
| 302 |
+
len_feature_list = len(self.feature_list_dinov2)
|
| 303 |
+
for patch_feature, mem_patch_feature in zip(patch_tokens_dinov2.chunk(len_feature_list, dim=-1), mem_patch_feature_dinov2_coreset.chunk(len_feature_list, dim=-1)):
|
| 304 |
+
patch_feature = F.normalize(patch_feature, dim=-1)
|
| 305 |
+
mem_patch_feature = F.normalize(mem_patch_feature, dim=-1)
|
| 306 |
+
normal_map_patchcore = (patch_feature @ mem_patch_feature.T)
|
| 307 |
+
normal_map_patchcore = (normal_map_patchcore.max(1)[0]).cpu().numpy() # 1: normal 0: abnormal
|
| 308 |
+
anomaly_map_patchcore = 1 - normal_map_patchcore
|
| 309 |
+
|
| 310 |
+
anomaly_maps_patchcore.append(anomaly_map_patchcore)
|
| 311 |
+
|
| 312 |
+
structural_score = np.stack(anomaly_maps_patchcore).mean(0).max()
|
| 313 |
+
# anomaly_map_structural = np.stack(anomaly_maps_patchcore).mean(0).reshape(self.feat_size, self.feat_size)
|
| 314 |
+
|
| 315 |
+
instance_masks = results["instance_masks"]
|
| 316 |
+
anomaly_instances_hungarian = []
|
| 317 |
+
instance_hungarian_match_score = 1.
|
| 318 |
+
if self.mem_instance_masks is not None and len(instance_masks) != 0:
|
| 319 |
+
for patch_feature, mem_instance_features_single_stage in zip(patch_tokens_clip.chunk(len_feature_list, dim=-1), self.mem_instance_features_multi_stage.chunk(len_feature_list, dim=1)):
|
| 320 |
+
instance_features = [patch_feature[mask, :].mean(0, keepdim=True) for mask in instance_masks]
|
| 321 |
+
instance_features = torch.cat(instance_features, dim=0)
|
| 322 |
+
instance_features = F.normalize(instance_features, dim=-1)
|
| 323 |
+
|
| 324 |
+
normal_instance_hungarian = (instance_features @ mem_instance_features_single_stage.T)
|
| 325 |
+
cost_matrix = (1 - normal_instance_hungarian).cpu().numpy()
|
| 326 |
+
|
| 327 |
+
row_ind, col_ind = linear_sum_assignment(cost_matrix)
|
| 328 |
+
cost = cost_matrix[row_ind, col_ind].sum()
|
| 329 |
+
cost = cost / min(cost_matrix.shape)
|
| 330 |
+
anomaly_instances_hungarian.append(cost)
|
| 331 |
+
|
| 332 |
+
instance_hungarian_match_score = np.mean(anomaly_instances_hungarian)
|
| 333 |
+
|
| 334 |
+
results = {'hist_score': hist_score, 'structural_score': structural_score, 'instance_hungarian_match_score': instance_hungarian_match_score}
|
| 335 |
+
|
| 336 |
+
return results
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
def histogram(self, image, cluster_feature, proj_patch_token, class_name, path):
|
| 340 |
+
def plot_results_only(sorted_anns):
|
| 341 |
+
cur = 1
|
| 342 |
+
img_color = np.zeros((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1]))
|
| 343 |
+
for ann in sorted_anns:
|
| 344 |
+
m = ann['segmentation']
|
| 345 |
+
img_color[m] = cur
|
| 346 |
+
cur += 1
|
| 347 |
+
return img_color
|
| 348 |
+
|
| 349 |
+
def merge_segmentations(a, b, background_class):
|
| 350 |
+
unique_labels_a = np.unique(a)
|
| 351 |
+
unique_labels_b = np.unique(b)
|
| 352 |
+
|
| 353 |
+
max_label_a = unique_labels_a.max()
|
| 354 |
+
label_map = np.zeros(max_label_a + 1, dtype=int)
|
| 355 |
+
|
| 356 |
+
for label_a in unique_labels_a:
|
| 357 |
+
mask_a = (a == label_a)
|
| 358 |
+
|
| 359 |
+
labels_b = b[mask_a]
|
| 360 |
+
if labels_b.size > 0:
|
| 361 |
+
count_b = np.bincount(labels_b, minlength=unique_labels_b.max() + 1)
|
| 362 |
+
label_map[label_a] = np.argmax(count_b)
|
| 363 |
+
else:
|
| 364 |
+
label_map[label_a] = background_class # default background
|
| 365 |
+
|
| 366 |
+
merged_a = label_map[a]
|
| 367 |
+
return merged_a
|
| 368 |
+
|
| 369 |
+
pseudo_labels = kmeans_predict(cluster_feature, self.cluster_centers, 'euclidean', device=self.device)
|
| 370 |
+
kmeans_mask = torch.ones_like(pseudo_labels) * (self.classes - 1) # default to background
|
| 371 |
+
|
| 372 |
+
for pl in pseudo_labels.unique():
|
| 373 |
+
mask = (pseudo_labels == pl).reshape(-1)
|
| 374 |
+
# filter small region
|
| 375 |
+
binary = mask.cpu().numpy().reshape(self.feat_size, self.feat_size).astype(np.uint8)
|
| 376 |
+
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(binary, connectivity=8)
|
| 377 |
+
for i in range(1, num_labels):
|
| 378 |
+
temp_mask = labels == i
|
| 379 |
+
if np.sum(temp_mask) <= 8:
|
| 380 |
+
mask[temp_mask.reshape(-1)] = False
|
| 381 |
+
|
| 382 |
+
if mask.any():
|
| 383 |
+
region_feature = proj_patch_token[mask, :].mean(0, keepdim=True)
|
| 384 |
+
similarity = (region_feature @ self.query_obj.T)
|
| 385 |
+
prob, index = torch.max(similarity, dim=-1)
|
| 386 |
+
temp_label = index.squeeze(0).item()
|
| 387 |
+
temp_prob = prob.squeeze(0).item()
|
| 388 |
+
if temp_prob > self.query_threshold_dict[class_name][temp_label]: # threshold for each class
|
| 389 |
+
kmeans_mask[mask] = temp_label
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
raw_image = to_np_img(image[0])
|
| 393 |
+
height, width = raw_image.shape[:2]
|
| 394 |
+
masks = self.mask_generator.generate(raw_image)
|
| 395 |
+
# self.predictor.set_image(raw_image)
|
| 396 |
+
|
| 397 |
+
kmeans_label = pseudo_labels.view(self.feat_size, self.feat_size).cpu().numpy()
|
| 398 |
+
kmeans_mask = kmeans_mask.view(self.feat_size, self.feat_size).cpu().numpy()
|
| 399 |
+
|
| 400 |
+
patch_similarity = (proj_patch_token @ self.patch_query_obj.T)
|
| 401 |
+
patch_mask = patch_similarity.argmax(-1)
|
| 402 |
+
patch_mask = patch_mask.view(self.feat_size, self.feat_size).cpu().numpy()
|
| 403 |
+
|
| 404 |
+
sorted_masks = sorted(masks, key=(lambda x: x['area']), reverse=True)
|
| 405 |
+
sam_mask = plot_results_only(sorted_masks).astype(np.int)
|
| 406 |
+
|
| 407 |
+
resized_mask = cv2.resize(kmeans_mask, (width, height), interpolation = cv2.INTER_NEAREST)
|
| 408 |
+
merge_sam = merge_segmentations(sam_mask, resized_mask, background_class=self.classes-1)
|
| 409 |
+
|
| 410 |
+
resized_patch_mask = cv2.resize(patch_mask, (width, height), interpolation = cv2.INTER_NEAREST)
|
| 411 |
+
patch_merge_sam = merge_segmentations(sam_mask, resized_patch_mask, background_class=self.patch_query_obj.shape[0]-1)
|
| 412 |
+
|
| 413 |
+
# filter small region for merge sam
|
| 414 |
+
binary = np.isin(merge_sam, self.foreground_label_idx[self.class_name]).astype(np.uint8) # foreground 1 background 0
|
| 415 |
+
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(binary, connectivity=8)
|
| 416 |
+
for i in range(1, num_labels):
|
| 417 |
+
temp_mask = labels == i
|
| 418 |
+
if np.sum(temp_mask) <= 32: # 448x448
|
| 419 |
+
merge_sam[temp_mask] = self.classes - 1 # set to background
|
| 420 |
+
|
| 421 |
+
# filter small region for patch merge sam
|
| 422 |
+
binary = (patch_merge_sam != (self.patch_query_obj.shape[0]-1) ).astype(np.uint8) # foreground 1 background 0
|
| 423 |
+
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(binary, connectivity=8)
|
| 424 |
+
for i in range(1, num_labels):
|
| 425 |
+
temp_mask = labels == i
|
| 426 |
+
if np.sum(temp_mask) <= 32: # 448x448
|
| 427 |
+
patch_merge_sam[temp_mask] = self.patch_query_obj.shape[0]-1 # set to background
|
| 428 |
+
|
| 429 |
+
score = 0. # default to normal
|
| 430 |
+
self.anomaly_flag = False
|
| 431 |
+
instance_masks = []
|
| 432 |
+
if self.class_name == 'pushpins':
|
| 433 |
+
# object count hist
|
| 434 |
+
kernel = np.ones((3, 3), dtype=np.uint8) # dilate for robustness
|
| 435 |
+
binary = np.isin(merge_sam, self.foreground_label_idx[self.class_name]).astype(np.uint8) # foreground 1 background 0
|
| 436 |
+
dilate_binary = cv2.dilate(binary, kernel)
|
| 437 |
+
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(dilate_binary, connectivity=8)
|
| 438 |
+
pushpins_count = num_labels - 1 # number of pushpins
|
| 439 |
+
|
| 440 |
+
for i in range(1, num_labels):
|
| 441 |
+
instance_mask = (labels == i).astype(np.uint8)
|
| 442 |
+
instance_mask = cv2.resize(instance_mask, (self.feat_size, self.feat_size), interpolation = cv2.INTER_NEAREST)
|
| 443 |
+
if instance_mask.any():
|
| 444 |
+
instance_masks.append(instance_mask.astype(np.bool).reshape(-1))
|
| 445 |
+
|
| 446 |
+
if self.few_shot_inited and pushpins_count != self.pushpins_count and self.anomaly_flag is False:
|
| 447 |
+
self.anomaly_flag = True
|
| 448 |
+
print('number of pushpins: {}, but canonical number of pushpins: {}'.format(pushpins_count, self.pushpins_count))
|
| 449 |
+
|
| 450 |
+
# patch hist
|
| 451 |
+
clip_patch_hist = np.bincount(patch_mask.reshape(-1), minlength=self.patch_query_obj.shape[0])
|
| 452 |
+
clip_patch_hist = clip_patch_hist / np.linalg.norm(clip_patch_hist)
|
| 453 |
+
|
| 454 |
+
if self.few_shot_inited:
|
| 455 |
+
patch_hist_similarity = (clip_patch_hist @ self.patch_token_hist.T)
|
| 456 |
+
score = 1 - patch_hist_similarity.max()
|
| 457 |
+
|
| 458 |
+
binary_foreground = dilate_binary.astype(np.uint8)
|
| 459 |
+
|
| 460 |
+
if len(instance_masks) != 0:
|
| 461 |
+
instance_masks = np.stack(instance_masks) #[N, 64x64]
|
| 462 |
+
|
| 463 |
+
if self.visualization:
|
| 464 |
+
image_list = [raw_image, kmeans_label, kmeans_mask, patch_mask, sam_mask, merge_sam, patch_merge_sam, binary_foreground]
|
| 465 |
+
title_list = ['raw image', 'k-means', 'kmeans mask', 'patch mask', 'sam mask', 'merge sam mask', 'patch merge sam', 'binary_foreground']
|
| 466 |
+
plt.figure(figsize=(20, 3))
|
| 467 |
+
for ind, (temp_title, temp_image) in enumerate(zip(title_list, image_list), start=1):
|
| 468 |
+
plt.subplot(1, len(image_list), ind)
|
| 469 |
+
plt.imshow(temp_image)
|
| 470 |
+
plt.title(temp_title)
|
| 471 |
+
plt.margins(0, 0)
|
| 472 |
+
plt.axis('off')
|
| 473 |
+
# Extract relative path from class_name onwards
|
| 474 |
+
if class_name in path:
|
| 475 |
+
relative_path = path.split(class_name, 1)[-1]
|
| 476 |
+
if relative_path.startswith('/'):
|
| 477 |
+
relative_path = relative_path[1:]
|
| 478 |
+
save_path = f'visualization/full_data/{class_name}/{relative_path}.png'
|
| 479 |
+
else:
|
| 480 |
+
save_path = f'visualization/full_data/{class_name}/{path}.png'
|
| 481 |
+
|
| 482 |
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
| 483 |
+
plt.tight_layout()
|
| 484 |
+
plt.savefig(save_path, bbox_inches='tight', dpi=150)
|
| 485 |
+
plt.close()
|
| 486 |
+
|
| 487 |
+
|
| 488 |
+
# todo: same number in total but in different boxes or broken box
|
| 489 |
+
return {"score": score, "clip_patch_hist": clip_patch_hist, "instance_masks": instance_masks}
|
| 490 |
+
|
| 491 |
+
elif self.class_name == 'splicing_connectors':
|
| 492 |
+
# object count hist for default
|
| 493 |
+
sam_mask_max_area = sorted_masks[0]['segmentation'] # background
|
| 494 |
+
binary = (sam_mask_max_area == 0).astype(np.uint8) # sam_mask_max_area is background, background 0 foreground 1
|
| 495 |
+
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(binary, connectivity=8)
|
| 496 |
+
count = 0
|
| 497 |
+
for i in range(1, num_labels):
|
| 498 |
+
temp_mask = labels == i
|
| 499 |
+
if np.sum(temp_mask) <= 64: # 448x448 64
|
| 500 |
+
binary[temp_mask] = 0 # set to background
|
| 501 |
+
else:
|
| 502 |
+
count += 1
|
| 503 |
+
if count != 1 and self.anomaly_flag is False: # cable cut or no cable or no connector
|
| 504 |
+
print('number of connected component in splicing_connectors: {}, but the default connected component is 1.'.format(count))
|
| 505 |
+
self.anomaly_flag = True
|
| 506 |
+
|
| 507 |
+
merge_sam[~(binary.astype(np.bool))] = self.query_obj.shape[0] - 1 # remove noise
|
| 508 |
+
patch_merge_sam[~(binary.astype(np.bool))] = self.patch_query_obj.shape[0] - 1 # remove patch noise
|
| 509 |
+
|
| 510 |
+
# erode the cable and divide into left and right parts
|
| 511 |
+
kernel = np.ones((23, 23), dtype=np.uint8)
|
| 512 |
+
erode_binary = cv2.erode(binary, kernel)
|
| 513 |
+
h, w = erode_binary.shape
|
| 514 |
+
distance = 0
|
| 515 |
+
|
| 516 |
+
left, right = erode_binary[:, :int(w/2)], erode_binary[:, int(w/2):]
|
| 517 |
+
left_count = np.bincount(left.reshape(-1), minlength=self.classes)[1] # foreground
|
| 518 |
+
right_count = np.bincount(right.reshape(-1), minlength=self.classes)[1] # foreground
|
| 519 |
+
|
| 520 |
+
# binary_cable = (merge_sam == 1).astype(np.uint8)
|
| 521 |
+
binary_cable = (patch_merge_sam == 1).astype(np.uint8)
|
| 522 |
+
|
| 523 |
+
kernel = np.ones((5, 5), dtype=np.uint8)
|
| 524 |
+
binary_cable = cv2.erode(binary_cable, kernel)
|
| 525 |
+
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(binary_cable, connectivity=8)
|
| 526 |
+
for i in range(1, num_labels):
|
| 527 |
+
temp_mask = labels == i
|
| 528 |
+
if np.sum(temp_mask) <= 64: # 448x448
|
| 529 |
+
binary_cable[temp_mask] = 0 # set to background
|
| 530 |
+
|
| 531 |
+
|
| 532 |
+
binary_cable = cv2.resize(binary_cable, (self.feat_size, self.feat_size), interpolation = cv2.INTER_NEAREST)
|
| 533 |
+
|
| 534 |
+
binary_clamps = (patch_merge_sam == 0).astype(np.uint8)
|
| 535 |
+
|
| 536 |
+
kernel = np.ones((5, 5), dtype=np.uint8)
|
| 537 |
+
binary_clamps = cv2.erode(binary_clamps, kernel)
|
| 538 |
+
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(binary_clamps, connectivity=8)
|
| 539 |
+
for i in range(1, num_labels):
|
| 540 |
+
temp_mask = labels == i
|
| 541 |
+
if np.sum(temp_mask) <= 64: # 448x448
|
| 542 |
+
binary_clamps[temp_mask] = 0 # set to background
|
| 543 |
+
else:
|
| 544 |
+
instance_mask = temp_mask.astype(np.uint8)
|
| 545 |
+
instance_mask = cv2.resize(instance_mask, (self.feat_size, self.feat_size), interpolation = cv2.INTER_NEAREST)
|
| 546 |
+
if instance_mask.any():
|
| 547 |
+
instance_masks.append(instance_mask.astype(np.bool).reshape(-1))
|
| 548 |
+
|
| 549 |
+
binary_clamps = cv2.resize(binary_clamps, (self.feat_size, self.feat_size), interpolation = cv2.INTER_NEAREST)
|
| 550 |
+
|
| 551 |
+
binary_connector = cv2.resize(binary, (self.feat_size, self.feat_size), interpolation = cv2.INTER_NEAREST)
|
| 552 |
+
|
| 553 |
+
query_cable_color = encode_obj_text(self.model_clip, self.splicing_connectors_cable_color_query_words_dict, self.tokenizer, self.device)
|
| 554 |
+
cable_feature = proj_patch_token[binary_cable.astype(np.bool).reshape(-1), :].mean(0, keepdim=True)
|
| 555 |
+
idx_color = (cable_feature @ query_cable_color.T).argmax(-1).squeeze(0).item()
|
| 556 |
+
foreground_pixel_count = np.sum(erode_binary) / self.splicing_connectors_count[idx_color]
|
| 557 |
+
|
| 558 |
+
|
| 559 |
+
slice_cable = binary[:, int(w/2)-1: int(w/2)+1]
|
| 560 |
+
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(slice_cable, connectivity=8)
|
| 561 |
+
cable_count = num_labels - 1
|
| 562 |
+
if cable_count != 1 and self.anomaly_flag is False: # two cables
|
| 563 |
+
print('number of cable count in splicing_connectors: {}, but the default cable count is 1.'.format(cable_count))
|
| 564 |
+
self.anomaly_flag = True
|
| 565 |
+
|
| 566 |
+
# {2-clamp: yellow 3-clamp: blue 5-clamp: red} cable color and clamp number mismatch
|
| 567 |
+
if self.few_shot_inited and self.foreground_pixel_hist_splicing_connectors != 0 and self.anomaly_flag is False:
|
| 568 |
+
ratio = foreground_pixel_count / self.foreground_pixel_hist_splicing_connectors
|
| 569 |
+
if (ratio > 1.2 or ratio < 0.8) and self.anomaly_flag is False: # color and number mismatch
|
| 570 |
+
print('cable color and number of clamps mismatch, cable color idx: {} (0: yellow 2-clamp, 1: blue 3-clamp, 2: red 5-clamp), foreground_pixel_count :{}, canonical foreground_pixel_hist: {}.'.format(idx_color, foreground_pixel_count, self.foreground_pixel_hist_splicing_connectors))
|
| 571 |
+
self.anomaly_flag = True
|
| 572 |
+
|
| 573 |
+
# left right hist for symmetry
|
| 574 |
+
ratio = np.sum(left_count) / (np.sum(right_count) + 1e-5)
|
| 575 |
+
if self.few_shot_inited and (ratio > 1.2 or ratio < 0.8) and self.anomaly_flag is False: # left right asymmetry in clamp
|
| 576 |
+
print('left and right connectors are not symmetry.')
|
| 577 |
+
self.anomaly_flag = True
|
| 578 |
+
|
| 579 |
+
# left and right centroids distance
|
| 580 |
+
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(erode_binary, connectivity=8)
|
| 581 |
+
if num_labels - 1 == 2:
|
| 582 |
+
centroids = centroids[1:]
|
| 583 |
+
x1, y1 = centroids[0]
|
| 584 |
+
x2, y2 = centroids[1]
|
| 585 |
+
distance = np.sqrt((x1/w - x2/w)**2 + (y1/h - y2/h)**2)
|
| 586 |
+
if self.few_shot_inited and self.splicing_connectors_distance != 0 and self.anomaly_flag is False:
|
| 587 |
+
ratio = distance / self.splicing_connectors_distance
|
| 588 |
+
if ratio < 0.6 or ratio > 1.4: # too short or too long centroids distance (cable) # 0.6 1.4
|
| 589 |
+
print('cable is too short or too long.')
|
| 590 |
+
self.anomaly_flag = True
|
| 591 |
+
|
| 592 |
+
# patch hist
|
| 593 |
+
sam_patch_hist = np.bincount(patch_merge_sam.reshape(-1), minlength=self.patch_query_obj.shape[0])#[:-1] # ignore background (grid) for statistic
|
| 594 |
+
sam_patch_hist = sam_patch_hist / np.linalg.norm(sam_patch_hist)
|
| 595 |
+
|
| 596 |
+
if self.few_shot_inited:
|
| 597 |
+
patch_hist_similarity = (sam_patch_hist @ self.patch_token_hist.T)
|
| 598 |
+
score = 1 - patch_hist_similarity.max()
|
| 599 |
+
|
| 600 |
+
# todo mismatch cable link
|
| 601 |
+
binary_foreground = binary.astype(np.uint8) # only 1 instance, so additionally seperate cable and clamps
|
| 602 |
+
if binary_connector.any():
|
| 603 |
+
instance_masks.append(binary_connector.astype(np.bool).reshape(-1))
|
| 604 |
+
if binary_clamps.any():
|
| 605 |
+
instance_masks.append(binary_clamps.astype(np.bool).reshape(-1))
|
| 606 |
+
if binary_cable.any():
|
| 607 |
+
instance_masks.append(binary_cable.astype(np.bool).reshape(-1))
|
| 608 |
+
|
| 609 |
+
if len(instance_masks) != 0:
|
| 610 |
+
instance_masks = np.stack(instance_masks) #[N, 64x64]
|
| 611 |
+
|
| 612 |
+
if self.visualization:
|
| 613 |
+
image_list = [raw_image, kmeans_label, kmeans_mask, patch_mask, sam_mask, binary_connector, merge_sam, patch_merge_sam, erode_binary, binary_cable, binary_clamps]
|
| 614 |
+
title_list = ['raw image', 'k-means', 'kmeans mask', 'patch mask', 'sam mask', 'binary_connector', 'merge sam', 'patch merge sam', 'erode binary', 'binary_cable', 'binary_clamps']
|
| 615 |
+
plt.figure(figsize=(25, 3))
|
| 616 |
+
for ind, (temp_title, temp_image) in enumerate(zip(title_list, image_list), start=1):
|
| 617 |
+
plt.subplot(1, len(image_list), ind)
|
| 618 |
+
plt.imshow(temp_image)
|
| 619 |
+
plt.title(temp_title)
|
| 620 |
+
plt.margins(0, 0)
|
| 621 |
+
plt.axis('off')
|
| 622 |
+
# Extract relative path from class_name onwards
|
| 623 |
+
if class_name in path:
|
| 624 |
+
relative_path = path.split(class_name, 1)[-1]
|
| 625 |
+
if relative_path.startswith('/'):
|
| 626 |
+
relative_path = relative_path[1:]
|
| 627 |
+
save_path = f'visualization/full_data/{class_name}/{relative_path}.png'
|
| 628 |
+
else:
|
| 629 |
+
save_path = f'visualization/full_data/{class_name}/{path}.png'
|
| 630 |
+
|
| 631 |
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
| 632 |
+
plt.tight_layout()
|
| 633 |
+
plt.savefig(save_path, bbox_inches='tight', dpi=150)
|
| 634 |
+
plt.close()
|
| 635 |
+
|
| 636 |
+
return {"score": score, "foreground_pixel_count": foreground_pixel_count, "distance": distance, "sam_patch_hist": sam_patch_hist, "instance_masks": instance_masks}
|
| 637 |
+
|
| 638 |
+
elif self.class_name == 'screw_bag':
|
| 639 |
+
# pixel hist of kmeans mask
|
| 640 |
+
foreground_pixel_count = np.sum(np.bincount(kmeans_mask.reshape(-1))[:len(self.foreground_label_idx[self.class_name])]) # foreground pixel
|
| 641 |
+
if self.few_shot_inited and self.foreground_pixel_hist_screw_bag != 0 and self.anomaly_flag is False:
|
| 642 |
+
ratio = foreground_pixel_count / self.foreground_pixel_hist_screw_bag
|
| 643 |
+
# todo: optimize
|
| 644 |
+
if ratio < 0.94 or ratio > 1.06: # 82.95 | 81.3
|
| 645 |
+
print('foreground pixel histagram of screw bag: {}, the canonical foreground pixel histogram of screw bag in few shot: {}'.format(foreground_pixel_count, self.foreground_pixel_hist_screw_bag))
|
| 646 |
+
self.anomaly_flag = True
|
| 647 |
+
|
| 648 |
+
# patch hist
|
| 649 |
+
binary_screw = np.isin(kmeans_mask, self.foreground_label_idx[self.class_name])
|
| 650 |
+
patch_mask[~binary_screw] = self.patch_query_obj.shape[0] - 1 # remove patch noise
|
| 651 |
+
resized_binary_screw = cv2.resize(binary_screw.astype(np.uint8), (patch_merge_sam.shape[1], patch_merge_sam.shape[0]), interpolation = cv2.INTER_NEAREST)
|
| 652 |
+
patch_merge_sam[~(resized_binary_screw.astype(np.bool))] = self.patch_query_obj.shape[0] - 1 # remove patch noise
|
| 653 |
+
|
| 654 |
+
clip_patch_hist = np.bincount(patch_mask.reshape(-1), minlength=self.patch_query_obj.shape[0])[:-1]
|
| 655 |
+
clip_patch_hist = clip_patch_hist / np.linalg.norm(clip_patch_hist)
|
| 656 |
+
|
| 657 |
+
if self.few_shot_inited:
|
| 658 |
+
patch_hist_similarity = (clip_patch_hist @ self.patch_token_hist.T)
|
| 659 |
+
score = 1 - patch_hist_similarity.max()
|
| 660 |
+
|
| 661 |
+
for i in range(self.patch_query_obj.shape[0]-1):
|
| 662 |
+
binary_foreground = (patch_merge_sam == i).astype(np.uint8)
|
| 663 |
+
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(binary_foreground, connectivity=8)
|
| 664 |
+
for i in range(1, num_labels):
|
| 665 |
+
instance_mask = (labels == i).astype(np.uint8)
|
| 666 |
+
instance_mask = cv2.resize(instance_mask, (self.feat_size, self.feat_size), interpolation = cv2.INTER_NEAREST)
|
| 667 |
+
if instance_mask.any():
|
| 668 |
+
instance_masks.append(instance_mask.astype(np.bool).reshape(-1))
|
| 669 |
+
|
| 670 |
+
if len(instance_masks) != 0:
|
| 671 |
+
instance_masks = np.stack(instance_masks) #[N, 64x64]
|
| 672 |
+
|
| 673 |
+
if self.visualization:
|
| 674 |
+
image_list = [raw_image, kmeans_label, kmeans_mask, patch_mask, sam_mask, merge_sam, patch_merge_sam, binary_foreground]
|
| 675 |
+
title_list = ['raw image', 'k-means', 'kmeans mask', 'patch mask', 'sam mask', 'merge sam mask', 'patch merge sam', 'binary_foreground']
|
| 676 |
+
plt.figure(figsize=(20, 3))
|
| 677 |
+
for ind, (temp_title, temp_image) in enumerate(zip(title_list, image_list), start=1):
|
| 678 |
+
plt.subplot(1, len(image_list), ind)
|
| 679 |
+
plt.imshow(temp_image)
|
| 680 |
+
plt.title(temp_title)
|
| 681 |
+
plt.margins(0, 0)
|
| 682 |
+
plt.axis('off')
|
| 683 |
+
# Extract relative path from class_name onwards
|
| 684 |
+
if class_name in path:
|
| 685 |
+
relative_path = path.split(class_name, 1)[-1]
|
| 686 |
+
if relative_path.startswith('/'):
|
| 687 |
+
relative_path = relative_path[1:]
|
| 688 |
+
save_path = f'visualization/full_data/{class_name}/{relative_path}.png'
|
| 689 |
+
else:
|
| 690 |
+
save_path = f'visualization/full_data/{class_name}/{path}.png'
|
| 691 |
+
|
| 692 |
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
| 693 |
+
plt.tight_layout()
|
| 694 |
+
plt.savefig(save_path, bbox_inches='tight', dpi=150)
|
| 695 |
+
plt.close()
|
| 696 |
+
|
| 697 |
+
# plt.axis('off')
|
| 698 |
+
# plt.imshow(patch_merge_sam)
|
| 699 |
+
|
| 700 |
+
# plt.savefig('pic/vis/{}_seg_{}.png'.format(class_name, path), bbox_inches='tight', pad_inches = 0) # pad_inches = 0
|
| 701 |
+
# plt.close()
|
| 702 |
+
|
| 703 |
+
|
| 704 |
+
return {"score": score, "foreground_pixel_count": foreground_pixel_count, "clip_patch_hist": clip_patch_hist, "instance_masks": instance_masks}
|
| 705 |
+
|
| 706 |
+
elif self.class_name == 'breakfast_box':
|
| 707 |
+
# patch hist
|
| 708 |
+
sam_patch_hist = np.bincount(patch_merge_sam.reshape(-1), minlength=self.patch_query_obj.shape[0])
|
| 709 |
+
sam_patch_hist = sam_patch_hist / np.linalg.norm(sam_patch_hist)
|
| 710 |
+
|
| 711 |
+
if self.few_shot_inited:
|
| 712 |
+
patch_hist_similarity = (sam_patch_hist @ self.patch_token_hist.T)
|
| 713 |
+
score = 1 - patch_hist_similarity.max()
|
| 714 |
+
|
| 715 |
+
# todo: exist of foreground
|
| 716 |
+
|
| 717 |
+
binary_foreground = (patch_merge_sam != (self.patch_query_obj.shape[0] - 1)).astype(np.uint8)
|
| 718 |
+
|
| 719 |
+
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(binary_foreground, connectivity=8)
|
| 720 |
+
for i in range(1, num_labels):
|
| 721 |
+
instance_mask = (labels == i).astype(np.uint8)
|
| 722 |
+
instance_mask = cv2.resize(instance_mask, (self.feat_size, self.feat_size), interpolation = cv2.INTER_NEAREST)
|
| 723 |
+
if instance_mask.any():
|
| 724 |
+
instance_masks.append(instance_mask.astype(np.bool).reshape(-1))
|
| 725 |
+
|
| 726 |
+
|
| 727 |
+
if len(instance_masks) != 0:
|
| 728 |
+
instance_masks = np.stack(instance_masks) #[N, 64x64]
|
| 729 |
+
|
| 730 |
+
if self.visualization:
|
| 731 |
+
image_list = [raw_image, kmeans_label, kmeans_mask, patch_mask, sam_mask, merge_sam, patch_merge_sam, binary_foreground]
|
| 732 |
+
title_list = ['raw image', 'k-means', 'kmeans mask', 'patch mask', 'sam mask', 'merge sam mask', 'patch merge sam', 'binary_foreground']
|
| 733 |
+
plt.figure(figsize=(20, 3))
|
| 734 |
+
for ind, (temp_title, temp_image) in enumerate(zip(title_list, image_list), start=1):
|
| 735 |
+
plt.subplot(1, len(image_list), ind)
|
| 736 |
+
plt.imshow(temp_image)
|
| 737 |
+
plt.title(temp_title)
|
| 738 |
+
plt.margins(0, 0)
|
| 739 |
+
plt.axis('off')
|
| 740 |
+
# Extract relative path from class_name onwards
|
| 741 |
+
if class_name in path:
|
| 742 |
+
relative_path = path.split(class_name, 1)[-1]
|
| 743 |
+
if relative_path.startswith('/'):
|
| 744 |
+
relative_path = relative_path[1:]
|
| 745 |
+
save_path = f'visualization/full_data/{class_name}/{relative_path}.png'
|
| 746 |
+
else:
|
| 747 |
+
save_path = f'visualization/full_data/{class_name}/{path}.png'
|
| 748 |
+
|
| 749 |
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
| 750 |
+
plt.tight_layout()
|
| 751 |
+
plt.savefig(save_path, bbox_inches='tight', dpi=150)
|
| 752 |
+
plt.close()
|
| 753 |
+
|
| 754 |
+
# plt.axis('off')
|
| 755 |
+
# plt.imshow(patch_merge_sam)
|
| 756 |
+
|
| 757 |
+
# plt.savefig('pic/vis/{}_seg_{}.png'.format(class_name, path), bbox_inches='tight', pad_inches = 0) # pad_inches = 0
|
| 758 |
+
# plt.close()
|
| 759 |
+
|
| 760 |
+
return {"score": score, "sam_patch_hist": sam_patch_hist, "instance_masks": instance_masks}
|
| 761 |
+
|
| 762 |
+
elif self.class_name == 'juice_bottle':
|
| 763 |
+
# remove noise due to non sam mask
|
| 764 |
+
merge_sam[sam_mask == 0] = self.classes - 1
|
| 765 |
+
patch_merge_sam[sam_mask == 0] = self.patch_query_obj.shape[0] - 1 # 79.5
|
| 766 |
+
|
| 767 |
+
# [['glass'], ['liquid in bottle'], ['fruit'], ['label', 'tag'], ['black background', 'background']],
|
| 768 |
+
# fruit and liquid mismatch (todo if exist)
|
| 769 |
+
resized_patch_merge_sam = cv2.resize(patch_merge_sam, (self.feat_size, self.feat_size), interpolation = cv2.INTER_NEAREST)
|
| 770 |
+
binary_liquid = (resized_patch_merge_sam == 1)
|
| 771 |
+
binary_fruit = (resized_patch_merge_sam == 2)
|
| 772 |
+
|
| 773 |
+
query_liquid = encode_obj_text(self.model_clip, self.juice_bottle_liquid_query_words_dict, self.tokenizer, self.device)
|
| 774 |
+
query_fruit = encode_obj_text(self.model_clip, self.juice_bottle_fruit_query_words_dict, self.tokenizer, self.device)
|
| 775 |
+
|
| 776 |
+
liquid_feature = proj_patch_token[binary_liquid.reshape(-1), :].mean(0, keepdim=True)
|
| 777 |
+
liquid_idx = (liquid_feature @ query_liquid.T).argmax(-1).squeeze(0).item()
|
| 778 |
+
|
| 779 |
+
fruit_feature = proj_patch_token[binary_fruit.reshape(-1), :].mean(0, keepdim=True)
|
| 780 |
+
fruit_idx = (fruit_feature @ query_fruit.T).argmax(-1).squeeze(0).item()
|
| 781 |
+
|
| 782 |
+
if (liquid_idx != fruit_idx) and self.anomaly_flag is False:
|
| 783 |
+
print('liquid: {}, but fruit: {}.'.format(self.juice_bottle_liquid_query_words_dict[liquid_idx], self.juice_bottle_fruit_query_words_dict[fruit_idx]))
|
| 784 |
+
self.anomaly_flag = True
|
| 785 |
+
|
| 786 |
+
# # todo centroid of fruit and tag_0 mismatch (if exist) , only one tag, center
|
| 787 |
+
|
| 788 |
+
# patch hist
|
| 789 |
+
sam_patch_hist = np.bincount(patch_merge_sam.reshape(-1), minlength=self.patch_query_obj.shape[0])
|
| 790 |
+
sam_patch_hist = sam_patch_hist / np.linalg.norm(sam_patch_hist)
|
| 791 |
+
|
| 792 |
+
if self.few_shot_inited:
|
| 793 |
+
patch_hist_similarity = (sam_patch_hist @ self.patch_token_hist.T)
|
| 794 |
+
score = 1 - patch_hist_similarity.max()
|
| 795 |
+
|
| 796 |
+
binary_foreground = (patch_merge_sam != (self.patch_query_obj.shape[0] - 1) ).astype(np.uint8)
|
| 797 |
+
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(binary_foreground, connectivity=8)
|
| 798 |
+
for i in range(1, num_labels):
|
| 799 |
+
instance_mask = (labels == i).astype(np.uint8)
|
| 800 |
+
instance_mask = cv2.resize(instance_mask, (self.feat_size, self.feat_size), interpolation = cv2.INTER_NEAREST)
|
| 801 |
+
if instance_mask.any():
|
| 802 |
+
instance_masks.append(instance_mask.astype(np.bool).reshape(-1))
|
| 803 |
+
|
| 804 |
+
if len(instance_masks) != 0:
|
| 805 |
+
instance_masks = np.stack(instance_masks) #[N, 64x64]
|
| 806 |
+
|
| 807 |
+
if self.visualization:
|
| 808 |
+
image_list = [raw_image, kmeans_label, kmeans_mask, patch_mask, sam_mask, merge_sam, patch_merge_sam, binary_foreground]
|
| 809 |
+
title_list = ['raw image', 'k-means', 'kmeans mask', 'patch mask', 'sam mask', 'merge sam mask', 'patch merge sam', 'binary_foreground']
|
| 810 |
+
plt.figure(figsize=(20, 3))
|
| 811 |
+
for ind, (temp_title, temp_image) in enumerate(zip(title_list, image_list), start=1):
|
| 812 |
+
plt.subplot(1, len(image_list), ind)
|
| 813 |
+
plt.imshow(temp_image)
|
| 814 |
+
plt.title(temp_title)
|
| 815 |
+
plt.margins(0, 0)
|
| 816 |
+
plt.axis('off')
|
| 817 |
+
# Extract relative path from class_name onwards
|
| 818 |
+
if class_name in path:
|
| 819 |
+
relative_path = path.split(class_name, 1)[-1]
|
| 820 |
+
if relative_path.startswith('/'):
|
| 821 |
+
relative_path = relative_path[1:]
|
| 822 |
+
save_path = f'visualization/full_data/{class_name}/{relative_path}.png'
|
| 823 |
+
else:
|
| 824 |
+
save_path = f'visualization/full_data/{class_name}/{path}.png'
|
| 825 |
+
|
| 826 |
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
| 827 |
+
plt.tight_layout()
|
| 828 |
+
plt.savefig(save_path, bbox_inches='tight', dpi=150)
|
| 829 |
+
plt.close()
|
| 830 |
+
|
| 831 |
+
return {"score": score, "sam_patch_hist": sam_patch_hist, "instance_masks": instance_masks}
|
| 832 |
+
|
| 833 |
+
return {"score": score, "instance_masks": instance_masks}
|
| 834 |
+
|
| 835 |
+
|
| 836 |
+
def process_k_shot(self, class_name, few_shot_samples, few_shot_paths):
|
| 837 |
+
few_shot_samples = F.interpolate(few_shot_samples, size=(448, 448), mode=self.inter_mode, align_corners=self.align_corners, antialias=self.antialias)
|
| 838 |
+
|
| 839 |
+
with torch.no_grad():
|
| 840 |
+
image_features, patch_tokens, proj_patch_tokens = self.model_clip.encode_image(few_shot_samples, self.feature_list)
|
| 841 |
+
patch_tokens = [p[:, 1:, :] for p in patch_tokens]
|
| 842 |
+
patch_tokens = [p.reshape(p.shape[0]*p.shape[1], p.shape[2]) for p in patch_tokens]
|
| 843 |
+
|
| 844 |
+
patch_tokens_clip = torch.cat(patch_tokens, dim=-1) # (bs, 1024, 1024x4)
|
| 845 |
+
# patch_tokens_clip = torch.cat(patch_tokens[2:], dim=-1) # (bs, 1024, 1024x2)
|
| 846 |
+
patch_tokens_clip = patch_tokens_clip.view(self.k_shot, self.ori_feat_size, self.ori_feat_size, -1).permute(0, 3, 1, 2)
|
| 847 |
+
patch_tokens_clip = F.interpolate(patch_tokens_clip, size=(self.feat_size, self.feat_size), mode=self.inter_mode, align_corners=self.align_corners)
|
| 848 |
+
patch_tokens_clip = patch_tokens_clip.permute(0, 2, 3, 1).view(-1, self.vision_width * len(self.feature_list))
|
| 849 |
+
patch_tokens_clip = F.normalize(patch_tokens_clip, p=2, dim=-1) # (bsx64x64, 1024x4)
|
| 850 |
+
|
| 851 |
+
with torch.no_grad():
|
| 852 |
+
patch_tokens_dinov2 = self.model_dinov2.forward_features(few_shot_samples, out_layer_list=self.feature_list_dinov2) # 4 x [bs, 32x32, 1024]
|
| 853 |
+
patch_tokens_dinov2 = torch.cat(patch_tokens_dinov2, dim=-1) # (bs, 1024, 1024x4)
|
| 854 |
+
patch_tokens_dinov2 = patch_tokens_dinov2.view(self.k_shot, self.ori_feat_size, self.ori_feat_size, -1).permute(0, 3, 1, 2)
|
| 855 |
+
patch_tokens_dinov2 = F.interpolate(patch_tokens_dinov2, size=(self.feat_size, self.feat_size), mode=self.inter_mode, align_corners=self.align_corners)
|
| 856 |
+
patch_tokens_dinov2 = patch_tokens_dinov2.permute(0, 2, 3, 1).view(-1, self.vision_width_dinov2 * len(self.feature_list_dinov2))
|
| 857 |
+
patch_tokens_dinov2 = F.normalize(patch_tokens_dinov2, p=2, dim=-1) # (bsx64x64, 1024x4)
|
| 858 |
+
|
| 859 |
+
|
| 860 |
+
cluster_features = None
|
| 861 |
+
for layer in self.cluster_feature_id:
|
| 862 |
+
temp_feat = patch_tokens[layer]
|
| 863 |
+
cluster_features = temp_feat if cluster_features is None else torch.cat((cluster_features, temp_feat), 1)
|
| 864 |
+
if self.feat_size != self.ori_feat_size:
|
| 865 |
+
cluster_features = cluster_features.view(self.k_shot, self.ori_feat_size, self.ori_feat_size, -1).permute(0, 3, 1, 2)
|
| 866 |
+
cluster_features = F.interpolate(cluster_features, size=(self.feat_size, self.feat_size), mode=self.inter_mode, align_corners=self.align_corners)
|
| 867 |
+
cluster_features = cluster_features.permute(0, 2, 3, 1).view(-1, self.vision_width * len(self.cluster_feature_id))
|
| 868 |
+
cluster_features = F.normalize(cluster_features, p=2, dim=-1)
|
| 869 |
+
|
| 870 |
+
if self.feat_size != self.ori_feat_size:
|
| 871 |
+
proj_patch_tokens = proj_patch_tokens.view(self.k_shot, self.ori_feat_size, self.ori_feat_size, -1).permute(0, 3, 1, 2)
|
| 872 |
+
proj_patch_tokens = F.interpolate(proj_patch_tokens, size=(self.feat_size, self.feat_size), mode=self.inter_mode, align_corners=self.align_corners)
|
| 873 |
+
proj_patch_tokens = proj_patch_tokens.permute(0, 2, 3, 1).view(-1, self.embed_dim)
|
| 874 |
+
proj_patch_tokens = F.normalize(proj_patch_tokens, p=2, dim=-1)
|
| 875 |
+
|
| 876 |
+
if not self.cluster_init:
|
| 877 |
+
num_clusters = self.cluster_num_dict[class_name]
|
| 878 |
+
_, self.cluster_centers = kmeans(X=cluster_features, num_clusters=num_clusters, device=self.device)
|
| 879 |
+
|
| 880 |
+
self.query_obj = encode_obj_text(self.model_clip, self.query_words_dict[class_name], self.tokenizer, self.device)
|
| 881 |
+
self.patch_query_obj = encode_obj_text(self.model_clip, self.patch_query_words_dict[class_name], self.tokenizer, self.device)
|
| 882 |
+
self.classes = self.query_obj.shape[0]
|
| 883 |
+
|
| 884 |
+
self.cluster_init = True
|
| 885 |
+
|
| 886 |
+
scores = []
|
| 887 |
+
foreground_pixel_hist = []
|
| 888 |
+
splicing_connectors_distance = []
|
| 889 |
+
patch_token_hist = []
|
| 890 |
+
mem_instance_masks = []
|
| 891 |
+
|
| 892 |
+
for image, cluster_feature, proj_patch_token, few_shot_path in zip(few_shot_samples.chunk(self.k_shot), cluster_features.chunk(self.k_shot), proj_patch_tokens.chunk(self.k_shot), few_shot_paths):
|
| 893 |
+
# path = os.path.dirname(few_shot_path).split('/')[-1] + "_" + os.path.basename(few_shot_path).split('.')[0]
|
| 894 |
+
self.anomaly_flag = False
|
| 895 |
+
results = self.histogram(image, cluster_feature, proj_patch_token, class_name, "few_shot_" + os.path.basename(few_shot_path).split('.')[0])
|
| 896 |
+
if self.class_name == 'pushpins':
|
| 897 |
+
patch_token_hist.append(results["clip_patch_hist"])
|
| 898 |
+
mem_instance_masks.append(results['instance_masks'])
|
| 899 |
+
|
| 900 |
+
elif self.class_name == 'splicing_connectors':
|
| 901 |
+
foreground_pixel_hist.append(results["foreground_pixel_count"])
|
| 902 |
+
splicing_connectors_distance.append(results["distance"])
|
| 903 |
+
patch_token_hist.append(results["sam_patch_hist"])
|
| 904 |
+
mem_instance_masks.append(results['instance_masks'])
|
| 905 |
+
|
| 906 |
+
elif self.class_name == 'screw_bag':
|
| 907 |
+
foreground_pixel_hist.append(results["foreground_pixel_count"])
|
| 908 |
+
patch_token_hist.append(results["clip_patch_hist"])
|
| 909 |
+
mem_instance_masks.append(results['instance_masks'])
|
| 910 |
+
|
| 911 |
+
elif self.class_name == 'breakfast_box':
|
| 912 |
+
patch_token_hist.append(results["sam_patch_hist"])
|
| 913 |
+
mem_instance_masks.append(results['instance_masks'])
|
| 914 |
+
|
| 915 |
+
elif self.class_name == 'juice_bottle':
|
| 916 |
+
patch_token_hist.append(results["sam_patch_hist"])
|
| 917 |
+
mem_instance_masks.append(results['instance_masks'])
|
| 918 |
+
|
| 919 |
+
scores.append(results["score"])
|
| 920 |
+
|
| 921 |
+
if len(foreground_pixel_hist) != 0:
|
| 922 |
+
self.foreground_pixel_hist = np.mean(foreground_pixel_hist)
|
| 923 |
+
if len(splicing_connectors_distance) != 0:
|
| 924 |
+
self.splicing_connectors_distance = np.mean(splicing_connectors_distance)
|
| 925 |
+
if len(patch_token_hist) != 0: # patch hist
|
| 926 |
+
self.patch_token_hist = np.stack(patch_token_hist)
|
| 927 |
+
if len(mem_instance_masks) != 0:
|
| 928 |
+
self.mem_instance_masks = mem_instance_masks
|
| 929 |
+
|
| 930 |
+
# for interests matching
|
| 931 |
+
len_feature_list = len(self.feature_list)
|
| 932 |
+
for idx, batch_mem_patch_feature in enumerate(patch_tokens_clip.chunk(len_feature_list, dim=-1)): # 4 stages batch_mem_patch_feature (bsx64x64, 1024)
|
| 933 |
+
mem_instance_features = []
|
| 934 |
+
for mem_patch_feature, mem_instance_masks in zip(batch_mem_patch_feature.chunk(self.k_shot), self.mem_instance_masks): # k shot mem_patch_feature (64x64, 1024)
|
| 935 |
+
mem_instance_features.extend([mem_patch_feature[mask, :].mean(0, keepdim=True) for mask in mem_instance_masks])
|
| 936 |
+
mem_instance_features = torch.cat(mem_instance_features, dim=0)
|
| 937 |
+
mem_instance_features = F.normalize(mem_instance_features, dim=-1) # 4 stages
|
| 938 |
+
# mem_instance_features_multi_stage.append(mem_instance_features)
|
| 939 |
+
self.mem_instance_features_multi_stage[idx].append(mem_instance_features)
|
| 940 |
+
|
| 941 |
+
|
| 942 |
+
mem_patch_feature_clip_coreset = patch_tokens_clip
|
| 943 |
+
mem_patch_feature_dinov2_coreset = patch_tokens_dinov2
|
| 944 |
+
|
| 945 |
+
return scores, mem_patch_feature_clip_coreset, mem_patch_feature_dinov2_coreset
|
| 946 |
+
|
| 947 |
+
def process(self, class_name: str, few_shot_samples: list[torch.Tensor], few_shot_paths: list[str]):
|
| 948 |
+
few_shot_samples = self.transform(few_shot_samples).to(self.device)
|
| 949 |
+
|
| 950 |
+
scores, mem_patch_feature_clip_coreset, mem_patch_feature_dinov2_coreset = self.process_k_shot(class_name, few_shot_samples, few_shot_paths)
|
| 951 |
+
|
| 952 |
+
clip_sampler = KCenterGreedy(embedding=mem_patch_feature_clip_coreset, sampling_ratio=0.25)
|
| 953 |
+
mem_patch_feature_clip_coreset = clip_sampler.sample_coreset()
|
| 954 |
+
|
| 955 |
+
dinov2_sampler = KCenterGreedy(embedding=mem_patch_feature_dinov2_coreset, sampling_ratio=0.25)
|
| 956 |
+
mem_patch_feature_dinov2_coreset = dinov2_sampler.sample_coreset()
|
| 957 |
+
|
| 958 |
+
self.mem_patch_feature_clip_coreset.append(mem_patch_feature_clip_coreset)
|
| 959 |
+
self.mem_patch_feature_dinov2_coreset.append(mem_patch_feature_dinov2_coreset)
|
| 960 |
+
|
| 961 |
+
|
| 962 |
+
def setup(self, data: dict) -> None:
|
| 963 |
+
"""Setup the few-shot samples for the model.
|
| 964 |
+
|
| 965 |
+
The evaluation script will call this method to pass the k images for few shot learning and the object class
|
| 966 |
+
name. In the case of MVTec LOCO this will be the dataset category name (e.g. breakfast_box). Please contact
|
| 967 |
+
the organizing committee if if your model requires any additional dataset-related information at setup-time.
|
| 968 |
+
"""
|
| 969 |
+
few_shot_samples = data.get("few_shot_samples")
|
| 970 |
+
class_name = data.get("dataset_category")
|
| 971 |
+
few_shot_paths = data.get("few_shot_samples_path")
|
| 972 |
+
self.class_name = class_name
|
| 973 |
+
|
| 974 |
+
print(few_shot_samples.shape)
|
| 975 |
+
|
| 976 |
+
self.total_size = few_shot_samples.size(0)
|
| 977 |
+
|
| 978 |
+
self.k_shot = 4 if self.total_size > 4 else self.total_size
|
| 979 |
+
|
| 980 |
+
self.cluster_init = False
|
| 981 |
+
self.mem_instance_features_multi_stage = [[],[],[],[]]
|
| 982 |
+
|
| 983 |
+
self.mem_patch_feature_clip_coreset = []
|
| 984 |
+
self.mem_patch_feature_dinov2_coreset = []
|
| 985 |
+
|
| 986 |
+
# Check if coreset files already exist
|
| 987 |
+
clip_file = 'memory_bank/mem_patch_feature_clip_{}.pt'.format(self.class_name)
|
| 988 |
+
dinov2_file = 'memory_bank/mem_patch_feature_dinov2_{}.pt'.format(self.class_name)
|
| 989 |
+
instance_file = 'memory_bank/mem_instance_features_multi_stage_{}.pt'.format(self.class_name)
|
| 990 |
+
|
| 991 |
+
files_exist = os.path.exists(clip_file) and os.path.exists(dinov2_file) and os.path.exists(instance_file)
|
| 992 |
+
|
| 993 |
+
if self.save_coreset_features and not files_exist:
|
| 994 |
+
print(f"Coreset files not found for {self.class_name}, computing and saving...")
|
| 995 |
+
for i in range(self.total_size//self.k_shot):
|
| 996 |
+
self.process(class_name, few_shot_samples[self.k_shot*i : min(self.k_shot*(i+1), self.total_size)], few_shot_paths[self.k_shot*i : min(self.k_shot*(i+1), self.total_size)])
|
| 997 |
+
|
| 998 |
+
# Coreset Subsampling
|
| 999 |
+
self.mem_patch_feature_clip_coreset = torch.cat(self.mem_patch_feature_clip_coreset, dim=0)
|
| 1000 |
+
torch.save(self.mem_patch_feature_clip_coreset, clip_file)
|
| 1001 |
+
|
| 1002 |
+
self.mem_patch_feature_dinov2_coreset = torch.cat(self.mem_patch_feature_dinov2_coreset, dim=0)
|
| 1003 |
+
torch.save(self.mem_patch_feature_dinov2_coreset, dinov2_file)
|
| 1004 |
+
|
| 1005 |
+
print(self.mem_patch_feature_dinov2_coreset.shape, self.mem_patch_feature_clip_coreset.shape)
|
| 1006 |
+
|
| 1007 |
+
self.mem_instance_features_multi_stage = [ torch.cat(mem_instance_features, dim=0) for mem_instance_features in self.mem_instance_features_multi_stage ]
|
| 1008 |
+
self.mem_instance_features_multi_stage = torch.cat(self.mem_instance_features_multi_stage, dim=1)
|
| 1009 |
+
torch.save(self.mem_instance_features_multi_stage, instance_file)
|
| 1010 |
+
|
| 1011 |
+
print(self.mem_instance_features_multi_stage.shape)
|
| 1012 |
+
|
| 1013 |
+
elif self.save_coreset_features and files_exist:
|
| 1014 |
+
print(f"Coreset files found for {self.class_name}, loading existing files...")
|
| 1015 |
+
self.process(class_name, few_shot_samples[0 : self.k_shot], few_shot_paths[0 : self.k_shot])
|
| 1016 |
+
|
| 1017 |
+
self.mem_patch_feature_clip_coreset = torch.load(clip_file)
|
| 1018 |
+
self.mem_patch_feature_dinov2_coreset = torch.load(dinov2_file)
|
| 1019 |
+
self.mem_instance_features_multi_stage = torch.load(instance_file)
|
| 1020 |
+
|
| 1021 |
+
print(self.mem_patch_feature_dinov2_coreset.shape, self.mem_patch_feature_clip_coreset.shape)
|
| 1022 |
+
print(self.mem_instance_features_multi_stage.shape)
|
| 1023 |
+
|
| 1024 |
+
else:
|
| 1025 |
+
self.process(class_name, few_shot_samples[0 : self.k_shot], few_shot_paths[0 : self.k_shot])
|
| 1026 |
+
|
| 1027 |
+
self.mem_patch_feature_clip_coreset = torch.load(clip_file)
|
| 1028 |
+
self.mem_patch_feature_dinov2_coreset = torch.load(dinov2_file)
|
| 1029 |
+
self.mem_instance_features_multi_stage = torch.load(instance_file)
|
| 1030 |
+
|
| 1031 |
+
|
| 1032 |
+
self.few_shot_inited = True
|
| 1033 |
+
|
| 1034 |
+
|
model_ensemble_few_shot.py
ADDED
|
@@ -0,0 +1,935 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
# Set cache directories to use checkpoint folder for model downloads
|
| 4 |
+
os.environ['TORCH_HOME'] = './checkpoint'
|
| 5 |
+
os.environ['HF_HOME'] = './checkpoint/huggingface'
|
| 6 |
+
os.environ['TRANSFORMERS_CACHE'] = './checkpoint/huggingface/transformers'
|
| 7 |
+
os.environ['HF_HUB_CACHE'] = './checkpoint/huggingface/hub'
|
| 8 |
+
|
| 9 |
+
# Create checkpoint subdirectories if they don't exist
|
| 10 |
+
os.makedirs('./checkpoint/huggingface/transformers', exist_ok=True)
|
| 11 |
+
os.makedirs('./checkpoint/huggingface/hub', exist_ok=True)
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
from torch import nn
|
| 15 |
+
from torchvision.transforms import v2
|
| 16 |
+
from torchvision.transforms.v2.functional import resize
|
| 17 |
+
import cv2
|
| 18 |
+
import json
|
| 19 |
+
import torch
|
| 20 |
+
import random
|
| 21 |
+
import logging
|
| 22 |
+
import argparse
|
| 23 |
+
import numpy as np
|
| 24 |
+
from PIL import Image
|
| 25 |
+
from skimage import measure
|
| 26 |
+
from tabulate import tabulate
|
| 27 |
+
from torchvision.ops.focal_loss import sigmoid_focal_loss
|
| 28 |
+
import torch.nn.functional as F
|
| 29 |
+
import torchvision.transforms as transforms
|
| 30 |
+
import torchvision.transforms.functional as TF
|
| 31 |
+
from sklearn.metrics import auc, roc_auc_score, average_precision_score, f1_score, precision_recall_curve, pairwise
|
| 32 |
+
from sklearn.mixture import GaussianMixture
|
| 33 |
+
import faiss
|
| 34 |
+
import open_clip_local as open_clip
|
| 35 |
+
|
| 36 |
+
from torch.utils.data.dataset import ConcatDataset
|
| 37 |
+
from scipy.optimize import linear_sum_assignment
|
| 38 |
+
from sklearn.random_projection import SparseRandomProjection
|
| 39 |
+
import cv2
|
| 40 |
+
from torchvision.transforms import InterpolationMode
|
| 41 |
+
from PIL import Image
|
| 42 |
+
import string
|
| 43 |
+
|
| 44 |
+
from prompt_ensemble import encode_text_with_prompt_ensemble, encode_normal_text, encode_abnormal_text, encode_general_text, encode_obj_text
|
| 45 |
+
from kmeans_pytorch import kmeans, kmeans_predict
|
| 46 |
+
from scipy.optimize import linear_sum_assignment
|
| 47 |
+
from scipy.stats import norm
|
| 48 |
+
|
| 49 |
+
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
|
| 50 |
+
from matplotlib import pyplot as plt
|
| 51 |
+
|
| 52 |
+
import matplotlib
|
| 53 |
+
matplotlib.use('Agg')
|
| 54 |
+
|
| 55 |
+
import pickle
|
| 56 |
+
from scipy.stats import norm
|
| 57 |
+
|
| 58 |
+
from open_clip_local.pos_embed import get_2d_sincos_pos_embed
|
| 59 |
+
|
| 60 |
+
def to_np_img(m):
|
| 61 |
+
m = m.permute(1, 2, 0).cpu().numpy()
|
| 62 |
+
mean = np.array([[[0.48145466, 0.4578275, 0.40821073]]])
|
| 63 |
+
std = np.array([[[0.26862954, 0.26130258, 0.27577711]]])
|
| 64 |
+
m = m * std + mean
|
| 65 |
+
return np.clip((m * 255.), 0, 255).astype(np.uint8)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def setup_seed(seed):
|
| 69 |
+
torch.manual_seed(seed)
|
| 70 |
+
torch.cuda.manual_seed_all(seed)
|
| 71 |
+
np.random.seed(seed)
|
| 72 |
+
random.seed(seed)
|
| 73 |
+
torch.backends.cudnn.deterministic = True
|
| 74 |
+
torch.backends.cudnn.benchmark = False
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class MyModel(nn.Module):
|
| 78 |
+
"""Example model class for track 2.
|
| 79 |
+
|
| 80 |
+
This class applies few-shot anomaly detection using the WinClip model from Anomalib.
|
| 81 |
+
"""
|
| 82 |
+
|
| 83 |
+
def __init__(self) -> None:
|
| 84 |
+
super().__init__()
|
| 85 |
+
|
| 86 |
+
setup_seed(42)
|
| 87 |
+
# NOTE: Create your transformation pipeline (if needed).
|
| 88 |
+
self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
| 89 |
+
self.transform = v2.Compose(
|
| 90 |
+
[
|
| 91 |
+
v2.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)),
|
| 92 |
+
],
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
# NOTE: Create your model.
|
| 96 |
+
|
| 97 |
+
self.model_clip, _, _ = open_clip.create_model_and_transforms('hf-hub:laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K')
|
| 98 |
+
self.tokenizer = open_clip.get_tokenizer('hf-hub:laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K')
|
| 99 |
+
self.feature_list = [6, 12, 18, 24]
|
| 100 |
+
self.embed_dim = 768
|
| 101 |
+
self.vision_width = 1024
|
| 102 |
+
|
| 103 |
+
self.model_sam = sam_model_registry["vit_h"](checkpoint = "./checkpoint/sam_vit_h_4b8939.pth").to(self.device)
|
| 104 |
+
self.mask_generator = SamAutomaticMaskGenerator(model = self.model_sam)
|
| 105 |
+
|
| 106 |
+
self.memory_size = 2048
|
| 107 |
+
self.n_neighbors = 2
|
| 108 |
+
|
| 109 |
+
self.model_clip.eval()
|
| 110 |
+
self.test_args = None
|
| 111 |
+
self.align_corners = True # False
|
| 112 |
+
self.antialias = True # False
|
| 113 |
+
self.inter_mode = 'bilinear' # bilinear/bicubic
|
| 114 |
+
|
| 115 |
+
self.cluster_feature_id = [0, 1]
|
| 116 |
+
|
| 117 |
+
self.cluster_num_dict = {
|
| 118 |
+
"breakfast_box": 3, # unused
|
| 119 |
+
"juice_bottle": 8, # unused
|
| 120 |
+
"splicing_connectors": 10, # unused
|
| 121 |
+
"pushpins": 10,
|
| 122 |
+
"screw_bag": 10,
|
| 123 |
+
}
|
| 124 |
+
self.query_words_dict = {
|
| 125 |
+
"breakfast_box": ['orange', "nectarine", "cereals", "banana chips", 'almonds', 'white box', 'black background'],
|
| 126 |
+
"juice_bottle": ['bottle', ['black background', 'background']],
|
| 127 |
+
"pushpins": [['pushpin', 'pin'], ['plastic box', 'black background']],
|
| 128 |
+
"screw_bag": [['screw'], 'plastic bag', 'background'],
|
| 129 |
+
"splicing_connectors": [['splicing connector', 'splice connector',], ['cable', 'wire'], ['grid']],
|
| 130 |
+
}
|
| 131 |
+
self.foreground_label_idx = { # for query_words_dict
|
| 132 |
+
"breakfast_box": [0, 1, 2, 3, 4, 5],
|
| 133 |
+
"juice_bottle": [0],
|
| 134 |
+
"pushpins": [0],
|
| 135 |
+
"screw_bag": [0],
|
| 136 |
+
"splicing_connectors":[0, 1]
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
self.patch_query_words_dict = {
|
| 140 |
+
"breakfast_box": ['orange', "nectarine", "cereals", "banana chips", 'almonds', 'white box', 'black background'],
|
| 141 |
+
"juice_bottle": [['glass'], ['liquid in bottle'], ['fruit'], ['label', 'tag'], ['black background', 'background']],
|
| 142 |
+
"pushpins": [['pushpin', 'pin'], ['plastic box', 'black background']],
|
| 143 |
+
"screw_bag": [['hex screw', 'hexagon bolt'], ['hex nut', 'hexagon nut'], ['ring washer', 'ring gasket'], ['plastic bag', 'background']],
|
| 144 |
+
"splicing_connectors": [['splicing connector', 'splice connector',], ['cable', 'wire'], ['grid']],
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
self.query_threshold_dict = {
|
| 149 |
+
"breakfast_box": [0., 0., 0., 0., 0., 0., 0.], # unused
|
| 150 |
+
"juice_bottle": [0., 0., 0.], # unused
|
| 151 |
+
"splicing_connectors": [0.15, 0.15, 0.15, 0., 0.], # unused
|
| 152 |
+
"pushpins": [0.2, 0., 0., 0.],
|
| 153 |
+
"screw_bag": [0., 0., 0.,],
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
self.feat_size = 64
|
| 157 |
+
self.ori_feat_size = 32
|
| 158 |
+
|
| 159 |
+
self.visualization = False
|
| 160 |
+
|
| 161 |
+
self.pushpins_count = 15
|
| 162 |
+
|
| 163 |
+
self.splicing_connectors_count = [2, 3, 5] # coresponding to yellow, blue, and red
|
| 164 |
+
self.splicing_connectors_distance = 0
|
| 165 |
+
self.splicing_connectors_cable_color_query_words_dict = [['yellow cable', 'yellow wire'], ['blue cable', 'blue wire'], ['red cable', 'red wire']]
|
| 166 |
+
|
| 167 |
+
self.juice_bottle_liquid_query_words_dict = [['red liquid', 'cherry juice'], ['yellow liquid', 'orange juice'], ['milky liquid']]
|
| 168 |
+
self.juice_bottle_fruit_query_words_dict = ['cherry', ['tangerine', 'orange'], 'banana']
|
| 169 |
+
|
| 170 |
+
# query words
|
| 171 |
+
self.foreground_pixel_hist = 0
|
| 172 |
+
# patch query words
|
| 173 |
+
self.patch_token_hist = []
|
| 174 |
+
|
| 175 |
+
self.few_shot_inited = False
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
from dinov2.dinov2.hub.backbones import dinov2_vitl14
|
| 179 |
+
self.model_dinov2 = dinov2_vitl14()
|
| 180 |
+
self.model_dinov2.to(self.device)
|
| 181 |
+
self.model_dinov2.eval()
|
| 182 |
+
self.feature_list_dinov2 = [6, 12, 18, 24]
|
| 183 |
+
self.vision_width_dinov2 = 1024
|
| 184 |
+
|
| 185 |
+
self.stats = pickle.load(open("memory_bank/statistic_scores_model_ensemble_few_shot_val.pkl", "rb"))
|
| 186 |
+
|
| 187 |
+
self.mem_instance_masks = None
|
| 188 |
+
|
| 189 |
+
self.anomaly_flag = False
|
| 190 |
+
self.validation = False #True #False
|
| 191 |
+
|
| 192 |
+
def set_viz(self, viz):
|
| 193 |
+
self.visualization = viz
|
| 194 |
+
|
| 195 |
+
def set_val(self, val):
|
| 196 |
+
self.validation = val
|
| 197 |
+
|
| 198 |
+
def forward(self, batch: torch.Tensor, batch_path: list) -> dict[str, torch.Tensor]:
|
| 199 |
+
"""Transform the input batch and pass it through the model.
|
| 200 |
+
|
| 201 |
+
This model returns a dictionary with the following keys
|
| 202 |
+
- ``anomaly_map`` - Anomaly map.
|
| 203 |
+
- ``pred_score`` - Predicted anomaly score.
|
| 204 |
+
"""
|
| 205 |
+
self.anomaly_flag = False
|
| 206 |
+
batch = self.transform(batch).to(self.device)
|
| 207 |
+
results = self.forward_one_sample(batch, self.mem_patch_feature_clip_coreset, self.mem_patch_feature_dinov2_coreset, batch_path[0])
|
| 208 |
+
|
| 209 |
+
hist_score = results['hist_score']
|
| 210 |
+
structural_score = results['structural_score']
|
| 211 |
+
instance_hungarian_match_score = results['instance_hungarian_match_score']
|
| 212 |
+
|
| 213 |
+
anomaly_map_structural = results['anomaly_map_structural']
|
| 214 |
+
|
| 215 |
+
if self.validation:
|
| 216 |
+
return {"hist_score": torch.tensor(hist_score), "structural_score": torch.tensor(structural_score), "instance_hungarian_match_score": torch.tensor(instance_hungarian_match_score)}
|
| 217 |
+
|
| 218 |
+
def sigmoid(z):
|
| 219 |
+
return 1/(1 + np.exp(-z))
|
| 220 |
+
|
| 221 |
+
# standardization
|
| 222 |
+
standard_structural_score = (structural_score - self.stats[self.class_name]["structural_scores"]["mean"]) / self.stats[self.class_name]["structural_scores"]["unbiased_std"]
|
| 223 |
+
standard_instance_hungarian_match_score = (instance_hungarian_match_score - self.stats[self.class_name]["instance_hungarian_match_scores"]["mean"]) / self.stats[self.class_name]["instance_hungarian_match_scores"]["unbiased_std"]
|
| 224 |
+
|
| 225 |
+
pred_score = max(standard_instance_hungarian_match_score, standard_structural_score)
|
| 226 |
+
pred_score = sigmoid(pred_score)
|
| 227 |
+
|
| 228 |
+
if self.anomaly_flag:
|
| 229 |
+
pred_score = 1.
|
| 230 |
+
self.anomaly_flag = False
|
| 231 |
+
|
| 232 |
+
return {"pred_score": torch.tensor(pred_score), "anomaly_map": torch.tensor(anomaly_map_structural), "hist_score": torch.tensor(hist_score), "structural_score": torch.tensor(structural_score), "instance_hungarian_match_score": torch.tensor(instance_hungarian_match_score)}
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
def forward_one_sample(self, batch: torch.Tensor, mem_patch_feature_clip_coreset: torch.Tensor, mem_patch_feature_dinov2_coreset: torch.Tensor, path: str):
|
| 236 |
+
|
| 237 |
+
with torch.no_grad():
|
| 238 |
+
image_features, patch_tokens, proj_patch_tokens = self.model_clip.encode_image(batch, self.feature_list)
|
| 239 |
+
# image_features /= image_features.norm(dim=-1, keepdim=True)
|
| 240 |
+
patch_tokens = [p[:, 1:, :] for p in patch_tokens]
|
| 241 |
+
patch_tokens = [p.reshape(p.shape[0]*p.shape[1], p.shape[2]) for p in patch_tokens]
|
| 242 |
+
|
| 243 |
+
patch_tokens_clip = torch.cat(patch_tokens, dim=-1) # (1, 1024, 1024x4)
|
| 244 |
+
# patch_tokens_clip = torch.cat(patch_tokens[2:], dim=-1) # (1, 1024, 1024x2)
|
| 245 |
+
patch_tokens_clip = patch_tokens_clip.view(1, self.ori_feat_size, self.ori_feat_size, -1).permute(0, 3, 1, 2)
|
| 246 |
+
patch_tokens_clip = F.interpolate(patch_tokens_clip, size=(self.feat_size, self.feat_size), mode=self.inter_mode, align_corners=self.align_corners)
|
| 247 |
+
patch_tokens_clip = patch_tokens_clip.permute(0, 2, 3, 1).view(-1, self.vision_width * len(self.feature_list))
|
| 248 |
+
patch_tokens_clip = F.normalize(patch_tokens_clip, p=2, dim=-1) # (1x64x64, 1024x4)
|
| 249 |
+
|
| 250 |
+
with torch.no_grad():
|
| 251 |
+
patch_tokens_dinov2 = self.model_dinov2.forward_features(batch, out_layer_list=self.feature_list)
|
| 252 |
+
patch_tokens_dinov2 = torch.cat(patch_tokens_dinov2, dim=-1) # (1, 1024, 1024x4)
|
| 253 |
+
patch_tokens_dinov2 = patch_tokens_dinov2.view(1, self.ori_feat_size, self.ori_feat_size, -1).permute(0, 3, 1, 2)
|
| 254 |
+
patch_tokens_dinov2 = F.interpolate(patch_tokens_dinov2, size=(self.feat_size, self.feat_size), mode=self.inter_mode, align_corners=self.align_corners)
|
| 255 |
+
patch_tokens_dinov2 = patch_tokens_dinov2.permute(0, 2, 3, 1).view(-1, self.vision_width_dinov2 * len(self.feature_list_dinov2))
|
| 256 |
+
patch_tokens_dinov2 = F.normalize(patch_tokens_dinov2, p=2, dim=-1) # (1x64x64, 1024x4)
|
| 257 |
+
|
| 258 |
+
'''adding for kmeans seg '''
|
| 259 |
+
if self.feat_size != self.ori_feat_size:
|
| 260 |
+
proj_patch_tokens = proj_patch_tokens.view(1, self.ori_feat_size, self.ori_feat_size, -1).permute(0, 3, 1, 2)
|
| 261 |
+
proj_patch_tokens = F.interpolate(proj_patch_tokens, size=(self.feat_size, self.feat_size), mode=self.inter_mode, align_corners=self.align_corners)
|
| 262 |
+
proj_patch_tokens = proj_patch_tokens.permute(0, 2, 3, 1).view(self.feat_size * self.feat_size, self.embed_dim)
|
| 263 |
+
proj_patch_tokens = F.normalize(proj_patch_tokens, p=2, dim=-1)
|
| 264 |
+
|
| 265 |
+
mid_features = None
|
| 266 |
+
for layer in self.cluster_feature_id:
|
| 267 |
+
temp_feat = patch_tokens[layer]
|
| 268 |
+
mid_features = temp_feat if mid_features is None else torch.cat((mid_features, temp_feat), -1)
|
| 269 |
+
|
| 270 |
+
if self.feat_size != self.ori_feat_size:
|
| 271 |
+
mid_features = mid_features.view(1, self.ori_feat_size, self.ori_feat_size, -1).permute(0, 3, 1, 2)
|
| 272 |
+
mid_features = F.interpolate(mid_features, size=(self.feat_size, self.feat_size), mode=self.inter_mode, align_corners=self.align_corners)
|
| 273 |
+
mid_features = mid_features.permute(0, 2, 3, 1).view(-1, self.vision_width * len(self.cluster_feature_id))
|
| 274 |
+
mid_features = F.normalize(mid_features, p=2, dim=-1)
|
| 275 |
+
|
| 276 |
+
results = self.histogram(batch, mid_features, proj_patch_tokens, self.class_name, os.path.dirname(path).split('/')[-1] + "_" + os.path.basename(path).split('.')[0])
|
| 277 |
+
|
| 278 |
+
hist_score = results['score']
|
| 279 |
+
|
| 280 |
+
'''calculate patchcore'''
|
| 281 |
+
anomaly_maps_patchcore = []
|
| 282 |
+
|
| 283 |
+
if self.class_name in ['pushpins', 'screw_bag']: # clip feature for patchcore
|
| 284 |
+
len_feature_list = len(self.feature_list)
|
| 285 |
+
for patch_feature, mem_patch_feature in zip(patch_tokens_clip.chunk(len_feature_list, dim=-1), mem_patch_feature_clip_coreset.chunk(len_feature_list, dim=-1)):
|
| 286 |
+
patch_feature = F.normalize(patch_feature, dim=-1)
|
| 287 |
+
mem_patch_feature = F.normalize(mem_patch_feature, dim=-1)
|
| 288 |
+
normal_map_patchcore = (patch_feature @ mem_patch_feature.T)
|
| 289 |
+
normal_map_patchcore = (normal_map_patchcore.max(1)[0]).cpu().numpy() # 1: normal 0: abnormal
|
| 290 |
+
anomaly_map_patchcore = 1 - normal_map_patchcore
|
| 291 |
+
|
| 292 |
+
anomaly_maps_patchcore.append(anomaly_map_patchcore)
|
| 293 |
+
|
| 294 |
+
if self.class_name in ['splicing_connectors', 'breakfast_box', 'juice_bottle']: # dinov2 feature for patchcore
|
| 295 |
+
len_feature_list = len(self.feature_list_dinov2)
|
| 296 |
+
for patch_feature, mem_patch_feature in zip(patch_tokens_dinov2.chunk(len_feature_list, dim=-1), mem_patch_feature_dinov2_coreset.chunk(len_feature_list, dim=-1)):
|
| 297 |
+
patch_feature = F.normalize(patch_feature, dim=-1)
|
| 298 |
+
mem_patch_feature = F.normalize(mem_patch_feature, dim=-1)
|
| 299 |
+
normal_map_patchcore = (patch_feature @ mem_patch_feature.T)
|
| 300 |
+
normal_map_patchcore = (normal_map_patchcore.max(1)[0]).cpu().numpy() # 1: normal 0: abnormal
|
| 301 |
+
anomaly_map_patchcore = 1 - normal_map_patchcore
|
| 302 |
+
|
| 303 |
+
anomaly_maps_patchcore.append(anomaly_map_patchcore)
|
| 304 |
+
|
| 305 |
+
structural_score = np.stack(anomaly_maps_patchcore).mean(0).max()
|
| 306 |
+
anomaly_map_structural = np.stack(anomaly_maps_patchcore).mean(0).reshape(self.feat_size, self.feat_size)
|
| 307 |
+
|
| 308 |
+
instance_masks = results["instance_masks"]
|
| 309 |
+
anomaly_instances_hungarian = []
|
| 310 |
+
instance_hungarian_match_score = 1.
|
| 311 |
+
if self.mem_instance_masks is not None and len(instance_masks) != 0:
|
| 312 |
+
for patch_feature, batch_mem_patch_feature in zip(patch_tokens_clip.chunk(len_feature_list, dim=-1), mem_patch_feature_clip_coreset.chunk(len_feature_list, dim=-1)):
|
| 313 |
+
instance_features = [patch_feature[mask, :].mean(0, keepdim=True) for mask in instance_masks]
|
| 314 |
+
instance_features = torch.cat(instance_features, dim=0)
|
| 315 |
+
instance_features = F.normalize(instance_features, dim=-1)
|
| 316 |
+
mem_instance_features = []
|
| 317 |
+
for mem_patch_feature, mem_instance_masks in zip(batch_mem_patch_feature.chunk(self.k_shot), self.mem_instance_masks):
|
| 318 |
+
mem_instance_features.extend([mem_patch_feature[mask, :].mean(0, keepdim=True) for mask in mem_instance_masks])
|
| 319 |
+
mem_instance_features = torch.cat(mem_instance_features, dim=0)
|
| 320 |
+
mem_instance_features = F.normalize(mem_instance_features, dim=-1)
|
| 321 |
+
|
| 322 |
+
normal_instance_hungarian = (instance_features @ mem_instance_features.T)
|
| 323 |
+
cost_matrix = (1 - normal_instance_hungarian).cpu().numpy()
|
| 324 |
+
|
| 325 |
+
row_ind, col_ind = linear_sum_assignment(cost_matrix)
|
| 326 |
+
cost = cost_matrix[row_ind, col_ind].sum()
|
| 327 |
+
cost = cost / min(cost_matrix.shape)
|
| 328 |
+
anomaly_instances_hungarian.append(cost)
|
| 329 |
+
|
| 330 |
+
instance_hungarian_match_score = np.mean(anomaly_instances_hungarian)
|
| 331 |
+
|
| 332 |
+
results = {'hist_score': hist_score, 'structural_score': structural_score, 'instance_hungarian_match_score': instance_hungarian_match_score, "anomaly_map_structural": anomaly_map_structural}
|
| 333 |
+
|
| 334 |
+
return results
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
def histogram(self, image, cluster_feature, proj_patch_token, class_name, path):
|
| 338 |
+
def plot_results_only(sorted_anns):
|
| 339 |
+
cur = 1
|
| 340 |
+
img_color = np.zeros((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1]))
|
| 341 |
+
for ann in sorted_anns:
|
| 342 |
+
m = ann['segmentation']
|
| 343 |
+
img_color[m] = cur
|
| 344 |
+
cur += 1
|
| 345 |
+
return img_color
|
| 346 |
+
|
| 347 |
+
def merge_segmentations(a, b, background_class):
|
| 348 |
+
unique_labels_a = np.unique(a)
|
| 349 |
+
unique_labels_b = np.unique(b)
|
| 350 |
+
|
| 351 |
+
max_label_a = unique_labels_a.max()
|
| 352 |
+
label_map = np.zeros(max_label_a + 1, dtype=int)
|
| 353 |
+
|
| 354 |
+
for label_a in unique_labels_a:
|
| 355 |
+
mask_a = (a == label_a)
|
| 356 |
+
|
| 357 |
+
labels_b = b[mask_a]
|
| 358 |
+
if labels_b.size > 0:
|
| 359 |
+
count_b = np.bincount(labels_b, minlength=unique_labels_b.max() + 1)
|
| 360 |
+
label_map[label_a] = np.argmax(count_b)
|
| 361 |
+
else:
|
| 362 |
+
label_map[label_a] = background_class # default background
|
| 363 |
+
|
| 364 |
+
merged_a = label_map[a]
|
| 365 |
+
return merged_a
|
| 366 |
+
|
| 367 |
+
pseudo_labels = kmeans_predict(cluster_feature, self.cluster_centers, 'euclidean', device=self.device)
|
| 368 |
+
kmeans_mask = torch.ones_like(pseudo_labels) * (self.classes - 1) # default to background
|
| 369 |
+
|
| 370 |
+
for pl in pseudo_labels.unique():
|
| 371 |
+
mask = (pseudo_labels == pl).reshape(-1)
|
| 372 |
+
# filter small region
|
| 373 |
+
binary = mask.cpu().numpy().reshape(self.feat_size, self.feat_size).astype(np.uint8)
|
| 374 |
+
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(binary, connectivity=8)
|
| 375 |
+
for i in range(1, num_labels):
|
| 376 |
+
temp_mask = labels == i
|
| 377 |
+
if np.sum(temp_mask) <= 8:
|
| 378 |
+
mask[temp_mask.reshape(-1)] = False
|
| 379 |
+
|
| 380 |
+
if mask.any():
|
| 381 |
+
region_feature = proj_patch_token[mask, :].mean(0, keepdim=True)
|
| 382 |
+
similarity = (region_feature @ self.query_obj.T)
|
| 383 |
+
prob, index = torch.max(similarity, dim=-1)
|
| 384 |
+
temp_label = index.squeeze(0).item()
|
| 385 |
+
temp_prob = prob.squeeze(0).item()
|
| 386 |
+
if temp_prob > self.query_threshold_dict[class_name][temp_label]: # threshold for each class
|
| 387 |
+
kmeans_mask[mask] = temp_label
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
raw_image = to_np_img(image[0])
|
| 391 |
+
height, width = raw_image.shape[:2]
|
| 392 |
+
masks = self.mask_generator.generate(raw_image)
|
| 393 |
+
# self.predictor.set_image(raw_image)
|
| 394 |
+
|
| 395 |
+
kmeans_label = pseudo_labels.view(self.feat_size, self.feat_size).cpu().numpy()
|
| 396 |
+
kmeans_mask = kmeans_mask.view(self.feat_size, self.feat_size).cpu().numpy()
|
| 397 |
+
|
| 398 |
+
patch_similarity = (proj_patch_token @ self.patch_query_obj.T)
|
| 399 |
+
patch_mask = patch_similarity.argmax(-1)
|
| 400 |
+
patch_mask = patch_mask.view(self.feat_size, self.feat_size).cpu().numpy()
|
| 401 |
+
|
| 402 |
+
sorted_masks = sorted(masks, key=(lambda x: x['area']), reverse=True)
|
| 403 |
+
sam_mask = plot_results_only(sorted_masks).astype(np.int)
|
| 404 |
+
|
| 405 |
+
resized_mask = cv2.resize(kmeans_mask, (width, height), interpolation = cv2.INTER_NEAREST)
|
| 406 |
+
merge_sam = merge_segmentations(sam_mask, resized_mask, background_class=self.classes-1)
|
| 407 |
+
|
| 408 |
+
resized_patch_mask = cv2.resize(patch_mask, (width, height), interpolation = cv2.INTER_NEAREST)
|
| 409 |
+
patch_merge_sam = merge_segmentations(sam_mask, resized_patch_mask, background_class=self.patch_query_obj.shape[0]-1)
|
| 410 |
+
|
| 411 |
+
# filter small region for merge sam
|
| 412 |
+
binary = np.isin(merge_sam, self.foreground_label_idx[self.class_name]).astype(np.uint8) # foreground 1 background 0
|
| 413 |
+
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(binary, connectivity=8)
|
| 414 |
+
for i in range(1, num_labels):
|
| 415 |
+
temp_mask = labels == i
|
| 416 |
+
if np.sum(temp_mask) <= 32: # 448x448
|
| 417 |
+
merge_sam[temp_mask] = self.classes - 1 # set to background
|
| 418 |
+
|
| 419 |
+
# filter small region for patch merge sam
|
| 420 |
+
binary = (patch_merge_sam != (self.patch_query_obj.shape[0]-1) ).astype(np.uint8) # foreground 1 background 0
|
| 421 |
+
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(binary, connectivity=8)
|
| 422 |
+
for i in range(1, num_labels):
|
| 423 |
+
temp_mask = labels == i
|
| 424 |
+
if np.sum(temp_mask) <= 32: # 448x448
|
| 425 |
+
patch_merge_sam[temp_mask] = self.patch_query_obj.shape[0]-1 # set to background
|
| 426 |
+
|
| 427 |
+
score = 0. # default to normal
|
| 428 |
+
self.anomaly_flag = False
|
| 429 |
+
instance_masks = []
|
| 430 |
+
if self.class_name == 'pushpins':
|
| 431 |
+
# object count hist
|
| 432 |
+
kernel = np.ones((3, 3), dtype=np.uint8) # dilate for robustness
|
| 433 |
+
binary = np.isin(merge_sam, self.foreground_label_idx[self.class_name]).astype(np.uint8) # foreground 1 background 0
|
| 434 |
+
dilate_binary = cv2.dilate(binary, kernel)
|
| 435 |
+
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(dilate_binary, connectivity=8)
|
| 436 |
+
pushpins_count = num_labels - 1 # number of pushpins
|
| 437 |
+
|
| 438 |
+
for i in range(1, num_labels):
|
| 439 |
+
instance_mask = (labels == i).astype(np.uint8)
|
| 440 |
+
instance_mask = cv2.resize(instance_mask, (self.feat_size, self.feat_size), interpolation = cv2.INTER_NEAREST)
|
| 441 |
+
if instance_mask.any():
|
| 442 |
+
instance_masks.append(instance_mask.astype(np.bool).reshape(-1))
|
| 443 |
+
|
| 444 |
+
if self.few_shot_inited and pushpins_count != self.pushpins_count and self.anomaly_flag is False:
|
| 445 |
+
self.anomaly_flag = True
|
| 446 |
+
print('number of pushpins: {}, but canonical number of pushpins: {}'.format(pushpins_count, self.pushpins_count))
|
| 447 |
+
|
| 448 |
+
# patch hist
|
| 449 |
+
clip_patch_hist = np.bincount(patch_mask.reshape(-1), minlength=self.patch_query_obj.shape[0])
|
| 450 |
+
clip_patch_hist = clip_patch_hist / np.linalg.norm(clip_patch_hist)
|
| 451 |
+
|
| 452 |
+
if self.few_shot_inited:
|
| 453 |
+
patch_hist_similarity = (clip_patch_hist @ self.patch_token_hist.T)
|
| 454 |
+
score = 1 - patch_hist_similarity.max()
|
| 455 |
+
|
| 456 |
+
binary_foreground = dilate_binary.astype(np.uint8)
|
| 457 |
+
|
| 458 |
+
if len(instance_masks) != 0:
|
| 459 |
+
instance_masks = np.stack(instance_masks) #[N, 64x64]
|
| 460 |
+
|
| 461 |
+
if self.visualization:
|
| 462 |
+
image_list = [raw_image, kmeans_label, kmeans_mask, patch_mask, sam_mask, merge_sam, patch_merge_sam, binary_foreground]
|
| 463 |
+
title_list = ['raw image', 'k-means', 'kmeans mask', 'patch mask', 'sam mask', 'merge sam mask', 'patch merge sam', 'binary_foreground']
|
| 464 |
+
plt.figure(figsize=(20, 3))
|
| 465 |
+
for ind, (temp_title, temp_image) in enumerate(zip(title_list, image_list), start=1):
|
| 466 |
+
plt.subplot(1, len(image_list), ind)
|
| 467 |
+
plt.imshow(temp_image)
|
| 468 |
+
plt.title(temp_title)
|
| 469 |
+
plt.margins(0, 0)
|
| 470 |
+
plt.axis('off')
|
| 471 |
+
# Extract relative path from class_name onwards
|
| 472 |
+
if class_name in path:
|
| 473 |
+
relative_path = path.split(class_name, 1)[-1]
|
| 474 |
+
if relative_path.startswith('/'):
|
| 475 |
+
relative_path = relative_path[1:]
|
| 476 |
+
save_path = f'visualization/few_shot/{class_name}/{relative_path}.png'
|
| 477 |
+
else:
|
| 478 |
+
save_path = f'visualization/few_shot/{class_name}/{path}.png'
|
| 479 |
+
|
| 480 |
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
| 481 |
+
plt.tight_layout()
|
| 482 |
+
plt.savefig(save_path, bbox_inches='tight', dpi=150)
|
| 483 |
+
plt.close()
|
| 484 |
+
|
| 485 |
+
# todo: same number in total but in different boxes or broken box
|
| 486 |
+
return {"score": score, "clip_patch_hist": clip_patch_hist, "instance_masks": instance_masks}
|
| 487 |
+
|
| 488 |
+
elif self.class_name == 'splicing_connectors':
|
| 489 |
+
# object count hist for default
|
| 490 |
+
sam_mask_max_area = sorted_masks[0]['segmentation'] # background
|
| 491 |
+
binary = (sam_mask_max_area == 0).astype(np.uint8) # sam_mask_max_area is background, background 0 foreground 1
|
| 492 |
+
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(binary, connectivity=8)
|
| 493 |
+
count = 0
|
| 494 |
+
for i in range(1, num_labels):
|
| 495 |
+
temp_mask = labels == i
|
| 496 |
+
if np.sum(temp_mask) <= 64: # 448x448 64
|
| 497 |
+
binary[temp_mask] = 0 # set to background
|
| 498 |
+
else:
|
| 499 |
+
count += 1
|
| 500 |
+
if count != 1 and self.anomaly_flag is False: # cable cut or no cable or no connector
|
| 501 |
+
print('number of connected component in splicing_connectors: {}, but the default connected component is 1.'.format(count))
|
| 502 |
+
self.anomaly_flag = True
|
| 503 |
+
|
| 504 |
+
merge_sam[~(binary.astype(np.bool))] = self.query_obj.shape[0] - 1 # remove noise
|
| 505 |
+
patch_merge_sam[~(binary.astype(np.bool))] = self.patch_query_obj.shape[0] - 1 # remove patch noise
|
| 506 |
+
|
| 507 |
+
# erode the cable and divide into left and right parts
|
| 508 |
+
kernel = np.ones((23, 23), dtype=np.uint8)
|
| 509 |
+
erode_binary = cv2.erode(binary, kernel)
|
| 510 |
+
h, w = erode_binary.shape
|
| 511 |
+
distance = 0
|
| 512 |
+
|
| 513 |
+
left, right = erode_binary[:, :int(w/2)], erode_binary[:, int(w/2):]
|
| 514 |
+
left_count = np.bincount(left.reshape(-1), minlength=self.classes)[1] # foreground
|
| 515 |
+
right_count = np.bincount(right.reshape(-1), minlength=self.classes)[1] # foreground
|
| 516 |
+
|
| 517 |
+
binary_cable = (patch_merge_sam == 1).astype(np.uint8)
|
| 518 |
+
|
| 519 |
+
kernel = np.ones((5, 5), dtype=np.uint8)
|
| 520 |
+
binary_cable = cv2.erode(binary_cable, kernel)
|
| 521 |
+
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(binary_cable, connectivity=8)
|
| 522 |
+
for i in range(1, num_labels):
|
| 523 |
+
temp_mask = labels == i
|
| 524 |
+
if np.sum(temp_mask) <= 64: # 448x448
|
| 525 |
+
binary_cable[temp_mask] = 0 # set to background
|
| 526 |
+
|
| 527 |
+
|
| 528 |
+
binary_cable = cv2.resize(binary_cable, (self.feat_size, self.feat_size), interpolation = cv2.INTER_NEAREST)
|
| 529 |
+
|
| 530 |
+
binary_clamps = (patch_merge_sam == 0).astype(np.uint8)
|
| 531 |
+
|
| 532 |
+
kernel = np.ones((5, 5), dtype=np.uint8)
|
| 533 |
+
binary_clamps = cv2.erode(binary_clamps, kernel)
|
| 534 |
+
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(binary_clamps, connectivity=8)
|
| 535 |
+
for i in range(1, num_labels):
|
| 536 |
+
temp_mask = labels == i
|
| 537 |
+
if np.sum(temp_mask) <= 64: # 448x448
|
| 538 |
+
binary_clamps[temp_mask] = 0 # set to background
|
| 539 |
+
else:
|
| 540 |
+
instance_mask = temp_mask.astype(np.uint8)
|
| 541 |
+
instance_mask = cv2.resize(instance_mask, (self.feat_size, self.feat_size), interpolation = cv2.INTER_NEAREST)
|
| 542 |
+
if instance_mask.any():
|
| 543 |
+
instance_masks.append(instance_mask.astype(np.bool).reshape(-1))
|
| 544 |
+
|
| 545 |
+
binary_clamps = cv2.resize(binary_clamps, (self.feat_size, self.feat_size), interpolation = cv2.INTER_NEAREST)
|
| 546 |
+
|
| 547 |
+
binary_connector = cv2.resize(binary, (self.feat_size, self.feat_size), interpolation = cv2.INTER_NEAREST)
|
| 548 |
+
|
| 549 |
+
query_cable_color = encode_obj_text(self.model_clip, self.splicing_connectors_cable_color_query_words_dict, self.tokenizer, self.device)
|
| 550 |
+
cable_feature = proj_patch_token[binary_cable.astype(np.bool).reshape(-1), :].mean(0, keepdim=True)
|
| 551 |
+
idx_color = (cable_feature @ query_cable_color.T).argmax(-1).squeeze(0).item()
|
| 552 |
+
foreground_pixel_count = np.sum(erode_binary) / self.splicing_connectors_count[idx_color]
|
| 553 |
+
|
| 554 |
+
|
| 555 |
+
slice_cable = binary[:, int(w/2)-1: int(w/2)+1]
|
| 556 |
+
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(slice_cable, connectivity=8)
|
| 557 |
+
cable_count = num_labels - 1
|
| 558 |
+
if cable_count != 1 and self.anomaly_flag is False: # two cables
|
| 559 |
+
print('number of cable count in splicing_connectors: {}, but the default cable count is 1.'.format(cable_count))
|
| 560 |
+
self.anomaly_flag = True
|
| 561 |
+
|
| 562 |
+
# {2-clamp: yellow 3-clamp: blue 5-clamp: red} cable color and clamp number mismatch
|
| 563 |
+
if self.few_shot_inited and self.foreground_pixel_hist != 0 and self.anomaly_flag is False:
|
| 564 |
+
ratio = foreground_pixel_count / self.foreground_pixel_hist
|
| 565 |
+
if (ratio > 1.2 or ratio < 0.8) and self.anomaly_flag is False: # color and number mismatch
|
| 566 |
+
print('cable color and number of clamps mismatch, cable color idx: {} (0: yellow 2-clamp, 1: blue 3-clamp, 2: red 5-clamp), foreground_pixel_count :{}, canonical foreground_pixel_hist: {}.'.format(idx_color, foreground_pixel_count, self.foreground_pixel_hist))
|
| 567 |
+
self.anomaly_flag = True
|
| 568 |
+
|
| 569 |
+
# left right hist for symmetry
|
| 570 |
+
ratio = np.sum(left_count) / (np.sum(right_count) + 1e-5)
|
| 571 |
+
if self.few_shot_inited and (ratio > 1.2 or ratio < 0.8) and self.anomaly_flag is False: # left right asymmetry in clamp
|
| 572 |
+
print('left and right connectors are not symmetry.')
|
| 573 |
+
self.anomaly_flag = True
|
| 574 |
+
|
| 575 |
+
# left and right centroids distance
|
| 576 |
+
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(erode_binary, connectivity=8)
|
| 577 |
+
if num_labels - 1 == 2:
|
| 578 |
+
centroids = centroids[1:]
|
| 579 |
+
x1, y1 = centroids[0]
|
| 580 |
+
x2, y2 = centroids[1]
|
| 581 |
+
distance = np.sqrt((x1/w - x2/w)**2 + (y1/h - y2/h)**2)
|
| 582 |
+
if self.few_shot_inited and self.splicing_connectors_distance != 0 and self.anomaly_flag is False:
|
| 583 |
+
ratio = distance / self.splicing_connectors_distance
|
| 584 |
+
if ratio < 0.6 or ratio > 1.4: # too short or too long centroids distance (cable) # 0.6 1.4
|
| 585 |
+
print('cable is too short or too long.')
|
| 586 |
+
self.anomaly_flag = True
|
| 587 |
+
|
| 588 |
+
# patch hist
|
| 589 |
+
sam_patch_hist = np.bincount(patch_merge_sam.reshape(-1), minlength=self.patch_query_obj.shape[0])#[:-1] # ignore background (grid) for statistic
|
| 590 |
+
sam_patch_hist = sam_patch_hist / np.linalg.norm(sam_patch_hist)
|
| 591 |
+
|
| 592 |
+
if self.few_shot_inited:
|
| 593 |
+
patch_hist_similarity = (sam_patch_hist @ self.patch_token_hist.T)
|
| 594 |
+
score = 1 - patch_hist_similarity.max()
|
| 595 |
+
|
| 596 |
+
# todo mismatch cable link
|
| 597 |
+
binary_foreground = binary.astype(np.uint8) # only 1 instance, so additionally seperate cable and clamps
|
| 598 |
+
if binary_connector.any():
|
| 599 |
+
instance_masks.append(binary_connector.astype(np.bool).reshape(-1))
|
| 600 |
+
if binary_clamps.any():
|
| 601 |
+
instance_masks.append(binary_clamps.astype(np.bool).reshape(-1))
|
| 602 |
+
if binary_cable.any():
|
| 603 |
+
instance_masks.append(binary_cable.astype(np.bool).reshape(-1))
|
| 604 |
+
|
| 605 |
+
if len(instance_masks) != 0:
|
| 606 |
+
instance_masks = np.stack(instance_masks) #[N, 64x64]
|
| 607 |
+
|
| 608 |
+
if self.visualization:
|
| 609 |
+
image_list = [raw_image, kmeans_label, kmeans_mask, patch_mask, sam_mask, binary_connector, merge_sam, patch_merge_sam, erode_binary, binary_cable, binary_clamps]
|
| 610 |
+
title_list = ['raw image', 'k-means', 'kmeans mask', 'patch mask', 'sam mask', 'binary_connector', 'merge sam', 'patch merge sam', 'erode binary', 'binary_cable', 'binary_clamps']
|
| 611 |
+
plt.figure(figsize=(25, 3))
|
| 612 |
+
for ind, (temp_title, temp_image) in enumerate(zip(title_list, image_list), start=1):
|
| 613 |
+
plt.subplot(1, len(image_list), ind)
|
| 614 |
+
plt.imshow(temp_image)
|
| 615 |
+
plt.title(temp_title)
|
| 616 |
+
plt.margins(0, 0)
|
| 617 |
+
plt.axis('off')
|
| 618 |
+
# Extract relative path from class_name onwards
|
| 619 |
+
if class_name in path:
|
| 620 |
+
relative_path = path.split(class_name, 1)[-1]
|
| 621 |
+
if relative_path.startswith('/'):
|
| 622 |
+
relative_path = relative_path[1:]
|
| 623 |
+
save_path = f'visualization/few_shot/{class_name}/{relative_path}.png'
|
| 624 |
+
else:
|
| 625 |
+
save_path = f'visualization/few_shot/{class_name}/{path}.png'
|
| 626 |
+
|
| 627 |
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
| 628 |
+
plt.tight_layout()
|
| 629 |
+
plt.savefig(save_path, bbox_inches='tight', dpi=150)
|
| 630 |
+
plt.close()
|
| 631 |
+
|
| 632 |
+
return {"score": score, "foreground_pixel_count": foreground_pixel_count, "distance": distance, "sam_patch_hist": sam_patch_hist, "instance_masks": instance_masks}
|
| 633 |
+
|
| 634 |
+
elif self.class_name == 'screw_bag':
|
| 635 |
+
# pixel hist of kmeans mask
|
| 636 |
+
foreground_pixel_count = np.sum(np.bincount(kmeans_mask.reshape(-1))[:len(self.foreground_label_idx[self.class_name])]) # foreground pixel
|
| 637 |
+
if self.few_shot_inited and self.foreground_pixel_hist != 0 and self.anomaly_flag is False:
|
| 638 |
+
ratio = foreground_pixel_count / self.foreground_pixel_hist
|
| 639 |
+
# todo: optimize
|
| 640 |
+
if ratio < 0.94 or ratio > 1.06:
|
| 641 |
+
print('foreground pixel histagram of screw bag: {}, the canonical foreground pixel histogram of screw bag in few shot: {}'.format(foreground_pixel_count, self.foreground_pixel_hist))
|
| 642 |
+
self.anomaly_flag = True
|
| 643 |
+
|
| 644 |
+
# patch hist
|
| 645 |
+
binary_screw = np.isin(kmeans_mask, self.foreground_label_idx[self.class_name])
|
| 646 |
+
patch_mask[~binary_screw] = self.patch_query_obj.shape[0] - 1 # remove patch noise
|
| 647 |
+
resized_binary_screw = cv2.resize(binary_screw.astype(np.uint8), (patch_merge_sam.shape[1], patch_merge_sam.shape[0]), interpolation = cv2.INTER_NEAREST)
|
| 648 |
+
patch_merge_sam[~(resized_binary_screw.astype(np.bool))] = self.patch_query_obj.shape[0] - 1 # remove patch noise
|
| 649 |
+
|
| 650 |
+
clip_patch_hist = np.bincount(patch_mask.reshape(-1), minlength=self.patch_query_obj.shape[0])[:-1]
|
| 651 |
+
clip_patch_hist = clip_patch_hist / np.linalg.norm(clip_patch_hist)
|
| 652 |
+
|
| 653 |
+
if self.few_shot_inited:
|
| 654 |
+
patch_hist_similarity = (clip_patch_hist @ self.patch_token_hist.T)
|
| 655 |
+
score = 1 - patch_hist_similarity.max()
|
| 656 |
+
|
| 657 |
+
# # todo: count of screw, nut and washer, screw of different length
|
| 658 |
+
binary_foreground = (patch_merge_sam != (self.patch_query_obj.shape[0] - 1)).astype(np.uint8)
|
| 659 |
+
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(binary_foreground, connectivity=8)
|
| 660 |
+
for i in range(1, num_labels):
|
| 661 |
+
instance_mask = (labels == i).astype(np.uint8)
|
| 662 |
+
instance_mask = cv2.resize(instance_mask, (self.feat_size, self.feat_size), interpolation = cv2.INTER_NEAREST)
|
| 663 |
+
if instance_mask.any():
|
| 664 |
+
instance_masks.append(instance_mask.astype(np.bool).reshape(-1))
|
| 665 |
+
|
| 666 |
+
if len(instance_masks) != 0:
|
| 667 |
+
instance_masks = np.stack(instance_masks) #[N, 64x64]
|
| 668 |
+
|
| 669 |
+
if self.visualization:
|
| 670 |
+
image_list = [raw_image, kmeans_label, kmeans_mask, patch_mask, sam_mask, merge_sam, patch_merge_sam, binary_foreground]
|
| 671 |
+
title_list = ['raw image', 'k-means', 'kmeans mask', 'patch mask', 'sam mask', 'merge sam mask', 'patch merge sam', 'binary_foreground']
|
| 672 |
+
plt.figure(figsize=(20, 3))
|
| 673 |
+
for ind, (temp_title, temp_image) in enumerate(zip(title_list, image_list), start=1):
|
| 674 |
+
plt.subplot(1, len(image_list), ind)
|
| 675 |
+
plt.imshow(temp_image)
|
| 676 |
+
plt.title(temp_title)
|
| 677 |
+
plt.margins(0, 0)
|
| 678 |
+
plt.axis('off')
|
| 679 |
+
# Extract relative path from class_name onwards
|
| 680 |
+
if class_name in path:
|
| 681 |
+
relative_path = path.split(class_name, 1)[-1]
|
| 682 |
+
if relative_path.startswith('/'):
|
| 683 |
+
relative_path = relative_path[1:]
|
| 684 |
+
save_path = f'visualization/few_shot/{class_name}/{relative_path}.png'
|
| 685 |
+
else:
|
| 686 |
+
save_path = f'visualization/few_shot/{class_name}/{path}.png'
|
| 687 |
+
|
| 688 |
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
| 689 |
+
plt.tight_layout()
|
| 690 |
+
plt.savefig(save_path, bbox_inches='tight', dpi=150)
|
| 691 |
+
plt.close()
|
| 692 |
+
|
| 693 |
+
return {"score": score, "foreground_pixel_count": foreground_pixel_count, "clip_patch_hist": clip_patch_hist, "instance_masks": instance_masks}
|
| 694 |
+
|
| 695 |
+
elif self.class_name == 'breakfast_box':
|
| 696 |
+
# patch hist
|
| 697 |
+
sam_patch_hist = np.bincount(patch_merge_sam.reshape(-1), minlength=self.patch_query_obj.shape[0])
|
| 698 |
+
sam_patch_hist = sam_patch_hist / np.linalg.norm(sam_patch_hist)
|
| 699 |
+
|
| 700 |
+
if self.few_shot_inited:
|
| 701 |
+
patch_hist_similarity = (sam_patch_hist @ self.patch_token_hist.T)
|
| 702 |
+
score = 1 - patch_hist_similarity.max()
|
| 703 |
+
|
| 704 |
+
# todo: exist of foreground
|
| 705 |
+
|
| 706 |
+
binary_foreground = (patch_merge_sam != (self.patch_query_obj.shape[0] - 1)).astype(np.uint8)
|
| 707 |
+
|
| 708 |
+
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(binary_foreground, connectivity=8)
|
| 709 |
+
for i in range(1, num_labels):
|
| 710 |
+
instance_mask = (labels == i).astype(np.uint8)
|
| 711 |
+
instance_mask = cv2.resize(instance_mask, (self.feat_size, self.feat_size), interpolation = cv2.INTER_NEAREST)
|
| 712 |
+
if instance_mask.any():
|
| 713 |
+
instance_masks.append(instance_mask.astype(np.bool).reshape(-1))
|
| 714 |
+
|
| 715 |
+
if len(instance_masks) != 0:
|
| 716 |
+
instance_masks = np.stack(instance_masks) #[N, 64x64]
|
| 717 |
+
|
| 718 |
+
if self.visualization:
|
| 719 |
+
image_list = [raw_image, kmeans_label, kmeans_mask, patch_mask, sam_mask, merge_sam, patch_merge_sam, binary_foreground]
|
| 720 |
+
title_list = ['raw image', 'k-means', 'kmeans mask', 'patch mask', 'sam mask', 'merge sam mask', 'patch merge sam', 'binary_foreground']
|
| 721 |
+
plt.figure(figsize=(20, 3))
|
| 722 |
+
for ind, (temp_title, temp_image) in enumerate(zip(title_list, image_list), start=1):
|
| 723 |
+
plt.subplot(1, len(image_list), ind)
|
| 724 |
+
plt.imshow(temp_image)
|
| 725 |
+
plt.title(temp_title)
|
| 726 |
+
plt.margins(0, 0)
|
| 727 |
+
plt.axis('off')
|
| 728 |
+
# Extract relative path from class_name onwards
|
| 729 |
+
if class_name in path:
|
| 730 |
+
relative_path = path.split(class_name, 1)[-1]
|
| 731 |
+
if relative_path.startswith('/'):
|
| 732 |
+
relative_path = relative_path[1:]
|
| 733 |
+
save_path = f'visualization/few_shot/{class_name}/{relative_path}.png'
|
| 734 |
+
else:
|
| 735 |
+
save_path = f'visualization/few_shot/{class_name}/{path}.png'
|
| 736 |
+
|
| 737 |
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
| 738 |
+
plt.tight_layout()
|
| 739 |
+
plt.savefig(save_path, bbox_inches='tight', dpi=150)
|
| 740 |
+
plt.close()
|
| 741 |
+
|
| 742 |
+
return {"score": score, "sam_patch_hist": sam_patch_hist, "instance_masks": instance_masks}
|
| 743 |
+
|
| 744 |
+
elif self.class_name == 'juice_bottle':
|
| 745 |
+
# remove noise due to non sam mask
|
| 746 |
+
merge_sam[sam_mask == 0] = self.classes - 1
|
| 747 |
+
patch_merge_sam[sam_mask == 0] = self.patch_query_obj.shape[0] - 1 # 79.5
|
| 748 |
+
|
| 749 |
+
# [['glass'], ['liquid in bottle'], ['fruit'], ['label', 'tag'], ['black background', 'background']],
|
| 750 |
+
# fruit and liquid mismatch (todo if exist)
|
| 751 |
+
resized_patch_merge_sam = cv2.resize(patch_merge_sam, (self.feat_size, self.feat_size), interpolation = cv2.INTER_NEAREST)
|
| 752 |
+
binary_liquid = (resized_patch_merge_sam == 1)
|
| 753 |
+
binary_fruit = (resized_patch_merge_sam == 2)
|
| 754 |
+
|
| 755 |
+
query_liquid = encode_obj_text(self.model_clip, self.juice_bottle_liquid_query_words_dict, self.tokenizer, self.device)
|
| 756 |
+
query_fruit = encode_obj_text(self.model_clip, self.juice_bottle_fruit_query_words_dict, self.tokenizer, self.device)
|
| 757 |
+
|
| 758 |
+
liquid_feature = proj_patch_token[binary_liquid.reshape(-1), :].mean(0, keepdim=True)
|
| 759 |
+
liquid_idx = (liquid_feature @ query_liquid.T).argmax(-1).squeeze(0).item()
|
| 760 |
+
|
| 761 |
+
fruit_feature = proj_patch_token[binary_fruit.reshape(-1), :].mean(0, keepdim=True)
|
| 762 |
+
fruit_idx = (fruit_feature @ query_fruit.T).argmax(-1).squeeze(0).item()
|
| 763 |
+
|
| 764 |
+
if (liquid_idx != fruit_idx) and self.anomaly_flag is False:
|
| 765 |
+
print('liquid: {}, but fruit: {}.'.format(self.juice_bottle_liquid_query_words_dict[liquid_idx], self.juice_bottle_fruit_query_words_dict[fruit_idx]))
|
| 766 |
+
self.anomaly_flag = True
|
| 767 |
+
|
| 768 |
+
# # todo centroid of fruit and tag_0 mismatch (if exist) , only one tag, center
|
| 769 |
+
|
| 770 |
+
# patch hist
|
| 771 |
+
sam_patch_hist = np.bincount(patch_merge_sam.reshape(-1), minlength=self.patch_query_obj.shape[0])
|
| 772 |
+
sam_patch_hist = sam_patch_hist / np.linalg.norm(sam_patch_hist)
|
| 773 |
+
|
| 774 |
+
if self.few_shot_inited:
|
| 775 |
+
patch_hist_similarity = (sam_patch_hist @ self.patch_token_hist.T)
|
| 776 |
+
score = 1 - patch_hist_similarity.max()
|
| 777 |
+
|
| 778 |
+
binary_foreground = (patch_merge_sam != (self.patch_query_obj.shape[0] - 1) ).astype(np.uint8)
|
| 779 |
+
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(binary_foreground, connectivity=8)
|
| 780 |
+
for i in range(1, num_labels):
|
| 781 |
+
instance_mask = (labels == i).astype(np.uint8)
|
| 782 |
+
instance_mask = cv2.resize(instance_mask, (self.feat_size, self.feat_size), interpolation = cv2.INTER_NEAREST)
|
| 783 |
+
if instance_mask.any():
|
| 784 |
+
instance_masks.append(instance_mask.astype(np.bool).reshape(-1))
|
| 785 |
+
|
| 786 |
+
if len(instance_masks) != 0:
|
| 787 |
+
instance_masks = np.stack(instance_masks) #[N, 64x64]
|
| 788 |
+
|
| 789 |
+
if self.visualization:
|
| 790 |
+
image_list = [raw_image, kmeans_label, kmeans_mask, patch_mask, sam_mask, merge_sam, patch_merge_sam, binary_foreground]
|
| 791 |
+
title_list = ['raw image', 'k-means', 'kmeans mask', 'patch mask', 'sam mask', 'merge sam mask', 'patch merge sam', 'binary_foreground']
|
| 792 |
+
plt.figure(figsize=(20, 3))
|
| 793 |
+
for ind, (temp_title, temp_image) in enumerate(zip(title_list, image_list), start=1):
|
| 794 |
+
plt.subplot(1, len(image_list), ind)
|
| 795 |
+
plt.imshow(temp_image)
|
| 796 |
+
plt.title(temp_title)
|
| 797 |
+
plt.margins(0, 0)
|
| 798 |
+
plt.axis('off')
|
| 799 |
+
# Extract relative path from class_name onwards
|
| 800 |
+
if class_name in path:
|
| 801 |
+
relative_path = path.split(class_name, 1)[-1]
|
| 802 |
+
if relative_path.startswith('/'):
|
| 803 |
+
relative_path = relative_path[1:]
|
| 804 |
+
save_path = f'visualization/few_shot/{class_name}/{relative_path}.png'
|
| 805 |
+
else:
|
| 806 |
+
save_path = f'visualization/few_shot/{class_name}/{path}.png'
|
| 807 |
+
|
| 808 |
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
| 809 |
+
plt.tight_layout()
|
| 810 |
+
plt.savefig(save_path, bbox_inches='tight', dpi=150)
|
| 811 |
+
plt.close()
|
| 812 |
+
|
| 813 |
+
return {"score": score, "sam_patch_hist": sam_patch_hist, "instance_masks": instance_masks}
|
| 814 |
+
|
| 815 |
+
return {"score": score, "instance_masks": instance_masks}
|
| 816 |
+
|
| 817 |
+
|
| 818 |
+
def process_k_shot(self, class_name, few_shot_samples, few_shot_paths):
|
| 819 |
+
few_shot_samples = F.interpolate(few_shot_samples, size=(448, 448), mode=self.inter_mode, align_corners=self.align_corners, antialias=self.antialias)
|
| 820 |
+
|
| 821 |
+
with torch.no_grad():
|
| 822 |
+
image_features, patch_tokens, proj_patch_tokens = self.model_clip.encode_image(few_shot_samples, self.feature_list)
|
| 823 |
+
patch_tokens = [p[:, 1:, :] for p in patch_tokens]
|
| 824 |
+
patch_tokens = [p.reshape(p.shape[0]*p.shape[1], p.shape[2]) for p in patch_tokens]
|
| 825 |
+
|
| 826 |
+
patch_tokens_clip = torch.cat(patch_tokens, dim=-1) # (bs, 1024, 1024x4)
|
| 827 |
+
# patch_tokens_clip = torch.cat(patch_tokens[2:], dim=-1) # (bs, 1024, 1024x2)
|
| 828 |
+
patch_tokens_clip = patch_tokens_clip.view(self.k_shot, self.ori_feat_size, self.ori_feat_size, -1).permute(0, 3, 1, 2)
|
| 829 |
+
patch_tokens_clip = F.interpolate(patch_tokens_clip, size=(self.feat_size, self.feat_size), mode=self.inter_mode, align_corners=self.align_corners)
|
| 830 |
+
patch_tokens_clip = patch_tokens_clip.permute(0, 2, 3, 1).view(-1, self.vision_width * len(self.feature_list))
|
| 831 |
+
patch_tokens_clip = F.normalize(patch_tokens_clip, p=2, dim=-1) # (bsx64x64, 1024x4)
|
| 832 |
+
|
| 833 |
+
with torch.no_grad():
|
| 834 |
+
patch_tokens_dinov2 = self.model_dinov2.forward_features(few_shot_samples, out_layer_list=self.feature_list_dinov2) # 4 x [bs, 32x32, 1024]
|
| 835 |
+
patch_tokens_dinov2 = torch.cat(patch_tokens_dinov2, dim=-1) # (bs, 1024, 1024x4)
|
| 836 |
+
patch_tokens_dinov2 = patch_tokens_dinov2.view(self.k_shot, self.ori_feat_size, self.ori_feat_size, -1).permute(0, 3, 1, 2)
|
| 837 |
+
patch_tokens_dinov2 = F.interpolate(patch_tokens_dinov2, size=(self.feat_size, self.feat_size), mode=self.inter_mode, align_corners=self.align_corners)
|
| 838 |
+
patch_tokens_dinov2 = patch_tokens_dinov2.permute(0, 2, 3, 1).view(-1, self.vision_width_dinov2 * len(self.feature_list_dinov2))
|
| 839 |
+
patch_tokens_dinov2 = F.normalize(patch_tokens_dinov2, p=2, dim=-1) # (bsx64x64, 1024x4)
|
| 840 |
+
|
| 841 |
+
cluster_features = None
|
| 842 |
+
for layer in self.cluster_feature_id:
|
| 843 |
+
temp_feat = patch_tokens[layer]
|
| 844 |
+
cluster_features = temp_feat if cluster_features is None else torch.cat((cluster_features, temp_feat), 1)
|
| 845 |
+
if self.feat_size != self.ori_feat_size:
|
| 846 |
+
cluster_features = cluster_features.view(self.k_shot, self.ori_feat_size, self.ori_feat_size, -1).permute(0, 3, 1, 2)
|
| 847 |
+
cluster_features = F.interpolate(cluster_features, size=(self.feat_size, self.feat_size), mode=self.inter_mode, align_corners=self.align_corners)
|
| 848 |
+
cluster_features = cluster_features.permute(0, 2, 3, 1).view(-1, self.vision_width * len(self.cluster_feature_id))
|
| 849 |
+
cluster_features = F.normalize(cluster_features, p=2, dim=-1)
|
| 850 |
+
|
| 851 |
+
if self.feat_size != self.ori_feat_size:
|
| 852 |
+
proj_patch_tokens = proj_patch_tokens.view(self.k_shot, self.ori_feat_size, self.ori_feat_size, -1).permute(0, 3, 1, 2)
|
| 853 |
+
proj_patch_tokens = F.interpolate(proj_patch_tokens, size=(self.feat_size, self.feat_size), mode=self.inter_mode, align_corners=self.align_corners)
|
| 854 |
+
proj_patch_tokens = proj_patch_tokens.permute(0, 2, 3, 1).view(-1, self.embed_dim)
|
| 855 |
+
proj_patch_tokens = F.normalize(proj_patch_tokens, p=2, dim=-1)
|
| 856 |
+
|
| 857 |
+
num_clusters = self.cluster_num_dict[class_name]
|
| 858 |
+
_, self.cluster_centers = kmeans(X=cluster_features, num_clusters=num_clusters, device=self.device)
|
| 859 |
+
|
| 860 |
+
self.query_obj = encode_obj_text(self.model_clip, self.query_words_dict[class_name], self.tokenizer, self.device)
|
| 861 |
+
self.patch_query_obj = encode_obj_text(self.model_clip, self.patch_query_words_dict[class_name], self.tokenizer, self.device)
|
| 862 |
+
self.classes = self.query_obj.shape[0]
|
| 863 |
+
|
| 864 |
+
scores = []
|
| 865 |
+
foreground_pixel_hist = []
|
| 866 |
+
splicing_connectors_distance = []
|
| 867 |
+
patch_token_hist = []
|
| 868 |
+
mem_instance_masks = []
|
| 869 |
+
|
| 870 |
+
for image, cluster_feature, proj_patch_token, few_shot_path in zip(few_shot_samples.chunk(self.k_shot), cluster_features.chunk(self.k_shot), proj_patch_tokens.chunk(self.k_shot), few_shot_paths):
|
| 871 |
+
# path = os.path.dirname(few_shot_path).split('/')[-1] + "_" + os.path.basename(few_shot_path).split('.')[0]
|
| 872 |
+
self.anomaly_flag = False
|
| 873 |
+
results = self.histogram(image, cluster_feature, proj_patch_token, class_name, "few_shot_" + os.path.basename(few_shot_path).split('.')[0])
|
| 874 |
+
if self.class_name == 'pushpins':
|
| 875 |
+
patch_token_hist.append(results["clip_patch_hist"])
|
| 876 |
+
mem_instance_masks.append(results['instance_masks'])
|
| 877 |
+
|
| 878 |
+
elif self.class_name == 'splicing_connectors':
|
| 879 |
+
foreground_pixel_hist.append(results["foreground_pixel_count"])
|
| 880 |
+
splicing_connectors_distance.append(results["distance"])
|
| 881 |
+
patch_token_hist.append(results["sam_patch_hist"])
|
| 882 |
+
mem_instance_masks.append(results['instance_masks'])
|
| 883 |
+
|
| 884 |
+
elif self.class_name == 'screw_bag':
|
| 885 |
+
foreground_pixel_hist.append(results["foreground_pixel_count"])
|
| 886 |
+
patch_token_hist.append(results["clip_patch_hist"])
|
| 887 |
+
mem_instance_masks.append(results['instance_masks'])
|
| 888 |
+
|
| 889 |
+
elif self.class_name == 'breakfast_box':
|
| 890 |
+
patch_token_hist.append(results["sam_patch_hist"])
|
| 891 |
+
mem_instance_masks.append(results['instance_masks'])
|
| 892 |
+
|
| 893 |
+
elif self.class_name == 'juice_bottle':
|
| 894 |
+
patch_token_hist.append(results["sam_patch_hist"])
|
| 895 |
+
mem_instance_masks.append(results['instance_masks'])
|
| 896 |
+
|
| 897 |
+
scores.append(results["score"])
|
| 898 |
+
|
| 899 |
+
if len(foreground_pixel_hist) != 0:
|
| 900 |
+
self.foreground_pixel_hist = np.mean(foreground_pixel_hist)
|
| 901 |
+
if len(splicing_connectors_distance) != 0:
|
| 902 |
+
self.splicing_connectors_distance = np.mean(splicing_connectors_distance)
|
| 903 |
+
if len(patch_token_hist) != 0: # patch hist
|
| 904 |
+
self.patch_token_hist = np.stack(patch_token_hist)
|
| 905 |
+
if len(mem_instance_masks) != 0:
|
| 906 |
+
self.mem_instance_masks = mem_instance_masks
|
| 907 |
+
|
| 908 |
+
mem_patch_feature_clip_coreset = patch_tokens_clip
|
| 909 |
+
mem_patch_feature_dinov2_coreset = patch_tokens_dinov2
|
| 910 |
+
|
| 911 |
+
return scores, mem_patch_feature_clip_coreset, mem_patch_feature_dinov2_coreset
|
| 912 |
+
|
| 913 |
+
|
| 914 |
+
|
| 915 |
+
def process(self, class_name: str, few_shot_samples: list[torch.Tensor], few_shot_paths: list[str]):
|
| 916 |
+
few_shot_samples = self.transform(few_shot_samples).to(self.device)
|
| 917 |
+
scores, self.mem_patch_feature_clip_coreset, self.mem_patch_feature_dinov2_coreset = self.process_k_shot(class_name, few_shot_samples, few_shot_paths)
|
| 918 |
+
|
| 919 |
+
def setup(self, data: dict) -> None:
|
| 920 |
+
"""Setup the few-shot samples for the model.
|
| 921 |
+
|
| 922 |
+
The evaluation script will call this method to pass the k images for few shot learning and the object class
|
| 923 |
+
name. In the case of MVTec LOCO this will be the dataset category name (e.g. breakfast_box). Please contact
|
| 924 |
+
the organizing committee if if your model requires any additional dataset-related information at setup-time.
|
| 925 |
+
"""
|
| 926 |
+
few_shot_samples = data.get("few_shot_samples")
|
| 927 |
+
class_name = data.get("dataset_category")
|
| 928 |
+
few_shot_paths = data.get("few_shot_samples_path")
|
| 929 |
+
self.class_name = class_name
|
| 930 |
+
|
| 931 |
+
self.k_shot = few_shot_samples.size(0)
|
| 932 |
+
self.process(class_name, few_shot_samples, few_shot_paths)
|
| 933 |
+
self.few_shot_inited = True
|
| 934 |
+
|
| 935 |
+
|
prompt_ensemble.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import Union, List
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
from imagenet_template import openai_imagenet_template
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def encode_text_with_prompt_ensemble(model, objs, tokenizer, device):
|
| 10 |
+
prompt_normal = ['{}', 'flawless {}', 'perfect {}', 'unblemished {}', '{} without flaw', '{} without defect', '{} without damage']
|
| 11 |
+
prompt_abnormal = ['damaged {}', 'broken {}', '{} with flaw', '{} with defect', '{} with damage']
|
| 12 |
+
prompt_state = [prompt_normal, prompt_abnormal]
|
| 13 |
+
prompt_templates = ['a bad photo of a {}.', 'a low resolution photo of the {}.', 'a bad photo of the {}.', 'a cropped photo of the {}.', 'a bright photo of a {}.', 'a dark photo of the {}.', 'a photo of my {}.', 'a photo of the cool {}.', 'a close-up photo of a {}.', 'a black and white photo of the {}.', 'a bright photo of the {}.', 'a cropped photo of a {}.', 'a jpeg corrupted photo of a {}.', 'a blurry photo of the {}.', 'a photo of the {}.', 'a good photo of the {}.', 'a photo of one {}.', 'a close-up photo of the {}.', 'a photo of a {}.', 'a low resolution photo of a {}.', 'a photo of a large {}.', 'a blurry photo of a {}.', 'a jpeg corrupted photo of the {}.', 'a good photo of a {}.', 'a photo of the small {}.', 'a photo of the large {}.', 'a black and white photo of a {}.', 'a dark photo of a {}.', 'a photo of a cool {}.', 'a photo of a small {}.', 'there is a {} in the scene.', 'there is the {} in the scene.', 'this is a {} in the scene.', 'this is the {} in the scene.', 'this is one {} in the scene.']
|
| 14 |
+
text_prompts = {}
|
| 15 |
+
for obj in objs:
|
| 16 |
+
text_features = []
|
| 17 |
+
for i in range(len(prompt_state)):
|
| 18 |
+
prompted_state = [state.format(obj) for state in prompt_state[i]]
|
| 19 |
+
prompted_sentence = []
|
| 20 |
+
for s in prompted_state:
|
| 21 |
+
for template in prompt_templates:
|
| 22 |
+
prompted_sentence.append(template.format(s))
|
| 23 |
+
prompted_sentence = tokenizer(prompted_sentence).to(device)
|
| 24 |
+
class_embeddings = model.encode_text(prompted_sentence)
|
| 25 |
+
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
|
| 26 |
+
class_embedding = class_embeddings.mean(dim=0)
|
| 27 |
+
class_embedding /= class_embedding.norm()
|
| 28 |
+
text_features.append(class_embedding)
|
| 29 |
+
|
| 30 |
+
text_features = torch.stack(text_features, dim=1).to(device)
|
| 31 |
+
text_prompts[obj] = text_features
|
| 32 |
+
|
| 33 |
+
return text_prompts
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def encode_general_text(model, obj_list, tokenizer, device):
|
| 37 |
+
text_dir = '/data/yizhou/VAND2.0/wgd/general_texts/train2014'
|
| 38 |
+
text_name_list = sorted(os.listdir(text_dir))
|
| 39 |
+
bs = 100
|
| 40 |
+
sentences = []
|
| 41 |
+
embeddings = []
|
| 42 |
+
all_sentences = []
|
| 43 |
+
for text_name in tqdm(text_name_list):
|
| 44 |
+
with open(os.path.join(text_dir, text_name), 'r') as f:
|
| 45 |
+
for line in f.readlines():
|
| 46 |
+
sentences.append(line.strip())
|
| 47 |
+
if len(sentences) > bs:
|
| 48 |
+
prompted_sentences = tokenizer(sentences).to(device)
|
| 49 |
+
class_embeddings = model.encode_text(prompted_sentences)
|
| 50 |
+
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
|
| 51 |
+
embeddings.append(class_embeddings)
|
| 52 |
+
all_sentences.extend(sentences)
|
| 53 |
+
sentences = []
|
| 54 |
+
# if len(all_sentences) > 10000:
|
| 55 |
+
# break
|
| 56 |
+
embeddings = torch.cat(embeddings, 0)
|
| 57 |
+
print(embeddings.size(0))
|
| 58 |
+
embeddings_dict = {}
|
| 59 |
+
for obj in obj_list:
|
| 60 |
+
embeddings_dict[obj] = embeddings
|
| 61 |
+
return embeddings_dict, all_sentences
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def encode_abnormal_text(model, obj_list, tokenizer, device):
|
| 65 |
+
embeddings = {}
|
| 66 |
+
sentences = {}
|
| 67 |
+
for obj in obj_list:
|
| 68 |
+
sentence_abnormal = []
|
| 69 |
+
with open(os.path.join('text_prompt', 'v1', obj + '_abnormal.txt'), 'r') as f:
|
| 70 |
+
for line in f.readlines():
|
| 71 |
+
sentence_abnormal.append(line.strip().lower())
|
| 72 |
+
|
| 73 |
+
prompted_sentences = tokenizer(sentence_abnormal).to(device)
|
| 74 |
+
class_embeddings = model.encode_text(prompted_sentences)
|
| 75 |
+
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
|
| 76 |
+
embeddings[obj] = class_embeddings
|
| 77 |
+
sentences[obj] = sentence_abnormal
|
| 78 |
+
return embeddings, sentences
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def encode_normal_text(model, obj_list, tokenizer, device):
|
| 82 |
+
embeddings = {}
|
| 83 |
+
sentences = {}
|
| 84 |
+
for obj in obj_list:
|
| 85 |
+
sentence_abnormal = []
|
| 86 |
+
with open(os.path.join('text_prompt', 'v1', obj + '_normal.txt'), 'r') as f:
|
| 87 |
+
for line in f.readlines():
|
| 88 |
+
sentence_abnormal.append(line.strip().lower())
|
| 89 |
+
|
| 90 |
+
prompted_sentences = tokenizer(sentence_abnormal).to(device)
|
| 91 |
+
class_embeddings = model.encode_text(prompted_sentences)
|
| 92 |
+
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
|
| 93 |
+
embeddings[obj] = class_embeddings
|
| 94 |
+
sentences[obj] = sentence_abnormal
|
| 95 |
+
return embeddings, sentences
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def encode_obj_text(model, query_words, tokenizer, device):
|
| 99 |
+
# query_words = ['orange', "nectarine", "cereals", "banana chips", 'almonds', 'white box']
|
| 100 |
+
# query_words = ['liquid', 'glass', "top", 'black background']
|
| 101 |
+
# query_words = ["connector", "grid"]
|
| 102 |
+
# query_words = [['screw'], 'plastic bag', 'background']
|
| 103 |
+
# query_words = [['pushpin', 'pin'], ['plastic box'], 'box', 'black background']
|
| 104 |
+
query_features = []
|
| 105 |
+
with torch.no_grad():
|
| 106 |
+
for qw in query_words:
|
| 107 |
+
token_input = []
|
| 108 |
+
if type(qw) == list:
|
| 109 |
+
for qw2 in qw:
|
| 110 |
+
token_input.extend([temp(qw2) for temp in openai_imagenet_template])
|
| 111 |
+
else:
|
| 112 |
+
token_input = [temp(qw) for temp in openai_imagenet_template]
|
| 113 |
+
query = tokenizer(token_input).to(device)
|
| 114 |
+
feature = model.encode_text(query)
|
| 115 |
+
feature /= feature.norm(dim=-1, keepdim=True)
|
| 116 |
+
feature = feature.mean(dim=0)
|
| 117 |
+
feature /= feature.norm()
|
| 118 |
+
query_features.append(feature.unsqueeze(0))
|
| 119 |
+
query_features = torch.cat(query_features, dim=0)
|
| 120 |
+
return query_features
|
| 121 |
+
|
requirements.txt
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
aiohappyeyeballs==2.6.1
|
| 2 |
+
aiohttp==3.12.11
|
| 3 |
+
aiosignal==1.3.2
|
| 4 |
+
antlr4-python3-runtime==4.9.3
|
| 5 |
+
async-timeout==5.0.1
|
| 6 |
+
attrs==25.3.0
|
| 7 |
+
certifi==2025.4.26
|
| 8 |
+
charset-normalizer==3.4.2
|
| 9 |
+
contourpy==1.3.2
|
| 10 |
+
cycler==0.12.1
|
| 11 |
+
einops==0.6.1
|
| 12 |
+
faiss-cpu==1.8.0
|
| 13 |
+
filelock==3.18.0
|
| 14 |
+
fonttools==4.58.2
|
| 15 |
+
FrEIA==0.2
|
| 16 |
+
frozenlist==1.6.2
|
| 17 |
+
fsspec==2024.12.0
|
| 18 |
+
ftfy==6.3.1
|
| 19 |
+
hf-xet==1.1.3
|
| 20 |
+
huggingface-hub==0.32.4
|
| 21 |
+
idna==3.10
|
| 22 |
+
imageio==2.37.0
|
| 23 |
+
imgaug==0.4.0
|
| 24 |
+
Jinja2==3.1.6
|
| 25 |
+
joblib==1.5.1
|
| 26 |
+
jsonargparse==4.29.0
|
| 27 |
+
kiwisolver==1.4.8
|
| 28 |
+
kmeans-pytorch==0.3
|
| 29 |
+
kornia==0.7.0
|
| 30 |
+
lazy_loader==0.4
|
| 31 |
+
lightning==2.2.5
|
| 32 |
+
lightning-utilities==0.14.3
|
| 33 |
+
markdown-it-py==3.0.0
|
| 34 |
+
MarkupSafe==3.0.2
|
| 35 |
+
matplotlib==3.10.3
|
| 36 |
+
mdurl==0.1.2
|
| 37 |
+
mpmath==1.3.0
|
| 38 |
+
multidict==6.4.4
|
| 39 |
+
networkx==3.4.2
|
| 40 |
+
omegaconf==2.3.0
|
| 41 |
+
open-clip-torch==2.24.0
|
| 42 |
+
opencv-python==4.8.1.78
|
| 43 |
+
packaging==24.2
|
| 44 |
+
pandas==2.0.3
|
| 45 |
+
pillow==11.2.1
|
| 46 |
+
propcache==0.3.1
|
| 47 |
+
protobuf==6.31.1
|
| 48 |
+
Pygments==2.19.1
|
| 49 |
+
pyparsing==3.2.3
|
| 50 |
+
python-dateutil==2.9.0.post0
|
| 51 |
+
pytorch-lightning==2.5.1.post0
|
| 52 |
+
pytz==2025.2
|
| 53 |
+
PyYAML==6.0.2
|
| 54 |
+
regex==2024.11.6
|
| 55 |
+
requests==2.32.3
|
| 56 |
+
rich==13.7.1
|
| 57 |
+
safetensors==0.5.3
|
| 58 |
+
scikit-image==0.25.2
|
| 59 |
+
scikit-learn==1.2.2
|
| 60 |
+
scipy==1.15.3
|
| 61 |
+
segment-anything==1.0
|
| 62 |
+
sentencepiece==0.2.0
|
| 63 |
+
shapely==2.1.1
|
| 64 |
+
six==1.17.0
|
| 65 |
+
sympy==1.14.0
|
| 66 |
+
tabulate==0.9.0
|
| 67 |
+
threadpoolctl==3.6.0
|
| 68 |
+
tifffile==2025.5.10
|
| 69 |
+
timm==1.0.15
|
| 70 |
+
torchmetrics==1.7.2
|
| 71 |
+
tqdm==4.67.1
|
| 72 |
+
triton==2.1.0
|
| 73 |
+
typing_extensions==4.14.0
|
| 74 |
+
tzdata==2025.2
|
| 75 |
+
urllib3==2.4.0
|
| 76 |
+
wcwidth==0.2.13
|
| 77 |
+
yarl==1.20.0
|