Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import random | |
| import torch | |
| import gradio as gr | |
| import numpy as np | |
| import spaces | |
| from diffusers import DiffusionPipeline | |
| from PIL import Image | |
| # --- [Optional Patch] --------------------------------------------------------- | |
| # This patch fixes potential JSON schema parsing issues in Gradio/Gradio-Client. | |
| import gradio_client.utils | |
| original_json_schema = gradio_client.utils._json_schema_to_python_type | |
| from PIL import ImageOps, ExifTags | |
| def preprocess_image(image): | |
| # EXIF 정보에 따라 이미지 회전 조정 | |
| try: | |
| image = ImageOps.exif_transpose(image) | |
| except Exception as e: | |
| print(f"EXIF 변환 오류: {e}") | |
| # 이미지 크기 조정 (너무 크면 모델이 제대로 처리하지 못할 수 있음) | |
| if max(image.width, image.height) > 1024: | |
| image.thumbnail((1024, 1024), Image.LANCZOS) | |
| # 이미지 모드 확인 및 변환 | |
| if image.mode != "RGB": | |
| image = image.convert("RGB") | |
| return image | |
| # DELETE THIS LINE COMPLETELY | |
| def patched_json_schema(schema, defs=None): | |
| # Handle boolean schema directly | |
| if isinstance(schema, bool): | |
| return "bool" | |
| # If 'additionalProperties' is a boolean, replace it with a generic type | |
| try: | |
| if "additionalProperties" in schema and isinstance(schema["additionalProperties"], bool): | |
| schema["additionalProperties"] = {"type": "any"} | |
| except (TypeError, KeyError): | |
| pass | |
| # Attempt to parse normally; fallback to "any" on error | |
| try: | |
| return original_json_schema(schema, defs) | |
| except Exception: | |
| return "any" | |
| gradio_client.utils._json_schema_to_python_type = patched_json_schema | |
| # ----------------------------------------------------------------------------- | |
| # ----------------------------- Model Loading ---------------------------------- | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| repo_id = "black-forest-labs/FLUX.1-dev" | |
| adapter_id = "openfree/flux-chatgpt-ghibli-lora" | |
| def load_model_with_retry(max_retries=5): | |
| for attempt in range(max_retries): | |
| try: | |
| print(f"Loading model attempt {attempt+1}/{max_retries}...") | |
| pipeline = DiffusionPipeline.from_pretrained( | |
| repo_id, | |
| torch_dtype=torch.bfloat16, | |
| use_safetensors=True, | |
| resume_download=True | |
| ) | |
| print("Base model loaded successfully, now loading LoRA weights...") | |
| pipeline.load_lora_weights(adapter_id) | |
| pipeline = pipeline.to(device) | |
| print("Pipeline is ready!") | |
| return pipeline | |
| except Exception as e: | |
| if attempt < max_retries - 1: | |
| wait_time = 10 * (attempt + 1) | |
| print(f"Error loading model: {e}. Retrying in {wait_time} seconds...") | |
| import time | |
| time.sleep(wait_time) | |
| else: | |
| raise Exception(f"Failed to load model after {max_retries} attempts: {e}") | |
| pipeline = load_model_with_retry() | |
| # ----------------------------- Inference Function ----------------------------- | |
| MAX_SEED = np.iinfo(np.int32).max | |
| MAX_IMAGE_SIZE = 1024 | |
| def inference( | |
| prompt: str, | |
| seed: int, | |
| randomize_seed: bool, | |
| width: int, | |
| height: int, | |
| guidance_scale: float, | |
| num_inference_steps: int, | |
| lora_scale: float, | |
| ): | |
| # If "randomize_seed" is selected, choose a random seed | |
| if randomize_seed: | |
| seed = random.randint(0, MAX_SEED) | |
| generator = torch.Generator(device=device).manual_seed(seed) | |
| print(f"Running inference with prompt: {prompt}") | |
| try: | |
| image = pipeline( | |
| prompt=prompt, | |
| guidance_scale=guidance_scale, | |
| num_inference_steps=num_inference_steps, | |
| width=width, | |
| height=height, | |
| generator=generator, | |
| joint_attention_kwargs={"scale": lora_scale}, | |
| ).images[0] | |
| return image, seed | |
| except Exception as e: | |
| print(f"Error during inference: {e}") | |
| # Return a red error image of the specified size and the used seed | |
| error_img = Image.new('RGB', (width, height), color='red') | |
| return error_img, seed | |
| # ----------------------------- Florence-2 Captioner --------------------------- | |
| import subprocess | |
| try: | |
| subprocess.run( | |
| 'pip install flash-attn --no-build-isolation', | |
| env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, | |
| shell=True | |
| ) | |
| except Exception as e: | |
| print(f"Warning: Could not install flash-attn: {e}") | |
| from transformers import AutoProcessor, AutoModelForCausalLM | |
| # Function to safely load models | |
| def load_caption_model(model_name): | |
| try: | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, trust_remote_code=True | |
| ).eval() | |
| processor = AutoProcessor.from_pretrained( | |
| model_name, trust_remote_code=True | |
| ) | |
| return model, processor | |
| except Exception as e: | |
| print(f"Error loading caption model {model_name}: {e}") | |
| return None, None | |
| # Pre-load models and processors | |
| print("Loading captioning models...") | |
| default_caption_model = 'microsoft/Florence-2-large' | |
| models = {} | |
| processors = {} | |
| # Try to load the default model | |
| default_model, default_processor = load_caption_model(default_caption_model) | |
| if default_model is not None and default_processor is not None: | |
| models[default_caption_model] = default_model | |
| processors[default_caption_model] = default_processor | |
| print(f"Successfully loaded default caption model: {default_caption_model}") | |
| else: | |
| # Fallback to simpler model | |
| fallback_model = 'gokaygokay/Florence-2-Flux' | |
| fallback_model_obj, fallback_processor = load_caption_model(fallback_model) | |
| if fallback_model_obj is not None and fallback_processor is not None: | |
| models[fallback_model] = fallback_model_obj | |
| processors[fallback_model] = fallback_processor | |
| default_caption_model = fallback_model | |
| print(f"Loaded fallback caption model: {fallback_model}") | |
| else: | |
| print("WARNING: Failed to load any caption model!") | |
| def caption_image(image, model_name=default_caption_model): | |
| """ | |
| Runs the selected Florence-2 model to generate a detailed caption. | |
| """ | |
| from PIL import Image as PILImage | |
| import numpy as np | |
| print(f"Starting caption generation with model: {model_name}") | |
| # Handle case where image is already a PIL image | |
| if isinstance(image, PILImage.Image): | |
| pil_image = image | |
| else: | |
| # Convert numpy array to PIL | |
| if isinstance(image, np.ndarray): | |
| pil_image = PILImage.fromarray(image) | |
| else: | |
| print(f"Unexpected image type: {type(image)}") | |
| return "Error: Unsupported image type" | |
| # Convert input to RGB if needed | |
| if pil_image.mode != "RGB": | |
| pil_image = pil_image.convert("RGB") | |
| # Check if model is available | |
| if model_name not in models or model_name not in processors: | |
| available_models = list(models.keys()) | |
| if available_models: | |
| model_name = available_models[0] | |
| print(f"Requested model not available, using: {model_name}") | |
| else: | |
| return "Error: No caption models available" | |
| model = models[model_name] | |
| processor = processors[model_name] | |
| task_prompt = "<DESCRIPTION>" | |
| user_prompt = task_prompt + "Describe this image in great detail." | |
| try: | |
| inputs = processor(text=user_prompt, images=pil_image, return_tensors="pt") | |
| generated_ids = model.generate( | |
| input_ids=inputs["input_ids"], | |
| pixel_values=inputs["pixel_values"], | |
| max_new_tokens=1024, | |
| num_beams=3, | |
| repetition_penalty=1.10, | |
| ) | |
| generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] | |
| parsed_answer = processor.post_process_generation( | |
| generated_text, task=task_prompt, image_size=(pil_image.width, pil_image.height) | |
| ) | |
| # Extract the caption | |
| caption = parsed_answer.get("<DESCRIPTION>", "") | |
| print(f"Generated caption: {caption}") | |
| return caption | |
| except Exception as e: | |
| print(f"Error during captioning: {e}") | |
| return f"Error generating caption: {str(e)}" | |
| # --------- Process uploaded image and generate Ghibli style image --------- | |
| def process_uploaded_image( | |
| image, | |
| seed, | |
| randomize_seed, | |
| width, | |
| height, | |
| guidance_scale, | |
| num_inference_steps, | |
| lora_scale | |
| ): | |
| if image is None: | |
| print("No image provided") | |
| return None, None, "No image provided", "No image provided" | |
| print("Starting image processing workflow") | |
| # Step 1: Generate caption from the uploaded image | |
| try: | |
| caption = caption_image(image) | |
| if caption.startswith("Error:"): | |
| print(f"Captioning failed: {caption}") | |
| # Use a default caption as fallback | |
| caption = "A beautiful scene" | |
| except Exception as e: | |
| print(f"Exception during captioning: {e}") | |
| caption = "A beautiful scene" | |
| # Step 2: Append "ghibli style" to the caption | |
| ghibli_prompt = f"{caption}, ghibli style" | |
| print(f"Final prompt for Ghibli generation: {ghibli_prompt}") | |
| # Step 3: Generate Ghibli-style image based on the caption | |
| try: | |
| generated_image, used_seed = inference( | |
| prompt=ghibli_prompt, | |
| seed=seed, | |
| randomize_seed=randomize_seed, | |
| width=width, | |
| height=height, | |
| guidance_scale=guidance_scale, | |
| num_inference_steps=num_inference_steps, | |
| lora_scale=lora_scale | |
| ) | |
| print(f"Image generation complete with seed: {used_seed}") | |
| return generated_image, used_seed, caption, ghibli_prompt | |
| except Exception as e: | |
| print(f"Error generating image: {e}") | |
| error_img = Image.new('RGB', (width, height), color='red') | |
| return error_img, seed, caption, ghibli_prompt | |
| # Define Ghibli Studio Theme | |
| ghibli_theme = gr.themes.Soft( | |
| primary_hue="indigo", | |
| secondary_hue="blue", | |
| neutral_hue="slate", | |
| font=[gr.themes.GoogleFont("Nunito"), "ui-sans-serif", "sans-serif"], | |
| radius_size=gr.themes.sizes.radius_sm, | |
| ).set( | |
| body_background_fill="#f0f9ff", | |
| body_background_fill_dark="#0f172a", | |
| button_primary_background_fill="#6366f1", | |
| button_primary_background_fill_hover="#4f46e5", | |
| button_primary_text_color="#ffffff", | |
| block_title_text_weight="600", | |
| block_border_width="1px", | |
| block_shadow="0 4px 6px -1px rgb(0 0 0 / 0.1), 0 2px 4px -2px rgb(0 0 0 / 0.1)", | |
| ) | |
| # Custom CSS for enhanced visuals | |
| custom_css = """ | |
| .gradio-container { | |
| max-width: 1200px !important; | |
| } | |
| .main-header { | |
| text-align: center; | |
| margin-bottom: 1rem; | |
| font-weight: 800; | |
| font-size: 2.5rem; | |
| background: linear-gradient(90deg, #4338ca, #3b82f6); | |
| -webkit-background-clip: text; | |
| -webkit-text-fill-color: transparent; | |
| padding: 0.5rem; | |
| } | |
| .tagline { | |
| text-align: center; | |
| font-size: 1.2rem; | |
| margin-bottom: 2rem; | |
| color: #4b5563; | |
| } | |
| .image-preview { | |
| border-radius: 12px; | |
| overflow: hidden; | |
| box-shadow: 0 10px 15px -3px rgb(0 0 0 / 0.1), 0 4px 6px -4px rgb(0 0 0 / 0.1); | |
| } | |
| .panel-box { | |
| border-radius: 12px; | |
| background-color: rgba(255, 255, 255, 0.8); | |
| padding: 1rem; | |
| box-shadow: 0 4px 6px -1px rgb(0 0 0 / 0.1), 0 2px 4px -2px rgb(0 0 0 / 0.1); | |
| } | |
| .control-panel { | |
| padding: 1rem; | |
| border-radius: 12px; | |
| background-color: rgba(255, 255, 255, 0.9); | |
| margin-bottom: 1rem; | |
| border: 1px solid #e2e8f0; | |
| } | |
| .section-header { | |
| font-weight: 600; | |
| font-size: 1.1rem; | |
| margin-bottom: 0.5rem; | |
| color: #4338ca; | |
| } | |
| .transform-button { | |
| font-weight: 600 !important; | |
| margin-top: 1rem !important; | |
| } | |
| .footer { | |
| text-align: center; | |
| color: #6b7280; | |
| margin-top: 2rem; | |
| font-size: 0.9rem; | |
| } | |
| .output-panel { | |
| background: linear-gradient(135deg, #f0f9ff, #e0f2fe); | |
| border-radius: 12px; | |
| padding: 1rem; | |
| border: 1px solid #bfdbfe; | |
| } | |
| """ | |
| # ----------------------------- Gradio UI -------------------------------------- | |
| with gr.Blocks(analytics_enabled=False, theme=ghibli_theme, css=custom_css) as demo: | |
| gr.HTML( | |
| """ | |
| <div class="main-header">Open Ghibli Studio</div> | |
| <div class="tagline">Transform your photos into magical Ghibli-inspired artwork</div> | |
| """ | |
| ) | |
| # Background image for the app | |
| gr.HTML( | |
| """ | |
| <style> | |
| body { | |
| background-image: url('https://i.imgur.com/LxPQPR1.jpg'); | |
| background-size: cover; | |
| background-position: center; | |
| background-attachment: fixed; | |
| background-repeat: no-repeat; | |
| background-color: #f0f9ff; | |
| } | |
| @media (max-width: 768px) { | |
| body { | |
| background-size: contain; | |
| } | |
| } | |
| </style> | |
| """ | |
| ) | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=1): | |
| with gr.Group(elem_classes="panel-box"): | |
| gr.HTML('<div class="section-header">Upload Image</div>') | |
| upload_img = gr.Image( | |
| label="Drop your image here", | |
| type="pil", | |
| elem_classes="image-preview", | |
| height=400 | |
| ) | |
| with gr.Accordion("Advanced Settings", open=False): | |
| with gr.Group(elem_classes="control-panel"): | |
| gr.HTML('<div class="section-header">Generation Controls</div>') | |
| with gr.Row(): | |
| img2img_seed = gr.Slider( | |
| label="Seed", | |
| minimum=0, | |
| maximum=MAX_SEED, | |
| step=1, | |
| value=42, | |
| info="Set a specific seed for reproducible results" | |
| ) | |
| img2img_randomize_seed = gr.Checkbox( | |
| label="Randomize Seed", | |
| value=True, | |
| info="Enable to get different results each time" | |
| ) | |
| with gr.Group(): | |
| gr.HTML('<div class="section-header">Image Size</div>') | |
| with gr.Row(): | |
| img2img_width = gr.Slider( | |
| label="Width", | |
| minimum=256, | |
| maximum=MAX_IMAGE_SIZE, | |
| step=32, | |
| value=1024, | |
| info="Image width in pixels" | |
| ) | |
| img2img_height = gr.Slider( | |
| label="Height", | |
| minimum=256, | |
| maximum=MAX_IMAGE_SIZE, | |
| step=32, | |
| value=1024, | |
| info="Image height in pixels" | |
| ) | |
| with gr.Group(): | |
| gr.HTML('<div class="section-header">Generation Parameters</div>') | |
| with gr.Row(): | |
| img2img_guidance_scale = gr.Slider( | |
| label="Guidance Scale", | |
| minimum=0.0, | |
| maximum=10.0, | |
| step=0.1, | |
| value=3.5, | |
| info="Higher values follow the prompt more closely" | |
| ) | |
| img2img_steps = gr.Slider( | |
| label="Steps", | |
| minimum=1, | |
| maximum=50, | |
| step=1, | |
| value=30, | |
| info="More steps = more detailed but slower generation" | |
| ) | |
| img2img_lora_scale = gr.Slider( | |
| label="Ghibli Style Strength", | |
| minimum=0.0, | |
| maximum=1.0, | |
| step=0.05, | |
| value=1.0, | |
| info="Controls the intensity of the Ghibli style" | |
| ) | |
| transform_button = gr.Button("Transform to Ghibli Style", variant="primary", elem_classes="transform-button") | |
| with gr.Column(scale=1): | |
| with gr.Group(elem_classes="output-panel"): | |
| gr.HTML('<div class="section-header">Ghibli Magic Result</div>') | |
| ghibli_output_image = gr.Image( | |
| label="Generated Ghibli Style Image", | |
| elem_classes="image-preview", | |
| height=400 | |
| ) | |
| ghibli_output_seed = gr.Number(label="Seed Used", interactive=False) | |
| # Debug elements | |
| with gr.Accordion("Image Details", open=False): | |
| extracted_caption = gr.Textbox( | |
| label="Detected Image Content", | |
| placeholder="The AI will analyze your image and describe it here...", | |
| info="AI-generated description of your uploaded image" | |
| ) | |
| ghibli_prompt = gr.Textbox( | |
| label="Generation Prompt", | |
| placeholder="The prompt used to create your Ghibli image will appear here...", | |
| info="Final prompt used for the Ghibli transformation" | |
| ) | |
| gr.HTML( | |
| """ | |
| <div class="footer"> | |
| <p>Open Ghibli Studio uses AI to transform your images into Ghibli-inspired artwork.</p> | |
| <p>Powered by FLUX.1 and Florence-2 models.</p> | |
| </div> | |
| """ | |
| ) | |
| # Auto-process when image is uploaded | |
| upload_img.upload( | |
| process_uploaded_image, | |
| inputs=[ | |
| upload_img, | |
| img2img_seed, | |
| img2img_randomize_seed, | |
| img2img_width, | |
| img2img_height, | |
| img2img_guidance_scale, | |
| img2img_steps, | |
| img2img_lora_scale, | |
| ], | |
| outputs=[ | |
| ghibli_output_image, | |
| ghibli_output_seed, | |
| extracted_caption, | |
| ghibli_prompt, | |
| ] | |
| ) | |
| # Manual process button | |
| transform_button.click( | |
| process_uploaded_image, | |
| inputs=[ | |
| upload_img, | |
| img2img_seed, | |
| img2img_randomize_seed, | |
| img2img_width, | |
| img2img_height, | |
| img2img_guidance_scale, | |
| img2img_steps, | |
| img2img_lora_scale, | |
| ], | |
| outputs=[ | |
| ghibli_output_image, | |
| ghibli_output_seed, | |
| extracted_caption, | |
| ghibli_prompt, | |
| ] | |
| ) | |
| demo.launch(debug=True) |