Spaces:
Running
Running
jhj0517
commited on
Commit
·
184dab0
1
Parent(s):
736cf38
add `max_length` parameter
Browse files
app.py
CHANGED
|
@@ -20,7 +20,7 @@ class App:
|
|
| 20 |
print(f"Device \"{self.whisper_inf.device}\" is detected")
|
| 21 |
self.nllb_inf = NLLBInference(
|
| 22 |
model_dir=self.args.nllb_model_dir,
|
| 23 |
-
output_dir=self.args.output_dir
|
| 24 |
)
|
| 25 |
self.deepl_api = DeepLAPI(
|
| 26 |
output_dir=self.args.output_dir
|
|
@@ -375,6 +375,8 @@ class App:
|
|
| 375 |
choices=self.nllb_inf.available_source_langs)
|
| 376 |
dd_nllb_targetlang = gr.Dropdown(label="Target Language",
|
| 377 |
choices=self.nllb_inf.available_target_langs)
|
|
|
|
|
|
|
| 378 |
with gr.Row():
|
| 379 |
cb_timestamp = gr.Checkbox(value=True, label="Add a timestamp to the end of the filename",
|
| 380 |
interactive=True)
|
|
@@ -388,7 +390,7 @@ class App:
|
|
| 388 |
md_vram_table = gr.HTML(NLLB_VRAM_TABLE, elem_id="md_nllb_vram_table")
|
| 389 |
|
| 390 |
btn_run.click(fn=self.nllb_inf.translate_file,
|
| 391 |
-
inputs=[file_subs, dd_nllb_model, dd_nllb_sourcelang, dd_nllb_targetlang, cb_timestamp],
|
| 392 |
outputs=[tb_indicator, files_subtitles])
|
| 393 |
|
| 394 |
btn_openfolder.click(fn=lambda: self.open_folder(os.path.join("outputs", "translations")),
|
|
|
|
| 20 |
print(f"Device \"{self.whisper_inf.device}\" is detected")
|
| 21 |
self.nllb_inf = NLLBInference(
|
| 22 |
model_dir=self.args.nllb_model_dir,
|
| 23 |
+
output_dir=os.path.join(self.args.output_dir, "translations")
|
| 24 |
)
|
| 25 |
self.deepl_api = DeepLAPI(
|
| 26 |
output_dir=self.args.output_dir
|
|
|
|
| 375 |
choices=self.nllb_inf.available_source_langs)
|
| 376 |
dd_nllb_targetlang = gr.Dropdown(label="Target Language",
|
| 377 |
choices=self.nllb_inf.available_target_langs)
|
| 378 |
+
with gr.Row():
|
| 379 |
+
nb_max_length = gr.Number(label="Max Length Per Line", value=200, precision=0)
|
| 380 |
with gr.Row():
|
| 381 |
cb_timestamp = gr.Checkbox(value=True, label="Add a timestamp to the end of the filename",
|
| 382 |
interactive=True)
|
|
|
|
| 390 |
md_vram_table = gr.HTML(NLLB_VRAM_TABLE, elem_id="md_nllb_vram_table")
|
| 391 |
|
| 392 |
btn_run.click(fn=self.nllb_inf.translate_file,
|
| 393 |
+
inputs=[file_subs, dd_nllb_model, dd_nllb_sourcelang, dd_nllb_targetlang, nb_max_length, cb_timestamp],
|
| 394 |
outputs=[tb_indicator, files_subtitles])
|
| 395 |
|
| 396 |
btn_openfolder.click(fn=lambda: self.open_folder(os.path.join("outputs", "translations")),
|
modules/translation/nllb_inference.py
CHANGED
|
@@ -21,9 +21,13 @@ class NLLBInference(TranslationBase):
|
|
| 21 |
self.pipeline = None
|
| 22 |
|
| 23 |
def translate(self,
|
| 24 |
-
text: str
|
|
|
|
| 25 |
):
|
| 26 |
-
result = self.pipeline(
|
|
|
|
|
|
|
|
|
|
| 27 |
return result[0]['translation_text']
|
| 28 |
|
| 29 |
def update_model(self,
|
|
|
|
| 21 |
self.pipeline = None
|
| 22 |
|
| 23 |
def translate(self,
|
| 24 |
+
text: str,
|
| 25 |
+
max_length: int
|
| 26 |
):
|
| 27 |
+
result = self.pipeline(
|
| 28 |
+
text,
|
| 29 |
+
max_length=max_length
|
| 30 |
+
)
|
| 31 |
return result[0]['translation_text']
|
| 32 |
|
| 33 |
def update_model(self,
|
modules/translation/translation_base.py
CHANGED
|
@@ -24,7 +24,8 @@ class TranslationBase(ABC):
|
|
| 24 |
|
| 25 |
@abstractmethod
|
| 26 |
def translate(self,
|
| 27 |
-
text: str
|
|
|
|
| 28 |
):
|
| 29 |
pass
|
| 30 |
|
|
@@ -42,6 +43,7 @@ class TranslationBase(ABC):
|
|
| 42 |
model_size: str,
|
| 43 |
src_lang: str,
|
| 44 |
tgt_lang: str,
|
|
|
|
| 45 |
add_timestamp: bool,
|
| 46 |
progress=gr.Progress()) -> list:
|
| 47 |
"""
|
|
@@ -57,6 +59,8 @@ class TranslationBase(ABC):
|
|
| 57 |
Source language of the file to translate from gr.Dropdown()
|
| 58 |
tgt_lang: str
|
| 59 |
Target language of the file to translate from gr.Dropdown()
|
|
|
|
|
|
|
| 60 |
add_timestamp: bool
|
| 61 |
Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
|
| 62 |
progress: gr.Progress
|
|
@@ -84,7 +88,7 @@ class TranslationBase(ABC):
|
|
| 84 |
total_progress = len(parsed_dicts)
|
| 85 |
for index, dic in enumerate(parsed_dicts):
|
| 86 |
progress(index / total_progress, desc="Translating..")
|
| 87 |
-
translated_text = self.translate(dic["sentence"])
|
| 88 |
dic["sentence"] = translated_text
|
| 89 |
subtitle = get_serialized_srt(parsed_dicts)
|
| 90 |
|
|
@@ -99,7 +103,7 @@ class TranslationBase(ABC):
|
|
| 99 |
total_progress = len(parsed_dicts)
|
| 100 |
for index, dic in enumerate(parsed_dicts):
|
| 101 |
progress(index / total_progress, desc="Translating..")
|
| 102 |
-
translated_text = self.translate(dic["sentence"])
|
| 103 |
dic["sentence"] = translated_text
|
| 104 |
subtitle = get_serialized_vtt(parsed_dicts)
|
| 105 |
|
|
@@ -124,7 +128,6 @@ class TranslationBase(ABC):
|
|
| 124 |
print(f"Error: {str(e)}")
|
| 125 |
finally:
|
| 126 |
self.release_cuda_memory()
|
| 127 |
-
self.remove_input_files([fileobj.name for fileobj in fileobjs])
|
| 128 |
|
| 129 |
@staticmethod
|
| 130 |
def get_device():
|
|
|
|
| 24 |
|
| 25 |
@abstractmethod
|
| 26 |
def translate(self,
|
| 27 |
+
text: str,
|
| 28 |
+
max_length: int
|
| 29 |
):
|
| 30 |
pass
|
| 31 |
|
|
|
|
| 43 |
model_size: str,
|
| 44 |
src_lang: str,
|
| 45 |
tgt_lang: str,
|
| 46 |
+
max_length: int,
|
| 47 |
add_timestamp: bool,
|
| 48 |
progress=gr.Progress()) -> list:
|
| 49 |
"""
|
|
|
|
| 59 |
Source language of the file to translate from gr.Dropdown()
|
| 60 |
tgt_lang: str
|
| 61 |
Target language of the file to translate from gr.Dropdown()
|
| 62 |
+
max_length: int
|
| 63 |
+
Max length per line to translate
|
| 64 |
add_timestamp: bool
|
| 65 |
Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
|
| 66 |
progress: gr.Progress
|
|
|
|
| 88 |
total_progress = len(parsed_dicts)
|
| 89 |
for index, dic in enumerate(parsed_dicts):
|
| 90 |
progress(index / total_progress, desc="Translating..")
|
| 91 |
+
translated_text = self.translate(dic["sentence"], max_length=max_length)
|
| 92 |
dic["sentence"] = translated_text
|
| 93 |
subtitle = get_serialized_srt(parsed_dicts)
|
| 94 |
|
|
|
|
| 103 |
total_progress = len(parsed_dicts)
|
| 104 |
for index, dic in enumerate(parsed_dicts):
|
| 105 |
progress(index / total_progress, desc="Translating..")
|
| 106 |
+
translated_text = self.translate(dic["sentence"], max_length=max_length)
|
| 107 |
dic["sentence"] = translated_text
|
| 108 |
subtitle = get_serialized_vtt(parsed_dicts)
|
| 109 |
|
|
|
|
| 128 |
print(f"Error: {str(e)}")
|
| 129 |
finally:
|
| 130 |
self.release_cuda_memory()
|
|
|
|
| 131 |
|
| 132 |
@staticmethod
|
| 133 |
def get_device():
|