Woziii commited on
Commit
8ebcf37
·
verified ·
1 Parent(s): 4a00ab1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -43
app.py CHANGED
@@ -10,6 +10,9 @@ from huggingface_hub import HfApi, hf_hub_download
10
  import json
11
  import os
12
 
 
 
 
13
  model_name = "Woziii/llama-3-8b-chat-me"
14
  model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)
15
  tokenizer = AutoTokenizer.from_pretrained(model_name)
@@ -168,7 +171,7 @@ def check_coherence(response):
168
  return False
169
  return True
170
 
171
- @spaces.GPU(duration=120)
172
  def generate(
173
  message: str,
174
  chat_history: list[tuple[str, str]],
@@ -195,53 +198,67 @@ def generate(
195
  else: # medium
196
  max_new_tokens = min(max(100, max_new_tokens), 150)
197
 
198
- conversation = []
199
-
200
- # Ajout du system prompt et du LUCAS_KNOWLEDGE_BASE
201
- enhanced_system_prompt = f"{system_prompt}\n\n{LUCAS_KNOWLEDGE_BASE}"
202
- conversation.append({"role": "system", "content": enhanced_system_prompt})
203
-
204
- # Ajout des 5 derniers inputs utilisateur uniquement
205
- for user, _ in chat_history[-5:]:
206
- conversation.append({"role": "user", "content": user})
207
-
208
- # Ajout du message actuel de l'utilisateur
209
- conversation.append({"role": "user", "content": message})
 
210
 
211
- input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
212
- if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
213
- input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
214
- gr.Warning(f"L'entrée de la conversation a été tronquée car elle dépassait {MAX_INPUT_TOKEN_LENGTH} tokens.")
215
-
216
- input_ids = input_ids.to(model.device)
217
- streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
218
-
219
- generate_kwargs = dict(
220
- input_ids=input_ids,
221
- streamer=streamer,
222
- max_new_tokens=max_new_tokens,
223
- do_sample=True,
224
- top_p=top_p,
225
- temperature=temperature,
226
- num_beams=1,
227
- )
228
-
229
- t = Thread(target=model.generate, kwargs=generate_kwargs)
230
- t.start()
231
-
232
- outputs = []
233
- for text in streamer:
234
- outputs.append(text)
235
- partial_output = post_process_response("".join(outputs), response_type == "short")
236
 
237
- if response_type == "long" and not check_coherence(partial_output):
238
- yield "Je m'excuse, ma réponse manquait de cohérence. Pouvez-vous reformuler votre question ?"
239
- return
 
240
 
241
- yield partial_output
 
 
 
 
 
 
 
 
 
242
 
243
- yield post_process_response("".join(outputs), response_type == "short")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
 
 
 
 
245
 
246
  def vote(data: gr.LikeData, history):
247
  user_input = history[-1][0] if history else ""
 
10
  import json
11
  import os
12
 
13
+ tokenizer.pad_token = tokenizer.eos_token
14
+ model.config.pad_token_id = tokenizer.pad_token_id
15
+
16
  model_name = "Woziii/llama-3-8b-chat-me"
17
  model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)
18
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
171
  return False
172
  return True
173
 
174
+ @spaces.GPU(duration=180)
175
  def generate(
176
  message: str,
177
  chat_history: list[tuple[str, str]],
 
198
  else: # medium
199
  max_new_tokens = min(max(100, max_new_tokens), 150)
200
 
201
+ try:
202
+ conversation = []
203
+
204
+ # Ajout du system prompt et du LUCAS_KNOWLEDGE_BASE
205
+ enhanced_system_prompt = f"{system_prompt}\n\n{LUCAS_KNOWLEDGE_BASE}"
206
+ conversation.append({"role": "system", "content": enhanced_system_prompt})
207
+
208
+ # Ajout des 5 dernières interactions complètes (user uniquement)
209
+ for user, assistant in chat_history[-MAX_HISTORY_LENGTH:]:
210
+ conversation.append({"role": "user", "content": user})
211
+
212
+ # Ajout du message actuel de l'utilisateur
213
+ conversation.append({"role": "user", "content": message})
214
 
215
+ input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
216
+ attention_mask = input_ids.ne(tokenizer.pad_token_id).long()
217
+
218
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
219
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
220
+ attention_mask = attention_mask[:, -MAX_INPUT_TOKEN_LENGTH:]
221
+ gr.Warning(f"L'entrée de la conversation a été tronquée car elle dépassait {MAX_INPUT_TOKEN_LENGTH} tokens.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
 
223
+ input_ids = input_ids.to(model.device)
224
+ attention_mask = attention_mask.to(model.device)
225
+
226
+ streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
227
 
228
+ generate_kwargs = dict(
229
+ input_ids=input_ids,
230
+ attention_mask=attention_mask,
231
+ streamer=streamer,
232
+ max_new_tokens=max_new_tokens,
233
+ do_sample=True,
234
+ top_p=top_p,
235
+ temperature=temperature,
236
+ num_beams=1,
237
+ )
238
 
239
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
240
+ t.start()
241
+
242
+ outputs = []
243
+ for text in streamer:
244
+ outputs.append(text)
245
+ partial_output = post_process_response("".join(outputs), response_type == "short")
246
+
247
+ if response_type == "long" and not check_coherence(partial_output):
248
+ yield "Je m'excuse, ma réponse manquait de cohérence. Pouvez-vous reformuler votre question ?"
249
+ return
250
+
251
+ yield partial_output
252
+
253
+ yield post_process_response("".join(outputs), response_type == "short")
254
+
255
+ except Exception as e:
256
+ print(f"Une erreur s'est produite : {str(e)}")
257
+ yield "Désolé, une erreur s'est produite. Veuillez réessayer."
258
 
259
+ finally:
260
+ # Nettoyage de la mémoire GPU
261
+ torch.cuda.empty_cache()
262
 
263
  def vote(data: gr.LikeData, history):
264
  user_input = history[-1][0] if history else ""