File size: 1,153 Bytes
d819961
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer

app = FastAPI()

model_id = "mistralai/Mistral-7B-Instruct-v0.1"  # example model
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32)

class ChatRequest(BaseModel):
    messages: list

@app.post("/chat")
async def chat(req: ChatRequest):
    prompt = ""
    for msg in req.messages:
        role = msg['role']
        content = msg['content']
        prompt += f"[{role.capitalize()}]: {content}\n"
    prompt += "[Assistant]:"

    # Encode the prompt
    inputs = tokenizer(prompt, return_tensors="pt")
    inputs = {key: value.to(model.device) for key, value in inputs.items()}

    # Generate a response
    output = model.generate(inputs['input_ids'], max_new_tokens=100)
    
    # Decode the output
    result = tokenizer.decode(output[0], skip_special_tokens=True)
    
    # Return the response, removing the prompt part
    return {"response": result.replace(prompt, "").strip()}