import gradio as gr
import torch
from transformers import AutoModel, BitsAndBytesConfig, AutoTokenizer
import tempfile
from huggingface_hub import HfApi
from huggingface_hub import list_models
from gradio_huggingfacehub_search import HuggingfaceHubSearch
from bitsandbytes.nn import Linear4bit
import os
from huggingface_hub import snapshot_download
def hello(profile: gr.OAuthProfile | None, oauth_token: gr.OAuthToken | None) -> str:
# ^ expect a gr.OAuthProfile object as input to get the user's profile
# if the user is not logged in, profile will be None
if profile is None:
return "Hello ! Please Login to your HuggingFace account to use the BitsAndBytes Quantizer!"
return f"Hello {profile.name} ! Welcome to BitsAndBytes Quantizer"
def check_model_exists(
oauth_token: gr.OAuthToken | None, username, model_name, quantized_model_name, upload_to_community
):
"""Check if a model exists in the user's Hugging Face repository."""
try:
models = list_models(author=username, token=oauth_token.token)
community_models = list_models(author="bnb-community", token=oauth_token.token)
model_names = [model.id for model in models]
community_model_names = [model.id for model in community_models]
if upload_to_community:
repo_name = f"bnb-community/{model_name.split('/')[-1]}-bnb-4bit"
else:
if quantized_model_name:
repo_name = f"{username}/{quantized_model_name}"
else:
repo_name = f"{username}/{model_name.split('/')[-1]}-bnb-4bit"
if repo_name in model_names:
return f"Model '{repo_name}' already exists in your repository."
elif repo_name in community_model_names:
return f"Model '{repo_name}' already exists in the bnb-community organization."
else:
return None # Model does not exist
except Exception as e:
return f"Error checking model existence: {str(e)}"
def create_model_card(
model_name, quant_type_4, double_quant_4, compute_type_4, quant_storage_4
):
# Try to download the original README
original_readme = ""
original_yaml_header = ""
try:
# Download the README.md file from the original model
model_path = snapshot_download(repo_id=model_name, allow_patterns=["README.md"], repo_type="model")
readme_path = os.path.join(model_path, "README.md")
if os.path.exists(readme_path):
with open(readme_path, 'r', encoding='utf-8') as f:
content = f.read()
if content.startswith('---'):
parts = content.split('---', 2)
if len(parts) >= 3:
original_yaml_header = parts[1]
original_readme = '---'.join(parts[2:])
else:
original_readme = content
else:
original_readme = content
except Exception as e:
print(f"Error reading original README: {str(e)}")
original_readme = ""
# Create new YAML header with base_model field
yaml_header = f"""---
base_model:
- {model_name}"""
# Add any original YAML fields except base_model
if original_yaml_header:
in_base_model_section = False
found_tags = False
for line in original_yaml_header.strip().split('\n'):
# Skip if we're in a base_model section that continues to the next line
if in_base_model_section:
if line.strip().startswith('-') or not line.strip() or line.startswith(' '):
continue
else:
in_base_model_section = False
# Check for base_model field
if line.strip().startswith('base_model:'):
in_base_model_section = True
# If base_model has inline value (like "base_model: model_name")
if ':' in line and len(line.split(':', 1)[1].strip()) > 0:
in_base_model_section = False
continue
# Check for tags field and add bnb-my-repo
if line.strip().startswith('tags:'):
found_tags = True
yaml_header += f"\n{line}"
yaml_header += "\n- bnb-my-repo"
continue
yaml_header += f"\n{line}"
# If tags field wasn't found, add it
if not found_tags:
yaml_header += "\ntags:"
yaml_header += "\n- bnb-my-repo"
# Complete the YAML header
yaml_header += "\n---"
# Create the quantization info section
quant_info = f"""
# {model_name} (Quantized)
## Description
This model is a quantized version of the original model [`{model_name}`](https://huggingface.co/{model_name}).
It's quantized using the BitsAndBytes library to 4-bit using the [bnb-my-repo](https://huggingface.co/spaces/bnb-community/bnb-my-repo) space.
## Quantization Details
- **Quantization Type**: int4
- **bnb_4bit_quant_type**: {quant_type_4}
- **bnb_4bit_use_double_quant**: {double_quant_4}
- **bnb_4bit_compute_dtype**: {compute_type_4}
- **bnb_4bit_quant_storage**: {quant_storage_4}
"""
# Combine everything
model_card = yaml_header + quant_info
# Append original README content if available
if original_readme and not original_readme.isspace():
model_card += "\n\n# 📄 Original Model Information\n\n" + original_readme
return model_card
DTYPE_MAPPING = {
"int8": torch.int8,
"uint8": torch.uint8,
"float16": torch.float16,
"float32": torch.float32,
"bfloat16": torch.bfloat16,
}
def quantize_model(
model_name,
quant_type_4,
double_quant_4,
compute_type_4,
quant_storage_4,
auth_token=None,
progress=gr.Progress(),
):
progress(0, desc="Loading model")
# Configure quantization
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type=quant_type_4,
bnb_4bit_use_double_quant=True if double_quant_4 == "True" else False,
bnb_4bit_quant_storage=DTYPE_MAPPING[quant_storage_4],
bnb_4bit_compute_dtype=DTYPE_MAPPING[compute_type_4],
)
# Load model
model = AutoModel.from_pretrained(
model_name,
quantization_config=quantization_config,
device_map="cpu",
use_auth_token=auth_token.token,
torch_dtype="auto",
)
progress(0.33, desc="Quantizing")
# Quantize model
# Calculate original model sizeo
original_size_gb = get_model_size(model)
modules = list(model.named_modules())
for idx, (_, module) in enumerate(modules):
if isinstance(module, Linear4bit):
module.to("cuda")
module.to("cpu")
progress(0.33 + (0.33 * idx / len(modules)), desc="Quantizing")
progress(0.66, desc="Quantized successfully")
return model, original_size_gb
def save_model(
model,
model_name,
original_size_gb,
quant_type_4,
double_quant_4,
compute_type_4,
quant_storage_4,
username=None,
auth_token=None,
quantized_model_name=None,
public=False,
upload_to_community=False,
progress=gr.Progress(),
):
progress(0.67, desc="Preparing to push")
with tempfile.TemporaryDirectory() as tmpdirname:
# Save model
tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=auth_token.token)
tokenizer.save_pretrained(tmpdirname, safe_serialization=True, use_auth_token=auth_token.token)
model.save_pretrained(
tmpdirname, safe_serialization=True, use_auth_token=auth_token.token
)
progress(0.75, desc="Preparing to push")
# Prepare repo name and model card
if upload_to_community:
repo_name = f"bnb-community/{model_name.split('/')[-1]}-bnb-4bit"
else:
if quantized_model_name:
repo_name = f"{username}/{quantized_model_name}"
else:
repo_name = f"{username}/{model_name.split('/')[-1]}-bnb-4bit"
model_card = create_model_card(
model_name, quant_type_4, double_quant_4, compute_type_4, quant_storage_4
)
with open(os.path.join(tmpdirname, "README.md"), "w") as f:
f.write(model_card)
progress(0.80, desc="Model card created")
# Push to Hub
api = HfApi(token=auth_token.token)
api.create_repo(repo_name, exist_ok=True, private=not public)
progress(0.85, desc="Pushing to Hub")
# Upload files
api.upload_folder(
folder_path=tmpdirname,
repo_id=repo_name,
repo_type="model",
)
progress(0.95, desc="Model pushed to Hub")
# Get model architecture as string
import io
from contextlib import redirect_stdout
import html
# Capture the model architecture string
f = io.StringIO()
with redirect_stdout(f):
print(model)
model_architecture_str = f.getvalue()
# Escape HTML characters and format with line breaks
model_architecture_str_html = html.escape(model_architecture_str).replace(
"\n", "
"
)
# Format it for display in markdown with proper styling
model_architecture_info = f"""
Original (bf16)≈ {original_size_gb} GB → Quantized ≈ {get_model_size(model)} GB
Find your repo here: {repo_name}
Please sign in to your HuggingFace account to use the quantizer.
Please sign in to your HuggingFace account to use the quantizer.
{exists_message}
{error_message}