import gradio as gr import torch import json import re import os import spaces from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig from peft import PeftModel from dataclasses import dataclass @dataclass class APIConfig: model_name: str = "meta-llama/Llama-3.1-8B-Instruct" hf_token: str = os.getenv("HF_TOKEN") # Set this in Space -> Settings -> Secrets max_length: int = 256 max_new_tokens: int = 150 temperature: float = 0.7 # ---- CPU/main-process globals must NOT touch CUDA ---- _gpu_model = None _gpu_tokenizer = None config = APIConfig() def _build_prompt(tok, business_description, num_domains=3): messages = [ {"role": "system", "content": "You are a helpful AI assistant that generates creative and relevant domain names for businesses. You refuse to generate domains for inappropriate or harmful content."}, {"role": "user", "content": f"Generate {num_domains} domain names for the following business: {business_description}"} ] if hasattr(tok, "apply_chat_template"): return tok.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) # Fallback if the tokenizer has no chat template sys_msg = messages[0]["content"] user_msg = messages[1]["content"] return f"[SYSTEM]\n{sys_msg}\n[/SYSTEM]\n[USER]\n{user_msg}\n[/USER]\n[ASSISTANT]\n" def calculate_confidence_score(domain, business_description): score = 0.5 domain_base = domain.split('.')[0].lower() business_words = set(business_description.lower().split()) if 6 <= len(domain_base) <= 15: score += 0.15 if domain_base not in business_words: score += 0.15 if domain.endswith(('.co', '.io', '.app', '.studio', '.pro')): score += 0.1 if not re.search(r'[0-9-]', domain_base): score += 0.1 return min(score, 1.0) # ---------- ALL CUDA/MODEL WORK HAPPENS ONLY BELOW ---------- def _ensure_model_loaded_on_gpu(): """This function is ONLY called from inside a @spaces.GPU function.""" global _gpu_model, _gpu_tokenizer if _gpu_model is not None and _gpu_tokenizer is not None: return _gpu_model, _gpu_tokenizer print("[ZeroGPU] Loading tokenizer + model on GPU...") tok = AutoTokenizer.from_pretrained( config.model_name, token=config.hf_token, ) tok.pad_token = tok.eos_token tok.padding_side = "right" bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True, ) base = AutoModelForCausalLM.from_pretrained( config.model_name, quantization_config=bnb_config, device_map="auto", torch_dtype=torch.bfloat16, token=config.hf_token, trust_remote_code=True, ) adapter_id = "Maikobi/domain-name-generator" print(f"[ZeroGPU] Loading LoRA adapter: {adapter_id}") peft_model = PeftModel.from_pretrained(base, adapter_id) peft_model.eval() _gpu_tokenizer = tok _gpu_model = peft_model print("[ZeroGPU] Model ready.") return _gpu_model, _gpu_tokenizer @spaces.GPU # Runs in GPU worker process (safe to touch CUDA here) def generate_domains_gpu(business_description, num_domains=3): m, tok = _ensure_model_loaded_on_gpu() prompt_text = _build_prompt(tok, business_description, num_domains=num_domains) enc = tok( prompt_text, return_tensors="pt", padding=True, truncation=True, max_length=config.max_length ) input_ids = enc["input_ids"].to(m.device) attention_mask = enc["attention_mask"].to(m.device) with torch.no_grad(): outputs = m.generate( input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=config.max_new_tokens, temperature=config.temperature, do_sample=True, pad_token_id=tok.eos_token_id, eos_token_id=tok.eos_token_id ) gen_ids = outputs[0][input_ids.shape[1]:] decoded = tok.decode(gen_ids, skip_special_tokens=False) for stop in ["<|eot_id|>", "<|end_header_id|>user", "<|start_header_id|>user"]: if stop in decoded: decoded = decoded.split(stop, 1)[0] return decoded # ---------- API (runs in main process; calls GPU function) ---------- def domain_api(business_description): if not business_description or not business_description.strip(): return { "suggestions": [], "status": "error", "message": "Business description is required" } try: generated_text = generate_domains_gpu(business_description.strip(), num_domains=3) refusal_indicators = ["cannot", "inappropriate", "refuse", "unable", "not generate", "not provide", "violates content policy"] if any(indicator in generated_text.lower() for indicator in refusal_indicators): return { "suggestions": [], "status": "blocked", "message": "Request contains inappropriate content" } domain_re = re.compile( r"\b[a-z0-9][a-z0-9-]{1,63}\.(?:com|org|net|io|co|ai|app|dev|studio|pro|cafe|coffee|restaurant|bakery|llc|firm|agency)\b", re.IGNORECASE ) candidates = domain_re.findall(generated_text) seen = set() suggestions = [] for domain in candidates: dl = domain.lower() if dl not in seen and len(suggestions) < 3: seen.add(dl) confidence = calculate_confidence_score(dl, business_description) suggestions.append({"domain": dl, "confidence": round(confidence, 2)}) if not suggestions: base = re.sub(r"[^a-z0-9 ]+", " ", business_description.lower()) words = [w for w in base.split() if len(w) > 2][:3] stem = "".join(words)[:12] or "brand" for domain in [f"{stem}.com", f"{stem}hub.com", f"get{stem}.com"]: confidence = calculate_confidence_score(domain, business_description) suggestions.append({"domain": domain, "confidence": round(confidence, 2)}) return {"suggestions": suggestions, "status": "success"} except Exception as e: return {"suggestions": [], "status": "error", "message": f"Generation failed: {str(e)}"} def domain_api_wrapper(json_input): """ Accepts either: - A single object: {"business_description": "..."} - An array of objects: [{"business_description": "..."}, ...] Returns a single result for object input, or a list of results for array input. """ try: data = json.loads(json_input) if isinstance(json_input, str) else json_input # Batch array input if isinstance(data, list): results = [] for item in data: bd = item.get("business_description", "") if isinstance(item, dict) else "" results.append(domain_api(bd)) return results # Single object input if isinstance(data, dict): bd = data.get("business_description", "") return domain_api(bd) return {"suggestions": [], "status": "error", "message": "Invalid JSON format"} except json.JSONDecodeError: return {"suggestions": [], "status": "error", "message": "Invalid JSON input"} except Exception as e: return {"suggestions": [], "status": "error", "message": f"Input processing failed: {str(e)}"} # ---------- UI ---------- with gr.Blocks(title="Domain Name Generator API") as demo: gr.Markdown("# Domain Name Generator API") gr.Markdown("Generate creative domain names for your business using fine-tuned Llama-3.1-8B") with gr.Tabs(): with gr.Tab("Simple Input"): with gr.Row(): with gr.Column(): business_input = gr.Textbox( label="Business Description", placeholder="organic coffee shop in downtown area", lines=3 ) simple_btn = gr.Button("Generate Domains", variant="primary") with gr.Column(): simple_output = gr.JSON(label="API Response") with gr.Tab("JSON API Format"): with gr.Row(): with gr.Column(): gr.Markdown("**Accepted Formats:**") gr.Code('{"business_description": "your business here"}', language="json") gr.Markdown("or") gr.Code('[{"business_description": "desc1"}, {"business_description": "desc2"}]', language="json") json_input = gr.Textbox( label="JSON Input", placeholder='{"business_description": "organic coffee shop in downtown area"}', lines=8 ) json_btn = gr.Button("Generate Domains", variant="primary") with gr.Column(): json_output = gr.JSON(label="API Response") simple_btn.click(fn=domain_api, inputs=[business_input], outputs=[simple_output]) json_btn.click(fn=domain_api_wrapper, inputs=[json_input], outputs=[json_output]) # Enable queue (required so @spaces.GPU runs in a worker on Stateless GPU) demo.queue() if __name__ == "__main__": demo.launch()