Woziii commited on
Commit
b59dc9b
·
verified ·
1 Parent(s): ed870be

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -27
app.py CHANGED
@@ -25,7 +25,7 @@ model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torc
25
  model.config.pad_token_id = tokenizer.pad_token_id
26
 
27
  MAX_MAX_NEW_TOKENS = 250
28
- DEFAULT_MAX_NEW_TOKENS = 70
29
  MAX_INPUT_TOKEN_LENGTH = 2048
30
 
31
 
@@ -161,35 +161,47 @@ def truncate_to_questions(text, max_questions):
161
 
162
 
163
 
164
- def post_process_response(response, is_short_response, max_questions=2):
165
- # Limiter au nombre spécifié de questions, quelle que soit la longueur de la réponse
166
- truncated_response = truncate_to_questions(response, max_questions)
 
 
 
 
 
 
 
167
 
168
- # Diviser la réponse en phrases
169
- sentences = re.split(r'(?<=[.!?])\s+', truncated_response)
170
 
171
- # Fonction pour compter les tokens (approximation)
172
- def count_tokens(text):
173
- return len(text.split())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
 
175
  # Déterminer la limite de tokens en fonction du type de réponse
176
  if is_short_response:
177
- token_limit = 70
178
  else:
179
- token_limit = 150 # Pour les réponses moyennes, ajustez si nécessaire
180
-
181
- # Construire la réponse finale
182
- final_response = ""
183
- for sentence in sentences:
184
- if count_tokens(final_response + sentence) <= token_limit:
185
- final_response += sentence + " "
186
- else:
187
- break
188
 
189
- # S'assurer que la réponse se termine par une ponctuation appropriée
190
- final_response = final_response.strip()
191
- if final_response and final_response[-1] not in ".!?":
192
- final_response += "."
193
 
194
  return final_response
195
 
@@ -273,15 +285,17 @@ def generate(
273
  outputs = []
274
  for text in streamer:
275
  outputs.append(text)
276
- partial_output = post_process_response("".join(outputs), response_type == "short")
 
277
 
278
- if response_type == "long" and not check_coherence(partial_output):
279
  yield "Je m'excuse, ma réponse manquait de cohérence. Pouvez-vous reformuler votre question ?"
280
  return
281
 
282
- yield partial_output
283
 
284
- yield post_process_response("".join(outputs), response_type == "short")
 
285
 
286
  except Exception as e:
287
  print(f"Une erreur s'est produite : {str(e)}")
 
25
  model.config.pad_token_id = tokenizer.pad_token_id
26
 
27
  MAX_MAX_NEW_TOKENS = 250
28
+ DEFAULT_MAX_NEW_TOKENS = 50
29
  MAX_INPUT_TOKEN_LENGTH = 2048
30
 
31
 
 
161
 
162
 
163
 
164
+
165
+ def find_logical_stop(text, max_tokens):
166
+ """
167
+ Trouve un point d'arrêt logique dans le texte, sans dépasser max_tokens.
168
+ """
169
+ # Définir les motifs de fin logiques
170
+ end_patterns = r'(?<=[.!?])\s+|\n|\. |\! |\? '
171
+
172
+ # Diviser le texte en segments logiques
173
+ segments = re.split(end_patterns, text)
174
 
175
+ current_length = 0
176
+ result = ""
177
 
178
+ for segment in segments:
179
+ segment_tokens = len(segment.split())
180
+ if current_length + segment_tokens <= max_tokens:
181
+ result += segment + " "
182
+ current_length += segment_tokens
183
+ else:
184
+ break
185
+
186
+ # Nettoyer et finaliser le résultat
187
+ result = result.strip()
188
+ if result and result[-1] not in ".!?":
189
+ result += "."
190
+
191
+ return result
192
+
193
+ def post_process_response(response, is_short_response, max_questions=2):
194
+ # Limiter au nombre spécifié de questions
195
+ truncated_response = truncate_to_questions(response, max_questions)
196
 
197
  # Déterminer la limite de tokens en fonction du type de réponse
198
  if is_short_response:
199
+ max_tokens = 70
200
  else:
201
+ max_tokens = 150 # Ajustez selon vos besoins
 
 
 
 
 
 
 
 
202
 
203
+ # Trouver un point d'arrêt logique
204
+ final_response = find_logical_stop(truncated_response, max_tokens)
 
 
205
 
206
  return final_response
207
 
 
285
  outputs = []
286
  for text in streamer:
287
  outputs.append(text)
288
+ current_output = "".join(outputs)
289
+ processed_output = post_process_response(current_output, response_type == "short")
290
 
291
+ if response_type == "long" and not check_coherence(processed_output):
292
  yield "Je m'excuse, ma réponse manquait de cohérence. Pouvez-vous reformuler votre question ?"
293
  return
294
 
295
+ yield processed_output
296
 
297
+ final_output = post_process_response("".join(outputs), response_type == "short")
298
+ yield final_output
299
 
300
  except Exception as e:
301
  print(f"Une erreur s'est produite : {str(e)}")