Zguin commited on
Commit
262fa8a
·
verified ·
1 Parent(s): ac10dbc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -41
app.py CHANGED
@@ -4,8 +4,12 @@ from PIL import Image
4
  from transformers import (
5
  BlipProcessor,
6
  BlipForConditionalGeneration,
 
 
7
  M2M100Tokenizer,
8
- M2M100ForConditionalGeneration
 
 
9
  )
10
  from typing import Union
11
  from gtts import gTTS
@@ -14,60 +18,79 @@ import uuid
14
  import time
15
  import gc
16
 
17
- # Оптимизация CPU: установка числа потоков
18
  torch.set_num_threads(2)
19
-
20
- # Глобальная переменная для кэширования pipeline
21
  _pipeline = None
22
 
23
- def init_pipeline():
24
  global _pipeline
25
  if _pipeline is None:
26
- _pipeline = ImageCaptionPipeline()
27
  return _pipeline
28
 
29
  class ImageCaptionPipeline:
30
- def __init__(self):
31
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
32
  start_time = time.time()
33
- self.blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large", use_fast=True)
34
- self.blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large").to(self.device)
35
- print(f"Время загрузки BLIP: {time.time() - start_time:.2f} секунд")
 
 
 
 
36
 
37
  start_time = time.time()
38
- self.translator_tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M")
39
- self.translator_model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M").to(self.device)
40
- print(f"Время загрузки переводчика: {time.time() - start_time:.2f} секунд")
 
 
 
 
41
 
42
- def generate_captions(self, image: Union[str, Image.Image]) -> tuple:
43
  start_time = time.time()
44
  if isinstance(image, str):
45
  image = Image.open(image)
46
  image = image.convert("RGB")
47
- image = image.resize((384, 384)) # Рекомендованный размер для BLIP-large
48
  inputs = self.blip_processor(images=image, return_tensors="pt").to(self.device)
49
  with torch.no_grad():
50
  output_ids = self.blip_model.generate(**inputs, max_length=50, num_beams=2, early_stopping=True)
51
  english_caption = self.blip_processor.decode(output_ids[0], skip_special_tokens=True)
 
52
  print(f"Время генерации английской подписи: {time.time() - start_time:.2f} секунд")
53
 
 
 
 
 
54
  start_time = time.time()
55
- self.translator_tokenizer.src_lang = "en"
56
- translated_inputs = self.translator_tokenizer(english_caption, return_tensors="pt", padding=True).to(self.device)
57
- with torch.no_grad():
58
- translated_ids = self.translator_model.generate(
59
- **translated_inputs,
60
- forced_bos_token_id=self.translator_tokenizer.get_lang_id("ru"),
61
- max_length=50,
62
- num_beams=2,
63
- early_stopping=True
64
- )
65
- russian_caption = self.translator_tokenizer.decode(translated_ids[0], skip_special_tokens=True)
 
 
 
 
 
 
 
66
  print(f"Время перевода на русский: {time.time() - start_time:.2f} секунд")
67
 
68
- # Освобождение памяти
69
  gc.collect()
70
- return english_caption, russian_caption
71
 
72
  def generate_audio(self, text: str, language: str) -> str:
73
  start_time = time.time()
@@ -78,42 +101,57 @@ class ImageCaptionPipeline:
78
  print(f"Время генерации озвучки: {time.time() - start_time:.2f} секунд")
79
  return audio_path
80
 
81
- def generate_captions(image: Image.Image) -> tuple:
82
  if image is not None:
83
- pipeline = init_pipeline()
84
- english_caption, russian_caption = pipeline.generate_captions(image)
85
- return english_caption, russian_caption, None
86
- return "Загрузите изображение.", "Загрузите изображение.", None
 
 
 
 
 
 
87
 
88
- def generate_audio(english_caption: str, russian_caption: str, audio_language: str) -> str:
89
  if not english_caption and not russian_caption:
90
  return None
91
- pipeline = init_pipeline()
92
  text = russian_caption if audio_language == "Русский" else english_caption
93
  return pipeline.generate_audio(text, audio_language)
94
 
95
- with gr.Blocks(css=".btn {width: 200px; background-color: #4682B4; color: white; border: none; padding: 10px 20px; text-align: center; font-size: 16px;} .equal-height { height: 40px; }") as iface:
96
  with gr.Row():
97
- with gr.Column(scale=1, min_width=400, variant="panel"):
98
- image = gr.Image(type="pil", label="Изображение", height=400, width=400)
 
99
  submit_button = gr.Button("Сгенерировать описание", elem_classes="btn")
100
  with gr.Column(scale=1, min_width=300):
101
  english_caption = gr.Textbox(label="Подпись English:", lines=2)
102
  russian_caption = gr.Textbox(label="Подпись Русский:", lines=2)
 
 
103
  audio_button = gr.Button("Сгенерировать озвучку", elem_classes="btn")
104
  with gr.Row():
105
  audio_language = gr.Dropdown(choices=["Русский", "English"], label="Язык озвучки", value="Русский", scale=1, min_width=150, elem_classes="equal-height")
106
  audio_output = gr.Audio(label="Озвучка", scale=1, min_width=150, elem_classes="equal-height")
107
 
108
  submit_button.click(
109
- fn=generate_captions,
110
- inputs=[image],
111
  outputs=[english_caption, russian_caption, audio_output]
112
  )
113
 
 
 
 
 
 
 
114
  audio_button.click(
115
  fn=generate_audio,
116
- inputs=[english_caption, russian_caption, audio_language],
117
  outputs=[audio_output]
118
  )
119
 
 
4
  from transformers import (
5
  BlipProcessor,
6
  BlipForConditionalGeneration,
7
+ Blip2Processor,
8
+ Blip2ForConditionalGeneration,
9
  M2M100Tokenizer,
10
+ M2M100ForConditionalGeneration,
11
+ AutoTokenizer,
12
+ AutoModelForSeq2SeqLM
13
  )
14
  from typing import Union
15
  from gtts import gTTS
 
18
  import time
19
  import gc
20
 
 
21
  torch.set_num_threads(2)
 
 
22
  _pipeline = None
23
 
24
+ def init_pipeline(caption_model: str, translator_model: str):
25
  global _pipeline
26
  if _pipeline is None:
27
+ _pipeline = ImageCaptionPipeline(caption_model, translator_model)
28
  return _pipeline
29
 
30
  class ImageCaptionPipeline:
31
+ def __init__(self, caption_model: str, translator_model: str):
32
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
33
+ self.caption_model = caption_model
34
+ self.translator_model = translator_model
35
+
36
  start_time = time.time()
37
+ if caption_model == "BLIP":
38
+ self.blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large", use_fast=True)
39
+ self.blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large").to(self.device)
40
+ else:
41
+ self.blip_processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
42
+ self.blip_model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b").to(self.device)
43
+ print(f"Время загрузки {caption_model}: {time.time() - start_time:.2f} секунд")
44
 
45
  start_time = time.time()
46
+ if translator_model == "M2M100":
47
+ self.translator_tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M")
48
+ self.translator_model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M").to(self.device)
49
+ else:
50
+ self.translator_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-ru")
51
+ self.translator_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-ru").to(self.device)
52
+ print(f"Время загрузки переводчика {translator_model}: {time.time() - start_time:.2f} секунд")
53
 
54
+ def generate_english_caption(self, image: Union[str, Image.Image]) -> str:
55
  start_time = time.time()
56
  if isinstance(image, str):
57
  image = Image.open(image)
58
  image = image.convert("RGB")
59
+ image = image.resize((384, 384))
60
  inputs = self.blip_processor(images=image, return_tensors="pt").to(self.device)
61
  with torch.no_grad():
62
  output_ids = self.blip_model.generate(**inputs, max_length=50, num_beams=2, early_stopping=True)
63
  english_caption = self.blip_processor.decode(output_ids[0], skip_special_tokens=True)
64
+ english_caption = english_caption[0].upper() + english_caption[1:] + ('.' if not english_caption.endswith('.') else '')
65
  print(f"Время генерации английской подписи: {time.time() - start_time:.2f} секунд")
66
 
67
+ gc.collect()
68
+ return english_caption
69
+
70
+ def translate_caption(self, english_caption: str) -> str:
71
  start_time = time.time()
72
+ if self.translator_model == "M2M100":
73
+ self.translator_tokenizer.src_lang = "en"
74
+ translated_inputs = self.translator_tokenizer(english_caption, return_tensors="pt", padding=True).to(self.device)
75
+ with torch.no_grad():
76
+ translated_ids = self.translator_model.generate(
77
+ **translated_inputs,
78
+ forced_bos_token_id=self.translator_tokenizer.get_lang_id("ru"),
79
+ max_length=50,
80
+ num_beams=2,
81
+ early_stopping=True
82
+ )
83
+ russian_caption = self.translator_tokenizer.decode(translated_ids[0], skip_special_tokens=True)
84
+ else:
85
+ translated_inputs = self.translator_tokenizer(english_caption, return_tensors="pt", padding=True).to(self.device)
86
+ with torch.no_grad():
87
+ translated_ids = self.translator_model.generate(**translated_inputs, max_length=50, num_beams=2, early_stopping=True)
88
+ russian_caption = self.translator_tokenizer.decode(translated_ids[0], skip_special_tokens=True)
89
+ russian_caption = russian_caption[0].upper() + russian_caption[1:] + ('.' if not russian_caption.endswith('.') else '')
90
  print(f"Время перевода на русский: {time.time() - start_time:.2f} секунд")
91
 
 
92
  gc.collect()
93
+ return russian_caption
94
 
95
  def generate_audio(self, text: str, language: str) -> str:
96
  start_time = time.time()
 
101
  print(f"Время генерации озвучки: {time.time() - start_time:.2f} секунд")
102
  return audio_path
103
 
104
+ def generate_english_caption(image: Image.Image, caption_model: str, translator_model: str) -> tuple:
105
  if image is not None:
106
+ pipeline = init_pipeline(caption_model, translator_model)
107
+ english_caption = pipeline.generate_english_caption(image)
108
+ return english_caption, "", None
109
+ return "Загрузите изображение.", "", None
110
+
111
+ def generate_translation(english_caption: str, caption_model: str, translator_model: str) -> str:
112
+ if not english_caption or english_caption == "Загрузите изображение.":
113
+ return ""
114
+ pipeline = init_pipeline(caption_model, translator_model)
115
+ return pipeline.translate_caption(english_caption)
116
 
117
+ def generate_audio(english_caption: str, russian_caption: str, audio_language: str, caption_model: str, translator_model: str) -> str:
118
  if not english_caption and not russian_caption:
119
  return None
120
+ pipeline = init_pipeline(caption_model, translator_model)
121
  text = russian_caption if audio_language == "Русский" else english_caption
122
  return pipeline.generate_audio(text, audio_language)
123
 
124
+ with gr.Blocks(css=".btn {width: 200px; background-color: #4B0082; color: white; border: none; padding: 10px 20px; text-align: center; font-size: 16px; margin: 10px auto; display: block;} .equal-height { height: 60px; }") as iface:
125
  with gr.Row():
126
+ with gr.Column(scale=1, min_width=250, variant="panel"):
127
+ image = gr.Image(type="pil", label="Изображение", height=250, width=250)
128
+ caption_model = gr.Dropdown(choices=["BLIP", "BLIP-2"], label="Модель описания", value="BLIP")
129
  submit_button = gr.Button("Сгенерировать описание", elem_classes="btn")
130
  with gr.Column(scale=1, min_width=300):
131
  english_caption = gr.Textbox(label="Подпись English:", lines=2)
132
  russian_caption = gr.Textbox(label="Подпись Русский:", lines=2)
133
+ translator_model = gr.Dropdown(choices=["M2M100", "Helsinki"], label="Модель перевода", value="M2M100")
134
+ translate_button = gr.Button("Сгенерировать перевод", elem_classes="btn")
135
  audio_button = gr.Button("Сгенерировать озвучку", elem_classes="btn")
136
  with gr.Row():
137
  audio_language = gr.Dropdown(choices=["Русский", "English"], label="Язык озвучки", value="Русский", scale=1, min_width=150, elem_classes="equal-height")
138
  audio_output = gr.Audio(label="Озвучка", scale=1, min_width=150, elem_classes="equal-height")
139
 
140
  submit_button.click(
141
+ fn=generate_english_caption,
142
+ inputs=[image, caption_model, translator_model],
143
  outputs=[english_caption, russian_caption, audio_output]
144
  )
145
 
146
+ translate_button.click(
147
+ fn=generate_translation,
148
+ inputs=[english_caption, caption_model, translator_model],
149
+ outputs=[russian_caption]
150
+ )
151
+
152
  audio_button.click(
153
  fn=generate_audio,
154
+ inputs=[english_caption, russian_caption, audio_language, caption_model, translator_model],
155
  outputs=[audio_output]
156
  )
157