Spaces:
Runtime error
Runtime error
| # Importar bibliotecas | |
| import torch | |
| import re | |
| import random | |
| import requests | |
| import shutil | |
| from clip_interrogator import Config, Interrogator | |
| from transformers import pipeline, set_seed, AutoTokenizer, AutoModelForSeq2SeqLM | |
| from PIL import Image | |
| import gradio as gr | |
| # Configurar CLIP | |
| config = Config() | |
| config.device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| config.blip_offload = False if torch.cuda.is_available() else True | |
| config.chunk_size = 2048 | |
| config.flavor_intermediate_count = 512 | |
| config.blip_num_beams = 64 | |
| config.clip_model_name = "ViT-H-14/laion2b_s32b_b79k" | |
| ci = Interrogator(config) | |
| # FunciΓ³n para generar prompt desde imagen | |
| def get_prompt_from_image(image, mode): | |
| image = image.convert('RGB') | |
| if mode == 'best': | |
| prompt = ci.interrogate(image) | |
| elif mode == 'classic': | |
| prompt = ci.interrogate_classic(image) | |
| elif mode == 'fast': | |
| prompt = ci.interrogate_fast(image) | |
| elif mode == 'negative': | |
| prompt = ci.interrogate_negative(image) | |
| return prompt | |
| # FunciΓ³n para generar texto | |
| text_pipe = pipeline('text-generation', model='succinctly/text2image-prompt-generator') | |
| def text_generate(input): | |
| seed = random.randint(100, 1000000) | |
| set_seed(seed) | |
| for count in range(6): | |
| sequences = text_pipe(input, max_length=random.randint(60, 90), num_return_sequences=8) | |
| list = [] | |
| for sequence in sequences: | |
| line = sequence['generated_text'].strip() | |
| if line != input and len(line) > (len(input) + 4) and line.endswith((':', '-', 'β')) is False: | |
| list.append(line) | |
| result = "\n".join(list) | |
| result = re.sub('[^ ]+\.[^ ]+','', result) | |
| result = result.replace('<', '').replace('>', '') | |
| if result != '': | |
| return result | |
| if count == 5: | |
| return result | |
| # Crear interfaz gradio | |
| with gr.Blocks() as block: | |
| with gr.Column(): | |
| gr.HTML('<h1>MidJourney / SD2 Helper Tool</h1>') | |
| with gr.Tab('Generate from Image'): | |
| with gr.Row(): | |
| input_image = gr.Image(type='pil') | |
| with gr.Column(): | |
| input_mode = gr.Radio(['best', 'fast', 'classic', 'negative'], value='best', label='Mode') | |
| img_btn = gr.Button('Discover Image Prompt') | |
| output_image = gr.Textbox(lines=6, label='Generated Prompt') | |
| with gr.Tab('Generate from Text'): | |
| input_text = gr.Textbox(lines=6, label='Your Idea', placeholder='Enter your content here...') | |
| output_text = gr.Textbox(lines=6, label='Generated Prompt') | |
| text_btn = gr.Button('Generate Prompt') | |
| img_btn.click(fn=get_prompt_from_image, inputs=[input_image, input_mode], outputs=output_image) | |
| text_btn.click(fn=text_generate, inputs=input_text, outputs=output_text) | |
| block.queue(max_size=64).launch(show_api=False, enable_queue=True, debug=True, share=True, server_name='0.0.0.0') |