import gradio as gr import torch from transformers import AutoModel, BitsAndBytesConfig import tempfile from huggingface_hub import HfApi from huggingface_hub import list_models from gradio_huggingfacehub_search import HuggingfaceHubSearch from bitsandbytes.nn import Linear4bit from packaging import version import os from tqdm import tqdm def hello(profile: gr.OAuthProfile | None, oauth_token: gr.OAuthToken | None) -> str: # ^ expect a gr.OAuthProfile object as input to get the user's profile # if the user is not logged in, profile will be None if profile is None: return "Hello Please Login to HuggingFace to use the BitsAndBytes Quantizer!" return f"Hello {profile.name} ! Welcome to BitsAndBytes Quantizer" def check_model_exists(oauth_token: gr.OAuthToken | None, username, model_name, quantized_model_name): """Check if a model exists in the user's Hugging Face repository.""" try: models = list_models(author=username, token=oauth_token.token) model_names = [model.id for model in models] if quantized_model_name : repo_name = f"{username}/{quantized_model_name}" else : repo_name = f"{username}/{model_name.split('/')[-1]}-bnb-4bit" if repo_name in model_names: return f"Model '{repo_name}' already exists in your repository." else: return None # Model does not exist except Exception as e: return f"Error checking model existence: {str(e)}" def create_model_card(model_name, quant_type_4, double_quant_4, compute_type_4, quant_storage_4): model_card = f"""--- base_model: - {model_name} --- # {model_name} (Quantized) ## Description This model is a quantized version of the original model `{model_name}`. It has been quantized using int4 quantization with bitsandbytes. ## Quantization Details - **Quantization Type**: int4 - **bnb_4bit_quant_type**: {quant_type_4} - **bnb_4bit_use_double_quant**: {double_quant_4} - **bnb_4bit_compute_dtype**: {compute_type_4} - **bnb_4bit_quant_storage**: {quant_storage_4} ## Usage You can use this model in your applications by loading it directly from the Hugging Face Hub: ```python from transformers import AutoModel model = AutoModel.from_pretrained("{model_name}")""" return model_card DTYPE_MAPPING = { "int8": torch.int8, "uint8": torch.uint8, "float16": torch.float16, "float32": torch.float32, "bfloat16": torch.bfloat16, } def quantize_model(model_name, quant_type_4, double_quant_4, compute_type_4, quant_storage_4, auth_token=None, progress=gr.Progress()): progress(0, desc="Starting") print(f"Quantizing model: {quant_type_4}") quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type=quant_type_4, bnb_4bit_use_double_quant=True if double_quant_4 == "True" else False, bnb_4bit_quant_storage=DTYPE_MAPPING[quant_storage_4], bnb_4bit_compute_dtype=DTYPE_MAPPING[compute_type_4], ) model = AutoModel.from_pretrained(model_name, quantization_config=quantization_config, device_map="cpu", use_auth_token=auth_token.token, torch_dtype=torch.bfloat16) for _ , module in progress.tqdm(model.named_modules(), desc="Quantizing model", total=len(list(model.named_modules())), unit="layers"): if isinstance(module, Linear4bit): module.to("cuda") module.to("cpu") return model def save_model(model, model_name, quant_type_4, double_quant_4, compute_type_4, quant_storage_4, username=None, auth_token=None, quantized_model_name=None, public=False): print("Saving quantized model") with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname, safe_serialization=True, use_auth_token=auth_token.token) if quantized_model_name : repo_name = f"{username}/{quantized_model_name}" else : repo_name = f"{username}/{model_name.split('/')[-1]}-bnb-4bit" model_card = create_model_card(repo_name, quant_type_4, double_quant_4, compute_type_4, quant_storage_4) with open(os.path.join(tmpdirname, "README.md"), "w") as f: f.write(model_card) # Push to Hub api = HfApi(token=auth_token.token) api.create_repo(repo_name, exist_ok=True, private=not public) api.upload_folder( folder_path=tmpdirname, repo_id=repo_name, repo_type="model", ) # Get model architecture as string import io from contextlib import redirect_stdout import html # Capture the model architecture string f = io.StringIO() with redirect_stdout(f): print(model) model_architecture_str = f.getvalue() # Escape HTML characters and format with line breaks model_architecture_str_html = html.escape(model_architecture_str).replace('\n', '
') # Format it for display in markdown with proper styling model_architecture_info = f"""
{model_architecture_str_html}
""" return f'🔗 Quantized Model

🤗 DONE


Find your repo here: {repo_name}

📊 Model Architecture
{model_architecture_info}' def quantize_and_save(profile: gr.OAuthProfile | None, oauth_token: gr.OAuthToken | None, model_name, quant_type_4, double_quant_4, compute_type_4, quant_storage_4, quantized_model_name, public): if oauth_token is None : return """

❌ Authentication Error

Please sign in to your HuggingFace account to use the quantizer.

""" if not profile: return """

❌ Authentication Error

Please sign in to your HuggingFace account to use the quantizer.

""" exists_message = check_model_exists(oauth_token, profile.username, model_name, quantized_model_name) if exists_message : return f"""

⚠️ Model Already Exists

{exists_message}

""" try: # Download phase quantized_model = quantize_model(model_name, quant_type_4, double_quant_4, compute_type_4, quant_storage_4, oauth_token) final_message = save_model(quantized_model, model_name, quant_type_4, double_quant_4, compute_type_4, quant_storage_4, profile.username, oauth_token, quantized_model_name, public) return final_message except Exception as e : error_message = str(e).replace('\n', '
') return f"""

❌ Error Occurred

{error_message}

""" css="""/* Custom CSS to allow scrolling */ .gradio-container {overflow-y: auto;} /* Fix alignment for radio buttons and checkboxes */ .gradio-radio { display: flex !important; align-items: center !important; margin: 10px 0 !important; } .gradio-checkbox { display: flex !important; align-items: center !important; margin: 10px 0 !important; } /* Ensure consistent spacing and alignment */ .gradio-dropdown, .gradio-textbox, .gradio-radio, .gradio-checkbox { margin-bottom: 12px !important; width: 100% !important; } /* Align radio buttons and checkboxes horizontally */ .option-row { display: flex !important; justify-content: space-between !important; align-items: center !important; gap: 20px !important; margin-bottom: 12px !important; } .option-row .gradio-radio, .option-row .gradio-checkbox { margin: 0 !important; flex: 1 !important; } /* Horizontally align radio button options with text */ .gradio-radio label { display: flex !important; align-items: center !important; } .gradio-radio input[type="radio"] { margin-right: 5px !important; } /* Remove padding and margin from model name textbox for better alignment */ .model-name-textbox { padding-left: 0 !important; padding-right: 0 !important; margin-left: 0 !important; margin-right: 0 !important; } /* Quantize button styling with glow effect */ button[variant="primary"] { background: linear-gradient(135deg, #3B82F6, #10B981) !important; color: white !important; padding: 16px 32px !important; font-size: 1.1rem !important; font-weight: 700 !important; border: none !important; border-radius: 12px !important; box-shadow: 0 0 15px rgba(59, 130, 246, 0.5) !important; transition: all 0.3s cubic-bezier(0.25, 0.8, 0.25, 1) !important; position: relative; overflow: hidden; animation: glow 1.5s ease-in-out infinite alternate; } button[variant="primary"]::before { content: "✨ "; } button[variant="primary"]:hover { transform: translateY(-5px) scale(1.05) !important; box-shadow: 0 10px 25px rgba(59, 130, 246, 0.7) !important; } @keyframes glow { from { box-shadow: 0 0 10px rgba(59, 130, 246, 0.5); } to { box-shadow: 0 0 20px rgba(59, 130, 246, 0.8), 0 0 30px rgba(16, 185, 129, 0.5); } } /* Login button styling with glow effect */ #login-button { background: linear-gradient(135deg, #3B82F6, #10B981) !important; color: white !important; font-weight: 700 !important; border: none !important; border-radius: 12px !important; box-shadow: 0 0 15px rgba(59, 130, 246, 0.5) !important; transition: all 0.3s cubic-bezier(0.25, 0.8, 0.25, 1) !important; position: relative; overflow: hidden; animation: glow 1.5s ease-in-out infinite alternate; max-width: 300px !important; margin: 0 auto !important; } #login-button::before { content: "🔑 "; display: inline-block !important; vertical-align: middle !important; margin-right: 5px !important; line-height: normal !important; } #login-button:hover { transform: translateY(-3px) scale(1.03) !important; box-shadow: 0 10px 25px rgba(59, 130, 246, 0.7) !important; } #login-button::after { content: ""; position: absolute; top: 0; left: -100%; width: 100%; height: 100%; background: linear-gradient(90deg, transparent, rgba(255, 255, 255, 0.2), transparent); transition: 0.5s; } #login-button:hover::after { left: 100%; } /* Toggle instructions button styling */ #toggle-button { background: linear-gradient(135deg, #3B82F6, #10B981) !important; color: white !important; font-size: 0.85rem !important; font-weight: 600 !important; padding: 8px 16px !important; border: none !important; border-radius: 8px !important; box-shadow: 0 2px 10px rgba(59, 130, 246, 0.3) !important; transition: all 0.3s ease !important; margin: 0.5rem auto 1.5rem auto !important; display: block !important; max-width: 200px !important; text-align: center !important; position: relative; overflow: hidden; } #toggle-button:hover { transform: translateY(-2px) !important; box-shadow: 0 4px 12px rgba(59, 130, 246, 0.5) !important; } #toggle-button::after { content: ""; position: absolute; top: 0; left: -100%; width: 100%; height: 100%; background: linear-gradient(90deg, transparent, rgba(255, 255, 255, 0.2), transparent); transition: 0.5s; } #toggle-button:hover::after { left: 100%; } /* Progress Bar Styles */ .progress-container { font-family: system-ui, -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; padding: 20px; background: white; border-radius: 12px; box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); } .progress-stage { font-size: 0.9rem; font-weight: 600; color: #64748b; } .progress-stage .stage { position: relative; padding: 8px 12px; border-radius: 6px; background: #f1f5f9; transition: all 0.3s ease; } .progress-stage .stage.completed { background: #ecfdf5; } .progress-bar { box-shadow: inset 0 2px 4px rgba(0, 0, 0, 0.1); } .progress { transition: width 0.8s cubic-bezier(0.4, 0, 0.2, 1); box-shadow: 0 2px 4px rgba(59, 130, 246, 0.3); } """ def quantize_model_with_progress(model_name, quant_type_4, double_quant_4, compute_type_4, quant_storage_4, auth_token, progress=gr.Progress()): """Quantize model with progress updates.""" progress(0, desc="Loading model") # Configure quantization quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type=quant_type_4, bnb_4bit_use_double_quant=True if double_quant_4 == "True" else False, bnb_4bit_quant_storage=DTYPE_MAPPING[quant_storage_4], bnb_4bit_compute_dtype=DTYPE_MAPPING[compute_type_4], ) # Load model model = AutoModel.from_pretrained(model_name, quantization_config=quantization_config, device_map="cpu", use_auth_token=auth_token.token, torch_dtype=torch.bfloat16) progress(0.33, desc="Quantizing") # Quantize model modules = list(model.named_modules()) for idx, (_, module) in enumerate(modules): if isinstance(module, Linear4bit): module.to("cuda") module.to("cpu") progress(0.33 + (0.33 * idx / len(modules)), desc="Quantizing") progress(0.66, desc="Quantized successfully") return model def save_model_with_progress(model, model_name, quant_type_4, double_quant_4, compute_type_4, quant_storage_4, username=None, auth_token=None, quantized_model_name=None, public=False, progress=gr.Progress()): """Save model with progress updates.""" progress(0.67, desc="Preparing to push") with tempfile.TemporaryDirectory() as tmpdirname: # Save model model.save_pretrained(tmpdirname, safe_serialization=True, use_auth_token=auth_token.token) progress(0.75, desc="Preparing to push") # Prepare repo name and model card if quantized_model_name: repo_name = f"{username}/{quantized_model_name}" else: repo_name = f"{username}/{model_name.split('/')[-1]}-bnb-4bit" model_card = create_model_card(repo_name, quant_type_4, double_quant_4, compute_type_4, quant_storage_4) with open(os.path.join(tmpdirname, "README.md"), "w") as f: f.write(model_card) progress(0.80, desc="Model card created") # Push to Hub api = HfApi(token=auth_token.token) api.create_repo(repo_name, exist_ok=True, private=not public) progress(0.85, desc="Pushing to Hub") # Upload files api.upload_folder( folder_path=tmpdirname, repo_id=repo_name, repo_type="model", ) progress(1.00, desc="Model pushed to Hub") # Get model architecture as string import io from contextlib import redirect_stdout import html # Capture the model architecture string f = io.StringIO() with redirect_stdout(f): print(model) model_architecture_str = f.getvalue() # Escape HTML characters and format with line breaks model_architecture_str_html = html.escape(model_architecture_str).replace('\n', '
') # Format it for display in markdown with proper styling model_architecture_info = f"""
{model_architecture_str_html}
""" return f'🔗 Quantized Model

🤗 DONE


Find your repo here: {repo_name}

📊 Model Architecture
{model_architecture_info}' def quantize_and_save(profile: gr.OAuthProfile | None, oauth_token: gr.OAuthToken | None, model_name, quant_type_4, double_quant_4, compute_type_4, quant_storage_4, quantized_model_name, public, progress=gr.Progress()): if oauth_token is None: return """

❌ Authentication Error

Please sign in to your HuggingFace account to use the quantizer.

""" if not profile: return """

❌ Authentication Error

Please sign in to your HuggingFace account to use the quantizer.

""" exists_message = check_model_exists(oauth_token, profile.username, model_name, quantized_model_name) if exists_message: return f"""

⚠️ Model Already Exists

{exists_message}

""" try: # Download and quantize phase progress(0, desc="Starting quantization process") quantized_model = quantize_model_with_progress(model_name, quant_type_4, double_quant_4, compute_type_4, quant_storage_4, oauth_token, progress) # Save and push phase final_message = save_model_with_progress(quantized_model, model_name, quant_type_4, double_quant_4, compute_type_4, quant_storage_4, profile.username, oauth_token, quantized_model_name, public, progress) return final_message except Exception as e: error_message = str(e).replace('\n', '
') return f"""

❌ Error Occurred

{error_message}

""" with gr.Blocks(theme=gr.themes.Ocean(), css=css) as demo: gr.Markdown( """ # 🤗 LLM Model BitsAndBytes Quantizer ✨ """ ) gr.LoginButton(elem_id="login-button", elem_classes="center-button", min_width=250) m1 = gr.Markdown() demo.load(hello, inputs=None, outputs=m1) instructions_visible = gr.State(False) with gr.Row(): with gr.Column(): with gr.Row(): model_name = HuggingfaceHubSearch( label="🔍 Hub Model ID", placeholder="Search for model id on Huggingface", search_type="model", ) with gr.Row(): with gr.Column(): gr.Markdown( """ ### ⚙️ Model Quantization Type Settings """ ) quant_type_4 = gr.Dropdown( info="The quantization data type in the bnb.nn.Linear4Bit layers", choices=["fp4", "nf4"], value="nf4", visible=True, show_label=False ) compute_type_4 = gr.Dropdown( info="The compute type for the model", choices=["float16", "bfloat16", "float32"], value="bfloat16", visible=True, show_label=False ) quant_storage_4 = gr.Dropdown( info="The storage type for the model", choices=["float16", "float32", "int8", "uint8", "bfloat16"], value="uint8", visible=True, show_label=False ) gr.Markdown( """ ### 🔄 Double Quantization Settings """ ) with gr.Row(elem_classes="option-row"): double_quant_4 = gr.Radio( ["True", "False"], info="Use Double Quant", visible=True, value="True", show_label=False ) gr.Markdown( """ ### 💾 Saving Settings """ ) with gr.Row(): quantized_model_name = gr.Textbox( label="✏️ Model Name", info="Model Name (optional : to override default)", value="", interactive=True, elem_classes="model-name-textbox", show_label=False, ) with gr.Row(): public = gr.Checkbox( label="🌐 Make model public", info="If checked, the model will be publicly accessible", value=True, interactive=True, show_label=True ) with gr.Column(): quantize_button = gr.Button("🚀 Quantize and Push to the Hub", variant="primary") output_link = gr.Markdown("🔗 Quantized Model", container=True, min_height=100) quantize_button.click( fn=quantize_and_save, inputs=[model_name, quant_type_4, double_quant_4, compute_type_4, quant_storage_4, quantized_model_name, public], outputs=[output_link], ) if __name__ == "__main__": demo.launch(share=True)