Spaces:
Paused
Paused
import torch | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Device in use: {device}") | |
from fastapi import FastAPI, Request | |
from fastapi.responses import JSONResponse | |
from pydantic import BaseModel | |
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM | |
import uvicorn | |
import os | |
# Load model (Mistral, Mixtral, Llama2, etc. that works on zeroGPU) | |
model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" # Open version (not instruct) | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", torch_dtype="auto") | |
generator = pipeline("text-generation", model=model, tokenizer=tokenizer) | |
# Create app | |
app = FastAPI() | |
# Data format matching OpenAI API | |
class Message(BaseModel): | |
role: str | |
content: str | |
class ChatRequest(BaseModel): | |
model: str | |
messages: list[Message] | |
temperature: float = 0.7 | |
top_p: float = 1.0 | |
max_tokens: int = 256 | |
stream: bool = False | |
async def chat_completions(request: ChatRequest): | |
# Combine chat messages into a prompt | |
prompt = "" | |
for msg in request.messages: | |
prompt += f"{msg.role}: {msg.content}\n" | |
prompt += "assistant:" | |
# Generate output | |
output = generator(prompt, max_new_tokens=request.max_tokens, temperature=request.temperature)[0]["generated_text"] | |
# Extract assistant response | |
assistant_reply = output.split("assistant:")[-1].strip() | |
# Build OpenAI-compatible response | |
return JSONResponse({ | |
"id": "chatcmpl-fake001", | |
"object": "chat.completion", | |
"created": 1234567890, | |
"model": request.model, | |
"choices": [ | |
{ | |
"index": 0, | |
"message": { | |
"role": "assistant", | |
"content": assistant_reply | |
}, | |
"finish_reason": "stop" | |
} | |
], | |
"usage": { | |
"prompt_tokens": 0, | |
"completion_tokens": 0, | |
"total_tokens": 0 | |
} | |
}) | |
# Run app if local (Spaces will handle this themselves via gradio) | |
if __name__ == "__main__": | |
uvicorn.run(app, host="0.0.0.0", port=7860) | |