|
import torch |
|
from fastapi import FastAPI |
|
from pydantic import BaseModel |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
import os |
|
import gc |
|
|
|
app = FastAPI() |
|
|
|
|
|
model_id = "mistralai/Mistral-7B-Instruct-v0.1" |
|
model_dir = "model_cache" |
|
|
|
|
|
tokenizer = None |
|
model = None |
|
|
|
def load_model(): |
|
"""Fungsi untuk memuat atau mengunduh model saat dibutuhkan""" |
|
global tokenizer, model |
|
|
|
|
|
if tokenizer is None or model is None: |
|
print(f"Loading model {model_id}...") |
|
|
|
|
|
os.makedirs(model_dir, exist_ok=True) |
|
|
|
|
|
if model is not None: |
|
del model |
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
model_id, |
|
cache_dir=model_dir, |
|
use_fast=True |
|
) |
|
|
|
|
|
device_map = "auto" if torch.cuda.is_available() else None |
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_id, |
|
cache_dir=model_dir, |
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
|
low_cpu_mem_usage=True, |
|
device_map=device_map |
|
) |
|
|
|
print("Model loaded successfully!") |
|
|
|
|
|
class ChatRequest(BaseModel): |
|
messages: list |
|
|
|
|
|
@app.post("/chat") |
|
async def chat(req: ChatRequest): |
|
|
|
load_model() |
|
|
|
prompt = "" |
|
for msg in req.messages: |
|
role = msg['role'] |
|
content = msg['content'] |
|
prompt += f"[{role.capitalize()}]: {content}\n" |
|
prompt += "[Assistant]:" |
|
|
|
|
|
inputs = tokenizer(prompt, return_tensors="pt") |
|
|
|
|
|
if hasattr(model, 'device'): |
|
inputs = {key: value.to(model.device) for key, value in inputs.items()} |
|
|
|
|
|
generation_config = { |
|
'max_new_tokens': 500, |
|
'temperature': 0.7, |
|
'top_p': 0.9, |
|
'do_sample': True, |
|
'pad_token_id': tokenizer.eos_token_id |
|
} |
|
|
|
|
|
with torch.no_grad(): |
|
output = model.generate( |
|
inputs['input_ids'], |
|
**generation_config |
|
) |
|
|
|
|
|
result = tokenizer.decode(output[0], skip_special_tokens=True) |
|
|
|
|
|
return {"response": result.replace(prompt, "").strip()} |
|
|
|
|
|
@app.get("/model-status") |
|
async def model_status(): |
|
if model is None: |
|
return {"status": "not_loaded", "model_id": model_id} |
|
return {"status": "loaded", "model_id": model_id} |
|
|
|
|
|
@app.post("/load-model") |
|
async def force_load_model(): |
|
load_model() |
|
return {"status": "success", "message": f"Model {model_id} dimuat berhasil"} |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
uvicorn.run(app, host="0.0.0.0", port=7860) |