File size: 28,494 Bytes
99ed203
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
import streamlit as st
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from huggingface_hub import HfApi, model_info
import time
import re
import os
import json
import signal
from contextlib import contextmanager
import numpy as np
from functools import partial
import gc
import sys

# Set page configuration
st.set_page_config(
    page_title="Quantized Model Comparison",
    page_icon="📊",
    layout="wide",
    initial_sidebar_state="expanded"
)

# Define a timeout context manager for safety on CPU-only environments
@contextmanager
def timeout(time_seconds=60):
    def signal_handler(signum, frame):
        raise TimeoutError("Timed out!")
    
    signal.signal(signal.SIGALRM, signal_handler)
    signal.alarm(time_seconds)
    try:
        yield
    finally:
        signal.alarm(0)

# Quantization keywords for filtering
QUANTIZATION_KEYWORDS = [
    "auto_round", "auto-round", "autoround", "intel",
    "autogptq", "auto_gptq", "auto-gptq", 
    "autoawq", "auto_awq", "auto-awq"
]

# Cache API results
@st.cache_data(ttl=3600)  # Cache for 1 hour
def get_user_models(username):
    api = HfApi()
    try:
        models = list(api.list_models(author=username))
        return models
    except Exception as e:
        st.error(f"Error fetching models: {str(e)}")
        return []

# Get model metadata without loading the model
@st.cache_data(ttl=3600)
def get_model_metadata(model_id):
    try:
        api = HfApi()
        model_meta = model_info(repo_id=model_id)
        return model_meta
    except Exception as e:
        st.warning(f"Failed to fetch metadata for {model_id}: {str(e)}")
        return None

# Function to check if a model matches the quantization keywords
def model_matches_keywords(model_id):
    model_name = model_id.lower()
    return any(keyword.lower() in model_name for keyword in QUANTIZATION_KEYWORDS)

# Function to extract quantization method from model name
def extract_quantization_method(model_id):
    model_name = model_id.lower()
    
    if any(kw in model_name for kw in ["auto_round", "auto-round", "autoround", "intel"]):
        return "Intel AutoRound"
    elif any(kw in model_name for kw in ["autogptq", "auto_gptq", "auto-gptq"]):
        return "AutoGPTQ"
    elif any(kw in model_name for kw in ["autoawq", "auto_awq", "auto-awq"]):
        return "AutoAWQ"
    else:
        return "Unknown"

# Function to extract model metadata from name and repo
def extract_model_metadata(model_id, repo_metadata=None):
    model_name = model_id.split("/")[-1]
    
    # Extract quantization method
    quant_method = extract_quantization_method(model_id)
    
    # Extract precision
    precision = "Unknown"
    if "int8" in model_name.lower():
        precision = "INT8"
    elif "int4" in model_name.lower():
        precision = "INT4"
    elif "fp16" in model_name.lower():
        precision = "FP16"
    elif "fp32" in model_name.lower():
        precision = "FP32"
            
    # Extract group size
    group_size = None
    gs_match = re.search(r'gs(\d+)', model_name.lower())
    if gs_match:
        group_size = int(gs_match.group(1))
    
    # Extract model size if available
    size_patterns = [r'(\d+(\.\d+)?)b', r'(\d+(\.\d+)?)m']
    model_size = None
    
    for pattern in size_patterns:
        match = re.search(pattern, model_name.lower())
        if match:
            size = float(match.group(1))
            unit = match.group(0)[-1].lower()
            if unit == 'b':
                model_size = size
            elif unit == 'm':
                model_size = size / 1000  # Convert to billions
            break
    
    # Extract base model name
    base_model = re.sub(r'[-_]?(auto_?round|auto_?gptq|auto_?awq|intel)[-_]?', '', model_name, flags=re.IGNORECASE)
    base_model = re.sub(r'[-_]?(int4|int8|fp16|fp32)[-_]?', '', base_model, flags=re.IGNORECASE)
    base_model = re.sub(r'[-_]?gs\d+[-_]?', '', base_model, flags=re.IGNORECASE)
    
    # Add repository metadata if available
    downloads = None
    likes = None
    last_modified = None
    library_name = None
    model_tags = []
    
    if repo_metadata:
        downloads = repo_metadata.downloads
        likes = repo_metadata.likes
        last_modified = repo_metadata.last_modified
        
        # Try to determine library from tags
        if hasattr(repo_metadata, "tags") and repo_metadata.tags:
            model_tags = repo_metadata.tags
            
            library_mapping = {
                "autoawq": "AutoAWQ",
                "gptq": "AutoGPTQ",
                "autogptq": "AutoGPTQ",
                "auto-gptq": "AutoGPTQ",
                "awq": "AutoAWQ",
                "quantization": "Quantized",
                "quantized": "Quantized",
                "intel": "Intel",
                "auto-round": "Intel AutoRound",
                "autoround": "Intel AutoRound"
            }
            
            for tag in model_tags:
                if tag.lower() in library_mapping:
                    library_name = library_mapping[tag.lower()]
                    break
    
    # If we couldn't determine the library from tags, use the name-based method
    if not library_name:
        library_name = quant_method
    
    return {
        "model_name": model_name,
        "base_model": base_model,
        "quant_method": quant_method,
        "precision": precision,
        "group_size": group_size,
        "model_size": model_size,
        "downloads": downloads,
        "likes": likes,
        "last_modified": last_modified,
        "library": library_name,
        "tags": model_tags
    }

# Get model stats without loading the entire model
@st.cache_data(ttl=3600)
def get_model_stats(model_id):
    try:
        api = HfApi()
        sibling_files = api.list_repo_files(repo_id=model_id)
        
        # Look for config files
        config_file = None
        for file in sibling_files:
            if file.endswith("config.json") or file == "config.json":
                config_file = file
                break
        
        if config_file:
            # Download just the config file
            config_content = api.hf_hub_download(repo_id=model_id, filename=config_file)
            
            with open(config_content, 'r') as f:
                config = json.load(f)
            
            # Extract useful info
            stats = {}
            
            # Get hidden size
            if "hidden_size" in config:
                stats["hidden_size"] = config["hidden_size"]
            
            # Get vocab size
            if "vocab_size" in config:
                stats["vocab_size"] = config["vocab_size"]
            
            # Get number of layers/blocks
            for key in ["num_hidden_layers", "n_layer", "num_layers"]:
                if key in config:
                    stats["num_layers"] = config[key]
                    break
            
            # Get attention details
            if "num_attention_heads" in config:
                stats["num_attention_heads"] = config["num_attention_heads"]
            
            # Get sequence length
            for key in ["max_position_embeddings", "n_positions", "max_seq_len"]:
                if key in config:
                    stats["max_seq_len"] = config[key]
                    break
            
            return stats
        
        return {}
    except Exception as e:
        st.warning(f"Failed to fetch stats for {model_id}: {str(e)}")
        return {}

# Function to estimate model size (without loading the model)
def estimate_model_size_from_files(model_id):
    try:
        api = HfApi()
        sibling_files = list(api.list_repo_files(repo_id=model_id))
        
        # Look for binary model files
        model_files = [f for f in sibling_files if f.endswith('.bin') or f.endswith('.safetensors')]
        
        total_size = 0
        for file in model_files:
            file_info = api.hf_hub_file_info(repo_id=model_id, filename=file)
            total_size += file_info.size
        
        # Convert to GB
        size_gb = total_size / (1024 ** 3)
        return size_gb
    except Exception as e:
        st.warning(f"Failed to estimate size for {model_id}: {str(e)}")
        return None

# Main function
def main():
    st.title("🔍 Quantized Model Comparison Tool")
    st.write("Compare Intel AutoRound, AutoGPTQ, and AutoAWQ models (optimized for free tier Space)")
    
    # Sidebar for configuration
    st.sidebar.header("Configuration")
    username = st.sidebar.text_input("HuggingFace Username", "fbaldassarri")
    
    # Fetch all models
    with st.spinner("Fetching models..."):
        all_models = get_user_models(username)
        all_model_ids = [model.id for model in all_models]
    
    # Filter models with quantization keywords
    quantized_model_ids = [model_id for model_id in all_model_ids if model_matches_keywords(model_id)]
    
    st.sidebar.write(f"Found {len(quantized_model_ids)} quantized models out of {len(all_model_ids)} total models")
    
    # Quantization method filtering
    quant_methods = ["Intel AutoRound", "AutoGPTQ", "AutoAWQ"]
    selected_quant_methods = st.sidebar.multiselect(
        "Filter by quantization method", 
        options=quant_methods,
        default=quant_methods
    )
    
    # Additional filtering
    additional_filter = st.sidebar.text_input("Additional model name filter", "")
    
    # Apply filters
    filtered_models = []
    for model_id in quantized_model_ids:
        quant_method = extract_quantization_method(model_id)
        if quant_method in selected_quant_methods:
            if additional_filter.lower() in model_id.lower() or not additional_filter:
                filtered_models.append(model_id)
    
    # Group models by base model name
    model_groups = {}
    for model_id in filtered_models:
        metadata = extract_model_metadata(model_id)
        base_model = metadata["base_model"]
        if base_model not in model_groups:
            model_groups[base_model] = []
        model_groups[base_model].append(model_id)
    
    # Select base model group
    base_model_options = list(model_groups.keys())
    base_model_options.sort()
    
    selected_base_model = st.sidebar.selectbox(
        "Select base model to compare", 
        options=["All"] + base_model_options
    )
    
    # Final model selection
    if selected_base_model == "All":
        model_selection_options = filtered_models
    else:
        model_selection_options = model_groups[selected_base_model]
    
    # Limit selection to prevent resource issues
    max_models_comparison = st.sidebar.slider("Maximum models to compare", 2, 10, 5)
    default_models = model_selection_options[:min(max_models_comparison, len(model_selection_options))]
    
    selected_models = st.sidebar.multiselect(
        "Select models to compare", 
        options=model_selection_options,
        default=default_models
    )
    
    # Limit selection if exceeded
    if len(selected_models) > max_models_comparison:
        st.warning(f"⚠️ Limited to {max_models_comparison} models for comparison (CPU constraints)")
        selected_models = selected_models[:max_models_comparison]
    
    # Comparison method
    st.sidebar.header("Comparison Method")
    
    compare_method = st.sidebar.radio(
        "Choose comparison method",
        ["Metadata Comparison Only", "Metadata + Estimated Size"]
    )
    
    if st.button("Run Comparison") and selected_models:
        # Progress tracking
        progress_bar = st.progress(0)
        status_text = st.empty()
        
        results = []
        
        # Analyze each model
        for i, model_id in enumerate(selected_models):
            status_text.text(f"Analyzing {model_id} ({i+1}/{len(selected_models)})")
            
            # Get repository metadata
            repo_meta = get_model_metadata(model_id)
            
            # Extract metadata
            metadata = extract_model_metadata(model_id, repo_meta)
            model_result = metadata.copy()
            
            # Get model architecture stats
            model_stats = get_model_stats(model_id)
            model_result.update(model_stats)
            
            # Get estimated size if needed
            if compare_method == "Metadata + Estimated Size":
                with st.spinner(f"Estimating size for {model_id}..."):
                    try:
                        estimated_size = estimate_model_size_from_files(model_id)
                        model_result["estimated_size_gb"] = estimated_size
                    except Exception as e:
                        st.warning(f"Size estimation failed for {model_id}: {str(e)}")
            
            # Add to results
            results.append(model_result)
            
            # Update progress
            progress_bar.progress((i + 1) / len(selected_models))
        
        # Clear progress indicators
        progress_bar.empty()
        status_text.empty()
        
        # Display results
        if results:
            # Convert to DataFrame
            results_df = pd.DataFrame(results)
            
            # Add formatting for dates if present
            if "last_modified" in results_df.columns:
                results_df["last_modified"] = pd.to_datetime(results_df["last_modified"])
                results_df["days_since_update"] = (pd.Timestamp.now() - results_df["last_modified"]).dt.days
            
            # Sort by quantization method and model name
            if "quant_method" in results_df.columns and "model_name" in results_df.columns:
                results_df = results_df.sort_values(["quant_method", "model_name"])
            
            # Display results in tabs
            results_tabs = st.tabs(["Model Comparison", "Model Details", "Visualizations"])
            
            with results_tabs[0]:
                st.subheader("Model Comparison")
                
                # Define columns to display
                basic_cols = ["model_name", "quant_method", "precision", "group_size"]
                
                size_cols = []
                if "model_size" in results_df.columns:
                    size_cols.append("model_size")
                if "estimated_size_gb" in results_df.columns:
                    size_cols.append("estimated_size_gb")
                
                arch_cols = []
                for col in ["num_layers", "hidden_size", "num_attention_heads", "max_seq_len"]:
                    if col in results_df.columns:
                        arch_cols.append(col)
                
                stats_cols = []
                for col in ["downloads", "likes", "days_since_update"]:
                    if col in results_df.columns:
                        stats_cols.append(col)
                
                # Create display dataframe
                display_cols = basic_cols + size_cols + arch_cols + stats_cols
                display_df = results_df[display_cols].copy()
                
                # Format columns
                if "estimated_size_gb" in display_df.columns:
                    display_df["estimated_size_gb"] = display_df["estimated_size_gb"].apply(
                        lambda x: f"{x:.2f} GB" if pd.notna(x) else "Unknown"
                    )
                
                if "model_size" in display_df.columns:
                    display_df["model_size"] = display_df["model_size"].apply(
                        lambda x: f"{x:.2f}B" if pd.notna(x) else "Unknown"
                    )
                
                # Display the table
                st.dataframe(display_df)
            
            with results_tabs[1]:
                st.subheader("Detailed Model Information")
                
                # Create tabs for each model
                model_tabs = st.tabs([m.split("/")[-1] for m in selected_models])
                
                for i, model_id in enumerate(selected_models):
                    with model_tabs[i]:
                        # Get the model row
                        model_row = results_df[results_df["model_name"] == model_id.split("/")[-1]].iloc[0]
                        
                        # Display model info in columns
                        col1, col2 = st.columns(2)
                        
                        with col1:
                            st.markdown("#### Model Information")
                            st.markdown(f"**Repository:** {model_id}")
                            st.markdown(f"**Base Model:** {model_row.get('base_model', 'Unknown')}")
                            st.markdown(f"**Quantization:** {model_row.get('quant_method', 'Unknown')}")
                            st.markdown(f"**Precision:** {model_row.get('precision', 'Unknown')}")
                            
                            if "group_size" in model_row and pd.notna(model_row["group_size"]):
                                st.markdown(f"**Group Size:** {int(model_row['group_size'])}")
                                
                            if "estimated_size_gb" in model_row and pd.notna(model_row["estimated_size_gb"]):
                                st.markdown(f"**Model Size:** {model_row['estimated_size_gb']:.2f} GB")
                        
                        with col2:
                            st.markdown("#### Architecture Details")
                            
                            for col in ["hidden_size", "num_layers", "num_attention_heads", "max_seq_len", "vocab_size"]:
                                if col in model_row and pd.notna(model_row[col]):
                                    st.markdown(f"**{col.replace('_', ' ').title()}:** {int(model_row[col])}")
                        
                        # Repository stats
                        st.markdown("#### Repository Statistics")
                        stat_cols = st.columns(3)
                        
                        with stat_cols[0]:
                            if "downloads" in model_row and pd.notna(model_row["downloads"]):
                                st.metric("Downloads", f"{int(model_row['downloads']):,}")
                        
                        with stat_cols[1]:
                            if "likes" in model_row and pd.notna(model_row["likes"]):
                                st.metric("Likes", f"{int(model_row['likes']):,}")
                        
                        with stat_cols[2]:
                            if "days_since_update" in model_row and pd.notna(model_row["days_since_update"]):
                                st.metric("Days Since Update", f"{int(model_row['days_since_update'])}")
                        
                        # Tags
                        if "tags" in model_row and model_row["tags"]:
                            st.markdown("#### Model Tags")
                            tags_html = " ".join([f"<span style='background-color: #eee; padding: 0.2rem 0.5rem; border-radius: 0.5rem; margin-right: 0.5rem;'>{tag}</span>" for tag in model_row["tags"]])
                            st.markdown(tags_html, unsafe_allow_html=True)
                        
                        # Add a link to the model
                        st.markdown(f"[View on HuggingFace 🤗]({'https://huggingface.co/' + model_id})")
            
            with results_tabs[2]:
                st.subheader("Visualizations")
                
                viz_tabs = st.tabs(["Quantization Methods", "Model Architecture", "Repository Stats"])
                
                with viz_tabs[0]:
                    # Quantization method distribution
                    if "quant_method" in results_df.columns:
                        method_counts = results_df["quant_method"].value_counts().reset_index()
                        method_counts.columns = ["Method", "Count"]
                        
                        fig = px.pie(
                            method_counts, 
                            names="Method", 
                            values="Count",
                            title="Distribution of Quantization Methods",
                            color="Method",
                            color_discrete_map={
                                "Intel AutoRound": "#0071c5",
                                "AutoGPTQ": "#ff4b4b",
                                "AutoAWQ": "#1e88e5"
                            }
                        )
                        st.plotly_chart(fig, use_container_width=True)
                    
                    # Precision distribution
                    if "precision" in results_df.columns:
                        precision_counts = results_df["precision"].value_counts().reset_index()
                        precision_counts.columns = ["Precision", "Count"]
                        
                        fig = px.bar(
                            precision_counts,
                            x="Precision",
                            y="Count",
                            title="Distribution of Precision Formats",
                            color="Precision"
                        )
                        st.plotly_chart(fig, use_container_width=True)
                    
                    # Group size distribution (if available)
                    if "group_size" in results_df.columns and results_df["group_size"].notna().any():
                        valid_gs_data = results_df[results_df["group_size"].notna()]
                        gs_counts = valid_gs_data["group_size"].value_counts().reset_index()
                        gs_counts.columns = ["Group Size", "Count"]
                        
                        fig = px.bar(
                            gs_counts,
                            x="Group Size",
                            y="Count",
                            title="Distribution of Group Sizes",
                            color="Group Size"
                        )
                        st.plotly_chart(fig, use_container_width=True)
                
                with viz_tabs[1]:
                    # Model size comparison
                    if "estimated_size_gb" in results_df.columns and results_df["estimated_size_gb"].notna().any():
                        valid_size_data = results_df[results_df["estimated_size_gb"].notna()].sort_values("estimated_size_gb")
                        
                        fig = px.bar(
                            valid_size_data,
                            x="model_name",
                            y="estimated_size_gb",
                            color="quant_method",
                            title="Model Size Comparison (GB)",
                            labels={"estimated_size_gb": "Size (GB)", "model_name": "Model", "quant_method": "Method"}
                        )
                        fig.update_layout(xaxis_tickangle=-45)
                        st.plotly_chart(fig, use_container_width=True)
                    
                    # Architecture comparison
                    for arch_col in ["num_layers", "hidden_size", "num_attention_heads"]:
                        if arch_col in results_df.columns and results_df[arch_col].notna().any():
                            valid_data = results_df[results_df[arch_col].notna()].sort_values(arch_col)
                            
                            fig = px.bar(
                                valid_data,
                                x="model_name",
                                y=arch_col,
                                color="quant_method",
                                title=f"{arch_col.replace('_', ' ').title()} Comparison",
                                labels={arch_col: arch_col.replace('_', ' ').title(), "model_name": "Model", "quant_method": "Method"}
                            )
                            fig.update_layout(xaxis_tickangle=-45)
                            st.plotly_chart(fig, use_container_width=True)
                
                with viz_tabs[2]:
                    # Downloads comparison
                    if "downloads" in results_df.columns and results_df["downloads"].notna().any():
                        valid_data = results_df[results_df["downloads"].notna()].sort_values("downloads", ascending=False)
                        
                        fig = px.bar(
                            valid_data,
                            x="model_name",
                            y="downloads",
                            color="quant_method",
                            title="Downloads Comparison",
                            labels={"downloads": "Downloads", "model_name": "Model", "quant_method": "Method"}
                        )
                        fig.update_layout(xaxis_tickangle=-45)
                        st.plotly_chart(fig, use_container_width=True)
                    
                    # Likes comparison
                    if "likes" in results_df.columns and results_df["likes"].notna().any():
                        valid_data = results_df[results_df["likes"].notna()].sort_values("likes", ascending=False)
                        
                        fig = px.bar(
                            valid_data,
                            x="model_name",
                            y="likes",
                            color="quant_method",
                            title="Likes Comparison",
                            labels={"likes": "Likes", "model_name": "Model", "quant_method": "Method"}
                        )
                        fig.update_layout(xaxis_tickangle=-45)
                        st.plotly_chart(fig, use_container_width=True)
                    
                    # Last updated comparison
                    if "days_since_update" in results_df.columns and results_df["days_since_update"].notna().any():
                        valid_data = results_df[results_df["days_since_update"].notna()].sort_values("days_since_update")
                        
                        fig = px.bar(
                            valid_data,
                            x="model_name",
                            y="days_since_update",
                            color="quant_method",
                            title="Days Since Last Update",
                            labels={"days_since_update": "Days", "model_name": "Model", "quant_method": "Method"}
                        )
                        fig.update_layout(xaxis_tickangle=-45)
                        st.plotly_chart(fig, use_container_width=True)
            
            # Export options
            st.subheader("Export Results")
            
            # Prepare download data
            csv_data = results_df.to_csv(index=False)
            
            st.download_button(
                "Download Results as CSV",
                data=csv_data,
                file_name=f"quantized_model_comparison_{username}_{time.strftime('%Y%m%d_%H%M')}.csv",
                mime="text/csv"
            )
        else:
            st.warning("No results were obtained. Please check for errors and try again.")
    
    # Show instructions if no comparison run
    if not st.session_state.get('comparison_run', False):
        st.info("""
        ## CPU-Optimized Model Comparison
        
        This tool is designed to compare your quantized models without requiring GPU resources, making it suitable for the free tier HuggingFace Space.
        
        ### Features:
        
        - **Metadata Analysis**: Compare model architectures without loading models
        - **Repository Stats**: View downloads, likes, and update frequency
        - **Visualization**: Compare models across multiple dimensions
        - **Filtering**: Focus on specific quantization methods or model families
        
        ### Supported Quantization Methods:
        
        - **Intel AutoRound**: Intel's quantization solution
        - **AutoGPTQ**: Automatic GPTQ quantization
        - **AutoAWQ**: Activation-aware weight quantization
        
        ### Instructions:
        
        1. Select models using the sidebar filters
        2. Click "Run Comparison" to analyze without loading full models
        3. View results in the tabs and charts
        4. Download results as CSV for further analysis
        """)

if __name__ == "__main__":
    main()