Maikobi's picture
Update app.py
4d87d33 verified
import gradio as gr
import torch
import json
import re
import os
import spaces
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel
from dataclasses import dataclass
@dataclass
class APIConfig:
model_name: str = "meta-llama/Llama-3.1-8B-Instruct"
hf_token: str = os.getenv("HF_TOKEN") # Set this in Space -> Settings -> Secrets
max_length: int = 256
max_new_tokens: int = 150
temperature: float = 0.7
# ---- CPU/main-process globals must NOT touch CUDA ----
_gpu_model = None
_gpu_tokenizer = None
config = APIConfig()
def _build_prompt(tok, business_description, num_domains=3):
messages = [
{"role": "system", "content": "You are a helpful AI assistant that generates creative and relevant domain names for businesses. You refuse to generate domains for inappropriate or harmful content."},
{"role": "user", "content": f"Generate {num_domains} domain names for the following business: {business_description}"}
]
if hasattr(tok, "apply_chat_template"):
return tok.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
# Fallback if the tokenizer has no chat template
sys_msg = messages[0]["content"]
user_msg = messages[1]["content"]
return f"<s>[SYSTEM]\n{sys_msg}\n[/SYSTEM]\n[USER]\n{user_msg}\n[/USER]\n[ASSISTANT]\n"
def calculate_confidence_score(domain, business_description):
score = 0.5
domain_base = domain.split('.')[0].lower()
business_words = set(business_description.lower().split())
if 6 <= len(domain_base) <= 15:
score += 0.15
if domain_base not in business_words:
score += 0.15
if domain.endswith(('.co', '.io', '.app', '.studio', '.pro')):
score += 0.1
if not re.search(r'[0-9-]', domain_base):
score += 0.1
return min(score, 1.0)
# ---------- ALL CUDA/MODEL WORK HAPPENS ONLY BELOW ----------
def _ensure_model_loaded_on_gpu():
"""This function is ONLY called from inside a @spaces.GPU function."""
global _gpu_model, _gpu_tokenizer
if _gpu_model is not None and _gpu_tokenizer is not None:
return _gpu_model, _gpu_tokenizer
print("[ZeroGPU] Loading tokenizer + model on GPU...")
tok = AutoTokenizer.from_pretrained(
config.model_name,
token=config.hf_token,
)
tok.pad_token = tok.eos_token
tok.padding_side = "right"
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
base = AutoModelForCausalLM.from_pretrained(
config.model_name,
quantization_config=bnb_config,
device_map="auto",
torch_dtype=torch.bfloat16,
token=config.hf_token,
trust_remote_code=True,
)
adapter_id = "Maikobi/domain-name-generator"
print(f"[ZeroGPU] Loading LoRA adapter: {adapter_id}")
peft_model = PeftModel.from_pretrained(base, adapter_id)
peft_model.eval()
_gpu_tokenizer = tok
_gpu_model = peft_model
print("[ZeroGPU] Model ready.")
return _gpu_model, _gpu_tokenizer
@spaces.GPU # Runs in GPU worker process (safe to touch CUDA here)
def generate_domains_gpu(business_description, num_domains=3):
m, tok = _ensure_model_loaded_on_gpu()
prompt_text = _build_prompt(tok, business_description, num_domains=num_domains)
enc = tok(
prompt_text,
return_tensors="pt",
padding=True,
truncation=True,
max_length=config.max_length
)
input_ids = enc["input_ids"].to(m.device)
attention_mask = enc["attention_mask"].to(m.device)
with torch.no_grad():
outputs = m.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=config.max_new_tokens,
temperature=config.temperature,
do_sample=True,
pad_token_id=tok.eos_token_id,
eos_token_id=tok.eos_token_id
)
gen_ids = outputs[0][input_ids.shape[1]:]
decoded = tok.decode(gen_ids, skip_special_tokens=False)
for stop in ["<|eot_id|>", "<|end_header_id|>user", "<|start_header_id|>user"]:
if stop in decoded:
decoded = decoded.split(stop, 1)[0]
return decoded
# ---------- API (runs in main process; calls GPU function) ----------
def domain_api(business_description):
if not business_description or not business_description.strip():
return {
"suggestions": [],
"status": "error",
"message": "Business description is required"
}
try:
generated_text = generate_domains_gpu(business_description.strip(), num_domains=3)
refusal_indicators = ["cannot", "inappropriate", "refuse", "unable", "not generate", "not provide", "violates content policy"]
if any(indicator in generated_text.lower() for indicator in refusal_indicators):
return {
"suggestions": [],
"status": "blocked",
"message": "Request contains inappropriate content"
}
domain_re = re.compile(
r"\b[a-z0-9][a-z0-9-]{1,63}\.(?:com|org|net|io|co|ai|app|dev|studio|pro|cafe|coffee|restaurant|bakery|llc|firm|agency)\b",
re.IGNORECASE
)
candidates = domain_re.findall(generated_text)
seen = set()
suggestions = []
for domain in candidates:
dl = domain.lower()
if dl not in seen and len(suggestions) < 3:
seen.add(dl)
confidence = calculate_confidence_score(dl, business_description)
suggestions.append({"domain": dl, "confidence": round(confidence, 2)})
if not suggestions:
base = re.sub(r"[^a-z0-9 ]+", " ", business_description.lower())
words = [w for w in base.split() if len(w) > 2][:3]
stem = "".join(words)[:12] or "brand"
for domain in [f"{stem}.com", f"{stem}hub.com", f"get{stem}.com"]:
confidence = calculate_confidence_score(domain, business_description)
suggestions.append({"domain": domain, "confidence": round(confidence, 2)})
return {"suggestions": suggestions, "status": "success"}
except Exception as e:
return {"suggestions": [], "status": "error", "message": f"Generation failed: {str(e)}"}
def domain_api_wrapper(json_input):
"""
Accepts either:
- A single object: {"business_description": "..."}
- An array of objects: [{"business_description": "..."}, ...]
Returns a single result for object input, or a list of results for array input.
"""
try:
data = json.loads(json_input) if isinstance(json_input, str) else json_input
# Batch array input
if isinstance(data, list):
results = []
for item in data:
bd = item.get("business_description", "") if isinstance(item, dict) else ""
results.append(domain_api(bd))
return results
# Single object input
if isinstance(data, dict):
bd = data.get("business_description", "")
return domain_api(bd)
return {"suggestions": [], "status": "error", "message": "Invalid JSON format"}
except json.JSONDecodeError:
return {"suggestions": [], "status": "error", "message": "Invalid JSON input"}
except Exception as e:
return {"suggestions": [], "status": "error", "message": f"Input processing failed: {str(e)}"}
# ---------- UI ----------
with gr.Blocks(title="Domain Name Generator API") as demo:
gr.Markdown("# Domain Name Generator API")
gr.Markdown("Generate creative domain names for your business using fine-tuned Llama-3.1-8B")
with gr.Tabs():
with gr.Tab("Simple Input"):
with gr.Row():
with gr.Column():
business_input = gr.Textbox(
label="Business Description",
placeholder="organic coffee shop in downtown area",
lines=3
)
simple_btn = gr.Button("Generate Domains", variant="primary")
with gr.Column():
simple_output = gr.JSON(label="API Response")
with gr.Tab("JSON API Format"):
with gr.Row():
with gr.Column():
gr.Markdown("**Accepted Formats:**")
gr.Code('{"business_description": "your business here"}', language="json")
gr.Markdown("or")
gr.Code('[{"business_description": "desc1"}, {"business_description": "desc2"}]', language="json")
json_input = gr.Textbox(
label="JSON Input",
placeholder='{"business_description": "organic coffee shop in downtown area"}',
lines=8
)
json_btn = gr.Button("Generate Domains", variant="primary")
with gr.Column():
json_output = gr.JSON(label="API Response")
simple_btn.click(fn=domain_api, inputs=[business_input], outputs=[simple_output])
json_btn.click(fn=domain_api_wrapper, inputs=[json_input], outputs=[json_output])
# Enable queue (required so @spaces.GPU runs in a worker on Stateless GPU)
demo.queue()
if __name__ == "__main__":
demo.launch()