Upload VITON-HD/eval/tryon-fid.py with huggingface_hub
Browse files- VITON-HD/eval/tryon-fid.py +268 -0
VITON-HD/eval/tryon-fid.py
ADDED
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
from typing import List, Tuple, Dict
|
5 |
+
|
6 |
+
import PIL.Image
|
7 |
+
import torch
|
8 |
+
from cleanfid import fid
|
9 |
+
from torch.utils.data import ConcatDataset, DataLoader
|
10 |
+
from torchmetrics import StructuralSimilarityIndexMeasure
|
11 |
+
from torchmetrics.image.inception import InceptionScore
|
12 |
+
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
|
13 |
+
from torchvision import transforms
|
14 |
+
from tqdm import tqdm
|
15 |
+
|
16 |
+
from .generate_fid_stats import make_custom_stats
|
17 |
+
|
18 |
+
|
19 |
+
class GTTestDataset(torch.utils.data.Dataset):
|
20 |
+
def __init__(self, dataroot: str, dataset: str, category: str, transform: transforms.Compose):
|
21 |
+
"""
|
22 |
+
Dataset for the ground truth test images
|
23 |
+
"""
|
24 |
+
|
25 |
+
# Validate inputs
|
26 |
+
assert dataset in ['dresscode', 'vitonhd'], 'Unsupported dataset'
|
27 |
+
assert category in ['all', 'dresses', 'lower_body', 'upper_body'], 'Unsupported category'
|
28 |
+
|
29 |
+
self.dataset = dataset
|
30 |
+
self.category = category
|
31 |
+
self.transform = transform
|
32 |
+
self.dataroot = dataroot
|
33 |
+
|
34 |
+
# Get the paths to the images
|
35 |
+
if dataset == 'dresscode':
|
36 |
+
filepath = os.path.join(dataroot, f"test_pairs_paired.txt")
|
37 |
+
with open(filepath, 'r') as f:
|
38 |
+
lines = f.read().splitlines()
|
39 |
+
|
40 |
+
if category in ['lower_body', 'upper_body', 'dresses']:
|
41 |
+
self.paths = sorted(
|
42 |
+
[os.path.join(dataroot, category, 'images', line.strip().split()[0]) for line in lines if
|
43 |
+
os.path.exists(os.path.join(dataroot, category, 'images', line.strip().split()[0]))])
|
44 |
+
else:
|
45 |
+
self.paths = sorted(
|
46 |
+
[os.path.join(dataroot, category, 'images', line.strip().split()[0]) for line in lines for
|
47 |
+
category in ['lower_body', 'upper_body', 'dresses'] if
|
48 |
+
os.path.exists(os.path.join(dataroot, category, 'images', line.strip().split()[0]))])
|
49 |
+
else: # vitonhd
|
50 |
+
filepath = os.path.join(dataroot, f"test_pairs.txt")
|
51 |
+
with open(filepath, 'r') as f:
|
52 |
+
lines = f.read().splitlines()
|
53 |
+
self.paths = sorted([os.path.join(dataroot, 'test', 'image', line.strip().split()[0]) for line in lines])
|
54 |
+
|
55 |
+
def __len__(self):
|
56 |
+
return len(self.paths)
|
57 |
+
|
58 |
+
def __getitem__(self, idx):
|
59 |
+
path = self.paths[idx]
|
60 |
+
name = os.path.splitext(os.path.basename(path))[0]
|
61 |
+
img = self.transform(PIL.Image.open(path).convert('RGB'))
|
62 |
+
return img, name
|
63 |
+
|
64 |
+
|
65 |
+
class GenTestDataset(torch.utils.data.Dataset):
|
66 |
+
def __init__(self, gen_folder: str, category: str, transform: transforms.Compose):
|
67 |
+
"""
|
68 |
+
Dataset for the ground truth test images
|
69 |
+
"""
|
70 |
+
|
71 |
+
# Validate inputs
|
72 |
+
assert category in ['all', 'dresses', 'lower_body', 'upper_body'], 'Unsupported category'
|
73 |
+
|
74 |
+
self.category = category
|
75 |
+
self.transform = transform
|
76 |
+
self.gen_folder = gen_folder
|
77 |
+
|
78 |
+
# Get the paths to the images
|
79 |
+
if category in ['lower_body', 'upper_body', 'dresses']:
|
80 |
+
self.paths = sorted(
|
81 |
+
[os.path.join(gen_folder, category, name) for name in os.listdir(os.path.join(gen_folder, category))])
|
82 |
+
elif category == 'all':
|
83 |
+
existing_categories = []
|
84 |
+
for category in ['lower_body', 'upper_body', 'dresses']:
|
85 |
+
if os.path.exists(os.path.join(gen_folder, category)):
|
86 |
+
existing_categories.append(category)
|
87 |
+
|
88 |
+
self.paths = sorted(
|
89 |
+
[os.path.join(gen_folder, category, name) for category in existing_categories for
|
90 |
+
name in os.listdir(os.path.join(gen_folder, category)) if
|
91 |
+
os.path.exists(os.path.join(gen_folder, category, name))])
|
92 |
+
else:
|
93 |
+
raise ValueError('Unsupported category')
|
94 |
+
|
95 |
+
def __len__(self):
|
96 |
+
return len(self.paths)
|
97 |
+
|
98 |
+
def __getitem__(self, idx):
|
99 |
+
path = self.paths[idx]
|
100 |
+
name = os.path.splitext(os.path.basename(path))[0]
|
101 |
+
img = self.transform(PIL.Image.open(path).convert('RGB'))
|
102 |
+
return img, name
|
103 |
+
|
104 |
+
|
105 |
+
# metrics = compute_metrics(args.gen_folder, args.test_order, args.dataset, args.category, ['all'],
|
106 |
+
# args.dresscode_dataroot, args.vitonhd_dataroot, batch_size=args.batch_size,
|
107 |
+
# workers=args.workers)
|
108 |
+
|
109 |
+
def compute_metrics(gen_folder: str, test_order: str, dataset: str, category: str, metrics2compute: List[str],
|
110 |
+
dresscode_dataroot: str, vitonhd_dataroot: str, generated_size: Tuple[int, int] = (512, 384),
|
111 |
+
batch_size: int = 32, workers: int = 8) -> Dict[str, float]:
|
112 |
+
"""
|
113 |
+
Computes the metrics for the generated images in gen_folder
|
114 |
+
"""
|
115 |
+
|
116 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
117 |
+
|
118 |
+
# Input validation
|
119 |
+
assert test_order in ['paired', 'unpaired']
|
120 |
+
assert dataset in ['dresscode', 'vitonhd'], 'Unsupported dataset'
|
121 |
+
assert category in ['all', 'dresses', 'lower_body', 'upper_body'], 'Unsupported category'
|
122 |
+
|
123 |
+
if dataset == 'dresscode':
|
124 |
+
gt_folder = dresscode_dataroot
|
125 |
+
elif dataset == 'vitonhd':
|
126 |
+
gt_folder = vitonhd_dataroot
|
127 |
+
else:
|
128 |
+
raise ValueError('Unsupported dataset')
|
129 |
+
|
130 |
+
for m in metrics2compute:
|
131 |
+
assert m in ['all', 'ssim_score', 'lpips_score', 'fid_score', 'kid_score', 'is_score'], 'Unsupported metric'
|
132 |
+
|
133 |
+
if metrics2compute == ['all']:
|
134 |
+
metrics2compute = ['ssim_score', 'lpips_score', 'fid_score', 'kid_score', 'is_score']
|
135 |
+
|
136 |
+
# Compute FID and KID scores
|
137 |
+
if category == 'all':
|
138 |
+
if "fid_score" in metrics2compute or "all" in metrics2compute:
|
139 |
+
# Check if FID stats exist, if not compute them
|
140 |
+
if not fid.test_stats_exists(f"{dataset}_all", mode='clean'):
|
141 |
+
make_custom_stats(dresscode_dataroot, vitonhd_dataroot)
|
142 |
+
|
143 |
+
# Compute FID score
|
144 |
+
fid_score = fid.compute_fid(gen_folder, dataset_name=f"{dataset}_all", mode='clean', dataset_split="custom",
|
145 |
+
verbose=True, use_dataparallel=False)
|
146 |
+
if "kid_score" in metrics2compute or "all" in metrics2compute:
|
147 |
+
|
148 |
+
# Check if KID stats exist, if not compute them
|
149 |
+
if not fid.test_stats_exists(f"{dataset}_all", mode='clean'):
|
150 |
+
make_custom_stats(dresscode_dataroot, vitonhd_dataroot)
|
151 |
+
|
152 |
+
# Compute FID score
|
153 |
+
kid_score = fid.compute_kid(gen_folder, dataset_name=f"{dataset}_all", mode='clean', dataset_split="custom",
|
154 |
+
verbose=True, use_dataparallel=False)
|
155 |
+
else: # single category
|
156 |
+
if "fid_score" in metrics2compute or "all" in metrics2compute:
|
157 |
+
|
158 |
+
# Check if FID stats exist, if not compute them
|
159 |
+
if not fid.test_stats_exists(f"{dataset}_{category}", mode='clean'):
|
160 |
+
make_custom_stats(dresscode_dataroot, vitonhd_dataroot)
|
161 |
+
|
162 |
+
# Compute FID score
|
163 |
+
fid_score = fid.compute_fid(os.path.join(gen_folder, category), dataset_name=f"{dataset}_{category}",
|
164 |
+
mode='clean', verbose=True, dataset_split="custom", use_dataparallel=False)
|
165 |
+
if "kid_score" in metrics2compute or "all" in metrics2compute:
|
166 |
+
# Check if KID stats exist, if not compute them
|
167 |
+
if not fid.test_stats_exists(f"{dataset}_{category}", mode='clean'):
|
168 |
+
make_custom_stats(dresscode_dataroot, vitonhd_dataroot)
|
169 |
+
|
170 |
+
# Compute KID score
|
171 |
+
kid_score = fid.compute_kid(os.path.join(gen_folder, category),
|
172 |
+
dataset_name=f"{dataset}_{category}", mode='clean', verbose=True,
|
173 |
+
dataset_split="custom", use_dataparallel=False)
|
174 |
+
|
175 |
+
# Define transforms, datasets and loaders
|
176 |
+
trans = transforms.Compose([
|
177 |
+
transforms.Resize(generated_size),
|
178 |
+
transforms.ToTensor(),
|
179 |
+
])
|
180 |
+
|
181 |
+
gen_dataset = GenTestDataset(gen_folder, category, transform=trans)
|
182 |
+
gt_dataset = GTTestDataset(gt_folder, dataset, category, trans)
|
183 |
+
|
184 |
+
gen_loader = DataLoader(gen_dataset, batch_size=batch_size, shuffle=False, num_workers=workers)
|
185 |
+
gt_loader = DataLoader(gt_dataset, batch_size=batch_size, shuffle=False, num_workers=workers)
|
186 |
+
|
187 |
+
# Define metrics models
|
188 |
+
if "is_score" in metrics2compute or "all" in metrics2compute:
|
189 |
+
model_is = InceptionScore(normalize=True).to(device)
|
190 |
+
|
191 |
+
if "ssim_score" in metrics2compute or "all" in metrics2compute:
|
192 |
+
ssim = StructuralSimilarityIndexMeasure(data_range=1.0).to(device)
|
193 |
+
|
194 |
+
if "lpips_score" in metrics2compute or "all" in metrics2compute:
|
195 |
+
lpips = LearnedPerceptualImagePatchSimilarity(net='alex', normalize=True).to(device)
|
196 |
+
|
197 |
+
for idx, (gen_batch, gt_batch) in tqdm(enumerate(zip(gen_loader, gt_loader)), total=len(gt_loader)):
|
198 |
+
gen_images, gen_names = gen_batch
|
199 |
+
gt_images, gt_names = gt_batch
|
200 |
+
|
201 |
+
assert gen_names == gt_names # Be sure that the images are in the same order
|
202 |
+
|
203 |
+
gen_images = gen_images.to(device)
|
204 |
+
gt_images = gt_images.to(device)
|
205 |
+
|
206 |
+
if "is_score" in metrics2compute or "all" in metrics2compute:
|
207 |
+
model_is.update(gen_images)
|
208 |
+
|
209 |
+
if "ssim_score" in metrics2compute or "all" in metrics2compute:
|
210 |
+
ssim.update(gen_images, gt_images)
|
211 |
+
|
212 |
+
if "lpips_score" in metrics2compute or "all" in metrics2compute:
|
213 |
+
lpips.update(gen_images, gt_images)
|
214 |
+
|
215 |
+
if "is_score" in metrics2compute or "all" in metrics2compute:
|
216 |
+
is_score, is_std = model_is.compute()
|
217 |
+
if "ssim_score" in metrics2compute or "all" in metrics2compute:
|
218 |
+
ssim_score = ssim.compute()
|
219 |
+
if "lpips_score" in metrics2compute or "all" in metrics2compute:
|
220 |
+
lpips_score = lpips.compute()
|
221 |
+
|
222 |
+
results = {}
|
223 |
+
|
224 |
+
for m in metrics2compute:
|
225 |
+
if torch.is_tensor(locals()[m]):
|
226 |
+
results[m] = locals()[m].item()
|
227 |
+
else:
|
228 |
+
results[m] = locals()[m]
|
229 |
+
return results
|
230 |
+
|
231 |
+
|
232 |
+
if __name__ == '__main__':
|
233 |
+
parser = argparse.ArgumentParser(description="Compute the metrics for the generated images")
|
234 |
+
parser.add_argument("--gen_folder", type=str, required=True, help="Path to the generated images")
|
235 |
+
parser.add_argument('--dresscode_dataroot', type=str, help='DressCode dataroot')
|
236 |
+
parser.add_argument('--vitonhd_dataroot', type=str, help='VitonHD dataroot')
|
237 |
+
parser.add_argument("--test_order", type=str, required=True, choices=['paired', 'unpaired'])
|
238 |
+
parser.add_argument("--dataset", type=str, required=True, choices=['dresscode', 'vitonhd'],
|
239 |
+
help="Dataset to use for the metrics")
|
240 |
+
parser.add_argument("--category", type=str, choices=['all', 'lower_body', 'upper_body', 'dresses'], default='all')
|
241 |
+
parser.add_argument("--batch_size", type=int, default=32, help="Batch size for the dataloaders")
|
242 |
+
parser.add_argument("--workers", type=int, default=8, help="Number of workers for the dataloaders")
|
243 |
+
|
244 |
+
args = parser.parse_args()
|
245 |
+
|
246 |
+
# Check if the dataset dataroot is provided
|
247 |
+
if args.dataset == "vitonhd" and args.vitonhd_dataroot is None:
|
248 |
+
raise ValueError("VitonHD dataroot must be provided")
|
249 |
+
if args.dataset == "dresscode" and args.dresscode_dataroot is None:
|
250 |
+
raise ValueError("DressCode dataroot must be provided")
|
251 |
+
|
252 |
+
# Check if the generated images folder exists
|
253 |
+
if not os.path.exists(args.gen_folder):
|
254 |
+
raise ValueError("The generated images folder does not exist")
|
255 |
+
|
256 |
+
metrics = compute_metrics(args.gen_folder, args.test_order, args.dataset, args.category, ['all'],
|
257 |
+
args.dresscode_dataroot, args.vitonhd_dataroot, batch_size=args.batch_size,
|
258 |
+
workers=args.workers)
|
259 |
+
|
260 |
+
# Print the metrics
|
261 |
+
for k, v in metrics.items():
|
262 |
+
if isinstance(v, float):
|
263 |
+
print(f"{k}: {v:.4f}")
|
264 |
+
else:
|
265 |
+
print(f"{k}: {v}")
|
266 |
+
|
267 |
+
with open(os.path.join(args.gen_folder, f"metrics_{args.test_order}_{args.category}.json"), "w+") as f:
|
268 |
+
json.dump(metrics, f, indent=4)
|