import gradio as gr
import torch
from transformers import AutoModel, BitsAndBytesConfig
import tempfile
from huggingface_hub import HfApi
from huggingface_hub import list_models
from gradio_huggingfacehub_search import HuggingfaceHubSearch
from bitsandbytes.nn import Linear4bit
from packaging import version
import os
from tqdm import tqdm
def hello(profile: gr.OAuthProfile | None, oauth_token: gr.OAuthToken | None) -> str:
# ^ expect a gr.OAuthProfile object as input to get the user's profile
# if the user is not logged in, profile will be None
if profile is None:
return "Hello Please Login to HuggingFace to use the BitsAndBytes Quantizer!"
return f"Hello {profile.name} ! Welcome to BitsAndBytes Quantizer"
def check_model_exists(oauth_token: gr.OAuthToken | None, username, model_name, quantized_model_name):
"""Check if a model exists in the user's Hugging Face repository."""
try:
models = list_models(author=username, token=oauth_token.token)
model_names = [model.id for model in models]
if quantized_model_name :
repo_name = f"{username}/{quantized_model_name}"
else :
repo_name = f"{username}/{model_name.split('/')[-1]}-bnb-4bit"
if repo_name in model_names:
return f"Model '{repo_name}' already exists in your repository."
else:
return None # Model does not exist
except Exception as e:
return f"Error checking model existence: {str(e)}"
def create_model_card(model_name, quant_type_4, double_quant_4, compute_type_4, quant_storage_4):
model_card = f"""---
base_model:
- {model_name}
---
# {model_name} (Quantized)
## Description
This model is a quantized version of the original model `{model_name}`. It has been quantized using int4 quantization with bitsandbytes.
## Quantization Details
- **Quantization Type**: int4
- **bnb_4bit_quant_type**: {quant_type_4}
- **bnb_4bit_use_double_quant**: {double_quant_4}
- **bnb_4bit_compute_dtype**: {compute_type_4}
- **bnb_4bit_quant_storage**: {quant_storage_4}
## Usage
You can use this model in your applications by loading it directly from the Hugging Face Hub:
```python
from transformers import AutoModel
model = AutoModel.from_pretrained("{model_name}")"""
return model_card
DTYPE_MAPPING = {
"int8": torch.int8,
"uint8": torch.uint8,
"float16": torch.float16,
"float32": torch.float32,
"bfloat16": torch.bfloat16,
}
def quantize_model(model_name, quant_type_4, double_quant_4, compute_type_4, quant_storage_4, auth_token=None, progress=gr.Progress()):
progress(0, desc="Starting")
print(f"Quantizing model: {quant_type_4}")
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type=quant_type_4,
bnb_4bit_use_double_quant=True if double_quant_4 == "True" else False,
bnb_4bit_quant_storage=DTYPE_MAPPING[quant_storage_4],
bnb_4bit_compute_dtype=DTYPE_MAPPING[compute_type_4],
)
model = AutoModel.from_pretrained(model_name, quantization_config=quantization_config, device_map="cpu", use_auth_token=auth_token.token, torch_dtype=torch.bfloat16)
for _ , module in progress.tqdm(model.named_modules(), desc="Quantizing model", total=len(list(model.named_modules())), unit="layers"):
if isinstance(module, Linear4bit):
module.to("cuda")
module.to("cpu")
return model
def save_model(model, model_name, quant_type_4, double_quant_4, compute_type_4, quant_storage_4, username=None, auth_token=None, quantized_model_name=None, public=False):
print("Saving quantized model")
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname, safe_serialization=True, use_auth_token=auth_token.token)
if quantized_model_name :
repo_name = f"{username}/{quantized_model_name}"
else :
repo_name = f"{username}/{model_name.split('/')[-1]}-bnb-4bit"
model_card = create_model_card(repo_name, quant_type_4, double_quant_4, compute_type_4, quant_storage_4)
with open(os.path.join(tmpdirname, "README.md"), "w") as f:
f.write(model_card)
# Push to Hub
api = HfApi(token=auth_token.token)
api.create_repo(repo_name, exist_ok=True, private=not public)
api.upload_folder(
folder_path=tmpdirname,
repo_id=repo_name,
repo_type="model",
)
# Get model architecture as string
import io
from contextlib import redirect_stdout
import html
# Capture the model architecture string
f = io.StringIO()
with redirect_stdout(f):
print(model)
model_architecture_str = f.getvalue()
# Escape HTML characters and format with line breaks
model_architecture_str_html = html.escape(model_architecture_str).replace('\n', '
')
# Format it for display in markdown with proper styling
model_architecture_info = f"""
{model_architecture_str_html}
"""
return f'🔗 Quantized Model
🤗 DONE
Find your repo here: {repo_name}
📊 Model Architecture
{model_architecture_info}'
def quantize_and_save(profile: gr.OAuthProfile | None, oauth_token: gr.OAuthToken | None, model_name, quant_type_4, double_quant_4, compute_type_4, quant_storage_4, quantized_model_name, public):
if oauth_token is None :
return """
❌ Authentication Error
Please sign in to your HuggingFace account to use the quantizer.
"""
if not profile:
return """
❌ Authentication Error
Please sign in to your HuggingFace account to use the quantizer.
"""
exists_message = check_model_exists(oauth_token, profile.username, model_name, quantized_model_name)
if exists_message :
return f"""
⚠️ Model Already Exists
{exists_message}
"""
try:
# Download phase
quantized_model = quantize_model(model_name, quant_type_4, double_quant_4, compute_type_4, quant_storage_4, oauth_token)
final_message = save_model(quantized_model, model_name, quant_type_4, double_quant_4, compute_type_4, quant_storage_4, profile.username, oauth_token, quantized_model_name, public)
return final_message
except Exception as e :
error_message = str(e).replace('\n', '
')
return f"""
❌ Error Occurred
{error_message}
"""
css="""/* Custom CSS to allow scrolling */
.gradio-container {overflow-y: auto;}
/* Fix alignment for radio buttons and checkboxes */
.gradio-radio {
display: flex !important;
align-items: center !important;
margin: 10px 0 !important;
}
.gradio-checkbox {
display: flex !important;
align-items: center !important;
margin: 10px 0 !important;
}
/* Ensure consistent spacing and alignment */
.gradio-dropdown, .gradio-textbox, .gradio-radio, .gradio-checkbox {
margin-bottom: 12px !important;
width: 100% !important;
}
/* Align radio buttons and checkboxes horizontally */
.option-row {
display: flex !important;
justify-content: space-between !important;
align-items: center !important;
gap: 20px !important;
margin-bottom: 12px !important;
}
.option-row .gradio-radio, .option-row .gradio-checkbox {
margin: 0 !important;
flex: 1 !important;
}
/* Horizontally align radio button options with text */
.gradio-radio label {
display: flex !important;
align-items: center !important;
}
.gradio-radio input[type="radio"] {
margin-right: 5px !important;
}
/* Remove padding and margin from model name textbox for better alignment */
.model-name-textbox {
padding-left: 0 !important;
padding-right: 0 !important;
margin-left: 0 !important;
margin-right: 0 !important;
}
/* Quantize button styling with glow effect */
button[variant="primary"] {
background: linear-gradient(135deg, #3B82F6, #10B981) !important;
color: white !important;
padding: 16px 32px !important;
font-size: 1.1rem !important;
font-weight: 700 !important;
border: none !important;
border-radius: 12px !important;
box-shadow: 0 0 15px rgba(59, 130, 246, 0.5) !important;
transition: all 0.3s cubic-bezier(0.25, 0.8, 0.25, 1) !important;
position: relative;
overflow: hidden;
animation: glow 1.5s ease-in-out infinite alternate;
}
button[variant="primary"]::before {
content: "✨ ";
}
button[variant="primary"]:hover {
transform: translateY(-5px) scale(1.05) !important;
box-shadow: 0 10px 25px rgba(59, 130, 246, 0.7) !important;
}
@keyframes glow {
from {
box-shadow: 0 0 10px rgba(59, 130, 246, 0.5);
}
to {
box-shadow: 0 0 20px rgba(59, 130, 246, 0.8), 0 0 30px rgba(16, 185, 129, 0.5);
}
}
/* Login button styling with glow effect */
#login-button {
background: linear-gradient(135deg, #3B82F6, #10B981) !important;
color: white !important;
font-weight: 700 !important;
border: none !important;
border-radius: 12px !important;
box-shadow: 0 0 15px rgba(59, 130, 246, 0.5) !important;
transition: all 0.3s cubic-bezier(0.25, 0.8, 0.25, 1) !important;
position: relative;
overflow: hidden;
animation: glow 1.5s ease-in-out infinite alternate;
max-width: 300px !important;
margin: 0 auto !important;
}
#login-button::before {
content: "🔑 ";
display: inline-block !important;
vertical-align: middle !important;
margin-right: 5px !important;
line-height: normal !important;
}
#login-button:hover {
transform: translateY(-3px) scale(1.03) !important;
box-shadow: 0 10px 25px rgba(59, 130, 246, 0.7) !important;
}
#login-button::after {
content: "";
position: absolute;
top: 0;
left: -100%;
width: 100%;
height: 100%;
background: linear-gradient(90deg, transparent, rgba(255, 255, 255, 0.2), transparent);
transition: 0.5s;
}
#login-button:hover::after {
left: 100%;
}
/* Toggle instructions button styling */
#toggle-button {
background: linear-gradient(135deg, #3B82F6, #10B981) !important;
color: white !important;
font-size: 0.85rem !important;
font-weight: 600 !important;
padding: 8px 16px !important;
border: none !important;
border-radius: 8px !important;
box-shadow: 0 2px 10px rgba(59, 130, 246, 0.3) !important;
transition: all 0.3s ease !important;
margin: 0.5rem auto 1.5rem auto !important;
display: block !important;
max-width: 200px !important;
text-align: center !important;
position: relative;
overflow: hidden;
}
#toggle-button:hover {
transform: translateY(-2px) !important;
box-shadow: 0 4px 12px rgba(59, 130, 246, 0.5) !important;
}
#toggle-button::after {
content: "";
position: absolute;
top: 0;
left: -100%;
width: 100%;
height: 100%;
background: linear-gradient(90deg, transparent, rgba(255, 255, 255, 0.2), transparent);
transition: 0.5s;
}
#toggle-button:hover::after {
left: 100%;
}
/* Progress Bar Styles */
.progress-container {
font-family: system-ui, -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
padding: 20px;
background: white;
border-radius: 12px;
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
}
.progress-stage {
font-size: 0.9rem;
font-weight: 600;
color: #64748b;
}
.progress-stage .stage {
position: relative;
padding: 8px 12px;
border-radius: 6px;
background: #f1f5f9;
transition: all 0.3s ease;
}
.progress-stage .stage.completed {
background: #ecfdf5;
}
.progress-bar {
box-shadow: inset 0 2px 4px rgba(0, 0, 0, 0.1);
}
.progress {
transition: width 0.8s cubic-bezier(0.4, 0, 0.2, 1);
box-shadow: 0 2px 4px rgba(59, 130, 246, 0.3);
}
"""
def quantize_model_with_progress(model_name, quant_type_4, double_quant_4, compute_type_4, quant_storage_4, auth_token, progress=gr.Progress()):
"""Quantize model with progress updates."""
progress(0, desc="Loading model")
# Configure quantization
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type=quant_type_4,
bnb_4bit_use_double_quant=True if double_quant_4 == "True" else False,
bnb_4bit_quant_storage=DTYPE_MAPPING[quant_storage_4],
bnb_4bit_compute_dtype=DTYPE_MAPPING[compute_type_4],
)
# Load model
model = AutoModel.from_pretrained(model_name, quantization_config=quantization_config, device_map="cpu", use_auth_token=auth_token.token, torch_dtype=torch.bfloat16)
progress(0.33, desc="Quantizing")
# Quantize model
modules = list(model.named_modules())
for idx, (_, module) in enumerate(modules):
if isinstance(module, Linear4bit):
module.to("cuda")
module.to("cpu")
progress(0.33 + (0.33 * idx / len(modules)), desc="Quantizing")
progress(0.66, desc="Quantized successfully")
return model
def save_model_with_progress(model, model_name, quant_type_4, double_quant_4, compute_type_4, quant_storage_4, username=None, auth_token=None, quantized_model_name=None, public=False, progress=gr.Progress()):
"""Save model with progress updates."""
progress(0.67, desc="Preparing to push")
with tempfile.TemporaryDirectory() as tmpdirname:
# Save model
model.save_pretrained(tmpdirname, safe_serialization=True, use_auth_token=auth_token.token)
progress(0.75, desc="Preparing to push")
# Prepare repo name and model card
if quantized_model_name:
repo_name = f"{username}/{quantized_model_name}"
else:
repo_name = f"{username}/{model_name.split('/')[-1]}-bnb-4bit"
model_card = create_model_card(repo_name, quant_type_4, double_quant_4, compute_type_4, quant_storage_4)
with open(os.path.join(tmpdirname, "README.md"), "w") as f:
f.write(model_card)
progress(0.80, desc="Model card created")
# Push to Hub
api = HfApi(token=auth_token.token)
api.create_repo(repo_name, exist_ok=True, private=not public)
progress(0.85, desc="Pushing to Hub")
# Upload files
api.upload_folder(
folder_path=tmpdirname,
repo_id=repo_name,
repo_type="model",
)
progress(1.00, desc="Model pushed to Hub")
# Get model architecture as string
import io
from contextlib import redirect_stdout
import html
# Capture the model architecture string
f = io.StringIO()
with redirect_stdout(f):
print(model)
model_architecture_str = f.getvalue()
# Escape HTML characters and format with line breaks
model_architecture_str_html = html.escape(model_architecture_str).replace('\n', '
')
# Format it for display in markdown with proper styling
model_architecture_info = f"""
{model_architecture_str_html}
"""
return f'🔗 Quantized Model
🤗 DONE
Find your repo here: {repo_name}
📊 Model Architecture
{model_architecture_info}'
def quantize_and_save(profile: gr.OAuthProfile | None, oauth_token: gr.OAuthToken | None, model_name, quant_type_4, double_quant_4, compute_type_4, quant_storage_4, quantized_model_name, public, progress=gr.Progress()):
if oauth_token is None:
return """
❌ Authentication Error
Please sign in to your HuggingFace account to use the quantizer.
"""
if not profile:
return """
❌ Authentication Error
Please sign in to your HuggingFace account to use the quantizer.
"""
exists_message = check_model_exists(oauth_token, profile.username, model_name, quantized_model_name)
if exists_message:
return f"""
⚠️ Model Already Exists
{exists_message}
"""
try:
# Download and quantize phase
progress(0, desc="Starting quantization process")
quantized_model = quantize_model_with_progress(model_name, quant_type_4, double_quant_4, compute_type_4, quant_storage_4, oauth_token, progress)
# Save and push phase
final_message = save_model_with_progress(quantized_model, model_name, quant_type_4, double_quant_4, compute_type_4, quant_storage_4, profile.username, oauth_token, quantized_model_name, public, progress)
return final_message
except Exception as e:
error_message = str(e).replace('\n', '
')
return f"""
❌ Error Occurred
{error_message}
"""
with gr.Blocks(theme=gr.themes.Ocean(), css=css) as demo:
gr.Markdown(
"""
# 🤗 LLM Model BitsAndBytes Quantizer ✨
"""
)
gr.LoginButton(elem_id="login-button", elem_classes="center-button", min_width=250)
m1 = gr.Markdown()
demo.load(hello, inputs=None, outputs=m1)
instructions_visible = gr.State(False)
with gr.Row():
with gr.Column():
with gr.Row():
model_name = HuggingfaceHubSearch(
label="🔍 Hub Model ID",
placeholder="Search for model id on Huggingface",
search_type="model",
)
with gr.Row():
with gr.Column():
gr.Markdown(
"""
### ⚙️ Model Quantization Type Settings
"""
)
quant_type_4 = gr.Dropdown(
info="The quantization data type in the bnb.nn.Linear4Bit layers",
choices=["fp4", "nf4"],
value="nf4",
visible=True,
show_label=False
)
compute_type_4 = gr.Dropdown(
info="The compute type for the model",
choices=["float16", "bfloat16", "float32"],
value="bfloat16",
visible=True,
show_label=False
)
quant_storage_4 = gr.Dropdown(
info="The storage type for the model",
choices=["float16", "float32", "int8", "uint8", "bfloat16"],
value="uint8",
visible=True,
show_label=False
)
gr.Markdown(
"""
### 🔄 Double Quantization Settings
"""
)
with gr.Row(elem_classes="option-row"):
double_quant_4 = gr.Radio(
["True", "False"],
info="Use Double Quant",
visible=True,
value="True",
show_label=False
)
gr.Markdown(
"""
### 💾 Saving Settings
"""
)
with gr.Row():
quantized_model_name = gr.Textbox(
label="✏️ Model Name",
info="Model Name (optional : to override default)",
value="",
interactive=True,
elem_classes="model-name-textbox",
show_label=False,
)
with gr.Row():
public = gr.Checkbox(
label="🌐 Make model public",
info="If checked, the model will be publicly accessible",
value=True,
interactive=True,
show_label=True
)
with gr.Column():
quantize_button = gr.Button("🚀 Quantize and Push to the Hub", variant="primary")
output_link = gr.Markdown("🔗 Quantized Model", container=True, min_height=100)
quantize_button.click(
fn=quantize_and_save,
inputs=[model_name, quant_type_4, double_quant_4, compute_type_4, quant_storage_4, quantized_model_name, public],
outputs=[output_link],
)
if __name__ == "__main__":
demo.launch(share=True)