File size: 1,980 Bytes
b3b0b53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import sys
from transformers import MarianMTModel, MarianTokenizer

def translate_text(text, source_lang, target_lang):
    """
    Translates text from a source language to a target language.
    """
    try:
        # Define the model name based on source and target languages
        # The format is 'Helsinki-NLP/opus-mt-{source}-{target}'
        model_name = f'Helsinki-NLP/opus-mt-{source_lang}-{target_lang}'

        # Load the tokenizer and model.
        # The first time a model is used, it will be downloaded from Hugging Face.
        # This might take a moment. Subsequent uses will load from cache.
        tokenizer = MarianTokenizer.from_pretrained(model_name)
        model = MarianMTModel.from_pretrained(model_name)

        # Tokenize the input text
        tokenized_text = tokenizer(text, return_tensors="pt", padding=True)

        # Generate the translation
        translated_tokens = model.generate(**tokenized_text)

        # Decode the translated tokens into text
        translated_text = tokenizer.decode(translated_tokens[0], skip_special_tokens=True)

        return translated_text
    except Exception as e:
        # Handle cases where a direct model doesn't exist (e.g., zh-es)
        # or other errors.
        return f"Error during translation: {str(e)}"


if __name__ == "__main__":
    # The script expects three arguments: text, source_lang, target_lang
    if len(sys.argv) != 4:
        print("Usage: python translate.py <text_to_translate> <source_lang> <target_lang>")
        sys.exit(1)

    input_text = sys.argv[1]
    source_language = sys.argv[2]
    target_language = sys.argv[3]

    # The models use 2-letter language codes (e.g., 'en', 'zh', 'es')
    # We take the first part of the lang code (e.g., 'zh-CN' -> 'zh')
    source_code = source_language.split('-')[0]
    target_code = target_language.split('-')[0]

    translated_output = translate_text(input_text, source_code, target_code)
    print(translated_output)