import torch from fastapi import FastAPI from pydantic import BaseModel from transformers import AutoModelForCausalLM, AutoTokenizer app = FastAPI() model_id = "mistralai/Mistral-7B-Instruct-v0.1" # example model tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32) class ChatRequest(BaseModel): messages: list @app.post("/chat") async def chat(req: ChatRequest): prompt = "" for msg in req.messages: role = msg['role'] content = msg['content'] prompt += f"[{role.capitalize()}]: {content}\n" prompt += "[Assistant]:" # Encode the prompt inputs = tokenizer(prompt, return_tensors="pt") inputs = {key: value.to(model.device) for key, value in inputs.items()} # Generate a response output = model.generate(inputs['input_ids'], max_new_tokens=100) # Decode the output result = tokenizer.decode(output[0], skip_special_tokens=True) # Return the response, removing the prompt part return {"response": result.replace(prompt, "").strip()}