Spaces:
Running
Running
jhj0517
commited on
Commit
·
a0d6f10
1
Parent(s):
3f9463b
Add model path check fucntion and `local_files_only`
Browse files
modules/translation/nllb_inference.py
CHANGED
|
@@ -40,10 +40,13 @@ class NLLBInference(TranslationBase):
|
|
| 40 |
print("\nInitializing NLLB Model..\n")
|
| 41 |
progress(0, desc="Initializing NLLB Model..")
|
| 42 |
self.current_model_size = model_size
|
|
|
|
| 43 |
self.model = AutoModelForSeq2SeqLM.from_pretrained(pretrained_model_name_or_path=model_size,
|
| 44 |
-
cache_dir=self.model_dir
|
|
|
|
| 45 |
self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=model_size,
|
| 46 |
-
cache_dir=os.path.join(self.model_dir, "tokenizers")
|
|
|
|
| 47 |
src_lang = NLLB_AVAILABLE_LANGS[src_lang]
|
| 48 |
tgt_lang = NLLB_AVAILABLE_LANGS[tgt_lang]
|
| 49 |
self.pipeline = pipeline("translation",
|
|
@@ -53,6 +56,18 @@ class NLLBInference(TranslationBase):
|
|
| 53 |
tgt_lang=tgt_lang,
|
| 54 |
device=self.device)
|
| 55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
NLLB_AVAILABLE_LANGS = {
|
| 57 |
"Acehnese (Arabic script)": "ace_Arab",
|
| 58 |
"Acehnese (Latin script)": "ace_Latn",
|
|
|
|
| 40 |
print("\nInitializing NLLB Model..\n")
|
| 41 |
progress(0, desc="Initializing NLLB Model..")
|
| 42 |
self.current_model_size = model_size
|
| 43 |
+
local_files_only = self.is_model_exists(self.current_model_size)
|
| 44 |
self.model = AutoModelForSeq2SeqLM.from_pretrained(pretrained_model_name_or_path=model_size,
|
| 45 |
+
cache_dir=self.model_dir,
|
| 46 |
+
local_files_only=local_files_only)
|
| 47 |
self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=model_size,
|
| 48 |
+
cache_dir=os.path.join(self.model_dir, "tokenizers"),
|
| 49 |
+
local_files_only=local_files_only)
|
| 50 |
src_lang = NLLB_AVAILABLE_LANGS[src_lang]
|
| 51 |
tgt_lang = NLLB_AVAILABLE_LANGS[tgt_lang]
|
| 52 |
self.pipeline = pipeline("translation",
|
|
|
|
| 56 |
tgt_lang=tgt_lang,
|
| 57 |
device=self.device)
|
| 58 |
|
| 59 |
+
def is_model_exists(self,
|
| 60 |
+
model_size: str):
|
| 61 |
+
"""Check if model exists or not (Only facebook model)"""
|
| 62 |
+
prefix = "models--facebook--"
|
| 63 |
+
_id, model_size_name = model_size.split("/")
|
| 64 |
+
model_dir_name = prefix + model_size_name
|
| 65 |
+
model_dir_path = os.path.join(self.model_dir, model_dir_name)
|
| 66 |
+
if os.path.exists(model_dir_path) and os.listdir(model_dir_path):
|
| 67 |
+
return True
|
| 68 |
+
return False
|
| 69 |
+
|
| 70 |
+
|
| 71 |
NLLB_AVAILABLE_LANGS = {
|
| 72 |
"Acehnese (Arabic script)": "ace_Arab",
|
| 73 |
"Acehnese (Latin script)": "ace_Latn",
|