akane-ai / app.py
Arifzyn19
Add application file
d819961
raw
history blame
1.15 kB
import torch
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer
app = FastAPI()
model_id = "mistralai/Mistral-7B-Instruct-v0.1" # example model
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32)
class ChatRequest(BaseModel):
messages: list
@app.post("/chat")
async def chat(req: ChatRequest):
prompt = ""
for msg in req.messages:
role = msg['role']
content = msg['content']
prompt += f"[{role.capitalize()}]: {content}\n"
prompt += "[Assistant]:"
# Encode the prompt
inputs = tokenizer(prompt, return_tensors="pt")
inputs = {key: value.to(model.device) for key, value in inputs.items()}
# Generate a response
output = model.generate(inputs['input_ids'], max_new_tokens=100)
# Decode the output
result = tokenizer.decode(output[0], skip_special_tokens=True)
# Return the response, removing the prompt part
return {"response": result.replace(prompt, "").strip()}