zjuJish commited on
Commit
fbc778e
·
verified ·
1 Parent(s): ac43f48

Upload VITON-HD/eval/tryon-fid.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. VITON-HD/eval/tryon-fid.py +268 -0
VITON-HD/eval/tryon-fid.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ from typing import List, Tuple, Dict
5
+
6
+ import PIL.Image
7
+ import torch
8
+ from cleanfid import fid
9
+ from torch.utils.data import ConcatDataset, DataLoader
10
+ from torchmetrics import StructuralSimilarityIndexMeasure
11
+ from torchmetrics.image.inception import InceptionScore
12
+ from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
13
+ from torchvision import transforms
14
+ from tqdm import tqdm
15
+
16
+ from .generate_fid_stats import make_custom_stats
17
+
18
+
19
+ class GTTestDataset(torch.utils.data.Dataset):
20
+ def __init__(self, dataroot: str, dataset: str, category: str, transform: transforms.Compose):
21
+ """
22
+ Dataset for the ground truth test images
23
+ """
24
+
25
+ # Validate inputs
26
+ assert dataset in ['dresscode', 'vitonhd'], 'Unsupported dataset'
27
+ assert category in ['all', 'dresses', 'lower_body', 'upper_body'], 'Unsupported category'
28
+
29
+ self.dataset = dataset
30
+ self.category = category
31
+ self.transform = transform
32
+ self.dataroot = dataroot
33
+
34
+ # Get the paths to the images
35
+ if dataset == 'dresscode':
36
+ filepath = os.path.join(dataroot, f"test_pairs_paired.txt")
37
+ with open(filepath, 'r') as f:
38
+ lines = f.read().splitlines()
39
+
40
+ if category in ['lower_body', 'upper_body', 'dresses']:
41
+ self.paths = sorted(
42
+ [os.path.join(dataroot, category, 'images', line.strip().split()[0]) for line in lines if
43
+ os.path.exists(os.path.join(dataroot, category, 'images', line.strip().split()[0]))])
44
+ else:
45
+ self.paths = sorted(
46
+ [os.path.join(dataroot, category, 'images', line.strip().split()[0]) for line in lines for
47
+ category in ['lower_body', 'upper_body', 'dresses'] if
48
+ os.path.exists(os.path.join(dataroot, category, 'images', line.strip().split()[0]))])
49
+ else: # vitonhd
50
+ filepath = os.path.join(dataroot, f"test_pairs.txt")
51
+ with open(filepath, 'r') as f:
52
+ lines = f.read().splitlines()
53
+ self.paths = sorted([os.path.join(dataroot, 'test', 'image', line.strip().split()[0]) for line in lines])
54
+
55
+ def __len__(self):
56
+ return len(self.paths)
57
+
58
+ def __getitem__(self, idx):
59
+ path = self.paths[idx]
60
+ name = os.path.splitext(os.path.basename(path))[0]
61
+ img = self.transform(PIL.Image.open(path).convert('RGB'))
62
+ return img, name
63
+
64
+
65
+ class GenTestDataset(torch.utils.data.Dataset):
66
+ def __init__(self, gen_folder: str, category: str, transform: transforms.Compose):
67
+ """
68
+ Dataset for the ground truth test images
69
+ """
70
+
71
+ # Validate inputs
72
+ assert category in ['all', 'dresses', 'lower_body', 'upper_body'], 'Unsupported category'
73
+
74
+ self.category = category
75
+ self.transform = transform
76
+ self.gen_folder = gen_folder
77
+
78
+ # Get the paths to the images
79
+ if category in ['lower_body', 'upper_body', 'dresses']:
80
+ self.paths = sorted(
81
+ [os.path.join(gen_folder, category, name) for name in os.listdir(os.path.join(gen_folder, category))])
82
+ elif category == 'all':
83
+ existing_categories = []
84
+ for category in ['lower_body', 'upper_body', 'dresses']:
85
+ if os.path.exists(os.path.join(gen_folder, category)):
86
+ existing_categories.append(category)
87
+
88
+ self.paths = sorted(
89
+ [os.path.join(gen_folder, category, name) for category in existing_categories for
90
+ name in os.listdir(os.path.join(gen_folder, category)) if
91
+ os.path.exists(os.path.join(gen_folder, category, name))])
92
+ else:
93
+ raise ValueError('Unsupported category')
94
+
95
+ def __len__(self):
96
+ return len(self.paths)
97
+
98
+ def __getitem__(self, idx):
99
+ path = self.paths[idx]
100
+ name = os.path.splitext(os.path.basename(path))[0]
101
+ img = self.transform(PIL.Image.open(path).convert('RGB'))
102
+ return img, name
103
+
104
+
105
+ # metrics = compute_metrics(args.gen_folder, args.test_order, args.dataset, args.category, ['all'],
106
+ # args.dresscode_dataroot, args.vitonhd_dataroot, batch_size=args.batch_size,
107
+ # workers=args.workers)
108
+
109
+ def compute_metrics(gen_folder: str, test_order: str, dataset: str, category: str, metrics2compute: List[str],
110
+ dresscode_dataroot: str, vitonhd_dataroot: str, generated_size: Tuple[int, int] = (512, 384),
111
+ batch_size: int = 32, workers: int = 8) -> Dict[str, float]:
112
+ """
113
+ Computes the metrics for the generated images in gen_folder
114
+ """
115
+
116
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
117
+
118
+ # Input validation
119
+ assert test_order in ['paired', 'unpaired']
120
+ assert dataset in ['dresscode', 'vitonhd'], 'Unsupported dataset'
121
+ assert category in ['all', 'dresses', 'lower_body', 'upper_body'], 'Unsupported category'
122
+
123
+ if dataset == 'dresscode':
124
+ gt_folder = dresscode_dataroot
125
+ elif dataset == 'vitonhd':
126
+ gt_folder = vitonhd_dataroot
127
+ else:
128
+ raise ValueError('Unsupported dataset')
129
+
130
+ for m in metrics2compute:
131
+ assert m in ['all', 'ssim_score', 'lpips_score', 'fid_score', 'kid_score', 'is_score'], 'Unsupported metric'
132
+
133
+ if metrics2compute == ['all']:
134
+ metrics2compute = ['ssim_score', 'lpips_score', 'fid_score', 'kid_score', 'is_score']
135
+
136
+ # Compute FID and KID scores
137
+ if category == 'all':
138
+ if "fid_score" in metrics2compute or "all" in metrics2compute:
139
+ # Check if FID stats exist, if not compute them
140
+ if not fid.test_stats_exists(f"{dataset}_all", mode='clean'):
141
+ make_custom_stats(dresscode_dataroot, vitonhd_dataroot)
142
+
143
+ # Compute FID score
144
+ fid_score = fid.compute_fid(gen_folder, dataset_name=f"{dataset}_all", mode='clean', dataset_split="custom",
145
+ verbose=True, use_dataparallel=False)
146
+ if "kid_score" in metrics2compute or "all" in metrics2compute:
147
+
148
+ # Check if KID stats exist, if not compute them
149
+ if not fid.test_stats_exists(f"{dataset}_all", mode='clean'):
150
+ make_custom_stats(dresscode_dataroot, vitonhd_dataroot)
151
+
152
+ # Compute FID score
153
+ kid_score = fid.compute_kid(gen_folder, dataset_name=f"{dataset}_all", mode='clean', dataset_split="custom",
154
+ verbose=True, use_dataparallel=False)
155
+ else: # single category
156
+ if "fid_score" in metrics2compute or "all" in metrics2compute:
157
+
158
+ # Check if FID stats exist, if not compute them
159
+ if not fid.test_stats_exists(f"{dataset}_{category}", mode='clean'):
160
+ make_custom_stats(dresscode_dataroot, vitonhd_dataroot)
161
+
162
+ # Compute FID score
163
+ fid_score = fid.compute_fid(os.path.join(gen_folder, category), dataset_name=f"{dataset}_{category}",
164
+ mode='clean', verbose=True, dataset_split="custom", use_dataparallel=False)
165
+ if "kid_score" in metrics2compute or "all" in metrics2compute:
166
+ # Check if KID stats exist, if not compute them
167
+ if not fid.test_stats_exists(f"{dataset}_{category}", mode='clean'):
168
+ make_custom_stats(dresscode_dataroot, vitonhd_dataroot)
169
+
170
+ # Compute KID score
171
+ kid_score = fid.compute_kid(os.path.join(gen_folder, category),
172
+ dataset_name=f"{dataset}_{category}", mode='clean', verbose=True,
173
+ dataset_split="custom", use_dataparallel=False)
174
+
175
+ # Define transforms, datasets and loaders
176
+ trans = transforms.Compose([
177
+ transforms.Resize(generated_size),
178
+ transforms.ToTensor(),
179
+ ])
180
+
181
+ gen_dataset = GenTestDataset(gen_folder, category, transform=trans)
182
+ gt_dataset = GTTestDataset(gt_folder, dataset, category, trans)
183
+
184
+ gen_loader = DataLoader(gen_dataset, batch_size=batch_size, shuffle=False, num_workers=workers)
185
+ gt_loader = DataLoader(gt_dataset, batch_size=batch_size, shuffle=False, num_workers=workers)
186
+
187
+ # Define metrics models
188
+ if "is_score" in metrics2compute or "all" in metrics2compute:
189
+ model_is = InceptionScore(normalize=True).to(device)
190
+
191
+ if "ssim_score" in metrics2compute or "all" in metrics2compute:
192
+ ssim = StructuralSimilarityIndexMeasure(data_range=1.0).to(device)
193
+
194
+ if "lpips_score" in metrics2compute or "all" in metrics2compute:
195
+ lpips = LearnedPerceptualImagePatchSimilarity(net='alex', normalize=True).to(device)
196
+
197
+ for idx, (gen_batch, gt_batch) in tqdm(enumerate(zip(gen_loader, gt_loader)), total=len(gt_loader)):
198
+ gen_images, gen_names = gen_batch
199
+ gt_images, gt_names = gt_batch
200
+
201
+ assert gen_names == gt_names # Be sure that the images are in the same order
202
+
203
+ gen_images = gen_images.to(device)
204
+ gt_images = gt_images.to(device)
205
+
206
+ if "is_score" in metrics2compute or "all" in metrics2compute:
207
+ model_is.update(gen_images)
208
+
209
+ if "ssim_score" in metrics2compute or "all" in metrics2compute:
210
+ ssim.update(gen_images, gt_images)
211
+
212
+ if "lpips_score" in metrics2compute or "all" in metrics2compute:
213
+ lpips.update(gen_images, gt_images)
214
+
215
+ if "is_score" in metrics2compute or "all" in metrics2compute:
216
+ is_score, is_std = model_is.compute()
217
+ if "ssim_score" in metrics2compute or "all" in metrics2compute:
218
+ ssim_score = ssim.compute()
219
+ if "lpips_score" in metrics2compute or "all" in metrics2compute:
220
+ lpips_score = lpips.compute()
221
+
222
+ results = {}
223
+
224
+ for m in metrics2compute:
225
+ if torch.is_tensor(locals()[m]):
226
+ results[m] = locals()[m].item()
227
+ else:
228
+ results[m] = locals()[m]
229
+ return results
230
+
231
+
232
+ if __name__ == '__main__':
233
+ parser = argparse.ArgumentParser(description="Compute the metrics for the generated images")
234
+ parser.add_argument("--gen_folder", type=str, required=True, help="Path to the generated images")
235
+ parser.add_argument('--dresscode_dataroot', type=str, help='DressCode dataroot')
236
+ parser.add_argument('--vitonhd_dataroot', type=str, help='VitonHD dataroot')
237
+ parser.add_argument("--test_order", type=str, required=True, choices=['paired', 'unpaired'])
238
+ parser.add_argument("--dataset", type=str, required=True, choices=['dresscode', 'vitonhd'],
239
+ help="Dataset to use for the metrics")
240
+ parser.add_argument("--category", type=str, choices=['all', 'lower_body', 'upper_body', 'dresses'], default='all')
241
+ parser.add_argument("--batch_size", type=int, default=32, help="Batch size for the dataloaders")
242
+ parser.add_argument("--workers", type=int, default=8, help="Number of workers for the dataloaders")
243
+
244
+ args = parser.parse_args()
245
+
246
+ # Check if the dataset dataroot is provided
247
+ if args.dataset == "vitonhd" and args.vitonhd_dataroot is None:
248
+ raise ValueError("VitonHD dataroot must be provided")
249
+ if args.dataset == "dresscode" and args.dresscode_dataroot is None:
250
+ raise ValueError("DressCode dataroot must be provided")
251
+
252
+ # Check if the generated images folder exists
253
+ if not os.path.exists(args.gen_folder):
254
+ raise ValueError("The generated images folder does not exist")
255
+
256
+ metrics = compute_metrics(args.gen_folder, args.test_order, args.dataset, args.category, ['all'],
257
+ args.dresscode_dataroot, args.vitonhd_dataroot, batch_size=args.batch_size,
258
+ workers=args.workers)
259
+
260
+ # Print the metrics
261
+ for k, v in metrics.items():
262
+ if isinstance(v, float):
263
+ print(f"{k}: {v:.4f}")
264
+ else:
265
+ print(f"{k}: {v}")
266
+
267
+ with open(os.path.join(args.gen_folder, f"metrics_{args.test_order}_{args.category}.json"), "w+") as f:
268
+ json.dump(metrics, f, indent=4)