sdfsdsdsd / appxx.py
Uhhy's picture
Update appxx.py
331aa5a verified
raw
history blame
2.61 kB
import gradio as gr
from PIL import Image
import torch
from diffusers import StableDiffusionImg2ImgPipeline
import spaces
# Funci贸n para inicializar el pipeline de Stable Diffusion utilizando GPU
@spaces.GPU(duration=120) # Solicita el uso de GPU por 120 minutos
def initialize_pipeline():
device = "cuda" if torch.cuda.is_available() else "cpu"
model_id = "CompVis/stable-diffusion-v1-4" # Cambia esto al ID del modelo adecuado
try:
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(model_id).to(device)
except ValueError as e:
return f"Error cargando el pipeline: {e}"
return pipe
# Configuraci贸n inicial del modelo
pipe = None
def setup_pipeline():
global pipe
if pipe is None:
pipe = initialize_pipeline()
def cartoonize_image(image, prompt):
setup_pipeline() # Aseg煤rate de que el pipeline est茅 inicializado
try:
# Preprocesar imagen
image = image.convert("RGB")
image = image.resize((512, 512)) # Ajusta el tama帽o seg煤n sea necesario
# Generar imagen con Stable Diffusion Img2Img
with torch.no_grad():
result = pipe(prompt=prompt, init_image=image, strength=0.75).images[0]
return result
except Exception as e:
return f"Error procesando la imagen: {e}"
def main():
with gr.Blocks() as demo:
gr.Markdown("## Conversi贸n de Imagen a Caricatura Estilo Pixar")
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", label="Sube tu imagen")
prompt = gr.Textbox(label="Prompt", value="A Pixar-style cartoon character")
with gr.Row():
submit_button = gr.Button("Submit")
stop_button = gr.Button("Stop")
output_image = gr.Image(type="pil", label="Imagen Caricaturizada")
def process_image(image, prompt):
return cartoonize_image(image, prompt)
submit_button.click(
fn=process_image,
inputs=[input_image, prompt],
outputs=output_image
)
# La funcionalidad de stop en un entorno de Gradio no se puede manejar f谩cilmente,
# pero puedes usar controladores adicionales o l贸gica para detener la ejecuci贸n si es necesario.
stop_button.click(
fn=lambda: "El proceso ha sido detenido", # Mensaje de placeholder
outputs=output_image
)
demo.launch()
if __name__ == "__main__":
main()