chatbotv2 / main.py
sahil239's picture
Upload main.py
521cf23 verified
raw
history blame
2.43 kB
import os
from pathlib import Path
from fastapi import FastAPI
from pydantic import BaseModel
from starlette.middleware.cors import CORSMiddleware
from starlette.responses import HTMLResponse, FileResponse
from starlette.staticfiles import StaticFiles
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from peft import PeftModel
# Load once at startup
base_model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
adapter_repo = "sahil239/chatbot-v2"
tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
base_model = AutoModelForCausalLM.from_pretrained(
base_model_name,
device_map="auto",
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
trust_remote_code=True
)
model = PeftModel.from_pretrained(base_model, adapter_repo)
model.eval()
def generate_response(user_message: str) -> str:
# Proper chat formatting
chat = [{"role": "user", "content": user_message}]
inputs = tokenizer.apply_chat_template(
chat,
tokenize=True,
return_tensors="pt"
).to(model.device)
outputs = model.generate(
input_ids=inputs,
max_new_tokens=256,
temperature=0.1,
top_p=0.95,
do_sample=False,
eos_token_id=tokenizer.eos_token_id
)
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Optional: strip past conversation if using chat template
# You might want to extract just the assistant's last response.
assistant_prefix = "<|assistant|>"
if assistant_prefix in decoded:
reply = decoded.split(assistant_prefix)[-1].strip()
else: reply = decoded.strip()
stop_strings = ["<|user|>", "<|end|>", "</s>"]
for stop in stop_strings:
if stop in reply:
reply = reply.split(stop)[0].strip()
return reply
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
class ChatRequest(BaseModel):
prompt: str
@app.post("/chat")
def chat_endpoint(req: ChatRequest):
reply = generate_response(req.prompt)
return {"response": reply}
@app.get("/", response_class=HTMLResponse)
async def homepage():
html_path = Path(__file__).parent / "index.html"
return HTMLResponse(content=html_path.read_text(), status_code=200)