evijit HF Staff commited on
Commit
9d2f4f2
·
verified ·
1 Parent(s): 3610b1c

carryover from evijit

Browse files
Files changed (3) hide show
  1. app.py +468 -788
  2. models_processed.parquet +3 -0
  3. preprocess.py +371 -0
app.py CHANGED
@@ -1,846 +1,526 @@
 
 
1
  import json
2
  import gradio as gr
3
  import pandas as pd
4
  import plotly.express as px
5
  import os
6
  import numpy as np
7
- import io
8
  import duckdb
 
 
 
9
 
10
- # Define pipeline tags
11
- PIPELINE_TAGS = [
12
- 'text-generation',
13
- 'text-to-image',
14
- 'text-classification',
15
- 'text2text-generation',
16
- 'audio-to-audio',
17
- 'feature-extraction',
18
- 'image-classification',
19
- 'translation',
20
- 'reinforcement-learning',
21
- 'fill-mask',
22
- 'text-to-speech',
23
- 'automatic-speech-recognition',
24
- 'image-text-to-text',
25
- 'token-classification',
26
- 'sentence-similarity',
27
- 'question-answering',
28
- 'image-feature-extraction',
29
- 'summarization',
30
- 'zero-shot-image-classification',
31
- 'object-detection',
32
- 'image-segmentation',
33
- 'image-to-image',
34
- 'image-to-text',
35
- 'audio-classification',
36
- 'visual-question-answering',
37
- 'text-to-video',
38
- 'zero-shot-classification',
39
- 'depth-estimation',
40
- 'text-ranking',
41
- 'image-to-video',
42
- 'multiple-choice',
43
- 'unconditional-image-generation',
44
- 'video-classification',
45
- 'text-to-audio',
46
- 'time-series-forecasting',
47
- 'any-to-any',
48
- 'video-text-to-text',
49
- 'table-question-answering',
50
- ]
51
-
52
- # Model size categories in GB
53
  MODEL_SIZE_RANGES = {
54
- "Small (<1GB)": (0, 1),
55
- "Medium (1-5GB)": (1, 5),
56
- "Large (5-20GB)": (5, 20),
57
- "X-Large (20-50GB)": (20, 50),
58
- "XX-Large (>50GB)": (50, float('inf'))
59
  }
 
 
60
 
61
- # Filter functions for tags - UPDATED to use cached columns
62
- def is_audio_speech(row):
63
- # Use cached column instead of recalculating
64
- return row['is_audio_speech']
65
-
66
- def is_music(row):
67
- # Use cached column instead of recalculating
68
- return row['has_music']
69
-
70
- def is_robotics(row):
71
- # Use cached column instead of recalculating
72
- return row['has_robot']
73
 
74
- def is_biomed(row):
75
- # Use cached column instead of recalculating
76
- return row['is_biomed']
 
 
 
 
 
 
 
 
 
 
77
 
78
- def is_timeseries(row):
79
- # Use cached column instead of recalculating
80
- return row['has_series']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
- def is_science(row):
83
- # Use cached column instead of recalculating
84
- return row['has_science']
 
85
 
86
- def is_video(row):
87
- # Use cached column instead of recalculating
88
- return row['has_video']
89
 
90
- def is_image(row):
91
- # Use cached column instead of recalculating
92
- return row['has_image']
93
 
94
- def is_text(row):
95
- # Use cached column instead of recalculating
96
- return row['has_text']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
- def is_image(row):
99
- tags = row.get("tags", [])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
- # Check if tags exists and is not empty
102
- if tags is not None:
103
- # For numpy arrays
104
- if hasattr(tags, 'dtype') and hasattr(tags, 'tolist'):
105
- # Convert numpy array to list
106
- tags_list = tags.tolist()
107
- return any("image" in str(tag).lower() for tag in tags_list)
108
- # For regular lists
109
- elif isinstance(tags, list):
110
- return any("image" in str(tag).lower() for tag in tags)
111
- # For string tags
112
- elif isinstance(tags, str):
113
- return "image" in tags.lower()
114
- return False
115
-
116
- def is_text(row):
117
- tags = row.get("tags", [])
118
 
119
- # Check if tags exists and is not empty
120
- if tags is not None:
121
- # For numpy arrays
122
- if hasattr(tags, 'dtype') and hasattr(tags, 'tolist'):
123
- # Convert numpy array to list
124
- tags_list = tags.tolist()
125
- return any("text" in str(tag).lower() for tag in tags_list)
126
- # For regular lists
127
- elif isinstance(tags, list):
128
- return any("text" in str(tag).lower() for tag in tags)
129
- # For string tags
130
- elif isinstance(tags, str):
131
- return "text" in tags.lower()
132
- return False
133
-
134
- def extract_model_size(safetensors_data):
135
- """Extract model size in GB from safetensors data"""
136
- try:
137
- if pd.isna(safetensors_data):
138
- return 0
139
-
140
- # If it's already a dictionary, use it directly
141
- if isinstance(safetensors_data, dict):
142
- if 'total' in safetensors_data:
143
- try:
144
- size_bytes = float(safetensors_data['total'])
145
- return size_bytes / (1024 * 1024 * 1024) # Convert to GB
146
- except (ValueError, TypeError):
147
- pass
148
-
149
- # If it's a string, try to parse it as JSON
150
- elif isinstance(safetensors_data, str):
151
- try:
152
- data_dict = json.loads(safetensors_data)
153
- if 'total' in data_dict:
154
- try:
155
- size_bytes = float(data_dict['total'])
156
- return size_bytes / (1024 * 1024 * 1024) # Convert to GB
157
- except (ValueError, TypeError):
158
- pass
159
- except:
160
- pass
161
-
162
- return 0
163
- except Exception as e:
164
- print(f"Error extracting model size: {e}")
165
- return 0
166
-
167
- # Add model size filter function - UPDATED to use cached size_category column
168
- def is_in_size_range(row, size_range):
169
- """Check if a model is in the specified size range using pre-calculated size category"""
170
- if size_range is None or size_range == "None":
171
- return True
172
 
173
- # Simply compare with cached size_category
174
- return row['size_category'] == size_range
175
-
176
- TAG_FILTER_FUNCS = {
177
- "Audio & Speech": is_audio_speech,
178
- "Time series": is_timeseries,
179
- "Robotics": is_robotics,
180
- "Music": is_music,
181
- "Video": is_video,
182
- "Images": is_image,
183
- "Text": is_text,
184
- "Biomedical": is_biomed,
185
- "Sciences": is_science,
186
- }
187
 
188
- def extract_org_from_id(model_id):
189
- """Extract organization name from model ID"""
190
- if "/" in model_id:
191
- return model_id.split("/")[0]
192
- return "unaffiliated"
193
 
194
  def make_treemap_data(df, count_by, top_k=25, tag_filter=None, pipeline_filter=None, size_filter=None, skip_orgs=None):
195
- """Process DataFrame into treemap format with filters applied - OPTIMIZED with cached columns"""
196
- # Create a copy to avoid modifying the original
197
  filtered_df = df.copy()
 
 
 
198
 
199
- # Apply filters
200
- filter_stats = {"initial": len(filtered_df)}
201
- start_time = pd.Timestamp.now()
202
-
203
- # Apply tag filter - OPTIMIZED to use cached columns
204
- if tag_filter and tag_filter in TAG_FILTER_FUNCS:
205
- print(f"Applying tag filter: {tag_filter}")
206
-
207
- # Use direct column filtering instead of applying a function to each row
208
- if tag_filter == "Audio & Speech":
209
- filtered_df = filtered_df[filtered_df['is_audio_speech']]
210
- elif tag_filter == "Music":
211
- filtered_df = filtered_df[filtered_df['has_music']]
212
- elif tag_filter == "Robotics":
213
- filtered_df = filtered_df[filtered_df['has_robot']]
214
- elif tag_filter == "Biomedical":
215
- filtered_df = filtered_df[filtered_df['is_biomed']]
216
- elif tag_filter == "Time series":
217
- filtered_df = filtered_df[filtered_df['has_series']]
218
- elif tag_filter == "Sciences":
219
- filtered_df = filtered_df[filtered_df['has_science']]
220
- elif tag_filter == "Video":
221
- filtered_df = filtered_df[filtered_df['has_video']]
222
- elif tag_filter == "Images":
223
- filtered_df = filtered_df[filtered_df['has_image']]
224
- elif tag_filter == "Text":
225
- filtered_df = filtered_df[filtered_df['has_text']]
226
-
227
- filter_stats["after_tag_filter"] = len(filtered_df)
228
- print(f"Tag filter applied in {(pd.Timestamp.now() - start_time).total_seconds():.3f} seconds")
229
- start_time = pd.Timestamp.now()
230
-
231
- # Apply pipeline filter
232
  if pipeline_filter:
233
- print(f"Applying pipeline filter: {pipeline_filter}")
234
- filtered_df = filtered_df[filtered_df["pipeline_tag"] == pipeline_filter]
235
- filter_stats["after_pipeline_filter"] = len(filtered_df)
236
- print(f"Pipeline filter applied in {(pd.Timestamp.now() - start_time).total_seconds():.3f} seconds")
237
- start_time = pd.Timestamp.now()
238
-
239
- # Apply size filter - OPTIMIZED to use cached size_category column
240
- if size_filter and size_filter in MODEL_SIZE_RANGES:
241
- print(f"Applying size filter: {size_filter}")
242
-
243
- # Use the cached size_category column directly
244
- filtered_df = filtered_df[filtered_df['size_category'] == size_filter]
245
-
246
- # Debug info
247
- print(f"Size filter '{size_filter}' applied.")
248
- print(f"Models after size filter: {len(filtered_df)}")
249
-
250
- filter_stats["after_size_filter"] = len(filtered_df)
251
- print(f"Size filter applied in {(pd.Timestamp.now() - start_time).total_seconds():.3f} seconds")
252
- start_time = pd.Timestamp.now()
253
-
254
- # Add organization column
255
- filtered_df["organization"] = filtered_df["id"].apply(extract_org_from_id)
256
-
257
- # Skip organizations if specified
258
  if skip_orgs and len(skip_orgs) > 0:
259
- filtered_df = filtered_df[~filtered_df["organization"].isin(skip_orgs)]
260
- filter_stats["after_skip_orgs"] = len(filtered_df)
261
-
262
- # Print filter stats
263
- print("Filter statistics:")
264
- for stage, count in filter_stats.items():
265
- print(f" {stage}: {count} models")
266
-
267
- # Check if we have any data left
268
- if filtered_df.empty:
269
- print("Warning: No data left after applying filters!")
270
- return pd.DataFrame() # Return empty DataFrame
271
-
272
- # Aggregate by organization
273
- org_totals = filtered_df.groupby("organization")[count_by].sum().reset_index()
274
- org_totals = org_totals.sort_values(by=count_by, ascending=False)
275
-
276
- # Get top organizations
277
- top_orgs = org_totals.head(top_k)["organization"].tolist()
278
-
279
- # Filter to only include models from top organizations
280
- filtered_df = filtered_df[filtered_df["organization"].isin(top_orgs)]
281
-
282
- # Prepare data for treemap
283
- treemap_data = filtered_df[["id", "organization", count_by]].copy()
284
-
285
- # Add a root node
286
  treemap_data["root"] = "models"
287
-
288
- # Ensure numeric values
289
- treemap_data[count_by] = pd.to_numeric(treemap_data[count_by], errors="coerce").fillna(0)
290
-
291
- print(f"Treemap data prepared in {(pd.Timestamp.now() - start_time).total_seconds():.3f} seconds")
292
  return treemap_data
293
 
294
  def create_treemap(treemap_data, count_by, title=None):
295
- """Create a Plotly treemap from the prepared data"""
296
  if treemap_data.empty:
297
- # Create an empty figure with a message
298
- fig = px.treemap(
299
- names=["No data matches the selected filters"],
300
- values=[1]
301
- )
302
- fig.update_layout(
303
- title="No data matches the selected filters",
304
- margin=dict(t=50, l=25, r=25, b=25)
305
- )
306
  return fig
307
-
308
- # Create the treemap
309
  fig = px.treemap(
310
- treemap_data,
311
- path=["root", "organization", "id"],
312
- values=count_by,
313
  title=title or f"HuggingFace Models - {count_by.capitalize()} by Organization",
314
  color_discrete_sequence=px.colors.qualitative.Plotly
315
  )
316
-
317
- # Update layout
318
- fig.update_layout(
319
- margin=dict(t=50, l=25, r=25, b=25)
320
- )
321
-
322
- # Update traces for better readability
323
- fig.update_traces(
324
- textinfo="label+value+percent root",
325
- hovertemplate="<b>%{label}</b><br>%{value:,} " + count_by + "<br>%{percentRoot:.2%} of total<extra></extra>"
326
- )
327
-
328
  return fig
329
 
330
- def load_models_data():
331
- """Load models data from Hugging Face using DuckDB with caching for improved performance"""
332
- try:
333
- # The URL to the parquet file
334
- parquet_url = "https://huggingface.co/datasets/cfahlgren1/hub-stats/resolve/main/models.parquet"
335
-
336
- print("Fetching data from Hugging Face models.parquet...")
337
-
338
- # Based on the column names provided, we can directly select the columns we need
339
- # Note: We need to select safetensors to get the model size information
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
340
  try:
341
- query = """
342
- SELECT
343
- id,
344
- downloads,
345
- downloadsAllTime,
346
- likes,
347
- pipeline_tag,
348
- tags,
349
- safetensors
350
- FROM read_parquet('https://huggingface.co/datasets/cfahlgren1/hub-stats/resolve/main/models.parquet')
351
- """
352
- df = duckdb.sql(query).df()
353
- except Exception as sql_error:
354
- print(f"Error with specific column selection: {sql_error}")
355
- # Fallback to just selecting everything and then filtering
356
- print("Falling back to select * query...")
357
- query = "SELECT * FROM read_parquet('https://huggingface.co/datasets/cfahlgren1/hub-stats/resolve/main/models.parquet')"
358
- raw_df = duckdb.sql(query).df()
359
-
360
- # Now extract only the columns we need
361
- needed_columns = ['id', 'downloads', 'downloadsAllTime', 'likes', 'pipeline_tag', 'tags', 'safetensors']
362
- available_columns = set(raw_df.columns)
363
- df = pd.DataFrame()
364
-
365
- # Copy over columns that exist
366
- for col in needed_columns:
367
- if col in available_columns:
368
- df[col] = raw_df[col]
369
  else:
370
- # Create empty columns for missing data
371
- if col in ['downloads', 'downloadsAllTime', 'likes']:
372
- df[col] = 0
373
- elif col == 'pipeline_tag':
374
- df[col] = ''
375
- elif col == 'tags':
376
- df[col] = [[] for _ in range(len(raw_df))]
377
- elif col == 'safetensors':
378
- df[col] = None
379
- elif col == 'id':
380
- # Create IDs based on index if missing
381
- df[col] = [f"model_{i}" for i in range(len(raw_df))]
382
-
383
- print(f"Data fetched successfully. Shape: {df.shape}")
384
-
385
- # Check if safetensors column exists before trying to process it
386
- if 'safetensors' in df.columns:
387
- # Add params column derived from safetensors.total (model size in GB)
388
- df['params'] = df['safetensors'].apply(extract_model_size)
389
-
390
- # Debug model sizes
391
- size_ranges = {
392
- "Small (<1GB)": 0,
393
- "Medium (1-5GB)": 0,
394
- "Large (5-20GB)": 0,
395
- "X-Large (20-50GB)": 0,
396
- "XX-Large (>50GB)": 0
397
- }
398
-
399
- # Count models in each size range
400
- for idx, row in df.iterrows():
401
- size_gb = row['params']
402
- if 0 <= size_gb < 1:
403
- size_ranges["Small (<1GB)"] += 1
404
- elif 1 <= size_gb < 5:
405
- size_ranges["Medium (1-5GB)"] += 1
406
- elif 5 <= size_gb < 20:
407
- size_ranges["Large (5-20GB)"] += 1
408
- elif 20 <= size_gb < 50:
409
- size_ranges["X-Large (20-50GB)"] += 1
410
- elif size_gb >= 50:
411
- size_ranges["XX-Large (>50GB)"] += 1
412
-
413
- print("Model size distribution:")
414
- for size_range, count in size_ranges.items():
415
- print(f" {size_range}: {count} models")
416
-
417
- # CACHE SIZE CATEGORY: Add a size_category column for faster filtering
418
- def get_size_category(size_gb):
419
- if 0 <= size_gb < 1:
420
- return "Small (<1GB)"
421
- elif 1 <= size_gb < 5:
422
- return "Medium (1-5GB)"
423
- elif 5 <= size_gb < 20:
424
- return "Large (5-20GB)"
425
- elif 20 <= size_gb < 50:
426
- return "X-Large (20-50GB)"
427
- elif size_gb >= 50:
428
- return "XX-Large (>50GB)"
429
- return None
430
-
431
- # Add cached size category column
432
- df['size_category'] = df['params'].apply(get_size_category)
433
-
434
- # Remove the safetensors column as we don't need it anymore
435
- df = df.drop(columns=['safetensors'])
436
- else:
437
- # If no safetensors column, add empty params column
438
- df['params'] = 0
439
- df['size_category'] = None
440
-
441
- # Process tags to ensure it's in the right format - FIXED
442
- def process_tags(tags_value):
443
- try:
444
- if pd.isna(tags_value) or tags_value is None:
445
- return []
446
-
447
- # If it's a numpy array, convert to a list of strings
448
- if hasattr(tags_value, 'dtype') and hasattr(tags_value, 'tolist'):
449
- # Note: This is the fix for the error
450
- return [str(tag) for tag in tags_value.tolist()]
451
-
452
- # If already a list, ensure all elements are strings
453
- if isinstance(tags_value, list):
454
- return [str(tag) for tag in tags_value]
455
 
456
- # If string, try to parse as JSON or split by comma
457
- if isinstance(tags_value, str):
458
- try:
459
- tags_list = json.loads(tags_value)
460
- if isinstance(tags_list, list):
461
- return [str(tag) for tag in tags_list]
462
- except:
463
- # Split by comma if JSON parsing fails
464
- return [tag.strip() for tag in tags_value.split(',') if tag.strip()]
465
 
466
- # Last resort, convert to string and return as a single tag
467
- return [str(tags_value)]
 
 
 
468
 
469
- except Exception as e:
470
- print(f"Error processing tags: {e}")
471
- return []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
472
 
473
- # Check if tags column exists before trying to process it
474
- if 'tags' in df.columns:
475
- # Process tags column
476
- df['tags'] = df['tags'].apply(process_tags)
477
-
478
- # CACHE TAG CATEGORIES: Pre-calculate tag categories for faster filtering
479
- print("Pre-calculating cached tag categories...")
480
-
481
- # Helper functions to check for specific tags (simplified for caching)
482
- def has_audio_tag(tags):
483
- if tags and isinstance(tags, list):
484
- return any("audio" in str(tag).lower() for tag in tags)
485
- return False
486
-
487
- def has_speech_tag(tags):
488
- if tags and isinstance(tags, list):
489
- return any("speech" in str(tag).lower() for tag in tags)
490
- return False
491
-
492
- def has_music_tag(tags):
493
- if tags and isinstance(tags, list):
494
- return any("music" in str(tag).lower() for tag in tags)
495
- return False
496
-
497
- def has_robot_tag(tags):
498
- if tags and isinstance(tags, list):
499
- return any("robot" in str(tag).lower() for tag in tags)
500
- return False
501
-
502
- def has_bio_tag(tags):
503
- if tags and isinstance(tags, list):
504
- return any("bio" in str(tag).lower() for tag in tags)
505
- return False
506
-
507
- def has_med_tag(tags):
508
- if tags and isinstance(tags, list):
509
- return any("medic" in str(tag).lower() for tag in tags)
510
- return False
511
-
512
- def has_series_tag(tags):
513
- if tags and isinstance(tags, list):
514
- return any("series" in str(tag).lower() for tag in tags)
515
- return False
516
-
517
- def has_science_tag(tags):
518
- if tags and isinstance(tags, list):
519
- return any("science" in str(tag).lower() and "bigscience" not in str(tag).lower() for tag in tags)
520
- return False
521
-
522
- def has_video_tag(tags):
523
- if tags and isinstance(tags, list):
524
- return any("video" in str(tag).lower() for tag in tags)
525
- return False
526
-
527
- def has_image_tag(tags):
528
- if tags and isinstance(tags, list):
529
- return any("image" in str(tag).lower() for tag in tags)
530
- return False
531
-
532
- def has_text_tag(tags):
533
- if tags and isinstance(tags, list):
534
- return any("text" in str(tag).lower() for tag in tags)
535
- return False
536
-
537
- # Add cached columns for tag categories
538
- print("Creating cached tag columns...")
539
- df['has_audio'] = df['tags'].apply(has_audio_tag)
540
- df['has_speech'] = df['tags'].apply(has_speech_tag)
541
- df['has_music'] = df['tags'].apply(has_music_tag)
542
- df['has_robot'] = df['tags'].apply(has_robot_tag)
543
- df['has_bio'] = df['tags'].apply(has_bio_tag)
544
- df['has_med'] = df['tags'].apply(has_med_tag)
545
- df['has_series'] = df['tags'].apply(has_series_tag)
546
- df['has_science'] = df['tags'].apply(has_science_tag)
547
- df['has_video'] = df['tags'].apply(has_video_tag)
548
- df['has_image'] = df['tags'].apply(has_image_tag)
549
- df['has_text'] = df['tags'].apply(has_text_tag)
550
-
551
- # Create combined category flags for faster filtering
552
- df['is_audio_speech'] = (df['has_audio'] | df['has_speech'] |
553
- df['pipeline_tag'].str.contains('audio', case=False, na=False) |
554
- df['pipeline_tag'].str.contains('speech', case=False, na=False))
555
- df['is_biomed'] = df['has_bio'] | df['has_med']
556
-
557
- print("Cached tag columns created successfully!")
558
- else:
559
- # If no tags column, add empty tags and set all category flags to False
560
- df['tags'] = [[] for _ in range(len(df))]
561
- for col in ['has_audio', 'has_speech', 'has_music', 'has_robot',
562
- 'has_bio', 'has_med', 'has_series', 'has_science',
563
- 'has_video', 'has_image', 'has_text',
564
- 'is_audio_speech', 'is_biomed']:
565
- df[col] = False
566
-
567
- # Fill NaN values
568
- df.fillna({'downloads': 0, 'downloadsAllTime': 0, 'likes': 0, 'params': 0}, inplace=True)
569
-
570
- # Ensure pipeline_tag is a string
571
- if 'pipeline_tag' in df.columns:
572
- df['pipeline_tag'] = df['pipeline_tag'].fillna('')
573
- else:
574
- df['pipeline_tag'] = ''
575
-
576
- # Make sure all required columns exist
577
- for col in ['id', 'downloads', 'downloadsAllTime', 'likes', 'pipeline_tag', 'tags', 'params']:
578
- if col not in df.columns:
579
- if col in ['downloads', 'downloadsAllTime', 'likes', 'params']:
580
- df[col] = 0
581
- elif col == 'pipeline_tag':
582
- df[col] = ''
583
- elif col == 'tags':
584
- df[col] = [[] for _ in range(len(df))]
585
- elif col == 'id':
586
- df[col] = [f"model_{i}" for i in range(len(df))]
587
-
588
- print(f"Successfully processed {len(df)} models with cached tag and size information")
589
- return df, True
590
-
591
- except Exception as e:
592
- print(f"Error loading data: {e}")
593
- # Return an empty DataFrame and False to indicate loading failure
594
- return pd.DataFrame(), False
595
-
596
- # Create Gradio interface
597
- with gr.Blocks() as demo:
598
- models_data = gr.State()
599
- loading_complete = gr.State(False) # Flag to indicate data load completion
600
-
601
- with gr.Row():
602
- gr.Markdown("""
603
- # HuggingFace Models TreeMap Visualization
604
-
605
- This app shows how different organizations contribute to the HuggingFace ecosystem with their models.
606
- Use the filters to explore models by different metrics, tags, pipelines, and model sizes.
607
 
608
- The treemap visualizes models grouped by organization, with the size of each box representing the selected metric.
609
-
610
- """)
611
-
612
- with gr.Row():
613
- with gr.Column(scale=1):
614
- count_by_dropdown = gr.Dropdown(
615
- label="Metric",
616
- choices=[
617
- ("Downloads (last 30 days)", "downloads"),
618
- ("Downloads (All Time)", "downloadsAllTime"),
619
- ("Likes", "likes")
620
- ],
621
- value="downloads",
622
- info="Select the metric to determine box sizes"
623
- )
624
-
625
- filter_choice_radio = gr.Radio(
626
- label="Filter Type",
627
- choices=["None", "Tag Filter", "Pipeline Filter"],
628
- value="None",
629
- info="Choose how to filter the models"
630
- )
631
-
632
- tag_filter_dropdown = gr.Dropdown(
633
- label="Select Tag",
634
- choices=list(TAG_FILTER_FUNCS.keys()),
635
- value=None,
636
- visible=False,
637
- info="Filter models by domain/category"
638
- )
639
-
640
- pipeline_filter_dropdown = gr.Dropdown(
641
- label="Select Pipeline Tag",
642
- choices=PIPELINE_TAGS,
643
- value=None,
644
- visible=False,
645
- info="Filter models by specific pipeline"
646
- )
647
-
648
- size_filter_dropdown = gr.Dropdown(
649
- label="Model Size Filter",
650
- choices=["None"] + list(MODEL_SIZE_RANGES.keys()),
651
- value="None",
652
- info="Filter models by their size (using params column)"
653
- )
654
-
655
- top_k_slider = gr.Slider(
656
- label="Number of Top Organizations",
657
- minimum=5,
658
- maximum=50,
659
- value=25,
660
- step=5,
661
- info="Number of top organizations to include"
662
- )
663
-
664
- skip_orgs_textbox = gr.Textbox(
665
- label="Organizations to Skip (comma-separated)",
666
- placeholder="e.g., OpenAI, Google",
667
- value="TheBloke, MaziyarPanahi, unsloth, modularai, Gensyn, bartowski"
668
- )
669
-
670
- generate_plot_button = gr.Button("Generate Plot", variant="primary", interactive=False)
671
- refresh_data_button = gr.Button("Refresh Data from Hugging Face", variant="secondary")
672
-
673
- with gr.Column(scale=3):
674
- plot_output = gr.Plot()
675
- stats_output = gr.Markdown("*Loading data from Hugging Face...*")
676
- data_info = gr.Markdown("")
677
-
678
- # Button enablement after data load
679
- def enable_plot_button(loaded):
680
- return gr.update(interactive=loaded)
681
-
682
- loading_complete.change(
683
- fn=enable_plot_button,
684
- inputs=[loading_complete],
685
- outputs=[generate_plot_button]
686
- )
687
-
688
- # Show/hide tag/pipeline dropdown
689
- def update_filter_visibility(filter_choice):
690
- if filter_choice == "Tag Filter":
691
- return gr.update(visible=True), gr.update(visible=False)
692
- elif filter_choice == "Pipeline Filter":
693
- return gr.update(visible=False), gr.update(visible=True)
694
- else:
695
- return gr.update(visible=False), gr.update(visible=False)
696
-
697
- filter_choice_radio.change(
698
- fn=update_filter_visibility,
699
- inputs=[filter_choice_radio],
700
- outputs=[tag_filter_dropdown, pipeline_filter_dropdown]
701
- )
702
-
703
- # Function to handle data load and provide data info
704
- def load_and_provide_info():
705
- df, success = load_models_data()
706
 
707
- if success:
708
- # Generate information about the loaded data
709
- info_text = f"""
710
- ### Data Information
711
- - **Total models loaded**: {len(df):,}
712
- - **Last update**: {pd.Timestamp.now().strftime('%Y-%m-%d %H:%M:%S')}
713
- - **Data source**: [Hugging Face Hub Stats](https://huggingface.co/datasets/cfahlgren1/hub-stats) (models.parquet)
714
- """
715
-
716
- # Return the data, loading status, and info text
717
- return df, True, info_text, "*Data loaded successfully. Use the controls to generate a plot.*"
718
- else:
719
- # Return empty data, failed loading status, and error message
720
- return pd.DataFrame(), False, "*Error loading data from Hugging Face.*", "*Failed to load data. Please try again.*"
721
-
722
- # Main generate function
723
- def generate_plot_on_click(count_by, filter_choice, tag_filter, pipeline_filter, size_filter, top_k, skip_orgs_text, data_df):
724
- if data_df is None or not isinstance(data_df, pd.DataFrame) or data_df.empty:
725
- return None, "Error: Data is still loading. Please wait a moment and try again."
726
-
727
- selected_tag_filter = None
728
- selected_pipeline_filter = None
729
- selected_size_filter = None
730
-
731
- if filter_choice == "Tag Filter":
732
- selected_tag_filter = tag_filter
733
- elif filter_choice == "Pipeline Filter":
734
- selected_pipeline_filter = pipeline_filter
735
-
736
- if size_filter != "None":
737
- selected_size_filter = size_filter
738
-
739
- skip_orgs = []
740
- if skip_orgs_text and skip_orgs_text.strip():
741
- skip_orgs = [org.strip() for org in skip_orgs_text.split(',') if org.strip()]
742
-
743
- treemap_data = make_treemap_data(
744
- df=data_df,
745
- count_by=count_by,
746
- top_k=top_k,
747
- tag_filter=selected_tag_filter,
748
- pipeline_filter=selected_pipeline_filter,
749
- size_filter=selected_size_filter,
750
- skip_orgs=skip_orgs
751
- )
752
-
753
- title_labels = {
754
- "downloads": "Downloads (last 30 days)",
755
- "downloadsAllTime": "Downloads (All Time)",
756
- "likes": "Likes"
757
- }
758
- title_text = f"HuggingFace Models - {title_labels.get(count_by, count_by)} by Organization"
759
-
760
- fig = create_treemap(
761
- treemap_data=treemap_data,
762
- count_by=count_by,
763
- title=title_text
764
- )
765
-
766
- if treemap_data.empty:
767
- stats_md = "No data matches the selected filters."
768
  else:
769
- total_models = len(treemap_data)
770
- total_value = treemap_data[count_by].sum()
771
-
772
- # Get top 5 organizations
773
- top_5_orgs = treemap_data.groupby("organization")[count_by].sum().sort_values(ascending=False).head(5)
774
-
775
- # Get top 5 individual models
776
- top_5_models = treemap_data[["id", count_by]].sort_values(by=count_by, ascending=False).head(5)
777
-
778
- # Create statistics section
779
- stats_md = f"""
780
- ## Statistics
781
- - **Total models shown**: {total_models:,}
782
- - **Total {count_by}**: {int(total_value):,}
783
-
784
- ## Top Organizations by {count_by.capitalize()}
785
 
786
- | Organization | {count_by.capitalize()} | % of Total |
787
- |--------------|-------------:|----------:|
788
- """
789
-
790
- # Add top organizations to the table
791
- for org, value in top_5_orgs.items():
792
- percentage = (value / total_value) * 100
793
- stats_md += f"| {org} | {int(value):,} | {percentage:.2f}% |\n"
794
-
795
- # Add the top models table
796
- stats_md += f"""
797
- ## Top Models by {count_by.capitalize()}
798
-
799
- | Model | {count_by.capitalize()} | % of Total |
800
- |-------|-------------:|----------:|
801
- """
802
-
803
- # Add top models to the table
804
- for _, row in top_5_models.iterrows():
805
- model_id = row["id"]
806
- value = row[count_by]
807
- percentage = (value / total_value) * 100
808
- stats_md += f"| {model_id} | {int(value):,} | {percentage:.2f}% |\n"
809
-
810
- # Add note about skipped organizations if any
811
- if skip_orgs:
812
- stats_md += f"\n*Note: {len(skip_orgs)} organization(s) excluded: {', '.join(skip_orgs)}*"
813
-
814
- return fig, stats_md
815
-
816
- # Load data at startup
817
  demo.load(
818
- fn=load_and_provide_info,
819
- inputs=[],
820
- outputs=[models_data, loading_complete, data_info, stats_output]
821
  )
822
-
823
- # Refresh data when button is clicked
824
  refresh_data_button.click(
825
- fn=load_and_provide_info,
826
- inputs=[],
827
- outputs=[models_data, loading_complete, data_info, stats_output]
828
  )
829
-
830
  generate_plot_button.click(
831
- fn=generate_plot_on_click,
832
- inputs=[
833
- count_by_dropdown,
834
- filter_choice_radio,
835
- tag_filter_dropdown,
836
- pipeline_filter_dropdown,
837
- size_filter_dropdown,
838
- top_k_slider,
839
- skip_orgs_textbox,
840
- models_data
841
- ],
842
- outputs=[plot_output, stats_output]
843
  )
844
 
845
  if __name__ == "__main__":
846
- demo.launch()
 
 
 
 
 
 
 
 
1
+ # --- START OF FILE app.py ---
2
+
3
  import json
4
  import gradio as gr
5
  import pandas as pd
6
  import plotly.express as px
7
  import os
8
  import numpy as np
 
9
  import duckdb
10
+ from tqdm.auto import tqdm # Standard tqdm for console, gr.Progress will track it
11
+ import time
12
+ import ast # For safely evaluating string representations of lists/dicts
13
 
14
+ # --- Constants ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  MODEL_SIZE_RANGES = {
16
+ "Small (<1GB)": (0, 1), "Medium (1-5GB)": (1, 5), "Large (5-20GB)": (5, 20),
17
+ "X-Large (20-50GB)": (20, 50), "XX-Large (>50GB)": (50, float('inf'))
 
 
 
18
  }
19
+ PROCESSED_PARQUET_FILE_PATH = "models_processed.parquet"
20
+ HF_PARQUET_URL = 'https://huggingface.co/datasets/cfahlgren1/hub-stats/resolve/main/models.parquet' # Added for completeness within app.py context
21
 
22
+ TAG_FILTER_CHOICES = [
23
+ "Audio & Speech", "Time series", "Robotics", "Music", "Video", "Images",
24
+ "Text", "Biomedical", "Sciences"
25
+ ]
 
 
 
 
 
 
 
 
26
 
27
+ PIPELINE_TAGS = [
28
+ 'text-generation', 'text-to-image', 'text-classification', 'text2text-generation',
29
+ 'audio-to-audio', 'feature-extraction', 'image-classification', 'translation',
30
+ 'reinforcement-learning', 'fill-mask', 'text-to-speech', 'automatic-speech-recognition',
31
+ 'image-text-to-text', 'token-classification', 'sentence-similarity', 'question-answering',
32
+ 'image-feature-extraction', 'summarization', 'zero-shot-image-classification',
33
+ 'object-detection', 'image-segmentation', 'image-to-image', 'image-to-text',
34
+ 'audio-classification', 'visual-question-answering', 'text-to-video',
35
+ 'zero-shot-classification', 'depth-estimation', 'text-ranking', 'image-to-video',
36
+ 'multiple-choice', 'unconditional-image-generation', 'video-classification',
37
+ 'text-to-audio', 'time-series-forecasting', 'any-to-any', 'video-text-to-text',
38
+ 'table-question-answering',
39
+ ]
40
 
41
+ def extract_model_size(safetensors_data):
42
+ try:
43
+ if pd.isna(safetensors_data): return 0.0
44
+ data_to_parse = safetensors_data
45
+ if isinstance(safetensors_data, str):
46
+ try:
47
+ if (safetensors_data.startswith('{') and safetensors_data.endswith('}')) or \
48
+ (safetensors_data.startswith('[') and safetensors_data.endswith(']')):
49
+ data_to_parse = ast.literal_eval(safetensors_data)
50
+ else: data_to_parse = json.loads(safetensors_data)
51
+ except: return 0.0
52
+ if isinstance(data_to_parse, dict) and 'total' in data_to_parse:
53
+ try:
54
+ total_bytes_val = data_to_parse['total']
55
+ size_bytes = float(total_bytes_val)
56
+ return size_bytes / (1024 * 1024 * 1024)
57
+ except (ValueError, TypeError): pass
58
+ return 0.0
59
+ except: return 0.0
60
 
61
+ def extract_org_from_id(model_id):
62
+ if pd.isna(model_id): return "unaffiliated"
63
+ model_id_str = str(model_id)
64
+ return model_id_str.split("/")[0] if "/" in model_id_str else "unaffiliated"
65
 
66
+ def process_tags_for_series(series_of_tags_values):
67
+ processed_tags_accumulator = []
 
68
 
69
+ for i, tags_value_from_series in enumerate(tqdm(series_of_tags_values, desc="Standardizing Tags", leave=False, unit="row")):
70
+ temp_processed_list_for_row = []
71
+ current_value_for_error_msg = str(tags_value_from_series)[:200] # Truncate for long error messages
72
 
73
+ try:
74
+ # Order of checks is important!
75
+ # 1. Handle explicit Python lists first
76
+ if isinstance(tags_value_from_series, list):
77
+ current_tags_in_list = []
78
+ for idx_tag, tag_item in enumerate(tags_value_from_series):
79
+ try:
80
+ # Ensure item is not NaN before string conversion if it might be a float NaN in a list
81
+ if pd.isna(tag_item): continue
82
+ str_tag = str(tag_item)
83
+ stripped_tag = str_tag.strip()
84
+ if stripped_tag:
85
+ current_tags_in_list.append(stripped_tag)
86
+ except Exception as e_inner_list_proc:
87
+ print(f"ERROR processing item '{tag_item}' (type: {type(tag_item)}) within a list for row {i}. Error: {e_inner_list_proc}. Original list: {current_value_for_error_msg}")
88
+ temp_processed_list_for_row = current_tags_in_list
89
+
90
+ # 2. Handle NumPy arrays
91
+ elif isinstance(tags_value_from_series, np.ndarray):
92
+ # Convert to list, then process elements, handling potential NaNs within the array
93
+ current_tags_in_list = []
94
+ for idx_tag, tag_item in enumerate(tags_value_from_series.tolist()): # .tolist() is crucial
95
+ try:
96
+ if pd.isna(tag_item): continue # Check for NaN after converting to Python type
97
+ str_tag = str(tag_item)
98
+ stripped_tag = str_tag.strip()
99
+ if stripped_tag:
100
+ current_tags_in_list.append(stripped_tag)
101
+ except Exception as e_inner_array_proc:
102
+ print(f"ERROR processing item '{tag_item}' (type: {type(tag_item)}) within a NumPy array for row {i}. Error: {e_inner_array_proc}. Original array: {current_value_for_error_msg}")
103
+ temp_processed_list_for_row = current_tags_in_list
104
+
105
+ # 3. Handle simple None or pd.NA after lists and arrays (which might contain pd.NA elements handled above)
106
+ elif tags_value_from_series is None or pd.isna(tags_value_from_series): # Now pd.isna is safe for scalars
107
+ temp_processed_list_for_row = []
108
+
109
+ # 4. Handle strings (could be JSON-like, list-like, or comma-separated)
110
+ elif isinstance(tags_value_from_series, str):
111
+ processed_str_tags = []
112
+ # Attempt ast.literal_eval for strings that look like lists/tuples
113
+ if (tags_value_from_series.startswith('[') and tags_value_from_series.endswith(']')) or \
114
+ (tags_value_from_series.startswith('(') and tags_value_from_series.endswith(')')):
115
+ try:
116
+ evaluated_tags = ast.literal_eval(tags_value_from_series)
117
+ if isinstance(evaluated_tags, (list, tuple)): # Check if eval result is a list/tuple
118
+ # Recursively process this evaluated list/tuple, as its elements could be complex
119
+ # For simplicity here, assume elements are simple strings after eval
120
+ current_eval_list = []
121
+ for tag_item in evaluated_tags:
122
+ if pd.isna(tag_item): continue
123
+ str_tag = str(tag_item).strip()
124
+ if str_tag: current_eval_list.append(str_tag)
125
+ processed_str_tags = current_eval_list
126
+ except (ValueError, SyntaxError):
127
+ pass # If ast.literal_eval fails, let it fall to JSON or comma split
128
+
129
+ # If ast.literal_eval didn't populate, try JSON
130
+ if not processed_str_tags:
131
+ try:
132
+ json_tags = json.loads(tags_value_from_series)
133
+ if isinstance(json_tags, list):
134
+ # Similar to above, assume elements are simple strings after JSON parsing
135
+ current_json_list = []
136
+ for tag_item in json_tags:
137
+ if pd.isna(tag_item): continue
138
+ str_tag = str(tag_item).strip()
139
+ if str_tag: current_json_list.append(str_tag)
140
+ processed_str_tags = current_json_list
141
+ except json.JSONDecodeError:
142
+ # If not a valid JSON list, fall back to comma splitting as the final string strategy
143
+ processed_str_tags = [tag.strip() for tag in tags_value_from_series.split(',') if tag.strip()]
144
+ except Exception as e_json_other:
145
+ print(f"ERROR during JSON processing for string '{current_value_for_error_msg}' for row {i}. Error: {e_json_other}")
146
+ processed_str_tags = [tag.strip() for tag in tags_value_from_series.split(',') if tag.strip()] # Fallback
147
+
148
+ temp_processed_list_for_row = processed_str_tags
149
+
150
+ # 5. Fallback for other scalar types (e.g., int, float that are not NaN)
151
+ else:
152
+ # This path is for non-list, non-ndarray, non-None/NaN, non-string types.
153
+ # Or for NaNs that slipped through if they are not None or pd.NA (e.g. float('nan'))
154
+ if pd.isna(tags_value_from_series): # Catch any remaining NaNs like float('nan')
155
+ temp_processed_list_for_row = []
156
+ else:
157
+ str_val = str(tags_value_from_series).strip()
158
+ temp_processed_list_for_row = [str_val] if str_val else []
159
+
160
+ processed_tags_accumulator.append(temp_processed_list_for_row)
161
 
162
+ except Exception as e_outer_tag_proc:
163
+ print(f"CRITICAL UNHANDLED ERROR processing row {i}: value '{current_value_for_error_msg}' (type: {type(tags_value_from_series)}). Error: {e_outer_tag_proc}. Appending [].")
164
+ processed_tags_accumulator.append([])
165
+
166
+ return processed_tags_accumulator
167
+
168
+ def load_models_data(force_refresh=False, tqdm_cls=None):
169
+ if tqdm_cls is None: tqdm_cls = tqdm
170
+ overall_start_time = time.time()
171
+ print(f"Gradio load_models_data called with force_refresh={force_refresh}")
172
+
173
+ expected_cols_in_processed_parquet = [
174
+ 'id', 'downloads', 'downloadsAllTime', 'likes', 'pipeline_tag', 'tags', 'params',
175
+ 'size_category', 'organization', 'has_audio', 'has_speech', 'has_music',
176
+ 'has_robot', 'has_bio', 'has_med', 'has_series', 'has_video', 'has_image',
177
+ 'has_text', 'has_science', 'is_audio_speech', 'is_biomed',
178
+ 'data_download_timestamp'
179
+ ]
180
+
181
+ if not force_refresh and os.path.exists(PROCESSED_PARQUET_FILE_PATH):
182
+ print(f"Attempting to load pre-processed data from: {PROCESSED_PARQUET_FILE_PATH}")
183
+ try:
184
+ df = pd.read_parquet(PROCESSED_PARQUET_FILE_PATH)
185
+ elapsed = time.time() - overall_start_time
186
+ missing_cols = [col for col in expected_cols_in_processed_parquet if col not in df.columns]
187
+ if missing_cols:
188
+ raise ValueError(f"Pre-processed Parquet is missing columns: {missing_cols}. Please run preprocessor or refresh data in app.")
189
+
190
+ # --- Diagnostic for 'has_robot' after loading parquet ---
191
+ if 'has_robot' in df.columns:
192
+ robot_count_parquet = df['has_robot'].sum()
193
+ print(f"DIAGNOSTIC (App - Parquet Load): 'has_robot' column found. Number of True values: {robot_count_parquet}")
194
+ if 0 < robot_count_parquet < 10:
195
+ print(f"Sample 'has_robot' models (from parquet): {df[df['has_robot']]['id'].head().tolist()}")
196
+ else:
197
+ print("DIAGNOSTIC (App - Parquet Load): 'has_robot' column NOT FOUND.")
198
+ # --- End Diagnostic ---
199
+
200
+ msg = f"Successfully loaded pre-processed data in {elapsed:.2f}s. Shape: {df.shape}"
201
+ print(msg)
202
+ return df, True, msg
203
+ except Exception as e:
204
+ print(f"Could not load pre-processed Parquet: {e}. ")
205
+ if force_refresh: print("Proceeding to fetch fresh data as force_refresh=True.")
206
+ else:
207
+ err_msg = (f"Pre-processed data could not be loaded: {e}. "
208
+ "Please use 'Refresh Data from Hugging Face' button.")
209
+ return pd.DataFrame(), False, err_msg
210
+
211
+ df_raw = None
212
+ raw_data_source_msg = ""
213
+ if force_refresh:
214
+ print("force_refresh=True (Gradio). Fetching fresh data...")
215
+ fetch_start = time.time()
216
+ try:
217
+ query = f"SELECT * FROM read_parquet('{HF_PARQUET_URL}')" # Ensure HF_PARQUET_URL is defined
218
+ df_raw = duckdb.sql(query).df()
219
+ if df_raw is None or df_raw.empty: raise ValueError("Fetched data is empty or None.")
220
+ raw_data_source_msg = f"Fetched by Gradio in {time.time() - fetch_start:.2f}s. Rows: {len(df_raw)}"
221
+ print(raw_data_source_msg)
222
+ except Exception as e_hf:
223
+ return pd.DataFrame(), False, f"Fatal error fetching from Hugging Face (Gradio): {e_hf}"
224
+ else:
225
+ err_msg = (f"Pre-processed data '{PROCESSED_PARQUET_FILE_PATH}' not found/invalid. "
226
+ "Run preprocessor or use 'Refresh Data' button.")
227
+ return pd.DataFrame(), False, err_msg
228
+
229
+ print(f"Initiating processing for data newly fetched by Gradio. {raw_data_source_msg}")
230
+ df = pd.DataFrame()
231
+ proc_start = time.time()
232
 
233
+ core_cols = {'id': str, 'downloads': float, 'downloadsAllTime': float, 'likes': float,
234
+ 'pipeline_tag': str, 'tags': object, 'safetensors': object}
235
+ for col, dtype in core_cols.items():
236
+ if col in df_raw.columns:
237
+ df[col] = df_raw[col]
238
+ if dtype == float: df[col] = pd.to_numeric(df[col], errors='coerce').fillna(0.0)
239
+ elif dtype == str: df[col] = df[col].astype(str).fillna('')
240
+ else:
241
+ if col in ['downloads', 'downloadsAllTime', 'likes']: df[col] = 0.0
242
+ elif col == 'pipeline_tag': df[col] = ''
243
+ elif col == 'tags': df[col] = pd.Series([[] for _ in range(len(df_raw))])
244
+ elif col == 'safetensors': df[col] = None
245
+ elif col == 'id': return pd.DataFrame(), False, "Critical: 'id' column missing."
 
 
 
 
246
 
247
+ output_filesize_col_name = 'params'
248
+ if output_filesize_col_name in df_raw.columns and pd.api.types.is_numeric_dtype(df_raw[output_filesize_col_name]):
249
+ df[output_filesize_col_name] = pd.to_numeric(df_raw[output_filesize_col_name], errors='coerce').fillna(0.0)
250
+ elif 'safetensors' in df.columns:
251
+ safetensors_iter = df['safetensors']
252
+ if tqdm_cls != tqdm :
253
+ safetensors_iter = tqdm_cls(df['safetensors'], desc="Extracting model sizes (GB)")
254
+ df[output_filesize_col_name] = [extract_model_size(s) for s in safetensors_iter]
255
+ df[output_filesize_col_name] = pd.to_numeric(df[output_filesize_col_name], errors='coerce').fillna(0.0)
256
+ else:
257
+ df[output_filesize_col_name] = 0.0
258
+
259
+ def get_size_category_gradio(size_gb_val):
260
+ try: numeric_size_gb = float(size_gb_val)
261
+ except (ValueError, TypeError): numeric_size_gb = 0.0
262
+ if pd.isna(numeric_size_gb): numeric_size_gb = 0.0
263
+ if 0 <= numeric_size_gb < 1: return "Small (<1GB)"
264
+ elif 1 <= numeric_size_gb < 5: return "Medium (1-5GB)"
265
+ elif 5 <= numeric_size_gb < 20: return "Large (5-20GB)"
266
+ elif 20 <= numeric_size_gb < 50: return "X-Large (20-50GB)"
267
+ elif numeric_size_gb >= 50: return "XX-Large (>50GB)"
268
+ else: return "Small (<1GB)"
269
+ df['size_category'] = df[output_filesize_col_name].apply(get_size_category_gradio)
270
+
271
+ df['tags'] = process_tags_for_series(df['tags'])
272
+ df['temp_tags_joined'] = df['tags'].apply(
273
+ lambda tl: '~~~'.join(str(t).lower() for t in tl if pd.notna(t) and str(t).strip()) if isinstance(tl, list) else ''
274
+ )
275
+ tag_map = {
276
+ 'has_audio': ['audio'], 'has_speech': ['speech'], 'has_music': ['music'],
277
+ 'has_robot': ['robot', 'robotics'],
278
+ 'has_bio': ['bio'], 'has_med': ['medic', 'medical'],
279
+ 'has_series': ['series', 'time-series', 'timeseries'],
280
+ 'has_video': ['video'], 'has_image': ['image', 'vision'],
281
+ 'has_text': ['text', 'nlp', 'llm']
282
+ }
283
+ for col, kws in tag_map.items():
284
+ pattern = '|'.join(kws)
285
+ df[col] = df['temp_tags_joined'].str.contains(pattern, na=False, case=False, regex=True)
286
+ df['has_science'] = (
287
+ df['temp_tags_joined'].str.contains('science', na=False, case=False, regex=True) &
288
+ ~df['temp_tags_joined'].str.contains('bigscience', na=False, case=False, regex=True)
289
+ )
290
+ del df['temp_tags_joined']
291
+ df['is_audio_speech'] = (df['has_audio'] | df['has_speech'] |
292
+ df['pipeline_tag'].str.contains('audio|speech', case=False, na=False, regex=True))
293
+ df['is_biomed'] = df['has_bio'] | df['has_med']
294
+ df['organization'] = df['id'].apply(extract_org_from_id)
295
+
296
+ if 'safetensors' in df.columns and \
297
+ not (output_filesize_col_name in df_raw.columns and pd.api.types.is_numeric_dtype(df_raw[output_filesize_col_name])):
298
+ df = df.drop(columns=['safetensors'], errors='ignore')
 
299
 
300
+ # --- Diagnostic for 'has_robot' after app-side processing (force_refresh path) ---
301
+ if force_refresh and 'has_robot' in df.columns:
302
+ robot_count_app_proc = df['has_robot'].sum()
303
+ print(f"DIAGNOSTIC (App - Force Refresh Processing): 'has_robot' column processed. Number of True values: {robot_count_app_proc}")
304
+ if 0 < robot_count_app_proc < 10:
305
+ print(f"Sample 'has_robot' models (App processed): {df[df['has_robot']]['id'].head().tolist()}")
306
+ # --- End Diagnostic ---
307
+
308
+ print(f"Data processing by Gradio completed in {time.time() - proc_start:.2f}s.")
309
+
310
+ total_elapsed = time.time() - overall_start_time
311
+ final_msg = f"{raw_data_source_msg}. Processing by Gradio took {time.time() - proc_start:.2f}s. Total: {total_elapsed:.2f}s. Shape: {df.shape}"
312
+ print(final_msg)
313
+ return df, True, final_msg
314
 
 
 
 
 
 
315
 
316
  def make_treemap_data(df, count_by, top_k=25, tag_filter=None, pipeline_filter=None, size_filter=None, skip_orgs=None):
317
+ if df is None or df.empty: return pd.DataFrame()
 
318
  filtered_df = df.copy()
319
+ col_map = { "Audio & Speech": "is_audio_speech", "Music": "has_music", "Robotics": "has_robot",
320
+ "Biomedical": "is_biomed", "Time series": "has_series", "Sciences": "has_science",
321
+ "Video": "has_video", "Images": "has_image", "Text": "has_text"}
322
 
323
+ # --- Diagnostic within make_treemap_data ---
324
+ if 'has_robot' in filtered_df.columns:
325
+ initial_robot_count = filtered_df['has_robot'].sum()
326
+ print(f"DIAGNOSTIC (make_treemap_data entry): Input df has {initial_robot_count} 'has_robot' models.")
327
+ else:
328
+ print("DIAGNOSTIC (make_treemap_data entry): 'has_robot' column NOT in input df.")
329
+ # --- End Diagnostic ---
330
+
331
+ if tag_filter and tag_filter in col_map:
332
+ target_col = col_map[tag_filter]
333
+ if target_col in filtered_df.columns:
334
+ # --- Diagnostic for specific 'Robotics' filter application ---
335
+ if tag_filter == "Robotics":
336
+ count_before_robot_filter = filtered_df[target_col].sum()
337
+ print(f"DIAGNOSTIC (make_treemap_data): Applying 'Robotics' filter. Models with '{target_col}'=True before this filter step: {count_before_robot_filter}")
338
+ # --- End Diagnostic ---
339
+ filtered_df = filtered_df[filtered_df[target_col]]
340
+ if tag_filter == "Robotics":
341
+ print(f"DIAGNOSTIC (make_treemap_data): After 'Robotics' filter ({target_col}), df rows: {len(filtered_df)}")
342
+ else:
343
+ print(f"Warning: Tag filter column '{col_map[tag_filter]}' not found in DataFrame.")
 
 
 
 
 
 
 
 
 
 
 
 
344
  if pipeline_filter:
345
+ if "pipeline_tag" in filtered_df.columns:
346
+ filtered_df = filtered_df[filtered_df["pipeline_tag"] == pipeline_filter]
347
+ else:
348
+ print(f"Warning: 'pipeline_tag' column not found for filtering.")
349
+ if size_filter and size_filter != "None" and size_filter in MODEL_SIZE_RANGES.keys():
350
+ if 'size_category' in filtered_df.columns:
351
+ filtered_df = filtered_df[filtered_df['size_category'] == size_filter]
352
+ else:
353
+ print("Warning: 'size_category' column not found for filtering.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
354
  if skip_orgs and len(skip_orgs) > 0:
355
+ if "organization" in filtered_df.columns:
356
+ filtered_df = filtered_df[~filtered_df["organization"].isin(skip_orgs)]
357
+ else:
358
+ print("Warning: 'organization' column not found for filtering.")
359
+ if filtered_df.empty: return pd.DataFrame()
360
+ if count_by not in filtered_df.columns or not pd.api.types.is_numeric_dtype(filtered_df[count_by]):
361
+ filtered_df[count_by] = pd.to_numeric(filtered_df.get(count_by), errors="coerce").fillna(0.0)
362
+ org_totals = filtered_df.groupby("organization")[count_by].sum().nlargest(top_k, keep='first')
363
+ top_orgs_list = org_totals.index.tolist()
364
+ treemap_data = filtered_df[filtered_df["organization"].isin(top_orgs_list)][["id", "organization", count_by]].copy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
365
  treemap_data["root"] = "models"
366
+ treemap_data[count_by] = pd.to_numeric(treemap_data[count_by], errors="coerce").fillna(0.0)
 
 
 
 
367
  return treemap_data
368
 
369
  def create_treemap(treemap_data, count_by, title=None):
 
370
  if treemap_data.empty:
371
+ fig = px.treemap(names=["No data matches filters"], parents=[""], values=[1])
372
+ fig.update_layout(title="No data matches the selected filters", margin=dict(t=50, l=25, r=25, b=25))
 
 
 
 
 
 
 
373
  return fig
 
 
374
  fig = px.treemap(
375
+ treemap_data, path=["root", "organization", "id"], values=count_by,
 
 
376
  title=title or f"HuggingFace Models - {count_by.capitalize()} by Organization",
377
  color_discrete_sequence=px.colors.qualitative.Plotly
378
  )
379
+ fig.update_layout(margin=dict(t=50, l=25, r=25, b=25))
380
+ fig.update_traces(textinfo="label+value+percent root", hovertemplate="<b>%{label}</b><br>%{value:,} " + count_by + "<br>%{percentRoot:.2%} of total<extra></extra>")
 
 
 
 
 
 
 
 
 
 
381
  return fig
382
 
383
+ with gr.Blocks(title="HuggingFace Model Explorer", fill_width=True) as demo:
384
+ models_data_state = gr.State(pd.DataFrame())
385
+ loading_complete_state = gr.State(False)
386
+
387
+ with gr.Row(): gr.Markdown("# HuggingFace Models TreeMap Visualization")
388
+ with gr.Row():
389
+ with gr.Column(scale=1):
390
+ count_by_dropdown = gr.Dropdown(label="Metric", choices=[("Downloads (last 30 days)", "downloads"), ("Downloads (All Time)", "downloadsAllTime"), ("Likes", "likes")], value="downloads")
391
+ filter_choice_radio = gr.Radio(label="Filter Type", choices=["None", "Tag Filter", "Pipeline Filter"], value="None")
392
+ tag_filter_dropdown = gr.Dropdown(label="Select Tag", choices=TAG_FILTER_CHOICES, value=None, visible=False)
393
+ pipeline_filter_dropdown = gr.Dropdown(label="Select Pipeline Tag", choices=PIPELINE_TAGS, value=None, visible=False)
394
+ size_filter_dropdown = gr.Dropdown(label="Model Size Filter", choices=["None"] + list(MODEL_SIZE_RANGES.keys()), value="None")
395
+ top_k_slider = gr.Slider(label="Number of Top Organizations", minimum=5, maximum=50, value=25, step=5)
396
+ skip_orgs_textbox = gr.Textbox(label="Organizations to Skip (comma-separated)", value="TheBloke,MaziyarPanahi,unsloth,modularai,Gensyn,bartowski")
397
+ generate_plot_button = gr.Button(value="Generate Plot", variant="primary", interactive=False)
398
+ refresh_data_button = gr.Button(value="Refresh Data from Hugging Face", variant="secondary")
399
+ with gr.Column(scale=3):
400
+ plot_output = gr.Plot()
401
+ status_message_md = gr.Markdown("Initializing...")
402
+ data_info_md = gr.Markdown("")
403
+
404
+ def _update_button_interactivity(is_loaded_flag):
405
+ return gr.update(interactive=is_loaded_flag)
406
+ loading_complete_state.change(fn=_update_button_interactivity, inputs=loading_complete_state, outputs=generate_plot_button)
407
+
408
+ def _toggle_filters_visibility(choice):
409
+ return gr.update(visible=choice == "Tag Filter"), gr.update(visible=choice == "Pipeline Filter")
410
+ filter_choice_radio.change(fn=_toggle_filters_visibility, inputs=filter_choice_radio, outputs=[tag_filter_dropdown, pipeline_filter_dropdown])
411
+
412
+ def ui_load_data_controller(force_refresh_ui_trigger=False, progress=gr.Progress(track_tqdm=True)):
413
+ print(f"ui_load_data_controller called with force_refresh_ui_trigger={force_refresh_ui_trigger}")
414
+ status_msg_ui = "Loading data..."
415
+ data_info_text = ""
416
+ current_df = pd.DataFrame()
417
+ load_success_flag = False
418
+ data_as_of_date_display = "N/A"
419
  try:
420
+ current_df, load_success_flag, status_msg_from_load = load_models_data(
421
+ force_refresh=force_refresh_ui_trigger, tqdm_cls=progress.tqdm
422
+ )
423
+ if load_success_flag:
424
+ if force_refresh_ui_trigger:
425
+ data_as_of_date_display = pd.Timestamp.now(tz='UTC').strftime('%B %d, %Y, %H:%M:%S %Z')
426
+ elif 'data_download_timestamp' in current_df.columns and not current_df.empty and pd.notna(current_df['data_download_timestamp'].iloc[0]):
427
+ timestamp_from_parquet = pd.to_datetime(current_df['data_download_timestamp'].iloc[0])
428
+ if timestamp_from_parquet.tzinfo is None:
429
+ timestamp_from_parquet = timestamp_from_parquet.tz_localize('UTC')
430
+ data_as_of_date_display = timestamp_from_parquet.strftime('%B %d, %Y, %H:%M:%S %Z')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
431
  else:
432
+ data_as_of_date_display = "Pre-processed (date unavailable)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
433
 
434
+ size_dist_lines = []
435
+ if 'size_category' in current_df.columns:
436
+ for cat in MODEL_SIZE_RANGES.keys():
437
+ count = (current_df['size_category'] == cat).sum()
438
+ size_dist_lines.append(f" - {cat}: {count:,} models")
439
+ else: size_dist_lines.append(" - Size category information not available.")
440
+ size_dist = "\n".join(size_dist_lines)
 
 
441
 
442
+ data_info_text = (f"### Data Information\n"
443
+ f"- Overall Status: {status_msg_from_load}\n"
444
+ f"- Total models loaded: {len(current_df):,}\n"
445
+ f"- Data as of: {data_as_of_date_display}\n"
446
+ f"- Size categories:\n{size_dist}")
447
 
448
+ # # --- MODIFICATION: Add 'has_robot' count to UI data_info_text ---
449
+ # if not current_df.empty and 'has_robot' in current_df.columns:
450
+ # robot_true_count = current_df['has_robot'].sum()
451
+ # data_info_text += f"\n- **Models flagged 'has_robot'**: {robot_true_count}"
452
+ # if 0 < robot_true_count <= 10: # If a few are found, list some IDs
453
+ # sample_robot_ids = current_df[current_df['has_robot']]['id'].head(5).tolist()
454
+ # data_info_text += f"\n - Sample 'has_robot' model IDs: `{', '.join(sample_robot_ids)}`"
455
+ # elif not current_df.empty:
456
+ # data_info_text += "\n- **Models flagged 'has_robot'**: 'has_robot' column not found in loaded data."
457
+ # # --- END MODIFICATION ---
458
+
459
+ status_msg_ui = "Data loaded successfully. Ready to generate plot."
460
+ else:
461
+ data_info_text = f"### Data Load Failed\n- {status_msg_from_load}"
462
+ status_msg_ui = status_msg_from_load
463
+ except Exception as e:
464
+ status_msg_ui = f"An unexpected error occurred in ui_load_data_controller: {str(e)}"
465
+ data_info_text = f"### Critical Error\n- {status_msg_ui}"
466
+ print(f"Critical error in ui_load_data_controller: {e}")
467
+ load_success_flag = False
468
+ return current_df, load_success_flag, data_info_text, status_msg_ui
469
+
470
+ def ui_generate_plot_controller(metric_choice, filter_type, tag_choice, pipeline_choice,
471
+ size_choice, k_orgs, skip_orgs_input, df_current_models):
472
+ if df_current_models is None or df_current_models.empty:
473
+ empty_fig = create_treemap(pd.DataFrame(), metric_choice, "Error: Model Data Not Loaded")
474
+ error_msg = "Model data is not loaded or is empty. Please load or refresh data first."
475
+ gr.Warning(error_msg)
476
+ return empty_fig, error_msg
477
+ tag_to_use = tag_choice if filter_type == "Tag Filter" else None
478
+ pipeline_to_use = pipeline_choice if filter_type == "Pipeline Filter" else None
479
+ size_to_use = size_choice if size_choice != "None" else None
480
+ orgs_to_skip = [org.strip() for org in skip_orgs_input.split(',') if org.strip()] if skip_orgs_input else []
481
 
482
+ # --- Diagnostic before calling make_treemap_data ---
483
+ if 'has_robot' in df_current_models.columns:
484
+ robot_count_before_treemap = df_current_models['has_robot'].sum()
485
+ print(f"DIAGNOSTIC (ui_generate_plot_controller): df_current_models entering make_treemap_data has {robot_count_before_treemap} 'has_robot' models.")
486
+ # --- End Diagnostic ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
487
 
488
+ treemap_df = make_treemap_data(df_current_models, metric_choice, k_orgs, tag_to_use, pipeline_to_use, size_to_use, orgs_to_skip)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
489
 
490
+ title_labels = {"downloads": "Downloads (last 30 days)", "downloadsAllTime": "Downloads (All Time)", "likes": "Likes"}
491
+ chart_title = f"HuggingFace Models - {title_labels.get(metric_choice, metric_choice)} by Organization"
492
+ plotly_fig = create_treemap(treemap_df, metric_choice, chart_title)
493
+ if treemap_df.empty:
494
+ plot_stats_md = "No data matches the selected filters. Try adjusting your filters."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
495
  else:
496
+ total_items_in_plot = len(treemap_df['id'].unique())
497
+ total_value_in_plot = treemap_df[metric_choice].sum()
498
+ plot_stats_md = (f"## Plot Statistics\n- **Models shown**: {total_items_in_plot:,}\n- **Total {metric_choice}**: {int(total_value_in_plot):,}")
499
+ return plotly_fig, plot_stats_md
 
 
 
 
 
 
 
 
 
 
 
 
500
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
501
  demo.load(
502
+ fn=lambda progress=gr.Progress(track_tqdm=True): ui_load_data_controller(force_refresh_ui_trigger=False, progress=progress),
503
+ inputs=[],
504
+ outputs=[models_data_state, loading_complete_state, data_info_md, status_message_md]
505
  )
 
 
506
  refresh_data_button.click(
507
+ fn=lambda progress=gr.Progress(track_tqdm=True): ui_load_data_controller(force_refresh_ui_trigger=True, progress=progress),
508
+ inputs=[],
509
+ outputs=[models_data_state, loading_complete_state, data_info_md, status_message_md]
510
  )
 
511
  generate_plot_button.click(
512
+ fn=ui_generate_plot_controller,
513
+ inputs=[count_by_dropdown, filter_choice_radio, tag_filter_dropdown, pipeline_filter_dropdown,
514
+ size_filter_dropdown, top_k_slider, skip_orgs_textbox, models_data_state],
515
+ outputs=[plot_output, status_message_md]
 
 
 
 
 
 
 
 
516
  )
517
 
518
  if __name__ == "__main__":
519
+ if not os.path.exists(PROCESSED_PARQUET_FILE_PATH):
520
+ print(f"WARNING: Pre-processed data file '{PROCESSED_PARQUET_FILE_PATH}' not found.")
521
+ print("It is highly recommended to run the preprocessing script (e.g., preprocess.py) first.") # Corrected script name
522
+ else:
523
+ print(f"Found pre-processed data file: '{PROCESSED_PARQUET_FILE_PATH}'.")
524
+ demo.launch()
525
+
526
+ # --- END OF FILE app.py ---
models_processed.parquet ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:998afad6c0c4c64f9e98efd8609d1cbab1dd2ac281b9c2e023878ad436c2fbde
3
+ size 96033487
preprocess.py ADDED
@@ -0,0 +1,371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --- START OF FILE preprocess.py ---
2
+
3
+ import pandas as pd
4
+ import numpy as np
5
+ import json
6
+ import ast
7
+ from tqdm.auto import tqdm
8
+ import time
9
+ import os
10
+ import duckdb
11
+ import re # Import re for the manual regex check in debug
12
+
13
+ # --- Constants ---
14
+ PROCESSED_PARQUET_FILE_PATH = "models_processed.parquet"
15
+ HF_PARQUET_URL = 'https://huggingface.co/datasets/cfahlgren1/hub-stats/resolve/main/models.parquet'
16
+
17
+ MODEL_SIZE_RANGES = {
18
+ "Small (<1GB)": (0, 1),
19
+ "Medium (1-5GB)": (1, 5),
20
+ "Large (5-20GB)": (5, 20),
21
+ "X-Large (20-50GB)": (20, 50),
22
+ "XX-Large (>50GB)": (50, float('inf'))
23
+ }
24
+
25
+ # --- Debugging Constant ---
26
+ # <<<<<<< SET THE MODEL ID YOU WANT TO DEBUG HERE >>>>>>>
27
+ MODEL_ID_TO_DEBUG = "openvla/openvla-7b"
28
+ # Example: MODEL_ID_TO_DEBUG = "openai-community/gpt2"
29
+ # If you don't have a specific ID, the debug block will just report it's not found.
30
+
31
+ # --- Utility Functions (extract_model_file_size_gb, extract_org_from_id, process_tags_for_series, get_file_size_category - unchanged from previous correct version) ---
32
+ def extract_model_file_size_gb(safetensors_data):
33
+ try:
34
+ if pd.isna(safetensors_data): return 0.0
35
+ data_to_parse = safetensors_data
36
+ if isinstance(safetensors_data, str):
37
+ try:
38
+ if (safetensors_data.startswith('{') and safetensors_data.endswith('}')) or \
39
+ (safetensors_data.startswith('[') and safetensors_data.endswith(']')):
40
+ data_to_parse = ast.literal_eval(safetensors_data)
41
+ else: data_to_parse = json.loads(safetensors_data)
42
+ except Exception: return 0.0
43
+ if isinstance(data_to_parse, dict) and 'total' in data_to_parse:
44
+ total_bytes_val = data_to_parse['total']
45
+ try:
46
+ size_bytes = float(total_bytes_val)
47
+ return size_bytes / (1024 * 1024 * 1024)
48
+ except (ValueError, TypeError): return 0.0
49
+ return 0.0
50
+ except Exception: return 0.0
51
+
52
+ def extract_org_from_id(model_id):
53
+ if pd.isna(model_id): return "unaffiliated"
54
+ model_id_str = str(model_id)
55
+ return model_id_str.split("/")[0] if "/" in model_id_str else "unaffiliated"
56
+
57
+ def process_tags_for_series(series_of_tags_values):
58
+ processed_tags_accumulator = []
59
+
60
+ for i, tags_value_from_series in enumerate(tqdm(series_of_tags_values, desc="Standardizing Tags", leave=False, unit="row")):
61
+ temp_processed_list_for_row = []
62
+ current_value_for_error_msg = str(tags_value_from_series)[:200] # Truncate for long error messages
63
+
64
+ try:
65
+ # Order of checks is important!
66
+ # 1. Handle explicit Python lists first
67
+ if isinstance(tags_value_from_series, list):
68
+ current_tags_in_list = []
69
+ for idx_tag, tag_item in enumerate(tags_value_from_series):
70
+ try:
71
+ # Ensure item is not NaN before string conversion if it might be a float NaN in a list
72
+ if pd.isna(tag_item): continue
73
+ str_tag = str(tag_item)
74
+ stripped_tag = str_tag.strip()
75
+ if stripped_tag:
76
+ current_tags_in_list.append(stripped_tag)
77
+ except Exception as e_inner_list_proc:
78
+ print(f"ERROR processing item '{tag_item}' (type: {type(tag_item)}) within a list for row {i}. Error: {e_inner_list_proc}. Original list: {current_value_for_error_msg}")
79
+ temp_processed_list_for_row = current_tags_in_list
80
+
81
+ # 2. Handle NumPy arrays
82
+ elif isinstance(tags_value_from_series, np.ndarray):
83
+ # Convert to list, then process elements, handling potential NaNs within the array
84
+ current_tags_in_list = []
85
+ for idx_tag, tag_item in enumerate(tags_value_from_series.tolist()): # .tolist() is crucial
86
+ try:
87
+ if pd.isna(tag_item): continue # Check for NaN after converting to Python type
88
+ str_tag = str(tag_item)
89
+ stripped_tag = str_tag.strip()
90
+ if stripped_tag:
91
+ current_tags_in_list.append(stripped_tag)
92
+ except Exception as e_inner_array_proc:
93
+ print(f"ERROR processing item '{tag_item}' (type: {type(tag_item)}) within a NumPy array for row {i}. Error: {e_inner_array_proc}. Original array: {current_value_for_error_msg}")
94
+ temp_processed_list_for_row = current_tags_in_list
95
+
96
+ # 3. Handle simple None or pd.NA after lists and arrays (which might contain pd.NA elements handled above)
97
+ elif tags_value_from_series is None or pd.isna(tags_value_from_series): # Now pd.isna is safe for scalars
98
+ temp_processed_list_for_row = []
99
+
100
+ # 4. Handle strings (could be JSON-like, list-like, or comma-separated)
101
+ elif isinstance(tags_value_from_series, str):
102
+ processed_str_tags = []
103
+ # Attempt ast.literal_eval for strings that look like lists/tuples
104
+ if (tags_value_from_series.startswith('[') and tags_value_from_series.endswith(']')) or \
105
+ (tags_value_from_series.startswith('(') and tags_value_from_series.endswith(')')):
106
+ try:
107
+ evaluated_tags = ast.literal_eval(tags_value_from_series)
108
+ if isinstance(evaluated_tags, (list, tuple)): # Check if eval result is a list/tuple
109
+ # Recursively process this evaluated list/tuple, as its elements could be complex
110
+ # For simplicity here, assume elements are simple strings after eval
111
+ current_eval_list = []
112
+ for tag_item in evaluated_tags:
113
+ if pd.isna(tag_item): continue
114
+ str_tag = str(tag_item).strip()
115
+ if str_tag: current_eval_list.append(str_tag)
116
+ processed_str_tags = current_eval_list
117
+ except (ValueError, SyntaxError):
118
+ pass # If ast.literal_eval fails, let it fall to JSON or comma split
119
+
120
+ # If ast.literal_eval didn't populate, try JSON
121
+ if not processed_str_tags:
122
+ try:
123
+ json_tags = json.loads(tags_value_from_series)
124
+ if isinstance(json_tags, list):
125
+ # Similar to above, assume elements are simple strings after JSON parsing
126
+ current_json_list = []
127
+ for tag_item in json_tags:
128
+ if pd.isna(tag_item): continue
129
+ str_tag = str(tag_item).strip()
130
+ if str_tag: current_json_list.append(str_tag)
131
+ processed_str_tags = current_json_list
132
+ except json.JSONDecodeError:
133
+ # If not a valid JSON list, fall back to comma splitting as the final string strategy
134
+ processed_str_tags = [tag.strip() for tag in tags_value_from_series.split(',') if tag.strip()]
135
+ except Exception as e_json_other:
136
+ print(f"ERROR during JSON processing for string '{current_value_for_error_msg}' for row {i}. Error: {e_json_other}")
137
+ processed_str_tags = [tag.strip() for tag in tags_value_from_series.split(',') if tag.strip()] # Fallback
138
+
139
+ temp_processed_list_for_row = processed_str_tags
140
+
141
+ # 5. Fallback for other scalar types (e.g., int, float that are not NaN)
142
+ else:
143
+ # This path is for non-list, non-ndarray, non-None/NaN, non-string types.
144
+ # Or for NaNs that slipped through if they are not None or pd.NA (e.g. float('nan'))
145
+ if pd.isna(tags_value_from_series): # Catch any remaining NaNs like float('nan')
146
+ temp_processed_list_for_row = []
147
+ else:
148
+ str_val = str(tags_value_from_series).strip()
149
+ temp_processed_list_for_row = [str_val] if str_val else []
150
+
151
+ processed_tags_accumulator.append(temp_processed_list_for_row)
152
+
153
+ except Exception as e_outer_tag_proc:
154
+ print(f"CRITICAL UNHANDLED ERROR processing row {i}: value '{current_value_for_error_msg}' (type: {type(tags_value_from_series)}). Error: {e_outer_tag_proc}. Appending [].")
155
+ processed_tags_accumulator.append([])
156
+
157
+ return processed_tags_accumulator
158
+
159
+ def get_file_size_category(file_size_gb_val):
160
+ try:
161
+ numeric_file_size_gb = float(file_size_gb_val)
162
+ if pd.isna(numeric_file_size_gb): numeric_file_size_gb = 0.0
163
+ except (ValueError, TypeError): numeric_file_size_gb = 0.0
164
+ if 0 <= numeric_file_size_gb < 1: return "Small (<1GB)"
165
+ elif 1 <= numeric_file_size_gb < 5: return "Medium (1-5GB)"
166
+ elif 5 <= numeric_file_size_gb < 20: return "Large (5-20GB)"
167
+ elif 20 <= numeric_file_size_gb < 50: return "X-Large (20-50GB)"
168
+ elif numeric_file_size_gb >= 50: return "XX-Large (>50GB)"
169
+ else: return "Small (<1GB)"
170
+
171
+
172
+ def main_preprocessor():
173
+ print(f"Starting pre-processing script. Output: '{PROCESSED_PARQUET_FILE_PATH}'.")
174
+ overall_start_time = time.time()
175
+
176
+ print(f"Fetching fresh data from Hugging Face: {HF_PARQUET_URL}")
177
+ try:
178
+ fetch_start_time = time.time()
179
+ query = f"SELECT * FROM read_parquet('{HF_PARQUET_URL}')"
180
+ df_raw = duckdb.sql(query).df()
181
+ data_download_timestamp = pd.Timestamp.now(tz='UTC')
182
+
183
+ if df_raw is None or df_raw.empty: raise ValueError("Fetched data is empty or None.")
184
+ if 'id' not in df_raw.columns: raise ValueError("Fetched data must contain 'id' column.")
185
+
186
+ print(f"Fetched data in {time.time() - fetch_start_time:.2f}s. Rows: {len(df_raw)}. Downloaded at: {data_download_timestamp.strftime('%Y-%m-%d %H:%M:%S %Z')}")
187
+ except Exception as e_fetch:
188
+ print(f"ERROR: Could not fetch data from Hugging Face: {e_fetch}.")
189
+ return
190
+
191
+ df = pd.DataFrame()
192
+ print("Processing raw data...")
193
+ proc_start = time.time()
194
+
195
+ expected_cols_setup = {
196
+ 'id': str, 'downloads': float, 'downloadsAllTime': float, 'likes': float,
197
+ 'pipeline_tag': str, 'tags': object, 'safetensors': object
198
+ }
199
+ for col_name, target_dtype in expected_cols_setup.items():
200
+ if col_name in df_raw.columns:
201
+ df[col_name] = df_raw[col_name]
202
+ if target_dtype == float: df[col_name] = pd.to_numeric(df[col_name], errors='coerce').fillna(0.0)
203
+ elif target_dtype == str: df[col_name] = df[col_name].astype(str).fillna('')
204
+ else:
205
+ if col_name in ['downloads', 'downloadsAllTime', 'likes']: df[col_name] = 0.0
206
+ elif col_name == 'pipeline_tag': df[col_name] = ''
207
+ elif col_name == 'tags': df[col_name] = pd.Series([[] for _ in range(len(df_raw))]) # Initialize with empty lists
208
+ elif col_name == 'safetensors': df[col_name] = None # Initialize with None
209
+ elif col_name == 'id': print("CRITICAL ERROR: 'id' column missing."); return
210
+
211
+ output_filesize_col_name = 'params'
212
+ if output_filesize_col_name in df_raw.columns and pd.api.types.is_numeric_dtype(df_raw[output_filesize_col_name]):
213
+ print(f"Using pre-existing '{output_filesize_col_name}' column as file size in GB.")
214
+ df[output_filesize_col_name] = pd.to_numeric(df_raw[output_filesize_col_name], errors='coerce').fillna(0.0)
215
+ elif 'safetensors' in df.columns:
216
+ print(f"Calculating '{output_filesize_col_name}' (file size in GB) from 'safetensors' data...")
217
+ df[output_filesize_col_name] = df['safetensors'].apply(extract_model_file_size_gb)
218
+ df[output_filesize_col_name] = pd.to_numeric(df[output_filesize_col_name], errors='coerce').fillna(0.0)
219
+ else:
220
+ print(f"Cannot determine file size. Setting '{output_filesize_col_name}' to 0.0.")
221
+ df[output_filesize_col_name] = 0.0
222
+
223
+ df['data_download_timestamp'] = data_download_timestamp
224
+ print(f"Added 'data_download_timestamp' column.")
225
+
226
+ print("Categorizing models by file size...")
227
+ df['size_category'] = df[output_filesize_col_name].apply(get_file_size_category)
228
+
229
+ print("Standardizing 'tags' column...")
230
+ df['tags'] = process_tags_for_series(df['tags']) # This now uses tqdm internally
231
+
232
+ # --- START DEBUGGING BLOCK ---
233
+ # This block will execute before the main tag processing loop
234
+ if MODEL_ID_TO_DEBUG and MODEL_ID_TO_DEBUG in df['id'].values: # Check if ID exists
235
+ print(f"\n--- Pre-Loop Debugging for Model ID: {MODEL_ID_TO_DEBUG} ---")
236
+
237
+ # 1. Check the 'tags' column content after process_tags_for_series
238
+ model_specific_tags_list = df.loc[df['id'] == MODEL_ID_TO_DEBUG, 'tags'].iloc[0]
239
+ print(f"1. Tags from df['tags'] (after process_tags_for_series): {model_specific_tags_list}")
240
+ print(f" Type of tags: {type(model_specific_tags_list)}")
241
+ if isinstance(model_specific_tags_list, list):
242
+ for i, tag_item in enumerate(model_specific_tags_list):
243
+ print(f" Tag item {i}: '{tag_item}' (type: {type(tag_item)}, len: {len(str(tag_item))})")
244
+ # Detailed check for 'robotics' specifically
245
+ if 'robotics' in str(tag_item).lower():
246
+ print(f" DEBUG: Found 'robotics' substring in '{tag_item}'")
247
+ print(f" - str(tag_item).lower().strip(): '{str(tag_item).lower().strip()}'")
248
+ print(f" - Is it exactly 'robotics'?: {str(tag_item).lower().strip() == 'robotics'}")
249
+ print(f" - Ordinals: {[ord(c) for c in str(tag_item)]}")
250
+
251
+ # 2. Simulate temp_tags_joined for this specific model
252
+ if isinstance(model_specific_tags_list, list):
253
+ simulated_temp_tags_joined = '~~~'.join(str(t).lower().strip() for t in model_specific_tags_list if pd.notna(t) and str(t).strip())
254
+ else:
255
+ simulated_temp_tags_joined = ''
256
+ print(f"2. Simulated 'temp_tags_joined' for this model: '{simulated_temp_tags_joined}'")
257
+
258
+ # 3. Simulate 'has_robot' check for this model
259
+ robot_keywords = ['robot', 'robotics']
260
+ robot_pattern = '|'.join(robot_keywords)
261
+ manual_robot_check = bool(re.search(robot_pattern, simulated_temp_tags_joined, flags=re.IGNORECASE))
262
+ print(f"3. Manual regex check for 'has_robot' ('{robot_pattern}' in '{simulated_temp_tags_joined}'): {manual_robot_check}")
263
+ print(f"--- End Pre-Loop Debugging for Model ID: {MODEL_ID_TO_DEBUG} ---\n")
264
+ elif MODEL_ID_TO_DEBUG:
265
+ print(f"DEBUG: Model ID '{MODEL_ID_TO_DEBUG}' not found in DataFrame for pre-loop debugging.")
266
+ # --- END DEBUGGING BLOCK ---
267
+
268
+
269
+ print("Vectorized creation of cached tag columns...")
270
+ tag_time = time.time()
271
+ # This is the original temp_tags_joined creation:
272
+ df['temp_tags_joined'] = df['tags'].apply(
273
+ lambda tl: '~~~'.join(str(t).lower().strip() for t in tl if pd.notna(t) and str(t).strip()) if isinstance(tl, list) else ''
274
+ )
275
+
276
+ tag_map = {
277
+ 'has_audio': ['audio'], 'has_speech': ['speech'], 'has_music': ['music'],
278
+ 'has_robot': ['robot', 'robotics','openvla','vla'],
279
+ 'has_bio': ['bio'], 'has_med': ['medic', 'medical'],
280
+ 'has_series': ['series', 'time-series', 'timeseries'],
281
+ 'has_video': ['video'], 'has_image': ['image', 'vision'],
282
+ 'has_text': ['text', 'nlp', 'llm']
283
+ }
284
+ for col, kws in tag_map.items():
285
+ pattern = '|'.join(kws)
286
+ df[col] = df['temp_tags_joined'].str.contains(pattern, na=False, case=False, regex=True)
287
+
288
+ df['has_science'] = (
289
+ df['temp_tags_joined'].str.contains('science', na=False, case=False, regex=True) &
290
+ ~df['temp_tags_joined'].str.contains('bigscience', na=False, case=False, regex=True)
291
+ )
292
+ del df['temp_tags_joined'] # Clean up temporary column
293
+ df['is_audio_speech'] = (df['has_audio'] | df['has_speech'] |
294
+ df['pipeline_tag'].str.contains('audio|speech', case=False, na=False, regex=True))
295
+ df['is_biomed'] = df['has_bio'] | df['has_med']
296
+ print(f"Vectorized tag columns created in {time.time() - tag_time:.2f}s.")
297
+
298
+ # --- POST-LOOP DIAGNOSTIC for has_robot & a specific model ---
299
+ if 'has_robot' in df.columns:
300
+ print("\n--- 'has_robot' Diagnostics (Preprocessor - Post-Loop) ---")
301
+ print(df['has_robot'].value_counts(dropna=False))
302
+
303
+ if MODEL_ID_TO_DEBUG and MODEL_ID_TO_DEBUG in df['id'].values:
304
+ model_has_robot_val = df.loc[df['id'] == MODEL_ID_TO_DEBUG, 'has_robot'].iloc[0]
305
+ print(f"Value of 'has_robot' for model '{MODEL_ID_TO_DEBUG}': {model_has_robot_val}")
306
+ if model_has_robot_val:
307
+ print(f" Original tags for '{MODEL_ID_TO_DEBUG}': {df.loc[df['id'] == MODEL_ID_TO_DEBUG, 'tags'].iloc[0]}")
308
+
309
+ if df['has_robot'].any():
310
+ print("Sample models flagged as 'has_robot':")
311
+ print(df[df['has_robot']][['id', 'tags', 'has_robot']].head(5))
312
+ else:
313
+ print("No models were flagged as 'has_robot' after processing.")
314
+ print("--------------------------------------------------------\n")
315
+ # --- END POST-LOOP DIAGNOSTIC ---
316
+
317
+
318
+ print("Adding organization column...")
319
+ df['organization'] = df['id'].apply(extract_org_from_id)
320
+
321
+ # Drop safetensors if params was calculated from it, and params didn't pre-exist as numeric
322
+ if 'safetensors' in df.columns and \
323
+ not (output_filesize_col_name in df_raw.columns and pd.api.types.is_numeric_dtype(df_raw[output_filesize_col_name])):
324
+ df = df.drop(columns=['safetensors'], errors='ignore')
325
+
326
+ final_expected_cols = [
327
+ 'id', 'downloads', 'downloadsAllTime', 'likes', 'pipeline_tag', 'tags',
328
+ 'params', 'size_category', 'organization',
329
+ 'has_audio', 'has_speech', 'has_music', 'has_robot', 'has_bio', 'has_med',
330
+ 'has_series', 'has_video', 'has_image', 'has_text', 'has_science',
331
+ 'is_audio_speech', 'is_biomed',
332
+ 'data_download_timestamp'
333
+ ]
334
+ # Ensure all final columns exist, adding defaults if necessary
335
+ for col in final_expected_cols:
336
+ if col not in df.columns:
337
+ print(f"Warning: Final expected column '{col}' is missing! Defaulting appropriately.")
338
+ if col == 'params': df[col] = 0.0
339
+ elif col == 'size_category': df[col] = "Small (<1GB)" # Default size category
340
+ elif 'has_' in col or 'is_' in col : df[col] = False # Default boolean flags to False
341
+ elif col == 'data_download_timestamp': df[col] = pd.NaT # Default timestamp to NaT
342
+
343
+ print(f"Data processing completed in {time.time() - proc_start:.2f}s.")
344
+ try:
345
+ print(f"Saving processed data to: {PROCESSED_PARQUET_FILE_PATH}")
346
+ df_to_save = df[final_expected_cols].copy() # Ensure only expected columns are saved
347
+ df_to_save.to_parquet(PROCESSED_PARQUET_FILE_PATH, index=False, engine='pyarrow')
348
+ print(f"Successfully saved processed data.")
349
+ except Exception as e_save:
350
+ print(f"ERROR: Could not save processed data: {e_save}")
351
+ return
352
+
353
+ total_elapsed_script = time.time() - overall_start_time
354
+ print(f"Pre-processing finished. Total time: {total_elapsed_script:.2f}s. Final Parquet shape: {df_to_save.shape}")
355
+
356
+ if __name__ == "__main__":
357
+ if os.path.exists(PROCESSED_PARQUET_FILE_PATH):
358
+ print(f"Deleting existing '{PROCESSED_PARQUET_FILE_PATH}' to ensure fresh processing...")
359
+ try: os.remove(PROCESSED_PARQUET_FILE_PATH)
360
+ except OSError as e: print(f"Error deleting file: {e}. Please delete manually and rerun."); exit()
361
+
362
+ main_preprocessor()
363
+
364
+ if os.path.exists(PROCESSED_PARQUET_FILE_PATH):
365
+ print(f"\nTo verify, load parquet and check 'has_robot' and its 'tags':")
366
+ print(f"import pandas as pd; df_chk = pd.read_parquet('{PROCESSED_PARQUET_FILE_PATH}')")
367
+ print(f"print(df_chk['has_robot'].value_counts())")
368
+ if MODEL_ID_TO_DEBUG:
369
+ print(f"print(df_chk[df_chk['id'] == '{MODEL_ID_TO_DEBUG}'][['id', 'tags', 'has_robot']])")
370
+ else:
371
+ print(f"print(df_chk[df_chk['has_robot']][['id', 'tags', 'has_robot']].head())")