Zguin commited on
Commit
d98c8c7
·
verified ·
1 Parent(s): e5b771f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -27
app.py CHANGED
@@ -10,55 +10,107 @@ from transformers import (
10
  from typing import Union
11
  from gtts import gTTS
12
  import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  class ImageCaptionPipeline:
15
  def __init__(self):
16
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
17
- self.blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base", use_fast=True)
18
- self.blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(self.device)
 
 
 
 
19
  self.translator_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-ru")
20
  self.translator_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-ru").to(self.device)
 
21
 
22
- def generate_caption(self, image: Union[str, Image.Image], language: str = "Русский") -> str:
 
23
  if isinstance(image, str):
24
  image = Image.open(image)
25
  image = image.convert("RGB")
26
  inputs = self.blip_processor(images=image, return_tensors="pt").to(self.device)
27
  with torch.no_grad():
28
- output_ids = self.blip_model.generate(**inputs, max_length=200, num_beams=4)
29
  english_caption = self.blip_processor.decode(output_ids[0], skip_special_tokens=True)
30
- if language == "Русский":
31
- translated_inputs = self.translator_tokenizer(english_caption, return_tensors="pt", padding=True).to(self.device)
32
- with torch.no_grad():
33
- translated_ids = self.translator_model.generate(**translated_inputs, max_length=200, num_beams=4)
34
- russian_caption = self.translator_tokenizer.decode(translated_ids[0], skip_special_tokens=True)
35
- return russian_caption
36
- return english_caption
 
 
 
 
 
37
 
38
- def app(image: Image.Image, language: str) -> tuple:
39
- if image is not None:
40
- pipeline = ImageCaptionPipeline()
41
- caption = pipeline.generate_caption(image, language=language)
42
  lang_code = "ru" if language == "Русский" else "en"
43
- tts = gTTS(text=caption, lang=lang_code)
44
- audio_path = "caption_audio.mp3"
45
  tts.save(audio_path)
46
- return caption, audio_path
47
- return "Загрузите изображение и выберите язык для получения подписи.", None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
  with gr.Blocks() as iface:
50
  gr.Markdown("# Генератор подписей")
51
- gr.Markdown("Загрузите изображение и выберите язык.")
52
- language = gr.Dropdown(choices=["Русский", "English"], label="Язык", value="Русский")
53
  image = gr.Image(type="pil", label="Изображение", height=400, width=400)
54
- submit_button = gr.Button("Сгенерировать", elem_classes="btn")
55
- caption_output = gr.Textbox(label="Подпись")
 
 
 
 
 
 
 
 
56
  audio_output = gr.Audio(label="Озвучка")
57
-
58
  submit_button.click(
59
- fn=app,
60
- inputs=[image, language],
61
- outputs=[caption_output, audio_output]
 
 
 
 
 
 
62
  )
63
 
64
  if __name__ == "__main__":
 
10
  from typing import Union
11
  from gtts import gTTS
12
  import os
13
+ import uuid
14
+ import time
15
+ import gc
16
+
17
+ # Оптимизация CPU: установка числа потоков
18
+ torch.set_num_threads(2)
19
+
20
+ # Глобальная переменная для кэширования pipeline
21
+ _pipeline = None
22
+
23
+ def init_pipeline():
24
+ global _pipeline
25
+ if _pipeline is None:
26
+ _pipeline = ImageCaptionPipeline()
27
+ return _pipeline
28
 
29
  class ImageCaptionPipeline:
30
  def __init__(self):
31
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
32
+ start_time = time.time()
33
+ self.blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large", use_fast=True)
34
+ self.blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large").to(self.device)
35
+ print(f"Время загрузки BLIP: {time.time() - start_time:.2f} секунд")
36
+
37
+ start_time = time.time()
38
  self.translator_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-ru")
39
  self.translator_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-ru").to(self.device)
40
+ print(f"Время загрузки переводчика: {time.time() - start_time:.2f} секунд")
41
 
42
+ def generate_captions(self, image: Union[str, Image.Image]) -> tuple:
43
+ start_time = time.time()
44
  if isinstance(image, str):
45
  image = Image.open(image)
46
  image = image.convert("RGB")
47
  inputs = self.blip_processor(images=image, return_tensors="pt").to(self.device)
48
  with torch.no_grad():
49
+ output_ids = self.blip_model.generate(**inputs, max_length=50, num_beams=2, early_stopping=True)
50
  english_caption = self.blip_processor.decode(output_ids[0], skip_special_tokens=True)
51
+ print(f"Время генерации английской подписи: {time.time() - start_time:.2f} секунд")
52
+
53
+ start_time = time.time()
54
+ translated_inputs = self.translator_tokenizer(english_caption, return_tensors="pt", padding=True).to(self.device)
55
+ with torch.no_grad():
56
+ translated_ids = self.translator_model.generate(**translated_inputs, max_length=50, num_beams=2, early_stopping=True)
57
+ russian_caption = self.translator_tokenizer.decode(translated_ids[0], skip_special_tokens=True)
58
+ print(f"Время перевода на русский: {time.time() - start_time:.2f} секунд")
59
+
60
+ # Освобождение памяти
61
+ gc.collect()
62
+ return english_caption, russian_caption
63
 
64
+ def generate_audio(self, text: str, language: str) -> str:
65
+ start_time = time.time()
 
 
66
  lang_code = "ru" if language == "Русский" else "en"
67
+ tts = gTTS(text=text, lang=lang_code)
68
+ audio_path = f"caption_audio_{uuid.uuid4()}.mp3"
69
  tts.save(audio_path)
70
+ print(f"Время генерации озвучки: {time.time() - start_time:.2f} секунд")
71
+ return audio_path
72
+
73
+ def generate_captions(image: Image.Image) -> tuple:
74
+ if image is not None:
75
+ pipeline = init_pipeline()
76
+ english_caption, russian_caption = pipeline.generate_captions(image)
77
+ return f"English: {english_caption}", f"Русский: {russian_caption}", None
78
+ return "Загрузите изображение.", "Загрузите изображение.", None
79
+
80
+ def generate_audio(english_caption: str, russian_caption: str, audio_language: str) -> str:
81
+ if not english_caption and not russian_caption:
82
+ return None
83
+ pipeline = init_pipeline()
84
+ text = russian_caption.replace("Русский: ", "") if audio_language == "Русский" else english_caption.replace("English: ", "")
85
+ return pipeline.generate_audio(text, audio_language)
86
 
87
  with gr.Blocks() as iface:
88
  gr.Markdown("# Генератор подписей")
89
+ gr.Markdown("Загрузите изображение для получения подписей на двух языках.")
90
+
91
  image = gr.Image(type="pil", label="Изображение", height=400, width=400)
92
+ submit_button = gr.Button("Сгенерировать подписи", elem_classes="btn")
93
+
94
+ with gr.Row():
95
+ english_caption = gr.Textbox(label="Подпись (English)")
96
+ russian_caption = gr.Textbox(label="Подпись (Русский)")
97
+
98
+ with gr.Row():
99
+ audio_language = gr.Dropdown(choices=["Русский", "English"], label="Язык озвучки", value="Русский")
100
+ audio_button = gr.Button("Сгенерировать озвучку", elem_classes="btn")
101
+
102
  audio_output = gr.Audio(label="Озвучка")
103
+
104
  submit_button.click(
105
+ fn=generate_captions,
106
+ inputs=[image],
107
+ outputs=[english_caption, russian_caption, audio_output]
108
+ )
109
+
110
+ audio_button.click(
111
+ fn=generate_audio,
112
+ inputs=[english_caption, russian_caption, audio_language],
113
+ outputs=[audio_output]
114
  )
115
 
116
  if __name__ == "__main__":