import torch from fastapi import FastAPI, HTTPException, BackgroundTasks from pydantic import BaseModel from transformers import AutoModelForCausalLM, AutoTokenizer import os import gc import logging from typing import List, Dict, Any, Optional logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) app = FastAPI(title="TinyLlama API", description="API untuk model TinyLlama-1.1B-Chat") model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" model_dir = "model_cache" tokenizer = None model = None is_loading = False def load_model(): global tokenizer, model, is_loading if is_loading: logger.info("Model sedang dimuat oleh proses lain") return if tokenizer is None or model is None: try: is_loading = True logger.info(f"Memuat model {model_id}...") os.makedirs(model_dir, exist_ok=True) if model is not None: del model torch.cuda.empty_cache() gc.collect() tokenizer = AutoTokenizer.from_pretrained( model_id, cache_dir=model_dir, use_fast=True, ) device_map = "auto" if torch.cuda.is_available() else None model = AutoModelForCausalLM.from_pretrained( model_id, cache_dir=model_dir, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, low_cpu_mem_usage=True, device_map=device_map ) logger.info("Model berhasil dimuat!") except Exception as e: logger.error(f"Gagal memuat model: {str(e)}") raise e finally: is_loading = False class Message(BaseModel): role: str content: str class ChatRequest(BaseModel): messages: List[Message] max_tokens: Optional[int] = 500 temperature: Optional[float] = 0.7 top_p: Optional[float] = 0.9 class ChatResponse(BaseModel): response: str usage: Dict[str, Any] @app.on_event("startup") async def startup_event(): load_model() @app.post("/chat", response_model=ChatResponse) async def chat(req: ChatRequest): if model is None: raise HTTPException(status_code=500, detail="Gagal memuat model") try: system_content = "" for msg in req.messages: if msg.role.lower() == "system": system_content = msg.content break messages_text = [] if system_content: messages_text.append(f"<|system|>\n{system_content}") for msg in req.messages: role = msg.role.lower() content = msg.content if role == "system": continue if role == "user": messages_text.append(f"<|user|>\n{content}") elif role == "assistant": messages_text.append(f"<|assistant|>\n{content}") messages_text.append("<|assistant|>") prompt = "\n".join(messages_text) inputs = tokenizer(prompt, return_tensors="pt") input_length = len(inputs.input_ids[0]) if hasattr(model, 'device'): inputs = {key: value.to(model.device) for key, value in inputs.items()} generation_config = { 'max_new_tokens': req.max_tokens, 'temperature': req.temperature, 'top_p': req.top_p, 'do_sample': True if req.temperature > 0 else False, 'pad_token_id': tokenizer.eos_token_id } with torch.no_grad(): output = model.generate( inputs['input_ids'], **generation_config ) result = tokenizer.decode(output[0], skip_special_tokens=True) assistants = result.split("<|assistant|>") if len(assistants) > 1: response = assistants[-1].strip() else: user_tokens = result.split("<|user|>") if len(user_tokens) > 1: last_part = user_tokens[-1] if "\n" in last_part: response = "\n".join(last_part.split("\n")[1:]).strip() else: response = last_part.strip() else: prompt_length = len(tokenizer.decode(inputs.input_ids[0], skip_special_tokens=True)) response = result[prompt_length:].strip() if not response: response = "Maaf, tidak dapat menghasilkan respons yang valid." output_length = len(output[0]) new_tokens = output_length - input_length usage_info = { "prompt_tokens": input_length, "completion_tokens": new_tokens, "total_tokens": output_length } return ChatResponse(response=response, usage=usage_info) except Exception as e: logger.error(f"Error saat melakukan chat: {str(e)}") raise HTTPException(status_code=500, detail=f"Gagal menghasilkan respons: {str(e)}") @app.get("/model-status") async def model_status(): status = "loading" if is_loading else "not_loaded" if model is None else "loaded" return { "status": status, "model_id": model_id, "device": str(model.device) if model is not None and hasattr(model, 'device') else "tidak tersedia" } @app.post("/load-model") async def force_load_model(background_tasks: BackgroundTasks): global is_loading if is_loading: return {"status": "loading", "message": f"Model {model_id} sedang dimuat"} if model is not None: return {"status": "already_loaded", "message": f"Model {model_id} sudah dimuat"} background_tasks.add_task(load_model) return {"status": "loading_started", "message": f"Proses memuat model {model_id} telah dimulai"} @app.get("/") async def root(): status = "loading" if is_loading else "not_loaded" if model is None else "loaded" return { "message": "API TinyLlama berjalan", "model": model_id, "status": status, "endpoints": [ {"path": "/", "method": "GET", "description": "Informasi API"}, {"path": "/chat", "method": "POST", "description": "Endpoint untuk chat dengan model"}, {"path": "/model-status", "method": "GET", "description": "Cek status model"}, {"path": "/load-model", "method": "POST", "description": "Muat model jika belum dimuat"} ] } if __name__ == "__main__": import uvicorn logger.info(f"Memulai server API untuk model {model_id}") uvicorn.run(app, host="0.0.0.0", port=7860)