Update app.py
Browse files
app.py
CHANGED
@@ -7,7 +7,6 @@ import gc
|
|
7 |
import logging
|
8 |
from typing import List, Dict, Any, Optional
|
9 |
|
10 |
-
# Konfigurasi logging
|
11 |
logging.basicConfig(
|
12 |
level=logging.INFO,
|
13 |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
@@ -16,47 +15,38 @@ logger = logging.getLogger(__name__)
|
|
16 |
|
17 |
app = FastAPI(title="TinyLlama API", description="API untuk model TinyLlama-1.1B-Chat")
|
18 |
|
19 |
-
|
20 |
-
|
21 |
-
model_dir = "model_cache" # Direktori untuk menyimpan model
|
22 |
|
23 |
-
# Variabel global untuk menyimpan model dan tokenizer
|
24 |
tokenizer = None
|
25 |
model = None
|
26 |
is_loading = False
|
27 |
|
28 |
def load_model():
|
29 |
-
"""Fungsi untuk memuat atau mengunduh model saat dibutuhkan"""
|
30 |
global tokenizer, model, is_loading
|
31 |
|
32 |
-
# Hindari loading bersamaan
|
33 |
if is_loading:
|
34 |
logger.info("Model sedang dimuat oleh proses lain")
|
35 |
return
|
36 |
|
37 |
-
# Cek apakah model telah dimuat
|
38 |
if tokenizer is None or model is None:
|
39 |
try:
|
40 |
is_loading = True
|
41 |
logger.info(f"Memuat model {model_id}...")
|
42 |
|
43 |
-
# Buat direktori cache jika belum ada
|
44 |
os.makedirs(model_dir, exist_ok=True)
|
45 |
|
46 |
-
# Bersihkan memori jika ada model sebelumnya
|
47 |
if model is not None:
|
48 |
del model
|
49 |
torch.cuda.empty_cache()
|
50 |
gc.collect()
|
51 |
|
52 |
-
# Muat tokenizer dengan cache
|
53 |
tokenizer = AutoTokenizer.from_pretrained(
|
54 |
model_id,
|
55 |
cache_dir=model_dir,
|
56 |
use_fast=True,
|
57 |
)
|
58 |
|
59 |
-
# Muat model dengan cache dan pengaturan hemat memori
|
60 |
device_map = "auto" if torch.cuda.is_available() else None
|
61 |
|
62 |
model = AutoModelForCausalLM.from_pretrained(
|
@@ -101,29 +91,22 @@ async def chat(req: ChatRequest):
|
|
101 |
raise HTTPException(status_code=500, detail="Gagal memuat model")
|
102 |
|
103 |
try:
|
104 |
-
# Format untuk Phi-1.5
|
105 |
-
# Phi dapat menggunakan format sederhana dengan <|user|>, <|assistant|>
|
106 |
system_content = ""
|
107 |
|
108 |
-
# Cari system prompt jika ada
|
109 |
for msg in req.messages:
|
110 |
if msg.role.lower() == "system":
|
111 |
system_content = msg.content
|
112 |
break
|
113 |
|
114 |
-
# Gabungkan pesan dalam format yang sesuai untuk Phi
|
115 |
messages_text = []
|
116 |
|
117 |
-
# Tambahkan system prompt jika ada
|
118 |
if system_content:
|
119 |
messages_text.append(f"<|system|>\n{system_content}")
|
120 |
|
121 |
-
# Tambahkan pesan user dan assistant
|
122 |
for msg in req.messages:
|
123 |
role = msg.role.lower()
|
124 |
content = msg.content
|
125 |
|
126 |
-
# Lewati system prompt karena sudah diproses
|
127 |
if role == "system":
|
128 |
continue
|
129 |
|
@@ -132,64 +115,50 @@ async def chat(req: ChatRequest):
|
|
132 |
elif role == "assistant":
|
133 |
messages_text.append(f"<|assistant|>\n{content}")
|
134 |
|
135 |
-
# Tambahkan token untuk memulai respons AI
|
136 |
messages_text.append("<|assistant|>")
|
137 |
|
138 |
-
# Gabungkan semua dengan newline
|
139 |
prompt = "\n".join(messages_text)
|
140 |
|
141 |
-
# Encode the prompt
|
142 |
inputs = tokenizer(prompt, return_tensors="pt")
|
143 |
input_length = len(inputs.input_ids[0])
|
144 |
|
145 |
-
# Pindahkan input ke device yang sama dengan model
|
146 |
if hasattr(model, 'device'):
|
147 |
inputs = {key: value.to(model.device) for key, value in inputs.items()}
|
148 |
|
149 |
-
# Set parameter generasi yang lebih sesuai
|
150 |
generation_config = {
|
151 |
'max_new_tokens': req.max_tokens,
|
152 |
-
'temperature':
|
153 |
-
'top_p':
|
154 |
-
'do_sample': False,
|
155 |
'pad_token_id': tokenizer.eos_token_id
|
156 |
}
|
157 |
|
158 |
-
# Generate a response
|
159 |
with torch.no_grad():
|
160 |
output = model.generate(
|
161 |
inputs['input_ids'],
|
162 |
**generation_config
|
163 |
)
|
164 |
|
165 |
-
# Decode the output
|
166 |
result = tokenizer.decode(output[0], skip_special_tokens=True)
|
167 |
|
168 |
-
# Cari respons setelah token <|assistant|> terakhir
|
169 |
assistants = result.split("<|assistant|>")
|
170 |
if len(assistants) > 1:
|
171 |
response = assistants[-1].strip()
|
172 |
else:
|
173 |
-
# Jika tidak ada token <|assistant|>
|
174 |
-
# Ambil respons setelah prompt terakhir
|
175 |
user_tokens = result.split("<|user|>")
|
176 |
if len(user_tokens) > 1:
|
177 |
last_part = user_tokens[-1]
|
178 |
if "\n" in last_part:
|
179 |
-
# Ambil teks setelah baris pertama (yang berisi prompt user)
|
180 |
response = "\n".join(last_part.split("\n")[1:]).strip()
|
181 |
else:
|
182 |
response = last_part.strip()
|
183 |
else:
|
184 |
-
# Fallback ke metode sederhana
|
185 |
prompt_length = len(tokenizer.decode(inputs.input_ids[0], skip_special_tokens=True))
|
186 |
response = result[prompt_length:].strip()
|
187 |
|
188 |
-
# Jika respons kosong, berikan pesan default
|
189 |
if not response:
|
190 |
response = "Maaf, tidak dapat menghasilkan respons yang valid."
|
191 |
|
192 |
-
# Hitung penggunaan token
|
193 |
output_length = len(output[0])
|
194 |
new_tokens = output_length - input_length
|
195 |
|
@@ -226,7 +195,6 @@ async def force_load_model(background_tasks: BackgroundTasks):
|
|
226 |
if model is not None:
|
227 |
return {"status": "already_loaded", "message": f"Model {model_id} sudah dimuat"}
|
228 |
|
229 |
-
# Lakukan loading di background untuk tidak memblokir API
|
230 |
background_tasks.add_task(load_model)
|
231 |
return {"status": "loading_started", "message": f"Proses memuat model {model_id} telah dimulai"}
|
232 |
|
@@ -247,9 +215,7 @@ async def root():
|
|
247 |
}
|
248 |
|
249 |
|
250 |
-
# Untuk menjalankan dengan uvicorn
|
251 |
if __name__ == "__main__":
|
252 |
import uvicorn
|
253 |
-
# Mulai server API
|
254 |
logger.info(f"Memulai server API untuk model {model_id}")
|
255 |
-
uvicorn.run(app, host="0.0.0.0", port=7860)
|
|
|
7 |
import logging
|
8 |
from typing import List, Dict, Any, Optional
|
9 |
|
|
|
10 |
logging.basicConfig(
|
11 |
level=logging.INFO,
|
12 |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
|
15 |
|
16 |
app = FastAPI(title="TinyLlama API", description="API untuk model TinyLlama-1.1B-Chat")
|
17 |
|
18 |
+
model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
19 |
+
model_dir = "model_cache"
|
|
|
20 |
|
|
|
21 |
tokenizer = None
|
22 |
model = None
|
23 |
is_loading = False
|
24 |
|
25 |
def load_model():
|
|
|
26 |
global tokenizer, model, is_loading
|
27 |
|
|
|
28 |
if is_loading:
|
29 |
logger.info("Model sedang dimuat oleh proses lain")
|
30 |
return
|
31 |
|
|
|
32 |
if tokenizer is None or model is None:
|
33 |
try:
|
34 |
is_loading = True
|
35 |
logger.info(f"Memuat model {model_id}...")
|
36 |
|
|
|
37 |
os.makedirs(model_dir, exist_ok=True)
|
38 |
|
|
|
39 |
if model is not None:
|
40 |
del model
|
41 |
torch.cuda.empty_cache()
|
42 |
gc.collect()
|
43 |
|
|
|
44 |
tokenizer = AutoTokenizer.from_pretrained(
|
45 |
model_id,
|
46 |
cache_dir=model_dir,
|
47 |
use_fast=True,
|
48 |
)
|
49 |
|
|
|
50 |
device_map = "auto" if torch.cuda.is_available() else None
|
51 |
|
52 |
model = AutoModelForCausalLM.from_pretrained(
|
|
|
91 |
raise HTTPException(status_code=500, detail="Gagal memuat model")
|
92 |
|
93 |
try:
|
|
|
|
|
94 |
system_content = ""
|
95 |
|
|
|
96 |
for msg in req.messages:
|
97 |
if msg.role.lower() == "system":
|
98 |
system_content = msg.content
|
99 |
break
|
100 |
|
|
|
101 |
messages_text = []
|
102 |
|
|
|
103 |
if system_content:
|
104 |
messages_text.append(f"<|system|>\n{system_content}")
|
105 |
|
|
|
106 |
for msg in req.messages:
|
107 |
role = msg.role.lower()
|
108 |
content = msg.content
|
109 |
|
|
|
110 |
if role == "system":
|
111 |
continue
|
112 |
|
|
|
115 |
elif role == "assistant":
|
116 |
messages_text.append(f"<|assistant|>\n{content}")
|
117 |
|
|
|
118 |
messages_text.append("<|assistant|>")
|
119 |
|
|
|
120 |
prompt = "\n".join(messages_text)
|
121 |
|
|
|
122 |
inputs = tokenizer(prompt, return_tensors="pt")
|
123 |
input_length = len(inputs.input_ids[0])
|
124 |
|
|
|
125 |
if hasattr(model, 'device'):
|
126 |
inputs = {key: value.to(model.device) for key, value in inputs.items()}
|
127 |
|
|
|
128 |
generation_config = {
|
129 |
'max_new_tokens': req.max_tokens,
|
130 |
+
'temperature': req.temperature,
|
131 |
+
'top_p': req.top_p,
|
132 |
+
'do_sample': True if req.temperature > 0 else False,
|
133 |
'pad_token_id': tokenizer.eos_token_id
|
134 |
}
|
135 |
|
|
|
136 |
with torch.no_grad():
|
137 |
output = model.generate(
|
138 |
inputs['input_ids'],
|
139 |
**generation_config
|
140 |
)
|
141 |
|
|
|
142 |
result = tokenizer.decode(output[0], skip_special_tokens=True)
|
143 |
|
|
|
144 |
assistants = result.split("<|assistant|>")
|
145 |
if len(assistants) > 1:
|
146 |
response = assistants[-1].strip()
|
147 |
else:
|
|
|
|
|
148 |
user_tokens = result.split("<|user|>")
|
149 |
if len(user_tokens) > 1:
|
150 |
last_part = user_tokens[-1]
|
151 |
if "\n" in last_part:
|
|
|
152 |
response = "\n".join(last_part.split("\n")[1:]).strip()
|
153 |
else:
|
154 |
response = last_part.strip()
|
155 |
else:
|
|
|
156 |
prompt_length = len(tokenizer.decode(inputs.input_ids[0], skip_special_tokens=True))
|
157 |
response = result[prompt_length:].strip()
|
158 |
|
|
|
159 |
if not response:
|
160 |
response = "Maaf, tidak dapat menghasilkan respons yang valid."
|
161 |
|
|
|
162 |
output_length = len(output[0])
|
163 |
new_tokens = output_length - input_length
|
164 |
|
|
|
195 |
if model is not None:
|
196 |
return {"status": "already_loaded", "message": f"Model {model_id} sudah dimuat"}
|
197 |
|
|
|
198 |
background_tasks.add_task(load_model)
|
199 |
return {"status": "loading_started", "message": f"Proses memuat model {model_id} telah dimulai"}
|
200 |
|
|
|
215 |
}
|
216 |
|
217 |
|
|
|
218 |
if __name__ == "__main__":
|
219 |
import uvicorn
|
|
|
220 |
logger.info(f"Memulai server API untuk model {model_id}")
|
221 |
+
uvicorn.run(app, host="0.0.0.0", port=7860)
|