import torch from fastapi import FastAPI, HTTPException, BackgroundTasks from pydantic import BaseModel from transformers import AutoModelForCausalLM, AutoTokenizer import os import gc import logging from typing import List, Dict, Any, Optional # Konfigurasi logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) app = FastAPI(title="TinyLlama API", description="API untuk model TinyLlama-1.1B-Chat") # Gunakan model open source yang tidak memerlukan login model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" # Model TinyLlama Chat model_dir = "model_cache" # Direktori untuk menyimpan model # Variabel global untuk menyimpan model dan tokenizer tokenizer = None model = None is_loading = False def load_model(): """Fungsi untuk memuat atau mengunduh model saat dibutuhkan""" global tokenizer, model, is_loading # Hindari loading bersamaan if is_loading: logger.info("Model sedang dimuat oleh proses lain") return # Cek apakah model telah dimuat if tokenizer is None or model is None: try: is_loading = True logger.info(f"Memuat model {model_id}...") # Buat direktori cache jika belum ada os.makedirs(model_dir, exist_ok=True) # Bersihkan memori jika ada model sebelumnya if model is not None: del model torch.cuda.empty_cache() gc.collect() # Muat tokenizer dengan cache tokenizer = AutoTokenizer.from_pretrained( model_id, cache_dir=model_dir, use_fast=True, ) # Muat model dengan cache dan pengaturan hemat memori device_map = "auto" if torch.cuda.is_available() else None model = AutoModelForCausalLM.from_pretrained( model_id, cache_dir=model_dir, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, low_cpu_mem_usage=True, device_map=device_map ) logger.info("Model berhasil dimuat!") except Exception as e: logger.error(f"Gagal memuat model: {str(e)}") raise e finally: is_loading = False class Message(BaseModel): role: str content: str class ChatRequest(BaseModel): messages: List[Message] max_tokens: Optional[int] = 500 temperature: Optional[float] = 0.7 top_p: Optional[float] = 0.9 class ChatResponse(BaseModel): response: str usage: Dict[str, Any] @app.post("/chat", response_model=ChatResponse) async def chat(req: ChatRequest): # Pastikan model dimuat sebelum digunakan if model is None: load_model() if model is None: raise HTTPException(status_code=500, detail="Gagal memuat model") try: # Format untuk Phi-1.5 # Phi dapat menggunakan format sederhana dengan <|user|>, <|assistant|> system_content = "" # Cari system prompt jika ada for msg in req.messages: if msg.role.lower() == "system": system_content = msg.content break # Gabungkan pesan dalam format yang sesuai untuk Phi messages_text = [] # Tambahkan system prompt jika ada if system_content: messages_text.append(f"<|system|>\n{system_content}") # Tambahkan pesan user dan assistant for msg in req.messages: role = msg.role.lower() content = msg.content # Lewati system prompt karena sudah diproses if role == "system": continue if role == "user": messages_text.append(f"<|user|>\n{content}") elif role == "assistant": messages_text.append(f"<|assistant|>\n{content}") # Tambahkan token untuk memulai respons AI messages_text.append("<|assistant|>") # Gabungkan semua dengan newline prompt = "\n".join(messages_text) # Encode the prompt inputs = tokenizer(prompt, return_tensors="pt") input_length = len(inputs.input_ids[0]) # Pindahkan input ke device yang sama dengan model if hasattr(model, 'device'): inputs = {key: value.to(model.device) for key, value in inputs.items()} # Set parameter generasi yang lebih sesuai generation_config = { 'max_new_tokens': req.max_tokens, 'temperature': req.temperature, 'top_p': req.top_p, 'do_sample': True, 'pad_token_id': tokenizer.eos_token_id } # Generate a response with torch.no_grad(): output = model.generate( inputs['input_ids'], **generation_config ) # Decode the output result = tokenizer.decode(output[0], skip_special_tokens=True) # Cari respons setelah token <|assistant|> terakhir assistants = result.split("<|assistant|>") if len(assistants) > 1: response = assistants[-1].strip() else: # Jika tidak ada token <|assistant|> # Ambil respons setelah prompt terakhir user_tokens = result.split("<|user|>") if len(user_tokens) > 1: last_part = user_tokens[-1] if "\n" in last_part: # Ambil teks setelah baris pertama (yang berisi prompt user) response = "\n".join(last_part.split("\n")[1:]).strip() else: response = last_part.strip() else: # Fallback ke metode sederhana prompt_length = len(tokenizer.decode(inputs.input_ids[0], skip_special_tokens=True)) response = result[prompt_length:].strip() # Jika respons kosong, berikan pesan default if not response: response = "Maaf, tidak dapat menghasilkan respons yang valid." # Hitung penggunaan token output_length = len(output[0]) new_tokens = output_length - input_length usage_info = { "prompt_tokens": input_length, "completion_tokens": new_tokens, "total_tokens": output_length } return ChatResponse(response=response, usage=usage_info) except Exception as e: logger.error(f"Error saat melakukan chat: {str(e)}") raise HTTPException(status_code=500, detail=f"Gagal menghasilkan respons: {str(e)}") @app.get("/model-status") async def model_status(): status = "loading" if is_loading else "not_loaded" if model is None else "loaded" return { "status": status, "model_id": model_id, "device": str(model.device) if model is not None and hasattr(model, 'device') else "tidak tersedia" } @app.post("/load-model") async def force_load_model(background_tasks: BackgroundTasks): global is_loading if is_loading: return {"status": "loading", "message": f"Model {model_id} sedang dimuat"} if model is not None: return {"status": "already_loaded", "message": f"Model {model_id} sudah dimuat"} # Lakukan loading di background untuk tidak memblokir API background_tasks.add_task(load_model) return {"status": "loading_started", "message": f"Proses memuat model {model_id} telah dimulai"} @app.get("/") async def root(): status = "loading" if is_loading else "not_loaded" if model is None else "loaded" return { "message": "API TinyLlama berjalan", "model": model_id, "status": status, "endpoints": [ {"path": "/", "method": "GET", "description": "Informasi API"}, {"path": "/chat", "method": "POST", "description": "Endpoint untuk chat dengan model"}, {"path": "/model-status", "method": "GET", "description": "Cek status model"}, {"path": "/load-model", "method": "POST", "description": "Muat model jika belum dimuat"} ] } # Untuk menjalankan dengan uvicorn if __name__ == "__main__": import uvicorn # Mulai server API logger.info(f"Memulai server API untuk model {model_id}") uvicorn.run(app, host="0.0.0.0", port=7860) # Port 7860 adalah port default di HF Spaces