Zguin commited on
Commit
deaf141
·
verified ·
1 Parent(s): eb2f678

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -64
app.py CHANGED
@@ -1,65 +1,65 @@
1
- import gradio as gr
2
- import torch
3
- from PIL import Image
4
- from transformers import (
5
- BlipProcessor,
6
- BlipForConditionalGeneration,
7
- AutoTokenizer,
8
- AutoModelForSeq2SeqLM
9
- )
10
- from typing import Union
11
- from gtts import gTTS
12
- import os
13
-
14
- class ImageCaptionPipeline:
15
- def __init__(self):
16
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
17
- self.blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
18
- self.blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large").to(self.device)
19
- self.translator_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-ru")
20
- self.translator_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-ru").to(self.device)
21
-
22
- def generate_caption(self, image: Union[str, Image.Image], language: str = "Русский") -> str:
23
- if isinstance(image, str):
24
- image = Image.open(image)
25
- image = image.convert("RGB")
26
- inputs = self.blip_processor(images=image, return_tensors="pt").to(self.device)
27
- with torch.no_grad():
28
- output_ids = self.blip_model.generate(**inputs, max_length=200, num_beams=4)
29
- english_caption = self.blip_processor.decode(output_ids[0], skip_special_tokens=True)
30
- if language == "Русский":
31
- translated_inputs = self.translator_tokenizer(english_caption, return_tensors="pt", padding=True).to(self.device)
32
- with torch.no_grad():
33
- translated_ids = self.translator_model.generate(**translated_inputs, max_length=200, num_beams=4)
34
- russian_caption = self.translator_tokenizer.decode(translated_ids[0], skip_special_tokens=True)
35
- return russian_caption
36
- return english_caption
37
-
38
- def app(image: Image.Image, language: str) -> tuple:
39
- if image is not None:
40
- pipeline = ImageCaptionPipeline()
41
- caption = pipeline.generate_caption(image, language=language)
42
- lang_code = "ru" if language == "Русский" else "en"
43
- tts = gTTS(text=caption, lang=lang_code)
44
- audio_path = "caption_audio.mp3"
45
- tts.save(audio_path)
46
- return caption, audio_path
47
- return "Загрузите изображение и выберите язык для получения подписи.", None
48
-
49
- with gr.Blocks() as iface:
50
- gr.Markdown("# Генератор подписей")
51
- gr.Markdown("Загрузите изображение и выберите язык.")
52
- language = gr.Dropdown(choices=["Русский", "English"], label="Язык", value="Русский")
53
- image = gr.Image(type="pil", label="Изображение", height=400, width=400)
54
- submit_button = gr.Button("Сгенерировать", elem_classes="btn")
55
- caption_output = gr.Textbox(label="Подпись")
56
- audio_output = gr.Audio(label="Озвучка")
57
-
58
- submit_button.click(
59
- fn=app,
60
- inputs=[image, language],
61
- outputs=[caption_output, audio_output]
62
- )
63
-
64
- if __name__ == "__main__":
65
  iface.launch()
 
1
+ import gradio as gr
2
+ import torch
3
+ from PIL import Image
4
+ from transformers import (
5
+ BlipProcessor,
6
+ BlipForConditionalGeneration,
7
+ AutoTokenizer,
8
+ AutoModelForSeq2SeqLM
9
+ )
10
+ from typing import Union
11
+ from gtts import gTTS
12
+ import os
13
+
14
+ class ImageCaptionPipeline:
15
+ def __init__(self):
16
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
17
+ self.blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large", use_fast=True)
18
+ self.blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large").to(self.device)
19
+ self.translator_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-ru")
20
+ self.translator_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-ru").to(self.device)
21
+
22
+ def generate_caption(self, image: Union[str, Image.Image], language: str = "Русский") -> str:
23
+ if isinstance(image, str):
24
+ image = Image.open(image)
25
+ image = image.convert("RGB")
26
+ inputs = self.blip_processor(images=image, return_tensors="pt").to(self.device)
27
+ with torch.no_grad():
28
+ output_ids = self.blip_model.generate(**inputs, max_length=200, num_beams=4)
29
+ english_caption = self.blip_processor.decode(output_ids[0], skip_special_tokens=True)
30
+ if language == "Русский":
31
+ translated_inputs = self.translator_tokenizer(english_caption, return_tensors="pt", padding=True).to(self.device)
32
+ with torch.no_grad():
33
+ translated_ids = self.translator_model.generate(**translated_inputs, max_length=200, num_beams=4)
34
+ russian_caption = self.translator_tokenizer.decode(translated_ids[0], skip_special_tokens=True)
35
+ return russian_caption
36
+ return english_caption
37
+
38
+ def app(image: Image.Image, language: str) -> tuple:
39
+ if image is not None:
40
+ pipeline = ImageCaptionPipeline()
41
+ caption = pipeline.generate_caption(image, language=language)
42
+ lang_code = "ru" if language == "Русский" else "en"
43
+ tts = gTTS(text=caption, lang=lang_code)
44
+ audio_path = "caption_audio.mp3"
45
+ tts.save(audio_path)
46
+ return caption, audio_path
47
+ return "Загрузите изображение и выберите язык для получения подписи.", None
48
+
49
+ with gr.Blocks() as iface:
50
+ gr.Markdown("# Генератор подписей")
51
+ gr.Markdown("Загрузите изображение и выберите язык.")
52
+ language = gr.Dropdown(choices=["Русский", "English"], label="Язык", value="Русский")
53
+ image = gr.Image(type="pil", label="Изображение", height=400, width=400)
54
+ submit_button = gr.Button("Сгенерировать", elem_classes="btn")
55
+ caption_output = gr.Textbox(label="Подпись")
56
+ audio_output = gr.Audio(label="Озвучка")
57
+
58
+ submit_button.click(
59
+ fn=app,
60
+ inputs=[image, language],
61
+ outputs=[caption_output, audio_output]
62
+ )
63
+
64
+ if __name__ == "__main__":
65
  iface.launch()