Spaces:
Running
Running
import gradio as gr | |
import torch | |
from PIL import Image | |
from transformers import ( | |
BlipProcessor, | |
BlipForConditionalGeneration, | |
AutoTokenizer, | |
AutoModelForSeq2SeqLM | |
) | |
from typing import Union | |
from gtts import gTTS | |
import os | |
class ImageCaptionPipeline: | |
def __init__(self): | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
self.blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large", use_fast=True) | |
self.blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large").to(self.device) | |
self.translator_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-ru") | |
self.translator_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-ru").to(self.device) | |
def generate_caption(self, image: Union[str, Image.Image], language: str = "Русский") -> str: | |
if isinstance(image, str): | |
image = Image.open(image) | |
image = image.convert("RGB") | |
inputs = self.blip_processor(images=image, return_tensors="pt").to(self.device) | |
with torch.no_grad(): | |
output_ids = self.blip_model.generate(**inputs, max_length=200, num_beams=4) | |
english_caption = self.blip_processor.decode(output_ids[0], skip_special_tokens=True) | |
if language == "Русский": | |
translated_inputs = self.translator_tokenizer(english_caption, return_tensors="pt", padding=True).to(self.device) | |
with torch.no_grad(): | |
translated_ids = self.translator_model.generate(**translated_inputs, max_length=200, num_beams=4) | |
russian_caption = self.translator_tokenizer.decode(translated_ids[0], skip_special_tokens=True) | |
return russian_caption | |
return english_caption | |
def app(image: Image.Image, language: str) -> tuple: | |
if image is not None: | |
pipeline = ImageCaptionPipeline() | |
caption = pipeline.generate_caption(image, language=language) | |
lang_code = "ru" if language == "Русский" else "en" | |
tts = gTTS(text=caption, lang=lang_code) | |
audio_path = "caption_audio.mp3" | |
tts.save(audio_path) | |
return caption, audio_path | |
return "Загрузите изображение и выберите язык для получения подписи.", None | |
with gr.Blocks() as iface: | |
gr.Markdown("# Генератор подписей") | |
gr.Markdown("Загрузите изображение и выберите язык.") | |
language = gr.Dropdown(choices=["Русский", "English"], label="Язык", value="Русский") | |
image = gr.Image(type="pil", label="Изображение", height=400, width=400) | |
submit_button = gr.Button("Сгенерировать", elem_classes="btn") | |
caption_output = gr.Textbox(label="Подпись") | |
audio_output = gr.Audio(label="Озвучка") | |
submit_button.click( | |
fn=app, | |
inputs=[image, language], | |
outputs=[caption_output, audio_output] | |
) | |
if __name__ == "__main__": | |
iface.launch() |