Jinglong Xiong commited on
Commit
c0f4df5
·
1 Parent(s): 8111433

add analysis script

Browse files
Files changed (1) hide show
  1. eval_analysis.py +299 -0
eval_analysis.py CHANGED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ import seaborn as sns
5
+ import json
6
+ from pathlib import Path
7
+
8
+ # Set style
9
+ plt.style.use('ggplot')
10
+ sns.set_palette("Set2")
11
+ plt.rcParams['figure.figsize'] = (12, 8)
12
+
13
+ # Load the data
14
+ results_csv = "results/summary_20250421_230054.csv"
15
+ results_json = "results/results_20250421_230054.json"
16
+
17
+ df = pd.read_csv(results_csv)
18
+
19
+ # Extract category from description if not already available
20
+ def extract_category(row):
21
+ """
22
+ Determines the category of an image based on its description or existing category.
23
+
24
+ Args:
25
+ row: A pandas DataFrame row containing 'category' and 'description' fields
26
+
27
+ Returns:
28
+ str: The determined category ('fashion', 'landscape', 'abstract', or 'unknown')
29
+ """
30
+ if pd.notna(row['category']) and row['category'] != 'unknown':
31
+ return row['category']
32
+
33
+ # Try to extract from description
34
+ desc = row['description'].lower()
35
+ if any(keyword in desc for keyword in ['coat', 'pants', 'shirt', 'dress', 'scarf', 'shoes']):
36
+ return 'fashion'
37
+ elif any(keyword in desc for keyword in ['forest', 'beach', 'mountain', 'ocean', 'lake', 'sky']):
38
+ return 'landscape'
39
+ elif any(keyword in desc for keyword in ['rectangle', 'circle', 'triangle', 'shape', 'spiral']):
40
+ return 'abstract'
41
+ else:
42
+ return 'unknown'
43
+
44
+ # Clean the data
45
+ df['category'] = df.apply(extract_category, axis=1)
46
+ df['generation_time'] = pd.to_numeric(df['generation_time'], errors='coerce')
47
+
48
+ # 1. Model Performance Comparison
49
+ def plot_model_comparison():
50
+ """
51
+ Creates boxplots comparing model performance across three metrics:
52
+ VQA score, aesthetic score, and fidelity score.
53
+
54
+ Saves the resulting plot to 'results/model_comparison.png'.
55
+ """
56
+ fig, axes = plt.subplots(1, 3, figsize=(18, 6))
57
+
58
+ metrics = ['vqa_score', 'aesthetic_score', 'fidelity_score']
59
+ titles = ['VQA Score', 'Aesthetic Score', 'Fidelity Score']
60
+
61
+ for i, (metric, title) in enumerate(zip(metrics, titles)):
62
+ sns.boxplot(x='model', y=metric, data=df, ax=axes[i])
63
+ axes[i].set_title(f'{title} by Model')
64
+ axes[i].set_ylim([0, 1])
65
+
66
+ plt.tight_layout()
67
+ plt.savefig('results/model_comparison.png')
68
+ plt.close()
69
+
70
+ # 2. Category Performance Analysis
71
+ def plot_category_performance():
72
+ """
73
+ Creates boxplots showing performance by category and model for three metrics:
74
+ VQA score, aesthetic score, and fidelity score.
75
+
76
+ Saves the resulting plot to 'results/category_performance.png'.
77
+ """
78
+ fig, axes = plt.subplots(1, 3, figsize=(18, 6))
79
+
80
+ metrics = ['vqa_score', 'aesthetic_score', 'fidelity_score']
81
+ titles = ['VQA Score', 'Aesthetic Score', 'Fidelity Score']
82
+
83
+ for i, (metric, title) in enumerate(zip(metrics, titles)):
84
+ sns.boxplot(x='category', y=metric, hue='model', data=df, ax=axes[i])
85
+ axes[i].set_title(f'{title} by Category and Model')
86
+ axes[i].set_ylim([0, 1])
87
+ if i > 0:
88
+ axes[i].get_legend().remove()
89
+
90
+ axes[0].legend(title='Model')
91
+ plt.tight_layout()
92
+ plt.savefig('results/category_performance.png')
93
+ plt.close()
94
+
95
+ # 3. Generation Time Analysis
96
+ def plot_generation_time():
97
+ """
98
+ Creates visualizations of generation time analysis:
99
+ 1. A boxplot showing generation time by model
100
+ 2. Scatter plots showing the relationship between generation time and quality metrics
101
+
102
+ Saves the resulting plots to 'results/generation_time.png' and 'results/quality_vs_time.png'.
103
+ """
104
+ plt.figure(figsize=(10, 6))
105
+ sns.boxplot(x='model', y='generation_time', data=df)
106
+ plt.title('Generation Time by Model')
107
+ plt.ylabel('Time (seconds)')
108
+ plt.tight_layout()
109
+ plt.savefig('results/generation_time.png')
110
+ plt.close()
111
+
112
+ # Generation time vs quality scatter plot
113
+ fig, axes = plt.subplots(1, 3, figsize=(18, 6))
114
+
115
+ metrics = ['vqa_score', 'aesthetic_score', 'fidelity_score']
116
+ titles = ['VQA Score', 'Aesthetic Score', 'Fidelity Score']
117
+
118
+ for i, (metric, title) in enumerate(zip(metrics, titles)):
119
+ for model, color in zip(df['model'].unique(), ['#1f77b4', '#ff7f0e']):
120
+ model_data = df[df['model'] == model]
121
+ axes[i].scatter(model_data['generation_time'], model_data[metric],
122
+ alpha=0.6, label=model, c=color)
123
+
124
+ axes[i].set_title(f'{title} vs. Generation Time')
125
+ axes[i].set_xlabel('Generation Time (seconds)')
126
+ axes[i].set_ylabel(title)
127
+ axes[i].legend()
128
+
129
+ plt.tight_layout()
130
+ plt.savefig('results/quality_vs_time.png')
131
+ plt.close()
132
+
133
+ # 4. Description complexity vs performance
134
+ def plot_complexity_performance():
135
+ """
136
+ Analyzes the relationship between description complexity (word count) and
137
+ performance metrics, creating scatter plots with trend lines.
138
+
139
+ Saves the resulting plot to 'results/complexity_performance.png'.
140
+ """
141
+ df['description_length'] = df['description'].str.len()
142
+ df['word_count'] = df['description'].str.split().str.len()
143
+
144
+ fig, axes = plt.subplots(1, 3, figsize=(18, 6))
145
+
146
+ metrics = ['vqa_score', 'aesthetic_score', 'fidelity_score']
147
+ titles = ['VQA Score', 'Aesthetic Score', 'Fidelity Score']
148
+
149
+ for i, (metric, title) in enumerate(zip(metrics, titles)):
150
+ for model, color in zip(df['model'].unique(), ['#1f77b4', '#ff7f0e']):
151
+ model_data = df[df['model'] == model]
152
+ axes[i].scatter(model_data['word_count'], model_data[metric],
153
+ alpha=0.6, label=model, c=color)
154
+
155
+ # Add trendline
156
+ z = np.polyfit(model_data['word_count'], model_data[metric], 1)
157
+ p = np.poly1d(z)
158
+ axes[i].plot(sorted(model_data['word_count']), p(sorted(model_data['word_count'])),
159
+ c=color, linestyle='--')
160
+
161
+ axes[i].set_title(f'{title} vs. Description Complexity')
162
+ axes[i].set_xlabel('Word Count')
163
+ axes[i].set_ylabel(title)
164
+ axes[i].legend()
165
+
166
+ plt.tight_layout()
167
+ plt.savefig('results/complexity_performance.png')
168
+ plt.close()
169
+
170
+ # 5. Success and failure examples
171
+ def analyze_best_worst_examples():
172
+ """
173
+ Identifies and prints the top 10 most successful and least successful generations
174
+ based on fidelity score.
175
+
176
+ Creates directories for sample SVG and PNG files if they don't exist.
177
+
178
+ Returns:
179
+ tuple: (success_df, failure_df) DataFrames containing the best and worst examples
180
+ """
181
+ # Create directory for result samples
182
+ Path("results/sample_svg").mkdir(exist_ok=True)
183
+ Path("results/sample_png").mkdir(exist_ok=True)
184
+
185
+ # Load detailed results
186
+ with open(results_json, 'r') as f:
187
+ results_data = json.load(f)
188
+
189
+ # Create success/failure dataframes
190
+ success_df = df.nlargest(10, 'fidelity_score')
191
+ failure_df = df.nsmallest(10, 'fidelity_score')
192
+
193
+ # Print success examples
194
+ print("Top 10 Successful Generations:")
195
+ print(success_df[['model', 'description', 'vqa_score', 'aesthetic_score', 'fidelity_score']].to_string(index=False))
196
+
197
+ # Print failure examples
198
+ print("\nTop 10 Failed Generations:")
199
+ print(failure_df[['model', 'description', 'vqa_score', 'aesthetic_score', 'fidelity_score']].to_string(index=False))
200
+
201
+ return success_df, failure_df
202
+
203
+ # 6. Summary statistics
204
+ def print_summary_stats():
205
+ """
206
+ Calculates and prints summary statistics for model performance:
207
+ 1. Overall stats by model (mean, std, min, max for each metric)
208
+ 2. Performance by category and model
209
+
210
+ Also creates a radar chart visualizing fidelity scores by category and model,
211
+ saved to 'results/category_radar.png'.
212
+ """
213
+ # Overall stats by model
214
+ model_stats = df.groupby('model').agg({
215
+ 'vqa_score': ['mean', 'std', 'min', 'max'],
216
+ 'aesthetic_score': ['mean', 'std', 'min', 'max'],
217
+ 'fidelity_score': ['mean', 'std', 'min', 'max'],
218
+ 'generation_time': ['mean', 'std', 'min', 'max']
219
+ })
220
+
221
+ print("Overall Model Performance:")
222
+ print(model_stats)
223
+
224
+ # Stats by category and model
225
+ category_stats = df.groupby(['model', 'category']).agg({
226
+ 'vqa_score': 'mean',
227
+ 'aesthetic_score': 'mean',
228
+ 'fidelity_score': 'mean',
229
+ 'generation_time': 'mean'
230
+ }).reset_index()
231
+
232
+ print("\nPerformance by Category and Model:")
233
+ print(category_stats.to_string())
234
+
235
+ # Create a radar chart for category performance
236
+ categories = category_stats['category'].unique()
237
+ models = category_stats['model'].unique()
238
+
239
+ plt.figure(figsize=(10, 8))
240
+ angles = np.linspace(0, 2*np.pi, len(categories), endpoint=False).tolist()
241
+ angles += angles[:1] # Close the loop
242
+
243
+ ax = plt.subplot(111, polar=True)
244
+
245
+ for model in models:
246
+ model_data = category_stats[category_stats['model'] == model]
247
+ values = []
248
+ for category in categories:
249
+ cat_data = model_data[model_data['category'] == category]
250
+ if not cat_data.empty:
251
+ values.append(cat_data['fidelity_score'].values[0])
252
+ else:
253
+ values.append(0)
254
+ values += values[:1] # Close the loop
255
+
256
+ ax.plot(angles, values, linewidth=2, label=model)
257
+ ax.fill(angles, values, alpha=0.25)
258
+
259
+ ax.set_xticks(angles[:-1])
260
+ ax.set_xticklabels(categories)
261
+ ax.set_title('Fidelity Score by Category and Model')
262
+ ax.legend(loc='upper right')
263
+
264
+ plt.tight_layout()
265
+ plt.savefig('results/category_radar.png')
266
+ plt.close()
267
+
268
+ # Main analysis function
269
+ def run_analysis():
270
+ """
271
+ Main function that runs the complete analysis pipeline:
272
+ 1. Creates necessary directories
273
+ 2. Generates all visualization plots
274
+ 3. Prints summary statistics
275
+ 4. Analyzes best and worst examples
276
+
277
+ All results are saved to the 'results/' directory.
278
+ """
279
+ print("Starting analysis of evaluation results...")
280
+
281
+ # Create plots directory if it doesn't exist
282
+ Path("results").mkdir(exist_ok=True)
283
+
284
+ # Generate all plots
285
+ plot_model_comparison()
286
+ plot_category_performance()
287
+ plot_generation_time()
288
+ plot_complexity_performance()
289
+
290
+ # Print summary statistics
291
+ print_summary_stats()
292
+
293
+ # Analyze best and worst examples
294
+ success_df, failure_df = analyze_best_worst_examples()
295
+
296
+ print("\nAnalysis complete. Visualizations saved to 'results/' directory.")
297
+
298
+ if __name__ == "__main__":
299
+ run_analysis()