import os from pathlib import Path from fastapi import FastAPI from pydantic import BaseModel from starlette.middleware.cors import CORSMiddleware from starlette.responses import HTMLResponse, FileResponse from starlette.staticfiles import StaticFiles from transformers import AutoTokenizer, AutoModelForCausalLM import torch from peft import PeftModel # Load once at startup base_model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" adapter_repo = "sahil239/chatbot-v2" tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token base_model = AutoModelForCausalLM.from_pretrained( base_model_name, device_map="auto", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, trust_remote_code=True ) model = PeftModel.from_pretrained(base_model, adapter_repo) model.eval() def generate_response(user_message: str) -> str: # Proper chat formatting chat = [{"role": "user", "content": user_message}] inputs = tokenizer.apply_chat_template( chat, tokenize=True, return_tensors="pt" ).to(model.device) outputs = model.generate( input_ids=inputs, max_new_tokens=256, temperature=0.1, top_p=0.95, do_sample=False, eos_token_id=tokenizer.eos_token_id ) decoded = tokenizer.decode(outputs[0], skip_special_tokens=True) # Optional: strip past conversation if using chat template # You might want to extract just the assistant's last response. assistant_prefix = "<|assistant|>" if assistant_prefix in decoded: reply = decoded.split(assistant_prefix)[-1].strip() else: reply = decoded.strip() stop_strings = ["<|user|>", "<|end|>", ""] for stop in stop_strings: if stop in reply: reply = reply.split(stop)[0].strip() return reply app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) class ChatRequest(BaseModel): prompt: str @app.post("/chat") def chat_endpoint(req: ChatRequest): reply = generate_response(req.prompt) return {"response": reply} @app.get("/", response_class=HTMLResponse) async def homepage(): html_path = Path(__file__).parent / "index.html" return HTMLResponse(content=html_path.read_text(), status_code=200)