longvnhue1 commited on
Commit
b701a5b
·
1 Parent(s): e263c34
Files changed (1) hide show
  1. app.py +32 -58
app.py CHANGED
@@ -1,47 +1,32 @@
1
  from fastapi import FastAPI, Request
2
  from pydantic import BaseModel
3
- import time
4
  from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
5
  import torch
6
  import re
7
- #fastapi
8
- app = FastAPI()
9
 
10
- def simple_sentence_tokenize(text):
11
- # Tách câu theo dấu chấm, hỏi, chấm than, theo sau là khoảng trắng hoặc xuống dòng
12
- sentence_endings = re.compile(r'(?<=[.!?])\s+')
13
- return sentence_endings.split(text)
14
 
15
- def split_text_by_sentences(text, min_words=150, max_words=200, fallback_words=180):
16
- sentences = simple_sentence_tokenize(text)
 
17
  chunks = []
18
- current_chunk = []
19
- current_word_count = 0
20
-
21
- for sentence in sentences:
22
- sentence = sentence.strip()
23
- if not sentence:
24
- continue
25
-
26
- word_count = len(sentence.split())
27
-
28
- if current_word_count + word_count <= max_words:
29
- current_chunk.append(sentence)
30
- current_word_count += word_count
31
  else:
32
- if current_word_count >= min_words:
33
- chunks.append(' '.join(current_chunk))
34
- current_chunk = [sentence]
35
- current_word_count = word_count
36
- else:
37
- if current_chunk:
38
- chunks.append(' '.join(current_chunk))
39
- current_chunk = [sentence]
40
- current_word_count = word_count
41
-
42
- if current_chunk:
43
- chunks.append(' '.join(current_chunk))
44
-
45
  return chunks
46
 
47
  # Load model
@@ -80,31 +65,20 @@ class TranslateRequest(BaseModel):
80
  @app.post("/translate")
81
  def translate_text(req: TranslateRequest):
82
  tokenizer.src_lang = req.source_lang
83
- text_chunks = split_text_by_sentences(req.text, min_words=150, max_words=200, fallback_words=180)
84
-
85
  translated_chunks = []
86
- total_gen_time = 0
87
-
88
- with torch.inference_mode():
89
- for chunk in text_chunks:
90
- encoded = tokenizer(chunk, return_tensors="pt", truncation=True, max_length=256).to(device)
91
-
92
- start = time.time()
93
- generated_tokens = model.generate(
94
- **encoded,
95
- forced_bos_token_id=tokenizer.get_lang_id(req.target_lang),
96
- max_length=256,
97
- num_beams=1,
98
- no_repeat_ngram_size=3,
99
- )
100
- total_gen_time += time.time() - start
101
-
102
- translation = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
103
- translated_chunks.append(translation)
104
- print(f"Translated chunk: {translation}")
105
- print("Total generating time:", total_gen_time)
106
  full_translation = "\n".join(translated_chunks)
107
-
108
  return {
109
  "source_text": req.text,
110
  "translated_text": full_translation,
 
1
  from fastapi import FastAPI, Request
2
  from pydantic import BaseModel
 
3
  from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
4
  import torch
5
  import re
 
 
6
 
7
+ app = FastAPI()
 
 
 
8
 
9
+ def split_by_words_and_dot(text, min_words=125, max_words=160, fallback_words=150):
10
+ import re
11
+ words = re.findall(r'\S+|\n', text) # giữ nguyên \n như một "từ"
12
  chunks = []
13
+ start = 0
14
+ while start < len(words):
15
+ end = min(start + max_words, len(words))
16
+ # Tìm dấu chấm trong khoảng min_words đến max_words
17
+ dot_idx = -1
18
+ for i in range(start + min_words, min(start + max_words, len(words))):
19
+ if words[i] == '.' or (words[i].endswith('.') and words[i] != '\n'):
20
+ dot_idx = i
21
+ if dot_idx != -1:
22
+ chunk_end = dot_idx + 1
23
+ elif end - start > fallback_words:
24
+ chunk_end = start + fallback_words
 
25
  else:
26
+ chunk_end = end
27
+ chunk = ' '.join([w if w != '\n' else '\n' for w in words[start:chunk_end]]).replace(' \n ', '\n').replace(' \n', '\n').replace('\n ', '\n')
28
+ chunks.append(chunk.strip())
29
+ start = chunk_end
 
 
 
 
 
 
 
 
 
30
  return chunks
31
 
32
  # Load model
 
65
  @app.post("/translate")
66
  def translate_text(req: TranslateRequest):
67
  tokenizer.src_lang = req.source_lang
68
+ text_chunks = split_by_words_and_dot(req.text, min_words=125, max_words=160, fallback_words=150)
 
69
  translated_chunks = []
70
+ for chunk in text_chunks:
71
+ encoded = tokenizer(chunk, return_tensors="pt", truncation=True, max_length=256).to(device)
72
+ generated_tokens = model.generate(
73
+ **encoded,
74
+ forced_bos_token_id=tokenizer.get_lang_id(req.target_lang),
75
+ max_length=256,
76
+ num_beams=2,
77
+ no_repeat_ngram_size=3,
78
+ )
79
+ translated_text = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
80
+ translated_chunks.append(translated_text)
 
 
 
 
 
 
 
 
 
81
  full_translation = "\n".join(translated_chunks)
 
82
  return {
83
  "source_text": req.text,
84
  "translated_text": full_translation,