Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -4,8 +4,12 @@ from PIL import Image
|
|
4 |
from transformers import (
|
5 |
BlipProcessor,
|
6 |
BlipForConditionalGeneration,
|
|
|
|
|
7 |
M2M100Tokenizer,
|
8 |
-
M2M100ForConditionalGeneration
|
|
|
|
|
9 |
)
|
10 |
from typing import Union
|
11 |
from gtts import gTTS
|
@@ -14,60 +18,79 @@ import uuid
|
|
14 |
import time
|
15 |
import gc
|
16 |
|
17 |
-
# Оптимизация CPU: установка числа потоков
|
18 |
torch.set_num_threads(2)
|
19 |
-
|
20 |
-
# Глобальная переменная для кэширования pipeline
|
21 |
_pipeline = None
|
22 |
|
23 |
-
def init_pipeline():
|
24 |
global _pipeline
|
25 |
if _pipeline is None:
|
26 |
-
_pipeline = ImageCaptionPipeline()
|
27 |
return _pipeline
|
28 |
|
29 |
class ImageCaptionPipeline:
|
30 |
-
def __init__(self):
|
31 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
|
|
32 |
start_time = time.time()
|
33 |
-
|
34 |
-
|
35 |
-
|
|
|
|
|
|
|
|
|
36 |
|
37 |
start_time = time.time()
|
38 |
-
|
39 |
-
|
40 |
-
|
|
|
|
|
|
|
|
|
41 |
|
42 |
-
def
|
43 |
start_time = time.time()
|
44 |
if isinstance(image, str):
|
45 |
image = Image.open(image)
|
46 |
image = image.convert("RGB")
|
47 |
-
image = image.resize((384, 384))
|
48 |
inputs = self.blip_processor(images=image, return_tensors="pt").to(self.device)
|
49 |
with torch.no_grad():
|
50 |
output_ids = self.blip_model.generate(**inputs, max_length=50, num_beams=2, early_stopping=True)
|
51 |
english_caption = self.blip_processor.decode(output_ids[0], skip_special_tokens=True)
|
|
|
52 |
print(f"Время генерации английской подписи: {time.time() - start_time:.2f} секунд")
|
53 |
|
|
|
|
|
|
|
|
|
54 |
start_time = time.time()
|
55 |
-
self.
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
print(f"Время перевода на русский: {time.time() - start_time:.2f} секунд")
|
67 |
|
68 |
-
# Освобождение памяти
|
69 |
gc.collect()
|
70 |
-
return
|
71 |
|
72 |
def generate_audio(self, text: str, language: str) -> str:
|
73 |
start_time = time.time()
|
@@ -78,42 +101,57 @@ class ImageCaptionPipeline:
|
|
78 |
print(f"Время генерации озвучки: {time.time() - start_time:.2f} секунд")
|
79 |
return audio_path
|
80 |
|
81 |
-
def
|
82 |
if image is not None:
|
83 |
-
pipeline = init_pipeline()
|
84 |
-
english_caption
|
85 |
-
return english_caption,
|
86 |
-
return "Загрузите изображение.", "
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
|
88 |
-
def generate_audio(english_caption: str, russian_caption: str, audio_language: str) -> str:
|
89 |
if not english_caption and not russian_caption:
|
90 |
return None
|
91 |
-
pipeline = init_pipeline()
|
92 |
text = russian_caption if audio_language == "Русский" else english_caption
|
93 |
return pipeline.generate_audio(text, audio_language)
|
94 |
|
95 |
-
with gr.Blocks(css=".btn {width: 200px; background-color: #
|
96 |
with gr.Row():
|
97 |
-
with gr.Column(scale=1, min_width=
|
98 |
-
image = gr.Image(type="pil", label="Изображение", height=
|
|
|
99 |
submit_button = gr.Button("Сгенерировать описание", elem_classes="btn")
|
100 |
with gr.Column(scale=1, min_width=300):
|
101 |
english_caption = gr.Textbox(label="Подпись English:", lines=2)
|
102 |
russian_caption = gr.Textbox(label="Подпись Русский:", lines=2)
|
|
|
|
|
103 |
audio_button = gr.Button("Сгенерировать озвучку", elem_classes="btn")
|
104 |
with gr.Row():
|
105 |
audio_language = gr.Dropdown(choices=["Русский", "English"], label="Язык озвучки", value="Русский", scale=1, min_width=150, elem_classes="equal-height")
|
106 |
audio_output = gr.Audio(label="Озвучка", scale=1, min_width=150, elem_classes="equal-height")
|
107 |
|
108 |
submit_button.click(
|
109 |
-
fn=
|
110 |
-
inputs=[image],
|
111 |
outputs=[english_caption, russian_caption, audio_output]
|
112 |
)
|
113 |
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
audio_button.click(
|
115 |
fn=generate_audio,
|
116 |
-
inputs=[english_caption, russian_caption, audio_language],
|
117 |
outputs=[audio_output]
|
118 |
)
|
119 |
|
|
|
4 |
from transformers import (
|
5 |
BlipProcessor,
|
6 |
BlipForConditionalGeneration,
|
7 |
+
Blip2Processor,
|
8 |
+
Blip2ForConditionalGeneration,
|
9 |
M2M100Tokenizer,
|
10 |
+
M2M100ForConditionalGeneration,
|
11 |
+
AutoTokenizer,
|
12 |
+
AutoModelForSeq2SeqLM
|
13 |
)
|
14 |
from typing import Union
|
15 |
from gtts import gTTS
|
|
|
18 |
import time
|
19 |
import gc
|
20 |
|
|
|
21 |
torch.set_num_threads(2)
|
|
|
|
|
22 |
_pipeline = None
|
23 |
|
24 |
+
def init_pipeline(caption_model: str, translator_model: str):
|
25 |
global _pipeline
|
26 |
if _pipeline is None:
|
27 |
+
_pipeline = ImageCaptionPipeline(caption_model, translator_model)
|
28 |
return _pipeline
|
29 |
|
30 |
class ImageCaptionPipeline:
|
31 |
+
def __init__(self, caption_model: str, translator_model: str):
|
32 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
33 |
+
self.caption_model = caption_model
|
34 |
+
self.translator_model = translator_model
|
35 |
+
|
36 |
start_time = time.time()
|
37 |
+
if caption_model == "BLIP":
|
38 |
+
self.blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large", use_fast=True)
|
39 |
+
self.blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large").to(self.device)
|
40 |
+
else:
|
41 |
+
self.blip_processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
|
42 |
+
self.blip_model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b").to(self.device)
|
43 |
+
print(f"Время загрузки {caption_model}: {time.time() - start_time:.2f} секунд")
|
44 |
|
45 |
start_time = time.time()
|
46 |
+
if translator_model == "M2M100":
|
47 |
+
self.translator_tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M")
|
48 |
+
self.translator_model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M").to(self.device)
|
49 |
+
else:
|
50 |
+
self.translator_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-ru")
|
51 |
+
self.translator_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-ru").to(self.device)
|
52 |
+
print(f"Время загрузки переводчика {translator_model}: {time.time() - start_time:.2f} секунд")
|
53 |
|
54 |
+
def generate_english_caption(self, image: Union[str, Image.Image]) -> str:
|
55 |
start_time = time.time()
|
56 |
if isinstance(image, str):
|
57 |
image = Image.open(image)
|
58 |
image = image.convert("RGB")
|
59 |
+
image = image.resize((384, 384))
|
60 |
inputs = self.blip_processor(images=image, return_tensors="pt").to(self.device)
|
61 |
with torch.no_grad():
|
62 |
output_ids = self.blip_model.generate(**inputs, max_length=50, num_beams=2, early_stopping=True)
|
63 |
english_caption = self.blip_processor.decode(output_ids[0], skip_special_tokens=True)
|
64 |
+
english_caption = english_caption[0].upper() + english_caption[1:] + ('.' if not english_caption.endswith('.') else '')
|
65 |
print(f"Время генерации английской подписи: {time.time() - start_time:.2f} секунд")
|
66 |
|
67 |
+
gc.collect()
|
68 |
+
return english_caption
|
69 |
+
|
70 |
+
def translate_caption(self, english_caption: str) -> str:
|
71 |
start_time = time.time()
|
72 |
+
if self.translator_model == "M2M100":
|
73 |
+
self.translator_tokenizer.src_lang = "en"
|
74 |
+
translated_inputs = self.translator_tokenizer(english_caption, return_tensors="pt", padding=True).to(self.device)
|
75 |
+
with torch.no_grad():
|
76 |
+
translated_ids = self.translator_model.generate(
|
77 |
+
**translated_inputs,
|
78 |
+
forced_bos_token_id=self.translator_tokenizer.get_lang_id("ru"),
|
79 |
+
max_length=50,
|
80 |
+
num_beams=2,
|
81 |
+
early_stopping=True
|
82 |
+
)
|
83 |
+
russian_caption = self.translator_tokenizer.decode(translated_ids[0], skip_special_tokens=True)
|
84 |
+
else:
|
85 |
+
translated_inputs = self.translator_tokenizer(english_caption, return_tensors="pt", padding=True).to(self.device)
|
86 |
+
with torch.no_grad():
|
87 |
+
translated_ids = self.translator_model.generate(**translated_inputs, max_length=50, num_beams=2, early_stopping=True)
|
88 |
+
russian_caption = self.translator_tokenizer.decode(translated_ids[0], skip_special_tokens=True)
|
89 |
+
russian_caption = russian_caption[0].upper() + russian_caption[1:] + ('.' if not russian_caption.endswith('.') else '')
|
90 |
print(f"Время перевода на русский: {time.time() - start_time:.2f} секунд")
|
91 |
|
|
|
92 |
gc.collect()
|
93 |
+
return russian_caption
|
94 |
|
95 |
def generate_audio(self, text: str, language: str) -> str:
|
96 |
start_time = time.time()
|
|
|
101 |
print(f"Время генерации озвучки: {time.time() - start_time:.2f} секунд")
|
102 |
return audio_path
|
103 |
|
104 |
+
def generate_english_caption(image: Image.Image, caption_model: str, translator_model: str) -> tuple:
|
105 |
if image is not None:
|
106 |
+
pipeline = init_pipeline(caption_model, translator_model)
|
107 |
+
english_caption = pipeline.generate_english_caption(image)
|
108 |
+
return english_caption, "", None
|
109 |
+
return "Загрузите изображение.", "", None
|
110 |
+
|
111 |
+
def generate_translation(english_caption: str, caption_model: str, translator_model: str) -> str:
|
112 |
+
if not english_caption or english_caption == "Загрузите изображение.":
|
113 |
+
return ""
|
114 |
+
pipeline = init_pipeline(caption_model, translator_model)
|
115 |
+
return pipeline.translate_caption(english_caption)
|
116 |
|
117 |
+
def generate_audio(english_caption: str, russian_caption: str, audio_language: str, caption_model: str, translator_model: str) -> str:
|
118 |
if not english_caption and not russian_caption:
|
119 |
return None
|
120 |
+
pipeline = init_pipeline(caption_model, translator_model)
|
121 |
text = russian_caption if audio_language == "Русский" else english_caption
|
122 |
return pipeline.generate_audio(text, audio_language)
|
123 |
|
124 |
+
with gr.Blocks(css=".btn {width: 200px; background-color: #4B0082; color: white; border: none; padding: 10px 20px; text-align: center; font-size: 16px; margin: 10px auto; display: block;} .equal-height { height: 60px; }") as iface:
|
125 |
with gr.Row():
|
126 |
+
with gr.Column(scale=1, min_width=250, variant="panel"):
|
127 |
+
image = gr.Image(type="pil", label="Изображение", height=250, width=250)
|
128 |
+
caption_model = gr.Dropdown(choices=["BLIP", "BLIP-2"], label="Модель описания", value="BLIP")
|
129 |
submit_button = gr.Button("Сгенерировать описание", elem_classes="btn")
|
130 |
with gr.Column(scale=1, min_width=300):
|
131 |
english_caption = gr.Textbox(label="Подпись English:", lines=2)
|
132 |
russian_caption = gr.Textbox(label="Подпись Русский:", lines=2)
|
133 |
+
translator_model = gr.Dropdown(choices=["M2M100", "Helsinki"], label="Модель перевода", value="M2M100")
|
134 |
+
translate_button = gr.Button("Сгенерировать перевод", elem_classes="btn")
|
135 |
audio_button = gr.Button("Сгенерировать озвучку", elem_classes="btn")
|
136 |
with gr.Row():
|
137 |
audio_language = gr.Dropdown(choices=["Русский", "English"], label="Язык озвучки", value="Русский", scale=1, min_width=150, elem_classes="equal-height")
|
138 |
audio_output = gr.Audio(label="Озвучка", scale=1, min_width=150, elem_classes="equal-height")
|
139 |
|
140 |
submit_button.click(
|
141 |
+
fn=generate_english_caption,
|
142 |
+
inputs=[image, caption_model, translator_model],
|
143 |
outputs=[english_caption, russian_caption, audio_output]
|
144 |
)
|
145 |
|
146 |
+
translate_button.click(
|
147 |
+
fn=generate_translation,
|
148 |
+
inputs=[english_caption, caption_model, translator_model],
|
149 |
+
outputs=[russian_caption]
|
150 |
+
)
|
151 |
+
|
152 |
audio_button.click(
|
153 |
fn=generate_audio,
|
154 |
+
inputs=[english_caption, russian_caption, audio_language, caption_model, translator_model],
|
155 |
outputs=[audio_output]
|
156 |
)
|
157 |
|