RO-Rtechs commited on
Commit
2230a38
·
verified ·
1 Parent(s): 565cffc

Update soni_translate/translate_segments.py

Browse files
soni_translate/translate_segments.py CHANGED
@@ -7,6 +7,8 @@ from .logging_setup import logger
7
  import re
8
  import json
9
  import time
 
 
10
 
11
  TRANSLATION_PROCESS_OPTIONS = [
12
  "google_translator_batch",
@@ -15,12 +17,15 @@ TRANSLATION_PROCESS_OPTIONS = [
15
  "gpt-3.5-turbo-0125",
16
  "gpt-4-turbo-preview_batch",
17
  "gpt-4-turbo-preview",
 
 
18
  "disable_translation",
19
  ]
20
  DOCS_TRANSLATION_PROCESS_OPTIONS = [
21
  "google_translator",
22
  "gpt-3.5-turbo-0125",
23
  "gpt-4-turbo-preview",
 
24
  "disable_translation",
25
  ]
26
 
@@ -418,6 +423,74 @@ def gpt_batch(segments, model, target, token_batch_limit=900, source=None):
418
  )
419
 
420
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
421
  def translate_text(
422
  segments,
423
  target,
@@ -443,7 +516,7 @@ def translate_text(
443
  )
444
  case model if model in ["gpt-3.5-turbo-0125", "gpt-4-turbo-preview"]:
445
  return gpt_sequential(segments, model, target, source)
446
- case model if model in ["gpt-3.5-turbo-0125_batch", "gpt-4-turbo-preview_batch",]:
447
  return gpt_batch(
448
  segments,
449
  translation_process.replace("_batch", ""),
@@ -451,6 +524,10 @@ def translate_text(
451
  token_batch_limit,
452
  source
453
  )
 
 
 
 
454
  case "disable_translation":
455
  return segments
456
  case _:
 
7
  import re
8
  import json
9
  import time
10
+ import os
11
+ import google.generativeai as genai
12
 
13
  TRANSLATION_PROCESS_OPTIONS = [
14
  "google_translator_batch",
 
17
  "gpt-3.5-turbo-0125",
18
  "gpt-4-turbo-preview_batch",
19
  "gpt-4-turbo-preview",
20
+ "gemini-pro",
21
+ "gemini-pro_batch",
22
  "disable_translation",
23
  ]
24
  DOCS_TRANSLATION_PROCESS_OPTIONS = [
25
  "google_translator",
26
  "gpt-3.5-turbo-0125",
27
  "gpt-4-turbo-preview",
28
+ "gemini-pro",
29
  "disable_translation",
30
  ]
31
 
 
423
  )
424
 
425
 
426
+ def check_gemini_api_key():
427
+ """Check if Gemini API key is set in environment variables."""
428
+ if not os.environ.get("GOOGLE_API_KEY"):
429
+ raise ValueError(
430
+ "Gemini API key not found. Please set the GOOGLE_API_KEY environment variable."
431
+ )
432
+
433
+ def translate_with_gemini(text, target_lang, source_lang=None):
434
+ """Translate text using Google's Gemini API."""
435
+ check_gemini_api_key()
436
+ genai.configure(api_key=os.environ["GOOGLE_API_KEY"])
437
+ model = genai.GenerativeModel('gemini-pro')
438
+
439
+ prompt = f"""Translate the following text to {target_lang}.
440
+ Keep the same tone and style. Preserve any special characters or formatting.
441
+
442
+ Text to translate: {text}
443
+ """
444
+ if source_lang:
445
+ prompt = f"Translate from {source_lang} to {target_lang}: {text}"
446
+
447
+ response = model.generate_content(prompt)
448
+ return response.text.strip()
449
+
450
+ def gemini_sequential(segments, target, source=None):
451
+ """Translate segments sequentially using Gemini."""
452
+ segments_ = copy.deepcopy(segments)
453
+
454
+ for line in tqdm(range(len(segments_))):
455
+ text = segments_[line]["text"]
456
+ translated_line = translate_with_gemini(text.strip(), target, source)
457
+ segments_[line]["text"] = translated_line
458
+
459
+ return segments_
460
+
461
+ def gemini_batch(segments, target, token_batch_limit=1000, source=None):
462
+ """Translate segments in batches using Gemini."""
463
+ segments_ = copy.deepcopy(segments)
464
+ batch_texts = []
465
+ current_batch = []
466
+ current_length = 0
467
+
468
+ # Group texts into batches
469
+ for segment in segments_:
470
+ text_length = len(segment["text"])
471
+ if current_length + text_length > token_batch_limit:
472
+ batch_texts.append(current_batch)
473
+ current_batch = []
474
+ current_length = 0
475
+ current_batch.append(segment["text"])
476
+ current_length += text_length
477
+
478
+ if current_batch:
479
+ batch_texts.append(current_batch)
480
+
481
+ # Translate each batch
482
+ for i, batch in enumerate(tqdm(batch_texts)):
483
+ batch_text = "\n---\n".join(batch)
484
+ translated_batch = translate_with_gemini(batch_text, target, source)
485
+ translated_segments = translated_batch.split("\n---\n")
486
+
487
+ # Update segments with translations
488
+ start_idx = sum(len(b) for b in batch_texts[:i])
489
+ for j, translation in enumerate(translated_segments):
490
+ segments_[start_idx + j]["text"] = translation.strip()
491
+
492
+ return segments_
493
+
494
  def translate_text(
495
  segments,
496
  target,
 
516
  )
517
  case model if model in ["gpt-3.5-turbo-0125", "gpt-4-turbo-preview"]:
518
  return gpt_sequential(segments, model, target, source)
519
+ case model if model in ["gpt-3.5-turbo-0125_batch", "gpt-4-turbo-preview_batch"]:
520
  return gpt_batch(
521
  segments,
522
  translation_process.replace("_batch", ""),
 
524
  token_batch_limit,
525
  source
526
  )
527
+ case "gemini-pro":
528
+ return gemini_sequential(segments, target, source)
529
+ case "gemini-pro_batch":
530
+ return gemini_batch(segments, target, token_batch_limit, source)
531
  case "disable_translation":
532
  return segments
533
  case _: