akane-ai / app.py
Arifzyn19
Update: app.py and requirement
69172f9
raw
history blame
3.33 kB
import torch
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import os
import gc
app = FastAPI()
# Model configuration
model_id = "mistralai/Mistral-7B-Instruct-v0.1"
model_dir = "model_cache" # Direktori untuk menyimpan model
# Variabel global untuk menyimpan model dan tokenizer
tokenizer = None
model = None
def load_model():
"""Fungsi untuk memuat atau mengunduh model saat dibutuhkan"""
global tokenizer, model
# Cek apakah model telah dimuat
if tokenizer is None or model is None:
print(f"Loading model {model_id}...")
# Buat direktori cache jika belum ada
os.makedirs(model_dir, exist_ok=True)
# Bersihkan memori jika ada model sebelumnya
if model is not None:
del model
torch.cuda.empty_cache()
gc.collect()
# Muat tokenizer dengan cache
tokenizer = AutoTokenizer.from_pretrained(
model_id,
cache_dir=model_dir,
use_fast=True
)
# Muat model dengan cache dan pengaturan hemat memori
device_map = "auto" if torch.cuda.is_available() else None
model = AutoModelForCausalLM.from_pretrained(
model_id,
cache_dir=model_dir,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
low_cpu_mem_usage=True,
device_map=device_map
)
print("Model loaded successfully!")
class ChatRequest(BaseModel):
messages: list
@app.post("/chat")
async def chat(req: ChatRequest):
# Pastikan model dimuat sebelum digunakan
load_model()
prompt = ""
for msg in req.messages:
role = msg['role']
content = msg['content']
prompt += f"[{role.capitalize()}]: {content}\n"
prompt += "[Assistant]:"
# Encode the prompt
inputs = tokenizer(prompt, return_tensors="pt")
# Pindahkan input ke device yang sama dengan model
if hasattr(model, 'device'):
inputs = {key: value.to(model.device) for key, value in inputs.items()}
# Set parameter generasi yang lebih sesuai
generation_config = {
'max_new_tokens': 500,
'temperature': 0.7,
'top_p': 0.9,
'do_sample': True,
'pad_token_id': tokenizer.eos_token_id
}
# Generate a response
with torch.no_grad():
output = model.generate(
inputs['input_ids'],
**generation_config
)
# Decode the output
result = tokenizer.decode(output[0], skip_special_tokens=True)
# Return the response, removing the prompt part
return {"response": result.replace(prompt, "").strip()}
@app.get("/model-status")
async def model_status():
if model is None:
return {"status": "not_loaded", "model_id": model_id}
return {"status": "loaded", "model_id": model_id}
@app.post("/load-model")
async def force_load_model():
load_model()
return {"status": "success", "message": f"Model {model_id} dimuat berhasil"}
# Untuk menjalankan dengan uvicorn
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860) # Port 7860 adalah port default di HF Spaces