Upload VITON-HD/eval/single_object_evaluation.py with huggingface_hub
Browse files
VITON-HD/eval/single_object_evaluation.py
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from evaluation.single_object.data import get_combinations
|
2 |
+
from evaluation.clip_eval import CLIPEvaluator
|
3 |
+
from torchvision.transforms import ToTensor
|
4 |
+
from accelerate import Accelerator
|
5 |
+
from typing import List
|
6 |
+
from PIL import Image
|
7 |
+
from tqdm import tqdm
|
8 |
+
import numpy as np
|
9 |
+
import argparse
|
10 |
+
import torch
|
11 |
+
import glob
|
12 |
+
import os
|
13 |
+
|
14 |
+
|
15 |
+
def read_reference_images(folder_path: str) -> List[np.ndarray]:
|
16 |
+
images = []
|
17 |
+
for filename in os.listdir(folder_path):
|
18 |
+
image_path = os.path.join(folder_path, filename)
|
19 |
+
image = Image.open(image_path).convert("RGB")
|
20 |
+
images.append(image)
|
21 |
+
return images
|
22 |
+
|
23 |
+
|
24 |
+
def compute_similarity_matrix(
|
25 |
+
evaluator, cropped_images: List[np.ndarray], reference_images: List[np.ndarray]
|
26 |
+
) -> np.ndarray:
|
27 |
+
similarity_matrix = np.zeros((len(cropped_images), len(reference_images)))
|
28 |
+
for i, cropped_image in enumerate(cropped_images):
|
29 |
+
for j, reference_image in enumerate(reference_images):
|
30 |
+
embed1 = evaluator(cropped_image)
|
31 |
+
embed2 = evaluator(reference_image)
|
32 |
+
similarity_matrix[i, j] = embed1 @ embed2.T
|
33 |
+
|
34 |
+
print(similarity_matrix)
|
35 |
+
return similarity_matrix
|
36 |
+
|
37 |
+
|
38 |
+
def greedy_matching(scores):
|
39 |
+
n, m = scores.shape
|
40 |
+
assert n == m
|
41 |
+
res = []
|
42 |
+
for _ in range(m):
|
43 |
+
pos = np.argmax(scores)
|
44 |
+
i, j = pos // m, pos % m
|
45 |
+
|
46 |
+
res.append(scores[i, j])
|
47 |
+
scores[i, :] = -1
|
48 |
+
scores[:, j] = -1
|
49 |
+
|
50 |
+
return min(res)
|
51 |
+
|
52 |
+
|
53 |
+
def save_image(tensor, path):
|
54 |
+
tensor = (tensor[0] * 0.5 + 0.5).clamp(min=0, max=1).permute(1, 2, 0) * 255.0
|
55 |
+
tensor = tensor.cpu().numpy().astype(np.uint8)
|
56 |
+
|
57 |
+
Image.fromarray(tensor).save(path)
|
58 |
+
|
59 |
+
|
60 |
+
def compute_average_similarity(
|
61 |
+
idx, face_detector, face_similarity, generated_image, reference_image
|
62 |
+
) -> float:
|
63 |
+
generated_face = face_detector(generated_image)
|
64 |
+
|
65 |
+
if generated_face == None:
|
66 |
+
return 0.0
|
67 |
+
generated_face = generated_face[:1]
|
68 |
+
|
69 |
+
reference_face = face_detector(reference_image)[:1]
|
70 |
+
assert len(reference_face) == 1, "no reference face detected in reference image"
|
71 |
+
|
72 |
+
generated_face = generated_face.to(face_detector.device).reshape(1, 3, 160, 160)
|
73 |
+
reference_face = reference_face.to(face_detector.device).reshape(1, 3, 160, 160)
|
74 |
+
|
75 |
+
similarity = face_similarity(generated_face) @ face_similarity(reference_face).T
|
76 |
+
return max(similarity.item(), 0.0)
|
77 |
+
|
78 |
+
|
79 |
+
def parse_args():
|
80 |
+
parser = argparse.ArgumentParser()
|
81 |
+
parser.add_argument("--num_images_per_prompt", type=int, default=4)
|
82 |
+
parser.add_argument("--prediction_folder", type=str)
|
83 |
+
parser.add_argument("--reference_folder", type=str)
|
84 |
+
|
85 |
+
args = parser.parse_args()
|
86 |
+
return args
|
87 |
+
|
88 |
+
|
89 |
+
def load_reference_image(reference_folder, image_id):
|
90 |
+
path = os.path.join(reference_folder, image_id)
|
91 |
+
image_path = sorted(glob.glob(os.path.join(path, "*.jpg")))[0]
|
92 |
+
image = Image.open(image_path).convert("RGB")
|
93 |
+
return image
|
94 |
+
|
95 |
+
|
96 |
+
@torch.no_grad()
|
97 |
+
def main():
|
98 |
+
args = parse_args()
|
99 |
+
|
100 |
+
accelerator = Accelerator()
|
101 |
+
|
102 |
+
from facenet_pytorch import MTCNN, InceptionResnetV1
|
103 |
+
|
104 |
+
face_detector = MTCNN(
|
105 |
+
image_size=160,
|
106 |
+
margin=0,
|
107 |
+
min_face_size=20,
|
108 |
+
thresholds=[0.6, 0.7, 0.7],
|
109 |
+
factor=0.709,
|
110 |
+
post_process=True,
|
111 |
+
device=accelerator.device,
|
112 |
+
keep_all=True,
|
113 |
+
)
|
114 |
+
face_similarity = (
|
115 |
+
InceptionResnetV1(pretrained="vggface2").eval().to(accelerator.device)
|
116 |
+
)
|
117 |
+
|
118 |
+
text_evaluator = CLIPEvaluator(device=accelerator.device, clip_model="ViT-L/14")
|
119 |
+
|
120 |
+
# get subject
|
121 |
+
prompt_subject_pairs = get_combinations("", is_fastcomposer=True, split="eval")
|
122 |
+
image_alignments, text_alignments = [], []
|
123 |
+
|
124 |
+
for case_id, (prompt_list, subject) in enumerate(tqdm(prompt_subject_pairs)):
|
125 |
+
# TODO: Load reference images using image_ids from subjects
|
126 |
+
ref_image = load_reference_image(args.reference_folder, subject)
|
127 |
+
|
128 |
+
for prompt_id, prompt in enumerate(prompt_list):
|
129 |
+
for instance_id in range(args.num_images_per_prompt):
|
130 |
+
generated_image_path = os.path.join(
|
131 |
+
args.prediction_folder,
|
132 |
+
f"subject_{case_id:04d}_prompt_{prompt_id:04d}_instance_{instance_id:04d}.jpg",
|
133 |
+
)
|
134 |
+
generated_image = Image.open(generated_image_path).convert("RGB")
|
135 |
+
|
136 |
+
identity_similarity = compute_average_similarity(
|
137 |
+
case_id, face_detector, face_similarity, generated_image, ref_image
|
138 |
+
)
|
139 |
+
|
140 |
+
generated_image_tensor = (
|
141 |
+
ToTensor()(generated_image).unsqueeze(0) * 2.0 - 1.0
|
142 |
+
)
|
143 |
+
prompt_similarity = text_evaluator.txt_to_img_similarity(
|
144 |
+
prompt, generated_image_tensor
|
145 |
+
)
|
146 |
+
|
147 |
+
image_alignments.append(float(identity_similarity))
|
148 |
+
text_alignments.append(float(prompt_similarity))
|
149 |
+
|
150 |
+
image_alignment = sum(image_alignments) / len(image_alignments)
|
151 |
+
text_alignment = sum(text_alignments) / len(text_alignments)
|
152 |
+
image_std = np.std(image_alignments)
|
153 |
+
text_std = np.std(text_alignments)
|
154 |
+
|
155 |
+
print(f"Image Alignment: {image_alignment} +- {image_std}")
|
156 |
+
print(f"Text Alignment: {text_alignment} +- {text_std}")
|
157 |
+
with open(os.path.join(args.prediction_folder, "score.txt"), "w") as f:
|
158 |
+
f.write(f"Image Alignment: {image_alignment} Text Alignment: {text_alignment}")
|
159 |
+
f.write(f"Image Alignment Std: {image_std} Text Alignment Std: {text_std}")
|
160 |
+
|
161 |
+
|
162 |
+
if __name__ == "__main__":
|
163 |
+
main()
|