carryover from evijit
Browse files- app.py +468 -788
- models_processed.parquet +3 -0
- preprocess.py +371 -0
app.py
CHANGED
@@ -1,846 +1,526 @@
|
|
|
|
|
|
1 |
import json
|
2 |
import gradio as gr
|
3 |
import pandas as pd
|
4 |
import plotly.express as px
|
5 |
import os
|
6 |
import numpy as np
|
7 |
-
import io
|
8 |
import duckdb
|
|
|
|
|
|
|
9 |
|
10 |
-
#
|
11 |
-
PIPELINE_TAGS = [
|
12 |
-
'text-generation',
|
13 |
-
'text-to-image',
|
14 |
-
'text-classification',
|
15 |
-
'text2text-generation',
|
16 |
-
'audio-to-audio',
|
17 |
-
'feature-extraction',
|
18 |
-
'image-classification',
|
19 |
-
'translation',
|
20 |
-
'reinforcement-learning',
|
21 |
-
'fill-mask',
|
22 |
-
'text-to-speech',
|
23 |
-
'automatic-speech-recognition',
|
24 |
-
'image-text-to-text',
|
25 |
-
'token-classification',
|
26 |
-
'sentence-similarity',
|
27 |
-
'question-answering',
|
28 |
-
'image-feature-extraction',
|
29 |
-
'summarization',
|
30 |
-
'zero-shot-image-classification',
|
31 |
-
'object-detection',
|
32 |
-
'image-segmentation',
|
33 |
-
'image-to-image',
|
34 |
-
'image-to-text',
|
35 |
-
'audio-classification',
|
36 |
-
'visual-question-answering',
|
37 |
-
'text-to-video',
|
38 |
-
'zero-shot-classification',
|
39 |
-
'depth-estimation',
|
40 |
-
'text-ranking',
|
41 |
-
'image-to-video',
|
42 |
-
'multiple-choice',
|
43 |
-
'unconditional-image-generation',
|
44 |
-
'video-classification',
|
45 |
-
'text-to-audio',
|
46 |
-
'time-series-forecasting',
|
47 |
-
'any-to-any',
|
48 |
-
'video-text-to-text',
|
49 |
-
'table-question-answering',
|
50 |
-
]
|
51 |
-
|
52 |
-
# Model size categories in GB
|
53 |
MODEL_SIZE_RANGES = {
|
54 |
-
"Small (<1GB)": (0, 1),
|
55 |
-
"
|
56 |
-
"Large (5-20GB)": (5, 20),
|
57 |
-
"X-Large (20-50GB)": (20, 50),
|
58 |
-
"XX-Large (>50GB)": (50, float('inf'))
|
59 |
}
|
|
|
|
|
60 |
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
def is_music(row):
|
67 |
-
# Use cached column instead of recalculating
|
68 |
-
return row['has_music']
|
69 |
-
|
70 |
-
def is_robotics(row):
|
71 |
-
# Use cached column instead of recalculating
|
72 |
-
return row['has_robot']
|
73 |
|
74 |
-
|
75 |
-
|
76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
|
78 |
-
def
|
79 |
-
|
80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
|
82 |
-
def
|
83 |
-
|
84 |
-
|
|
|
85 |
|
86 |
-
def
|
87 |
-
|
88 |
-
return row['has_video']
|
89 |
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
|
94 |
-
|
95 |
-
|
96 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
|
98 |
-
|
99 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
if
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
return "
|
114 |
-
return False
|
115 |
-
|
116 |
-
def is_text(row):
|
117 |
-
tags = row.get("tags", [])
|
118 |
|
119 |
-
|
120 |
-
if
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
return True
|
172 |
|
173 |
-
#
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
"
|
182 |
-
|
183 |
-
|
184 |
-
"
|
185 |
-
|
186 |
-
|
187 |
|
188 |
-
def extract_org_from_id(model_id):
|
189 |
-
"""Extract organization name from model ID"""
|
190 |
-
if "/" in model_id:
|
191 |
-
return model_id.split("/")[0]
|
192 |
-
return "unaffiliated"
|
193 |
|
194 |
def make_treemap_data(df, count_by, top_k=25, tag_filter=None, pipeline_filter=None, size_filter=None, skip_orgs=None):
|
195 |
-
|
196 |
-
# Create a copy to avoid modifying the original
|
197 |
filtered_df = df.copy()
|
|
|
|
|
|
|
198 |
|
199 |
-
#
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
filtered_df = filtered_df[filtered_df[
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
elif tag_filter == "Video":
|
221 |
-
filtered_df = filtered_df[filtered_df['has_video']]
|
222 |
-
elif tag_filter == "Images":
|
223 |
-
filtered_df = filtered_df[filtered_df['has_image']]
|
224 |
-
elif tag_filter == "Text":
|
225 |
-
filtered_df = filtered_df[filtered_df['has_text']]
|
226 |
-
|
227 |
-
filter_stats["after_tag_filter"] = len(filtered_df)
|
228 |
-
print(f"Tag filter applied in {(pd.Timestamp.now() - start_time).total_seconds():.3f} seconds")
|
229 |
-
start_time = pd.Timestamp.now()
|
230 |
-
|
231 |
-
# Apply pipeline filter
|
232 |
if pipeline_filter:
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
# Use the cached size_category column directly
|
244 |
-
filtered_df = filtered_df[filtered_df['size_category'] == size_filter]
|
245 |
-
|
246 |
-
# Debug info
|
247 |
-
print(f"Size filter '{size_filter}' applied.")
|
248 |
-
print(f"Models after size filter: {len(filtered_df)}")
|
249 |
-
|
250 |
-
filter_stats["after_size_filter"] = len(filtered_df)
|
251 |
-
print(f"Size filter applied in {(pd.Timestamp.now() - start_time).total_seconds():.3f} seconds")
|
252 |
-
start_time = pd.Timestamp.now()
|
253 |
-
|
254 |
-
# Add organization column
|
255 |
-
filtered_df["organization"] = filtered_df["id"].apply(extract_org_from_id)
|
256 |
-
|
257 |
-
# Skip organizations if specified
|
258 |
if skip_orgs and len(skip_orgs) > 0:
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
print("Warning: No data left after applying filters!")
|
270 |
-
return pd.DataFrame() # Return empty DataFrame
|
271 |
-
|
272 |
-
# Aggregate by organization
|
273 |
-
org_totals = filtered_df.groupby("organization")[count_by].sum().reset_index()
|
274 |
-
org_totals = org_totals.sort_values(by=count_by, ascending=False)
|
275 |
-
|
276 |
-
# Get top organizations
|
277 |
-
top_orgs = org_totals.head(top_k)["organization"].tolist()
|
278 |
-
|
279 |
-
# Filter to only include models from top organizations
|
280 |
-
filtered_df = filtered_df[filtered_df["organization"].isin(top_orgs)]
|
281 |
-
|
282 |
-
# Prepare data for treemap
|
283 |
-
treemap_data = filtered_df[["id", "organization", count_by]].copy()
|
284 |
-
|
285 |
-
# Add a root node
|
286 |
treemap_data["root"] = "models"
|
287 |
-
|
288 |
-
# Ensure numeric values
|
289 |
-
treemap_data[count_by] = pd.to_numeric(treemap_data[count_by], errors="coerce").fillna(0)
|
290 |
-
|
291 |
-
print(f"Treemap data prepared in {(pd.Timestamp.now() - start_time).total_seconds():.3f} seconds")
|
292 |
return treemap_data
|
293 |
|
294 |
def create_treemap(treemap_data, count_by, title=None):
|
295 |
-
"""Create a Plotly treemap from the prepared data"""
|
296 |
if treemap_data.empty:
|
297 |
-
|
298 |
-
fig =
|
299 |
-
names=["No data matches the selected filters"],
|
300 |
-
values=[1]
|
301 |
-
)
|
302 |
-
fig.update_layout(
|
303 |
-
title="No data matches the selected filters",
|
304 |
-
margin=dict(t=50, l=25, r=25, b=25)
|
305 |
-
)
|
306 |
return fig
|
307 |
-
|
308 |
-
# Create the treemap
|
309 |
fig = px.treemap(
|
310 |
-
treemap_data,
|
311 |
-
path=["root", "organization", "id"],
|
312 |
-
values=count_by,
|
313 |
title=title or f"HuggingFace Models - {count_by.capitalize()} by Organization",
|
314 |
color_discrete_sequence=px.colors.qualitative.Plotly
|
315 |
)
|
316 |
-
|
317 |
-
|
318 |
-
fig.update_layout(
|
319 |
-
margin=dict(t=50, l=25, r=25, b=25)
|
320 |
-
)
|
321 |
-
|
322 |
-
# Update traces for better readability
|
323 |
-
fig.update_traces(
|
324 |
-
textinfo="label+value+percent root",
|
325 |
-
hovertemplate="<b>%{label}</b><br>%{value:,} " + count_by + "<br>%{percentRoot:.2%} of total<extra></extra>"
|
326 |
-
)
|
327 |
-
|
328 |
return fig
|
329 |
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
340 |
try:
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
df = duckdb.sql(query).df()
|
353 |
-
except Exception as sql_error:
|
354 |
-
print(f"Error with specific column selection: {sql_error}")
|
355 |
-
# Fallback to just selecting everything and then filtering
|
356 |
-
print("Falling back to select * query...")
|
357 |
-
query = "SELECT * FROM read_parquet('https://huggingface.co/datasets/cfahlgren1/hub-stats/resolve/main/models.parquet')"
|
358 |
-
raw_df = duckdb.sql(query).df()
|
359 |
-
|
360 |
-
# Now extract only the columns we need
|
361 |
-
needed_columns = ['id', 'downloads', 'downloadsAllTime', 'likes', 'pipeline_tag', 'tags', 'safetensors']
|
362 |
-
available_columns = set(raw_df.columns)
|
363 |
-
df = pd.DataFrame()
|
364 |
-
|
365 |
-
# Copy over columns that exist
|
366 |
-
for col in needed_columns:
|
367 |
-
if col in available_columns:
|
368 |
-
df[col] = raw_df[col]
|
369 |
else:
|
370 |
-
|
371 |
-
if col in ['downloads', 'downloadsAllTime', 'likes']:
|
372 |
-
df[col] = 0
|
373 |
-
elif col == 'pipeline_tag':
|
374 |
-
df[col] = ''
|
375 |
-
elif col == 'tags':
|
376 |
-
df[col] = [[] for _ in range(len(raw_df))]
|
377 |
-
elif col == 'safetensors':
|
378 |
-
df[col] = None
|
379 |
-
elif col == 'id':
|
380 |
-
# Create IDs based on index if missing
|
381 |
-
df[col] = [f"model_{i}" for i in range(len(raw_df))]
|
382 |
-
|
383 |
-
print(f"Data fetched successfully. Shape: {df.shape}")
|
384 |
-
|
385 |
-
# Check if safetensors column exists before trying to process it
|
386 |
-
if 'safetensors' in df.columns:
|
387 |
-
# Add params column derived from safetensors.total (model size in GB)
|
388 |
-
df['params'] = df['safetensors'].apply(extract_model_size)
|
389 |
-
|
390 |
-
# Debug model sizes
|
391 |
-
size_ranges = {
|
392 |
-
"Small (<1GB)": 0,
|
393 |
-
"Medium (1-5GB)": 0,
|
394 |
-
"Large (5-20GB)": 0,
|
395 |
-
"X-Large (20-50GB)": 0,
|
396 |
-
"XX-Large (>50GB)": 0
|
397 |
-
}
|
398 |
-
|
399 |
-
# Count models in each size range
|
400 |
-
for idx, row in df.iterrows():
|
401 |
-
size_gb = row['params']
|
402 |
-
if 0 <= size_gb < 1:
|
403 |
-
size_ranges["Small (<1GB)"] += 1
|
404 |
-
elif 1 <= size_gb < 5:
|
405 |
-
size_ranges["Medium (1-5GB)"] += 1
|
406 |
-
elif 5 <= size_gb < 20:
|
407 |
-
size_ranges["Large (5-20GB)"] += 1
|
408 |
-
elif 20 <= size_gb < 50:
|
409 |
-
size_ranges["X-Large (20-50GB)"] += 1
|
410 |
-
elif size_gb >= 50:
|
411 |
-
size_ranges["XX-Large (>50GB)"] += 1
|
412 |
-
|
413 |
-
print("Model size distribution:")
|
414 |
-
for size_range, count in size_ranges.items():
|
415 |
-
print(f" {size_range}: {count} models")
|
416 |
-
|
417 |
-
# CACHE SIZE CATEGORY: Add a size_category column for faster filtering
|
418 |
-
def get_size_category(size_gb):
|
419 |
-
if 0 <= size_gb < 1:
|
420 |
-
return "Small (<1GB)"
|
421 |
-
elif 1 <= size_gb < 5:
|
422 |
-
return "Medium (1-5GB)"
|
423 |
-
elif 5 <= size_gb < 20:
|
424 |
-
return "Large (5-20GB)"
|
425 |
-
elif 20 <= size_gb < 50:
|
426 |
-
return "X-Large (20-50GB)"
|
427 |
-
elif size_gb >= 50:
|
428 |
-
return "XX-Large (>50GB)"
|
429 |
-
return None
|
430 |
-
|
431 |
-
# Add cached size category column
|
432 |
-
df['size_category'] = df['params'].apply(get_size_category)
|
433 |
-
|
434 |
-
# Remove the safetensors column as we don't need it anymore
|
435 |
-
df = df.drop(columns=['safetensors'])
|
436 |
-
else:
|
437 |
-
# If no safetensors column, add empty params column
|
438 |
-
df['params'] = 0
|
439 |
-
df['size_category'] = None
|
440 |
-
|
441 |
-
# Process tags to ensure it's in the right format - FIXED
|
442 |
-
def process_tags(tags_value):
|
443 |
-
try:
|
444 |
-
if pd.isna(tags_value) or tags_value is None:
|
445 |
-
return []
|
446 |
-
|
447 |
-
# If it's a numpy array, convert to a list of strings
|
448 |
-
if hasattr(tags_value, 'dtype') and hasattr(tags_value, 'tolist'):
|
449 |
-
# Note: This is the fix for the error
|
450 |
-
return [str(tag) for tag in tags_value.tolist()]
|
451 |
-
|
452 |
-
# If already a list, ensure all elements are strings
|
453 |
-
if isinstance(tags_value, list):
|
454 |
-
return [str(tag) for tag in tags_value]
|
455 |
|
456 |
-
|
457 |
-
if
|
458 |
-
|
459 |
-
|
460 |
-
|
461 |
-
|
462 |
-
|
463 |
-
# Split by comma if JSON parsing fails
|
464 |
-
return [tag.strip() for tag in tags_value.split(',') if tag.strip()]
|
465 |
|
466 |
-
|
467 |
-
|
|
|
|
|
|
|
468 |
|
469 |
-
|
470 |
-
|
471 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
472 |
|
473 |
-
#
|
474 |
-
if '
|
475 |
-
|
476 |
-
|
477 |
-
|
478 |
-
# CACHE TAG CATEGORIES: Pre-calculate tag categories for faster filtering
|
479 |
-
print("Pre-calculating cached tag categories...")
|
480 |
-
|
481 |
-
# Helper functions to check for specific tags (simplified for caching)
|
482 |
-
def has_audio_tag(tags):
|
483 |
-
if tags and isinstance(tags, list):
|
484 |
-
return any("audio" in str(tag).lower() for tag in tags)
|
485 |
-
return False
|
486 |
-
|
487 |
-
def has_speech_tag(tags):
|
488 |
-
if tags and isinstance(tags, list):
|
489 |
-
return any("speech" in str(tag).lower() for tag in tags)
|
490 |
-
return False
|
491 |
-
|
492 |
-
def has_music_tag(tags):
|
493 |
-
if tags and isinstance(tags, list):
|
494 |
-
return any("music" in str(tag).lower() for tag in tags)
|
495 |
-
return False
|
496 |
-
|
497 |
-
def has_robot_tag(tags):
|
498 |
-
if tags and isinstance(tags, list):
|
499 |
-
return any("robot" in str(tag).lower() for tag in tags)
|
500 |
-
return False
|
501 |
-
|
502 |
-
def has_bio_tag(tags):
|
503 |
-
if tags and isinstance(tags, list):
|
504 |
-
return any("bio" in str(tag).lower() for tag in tags)
|
505 |
-
return False
|
506 |
-
|
507 |
-
def has_med_tag(tags):
|
508 |
-
if tags and isinstance(tags, list):
|
509 |
-
return any("medic" in str(tag).lower() for tag in tags)
|
510 |
-
return False
|
511 |
-
|
512 |
-
def has_series_tag(tags):
|
513 |
-
if tags and isinstance(tags, list):
|
514 |
-
return any("series" in str(tag).lower() for tag in tags)
|
515 |
-
return False
|
516 |
-
|
517 |
-
def has_science_tag(tags):
|
518 |
-
if tags and isinstance(tags, list):
|
519 |
-
return any("science" in str(tag).lower() and "bigscience" not in str(tag).lower() for tag in tags)
|
520 |
-
return False
|
521 |
-
|
522 |
-
def has_video_tag(tags):
|
523 |
-
if tags and isinstance(tags, list):
|
524 |
-
return any("video" in str(tag).lower() for tag in tags)
|
525 |
-
return False
|
526 |
-
|
527 |
-
def has_image_tag(tags):
|
528 |
-
if tags and isinstance(tags, list):
|
529 |
-
return any("image" in str(tag).lower() for tag in tags)
|
530 |
-
return False
|
531 |
-
|
532 |
-
def has_text_tag(tags):
|
533 |
-
if tags and isinstance(tags, list):
|
534 |
-
return any("text" in str(tag).lower() for tag in tags)
|
535 |
-
return False
|
536 |
-
|
537 |
-
# Add cached columns for tag categories
|
538 |
-
print("Creating cached tag columns...")
|
539 |
-
df['has_audio'] = df['tags'].apply(has_audio_tag)
|
540 |
-
df['has_speech'] = df['tags'].apply(has_speech_tag)
|
541 |
-
df['has_music'] = df['tags'].apply(has_music_tag)
|
542 |
-
df['has_robot'] = df['tags'].apply(has_robot_tag)
|
543 |
-
df['has_bio'] = df['tags'].apply(has_bio_tag)
|
544 |
-
df['has_med'] = df['tags'].apply(has_med_tag)
|
545 |
-
df['has_series'] = df['tags'].apply(has_series_tag)
|
546 |
-
df['has_science'] = df['tags'].apply(has_science_tag)
|
547 |
-
df['has_video'] = df['tags'].apply(has_video_tag)
|
548 |
-
df['has_image'] = df['tags'].apply(has_image_tag)
|
549 |
-
df['has_text'] = df['tags'].apply(has_text_tag)
|
550 |
-
|
551 |
-
# Create combined category flags for faster filtering
|
552 |
-
df['is_audio_speech'] = (df['has_audio'] | df['has_speech'] |
|
553 |
-
df['pipeline_tag'].str.contains('audio', case=False, na=False) |
|
554 |
-
df['pipeline_tag'].str.contains('speech', case=False, na=False))
|
555 |
-
df['is_biomed'] = df['has_bio'] | df['has_med']
|
556 |
-
|
557 |
-
print("Cached tag columns created successfully!")
|
558 |
-
else:
|
559 |
-
# If no tags column, add empty tags and set all category flags to False
|
560 |
-
df['tags'] = [[] for _ in range(len(df))]
|
561 |
-
for col in ['has_audio', 'has_speech', 'has_music', 'has_robot',
|
562 |
-
'has_bio', 'has_med', 'has_series', 'has_science',
|
563 |
-
'has_video', 'has_image', 'has_text',
|
564 |
-
'is_audio_speech', 'is_biomed']:
|
565 |
-
df[col] = False
|
566 |
-
|
567 |
-
# Fill NaN values
|
568 |
-
df.fillna({'downloads': 0, 'downloadsAllTime': 0, 'likes': 0, 'params': 0}, inplace=True)
|
569 |
-
|
570 |
-
# Ensure pipeline_tag is a string
|
571 |
-
if 'pipeline_tag' in df.columns:
|
572 |
-
df['pipeline_tag'] = df['pipeline_tag'].fillna('')
|
573 |
-
else:
|
574 |
-
df['pipeline_tag'] = ''
|
575 |
-
|
576 |
-
# Make sure all required columns exist
|
577 |
-
for col in ['id', 'downloads', 'downloadsAllTime', 'likes', 'pipeline_tag', 'tags', 'params']:
|
578 |
-
if col not in df.columns:
|
579 |
-
if col in ['downloads', 'downloadsAllTime', 'likes', 'params']:
|
580 |
-
df[col] = 0
|
581 |
-
elif col == 'pipeline_tag':
|
582 |
-
df[col] = ''
|
583 |
-
elif col == 'tags':
|
584 |
-
df[col] = [[] for _ in range(len(df))]
|
585 |
-
elif col == 'id':
|
586 |
-
df[col] = [f"model_{i}" for i in range(len(df))]
|
587 |
-
|
588 |
-
print(f"Successfully processed {len(df)} models with cached tag and size information")
|
589 |
-
return df, True
|
590 |
-
|
591 |
-
except Exception as e:
|
592 |
-
print(f"Error loading data: {e}")
|
593 |
-
# Return an empty DataFrame and False to indicate loading failure
|
594 |
-
return pd.DataFrame(), False
|
595 |
-
|
596 |
-
# Create Gradio interface
|
597 |
-
with gr.Blocks() as demo:
|
598 |
-
models_data = gr.State()
|
599 |
-
loading_complete = gr.State(False) # Flag to indicate data load completion
|
600 |
-
|
601 |
-
with gr.Row():
|
602 |
-
gr.Markdown("""
|
603 |
-
# HuggingFace Models TreeMap Visualization
|
604 |
-
|
605 |
-
This app shows how different organizations contribute to the HuggingFace ecosystem with their models.
|
606 |
-
Use the filters to explore models by different metrics, tags, pipelines, and model sizes.
|
607 |
|
608 |
-
|
609 |
-
|
610 |
-
""")
|
611 |
-
|
612 |
-
with gr.Row():
|
613 |
-
with gr.Column(scale=1):
|
614 |
-
count_by_dropdown = gr.Dropdown(
|
615 |
-
label="Metric",
|
616 |
-
choices=[
|
617 |
-
("Downloads (last 30 days)", "downloads"),
|
618 |
-
("Downloads (All Time)", "downloadsAllTime"),
|
619 |
-
("Likes", "likes")
|
620 |
-
],
|
621 |
-
value="downloads",
|
622 |
-
info="Select the metric to determine box sizes"
|
623 |
-
)
|
624 |
-
|
625 |
-
filter_choice_radio = gr.Radio(
|
626 |
-
label="Filter Type",
|
627 |
-
choices=["None", "Tag Filter", "Pipeline Filter"],
|
628 |
-
value="None",
|
629 |
-
info="Choose how to filter the models"
|
630 |
-
)
|
631 |
-
|
632 |
-
tag_filter_dropdown = gr.Dropdown(
|
633 |
-
label="Select Tag",
|
634 |
-
choices=list(TAG_FILTER_FUNCS.keys()),
|
635 |
-
value=None,
|
636 |
-
visible=False,
|
637 |
-
info="Filter models by domain/category"
|
638 |
-
)
|
639 |
-
|
640 |
-
pipeline_filter_dropdown = gr.Dropdown(
|
641 |
-
label="Select Pipeline Tag",
|
642 |
-
choices=PIPELINE_TAGS,
|
643 |
-
value=None,
|
644 |
-
visible=False,
|
645 |
-
info="Filter models by specific pipeline"
|
646 |
-
)
|
647 |
-
|
648 |
-
size_filter_dropdown = gr.Dropdown(
|
649 |
-
label="Model Size Filter",
|
650 |
-
choices=["None"] + list(MODEL_SIZE_RANGES.keys()),
|
651 |
-
value="None",
|
652 |
-
info="Filter models by their size (using params column)"
|
653 |
-
)
|
654 |
-
|
655 |
-
top_k_slider = gr.Slider(
|
656 |
-
label="Number of Top Organizations",
|
657 |
-
minimum=5,
|
658 |
-
maximum=50,
|
659 |
-
value=25,
|
660 |
-
step=5,
|
661 |
-
info="Number of top organizations to include"
|
662 |
-
)
|
663 |
-
|
664 |
-
skip_orgs_textbox = gr.Textbox(
|
665 |
-
label="Organizations to Skip (comma-separated)",
|
666 |
-
placeholder="e.g., OpenAI, Google",
|
667 |
-
value="TheBloke, MaziyarPanahi, unsloth, modularai, Gensyn, bartowski"
|
668 |
-
)
|
669 |
-
|
670 |
-
generate_plot_button = gr.Button("Generate Plot", variant="primary", interactive=False)
|
671 |
-
refresh_data_button = gr.Button("Refresh Data from Hugging Face", variant="secondary")
|
672 |
-
|
673 |
-
with gr.Column(scale=3):
|
674 |
-
plot_output = gr.Plot()
|
675 |
-
stats_output = gr.Markdown("*Loading data from Hugging Face...*")
|
676 |
-
data_info = gr.Markdown("")
|
677 |
-
|
678 |
-
# Button enablement after data load
|
679 |
-
def enable_plot_button(loaded):
|
680 |
-
return gr.update(interactive=loaded)
|
681 |
-
|
682 |
-
loading_complete.change(
|
683 |
-
fn=enable_plot_button,
|
684 |
-
inputs=[loading_complete],
|
685 |
-
outputs=[generate_plot_button]
|
686 |
-
)
|
687 |
-
|
688 |
-
# Show/hide tag/pipeline dropdown
|
689 |
-
def update_filter_visibility(filter_choice):
|
690 |
-
if filter_choice == "Tag Filter":
|
691 |
-
return gr.update(visible=True), gr.update(visible=False)
|
692 |
-
elif filter_choice == "Pipeline Filter":
|
693 |
-
return gr.update(visible=False), gr.update(visible=True)
|
694 |
-
else:
|
695 |
-
return gr.update(visible=False), gr.update(visible=False)
|
696 |
-
|
697 |
-
filter_choice_radio.change(
|
698 |
-
fn=update_filter_visibility,
|
699 |
-
inputs=[filter_choice_radio],
|
700 |
-
outputs=[tag_filter_dropdown, pipeline_filter_dropdown]
|
701 |
-
)
|
702 |
-
|
703 |
-
# Function to handle data load and provide data info
|
704 |
-
def load_and_provide_info():
|
705 |
-
df, success = load_models_data()
|
706 |
|
707 |
-
|
708 |
-
|
709 |
-
|
710 |
-
|
711 |
-
|
712 |
-
- **Last update**: {pd.Timestamp.now().strftime('%Y-%m-%d %H:%M:%S')}
|
713 |
-
- **Data source**: [Hugging Face Hub Stats](https://huggingface.co/datasets/cfahlgren1/hub-stats) (models.parquet)
|
714 |
-
"""
|
715 |
-
|
716 |
-
# Return the data, loading status, and info text
|
717 |
-
return df, True, info_text, "*Data loaded successfully. Use the controls to generate a plot.*"
|
718 |
-
else:
|
719 |
-
# Return empty data, failed loading status, and error message
|
720 |
-
return pd.DataFrame(), False, "*Error loading data from Hugging Face.*", "*Failed to load data. Please try again.*"
|
721 |
-
|
722 |
-
# Main generate function
|
723 |
-
def generate_plot_on_click(count_by, filter_choice, tag_filter, pipeline_filter, size_filter, top_k, skip_orgs_text, data_df):
|
724 |
-
if data_df is None or not isinstance(data_df, pd.DataFrame) or data_df.empty:
|
725 |
-
return None, "Error: Data is still loading. Please wait a moment and try again."
|
726 |
-
|
727 |
-
selected_tag_filter = None
|
728 |
-
selected_pipeline_filter = None
|
729 |
-
selected_size_filter = None
|
730 |
-
|
731 |
-
if filter_choice == "Tag Filter":
|
732 |
-
selected_tag_filter = tag_filter
|
733 |
-
elif filter_choice == "Pipeline Filter":
|
734 |
-
selected_pipeline_filter = pipeline_filter
|
735 |
-
|
736 |
-
if size_filter != "None":
|
737 |
-
selected_size_filter = size_filter
|
738 |
-
|
739 |
-
skip_orgs = []
|
740 |
-
if skip_orgs_text and skip_orgs_text.strip():
|
741 |
-
skip_orgs = [org.strip() for org in skip_orgs_text.split(',') if org.strip()]
|
742 |
-
|
743 |
-
treemap_data = make_treemap_data(
|
744 |
-
df=data_df,
|
745 |
-
count_by=count_by,
|
746 |
-
top_k=top_k,
|
747 |
-
tag_filter=selected_tag_filter,
|
748 |
-
pipeline_filter=selected_pipeline_filter,
|
749 |
-
size_filter=selected_size_filter,
|
750 |
-
skip_orgs=skip_orgs
|
751 |
-
)
|
752 |
-
|
753 |
-
title_labels = {
|
754 |
-
"downloads": "Downloads (last 30 days)",
|
755 |
-
"downloadsAllTime": "Downloads (All Time)",
|
756 |
-
"likes": "Likes"
|
757 |
-
}
|
758 |
-
title_text = f"HuggingFace Models - {title_labels.get(count_by, count_by)} by Organization"
|
759 |
-
|
760 |
-
fig = create_treemap(
|
761 |
-
treemap_data=treemap_data,
|
762 |
-
count_by=count_by,
|
763 |
-
title=title_text
|
764 |
-
)
|
765 |
-
|
766 |
-
if treemap_data.empty:
|
767 |
-
stats_md = "No data matches the selected filters."
|
768 |
else:
|
769 |
-
|
770 |
-
|
771 |
-
|
772 |
-
|
773 |
-
top_5_orgs = treemap_data.groupby("organization")[count_by].sum().sort_values(ascending=False).head(5)
|
774 |
-
|
775 |
-
# Get top 5 individual models
|
776 |
-
top_5_models = treemap_data[["id", count_by]].sort_values(by=count_by, ascending=False).head(5)
|
777 |
-
|
778 |
-
# Create statistics section
|
779 |
-
stats_md = f"""
|
780 |
-
## Statistics
|
781 |
-
- **Total models shown**: {total_models:,}
|
782 |
-
- **Total {count_by}**: {int(total_value):,}
|
783 |
-
|
784 |
-
## Top Organizations by {count_by.capitalize()}
|
785 |
|
786 |
-
| Organization | {count_by.capitalize()} | % of Total |
|
787 |
-
|--------------|-------------:|----------:|
|
788 |
-
"""
|
789 |
-
|
790 |
-
# Add top organizations to the table
|
791 |
-
for org, value in top_5_orgs.items():
|
792 |
-
percentage = (value / total_value) * 100
|
793 |
-
stats_md += f"| {org} | {int(value):,} | {percentage:.2f}% |\n"
|
794 |
-
|
795 |
-
# Add the top models table
|
796 |
-
stats_md += f"""
|
797 |
-
## Top Models by {count_by.capitalize()}
|
798 |
-
|
799 |
-
| Model | {count_by.capitalize()} | % of Total |
|
800 |
-
|-------|-------------:|----------:|
|
801 |
-
"""
|
802 |
-
|
803 |
-
# Add top models to the table
|
804 |
-
for _, row in top_5_models.iterrows():
|
805 |
-
model_id = row["id"]
|
806 |
-
value = row[count_by]
|
807 |
-
percentage = (value / total_value) * 100
|
808 |
-
stats_md += f"| {model_id} | {int(value):,} | {percentage:.2f}% |\n"
|
809 |
-
|
810 |
-
# Add note about skipped organizations if any
|
811 |
-
if skip_orgs:
|
812 |
-
stats_md += f"\n*Note: {len(skip_orgs)} organization(s) excluded: {', '.join(skip_orgs)}*"
|
813 |
-
|
814 |
-
return fig, stats_md
|
815 |
-
|
816 |
-
# Load data at startup
|
817 |
demo.load(
|
818 |
-
fn=
|
819 |
-
inputs=[],
|
820 |
-
outputs=[
|
821 |
)
|
822 |
-
|
823 |
-
# Refresh data when button is clicked
|
824 |
refresh_data_button.click(
|
825 |
-
fn=
|
826 |
-
inputs=[],
|
827 |
-
outputs=[
|
828 |
)
|
829 |
-
|
830 |
generate_plot_button.click(
|
831 |
-
fn=
|
832 |
-
inputs=[
|
833 |
-
|
834 |
-
|
835 |
-
tag_filter_dropdown,
|
836 |
-
pipeline_filter_dropdown,
|
837 |
-
size_filter_dropdown,
|
838 |
-
top_k_slider,
|
839 |
-
skip_orgs_textbox,
|
840 |
-
models_data
|
841 |
-
],
|
842 |
-
outputs=[plot_output, stats_output]
|
843 |
)
|
844 |
|
845 |
if __name__ == "__main__":
|
846 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --- START OF FILE app.py ---
|
2 |
+
|
3 |
import json
|
4 |
import gradio as gr
|
5 |
import pandas as pd
|
6 |
import plotly.express as px
|
7 |
import os
|
8 |
import numpy as np
|
|
|
9 |
import duckdb
|
10 |
+
from tqdm.auto import tqdm # Standard tqdm for console, gr.Progress will track it
|
11 |
+
import time
|
12 |
+
import ast # For safely evaluating string representations of lists/dicts
|
13 |
|
14 |
+
# --- Constants ---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
MODEL_SIZE_RANGES = {
|
16 |
+
"Small (<1GB)": (0, 1), "Medium (1-5GB)": (1, 5), "Large (5-20GB)": (5, 20),
|
17 |
+
"X-Large (20-50GB)": (20, 50), "XX-Large (>50GB)": (50, float('inf'))
|
|
|
|
|
|
|
18 |
}
|
19 |
+
PROCESSED_PARQUET_FILE_PATH = "models_processed.parquet"
|
20 |
+
HF_PARQUET_URL = 'https://huggingface.co/datasets/cfahlgren1/hub-stats/resolve/main/models.parquet' # Added for completeness within app.py context
|
21 |
|
22 |
+
TAG_FILTER_CHOICES = [
|
23 |
+
"Audio & Speech", "Time series", "Robotics", "Music", "Video", "Images",
|
24 |
+
"Text", "Biomedical", "Sciences"
|
25 |
+
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
+
PIPELINE_TAGS = [
|
28 |
+
'text-generation', 'text-to-image', 'text-classification', 'text2text-generation',
|
29 |
+
'audio-to-audio', 'feature-extraction', 'image-classification', 'translation',
|
30 |
+
'reinforcement-learning', 'fill-mask', 'text-to-speech', 'automatic-speech-recognition',
|
31 |
+
'image-text-to-text', 'token-classification', 'sentence-similarity', 'question-answering',
|
32 |
+
'image-feature-extraction', 'summarization', 'zero-shot-image-classification',
|
33 |
+
'object-detection', 'image-segmentation', 'image-to-image', 'image-to-text',
|
34 |
+
'audio-classification', 'visual-question-answering', 'text-to-video',
|
35 |
+
'zero-shot-classification', 'depth-estimation', 'text-ranking', 'image-to-video',
|
36 |
+
'multiple-choice', 'unconditional-image-generation', 'video-classification',
|
37 |
+
'text-to-audio', 'time-series-forecasting', 'any-to-any', 'video-text-to-text',
|
38 |
+
'table-question-answering',
|
39 |
+
]
|
40 |
|
41 |
+
def extract_model_size(safetensors_data):
|
42 |
+
try:
|
43 |
+
if pd.isna(safetensors_data): return 0.0
|
44 |
+
data_to_parse = safetensors_data
|
45 |
+
if isinstance(safetensors_data, str):
|
46 |
+
try:
|
47 |
+
if (safetensors_data.startswith('{') and safetensors_data.endswith('}')) or \
|
48 |
+
(safetensors_data.startswith('[') and safetensors_data.endswith(']')):
|
49 |
+
data_to_parse = ast.literal_eval(safetensors_data)
|
50 |
+
else: data_to_parse = json.loads(safetensors_data)
|
51 |
+
except: return 0.0
|
52 |
+
if isinstance(data_to_parse, dict) and 'total' in data_to_parse:
|
53 |
+
try:
|
54 |
+
total_bytes_val = data_to_parse['total']
|
55 |
+
size_bytes = float(total_bytes_val)
|
56 |
+
return size_bytes / (1024 * 1024 * 1024)
|
57 |
+
except (ValueError, TypeError): pass
|
58 |
+
return 0.0
|
59 |
+
except: return 0.0
|
60 |
|
61 |
+
def extract_org_from_id(model_id):
|
62 |
+
if pd.isna(model_id): return "unaffiliated"
|
63 |
+
model_id_str = str(model_id)
|
64 |
+
return model_id_str.split("/")[0] if "/" in model_id_str else "unaffiliated"
|
65 |
|
66 |
+
def process_tags_for_series(series_of_tags_values):
|
67 |
+
processed_tags_accumulator = []
|
|
|
68 |
|
69 |
+
for i, tags_value_from_series in enumerate(tqdm(series_of_tags_values, desc="Standardizing Tags", leave=False, unit="row")):
|
70 |
+
temp_processed_list_for_row = []
|
71 |
+
current_value_for_error_msg = str(tags_value_from_series)[:200] # Truncate for long error messages
|
72 |
|
73 |
+
try:
|
74 |
+
# Order of checks is important!
|
75 |
+
# 1. Handle explicit Python lists first
|
76 |
+
if isinstance(tags_value_from_series, list):
|
77 |
+
current_tags_in_list = []
|
78 |
+
for idx_tag, tag_item in enumerate(tags_value_from_series):
|
79 |
+
try:
|
80 |
+
# Ensure item is not NaN before string conversion if it might be a float NaN in a list
|
81 |
+
if pd.isna(tag_item): continue
|
82 |
+
str_tag = str(tag_item)
|
83 |
+
stripped_tag = str_tag.strip()
|
84 |
+
if stripped_tag:
|
85 |
+
current_tags_in_list.append(stripped_tag)
|
86 |
+
except Exception as e_inner_list_proc:
|
87 |
+
print(f"ERROR processing item '{tag_item}' (type: {type(tag_item)}) within a list for row {i}. Error: {e_inner_list_proc}. Original list: {current_value_for_error_msg}")
|
88 |
+
temp_processed_list_for_row = current_tags_in_list
|
89 |
+
|
90 |
+
# 2. Handle NumPy arrays
|
91 |
+
elif isinstance(tags_value_from_series, np.ndarray):
|
92 |
+
# Convert to list, then process elements, handling potential NaNs within the array
|
93 |
+
current_tags_in_list = []
|
94 |
+
for idx_tag, tag_item in enumerate(tags_value_from_series.tolist()): # .tolist() is crucial
|
95 |
+
try:
|
96 |
+
if pd.isna(tag_item): continue # Check for NaN after converting to Python type
|
97 |
+
str_tag = str(tag_item)
|
98 |
+
stripped_tag = str_tag.strip()
|
99 |
+
if stripped_tag:
|
100 |
+
current_tags_in_list.append(stripped_tag)
|
101 |
+
except Exception as e_inner_array_proc:
|
102 |
+
print(f"ERROR processing item '{tag_item}' (type: {type(tag_item)}) within a NumPy array for row {i}. Error: {e_inner_array_proc}. Original array: {current_value_for_error_msg}")
|
103 |
+
temp_processed_list_for_row = current_tags_in_list
|
104 |
+
|
105 |
+
# 3. Handle simple None or pd.NA after lists and arrays (which might contain pd.NA elements handled above)
|
106 |
+
elif tags_value_from_series is None or pd.isna(tags_value_from_series): # Now pd.isna is safe for scalars
|
107 |
+
temp_processed_list_for_row = []
|
108 |
+
|
109 |
+
# 4. Handle strings (could be JSON-like, list-like, or comma-separated)
|
110 |
+
elif isinstance(tags_value_from_series, str):
|
111 |
+
processed_str_tags = []
|
112 |
+
# Attempt ast.literal_eval for strings that look like lists/tuples
|
113 |
+
if (tags_value_from_series.startswith('[') and tags_value_from_series.endswith(']')) or \
|
114 |
+
(tags_value_from_series.startswith('(') and tags_value_from_series.endswith(')')):
|
115 |
+
try:
|
116 |
+
evaluated_tags = ast.literal_eval(tags_value_from_series)
|
117 |
+
if isinstance(evaluated_tags, (list, tuple)): # Check if eval result is a list/tuple
|
118 |
+
# Recursively process this evaluated list/tuple, as its elements could be complex
|
119 |
+
# For simplicity here, assume elements are simple strings after eval
|
120 |
+
current_eval_list = []
|
121 |
+
for tag_item in evaluated_tags:
|
122 |
+
if pd.isna(tag_item): continue
|
123 |
+
str_tag = str(tag_item).strip()
|
124 |
+
if str_tag: current_eval_list.append(str_tag)
|
125 |
+
processed_str_tags = current_eval_list
|
126 |
+
except (ValueError, SyntaxError):
|
127 |
+
pass # If ast.literal_eval fails, let it fall to JSON or comma split
|
128 |
+
|
129 |
+
# If ast.literal_eval didn't populate, try JSON
|
130 |
+
if not processed_str_tags:
|
131 |
+
try:
|
132 |
+
json_tags = json.loads(tags_value_from_series)
|
133 |
+
if isinstance(json_tags, list):
|
134 |
+
# Similar to above, assume elements are simple strings after JSON parsing
|
135 |
+
current_json_list = []
|
136 |
+
for tag_item in json_tags:
|
137 |
+
if pd.isna(tag_item): continue
|
138 |
+
str_tag = str(tag_item).strip()
|
139 |
+
if str_tag: current_json_list.append(str_tag)
|
140 |
+
processed_str_tags = current_json_list
|
141 |
+
except json.JSONDecodeError:
|
142 |
+
# If not a valid JSON list, fall back to comma splitting as the final string strategy
|
143 |
+
processed_str_tags = [tag.strip() for tag in tags_value_from_series.split(',') if tag.strip()]
|
144 |
+
except Exception as e_json_other:
|
145 |
+
print(f"ERROR during JSON processing for string '{current_value_for_error_msg}' for row {i}. Error: {e_json_other}")
|
146 |
+
processed_str_tags = [tag.strip() for tag in tags_value_from_series.split(',') if tag.strip()] # Fallback
|
147 |
+
|
148 |
+
temp_processed_list_for_row = processed_str_tags
|
149 |
+
|
150 |
+
# 5. Fallback for other scalar types (e.g., int, float that are not NaN)
|
151 |
+
else:
|
152 |
+
# This path is for non-list, non-ndarray, non-None/NaN, non-string types.
|
153 |
+
# Or for NaNs that slipped through if they are not None or pd.NA (e.g. float('nan'))
|
154 |
+
if pd.isna(tags_value_from_series): # Catch any remaining NaNs like float('nan')
|
155 |
+
temp_processed_list_for_row = []
|
156 |
+
else:
|
157 |
+
str_val = str(tags_value_from_series).strip()
|
158 |
+
temp_processed_list_for_row = [str_val] if str_val else []
|
159 |
+
|
160 |
+
processed_tags_accumulator.append(temp_processed_list_for_row)
|
161 |
|
162 |
+
except Exception as e_outer_tag_proc:
|
163 |
+
print(f"CRITICAL UNHANDLED ERROR processing row {i}: value '{current_value_for_error_msg}' (type: {type(tags_value_from_series)}). Error: {e_outer_tag_proc}. Appending [].")
|
164 |
+
processed_tags_accumulator.append([])
|
165 |
+
|
166 |
+
return processed_tags_accumulator
|
167 |
+
|
168 |
+
def load_models_data(force_refresh=False, tqdm_cls=None):
|
169 |
+
if tqdm_cls is None: tqdm_cls = tqdm
|
170 |
+
overall_start_time = time.time()
|
171 |
+
print(f"Gradio load_models_data called with force_refresh={force_refresh}")
|
172 |
+
|
173 |
+
expected_cols_in_processed_parquet = [
|
174 |
+
'id', 'downloads', 'downloadsAllTime', 'likes', 'pipeline_tag', 'tags', 'params',
|
175 |
+
'size_category', 'organization', 'has_audio', 'has_speech', 'has_music',
|
176 |
+
'has_robot', 'has_bio', 'has_med', 'has_series', 'has_video', 'has_image',
|
177 |
+
'has_text', 'has_science', 'is_audio_speech', 'is_biomed',
|
178 |
+
'data_download_timestamp'
|
179 |
+
]
|
180 |
+
|
181 |
+
if not force_refresh and os.path.exists(PROCESSED_PARQUET_FILE_PATH):
|
182 |
+
print(f"Attempting to load pre-processed data from: {PROCESSED_PARQUET_FILE_PATH}")
|
183 |
+
try:
|
184 |
+
df = pd.read_parquet(PROCESSED_PARQUET_FILE_PATH)
|
185 |
+
elapsed = time.time() - overall_start_time
|
186 |
+
missing_cols = [col for col in expected_cols_in_processed_parquet if col not in df.columns]
|
187 |
+
if missing_cols:
|
188 |
+
raise ValueError(f"Pre-processed Parquet is missing columns: {missing_cols}. Please run preprocessor or refresh data in app.")
|
189 |
+
|
190 |
+
# --- Diagnostic for 'has_robot' after loading parquet ---
|
191 |
+
if 'has_robot' in df.columns:
|
192 |
+
robot_count_parquet = df['has_robot'].sum()
|
193 |
+
print(f"DIAGNOSTIC (App - Parquet Load): 'has_robot' column found. Number of True values: {robot_count_parquet}")
|
194 |
+
if 0 < robot_count_parquet < 10:
|
195 |
+
print(f"Sample 'has_robot' models (from parquet): {df[df['has_robot']]['id'].head().tolist()}")
|
196 |
+
else:
|
197 |
+
print("DIAGNOSTIC (App - Parquet Load): 'has_robot' column NOT FOUND.")
|
198 |
+
# --- End Diagnostic ---
|
199 |
+
|
200 |
+
msg = f"Successfully loaded pre-processed data in {elapsed:.2f}s. Shape: {df.shape}"
|
201 |
+
print(msg)
|
202 |
+
return df, True, msg
|
203 |
+
except Exception as e:
|
204 |
+
print(f"Could not load pre-processed Parquet: {e}. ")
|
205 |
+
if force_refresh: print("Proceeding to fetch fresh data as force_refresh=True.")
|
206 |
+
else:
|
207 |
+
err_msg = (f"Pre-processed data could not be loaded: {e}. "
|
208 |
+
"Please use 'Refresh Data from Hugging Face' button.")
|
209 |
+
return pd.DataFrame(), False, err_msg
|
210 |
+
|
211 |
+
df_raw = None
|
212 |
+
raw_data_source_msg = ""
|
213 |
+
if force_refresh:
|
214 |
+
print("force_refresh=True (Gradio). Fetching fresh data...")
|
215 |
+
fetch_start = time.time()
|
216 |
+
try:
|
217 |
+
query = f"SELECT * FROM read_parquet('{HF_PARQUET_URL}')" # Ensure HF_PARQUET_URL is defined
|
218 |
+
df_raw = duckdb.sql(query).df()
|
219 |
+
if df_raw is None or df_raw.empty: raise ValueError("Fetched data is empty or None.")
|
220 |
+
raw_data_source_msg = f"Fetched by Gradio in {time.time() - fetch_start:.2f}s. Rows: {len(df_raw)}"
|
221 |
+
print(raw_data_source_msg)
|
222 |
+
except Exception as e_hf:
|
223 |
+
return pd.DataFrame(), False, f"Fatal error fetching from Hugging Face (Gradio): {e_hf}"
|
224 |
+
else:
|
225 |
+
err_msg = (f"Pre-processed data '{PROCESSED_PARQUET_FILE_PATH}' not found/invalid. "
|
226 |
+
"Run preprocessor or use 'Refresh Data' button.")
|
227 |
+
return pd.DataFrame(), False, err_msg
|
228 |
+
|
229 |
+
print(f"Initiating processing for data newly fetched by Gradio. {raw_data_source_msg}")
|
230 |
+
df = pd.DataFrame()
|
231 |
+
proc_start = time.time()
|
232 |
|
233 |
+
core_cols = {'id': str, 'downloads': float, 'downloadsAllTime': float, 'likes': float,
|
234 |
+
'pipeline_tag': str, 'tags': object, 'safetensors': object}
|
235 |
+
for col, dtype in core_cols.items():
|
236 |
+
if col in df_raw.columns:
|
237 |
+
df[col] = df_raw[col]
|
238 |
+
if dtype == float: df[col] = pd.to_numeric(df[col], errors='coerce').fillna(0.0)
|
239 |
+
elif dtype == str: df[col] = df[col].astype(str).fillna('')
|
240 |
+
else:
|
241 |
+
if col in ['downloads', 'downloadsAllTime', 'likes']: df[col] = 0.0
|
242 |
+
elif col == 'pipeline_tag': df[col] = ''
|
243 |
+
elif col == 'tags': df[col] = pd.Series([[] for _ in range(len(df_raw))])
|
244 |
+
elif col == 'safetensors': df[col] = None
|
245 |
+
elif col == 'id': return pd.DataFrame(), False, "Critical: 'id' column missing."
|
|
|
|
|
|
|
|
|
246 |
|
247 |
+
output_filesize_col_name = 'params'
|
248 |
+
if output_filesize_col_name in df_raw.columns and pd.api.types.is_numeric_dtype(df_raw[output_filesize_col_name]):
|
249 |
+
df[output_filesize_col_name] = pd.to_numeric(df_raw[output_filesize_col_name], errors='coerce').fillna(0.0)
|
250 |
+
elif 'safetensors' in df.columns:
|
251 |
+
safetensors_iter = df['safetensors']
|
252 |
+
if tqdm_cls != tqdm :
|
253 |
+
safetensors_iter = tqdm_cls(df['safetensors'], desc="Extracting model sizes (GB)")
|
254 |
+
df[output_filesize_col_name] = [extract_model_size(s) for s in safetensors_iter]
|
255 |
+
df[output_filesize_col_name] = pd.to_numeric(df[output_filesize_col_name], errors='coerce').fillna(0.0)
|
256 |
+
else:
|
257 |
+
df[output_filesize_col_name] = 0.0
|
258 |
+
|
259 |
+
def get_size_category_gradio(size_gb_val):
|
260 |
+
try: numeric_size_gb = float(size_gb_val)
|
261 |
+
except (ValueError, TypeError): numeric_size_gb = 0.0
|
262 |
+
if pd.isna(numeric_size_gb): numeric_size_gb = 0.0
|
263 |
+
if 0 <= numeric_size_gb < 1: return "Small (<1GB)"
|
264 |
+
elif 1 <= numeric_size_gb < 5: return "Medium (1-5GB)"
|
265 |
+
elif 5 <= numeric_size_gb < 20: return "Large (5-20GB)"
|
266 |
+
elif 20 <= numeric_size_gb < 50: return "X-Large (20-50GB)"
|
267 |
+
elif numeric_size_gb >= 50: return "XX-Large (>50GB)"
|
268 |
+
else: return "Small (<1GB)"
|
269 |
+
df['size_category'] = df[output_filesize_col_name].apply(get_size_category_gradio)
|
270 |
+
|
271 |
+
df['tags'] = process_tags_for_series(df['tags'])
|
272 |
+
df['temp_tags_joined'] = df['tags'].apply(
|
273 |
+
lambda tl: '~~~'.join(str(t).lower() for t in tl if pd.notna(t) and str(t).strip()) if isinstance(tl, list) else ''
|
274 |
+
)
|
275 |
+
tag_map = {
|
276 |
+
'has_audio': ['audio'], 'has_speech': ['speech'], 'has_music': ['music'],
|
277 |
+
'has_robot': ['robot', 'robotics'],
|
278 |
+
'has_bio': ['bio'], 'has_med': ['medic', 'medical'],
|
279 |
+
'has_series': ['series', 'time-series', 'timeseries'],
|
280 |
+
'has_video': ['video'], 'has_image': ['image', 'vision'],
|
281 |
+
'has_text': ['text', 'nlp', 'llm']
|
282 |
+
}
|
283 |
+
for col, kws in tag_map.items():
|
284 |
+
pattern = '|'.join(kws)
|
285 |
+
df[col] = df['temp_tags_joined'].str.contains(pattern, na=False, case=False, regex=True)
|
286 |
+
df['has_science'] = (
|
287 |
+
df['temp_tags_joined'].str.contains('science', na=False, case=False, regex=True) &
|
288 |
+
~df['temp_tags_joined'].str.contains('bigscience', na=False, case=False, regex=True)
|
289 |
+
)
|
290 |
+
del df['temp_tags_joined']
|
291 |
+
df['is_audio_speech'] = (df['has_audio'] | df['has_speech'] |
|
292 |
+
df['pipeline_tag'].str.contains('audio|speech', case=False, na=False, regex=True))
|
293 |
+
df['is_biomed'] = df['has_bio'] | df['has_med']
|
294 |
+
df['organization'] = df['id'].apply(extract_org_from_id)
|
295 |
+
|
296 |
+
if 'safetensors' in df.columns and \
|
297 |
+
not (output_filesize_col_name in df_raw.columns and pd.api.types.is_numeric_dtype(df_raw[output_filesize_col_name])):
|
298 |
+
df = df.drop(columns=['safetensors'], errors='ignore')
|
|
|
299 |
|
300 |
+
# --- Diagnostic for 'has_robot' after app-side processing (force_refresh path) ---
|
301 |
+
if force_refresh and 'has_robot' in df.columns:
|
302 |
+
robot_count_app_proc = df['has_robot'].sum()
|
303 |
+
print(f"DIAGNOSTIC (App - Force Refresh Processing): 'has_robot' column processed. Number of True values: {robot_count_app_proc}")
|
304 |
+
if 0 < robot_count_app_proc < 10:
|
305 |
+
print(f"Sample 'has_robot' models (App processed): {df[df['has_robot']]['id'].head().tolist()}")
|
306 |
+
# --- End Diagnostic ---
|
307 |
+
|
308 |
+
print(f"Data processing by Gradio completed in {time.time() - proc_start:.2f}s.")
|
309 |
+
|
310 |
+
total_elapsed = time.time() - overall_start_time
|
311 |
+
final_msg = f"{raw_data_source_msg}. Processing by Gradio took {time.time() - proc_start:.2f}s. Total: {total_elapsed:.2f}s. Shape: {df.shape}"
|
312 |
+
print(final_msg)
|
313 |
+
return df, True, final_msg
|
314 |
|
|
|
|
|
|
|
|
|
|
|
315 |
|
316 |
def make_treemap_data(df, count_by, top_k=25, tag_filter=None, pipeline_filter=None, size_filter=None, skip_orgs=None):
|
317 |
+
if df is None or df.empty: return pd.DataFrame()
|
|
|
318 |
filtered_df = df.copy()
|
319 |
+
col_map = { "Audio & Speech": "is_audio_speech", "Music": "has_music", "Robotics": "has_robot",
|
320 |
+
"Biomedical": "is_biomed", "Time series": "has_series", "Sciences": "has_science",
|
321 |
+
"Video": "has_video", "Images": "has_image", "Text": "has_text"}
|
322 |
|
323 |
+
# --- Diagnostic within make_treemap_data ---
|
324 |
+
if 'has_robot' in filtered_df.columns:
|
325 |
+
initial_robot_count = filtered_df['has_robot'].sum()
|
326 |
+
print(f"DIAGNOSTIC (make_treemap_data entry): Input df has {initial_robot_count} 'has_robot' models.")
|
327 |
+
else:
|
328 |
+
print("DIAGNOSTIC (make_treemap_data entry): 'has_robot' column NOT in input df.")
|
329 |
+
# --- End Diagnostic ---
|
330 |
+
|
331 |
+
if tag_filter and tag_filter in col_map:
|
332 |
+
target_col = col_map[tag_filter]
|
333 |
+
if target_col in filtered_df.columns:
|
334 |
+
# --- Diagnostic for specific 'Robotics' filter application ---
|
335 |
+
if tag_filter == "Robotics":
|
336 |
+
count_before_robot_filter = filtered_df[target_col].sum()
|
337 |
+
print(f"DIAGNOSTIC (make_treemap_data): Applying 'Robotics' filter. Models with '{target_col}'=True before this filter step: {count_before_robot_filter}")
|
338 |
+
# --- End Diagnostic ---
|
339 |
+
filtered_df = filtered_df[filtered_df[target_col]]
|
340 |
+
if tag_filter == "Robotics":
|
341 |
+
print(f"DIAGNOSTIC (make_treemap_data): After 'Robotics' filter ({target_col}), df rows: {len(filtered_df)}")
|
342 |
+
else:
|
343 |
+
print(f"Warning: Tag filter column '{col_map[tag_filter]}' not found in DataFrame.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
344 |
if pipeline_filter:
|
345 |
+
if "pipeline_tag" in filtered_df.columns:
|
346 |
+
filtered_df = filtered_df[filtered_df["pipeline_tag"] == pipeline_filter]
|
347 |
+
else:
|
348 |
+
print(f"Warning: 'pipeline_tag' column not found for filtering.")
|
349 |
+
if size_filter and size_filter != "None" and size_filter in MODEL_SIZE_RANGES.keys():
|
350 |
+
if 'size_category' in filtered_df.columns:
|
351 |
+
filtered_df = filtered_df[filtered_df['size_category'] == size_filter]
|
352 |
+
else:
|
353 |
+
print("Warning: 'size_category' column not found for filtering.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
354 |
if skip_orgs and len(skip_orgs) > 0:
|
355 |
+
if "organization" in filtered_df.columns:
|
356 |
+
filtered_df = filtered_df[~filtered_df["organization"].isin(skip_orgs)]
|
357 |
+
else:
|
358 |
+
print("Warning: 'organization' column not found for filtering.")
|
359 |
+
if filtered_df.empty: return pd.DataFrame()
|
360 |
+
if count_by not in filtered_df.columns or not pd.api.types.is_numeric_dtype(filtered_df[count_by]):
|
361 |
+
filtered_df[count_by] = pd.to_numeric(filtered_df.get(count_by), errors="coerce").fillna(0.0)
|
362 |
+
org_totals = filtered_df.groupby("organization")[count_by].sum().nlargest(top_k, keep='first')
|
363 |
+
top_orgs_list = org_totals.index.tolist()
|
364 |
+
treemap_data = filtered_df[filtered_df["organization"].isin(top_orgs_list)][["id", "organization", count_by]].copy()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
365 |
treemap_data["root"] = "models"
|
366 |
+
treemap_data[count_by] = pd.to_numeric(treemap_data[count_by], errors="coerce").fillna(0.0)
|
|
|
|
|
|
|
|
|
367 |
return treemap_data
|
368 |
|
369 |
def create_treemap(treemap_data, count_by, title=None):
|
|
|
370 |
if treemap_data.empty:
|
371 |
+
fig = px.treemap(names=["No data matches filters"], parents=[""], values=[1])
|
372 |
+
fig.update_layout(title="No data matches the selected filters", margin=dict(t=50, l=25, r=25, b=25))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
373 |
return fig
|
|
|
|
|
374 |
fig = px.treemap(
|
375 |
+
treemap_data, path=["root", "organization", "id"], values=count_by,
|
|
|
|
|
376 |
title=title or f"HuggingFace Models - {count_by.capitalize()} by Organization",
|
377 |
color_discrete_sequence=px.colors.qualitative.Plotly
|
378 |
)
|
379 |
+
fig.update_layout(margin=dict(t=50, l=25, r=25, b=25))
|
380 |
+
fig.update_traces(textinfo="label+value+percent root", hovertemplate="<b>%{label}</b><br>%{value:,} " + count_by + "<br>%{percentRoot:.2%} of total<extra></extra>")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
381 |
return fig
|
382 |
|
383 |
+
with gr.Blocks(title="HuggingFace Model Explorer", fill_width=True) as demo:
|
384 |
+
models_data_state = gr.State(pd.DataFrame())
|
385 |
+
loading_complete_state = gr.State(False)
|
386 |
+
|
387 |
+
with gr.Row(): gr.Markdown("# HuggingFace Models TreeMap Visualization")
|
388 |
+
with gr.Row():
|
389 |
+
with gr.Column(scale=1):
|
390 |
+
count_by_dropdown = gr.Dropdown(label="Metric", choices=[("Downloads (last 30 days)", "downloads"), ("Downloads (All Time)", "downloadsAllTime"), ("Likes", "likes")], value="downloads")
|
391 |
+
filter_choice_radio = gr.Radio(label="Filter Type", choices=["None", "Tag Filter", "Pipeline Filter"], value="None")
|
392 |
+
tag_filter_dropdown = gr.Dropdown(label="Select Tag", choices=TAG_FILTER_CHOICES, value=None, visible=False)
|
393 |
+
pipeline_filter_dropdown = gr.Dropdown(label="Select Pipeline Tag", choices=PIPELINE_TAGS, value=None, visible=False)
|
394 |
+
size_filter_dropdown = gr.Dropdown(label="Model Size Filter", choices=["None"] + list(MODEL_SIZE_RANGES.keys()), value="None")
|
395 |
+
top_k_slider = gr.Slider(label="Number of Top Organizations", minimum=5, maximum=50, value=25, step=5)
|
396 |
+
skip_orgs_textbox = gr.Textbox(label="Organizations to Skip (comma-separated)", value="TheBloke,MaziyarPanahi,unsloth,modularai,Gensyn,bartowski")
|
397 |
+
generate_plot_button = gr.Button(value="Generate Plot", variant="primary", interactive=False)
|
398 |
+
refresh_data_button = gr.Button(value="Refresh Data from Hugging Face", variant="secondary")
|
399 |
+
with gr.Column(scale=3):
|
400 |
+
plot_output = gr.Plot()
|
401 |
+
status_message_md = gr.Markdown("Initializing...")
|
402 |
+
data_info_md = gr.Markdown("")
|
403 |
+
|
404 |
+
def _update_button_interactivity(is_loaded_flag):
|
405 |
+
return gr.update(interactive=is_loaded_flag)
|
406 |
+
loading_complete_state.change(fn=_update_button_interactivity, inputs=loading_complete_state, outputs=generate_plot_button)
|
407 |
+
|
408 |
+
def _toggle_filters_visibility(choice):
|
409 |
+
return gr.update(visible=choice == "Tag Filter"), gr.update(visible=choice == "Pipeline Filter")
|
410 |
+
filter_choice_radio.change(fn=_toggle_filters_visibility, inputs=filter_choice_radio, outputs=[tag_filter_dropdown, pipeline_filter_dropdown])
|
411 |
+
|
412 |
+
def ui_load_data_controller(force_refresh_ui_trigger=False, progress=gr.Progress(track_tqdm=True)):
|
413 |
+
print(f"ui_load_data_controller called with force_refresh_ui_trigger={force_refresh_ui_trigger}")
|
414 |
+
status_msg_ui = "Loading data..."
|
415 |
+
data_info_text = ""
|
416 |
+
current_df = pd.DataFrame()
|
417 |
+
load_success_flag = False
|
418 |
+
data_as_of_date_display = "N/A"
|
419 |
try:
|
420 |
+
current_df, load_success_flag, status_msg_from_load = load_models_data(
|
421 |
+
force_refresh=force_refresh_ui_trigger, tqdm_cls=progress.tqdm
|
422 |
+
)
|
423 |
+
if load_success_flag:
|
424 |
+
if force_refresh_ui_trigger:
|
425 |
+
data_as_of_date_display = pd.Timestamp.now(tz='UTC').strftime('%B %d, %Y, %H:%M:%S %Z')
|
426 |
+
elif 'data_download_timestamp' in current_df.columns and not current_df.empty and pd.notna(current_df['data_download_timestamp'].iloc[0]):
|
427 |
+
timestamp_from_parquet = pd.to_datetime(current_df['data_download_timestamp'].iloc[0])
|
428 |
+
if timestamp_from_parquet.tzinfo is None:
|
429 |
+
timestamp_from_parquet = timestamp_from_parquet.tz_localize('UTC')
|
430 |
+
data_as_of_date_display = timestamp_from_parquet.strftime('%B %d, %Y, %H:%M:%S %Z')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
431 |
else:
|
432 |
+
data_as_of_date_display = "Pre-processed (date unavailable)"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
433 |
|
434 |
+
size_dist_lines = []
|
435 |
+
if 'size_category' in current_df.columns:
|
436 |
+
for cat in MODEL_SIZE_RANGES.keys():
|
437 |
+
count = (current_df['size_category'] == cat).sum()
|
438 |
+
size_dist_lines.append(f" - {cat}: {count:,} models")
|
439 |
+
else: size_dist_lines.append(" - Size category information not available.")
|
440 |
+
size_dist = "\n".join(size_dist_lines)
|
|
|
|
|
441 |
|
442 |
+
data_info_text = (f"### Data Information\n"
|
443 |
+
f"- Overall Status: {status_msg_from_load}\n"
|
444 |
+
f"- Total models loaded: {len(current_df):,}\n"
|
445 |
+
f"- Data as of: {data_as_of_date_display}\n"
|
446 |
+
f"- Size categories:\n{size_dist}")
|
447 |
|
448 |
+
# # --- MODIFICATION: Add 'has_robot' count to UI data_info_text ---
|
449 |
+
# if not current_df.empty and 'has_robot' in current_df.columns:
|
450 |
+
# robot_true_count = current_df['has_robot'].sum()
|
451 |
+
# data_info_text += f"\n- **Models flagged 'has_robot'**: {robot_true_count}"
|
452 |
+
# if 0 < robot_true_count <= 10: # If a few are found, list some IDs
|
453 |
+
# sample_robot_ids = current_df[current_df['has_robot']]['id'].head(5).tolist()
|
454 |
+
# data_info_text += f"\n - Sample 'has_robot' model IDs: `{', '.join(sample_robot_ids)}`"
|
455 |
+
# elif not current_df.empty:
|
456 |
+
# data_info_text += "\n- **Models flagged 'has_robot'**: 'has_robot' column not found in loaded data."
|
457 |
+
# # --- END MODIFICATION ---
|
458 |
+
|
459 |
+
status_msg_ui = "Data loaded successfully. Ready to generate plot."
|
460 |
+
else:
|
461 |
+
data_info_text = f"### Data Load Failed\n- {status_msg_from_load}"
|
462 |
+
status_msg_ui = status_msg_from_load
|
463 |
+
except Exception as e:
|
464 |
+
status_msg_ui = f"An unexpected error occurred in ui_load_data_controller: {str(e)}"
|
465 |
+
data_info_text = f"### Critical Error\n- {status_msg_ui}"
|
466 |
+
print(f"Critical error in ui_load_data_controller: {e}")
|
467 |
+
load_success_flag = False
|
468 |
+
return current_df, load_success_flag, data_info_text, status_msg_ui
|
469 |
+
|
470 |
+
def ui_generate_plot_controller(metric_choice, filter_type, tag_choice, pipeline_choice,
|
471 |
+
size_choice, k_orgs, skip_orgs_input, df_current_models):
|
472 |
+
if df_current_models is None or df_current_models.empty:
|
473 |
+
empty_fig = create_treemap(pd.DataFrame(), metric_choice, "Error: Model Data Not Loaded")
|
474 |
+
error_msg = "Model data is not loaded or is empty. Please load or refresh data first."
|
475 |
+
gr.Warning(error_msg)
|
476 |
+
return empty_fig, error_msg
|
477 |
+
tag_to_use = tag_choice if filter_type == "Tag Filter" else None
|
478 |
+
pipeline_to_use = pipeline_choice if filter_type == "Pipeline Filter" else None
|
479 |
+
size_to_use = size_choice if size_choice != "None" else None
|
480 |
+
orgs_to_skip = [org.strip() for org in skip_orgs_input.split(',') if org.strip()] if skip_orgs_input else []
|
481 |
|
482 |
+
# --- Diagnostic before calling make_treemap_data ---
|
483 |
+
if 'has_robot' in df_current_models.columns:
|
484 |
+
robot_count_before_treemap = df_current_models['has_robot'].sum()
|
485 |
+
print(f"DIAGNOSTIC (ui_generate_plot_controller): df_current_models entering make_treemap_data has {robot_count_before_treemap} 'has_robot' models.")
|
486 |
+
# --- End Diagnostic ---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
487 |
|
488 |
+
treemap_df = make_treemap_data(df_current_models, metric_choice, k_orgs, tag_to_use, pipeline_to_use, size_to_use, orgs_to_skip)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
489 |
|
490 |
+
title_labels = {"downloads": "Downloads (last 30 days)", "downloadsAllTime": "Downloads (All Time)", "likes": "Likes"}
|
491 |
+
chart_title = f"HuggingFace Models - {title_labels.get(metric_choice, metric_choice)} by Organization"
|
492 |
+
plotly_fig = create_treemap(treemap_df, metric_choice, chart_title)
|
493 |
+
if treemap_df.empty:
|
494 |
+
plot_stats_md = "No data matches the selected filters. Try adjusting your filters."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
495 |
else:
|
496 |
+
total_items_in_plot = len(treemap_df['id'].unique())
|
497 |
+
total_value_in_plot = treemap_df[metric_choice].sum()
|
498 |
+
plot_stats_md = (f"## Plot Statistics\n- **Models shown**: {total_items_in_plot:,}\n- **Total {metric_choice}**: {int(total_value_in_plot):,}")
|
499 |
+
return plotly_fig, plot_stats_md
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
500 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
501 |
demo.load(
|
502 |
+
fn=lambda progress=gr.Progress(track_tqdm=True): ui_load_data_controller(force_refresh_ui_trigger=False, progress=progress),
|
503 |
+
inputs=[],
|
504 |
+
outputs=[models_data_state, loading_complete_state, data_info_md, status_message_md]
|
505 |
)
|
|
|
|
|
506 |
refresh_data_button.click(
|
507 |
+
fn=lambda progress=gr.Progress(track_tqdm=True): ui_load_data_controller(force_refresh_ui_trigger=True, progress=progress),
|
508 |
+
inputs=[],
|
509 |
+
outputs=[models_data_state, loading_complete_state, data_info_md, status_message_md]
|
510 |
)
|
|
|
511 |
generate_plot_button.click(
|
512 |
+
fn=ui_generate_plot_controller,
|
513 |
+
inputs=[count_by_dropdown, filter_choice_radio, tag_filter_dropdown, pipeline_filter_dropdown,
|
514 |
+
size_filter_dropdown, top_k_slider, skip_orgs_textbox, models_data_state],
|
515 |
+
outputs=[plot_output, status_message_md]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
516 |
)
|
517 |
|
518 |
if __name__ == "__main__":
|
519 |
+
if not os.path.exists(PROCESSED_PARQUET_FILE_PATH):
|
520 |
+
print(f"WARNING: Pre-processed data file '{PROCESSED_PARQUET_FILE_PATH}' not found.")
|
521 |
+
print("It is highly recommended to run the preprocessing script (e.g., preprocess.py) first.") # Corrected script name
|
522 |
+
else:
|
523 |
+
print(f"Found pre-processed data file: '{PROCESSED_PARQUET_FILE_PATH}'.")
|
524 |
+
demo.launch()
|
525 |
+
|
526 |
+
# --- END OF FILE app.py ---
|
models_processed.parquet
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:998afad6c0c4c64f9e98efd8609d1cbab1dd2ac281b9c2e023878ad436c2fbde
|
3 |
+
size 96033487
|
preprocess.py
ADDED
@@ -0,0 +1,371 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --- START OF FILE preprocess.py ---
|
2 |
+
|
3 |
+
import pandas as pd
|
4 |
+
import numpy as np
|
5 |
+
import json
|
6 |
+
import ast
|
7 |
+
from tqdm.auto import tqdm
|
8 |
+
import time
|
9 |
+
import os
|
10 |
+
import duckdb
|
11 |
+
import re # Import re for the manual regex check in debug
|
12 |
+
|
13 |
+
# --- Constants ---
|
14 |
+
PROCESSED_PARQUET_FILE_PATH = "models_processed.parquet"
|
15 |
+
HF_PARQUET_URL = 'https://huggingface.co/datasets/cfahlgren1/hub-stats/resolve/main/models.parquet'
|
16 |
+
|
17 |
+
MODEL_SIZE_RANGES = {
|
18 |
+
"Small (<1GB)": (0, 1),
|
19 |
+
"Medium (1-5GB)": (1, 5),
|
20 |
+
"Large (5-20GB)": (5, 20),
|
21 |
+
"X-Large (20-50GB)": (20, 50),
|
22 |
+
"XX-Large (>50GB)": (50, float('inf'))
|
23 |
+
}
|
24 |
+
|
25 |
+
# --- Debugging Constant ---
|
26 |
+
# <<<<<<< SET THE MODEL ID YOU WANT TO DEBUG HERE >>>>>>>
|
27 |
+
MODEL_ID_TO_DEBUG = "openvla/openvla-7b"
|
28 |
+
# Example: MODEL_ID_TO_DEBUG = "openai-community/gpt2"
|
29 |
+
# If you don't have a specific ID, the debug block will just report it's not found.
|
30 |
+
|
31 |
+
# --- Utility Functions (extract_model_file_size_gb, extract_org_from_id, process_tags_for_series, get_file_size_category - unchanged from previous correct version) ---
|
32 |
+
def extract_model_file_size_gb(safetensors_data):
|
33 |
+
try:
|
34 |
+
if pd.isna(safetensors_data): return 0.0
|
35 |
+
data_to_parse = safetensors_data
|
36 |
+
if isinstance(safetensors_data, str):
|
37 |
+
try:
|
38 |
+
if (safetensors_data.startswith('{') and safetensors_data.endswith('}')) or \
|
39 |
+
(safetensors_data.startswith('[') and safetensors_data.endswith(']')):
|
40 |
+
data_to_parse = ast.literal_eval(safetensors_data)
|
41 |
+
else: data_to_parse = json.loads(safetensors_data)
|
42 |
+
except Exception: return 0.0
|
43 |
+
if isinstance(data_to_parse, dict) and 'total' in data_to_parse:
|
44 |
+
total_bytes_val = data_to_parse['total']
|
45 |
+
try:
|
46 |
+
size_bytes = float(total_bytes_val)
|
47 |
+
return size_bytes / (1024 * 1024 * 1024)
|
48 |
+
except (ValueError, TypeError): return 0.0
|
49 |
+
return 0.0
|
50 |
+
except Exception: return 0.0
|
51 |
+
|
52 |
+
def extract_org_from_id(model_id):
|
53 |
+
if pd.isna(model_id): return "unaffiliated"
|
54 |
+
model_id_str = str(model_id)
|
55 |
+
return model_id_str.split("/")[0] if "/" in model_id_str else "unaffiliated"
|
56 |
+
|
57 |
+
def process_tags_for_series(series_of_tags_values):
|
58 |
+
processed_tags_accumulator = []
|
59 |
+
|
60 |
+
for i, tags_value_from_series in enumerate(tqdm(series_of_tags_values, desc="Standardizing Tags", leave=False, unit="row")):
|
61 |
+
temp_processed_list_for_row = []
|
62 |
+
current_value_for_error_msg = str(tags_value_from_series)[:200] # Truncate for long error messages
|
63 |
+
|
64 |
+
try:
|
65 |
+
# Order of checks is important!
|
66 |
+
# 1. Handle explicit Python lists first
|
67 |
+
if isinstance(tags_value_from_series, list):
|
68 |
+
current_tags_in_list = []
|
69 |
+
for idx_tag, tag_item in enumerate(tags_value_from_series):
|
70 |
+
try:
|
71 |
+
# Ensure item is not NaN before string conversion if it might be a float NaN in a list
|
72 |
+
if pd.isna(tag_item): continue
|
73 |
+
str_tag = str(tag_item)
|
74 |
+
stripped_tag = str_tag.strip()
|
75 |
+
if stripped_tag:
|
76 |
+
current_tags_in_list.append(stripped_tag)
|
77 |
+
except Exception as e_inner_list_proc:
|
78 |
+
print(f"ERROR processing item '{tag_item}' (type: {type(tag_item)}) within a list for row {i}. Error: {e_inner_list_proc}. Original list: {current_value_for_error_msg}")
|
79 |
+
temp_processed_list_for_row = current_tags_in_list
|
80 |
+
|
81 |
+
# 2. Handle NumPy arrays
|
82 |
+
elif isinstance(tags_value_from_series, np.ndarray):
|
83 |
+
# Convert to list, then process elements, handling potential NaNs within the array
|
84 |
+
current_tags_in_list = []
|
85 |
+
for idx_tag, tag_item in enumerate(tags_value_from_series.tolist()): # .tolist() is crucial
|
86 |
+
try:
|
87 |
+
if pd.isna(tag_item): continue # Check for NaN after converting to Python type
|
88 |
+
str_tag = str(tag_item)
|
89 |
+
stripped_tag = str_tag.strip()
|
90 |
+
if stripped_tag:
|
91 |
+
current_tags_in_list.append(stripped_tag)
|
92 |
+
except Exception as e_inner_array_proc:
|
93 |
+
print(f"ERROR processing item '{tag_item}' (type: {type(tag_item)}) within a NumPy array for row {i}. Error: {e_inner_array_proc}. Original array: {current_value_for_error_msg}")
|
94 |
+
temp_processed_list_for_row = current_tags_in_list
|
95 |
+
|
96 |
+
# 3. Handle simple None or pd.NA after lists and arrays (which might contain pd.NA elements handled above)
|
97 |
+
elif tags_value_from_series is None or pd.isna(tags_value_from_series): # Now pd.isna is safe for scalars
|
98 |
+
temp_processed_list_for_row = []
|
99 |
+
|
100 |
+
# 4. Handle strings (could be JSON-like, list-like, or comma-separated)
|
101 |
+
elif isinstance(tags_value_from_series, str):
|
102 |
+
processed_str_tags = []
|
103 |
+
# Attempt ast.literal_eval for strings that look like lists/tuples
|
104 |
+
if (tags_value_from_series.startswith('[') and tags_value_from_series.endswith(']')) or \
|
105 |
+
(tags_value_from_series.startswith('(') and tags_value_from_series.endswith(')')):
|
106 |
+
try:
|
107 |
+
evaluated_tags = ast.literal_eval(tags_value_from_series)
|
108 |
+
if isinstance(evaluated_tags, (list, tuple)): # Check if eval result is a list/tuple
|
109 |
+
# Recursively process this evaluated list/tuple, as its elements could be complex
|
110 |
+
# For simplicity here, assume elements are simple strings after eval
|
111 |
+
current_eval_list = []
|
112 |
+
for tag_item in evaluated_tags:
|
113 |
+
if pd.isna(tag_item): continue
|
114 |
+
str_tag = str(tag_item).strip()
|
115 |
+
if str_tag: current_eval_list.append(str_tag)
|
116 |
+
processed_str_tags = current_eval_list
|
117 |
+
except (ValueError, SyntaxError):
|
118 |
+
pass # If ast.literal_eval fails, let it fall to JSON or comma split
|
119 |
+
|
120 |
+
# If ast.literal_eval didn't populate, try JSON
|
121 |
+
if not processed_str_tags:
|
122 |
+
try:
|
123 |
+
json_tags = json.loads(tags_value_from_series)
|
124 |
+
if isinstance(json_tags, list):
|
125 |
+
# Similar to above, assume elements are simple strings after JSON parsing
|
126 |
+
current_json_list = []
|
127 |
+
for tag_item in json_tags:
|
128 |
+
if pd.isna(tag_item): continue
|
129 |
+
str_tag = str(tag_item).strip()
|
130 |
+
if str_tag: current_json_list.append(str_tag)
|
131 |
+
processed_str_tags = current_json_list
|
132 |
+
except json.JSONDecodeError:
|
133 |
+
# If not a valid JSON list, fall back to comma splitting as the final string strategy
|
134 |
+
processed_str_tags = [tag.strip() for tag in tags_value_from_series.split(',') if tag.strip()]
|
135 |
+
except Exception as e_json_other:
|
136 |
+
print(f"ERROR during JSON processing for string '{current_value_for_error_msg}' for row {i}. Error: {e_json_other}")
|
137 |
+
processed_str_tags = [tag.strip() for tag in tags_value_from_series.split(',') if tag.strip()] # Fallback
|
138 |
+
|
139 |
+
temp_processed_list_for_row = processed_str_tags
|
140 |
+
|
141 |
+
# 5. Fallback for other scalar types (e.g., int, float that are not NaN)
|
142 |
+
else:
|
143 |
+
# This path is for non-list, non-ndarray, non-None/NaN, non-string types.
|
144 |
+
# Or for NaNs that slipped through if they are not None or pd.NA (e.g. float('nan'))
|
145 |
+
if pd.isna(tags_value_from_series): # Catch any remaining NaNs like float('nan')
|
146 |
+
temp_processed_list_for_row = []
|
147 |
+
else:
|
148 |
+
str_val = str(tags_value_from_series).strip()
|
149 |
+
temp_processed_list_for_row = [str_val] if str_val else []
|
150 |
+
|
151 |
+
processed_tags_accumulator.append(temp_processed_list_for_row)
|
152 |
+
|
153 |
+
except Exception as e_outer_tag_proc:
|
154 |
+
print(f"CRITICAL UNHANDLED ERROR processing row {i}: value '{current_value_for_error_msg}' (type: {type(tags_value_from_series)}). Error: {e_outer_tag_proc}. Appending [].")
|
155 |
+
processed_tags_accumulator.append([])
|
156 |
+
|
157 |
+
return processed_tags_accumulator
|
158 |
+
|
159 |
+
def get_file_size_category(file_size_gb_val):
|
160 |
+
try:
|
161 |
+
numeric_file_size_gb = float(file_size_gb_val)
|
162 |
+
if pd.isna(numeric_file_size_gb): numeric_file_size_gb = 0.0
|
163 |
+
except (ValueError, TypeError): numeric_file_size_gb = 0.0
|
164 |
+
if 0 <= numeric_file_size_gb < 1: return "Small (<1GB)"
|
165 |
+
elif 1 <= numeric_file_size_gb < 5: return "Medium (1-5GB)"
|
166 |
+
elif 5 <= numeric_file_size_gb < 20: return "Large (5-20GB)"
|
167 |
+
elif 20 <= numeric_file_size_gb < 50: return "X-Large (20-50GB)"
|
168 |
+
elif numeric_file_size_gb >= 50: return "XX-Large (>50GB)"
|
169 |
+
else: return "Small (<1GB)"
|
170 |
+
|
171 |
+
|
172 |
+
def main_preprocessor():
|
173 |
+
print(f"Starting pre-processing script. Output: '{PROCESSED_PARQUET_FILE_PATH}'.")
|
174 |
+
overall_start_time = time.time()
|
175 |
+
|
176 |
+
print(f"Fetching fresh data from Hugging Face: {HF_PARQUET_URL}")
|
177 |
+
try:
|
178 |
+
fetch_start_time = time.time()
|
179 |
+
query = f"SELECT * FROM read_parquet('{HF_PARQUET_URL}')"
|
180 |
+
df_raw = duckdb.sql(query).df()
|
181 |
+
data_download_timestamp = pd.Timestamp.now(tz='UTC')
|
182 |
+
|
183 |
+
if df_raw is None or df_raw.empty: raise ValueError("Fetched data is empty or None.")
|
184 |
+
if 'id' not in df_raw.columns: raise ValueError("Fetched data must contain 'id' column.")
|
185 |
+
|
186 |
+
print(f"Fetched data in {time.time() - fetch_start_time:.2f}s. Rows: {len(df_raw)}. Downloaded at: {data_download_timestamp.strftime('%Y-%m-%d %H:%M:%S %Z')}")
|
187 |
+
except Exception as e_fetch:
|
188 |
+
print(f"ERROR: Could not fetch data from Hugging Face: {e_fetch}.")
|
189 |
+
return
|
190 |
+
|
191 |
+
df = pd.DataFrame()
|
192 |
+
print("Processing raw data...")
|
193 |
+
proc_start = time.time()
|
194 |
+
|
195 |
+
expected_cols_setup = {
|
196 |
+
'id': str, 'downloads': float, 'downloadsAllTime': float, 'likes': float,
|
197 |
+
'pipeline_tag': str, 'tags': object, 'safetensors': object
|
198 |
+
}
|
199 |
+
for col_name, target_dtype in expected_cols_setup.items():
|
200 |
+
if col_name in df_raw.columns:
|
201 |
+
df[col_name] = df_raw[col_name]
|
202 |
+
if target_dtype == float: df[col_name] = pd.to_numeric(df[col_name], errors='coerce').fillna(0.0)
|
203 |
+
elif target_dtype == str: df[col_name] = df[col_name].astype(str).fillna('')
|
204 |
+
else:
|
205 |
+
if col_name in ['downloads', 'downloadsAllTime', 'likes']: df[col_name] = 0.0
|
206 |
+
elif col_name == 'pipeline_tag': df[col_name] = ''
|
207 |
+
elif col_name == 'tags': df[col_name] = pd.Series([[] for _ in range(len(df_raw))]) # Initialize with empty lists
|
208 |
+
elif col_name == 'safetensors': df[col_name] = None # Initialize with None
|
209 |
+
elif col_name == 'id': print("CRITICAL ERROR: 'id' column missing."); return
|
210 |
+
|
211 |
+
output_filesize_col_name = 'params'
|
212 |
+
if output_filesize_col_name in df_raw.columns and pd.api.types.is_numeric_dtype(df_raw[output_filesize_col_name]):
|
213 |
+
print(f"Using pre-existing '{output_filesize_col_name}' column as file size in GB.")
|
214 |
+
df[output_filesize_col_name] = pd.to_numeric(df_raw[output_filesize_col_name], errors='coerce').fillna(0.0)
|
215 |
+
elif 'safetensors' in df.columns:
|
216 |
+
print(f"Calculating '{output_filesize_col_name}' (file size in GB) from 'safetensors' data...")
|
217 |
+
df[output_filesize_col_name] = df['safetensors'].apply(extract_model_file_size_gb)
|
218 |
+
df[output_filesize_col_name] = pd.to_numeric(df[output_filesize_col_name], errors='coerce').fillna(0.0)
|
219 |
+
else:
|
220 |
+
print(f"Cannot determine file size. Setting '{output_filesize_col_name}' to 0.0.")
|
221 |
+
df[output_filesize_col_name] = 0.0
|
222 |
+
|
223 |
+
df['data_download_timestamp'] = data_download_timestamp
|
224 |
+
print(f"Added 'data_download_timestamp' column.")
|
225 |
+
|
226 |
+
print("Categorizing models by file size...")
|
227 |
+
df['size_category'] = df[output_filesize_col_name].apply(get_file_size_category)
|
228 |
+
|
229 |
+
print("Standardizing 'tags' column...")
|
230 |
+
df['tags'] = process_tags_for_series(df['tags']) # This now uses tqdm internally
|
231 |
+
|
232 |
+
# --- START DEBUGGING BLOCK ---
|
233 |
+
# This block will execute before the main tag processing loop
|
234 |
+
if MODEL_ID_TO_DEBUG and MODEL_ID_TO_DEBUG in df['id'].values: # Check if ID exists
|
235 |
+
print(f"\n--- Pre-Loop Debugging for Model ID: {MODEL_ID_TO_DEBUG} ---")
|
236 |
+
|
237 |
+
# 1. Check the 'tags' column content after process_tags_for_series
|
238 |
+
model_specific_tags_list = df.loc[df['id'] == MODEL_ID_TO_DEBUG, 'tags'].iloc[0]
|
239 |
+
print(f"1. Tags from df['tags'] (after process_tags_for_series): {model_specific_tags_list}")
|
240 |
+
print(f" Type of tags: {type(model_specific_tags_list)}")
|
241 |
+
if isinstance(model_specific_tags_list, list):
|
242 |
+
for i, tag_item in enumerate(model_specific_tags_list):
|
243 |
+
print(f" Tag item {i}: '{tag_item}' (type: {type(tag_item)}, len: {len(str(tag_item))})")
|
244 |
+
# Detailed check for 'robotics' specifically
|
245 |
+
if 'robotics' in str(tag_item).lower():
|
246 |
+
print(f" DEBUG: Found 'robotics' substring in '{tag_item}'")
|
247 |
+
print(f" - str(tag_item).lower().strip(): '{str(tag_item).lower().strip()}'")
|
248 |
+
print(f" - Is it exactly 'robotics'?: {str(tag_item).lower().strip() == 'robotics'}")
|
249 |
+
print(f" - Ordinals: {[ord(c) for c in str(tag_item)]}")
|
250 |
+
|
251 |
+
# 2. Simulate temp_tags_joined for this specific model
|
252 |
+
if isinstance(model_specific_tags_list, list):
|
253 |
+
simulated_temp_tags_joined = '~~~'.join(str(t).lower().strip() for t in model_specific_tags_list if pd.notna(t) and str(t).strip())
|
254 |
+
else:
|
255 |
+
simulated_temp_tags_joined = ''
|
256 |
+
print(f"2. Simulated 'temp_tags_joined' for this model: '{simulated_temp_tags_joined}'")
|
257 |
+
|
258 |
+
# 3. Simulate 'has_robot' check for this model
|
259 |
+
robot_keywords = ['robot', 'robotics']
|
260 |
+
robot_pattern = '|'.join(robot_keywords)
|
261 |
+
manual_robot_check = bool(re.search(robot_pattern, simulated_temp_tags_joined, flags=re.IGNORECASE))
|
262 |
+
print(f"3. Manual regex check for 'has_robot' ('{robot_pattern}' in '{simulated_temp_tags_joined}'): {manual_robot_check}")
|
263 |
+
print(f"--- End Pre-Loop Debugging for Model ID: {MODEL_ID_TO_DEBUG} ---\n")
|
264 |
+
elif MODEL_ID_TO_DEBUG:
|
265 |
+
print(f"DEBUG: Model ID '{MODEL_ID_TO_DEBUG}' not found in DataFrame for pre-loop debugging.")
|
266 |
+
# --- END DEBUGGING BLOCK ---
|
267 |
+
|
268 |
+
|
269 |
+
print("Vectorized creation of cached tag columns...")
|
270 |
+
tag_time = time.time()
|
271 |
+
# This is the original temp_tags_joined creation:
|
272 |
+
df['temp_tags_joined'] = df['tags'].apply(
|
273 |
+
lambda tl: '~~~'.join(str(t).lower().strip() for t in tl if pd.notna(t) and str(t).strip()) if isinstance(tl, list) else ''
|
274 |
+
)
|
275 |
+
|
276 |
+
tag_map = {
|
277 |
+
'has_audio': ['audio'], 'has_speech': ['speech'], 'has_music': ['music'],
|
278 |
+
'has_robot': ['robot', 'robotics','openvla','vla'],
|
279 |
+
'has_bio': ['bio'], 'has_med': ['medic', 'medical'],
|
280 |
+
'has_series': ['series', 'time-series', 'timeseries'],
|
281 |
+
'has_video': ['video'], 'has_image': ['image', 'vision'],
|
282 |
+
'has_text': ['text', 'nlp', 'llm']
|
283 |
+
}
|
284 |
+
for col, kws in tag_map.items():
|
285 |
+
pattern = '|'.join(kws)
|
286 |
+
df[col] = df['temp_tags_joined'].str.contains(pattern, na=False, case=False, regex=True)
|
287 |
+
|
288 |
+
df['has_science'] = (
|
289 |
+
df['temp_tags_joined'].str.contains('science', na=False, case=False, regex=True) &
|
290 |
+
~df['temp_tags_joined'].str.contains('bigscience', na=False, case=False, regex=True)
|
291 |
+
)
|
292 |
+
del df['temp_tags_joined'] # Clean up temporary column
|
293 |
+
df['is_audio_speech'] = (df['has_audio'] | df['has_speech'] |
|
294 |
+
df['pipeline_tag'].str.contains('audio|speech', case=False, na=False, regex=True))
|
295 |
+
df['is_biomed'] = df['has_bio'] | df['has_med']
|
296 |
+
print(f"Vectorized tag columns created in {time.time() - tag_time:.2f}s.")
|
297 |
+
|
298 |
+
# --- POST-LOOP DIAGNOSTIC for has_robot & a specific model ---
|
299 |
+
if 'has_robot' in df.columns:
|
300 |
+
print("\n--- 'has_robot' Diagnostics (Preprocessor - Post-Loop) ---")
|
301 |
+
print(df['has_robot'].value_counts(dropna=False))
|
302 |
+
|
303 |
+
if MODEL_ID_TO_DEBUG and MODEL_ID_TO_DEBUG in df['id'].values:
|
304 |
+
model_has_robot_val = df.loc[df['id'] == MODEL_ID_TO_DEBUG, 'has_robot'].iloc[0]
|
305 |
+
print(f"Value of 'has_robot' for model '{MODEL_ID_TO_DEBUG}': {model_has_robot_val}")
|
306 |
+
if model_has_robot_val:
|
307 |
+
print(f" Original tags for '{MODEL_ID_TO_DEBUG}': {df.loc[df['id'] == MODEL_ID_TO_DEBUG, 'tags'].iloc[0]}")
|
308 |
+
|
309 |
+
if df['has_robot'].any():
|
310 |
+
print("Sample models flagged as 'has_robot':")
|
311 |
+
print(df[df['has_robot']][['id', 'tags', 'has_robot']].head(5))
|
312 |
+
else:
|
313 |
+
print("No models were flagged as 'has_robot' after processing.")
|
314 |
+
print("--------------------------------------------------------\n")
|
315 |
+
# --- END POST-LOOP DIAGNOSTIC ---
|
316 |
+
|
317 |
+
|
318 |
+
print("Adding organization column...")
|
319 |
+
df['organization'] = df['id'].apply(extract_org_from_id)
|
320 |
+
|
321 |
+
# Drop safetensors if params was calculated from it, and params didn't pre-exist as numeric
|
322 |
+
if 'safetensors' in df.columns and \
|
323 |
+
not (output_filesize_col_name in df_raw.columns and pd.api.types.is_numeric_dtype(df_raw[output_filesize_col_name])):
|
324 |
+
df = df.drop(columns=['safetensors'], errors='ignore')
|
325 |
+
|
326 |
+
final_expected_cols = [
|
327 |
+
'id', 'downloads', 'downloadsAllTime', 'likes', 'pipeline_tag', 'tags',
|
328 |
+
'params', 'size_category', 'organization',
|
329 |
+
'has_audio', 'has_speech', 'has_music', 'has_robot', 'has_bio', 'has_med',
|
330 |
+
'has_series', 'has_video', 'has_image', 'has_text', 'has_science',
|
331 |
+
'is_audio_speech', 'is_biomed',
|
332 |
+
'data_download_timestamp'
|
333 |
+
]
|
334 |
+
# Ensure all final columns exist, adding defaults if necessary
|
335 |
+
for col in final_expected_cols:
|
336 |
+
if col not in df.columns:
|
337 |
+
print(f"Warning: Final expected column '{col}' is missing! Defaulting appropriately.")
|
338 |
+
if col == 'params': df[col] = 0.0
|
339 |
+
elif col == 'size_category': df[col] = "Small (<1GB)" # Default size category
|
340 |
+
elif 'has_' in col or 'is_' in col : df[col] = False # Default boolean flags to False
|
341 |
+
elif col == 'data_download_timestamp': df[col] = pd.NaT # Default timestamp to NaT
|
342 |
+
|
343 |
+
print(f"Data processing completed in {time.time() - proc_start:.2f}s.")
|
344 |
+
try:
|
345 |
+
print(f"Saving processed data to: {PROCESSED_PARQUET_FILE_PATH}")
|
346 |
+
df_to_save = df[final_expected_cols].copy() # Ensure only expected columns are saved
|
347 |
+
df_to_save.to_parquet(PROCESSED_PARQUET_FILE_PATH, index=False, engine='pyarrow')
|
348 |
+
print(f"Successfully saved processed data.")
|
349 |
+
except Exception as e_save:
|
350 |
+
print(f"ERROR: Could not save processed data: {e_save}")
|
351 |
+
return
|
352 |
+
|
353 |
+
total_elapsed_script = time.time() - overall_start_time
|
354 |
+
print(f"Pre-processing finished. Total time: {total_elapsed_script:.2f}s. Final Parquet shape: {df_to_save.shape}")
|
355 |
+
|
356 |
+
if __name__ == "__main__":
|
357 |
+
if os.path.exists(PROCESSED_PARQUET_FILE_PATH):
|
358 |
+
print(f"Deleting existing '{PROCESSED_PARQUET_FILE_PATH}' to ensure fresh processing...")
|
359 |
+
try: os.remove(PROCESSED_PARQUET_FILE_PATH)
|
360 |
+
except OSError as e: print(f"Error deleting file: {e}. Please delete manually and rerun."); exit()
|
361 |
+
|
362 |
+
main_preprocessor()
|
363 |
+
|
364 |
+
if os.path.exists(PROCESSED_PARQUET_FILE_PATH):
|
365 |
+
print(f"\nTo verify, load parquet and check 'has_robot' and its 'tags':")
|
366 |
+
print(f"import pandas as pd; df_chk = pd.read_parquet('{PROCESSED_PARQUET_FILE_PATH}')")
|
367 |
+
print(f"print(df_chk['has_robot'].value_counts())")
|
368 |
+
if MODEL_ID_TO_DEBUG:
|
369 |
+
print(f"print(df_chk[df_chk['id'] == '{MODEL_ID_TO_DEBUG}'][['id', 'tags', 'has_robot']])")
|
370 |
+
else:
|
371 |
+
print(f"print(df_chk[df_chk['has_robot']][['id', 'tags', 'has_robot']].head())")
|