Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import json | |
| import re | |
| import os | |
| import spaces | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig | |
| from peft import PeftModel | |
| from dataclasses import dataclass | |
| class APIConfig: | |
| model_name: str = "meta-llama/Llama-3.1-8B-Instruct" | |
| hf_token: str = os.getenv("HF_TOKEN") # Set this in Space -> Settings -> Secrets | |
| max_length: int = 256 | |
| max_new_tokens: int = 150 | |
| temperature: float = 0.7 | |
| # ---- CPU/main-process globals must NOT touch CUDA ---- | |
| _gpu_model = None | |
| _gpu_tokenizer = None | |
| config = APIConfig() | |
| def _build_prompt(tok, business_description, num_domains=3): | |
| messages = [ | |
| {"role": "system", "content": "You are a helpful AI assistant that generates creative and relevant domain names for businesses. You refuse to generate domains for inappropriate or harmful content."}, | |
| {"role": "user", "content": f"Generate {num_domains} domain names for the following business: {business_description}"} | |
| ] | |
| if hasattr(tok, "apply_chat_template"): | |
| return tok.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) | |
| # Fallback if the tokenizer has no chat template | |
| sys_msg = messages[0]["content"] | |
| user_msg = messages[1]["content"] | |
| return f"<s>[SYSTEM]\n{sys_msg}\n[/SYSTEM]\n[USER]\n{user_msg}\n[/USER]\n[ASSISTANT]\n" | |
| def calculate_confidence_score(domain, business_description): | |
| score = 0.5 | |
| domain_base = domain.split('.')[0].lower() | |
| business_words = set(business_description.lower().split()) | |
| if 6 <= len(domain_base) <= 15: | |
| score += 0.15 | |
| if domain_base not in business_words: | |
| score += 0.15 | |
| if domain.endswith(('.co', '.io', '.app', '.studio', '.pro')): | |
| score += 0.1 | |
| if not re.search(r'[0-9-]', domain_base): | |
| score += 0.1 | |
| return min(score, 1.0) | |
| # ---------- ALL CUDA/MODEL WORK HAPPENS ONLY BELOW ---------- | |
| def _ensure_model_loaded_on_gpu(): | |
| """This function is ONLY called from inside a @spaces.GPU function.""" | |
| global _gpu_model, _gpu_tokenizer | |
| if _gpu_model is not None and _gpu_tokenizer is not None: | |
| return _gpu_model, _gpu_tokenizer | |
| print("[ZeroGPU] Loading tokenizer + model on GPU...") | |
| tok = AutoTokenizer.from_pretrained( | |
| config.model_name, | |
| token=config.hf_token, | |
| ) | |
| tok.pad_token = tok.eos_token | |
| tok.padding_side = "right" | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| bnb_4bit_use_double_quant=True, | |
| ) | |
| base = AutoModelForCausalLM.from_pretrained( | |
| config.model_name, | |
| quantization_config=bnb_config, | |
| device_map="auto", | |
| torch_dtype=torch.bfloat16, | |
| token=config.hf_token, | |
| trust_remote_code=True, | |
| ) | |
| adapter_id = "Maikobi/domain-name-generator" | |
| print(f"[ZeroGPU] Loading LoRA adapter: {adapter_id}") | |
| peft_model = PeftModel.from_pretrained(base, adapter_id) | |
| peft_model.eval() | |
| _gpu_tokenizer = tok | |
| _gpu_model = peft_model | |
| print("[ZeroGPU] Model ready.") | |
| return _gpu_model, _gpu_tokenizer | |
| # Runs in GPU worker process (safe to touch CUDA here) | |
| def generate_domains_gpu(business_description, num_domains=3): | |
| m, tok = _ensure_model_loaded_on_gpu() | |
| prompt_text = _build_prompt(tok, business_description, num_domains=num_domains) | |
| enc = tok( | |
| prompt_text, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=config.max_length | |
| ) | |
| input_ids = enc["input_ids"].to(m.device) | |
| attention_mask = enc["attention_mask"].to(m.device) | |
| with torch.no_grad(): | |
| outputs = m.generate( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| max_new_tokens=config.max_new_tokens, | |
| temperature=config.temperature, | |
| do_sample=True, | |
| pad_token_id=tok.eos_token_id, | |
| eos_token_id=tok.eos_token_id | |
| ) | |
| gen_ids = outputs[0][input_ids.shape[1]:] | |
| decoded = tok.decode(gen_ids, skip_special_tokens=False) | |
| for stop in ["<|eot_id|>", "<|end_header_id|>user", "<|start_header_id|>user"]: | |
| if stop in decoded: | |
| decoded = decoded.split(stop, 1)[0] | |
| return decoded | |
| # ---------- API (runs in main process; calls GPU function) ---------- | |
| def domain_api(business_description): | |
| if not business_description or not business_description.strip(): | |
| return { | |
| "suggestions": [], | |
| "status": "error", | |
| "message": "Business description is required" | |
| } | |
| try: | |
| generated_text = generate_domains_gpu(business_description.strip(), num_domains=3) | |
| refusal_indicators = ["cannot", "inappropriate", "refuse", "unable", "not generate", "not provide", "violates content policy"] | |
| if any(indicator in generated_text.lower() for indicator in refusal_indicators): | |
| return { | |
| "suggestions": [], | |
| "status": "blocked", | |
| "message": "Request contains inappropriate content" | |
| } | |
| domain_re = re.compile( | |
| r"\b[a-z0-9][a-z0-9-]{1,63}\.(?:com|org|net|io|co|ai|app|dev|studio|pro|cafe|coffee|restaurant|bakery|llc|firm|agency)\b", | |
| re.IGNORECASE | |
| ) | |
| candidates = domain_re.findall(generated_text) | |
| seen = set() | |
| suggestions = [] | |
| for domain in candidates: | |
| dl = domain.lower() | |
| if dl not in seen and len(suggestions) < 3: | |
| seen.add(dl) | |
| confidence = calculate_confidence_score(dl, business_description) | |
| suggestions.append({"domain": dl, "confidence": round(confidence, 2)}) | |
| if not suggestions: | |
| base = re.sub(r"[^a-z0-9 ]+", " ", business_description.lower()) | |
| words = [w for w in base.split() if len(w) > 2][:3] | |
| stem = "".join(words)[:12] or "brand" | |
| for domain in [f"{stem}.com", f"{stem}hub.com", f"get{stem}.com"]: | |
| confidence = calculate_confidence_score(domain, business_description) | |
| suggestions.append({"domain": domain, "confidence": round(confidence, 2)}) | |
| return {"suggestions": suggestions, "status": "success"} | |
| except Exception as e: | |
| return {"suggestions": [], "status": "error", "message": f"Generation failed: {str(e)}"} | |
| def domain_api_wrapper(json_input): | |
| """ | |
| Accepts either: | |
| - A single object: {"business_description": "..."} | |
| - An array of objects: [{"business_description": "..."}, ...] | |
| Returns a single result for object input, or a list of results for array input. | |
| """ | |
| try: | |
| data = json.loads(json_input) if isinstance(json_input, str) else json_input | |
| # Batch array input | |
| if isinstance(data, list): | |
| results = [] | |
| for item in data: | |
| bd = item.get("business_description", "") if isinstance(item, dict) else "" | |
| results.append(domain_api(bd)) | |
| return results | |
| # Single object input | |
| if isinstance(data, dict): | |
| bd = data.get("business_description", "") | |
| return domain_api(bd) | |
| return {"suggestions": [], "status": "error", "message": "Invalid JSON format"} | |
| except json.JSONDecodeError: | |
| return {"suggestions": [], "status": "error", "message": "Invalid JSON input"} | |
| except Exception as e: | |
| return {"suggestions": [], "status": "error", "message": f"Input processing failed: {str(e)}"} | |
| # ---------- UI ---------- | |
| with gr.Blocks(title="Domain Name Generator API") as demo: | |
| gr.Markdown("# Domain Name Generator API") | |
| gr.Markdown("Generate creative domain names for your business using fine-tuned Llama-3.1-8B") | |
| with gr.Tabs(): | |
| with gr.Tab("Simple Input"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| business_input = gr.Textbox( | |
| label="Business Description", | |
| placeholder="organic coffee shop in downtown area", | |
| lines=3 | |
| ) | |
| simple_btn = gr.Button("Generate Domains", variant="primary") | |
| with gr.Column(): | |
| simple_output = gr.JSON(label="API Response") | |
| with gr.Tab("JSON API Format"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("**Accepted Formats:**") | |
| gr.Code('{"business_description": "your business here"}', language="json") | |
| gr.Markdown("or") | |
| gr.Code('[{"business_description": "desc1"}, {"business_description": "desc2"}]', language="json") | |
| json_input = gr.Textbox( | |
| label="JSON Input", | |
| placeholder='{"business_description": "organic coffee shop in downtown area"}', | |
| lines=8 | |
| ) | |
| json_btn = gr.Button("Generate Domains", variant="primary") | |
| with gr.Column(): | |
| json_output = gr.JSON(label="API Response") | |
| simple_btn.click(fn=domain_api, inputs=[business_input], outputs=[simple_output]) | |
| json_btn.click(fn=domain_api_wrapper, inputs=[json_input], outputs=[json_output]) | |
| # Enable queue (required so @spaces.GPU runs in a worker on Stateless GPU) | |
| demo.queue() | |
| if __name__ == "__main__": | |
| demo.launch() | |