import gradio as gr import torch from transformers import AutoModel, BitsAndBytesConfig, AutoTokenizer import tempfile from huggingface_hub import HfApi from huggingface_hub import list_models from gradio_huggingfacehub_search import HuggingfaceHubSearch from bitsandbytes.nn import Linear4bit import os from huggingface_hub import snapshot_download def hello(profile: gr.OAuthProfile | None, oauth_token: gr.OAuthToken | None) -> str: # ^ expect a gr.OAuthProfile object as input to get the user's profile # if the user is not logged in, profile will be None if profile is None: return "Hello ! Please Login to your HuggingFace account to use the BitsAndBytes Quantizer!" return f"Hello {profile.name} ! Welcome to BitsAndBytes Quantizer" def check_model_exists( oauth_token: gr.OAuthToken | None, username, model_name, quantized_model_name, upload_to_community ): """Check if a model exists in the user's Hugging Face repository.""" try: models = list_models(author=username, token=oauth_token.token) community_models = list_models(author="bnb-community", token=oauth_token.token) model_names = [model.id for model in models] community_model_names = [model.id for model in community_models] if upload_to_community: repo_name = f"bnb-community/{model_name.split('/')[-1]}-bnb-4bit" else: if quantized_model_name: repo_name = f"{username}/{quantized_model_name}" else: repo_name = f"{username}/{model_name.split('/')[-1]}-bnb-4bit" if repo_name in model_names: return f"Model '{repo_name}' already exists in your repository." elif repo_name in community_model_names: return f"Model '{repo_name}' already exists in the bnb-community organization." else: return None # Model does not exist except Exception as e: return f"Error checking model existence: {str(e)}" def create_model_card( model_name, quant_type_4, double_quant_4, compute_type_4, quant_storage_4 ): # Try to download the original README original_readme = "" original_yaml_header = "" try: # Download the README.md file from the original model model_path = snapshot_download(repo_id=model_name, allow_patterns=["README.md"], repo_type="model") readme_path = os.path.join(model_path, "README.md") if os.path.exists(readme_path): with open(readme_path, 'r', encoding='utf-8') as f: content = f.read() if content.startswith('---'): parts = content.split('---', 2) if len(parts) >= 3: original_yaml_header = parts[1] original_readme = '---'.join(parts[2:]) else: original_readme = content else: original_readme = content except Exception as e: print(f"Error reading original README: {str(e)}") original_readme = "" # Create new YAML header with base_model field yaml_header = f"""--- base_model: - {model_name}""" # Add any original YAML fields except base_model if original_yaml_header: in_base_model_section = False found_tags = False for line in original_yaml_header.strip().split('\n'): # Skip if we're in a base_model section that continues to the next line if in_base_model_section: if line.strip().startswith('-') or not line.strip() or line.startswith(' '): continue else: in_base_model_section = False # Check for base_model field if line.strip().startswith('base_model:'): in_base_model_section = True # If base_model has inline value (like "base_model: model_name") if ':' in line and len(line.split(':', 1)[1].strip()) > 0: in_base_model_section = False continue # Check for tags field and add bnb-my-repo if line.strip().startswith('tags:'): found_tags = True yaml_header += f"\n{line}" yaml_header += "\n- bnb-my-repo" continue yaml_header += f"\n{line}" # If tags field wasn't found, add it if not found_tags: yaml_header += "\ntags:" yaml_header += "\n- bnb-my-repo" # Complete the YAML header yaml_header += "\n---" # Create the quantization info section quant_info = f""" # {model_name} (Quantized) ## Description This model is a quantized version of the original model [`{model_name}`](https://huggingface.co/{model_name}). It's quantized using the BitsAndBytes library to 4-bit using the [bnb-my-repo](https://huggingface.co/spaces/bnb-community/bnb-my-repo) space. ## Quantization Details - **Quantization Type**: int4 - **bnb_4bit_quant_type**: {quant_type_4} - **bnb_4bit_use_double_quant**: {double_quant_4} - **bnb_4bit_compute_dtype**: {compute_type_4} - **bnb_4bit_quant_storage**: {quant_storage_4} """ # Combine everything model_card = yaml_header + quant_info # Append original README content if available if original_readme and not original_readme.isspace(): model_card += "\n\n# 📄 Original Model Information\n\n" + original_readme return model_card DTYPE_MAPPING = { "int8": torch.int8, "uint8": torch.uint8, "float16": torch.float16, "float32": torch.float32, "bfloat16": torch.bfloat16, } def quantize_model( model_name, quant_type_4, double_quant_4, compute_type_4, quant_storage_4, auth_token=None, progress=gr.Progress(), ): progress(0, desc="Loading model") # Configure quantization quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type=quant_type_4, bnb_4bit_use_double_quant=True if double_quant_4 == "True" else False, bnb_4bit_quant_storage=DTYPE_MAPPING[quant_storage_4], bnb_4bit_compute_dtype=DTYPE_MAPPING[compute_type_4], ) # Load model model = AutoModel.from_pretrained( model_name, quantization_config=quantization_config, device_map="cpu", use_auth_token=auth_token.token, torch_dtype="auto", ) progress(0.33, desc="Quantizing") # Quantize model # Calculate original model sizeo original_size_gb = get_model_size(model) modules = list(model.named_modules()) for idx, (_, module) in enumerate(modules): if isinstance(module, Linear4bit): module.to("cuda") module.to("cpu") progress(0.33 + (0.33 * idx / len(modules)), desc="Quantizing") progress(0.66, desc="Quantized successfully") return model, original_size_gb def save_model( model, model_name, original_size_gb, quant_type_4, double_quant_4, compute_type_4, quant_storage_4, username=None, auth_token=None, quantized_model_name=None, public=False, upload_to_community=False, progress=gr.Progress(), ): progress(0.67, desc="Preparing to push") with tempfile.TemporaryDirectory() as tmpdirname: # Save model tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=auth_token.token) tokenizer.save_pretrained(tmpdirname, safe_serialization=True, use_auth_token=auth_token.token) model.save_pretrained( tmpdirname, safe_serialization=True, use_auth_token=auth_token.token ) progress(0.75, desc="Preparing to push") # Prepare repo name and model card if upload_to_community: repo_name = f"bnb-community/{model_name.split('/')[-1]}-bnb-4bit" else: if quantized_model_name: repo_name = f"{username}/{quantized_model_name}" else: repo_name = f"{username}/{model_name.split('/')[-1]}-bnb-4bit" model_card = create_model_card( model_name, quant_type_4, double_quant_4, compute_type_4, quant_storage_4 ) with open(os.path.join(tmpdirname, "README.md"), "w") as f: f.write(model_card) progress(0.80, desc="Model card created") # Push to Hub api = HfApi(token=auth_token.token) api.create_repo(repo_name, exist_ok=True, private=not public) progress(0.85, desc="Pushing to Hub") # Upload files api.upload_folder( folder_path=tmpdirname, repo_id=repo_name, repo_type="model", ) progress(0.95, desc="Model pushed to Hub") # Get model architecture as string import io from contextlib import redirect_stdout import html # Capture the model architecture string f = io.StringIO() with redirect_stdout(f): print(model) model_architecture_str = f.getvalue() # Escape HTML characters and format with line breaks model_architecture_str_html = html.escape(model_architecture_str).replace( "\n", "
" ) # Format it for display in markdown with proper styling model_architecture_info = f"""

📋 Model Architecture

{model_architecture_str_html}
""" model_size_info = f"""

📦 Model Size

Original (bf16)≈ {original_size_gb} GB → Quantized ≈ {get_model_size(model)} GB

""" repo_link = f""" """ return f'

🎉 Quantization Completed


{repo_link}{model_size_info}{model_architecture_info}' def quantize_and_save( profile: gr.OAuthProfile | None, oauth_token: gr.OAuthToken | None, model_name, quant_type_4, double_quant_4, compute_type_4, quant_storage_4, quantized_model_name, public, upload_to_community, progress=gr.Progress(), ): if oauth_token is None: return """

❌ Authentication Error

Please sign in to your HuggingFace account to use the quantizer.

""" if not profile: return """

❌ Authentication Error

Please sign in to your HuggingFace account to use the quantizer.

""" exists_message = check_model_exists( oauth_token, profile.username, model_name, quantized_model_name, upload_to_community ) if exists_message: return f"""

⚠️ Model Already Exists

{exists_message}

""" try: # Download phase progress(0, desc="Starting quantization process") quantized_model, original_size_gb = quantize_model( model_name, quant_type_4, double_quant_4, compute_type_4, quant_storage_4, oauth_token, progress, ) final_message = save_model( quantized_model, model_name, original_size_gb, quant_type_4, double_quant_4, compute_type_4, quant_storage_4, profile.username, oauth_token, quantized_model_name, public, upload_to_community, progress, ) # Clean up the model to free memory del quantized_model # Force garbage collection to release memory import gc gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() progress(1.0, desc="Memory cleaned") return final_message except Exception as e: error_message = str(e).replace("\n", "
") return f"""

❌ Error Occurred

{error_message}

""" def get_model_size(model): """ Calculate the size of a PyTorch model in gigabytes. Args: model: PyTorch model Returns: float: Size of the model in GB """ # Get model state dict state_dict = model.state_dict() # Calculate total size in bytes total_size = 0 for param in state_dict.values(): # Calculate bytes for each parameter total_size += param.nelement() * param.element_size() # Convert bytes to gigabytes (1 GB = 1,073,741,824 bytes) size_gb = total_size / (1024 ** 3) size_gb = round(size_gb, 2) return size_gb css = """/* Custom CSS to allow scrolling */ .gradio-container {overflow-y: auto;} /* Fix alignment for radio buttons and checkboxes */ .gradio-radio { display: flex !important; align-items: center !important; margin: 10px 0 !important; } .gradio-checkbox { display: flex !important; align-items: center !important; margin: 10px 0 !important; } /* Ensure consistent spacing and alignment */ .gradio-dropdown, .gradio-textbox, .gradio-radio, .gradio-checkbox { margin-bottom: 12px !important; width: 100% !important; } /* Align radio buttons and checkboxes horizontally */ .option-row { display: flex !important; justify-content: space-between !important; align-items: center !important; gap: 20px !important; margin-bottom: 12px !important; } .option-row .gradio-radio, .option-row .gradio-checkbox { margin: 0 !important; flex: 1 !important; } /* Horizontally align radio button options with text */ .gradio-radio label { display: flex !important; align-items: center !important; } .gradio-radio input[type="radio"] { margin-right: 5px !important; } /* Remove padding and margin from model name textbox for better alignment */ .model-name-textbox { padding-left: 0 !important; padding-right: 0 !important; margin-left: 0 !important; margin-right: 0 !important; } /* Quantize button styling with glow effect */ button[variant="primary"] { background: linear-gradient(135deg, #3B82F6, #10B981) !important; color: white !important; padding: 16px 32px !important; font-size: 1.1rem !important; font-weight: 700 !important; border: none !important; border-radius: 12px !important; box-shadow: 0 0 15px rgba(59, 130, 246, 0.5) !important; transition: all 0.3s cubic-bezier(0.25, 0.8, 0.25, 1) !important; position: relative; overflow: hidden; animation: glow 1.5s ease-in-out infinite alternate; } button[variant="primary"]::before { content: "✨ "; } button[variant="primary"]:hover { transform: translateY(-5px) scale(1.05) !important; box-shadow: 0 10px 25px rgba(59, 130, 246, 0.7) !important; } @keyframes glow { from { box-shadow: 0 0 10px rgba(59, 130, 246, 0.5); } to { box-shadow: 0 0 20px rgba(59, 130, 246, 0.8), 0 0 30px rgba(16, 185, 129, 0.5); } } /* Login button styling with glow effect */ #login-button { background: linear-gradient(135deg, #3B82F6, #10B981) !important; color: white !important; font-weight: 700 !important; border: none !important; border-radius: 12px !important; box-shadow: 0 0 15px rgba(59, 130, 246, 0.5) !important; transition: all 0.3s cubic-bezier(0.25, 0.8, 0.25, 1) !important; position: relative; overflow: hidden; animation: glow 1.5s ease-in-out infinite alternate; max-width: 300px !important; margin: 0 auto !important; } #login-button::before { content: "🔑 "; display: inline-block !important; vertical-align: middle !important; margin-right: 5px !important; line-height: normal !important; } #login-button:hover { transform: translateY(-3px) scale(1.03) !important; box-shadow: 0 10px 25px rgba(59, 130, 246, 0.7) !important; } #login-button::after { content: ""; position: absolute; top: 0; left: -100%; width: 100%; height: 100%; background: linear-gradient(90deg, transparent, rgba(255, 255, 255, 0.2), transparent); transition: 0.5s; } #login-button:hover::after { left: 100%; } /* Toggle instructions button styling */ #toggle-button { background: linear-gradient(135deg, #3B82F6, #10B981) !important; color: white !important; font-size: 0.85rem !important; font-weight: 600 !important; padding: 8px 16px !important; border: none !important; border-radius: 8px !important; box-shadow: 0 2px 10px rgba(59, 130, 246, 0.3) !important; transition: all 0.3s ease !important; margin: 0.5rem auto 1.5rem auto !important; display: block !important; max-width: 200px !important; text-align: center !important; position: relative; overflow: hidden; } #toggle-button:hover { transform: translateY(-2px) !important; box-shadow: 0 4px 12px rgba(59, 130, 246, 0.5) !important; } #toggle-button::after { content: ""; position: absolute; top: 0; left: -100%; width: 100%; height: 100%; background: linear-gradient(90deg, transparent, rgba(255, 255, 255, 0.2), transparent); transition: 0.5s; } #toggle-button:hover::after { left: 100%; } /* Progress Bar Styles */ .progress-container { font-family: system-ui, -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; padding: 20px; background: white; border-radius: 12px; box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); } .progress-stage { font-size: 0.9rem; font-weight: 600; color: #64748b; } .progress-stage .stage { position: relative; padding: 8px 12px; border-radius: 6px; background: #f1f5f9; transition: all 0.3s ease; } .progress-stage .stage.completed { background: #ecfdf5; } .progress-bar { box-shadow: inset 0 2px 4px rgba(0, 0, 0, 0.1); } .progress { transition: width 0.8s cubic-bezier(0.4, 0, 0.2, 1); box-shadow: 0 2px 4px rgba(59, 130, 246, 0.3); } """ with gr.Blocks(theme=gr.themes.Ocean(), css=css) as demo: gr.Markdown( """ # 🤗 BitsAndBytes Quantizer : Create your own BNB Quants ! ✨

""" ) gr.LoginButton(elem_id="login-button", elem_classes="center-button", min_width=250) m1 = gr.Markdown() demo.load(hello, inputs=None, outputs=m1) instructions_visible = gr.State(False) with gr.Row(): with gr.Column(): with gr.Row(): model_name = HuggingfaceHubSearch( label="🔍 Hub Model ID", placeholder="Search for model id on Huggingface", search_type="model", ) with gr.Row(): with gr.Column(): gr.Markdown( """ ### ⚙️ Model Quantization Type Settings """ ) quant_type_4 = gr.Dropdown( info="The quantization data type in the bnb.nn.Linear4Bit layers", choices=["fp4", "nf4"], value="nf4", visible=True, show_label=False, ) compute_type_4 = gr.Dropdown( info="The compute type for the model", choices=["float16", "bfloat16", "float32"], value="bfloat16", visible=True, show_label=False, ) quant_storage_4 = gr.Dropdown( info="The storage type for the model", choices=["float16", "float32", "int8", "uint8", "bfloat16"], value="uint8", visible=True, show_label=False, ) gr.Markdown( """ ### 🔄 Double Quantization Settings """ ) with gr.Row(elem_classes="option-row"): double_quant_4 = gr.Radio( ["True", "False"], info="Use Double Quant", visible=True, value="True", show_label=False, ) gr.Markdown( """ ### 💾 Saving Settings """ ) with gr.Row(): quantized_model_name = gr.Textbox( label="✏️ Model Name", info="Model Name (optional : to override default)", value="", interactive=True, elem_classes="model-name-textbox", show_label=False, ) with gr.Row(): public = gr.Checkbox( label="🌐 Make model public", info="If checked, the model will be publicly accessible", value=True, interactive=True, show_label=True, ) with gr.Row(): upload_to_community = gr.Checkbox( label="🤗 Upload to bnb-community", info="If checked, the model will be uploaded to the bnb-community organization \n(Give the space access to the bnb-community, if not already done revoke the token and login again)", value=False, interactive=True, show_label=True, ) # Add event handler to disable and clear model name when uploading to community def toggle_model_name(upload_to_community_checked): return gr.update( interactive=not upload_to_community_checked, value="Can't change model name when uploading to community" if upload_to_community_checked else quantized_model_name.value ) upload_to_community.change( fn=toggle_model_name, inputs=[upload_to_community], outputs=quantized_model_name ) with gr.Column(): quantize_button = gr.Button( "🚀 Quantize and Push to the Hub", variant="primary" ) output_link = gr.Markdown( "🔗 Quantized Model Info", container=True, min_height=200 ) quantize_button.click( fn=quantize_and_save, inputs=[ model_name, quant_type_4, double_quant_4, compute_type_4, quant_storage_4, quantized_model_name, public, upload_to_community, ], outputs=[output_link], show_progress="full", ) # Add information section about the app options with gr.Accordion("📚 About this app", open=True): gr.Markdown( """ ## 📝 Notes on Quantization Options ### Quantization Type (bnb_4bit_quant_type) - **fp4**: Floating-point 4-bit quantization. - **nf4**: Normal float 4-bit quantization. ### Double Quantization - **True**: Applies a second round of quantization to the quantization constants, further reducing memory usage. - **False**: Uses standard quantization only. ### Model Saving Options - **Model Name**: Custom name for your quantized model on the Hub. If left empty, a default name will be generated. - **Make model public**: If checked, anyone can access your quantized model. If unchecked, only you can access it. ## 🔍 How It Works This app uses the BitsAndBytes library to perform 4-bit quantization on Transformer models. The process: 1. Downloads the original model 2. Applies the selected quantization settings 3. Uploads the quantized model to your HuggingFace account ## 📊 Memory Usage 4-bit quantization can reduce model size by up to ≈75% compared to FP16 for big models. """ ) if __name__ == "__main__": demo.launch(share=True)