Spaces:
Sleeping
Sleeping
# from fastapi import FastAPI, Request | |
# from pydantic import BaseModel | |
# from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer | |
# import torch | |
# import re | |
# import time | |
# app = FastAPI() | |
# def split_by_words_and_dot(text, min_words=125, max_words=160, fallback_words=150): | |
# import re | |
# words = re.findall(r'\S+|\n', text) # giữ nguyên \n như một "từ" | |
# chunks = [] | |
# start = 0 | |
# while start < len(words): | |
# end = min(start + max_words, len(words)) | |
# # Tìm dấu chấm trong khoảng min_words đến max_words | |
# dot_idx = -1 | |
# for i in range(start + min_words, min(start + max_words, len(words))): | |
# if words[i] == '.' or (words[i].endswith('.') and words[i] != '\n'): | |
# dot_idx = i | |
# if dot_idx != -1: | |
# chunk_end = dot_idx + 1 | |
# elif end - start > fallback_words: | |
# chunk_end = start + fallback_words | |
# else: | |
# chunk_end = end | |
# chunk = ' '.join([w if w != '\n' else '\n' for w in words[start:chunk_end]]).replace(' \n ', '\n').replace(' \n', '\n').replace('\n ', '\n') | |
# chunks.append(chunk.strip()) | |
# start = chunk_end | |
# return chunks | |
# # Load model | |
# model_path = "longvnhue1/facebook-m2m100_418M-fine_tuning" | |
# tokenizer = M2M100Tokenizer.from_pretrained(model_path) | |
# model = M2M100ForConditionalGeneration.from_pretrained(model_path) | |
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# model.to(device) | |
# class TranslateRequest(BaseModel): | |
# text: str | |
# source_lang: str | |
# target_lang: str | |
# @app.post("/translate") | |
# def translate_text(req: TranslateRequest): | |
# tokenizer.src_lang = req.source_lang | |
# text_chunks = split_by_words_and_dot(req.text, min_words=125, max_words=160, fallback_words=150) | |
# translated_chunks = [] | |
# timing_info = [] | |
# for idx, chunk in enumerate(text_chunks): | |
# start_time = time.perf_counter() # Bắt đầu đếm thời gian | |
# encoded = tokenizer(chunk, return_tensors="pt", truncation=True, max_length=256).to(device) | |
# with torch.inference_mode(): | |
# generated_tokens = model.generate( | |
# **encoded, | |
# forced_bos_token_id=tokenizer.get_lang_id(req.target_lang), | |
# max_length=256, | |
# num_beams=2, | |
# no_repeat_ngram_size=3, | |
# ) | |
# translated_text = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0] | |
# translated_chunks.append(translated_text) | |
# end_time = time.perf_counter() # Kết thúc đếm thời gian | |
# elapsed = end_time - start_time | |
# timing_info.append(f"Translated chunk {idx+1}/{len(text_chunks)} in {elapsed:.3f} seconds") | |
# full_translation = "\n".join(translated_chunks) | |
# print(timing_info) | |
# return { | |
# "source_text": req.text, | |
# "translated_text": full_translation, | |
# "src_lang": req.source_lang, | |
# "tgt_lang": req.target_lang, | |
# } | |
from fastapi import FastAPI | |
from pydantic import BaseModel | |
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer | |
import torch | |
import re | |
import time | |
# Limit CPU thread | |
torch.set_num_threads(1) | |
app = FastAPI() | |
def startup_event(): | |
global tokenizer, model, device | |
model_path = "longvnhue1/facebook-m2m100_418M-fine_tuning" | |
tokenizer = M2M100Tokenizer.from_pretrained(model_path) | |
model = M2M100ForConditionalGeneration.from_pretrained(model_path) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model.to(device) | |
print("Model loaded and ready.") | |
# def split_by_words_and_dot(text, min_words=125, max_words=160, fallback_words=150): | |
# words = re.findall(r'\S+|\n', text) # giữ nguyên \n như một "từ" | |
# chunks = [] | |
# start = 0 | |
# while start < len(words): | |
# end = min(start + max_words, len(words)) | |
# dot_idx = -1 | |
# for i in range(start + min_words, min(start + max_words, len(words))): | |
# if words[i] == '.' or (words[i].endswith('.') and words[i] != '\n'): | |
# dot_idx = i | |
# if dot_idx != -1: | |
# chunk_end = dot_idx + 1 | |
# elif end - start > fallback_words: | |
# chunk_end = start + fallback_words | |
# else: | |
# chunk_end = end | |
# chunk = ' '.join([w if w != '\n' else '\n' for w in words[start:chunk_end]]).replace(' \n ', '\n').replace(' \n', '\n').replace('\n ', '\n') | |
# chunks.append(chunk.strip()) | |
# start = chunk_end | |
# return chunks | |
def split_by_words_and_dot(text, min_words=125, max_words=160, fallback_words=150): | |
import re | |
words = re.findall(r'\S+|\n', text) # giữ nguyên \n như một "từ" | |
chunks = [] | |
start = 0 | |
while start < len(words): | |
end = min(start + max_words, len(words)) | |
dot_idx = -1 | |
for i in range(start + min_words, end): | |
if words[i] in ['.', '?', '!'] or (words[i].endswith(('.', '?', '!')) and words[i] != '\n'): | |
dot_idx = i | |
if dot_idx != -1: | |
chunk_end = dot_idx + 1 | |
elif end - start > fallback_words: | |
chunk_end = start + fallback_words | |
else: | |
chunk_end = end | |
chunk = ' '.join([w if w != '\n' else '\n' for w in words[start:chunk_end]]) | |
chunk = chunk.replace(' \n ', '\n').replace(' \n', '\n').replace('\n ', '\n') | |
chunks.append(chunk.strip()) | |
start = chunk_end | |
return chunks | |
class TranslateRequest(BaseModel): | |
text: str | |
source_lang: str | |
target_lang: str | |
def translate_text(req: TranslateRequest): | |
tokenizer.src_lang = req.source_lang | |
text_chunks = split_by_words_and_dot(req.text) | |
translated_chunks = [] | |
timing_info = [] | |
global_start = time.perf_counter() | |
for idx, chunk in enumerate(text_chunks): | |
start_time = time.perf_counter() | |
encoded = tokenizer(chunk, return_tensors="pt", truncation=True, max_length=256).to(device) | |
with torch.inference_mode(): | |
generated_tokens = model.generate( | |
**encoded, | |
forced_bos_token_id=tokenizer.get_lang_id(req.target_lang), | |
max_length=256, # <-- Giữ mức vừa phải | |
num_beams=2, # <-- Đơn giản beam search | |
no_repeat_ngram_size=3, | |
) | |
print(f"forced_bos_token_id: {tokenizer.lang_code_to_id[req.target_lang]}") | |
print(f"tokenizer.src_lang: {tokenizer.src_lang}") | |
translated_text = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0] | |
translated_chunks.append(translated_text) | |
end_time = time.perf_counter() | |
timing_info.append(f"Translated chunk {idx+1}/{len(text_chunks)} in {end_time - start_time:.3f} seconds") | |
global_end = time.perf_counter() | |
print(f"⚡️ Total translation time: {global_end - global_start:.3f} seconds") | |
print(timing_info) | |
return { | |
"source_text": req.text, | |
"translated_text": "\n".join(translated_chunks), | |
"src_lang": req.source_lang, | |
"tgt_lang": req.target_lang, | |
} | |