Woziii commited on
Commit
e3b2117
·
verified ·
1 Parent(s): b59dc9b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -116
app.py CHANGED
@@ -11,21 +11,11 @@ import json
11
  import os
12
 
13
  model_name = "Woziii/llama-3-8b-chat-me"
14
-
15
- # Initialiser le tokenizer
16
- tokenizer = AutoTokenizer.from_pretrained(model_name)
17
-
18
- # Configurer le pad token
19
- tokenizer.pad_token = tokenizer.eos_token
20
-
21
- # Initialiser le modèle
22
  model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)
23
-
24
- # Configurer le pad token ID du modèle
25
- model.config.pad_token_id = tokenizer.pad_token_id
26
 
27
  MAX_MAX_NEW_TOKENS = 250
28
- DEFAULT_MAX_NEW_TOKENS = 50
29
  MAX_INPUT_TOKEN_LENGTH = 2048
30
 
31
 
@@ -159,53 +149,17 @@ def truncate_to_questions(text, max_questions):
159
 
160
  return ' '.join(truncated_sentences)
161
 
162
-
163
-
164
-
165
- def find_logical_stop(text, max_tokens):
166
- """
167
- Trouve un point d'arrêt logique dans le texte, sans dépasser max_tokens.
168
- """
169
- # Définir les motifs de fin logiques
170
- end_patterns = r'(?<=[.!?])\s+|\n|\. |\! |\? '
171
-
172
- # Diviser le texte en segments logiques
173
- segments = re.split(end_patterns, text)
174
-
175
- current_length = 0
176
- result = ""
177
-
178
- for segment in segments:
179
- segment_tokens = len(segment.split())
180
- if current_length + segment_tokens <= max_tokens:
181
- result += segment + " "
182
- current_length += segment_tokens
183
- else:
184
- break
185
-
186
- # Nettoyer et finaliser le résultat
187
- result = result.strip()
188
- if result and result[-1] not in ".!?":
189
- result += "."
190
-
191
- return result
192
-
193
  def post_process_response(response, is_short_response, max_questions=2):
194
- # Limiter au nombre spécifié de questions
195
  truncated_response = truncate_to_questions(response, max_questions)
196
 
197
- # Déterminer la limite de tokens en fonction du type de réponse
198
  if is_short_response:
199
- max_tokens = 70
200
- else:
201
- max_tokens = 150 # Ajustez selon vos besoins
202
 
203
- # Trouver un point d'arrêt logique
204
- final_response = find_logical_stop(truncated_response, max_tokens)
205
-
206
- return final_response
207
-
208
-
209
 
210
  def check_coherence(response):
211
  sentences = re.split(r'(?<=[.!?])\s+', response)
@@ -214,7 +168,7 @@ def check_coherence(response):
214
  return False
215
  return True
216
 
217
- @spaces.GPU(duration=180)
218
  def generate(
219
  message: str,
220
  chat_history: list[tuple[str, str]],
@@ -237,73 +191,57 @@ def generate(
237
  if response_type == "short":
238
  max_new_tokens = max(70, max_new_tokens)
239
  elif response_type == "long":
240
- max_new_tokens = min(max(200, max_new_tokens), 250)
241
  else: # medium
242
- max_new_tokens = min(max(70, max_new_tokens), 150)
243
 
244
- try:
245
- conversation = []
246
-
247
- # Ajout du system prompt et du LUCAS_KNOWLEDGE_BASE
248
- enhanced_system_prompt = f"{system_prompt}\n\n{LUCAS_KNOWLEDGE_BASE}"
249
- conversation.append({"role": "system", "content": enhanced_system_prompt})
250
-
251
- # Ajout des 5 dernières interactions complètes (user uniquement)
252
- for user, assistant in chat_history[-5:]:
253
- conversation.append({"role": "user", "content": user})
254
-
255
- # Ajout du message actuel de l'utilisateur
256
- conversation.append({"role": "user", "content": message})
257
-
258
- input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
259
- attention_mask = input_ids.ne(tokenizer.pad_token_id).long()
260
 
261
- if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
262
- input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
263
- attention_mask = attention_mask[:, -MAX_INPUT_TOKEN_LENGTH:]
264
- gr.Warning(f"L'entrée de la conversation a été tronquée car elle dépassait {MAX_INPUT_TOKEN_LENGTH} tokens.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
265
 
266
- input_ids = input_ids.to(model.device)
267
- attention_mask = attention_mask.to(model.device)
268
-
269
- streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
270
 
271
- generate_kwargs = dict(
272
- input_ids=input_ids,
273
- attention_mask=attention_mask,
274
- streamer=streamer,
275
- max_new_tokens=max_new_tokens,
276
- do_sample=True,
277
- top_p=top_p,
278
- temperature=temperature,
279
- num_beams=1,
280
- )
281
 
282
- t = Thread(target=model.generate, kwargs=generate_kwargs)
283
- t.start()
284
-
285
- outputs = []
286
- for text in streamer:
287
- outputs.append(text)
288
- current_output = "".join(outputs)
289
- processed_output = post_process_response(current_output, response_type == "short")
290
-
291
- if response_type == "long" and not check_coherence(processed_output):
292
- yield "Je m'excuse, ma réponse manquait de cohérence. Pouvez-vous reformuler votre question ?"
293
- return
294
-
295
- yield processed_output
296
-
297
- final_output = post_process_response("".join(outputs), response_type == "short")
298
- yield final_output
299
-
300
- except Exception as e:
301
- print(f"Une erreur s'est produite : {str(e)}")
302
- yield "Désolé, une erreur s'est produite. Veuillez réessayer."
303
 
304
- finally:
305
- # Nettoyage de la mémoire GPU
306
- torch.cuda.empty_cache()
307
 
308
  def vote(data: gr.LikeData, history):
309
  user_input = history[-1][0] if history else ""
@@ -449,4 +387,4 @@ N'hésitez pas à aborder des sujets variés, allant de l'intelligence artificie
449
  chat_interface.render()
450
  chat_interface.chatbot.like(vote, [chat_interface.chatbot], None)
451
 
452
- demo.queue(max_size=20, default_concurrency_limit=2).launch(max_threads=10)
 
11
  import os
12
 
13
  model_name = "Woziii/llama-3-8b-chat-me"
 
 
 
 
 
 
 
 
14
  model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)
15
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
 
 
16
 
17
  MAX_MAX_NEW_TOKENS = 250
18
+ DEFAULT_MAX_NEW_TOKENS = 70
19
  MAX_INPUT_TOKEN_LENGTH = 2048
20
 
21
 
 
149
 
150
  return ' '.join(truncated_sentences)
151
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  def post_process_response(response, is_short_response, max_questions=2):
153
+ # Limiter au nombre spécifié de questions, quelle que soit la longueur de la réponse
154
  truncated_response = truncate_to_questions(response, max_questions)
155
 
156
+ # Appliquer la limitation de longueur si nécessaire pour les réponses courtes
157
  if is_short_response:
158
+ sentences = re.split(r'(?<=[.!?])\s+', truncated_response)
159
+ if len(sentences) > 2:
160
+ return ' '.join(sentences[:2]).strip()
161
 
162
+ return truncated_response.strip()
 
 
 
 
 
163
 
164
  def check_coherence(response):
165
  sentences = re.split(r'(?<=[.!?])\s+', response)
 
168
  return False
169
  return True
170
 
171
+ @spaces.GPU(duration=120)
172
  def generate(
173
  message: str,
174
  chat_history: list[tuple[str, str]],
 
191
  if response_type == "short":
192
  max_new_tokens = max(70, max_new_tokens)
193
  elif response_type == "long":
194
+ max_new_tokens = min(max(200, max_new_tokens), 300)
195
  else: # medium
196
+ max_new_tokens = min(max(100, max_new_tokens), 150)
197
 
198
+ conversation = []
199
+
200
+ # Ajout du system prompt et du LUCAS_KNOWLEDGE_BASE
201
+ enhanced_system_prompt = f"{system_prompt}\n\n{LUCAS_KNOWLEDGE_BASE}"
202
+ conversation.append({"role": "system", "content": enhanced_system_prompt})
203
+
204
+ # Ajout des 5 derniers inputs utilisateur uniquement
205
+ for user, _ in chat_history[-5:]:
206
+ conversation.append({"role": "user", "content": user})
207
+
208
+ # Ajout du message actuel de l'utilisateur
209
+ conversation.append({"role": "user", "content": message})
 
 
 
 
210
 
211
+ input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
212
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
213
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
214
+ gr.Warning(f"L'entrée de la conversation a été tronquée car elle d��passait {MAX_INPUT_TOKEN_LENGTH} tokens.")
215
+
216
+ input_ids = input_ids.to(model.device)
217
+ streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
218
+
219
+ generate_kwargs = dict(
220
+ input_ids=input_ids,
221
+ streamer=streamer,
222
+ max_new_tokens=max_new_tokens,
223
+ do_sample=True,
224
+ top_p=top_p,
225
+ temperature=temperature,
226
+ num_beams=1,
227
+ )
228
+
229
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
230
+ t.start()
231
+
232
+ outputs = []
233
+ for text in streamer:
234
+ outputs.append(text)
235
+ partial_output = post_process_response("".join(outputs), response_type == "short")
236
 
237
+ if response_type == "long" and not check_coherence(partial_output):
238
+ yield "Je m'excuse, ma réponse manquait de cohérence. Pouvez-vous reformuler votre question ?"
239
+ return
 
240
 
241
+ yield partial_output
 
 
 
 
 
 
 
 
 
242
 
243
+ yield post_process_response("".join(outputs), response_type == "short")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
 
 
 
 
245
 
246
  def vote(data: gr.LikeData, history):
247
  user_input = history[-1][0] if history else ""
 
387
  chat_interface.render()
388
  chat_interface.chatbot.like(vote, [chat_interface.chatbot], None)
389
 
390
+ demo.queue(max_size=20, default_concurrency_limit=2).launch(max_threads=10)