Spaces:
Paused
Paused
| from typing import Tuple | |
| import requests | |
| import random | |
| import numpy as np | |
| import gradio as gr | |
| import spaces | |
| import torch | |
| from PIL import Image | |
| from huggingface_hub import login | |
| import os | |
| import time | |
| from gradio_imageslider import ImageSlider | |
| import requests | |
| from io import BytesIO | |
| import PIL.Image | |
| import requests | |
| import shutil | |
| import glob | |
| from huggingface_hub import snapshot_download, hf_hub_download | |
| MAX_SEED = np.iinfo(np.int32).max | |
| IMAGE_SIZE = 1024 | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| if HF_TOKEN: login(token=HF_TOKEN) | |
| cp_dir = os.getenv('CHECKPOINT_DIR', 'checkpoints') | |
| snapshot_download("Djrango/Qwen2vl-Flux", local_dir=cp_dir) | |
| hf_hub_download(repo_id="TheMistoAI/MistoLine", filename="MTEED.pth", subfolder="Anyline", local_dir=f"{cp_dir}/anyline") | |
| shutil.move("checkpoints/anyline/Anyline/MTEED.pth", f"{cp_dir}/anyline") | |
| snapshot_download("depth-anything/Depth-Anything-V2-Large", local_dir=f"{cp_dir}/depth-anything-v2") | |
| snapshot_download("facebook/sam2-hiera-large", local_dir=f"{cp_dir}/segment-anything-2") | |
| # https://github.com/facebookresearch/sam2/issues/26 | |
| os.makedirs("sam2_configs", exist_ok=True) | |
| for p in glob.glob(f"{cp_dir}/segment-anything-2/*.yaml"): | |
| shutil.copy(p, "sam2_configs") | |
| from modelmod import FluxModel | |
| model = FluxModel(device=DEVICE, is_turbo=False, required_features=['controlnet', 'depth', 'line'], is_quantization=True) # , 'sam' | |
| QWEN2VLFLUX_MODES = ["variation", "img2img", "inpaint", "controlnet", "controlnet-inpaint"] | |
| QWEN2VLFLUX_ASPECT_RATIO = ["1:1", "16:9", "9:16", "2.4:1", "3:4", "4:3"] | |
| class calculateDuration: | |
| def __init__(self, activity_name=""): | |
| self.activity_name = activity_name | |
| def __enter__(self): | |
| self.start_time = time.time() | |
| self.start_time_formatted = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(self.start_time)) | |
| print(f"Activity: {self.activity_name}, Start time: {self.start_time_formatted}") | |
| return self | |
| def __exit__(self, exc_type, exc_value, traceback): | |
| self.end_time = time.time() | |
| self.elapsed_time = self.end_time - self.start_time | |
| self.end_time_formatted = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(self.end_time)) | |
| if self.activity_name: | |
| print(f"Elapsed time for {self.activity_name}: {self.elapsed_time:.6f} seconds") | |
| else: | |
| print(f"Elapsed time: {self.elapsed_time:.6f} seconds") | |
| print(f"Activity: {self.activity_name}, End time: {self.start_time_formatted}") | |
| def resize_image_dimensions( | |
| original_resolution_wh: Tuple[int, int], | |
| maximum_dimension: int = IMAGE_SIZE | |
| ) -> Tuple[int, int]: | |
| width, height = original_resolution_wh | |
| # if width <= maximum_dimension and height <= maximum_dimension: | |
| # width = width - (width % 32) | |
| # height = height - (height % 32) | |
| # return width, height | |
| if width > height: | |
| scaling_factor = maximum_dimension / width | |
| else: | |
| scaling_factor = maximum_dimension / height | |
| new_width = int(width * scaling_factor) | |
| new_height = int(height * scaling_factor) | |
| new_width = new_width - (new_width % 32) | |
| new_height = new_height - (new_height % 32) | |
| return new_width, new_height | |
| def fetch_from_url(url: str, name: str): | |
| try: | |
| print(f"start to fetch {name} from url", url) | |
| response = requests.get(url) | |
| response.raise_for_status() | |
| image = PIL.Image.open(BytesIO(response.content)) | |
| print(f"fetch {name} success") | |
| return image | |
| except Exception as e: | |
| print(e) | |
| return None | |
| def process( | |
| mode: str, | |
| input_image_editor: dict, | |
| ref_image: Image.Image, | |
| image_url: str, | |
| mask_url: str, | |
| ref_url: str, | |
| input_text: str, | |
| strength: float, | |
| num_inference_steps: int, | |
| guidance_scale: float, | |
| aspect_ratio: str, | |
| attn_mode: bool, | |
| center_x: float, | |
| center_y: float, | |
| radius: float, | |
| line_mode: bool, | |
| line_strength: float, | |
| depth_mode: bool, | |
| depth_strength: float, | |
| progress=gr.Progress(track_tqdm=True) | |
| ): | |
| #if not input_text: | |
| # gr.Info("Please enter a text prompt.") | |
| # return None | |
| kwargs = {} | |
| image = input_image_editor['background'] | |
| mask = input_image_editor['layers'][0] | |
| if image_url: image = fetch_from_url(image_url, "image") | |
| if mask_url: mask = fetch_from_url(mask_url, "mask") | |
| if ref_url: ref_image = fetch_from_url(ref_url, "refernce image") | |
| if not image: | |
| gr.Info("Please upload an image.") | |
| return None | |
| if ref_image: kwargs["input_image_b"] = ref_image | |
| if mode == "inpaint" or mode == "controlnet-inpaint": | |
| if not mask: | |
| gr.Info("Please draw a mask on the image.") | |
| return None | |
| kwargs["mask_image"] = mask | |
| if attn_mode: | |
| kwargs["center_x"] = center_x | |
| kwargs["center_y"] = center_y | |
| kwargs["radius"] = radius | |
| with calculateDuration("run inference"): | |
| result = model.generate( | |
| input_image_a=image, | |
| prompt=input_text, | |
| guidance_scale=guidance_scale, | |
| num_inference_steps=num_inference_steps, | |
| aspect_ratio=aspect_ratio, | |
| mode=mode, | |
| denoise_strength=strength, | |
| line_mode=line_mode, | |
| line_strength=line_strength, | |
| depth_mode=depth_mode, | |
| depth_strength=depth_strength, | |
| imageCount=1, | |
| **kwargs | |
| )[0] | |
| #return result | |
| return [image, result] | |
| CSS = """ | |
| .title { text-align: center; } | |
| """ | |
| with gr.Blocks(fill_width=True, css=CSS) as demo: | |
| gr.Markdown("# Qwen2VL-Flux", elem_classes="title") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gen_mode = gr.Radio(label="Generation mode", choices=QWEN2VLFLUX_MODES, value="variation") | |
| with gr.Row(): | |
| input_image_editor = gr.ImageEditor(label='Image', type='pil', sources=["upload", "webcam", "clipboard"], image_mode='RGB', | |
| layers=False, brush=gr.Brush(colors=["#FFFFFF"], color_mode="fixed")) | |
| ref_image = gr.Image(label='Reference image', type='pil', sources=["upload", "webcam", "clipboard"], image_mode='RGB') | |
| with gr.Accordion("Image from URL", open=False): | |
| image_url = gr.Textbox(label="Image url", show_label=True, max_lines=1, placeholder="Enter your image url (Optional)") | |
| mask_url = gr.Textbox(label="Mask image url", show_label=True, max_lines=1, placeholder="Enter your mask image url (Optional)") | |
| ref_url = gr.Textbox(label="Reference image url", show_label=True, max_lines=1, placeholder="Enter your reference image url (Optional)") | |
| with gr.Accordion("Prompt Settings", open=True): | |
| input_text = gr.Textbox(label="Prompt", show_label=True, max_lines=1, placeholder="Enter your prompt") | |
| submit_button = gr.Button(value='Submit', variant='primary') | |
| with gr.Accordion("Advanced Settings", open=True): | |
| with gr.Row(): | |
| denoise_strength = gr.Slider(label="Denoise strength", minimum=0, maximum=1, step=0.01, value=0.75) | |
| aspect_ratio = gr.Radio(label="Output image ratio", choices=QWEN2VLFLUX_ASPECT_RATIO, value="1:1") | |
| num_inference_steps = gr.Slider(label="Number of inference steps", minimum=1, maximum=50, step=1, value=28) | |
| guidance_scale = gr.Slider(label="Guidance scale", minimum=0, maximum=20, step=0.5, value=3.5) | |
| with gr.Accordion("Attention Control", open=True): | |
| with gr.Row(): | |
| attn_mode = gr.Checkbox(label="Attention Control", value=False) | |
| center_x = gr.Slider(label="X coordinate of attention center", minimum=0, maximum=1, step=0.01, value=0.5) | |
| center_y = gr.Slider(label="Y coordinate of attention center", minimum=0, maximum=1, step=0.01, value=0.5) | |
| radius = gr.Slider(label="Radius of attention circle", minimum=0, maximum=1, step=0.01, value=0.5) | |
| with gr.Accordion("ControlNet Settings", open=True): | |
| with gr.Row(): | |
| line_mode = gr.Checkbox(label="Line mode", value=True) | |
| line_strength = gr.Slider(label="Line strength", minimum=0, maximum=1, step=0.01, value=0.4) | |
| depth_mode = gr.Checkbox(label="Depth mode", value=True) | |
| depth_strength = gr.Slider(label="Depth strength", minimum=0, maximum=1, step=0.01, value=0.2) | |
| with gr.Column(): | |
| #output_image = gr.Image(label="Generated image", type="pil", format="png", show_download_button=True, show_share_button=False) | |
| output_image = ImageSlider(label="Generated image", type="pil") | |
| gr.on(triggers=[submit_button.click, input_text.submit], fn=process, | |
| inputs=[gen_mode, input_image_editor, ref_image, image_url, mask_url, ref_url, | |
| input_text, denoise_strength, num_inference_steps, guidance_scale, aspect_ratio, | |
| attn_mode, center_x, center_y, radius, line_mode, line_strength, depth_mode, depth_strength], | |
| outputs=[output_image], queue=True) | |
| demo.queue().launch(debug=True, show_error=True) | |
| #demo.queue().launch(debug=True, show_error=True, ssr_mode=False) # Gradio 5 |