Arifzyn commited on
Commit
7bf34a0
·
verified ·
1 Parent(s): d3914ef

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -40
app.py CHANGED
@@ -7,7 +7,6 @@ import gc
7
  import logging
8
  from typing import List, Dict, Any, Optional
9
 
10
- # Konfigurasi logging
11
  logging.basicConfig(
12
  level=logging.INFO,
13
  format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
@@ -16,47 +15,38 @@ logger = logging.getLogger(__name__)
16
 
17
  app = FastAPI(title="TinyLlama API", description="API untuk model TinyLlama-1.1B-Chat")
18
 
19
- # Gunakan model open source yang tidak memerlukan login
20
- model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" # Model TinyLlama Chat
21
- model_dir = "model_cache" # Direktori untuk menyimpan model
22
 
23
- # Variabel global untuk menyimpan model dan tokenizer
24
  tokenizer = None
25
  model = None
26
  is_loading = False
27
 
28
  def load_model():
29
- """Fungsi untuk memuat atau mengunduh model saat dibutuhkan"""
30
  global tokenizer, model, is_loading
31
 
32
- # Hindari loading bersamaan
33
  if is_loading:
34
  logger.info("Model sedang dimuat oleh proses lain")
35
  return
36
 
37
- # Cek apakah model telah dimuat
38
  if tokenizer is None or model is None:
39
  try:
40
  is_loading = True
41
  logger.info(f"Memuat model {model_id}...")
42
 
43
- # Buat direktori cache jika belum ada
44
  os.makedirs(model_dir, exist_ok=True)
45
 
46
- # Bersihkan memori jika ada model sebelumnya
47
  if model is not None:
48
  del model
49
  torch.cuda.empty_cache()
50
  gc.collect()
51
 
52
- # Muat tokenizer dengan cache
53
  tokenizer = AutoTokenizer.from_pretrained(
54
  model_id,
55
  cache_dir=model_dir,
56
  use_fast=True,
57
  )
58
 
59
- # Muat model dengan cache dan pengaturan hemat memori
60
  device_map = "auto" if torch.cuda.is_available() else None
61
 
62
  model = AutoModelForCausalLM.from_pretrained(
@@ -101,29 +91,22 @@ async def chat(req: ChatRequest):
101
  raise HTTPException(status_code=500, detail="Gagal memuat model")
102
 
103
  try:
104
- # Format untuk Phi-1.5
105
- # Phi dapat menggunakan format sederhana dengan <|user|>, <|assistant|>
106
  system_content = ""
107
 
108
- # Cari system prompt jika ada
109
  for msg in req.messages:
110
  if msg.role.lower() == "system":
111
  system_content = msg.content
112
  break
113
 
114
- # Gabungkan pesan dalam format yang sesuai untuk Phi
115
  messages_text = []
116
 
117
- # Tambahkan system prompt jika ada
118
  if system_content:
119
  messages_text.append(f"<|system|>\n{system_content}")
120
 
121
- # Tambahkan pesan user dan assistant
122
  for msg in req.messages:
123
  role = msg.role.lower()
124
  content = msg.content
125
 
126
- # Lewati system prompt karena sudah diproses
127
  if role == "system":
128
  continue
129
 
@@ -132,64 +115,50 @@ async def chat(req: ChatRequest):
132
  elif role == "assistant":
133
  messages_text.append(f"<|assistant|>\n{content}")
134
 
135
- # Tambahkan token untuk memulai respons AI
136
  messages_text.append("<|assistant|>")
137
 
138
- # Gabungkan semua dengan newline
139
  prompt = "\n".join(messages_text)
140
 
141
- # Encode the prompt
142
  inputs = tokenizer(prompt, return_tensors="pt")
143
  input_length = len(inputs.input_ids[0])
144
 
145
- # Pindahkan input ke device yang sama dengan model
146
  if hasattr(model, 'device'):
147
  inputs = {key: value.to(model.device) for key, value in inputs.items()}
148
 
149
- # Set parameter generasi yang lebih sesuai
150
  generation_config = {
151
  'max_new_tokens': req.max_tokens,
152
- 'temperature': 0.7,
153
- 'top_p': 0.9,
154
- 'do_sample': False,
155
  'pad_token_id': tokenizer.eos_token_id
156
  }
157
 
158
- # Generate a response
159
  with torch.no_grad():
160
  output = model.generate(
161
  inputs['input_ids'],
162
  **generation_config
163
  )
164
 
165
- # Decode the output
166
  result = tokenizer.decode(output[0], skip_special_tokens=True)
167
 
168
- # Cari respons setelah token <|assistant|> terakhir
169
  assistants = result.split("<|assistant|>")
170
  if len(assistants) > 1:
171
  response = assistants[-1].strip()
172
  else:
173
- # Jika tidak ada token <|assistant|>
174
- # Ambil respons setelah prompt terakhir
175
  user_tokens = result.split("<|user|>")
176
  if len(user_tokens) > 1:
177
  last_part = user_tokens[-1]
178
  if "\n" in last_part:
179
- # Ambil teks setelah baris pertama (yang berisi prompt user)
180
  response = "\n".join(last_part.split("\n")[1:]).strip()
181
  else:
182
  response = last_part.strip()
183
  else:
184
- # Fallback ke metode sederhana
185
  prompt_length = len(tokenizer.decode(inputs.input_ids[0], skip_special_tokens=True))
186
  response = result[prompt_length:].strip()
187
 
188
- # Jika respons kosong, berikan pesan default
189
  if not response:
190
  response = "Maaf, tidak dapat menghasilkan respons yang valid."
191
 
192
- # Hitung penggunaan token
193
  output_length = len(output[0])
194
  new_tokens = output_length - input_length
195
 
@@ -226,7 +195,6 @@ async def force_load_model(background_tasks: BackgroundTasks):
226
  if model is not None:
227
  return {"status": "already_loaded", "message": f"Model {model_id} sudah dimuat"}
228
 
229
- # Lakukan loading di background untuk tidak memblokir API
230
  background_tasks.add_task(load_model)
231
  return {"status": "loading_started", "message": f"Proses memuat model {model_id} telah dimulai"}
232
 
@@ -247,9 +215,7 @@ async def root():
247
  }
248
 
249
 
250
- # Untuk menjalankan dengan uvicorn
251
  if __name__ == "__main__":
252
  import uvicorn
253
- # Mulai server API
254
  logger.info(f"Memulai server API untuk model {model_id}")
255
- uvicorn.run(app, host="0.0.0.0", port=7860) # Port 7860 adalah port default di HF Spaces
 
7
  import logging
8
  from typing import List, Dict, Any, Optional
9
 
 
10
  logging.basicConfig(
11
  level=logging.INFO,
12
  format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
 
15
 
16
  app = FastAPI(title="TinyLlama API", description="API untuk model TinyLlama-1.1B-Chat")
17
 
18
+ model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
19
+ model_dir = "model_cache"
 
20
 
 
21
  tokenizer = None
22
  model = None
23
  is_loading = False
24
 
25
  def load_model():
 
26
  global tokenizer, model, is_loading
27
 
 
28
  if is_loading:
29
  logger.info("Model sedang dimuat oleh proses lain")
30
  return
31
 
 
32
  if tokenizer is None or model is None:
33
  try:
34
  is_loading = True
35
  logger.info(f"Memuat model {model_id}...")
36
 
 
37
  os.makedirs(model_dir, exist_ok=True)
38
 
 
39
  if model is not None:
40
  del model
41
  torch.cuda.empty_cache()
42
  gc.collect()
43
 
 
44
  tokenizer = AutoTokenizer.from_pretrained(
45
  model_id,
46
  cache_dir=model_dir,
47
  use_fast=True,
48
  )
49
 
 
50
  device_map = "auto" if torch.cuda.is_available() else None
51
 
52
  model = AutoModelForCausalLM.from_pretrained(
 
91
  raise HTTPException(status_code=500, detail="Gagal memuat model")
92
 
93
  try:
 
 
94
  system_content = ""
95
 
 
96
  for msg in req.messages:
97
  if msg.role.lower() == "system":
98
  system_content = msg.content
99
  break
100
 
 
101
  messages_text = []
102
 
 
103
  if system_content:
104
  messages_text.append(f"<|system|>\n{system_content}")
105
 
 
106
  for msg in req.messages:
107
  role = msg.role.lower()
108
  content = msg.content
109
 
 
110
  if role == "system":
111
  continue
112
 
 
115
  elif role == "assistant":
116
  messages_text.append(f"<|assistant|>\n{content}")
117
 
 
118
  messages_text.append("<|assistant|>")
119
 
 
120
  prompt = "\n".join(messages_text)
121
 
 
122
  inputs = tokenizer(prompt, return_tensors="pt")
123
  input_length = len(inputs.input_ids[0])
124
 
 
125
  if hasattr(model, 'device'):
126
  inputs = {key: value.to(model.device) for key, value in inputs.items()}
127
 
 
128
  generation_config = {
129
  'max_new_tokens': req.max_tokens,
130
+ 'temperature': req.temperature,
131
+ 'top_p': req.top_p,
132
+ 'do_sample': True if req.temperature > 0 else False,
133
  'pad_token_id': tokenizer.eos_token_id
134
  }
135
 
 
136
  with torch.no_grad():
137
  output = model.generate(
138
  inputs['input_ids'],
139
  **generation_config
140
  )
141
 
 
142
  result = tokenizer.decode(output[0], skip_special_tokens=True)
143
 
 
144
  assistants = result.split("<|assistant|>")
145
  if len(assistants) > 1:
146
  response = assistants[-1].strip()
147
  else:
 
 
148
  user_tokens = result.split("<|user|>")
149
  if len(user_tokens) > 1:
150
  last_part = user_tokens[-1]
151
  if "\n" in last_part:
 
152
  response = "\n".join(last_part.split("\n")[1:]).strip()
153
  else:
154
  response = last_part.strip()
155
  else:
 
156
  prompt_length = len(tokenizer.decode(inputs.input_ids[0], skip_special_tokens=True))
157
  response = result[prompt_length:].strip()
158
 
 
159
  if not response:
160
  response = "Maaf, tidak dapat menghasilkan respons yang valid."
161
 
 
162
  output_length = len(output[0])
163
  new_tokens = output_length - input_length
164
 
 
195
  if model is not None:
196
  return {"status": "already_loaded", "message": f"Model {model_id} sudah dimuat"}
197
 
 
198
  background_tasks.add_task(load_model)
199
  return {"status": "loading_started", "message": f"Proses memuat model {model_id} telah dimulai"}
200
 
 
215
  }
216
 
217
 
 
218
  if __name__ == "__main__":
219
  import uvicorn
 
220
  logger.info(f"Memulai server API untuk model {model_id}")
221
+ uvicorn.run(app, host="0.0.0.0", port=7860)