zjuJish commited on
Commit
ac43f48
·
verified ·
1 Parent(s): 4f4476c

Upload VITON-HD/eval/single_object_evaluation.py with huggingface_hub

Browse files
VITON-HD/eval/single_object_evaluation.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from evaluation.single_object.data import get_combinations
2
+ from evaluation.clip_eval import CLIPEvaluator
3
+ from torchvision.transforms import ToTensor
4
+ from accelerate import Accelerator
5
+ from typing import List
6
+ from PIL import Image
7
+ from tqdm import tqdm
8
+ import numpy as np
9
+ import argparse
10
+ import torch
11
+ import glob
12
+ import os
13
+
14
+
15
+ def read_reference_images(folder_path: str) -> List[np.ndarray]:
16
+ images = []
17
+ for filename in os.listdir(folder_path):
18
+ image_path = os.path.join(folder_path, filename)
19
+ image = Image.open(image_path).convert("RGB")
20
+ images.append(image)
21
+ return images
22
+
23
+
24
+ def compute_similarity_matrix(
25
+ evaluator, cropped_images: List[np.ndarray], reference_images: List[np.ndarray]
26
+ ) -> np.ndarray:
27
+ similarity_matrix = np.zeros((len(cropped_images), len(reference_images)))
28
+ for i, cropped_image in enumerate(cropped_images):
29
+ for j, reference_image in enumerate(reference_images):
30
+ embed1 = evaluator(cropped_image)
31
+ embed2 = evaluator(reference_image)
32
+ similarity_matrix[i, j] = embed1 @ embed2.T
33
+
34
+ print(similarity_matrix)
35
+ return similarity_matrix
36
+
37
+
38
+ def greedy_matching(scores):
39
+ n, m = scores.shape
40
+ assert n == m
41
+ res = []
42
+ for _ in range(m):
43
+ pos = np.argmax(scores)
44
+ i, j = pos // m, pos % m
45
+
46
+ res.append(scores[i, j])
47
+ scores[i, :] = -1
48
+ scores[:, j] = -1
49
+
50
+ return min(res)
51
+
52
+
53
+ def save_image(tensor, path):
54
+ tensor = (tensor[0] * 0.5 + 0.5).clamp(min=0, max=1).permute(1, 2, 0) * 255.0
55
+ tensor = tensor.cpu().numpy().astype(np.uint8)
56
+
57
+ Image.fromarray(tensor).save(path)
58
+
59
+
60
+ def compute_average_similarity(
61
+ idx, face_detector, face_similarity, generated_image, reference_image
62
+ ) -> float:
63
+ generated_face = face_detector(generated_image)
64
+
65
+ if generated_face == None:
66
+ return 0.0
67
+ generated_face = generated_face[:1]
68
+
69
+ reference_face = face_detector(reference_image)[:1]
70
+ assert len(reference_face) == 1, "no reference face detected in reference image"
71
+
72
+ generated_face = generated_face.to(face_detector.device).reshape(1, 3, 160, 160)
73
+ reference_face = reference_face.to(face_detector.device).reshape(1, 3, 160, 160)
74
+
75
+ similarity = face_similarity(generated_face) @ face_similarity(reference_face).T
76
+ return max(similarity.item(), 0.0)
77
+
78
+
79
+ def parse_args():
80
+ parser = argparse.ArgumentParser()
81
+ parser.add_argument("--num_images_per_prompt", type=int, default=4)
82
+ parser.add_argument("--prediction_folder", type=str)
83
+ parser.add_argument("--reference_folder", type=str)
84
+
85
+ args = parser.parse_args()
86
+ return args
87
+
88
+
89
+ def load_reference_image(reference_folder, image_id):
90
+ path = os.path.join(reference_folder, image_id)
91
+ image_path = sorted(glob.glob(os.path.join(path, "*.jpg")))[0]
92
+ image = Image.open(image_path).convert("RGB")
93
+ return image
94
+
95
+
96
+ @torch.no_grad()
97
+ def main():
98
+ args = parse_args()
99
+
100
+ accelerator = Accelerator()
101
+
102
+ from facenet_pytorch import MTCNN, InceptionResnetV1
103
+
104
+ face_detector = MTCNN(
105
+ image_size=160,
106
+ margin=0,
107
+ min_face_size=20,
108
+ thresholds=[0.6, 0.7, 0.7],
109
+ factor=0.709,
110
+ post_process=True,
111
+ device=accelerator.device,
112
+ keep_all=True,
113
+ )
114
+ face_similarity = (
115
+ InceptionResnetV1(pretrained="vggface2").eval().to(accelerator.device)
116
+ )
117
+
118
+ text_evaluator = CLIPEvaluator(device=accelerator.device, clip_model="ViT-L/14")
119
+
120
+ # get subject
121
+ prompt_subject_pairs = get_combinations("", is_fastcomposer=True, split="eval")
122
+ image_alignments, text_alignments = [], []
123
+
124
+ for case_id, (prompt_list, subject) in enumerate(tqdm(prompt_subject_pairs)):
125
+ # TODO: Load reference images using image_ids from subjects
126
+ ref_image = load_reference_image(args.reference_folder, subject)
127
+
128
+ for prompt_id, prompt in enumerate(prompt_list):
129
+ for instance_id in range(args.num_images_per_prompt):
130
+ generated_image_path = os.path.join(
131
+ args.prediction_folder,
132
+ f"subject_{case_id:04d}_prompt_{prompt_id:04d}_instance_{instance_id:04d}.jpg",
133
+ )
134
+ generated_image = Image.open(generated_image_path).convert("RGB")
135
+
136
+ identity_similarity = compute_average_similarity(
137
+ case_id, face_detector, face_similarity, generated_image, ref_image
138
+ )
139
+
140
+ generated_image_tensor = (
141
+ ToTensor()(generated_image).unsqueeze(0) * 2.0 - 1.0
142
+ )
143
+ prompt_similarity = text_evaluator.txt_to_img_similarity(
144
+ prompt, generated_image_tensor
145
+ )
146
+
147
+ image_alignments.append(float(identity_similarity))
148
+ text_alignments.append(float(prompt_similarity))
149
+
150
+ image_alignment = sum(image_alignments) / len(image_alignments)
151
+ text_alignment = sum(text_alignments) / len(text_alignments)
152
+ image_std = np.std(image_alignments)
153
+ text_std = np.std(text_alignments)
154
+
155
+ print(f"Image Alignment: {image_alignment} +- {image_std}")
156
+ print(f"Text Alignment: {text_alignment} +- {text_std}")
157
+ with open(os.path.join(args.prediction_folder, "score.txt"), "w") as f:
158
+ f.write(f"Image Alignment: {image_alignment} Text Alignment: {text_alignment}")
159
+ f.write(f"Image Alignment Std: {image_std} Text Alignment Std: {text_std}")
160
+
161
+
162
+ if __name__ == "__main__":
163
+ main()