Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,20 +1,118 @@
|
|
1 |
-
from fastapi import FastAPI, Request
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
from pydantic import BaseModel
|
3 |
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
|
4 |
import torch
|
5 |
import re
|
6 |
import time
|
7 |
|
|
|
|
|
|
|
8 |
app = FastAPI()
|
9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
def split_by_words_and_dot(text, min_words=125, max_words=160, fallback_words=150):
|
11 |
-
import re
|
12 |
words = re.findall(r'\S+|\n', text) # giữ nguyên \n như một "từ"
|
13 |
chunks = []
|
14 |
start = 0
|
15 |
while start < len(words):
|
16 |
end = min(start + max_words, len(words))
|
17 |
-
# Tìm dấu chấm trong khoảng min_words đến max_words
|
18 |
dot_idx = -1
|
19 |
for i in range(start + min_words, min(start + max_words, len(words))):
|
20 |
if words[i] == '.' or (words[i].endswith('.') and words[i] != '\n'):
|
@@ -30,71 +128,45 @@ def split_by_words_and_dot(text, min_words=125, max_words=160, fallback_words=15
|
|
30 |
start = chunk_end
|
31 |
return chunks
|
32 |
|
33 |
-
# Load model
|
34 |
-
model_path = "longvnhue1/facebook-m2m100_418M-fine_tuning"
|
35 |
-
tokenizer = M2M100Tokenizer.from_pretrained(model_path)
|
36 |
-
model = M2M100ForConditionalGeneration.from_pretrained(model_path)
|
37 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
38 |
-
model.to(device)
|
39 |
-
|
40 |
class TranslateRequest(BaseModel):
|
41 |
text: str
|
42 |
source_lang: str
|
43 |
target_lang: str
|
44 |
|
45 |
-
# @app.post("/translate")
|
46 |
-
# def translate_text(req: TranslateRequest):
|
47 |
-
# tokenizer.src_lang = req.source_lang
|
48 |
-
# encoded = tokenizer(req.text, return_tensors="pt", truncation=True, max_length=512).to(device)
|
49 |
-
# generated_tokens = model.generate(
|
50 |
-
# **encoded,
|
51 |
-
# forced_bos_token_id=tokenizer.get_lang_id(req.target_lang),
|
52 |
-
# max_length=512, # tăng lên nếu cần dịch đoạn dài, nhưng không nên quá lớn
|
53 |
-
# num_beams=2, # giảm beam search để nhanh hơn
|
54 |
-
# no_repeat_ngram_size=3,
|
55 |
-
# early_stopping=True
|
56 |
-
# )
|
57 |
-
# translated_text = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
|
58 |
-
# return {
|
59 |
-
# "source_text": req.text,
|
60 |
-
# "translated_text": translated_text,
|
61 |
-
# "src_lang": req.source_lang,
|
62 |
-
# "tgt_lang": req.target_lang
|
63 |
-
# }
|
64 |
-
|
65 |
-
|
66 |
@app.post("/translate")
|
67 |
def translate_text(req: TranslateRequest):
|
68 |
tokenizer.src_lang = req.source_lang
|
69 |
-
text_chunks = split_by_words_and_dot(req.text
|
70 |
translated_chunks = []
|
71 |
timing_info = []
|
72 |
|
|
|
|
|
73 |
for idx, chunk in enumerate(text_chunks):
|
74 |
-
start_time = time.perf_counter()
|
75 |
|
76 |
encoded = tokenizer(chunk, return_tensors="pt", truncation=True, max_length=256).to(device)
|
77 |
with torch.inference_mode():
|
78 |
generated_tokens = model.generate(
|
79 |
**encoded,
|
80 |
forced_bos_token_id=tokenizer.get_lang_id(req.target_lang),
|
81 |
-
max_length=256,
|
82 |
-
num_beams=2,
|
83 |
no_repeat_ngram_size=3,
|
84 |
)
|
85 |
translated_text = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
|
86 |
translated_chunks.append(translated_text)
|
87 |
|
88 |
-
end_time = time.perf_counter()
|
89 |
-
|
90 |
-
timing_info.append(f"Translated chunk {idx+1}/{len(text_chunks)} in {elapsed:.3f} seconds")
|
91 |
|
92 |
-
|
|
|
93 |
print(timing_info)
|
94 |
|
95 |
return {
|
96 |
"source_text": req.text,
|
97 |
-
"translated_text":
|
98 |
"src_lang": req.source_lang,
|
99 |
"tgt_lang": req.target_lang,
|
100 |
-
}
|
|
|
1 |
+
# from fastapi import FastAPI, Request
|
2 |
+
# from pydantic import BaseModel
|
3 |
+
# from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
|
4 |
+
# import torch
|
5 |
+
# import re
|
6 |
+
# import time
|
7 |
+
|
8 |
+
# app = FastAPI()
|
9 |
+
|
10 |
+
# def split_by_words_and_dot(text, min_words=125, max_words=160, fallback_words=150):
|
11 |
+
# import re
|
12 |
+
# words = re.findall(r'\S+|\n', text) # giữ nguyên \n như một "từ"
|
13 |
+
# chunks = []
|
14 |
+
# start = 0
|
15 |
+
# while start < len(words):
|
16 |
+
# end = min(start + max_words, len(words))
|
17 |
+
# # Tìm dấu chấm trong khoảng min_words đến max_words
|
18 |
+
# dot_idx = -1
|
19 |
+
# for i in range(start + min_words, min(start + max_words, len(words))):
|
20 |
+
# if words[i] == '.' or (words[i].endswith('.') and words[i] != '\n'):
|
21 |
+
# dot_idx = i
|
22 |
+
# if dot_idx != -1:
|
23 |
+
# chunk_end = dot_idx + 1
|
24 |
+
# elif end - start > fallback_words:
|
25 |
+
# chunk_end = start + fallback_words
|
26 |
+
# else:
|
27 |
+
# chunk_end = end
|
28 |
+
# chunk = ' '.join([w if w != '\n' else '\n' for w in words[start:chunk_end]]).replace(' \n ', '\n').replace(' \n', '\n').replace('\n ', '\n')
|
29 |
+
# chunks.append(chunk.strip())
|
30 |
+
# start = chunk_end
|
31 |
+
# return chunks
|
32 |
+
|
33 |
+
# # Load model
|
34 |
+
# model_path = "longvnhue1/facebook-m2m100_418M-fine_tuning"
|
35 |
+
# tokenizer = M2M100Tokenizer.from_pretrained(model_path)
|
36 |
+
# model = M2M100ForConditionalGeneration.from_pretrained(model_path)
|
37 |
+
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
38 |
+
# model.to(device)
|
39 |
+
|
40 |
+
# class TranslateRequest(BaseModel):
|
41 |
+
# text: str
|
42 |
+
# source_lang: str
|
43 |
+
# target_lang: str
|
44 |
+
|
45 |
+
|
46 |
+
# @app.post("/translate")
|
47 |
+
# def translate_text(req: TranslateRequest):
|
48 |
+
# tokenizer.src_lang = req.source_lang
|
49 |
+
# text_chunks = split_by_words_and_dot(req.text, min_words=125, max_words=160, fallback_words=150)
|
50 |
+
# translated_chunks = []
|
51 |
+
# timing_info = []
|
52 |
+
|
53 |
+
# for idx, chunk in enumerate(text_chunks):
|
54 |
+
# start_time = time.perf_counter() # Bắt đầu đếm thời gian
|
55 |
+
|
56 |
+
# encoded = tokenizer(chunk, return_tensors="pt", truncation=True, max_length=256).to(device)
|
57 |
+
# with torch.inference_mode():
|
58 |
+
# generated_tokens = model.generate(
|
59 |
+
# **encoded,
|
60 |
+
# forced_bos_token_id=tokenizer.get_lang_id(req.target_lang),
|
61 |
+
# max_length=256,
|
62 |
+
# num_beams=2,
|
63 |
+
# no_repeat_ngram_size=3,
|
64 |
+
# )
|
65 |
+
# translated_text = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
|
66 |
+
# translated_chunks.append(translated_text)
|
67 |
+
|
68 |
+
# end_time = time.perf_counter() # Kết thúc đếm thời gian
|
69 |
+
# elapsed = end_time - start_time
|
70 |
+
# timing_info.append(f"Translated chunk {idx+1}/{len(text_chunks)} in {elapsed:.3f} seconds")
|
71 |
+
|
72 |
+
# full_translation = "\n".join(translated_chunks)
|
73 |
+
# print(timing_info)
|
74 |
+
|
75 |
+
# return {
|
76 |
+
# "source_text": req.text,
|
77 |
+
# "translated_text": full_translation,
|
78 |
+
# "src_lang": req.source_lang,
|
79 |
+
# "tgt_lang": req.target_lang,
|
80 |
+
# }
|
81 |
+
|
82 |
+
|
83 |
+
from fastapi import FastAPI
|
84 |
from pydantic import BaseModel
|
85 |
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
|
86 |
import torch
|
87 |
import re
|
88 |
import time
|
89 |
|
90 |
+
# Giới hạn số luồng CPU nếu không dùng GPU
|
91 |
+
torch.set_num_threads(1)
|
92 |
+
|
93 |
app = FastAPI()
|
94 |
|
95 |
+
@app.on_event("startup")
|
96 |
+
def startup_event():
|
97 |
+
print("🔁 Loading model...")
|
98 |
+
|
99 |
+
global tokenizer, model, device
|
100 |
+
|
101 |
+
model_path = "longvnhue1/facebook-m2m100_418M-fine_tuning"
|
102 |
+
tokenizer = M2M100Tokenizer.from_pretrained(model_path)
|
103 |
+
model = M2M100ForConditionalGeneration.from_pretrained(model_path)
|
104 |
+
|
105 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
106 |
+
model.to(device)
|
107 |
+
|
108 |
+
print("✅ Model loaded and ready.")
|
109 |
+
|
110 |
def split_by_words_and_dot(text, min_words=125, max_words=160, fallback_words=150):
|
|
|
111 |
words = re.findall(r'\S+|\n', text) # giữ nguyên \n như một "từ"
|
112 |
chunks = []
|
113 |
start = 0
|
114 |
while start < len(words):
|
115 |
end = min(start + max_words, len(words))
|
|
|
116 |
dot_idx = -1
|
117 |
for i in range(start + min_words, min(start + max_words, len(words))):
|
118 |
if words[i] == '.' or (words[i].endswith('.') and words[i] != '\n'):
|
|
|
128 |
start = chunk_end
|
129 |
return chunks
|
130 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
class TranslateRequest(BaseModel):
|
132 |
text: str
|
133 |
source_lang: str
|
134 |
target_lang: str
|
135 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
@app.post("/translate")
|
137 |
def translate_text(req: TranslateRequest):
|
138 |
tokenizer.src_lang = req.source_lang
|
139 |
+
text_chunks = split_by_words_and_dot(req.text)
|
140 |
translated_chunks = []
|
141 |
timing_info = []
|
142 |
|
143 |
+
global_start = time.perf_counter()
|
144 |
+
|
145 |
for idx, chunk in enumerate(text_chunks):
|
146 |
+
start_time = time.perf_counter()
|
147 |
|
148 |
encoded = tokenizer(chunk, return_tensors="pt", truncation=True, max_length=256).to(device)
|
149 |
with torch.inference_mode():
|
150 |
generated_tokens = model.generate(
|
151 |
**encoded,
|
152 |
forced_bos_token_id=tokenizer.get_lang_id(req.target_lang),
|
153 |
+
max_length=256, # <-- Giữ mức vừa phải
|
154 |
+
num_beams=2, # <-- Đơn giản beam search
|
155 |
no_repeat_ngram_size=3,
|
156 |
)
|
157 |
translated_text = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
|
158 |
translated_chunks.append(translated_text)
|
159 |
|
160 |
+
end_time = time.perf_counter()
|
161 |
+
timing_info.append(f"Translated chunk {idx+1}/{len(text_chunks)} in {end_time - start_time:.3f} seconds")
|
|
|
162 |
|
163 |
+
global_end = time.perf_counter()
|
164 |
+
print(f"⚡️ Total translation time: {global_end - global_start:.3f} seconds")
|
165 |
print(timing_info)
|
166 |
|
167 |
return {
|
168 |
"source_text": req.text,
|
169 |
+
"translated_text": "\n".join(translated_chunks),
|
170 |
"src_lang": req.source_lang,
|
171 |
"tgt_lang": req.target_lang,
|
172 |
+
}
|