longvnhue1's picture
Update app.py
18f57a5 verified
# from fastapi import FastAPI, Request
# from pydantic import BaseModel
# from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
# import torch
# import re
# import time
# app = FastAPI()
# def split_by_words_and_dot(text, min_words=125, max_words=160, fallback_words=150):
# import re
# words = re.findall(r'\S+|\n', text) # giữ nguyên \n như một "từ"
# chunks = []
# start = 0
# while start < len(words):
# end = min(start + max_words, len(words))
# # Tìm dấu chấm trong khoảng min_words đến max_words
# dot_idx = -1
# for i in range(start + min_words, min(start + max_words, len(words))):
# if words[i] == '.' or (words[i].endswith('.') and words[i] != '\n'):
# dot_idx = i
# if dot_idx != -1:
# chunk_end = dot_idx + 1
# elif end - start > fallback_words:
# chunk_end = start + fallback_words
# else:
# chunk_end = end
# chunk = ' '.join([w if w != '\n' else '\n' for w in words[start:chunk_end]]).replace(' \n ', '\n').replace(' \n', '\n').replace('\n ', '\n')
# chunks.append(chunk.strip())
# start = chunk_end
# return chunks
# # Load model
# model_path = "longvnhue1/facebook-m2m100_418M-fine_tuning"
# tokenizer = M2M100Tokenizer.from_pretrained(model_path)
# model = M2M100ForConditionalGeneration.from_pretrained(model_path)
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model.to(device)
# class TranslateRequest(BaseModel):
# text: str
# source_lang: str
# target_lang: str
# @app.post("/translate")
# def translate_text(req: TranslateRequest):
# tokenizer.src_lang = req.source_lang
# text_chunks = split_by_words_and_dot(req.text, min_words=125, max_words=160, fallback_words=150)
# translated_chunks = []
# timing_info = []
# for idx, chunk in enumerate(text_chunks):
# start_time = time.perf_counter() # Bắt đầu đếm thời gian
# encoded = tokenizer(chunk, return_tensors="pt", truncation=True, max_length=256).to(device)
# with torch.inference_mode():
# generated_tokens = model.generate(
# **encoded,
# forced_bos_token_id=tokenizer.get_lang_id(req.target_lang),
# max_length=256,
# num_beams=2,
# no_repeat_ngram_size=3,
# )
# translated_text = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
# translated_chunks.append(translated_text)
# end_time = time.perf_counter() # Kết thúc đếm thời gian
# elapsed = end_time - start_time
# timing_info.append(f"Translated chunk {idx+1}/{len(text_chunks)} in {elapsed:.3f} seconds")
# full_translation = "\n".join(translated_chunks)
# print(timing_info)
# return {
# "source_text": req.text,
# "translated_text": full_translation,
# "src_lang": req.source_lang,
# "tgt_lang": req.target_lang,
# }
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
import torch
import re
import time
# Limit CPU thread
torch.set_num_threads(1)
app = FastAPI()
@app.on_event("startup")
def startup_event():
global tokenizer, model, device
model_path = "longvnhue1/facebook-m2m100_418M-fine_tuning"
tokenizer = M2M100Tokenizer.from_pretrained(model_path)
model = M2M100ForConditionalGeneration.from_pretrained(model_path)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print("Model loaded and ready.")
# def split_by_words_and_dot(text, min_words=125, max_words=160, fallback_words=150):
# words = re.findall(r'\S+|\n', text) # giữ nguyên \n như một "từ"
# chunks = []
# start = 0
# while start < len(words):
# end = min(start + max_words, len(words))
# dot_idx = -1
# for i in range(start + min_words, min(start + max_words, len(words))):
# if words[i] == '.' or (words[i].endswith('.') and words[i] != '\n'):
# dot_idx = i
# if dot_idx != -1:
# chunk_end = dot_idx + 1
# elif end - start > fallback_words:
# chunk_end = start + fallback_words
# else:
# chunk_end = end
# chunk = ' '.join([w if w != '\n' else '\n' for w in words[start:chunk_end]]).replace(' \n ', '\n').replace(' \n', '\n').replace('\n ', '\n')
# chunks.append(chunk.strip())
# start = chunk_end
# return chunks
def split_by_words_and_dot(text, min_words=125, max_words=160, fallback_words=150):
import re
words = re.findall(r'\S+|\n', text) # giữ nguyên \n như một "từ"
chunks = []
start = 0
while start < len(words):
end = min(start + max_words, len(words))
dot_idx = -1
for i in range(start + min_words, end):
if words[i] in ['.', '?', '!'] or (words[i].endswith(('.', '?', '!')) and words[i] != '\n'):
dot_idx = i
if dot_idx != -1:
chunk_end = dot_idx + 1
elif end - start > fallback_words:
chunk_end = start + fallback_words
else:
chunk_end = end
chunk = ' '.join([w if w != '\n' else '\n' for w in words[start:chunk_end]])
chunk = chunk.replace(' \n ', '\n').replace(' \n', '\n').replace('\n ', '\n')
chunks.append(chunk.strip())
start = chunk_end
return chunks
class TranslateRequest(BaseModel):
text: str
source_lang: str
target_lang: str
@app.post("/translate")
def translate_text(req: TranslateRequest):
tokenizer.src_lang = req.source_lang
text_chunks = split_by_words_and_dot(req.text)
translated_chunks = []
timing_info = []
global_start = time.perf_counter()
for idx, chunk in enumerate(text_chunks):
start_time = time.perf_counter()
encoded = tokenizer(chunk, return_tensors="pt", truncation=True, max_length=256).to(device)
with torch.inference_mode():
generated_tokens = model.generate(
**encoded,
forced_bos_token_id=tokenizer.get_lang_id(req.target_lang),
max_length=256, # <-- Giữ mức vừa phải
num_beams=2, # <-- Đơn giản beam search
no_repeat_ngram_size=3,
)
print(f"forced_bos_token_id: {tokenizer.lang_code_to_id[req.target_lang]}")
print(f"tokenizer.src_lang: {tokenizer.src_lang}")
translated_text = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
translated_chunks.append(translated_text)
end_time = time.perf_counter()
timing_info.append(f"Translated chunk {idx+1}/{len(text_chunks)} in {end_time - start_time:.3f} seconds")
global_end = time.perf_counter()
print(f"⚡️ Total translation time: {global_end - global_start:.3f} seconds")
print(timing_info)
return {
"source_text": req.text,
"translated_text": "\n".join(translated_chunks),
"src_lang": req.source_lang,
"tgt_lang": req.target_lang,
}