Spaces:
Runtime error
Runtime error
| import spaces | |
| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| from torchvision import transforms | |
| # from diffusers import StableDiffusionImageVariationPipeline | |
| from inference import InferenceModel | |
| from pytorch_lightning import seed_everything | |
| import numpy as np | |
| import os | |
| import rembg | |
| import sys | |
| from loguru import logger | |
| # Add imports for downloading weights | |
| from huggingface_hub import hf_hub_download | |
| import shutil | |
| _SAMPLE_TAB_ID_ = 0 | |
| _HIGHRES_TAB_ID_ = 1 | |
| _FOREGROUND_TAB_ID_ = 2 | |
| def download_model_weights(): | |
| """Download model weights from the original repository if they don't exist.""" | |
| # Original repository ID | |
| repo_id = "LittleFrog/IntrinsicAnything" | |
| # Define the paths and files to download | |
| model_configs = [ | |
| { | |
| "local_dir": "weights/albedo", | |
| "files": [ | |
| "weights/albedo/checkpoints/last.ckpt", | |
| "weights/albedo/configs/albedo_project.yaml" | |
| ] | |
| }, | |
| { | |
| "local_dir": "weights/specular", | |
| "files": [ | |
| "weights/specular/checkpoints/last.ckpt", | |
| "weights/specular/configs/specular_project.yaml" | |
| ] | |
| } | |
| ] | |
| for config in model_configs: | |
| local_dir = config["local_dir"] | |
| # Create directories if they don't exist | |
| os.makedirs(f"{local_dir}/checkpoints", exist_ok=True) | |
| os.makedirs(f"{local_dir}/configs", exist_ok=True) | |
| for file_path in config["files"]: | |
| local_file_path = file_path | |
| # Check if file exists and is not a LFS pointer (> 1KB) | |
| if not os.path.exists(local_file_path) or os.path.getsize(local_file_path) < 1024: | |
| logger.info(f"Downloading {file_path} from HuggingFace...") | |
| try: | |
| # Download the file | |
| downloaded_file = hf_hub_download( | |
| repo_id=repo_id, | |
| filename=file_path, | |
| repo_type="space", | |
| cache_dir="./hf_cache" | |
| ) | |
| # Copy to the expected location | |
| shutil.copy2(downloaded_file, local_file_path) | |
| logger.info(f"Successfully downloaded {file_path}") | |
| except Exception as e: | |
| logger.error(f"Failed to download {file_path}: {e}") | |
| raise e | |
| else: | |
| logger.info(f"{local_file_path} already exists and appears to be valid") | |
| def set_loggers(level): | |
| logger.remove() | |
| logger.add(sys.stderr, level=level) | |
| def on_guide_select(evt: gr.SelectData): | |
| logger.debug(f"You selected {evt.value} at {evt.index} from {evt.target}") | |
| return [evt.value["image"]['path'], f"Sample {evt.index}"] | |
| def on_input_select(evt: gr.SelectData): | |
| logger.debug(f"You selected {evt.value} at {evt.index} from {evt.target}") | |
| return evt.value["image"]['path'] | |
| def sample_fine( | |
| input_im, | |
| domain="Albedo", | |
| require_mask=False, | |
| steps=25, | |
| n_samples=4, | |
| seed=0, | |
| guid_img=None, | |
| vert_split=2, | |
| hor_split=2, | |
| overlaps=2, | |
| guidance_scale=2, | |
| ): | |
| if require_mask: | |
| input_im = remove_bg(input_im) | |
| seed_everything(int(seed)) | |
| model = model_dict[domain] | |
| inp = tform(input_im).to(device).permute(1,2,0) | |
| guid_img = tform(guid_img).to(device).permute(1,2,0) | |
| images = model.generation((vert_split, hor_split), overlaps, guid_img[..., :3], inp[..., :3], inp[..., 3:], dps_scale=guidance_scale, uc_score=1.0, ddim_steps=steps, batch_size=1, n_samples=1) | |
| images["guid_iamges"] = [(guid_img.detach().cpu().numpy() * 255).astype(np.uint8)] | |
| output = images["out_images"][0] | |
| return [[(output, "High-res")], gr.Tabs(selected=_HIGHRES_TAB_ID_)] | |
| def remove_bg(input_im): | |
| output = rembg.remove(input_im, session=model_dict["remove_bg"]) | |
| return output | |
| def sampling(input_im, domain="Albedo", require_mask=False, | |
| steps=25, n_samples=4, seed=0): | |
| seed_everything(int(seed)) | |
| model = model_dict[domain] | |
| if require_mask: | |
| input_im = remove_bg(input_im) | |
| inp = tform(input_im).to(device).permute(1,2,0) | |
| images = model.generation((1, 1), 1, None, inp[..., :3], inp[..., 3:], dps_scale=0, uc_score=1, ddim_steps=steps, batch_size=1, n_samples=n_samples) | |
| output = [[(images["input_image"][0], "Foreground Object"), (images["input_maskes"][0], "Foreground Maks")], | |
| [(img,f"Sample {idx}") for idx, img in enumerate(images["out_images"])], | |
| gr.Tabs(selected=_SAMPLE_TAB_ID_), | |
| ] | |
| return output | |
| title = "IntrinsicAnything: Learning Diffusion Priors for Inverse Rendering Under Unknown Illumination" | |
| description = \ | |
| """ | |
| #### Generate intrinsic images (Albedo, Specular Shading) from a single image. | |
| ##### Tips | |
| - You can check the "Auto Mask" box if the input image requires a foreground mask. Or supply your mask with RGBA input. | |
| - You can optionally generate a high-resolution sample if the input image is of high resolution. We split the original image into `Vertical Splits` by `Horizontal Splits` patches with some `Overlaps` in between. Due to computation constraints for the online demo, we recommend `Vertical Splits` x `Horizontal Splits` to be no more than 6 and to set 2 for `Overlaps`. The denoising steps should at least be set to 80 for high resolution samples. | |
| """ | |
| set_loggers("INFO") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Download model weights if needed | |
| logger.info("Checking and downloading model weights if needed...") | |
| download_model_weights() | |
| logger.info(f"Loading Models...") | |
| model_dict = { | |
| "Albedo": InferenceModel(ckpt_path="weights/albedo", | |
| use_ddim=True, | |
| gpu_id=0), | |
| "Specular": InferenceModel(ckpt_path="weights/specular", | |
| use_ddim=True, | |
| gpu_id=0), | |
| "remove_bg": rembg.new_session(), | |
| } | |
| logger.info(f"All models Loaded!") | |
| tform = transforms.Compose([ | |
| transforms.ToTensor() | |
| ]) | |
| examples_dir = "examples" | |
| examples = [[os.path.join(examples_dir, img_name)] for img_name in os.listdir(examples_dir)] | |
| # theme definition | |
| theme = gr.Theme.from_hub("NoCrypt/miku") | |
| theme.body_background_fill = "#FFFFFF " | |
| theme.body_background_fill_dark = "#000000" | |
| demo = gr.Blocks(title=title, theme=theme) | |
| with demo: | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown('# ' + title) | |
| gr.Markdown(description) | |
| with gr.Column(): | |
| with gr.Row(): | |
| with gr.Column(scale=0.8): | |
| image_input = [gr.Image(image_mode='RGBA', height=256)] | |
| with gr.Column(): | |
| with gr.Tabs(): | |
| with gr.TabItem("Options"): | |
| with gr.Column(): | |
| with gr.Row(): | |
| domain_box = gr.Radio([("Albedo", "Albedo"),("Specular", "Specular")], | |
| value="Albedo", | |
| label="Type") | |
| with gr.Column(): | |
| gr.Markdown("### Automatic foreground segmentation") | |
| mask_box = gr.Checkbox(False, label="Auto Mask") | |
| options_tab = [ | |
| domain_box, | |
| mask_box, | |
| gr.Slider(5, 200, value=50, step=5, label="Denoising Steps (The larger the better results)"), | |
| gr.Slider(1, 10, value=2, step=1, label="Number of Samples"), | |
| gr.Number(75424, label="Seed", precision=0), | |
| ] | |
| with gr.TabItem("Advanced (High-res)"): | |
| with gr.Column(): | |
| guiding_img = gr.Image(image_mode='RGBA', label="Guiding Image", interactive=False, height=256, visible=False) | |
| sample_idx = gr.Textbox(placeholder="Select one from the generate low-res samples", lines=1, interactive=False, label="Guiding Image") | |
| options_advanced_tab = [ | |
| # high resolution options | |
| guiding_img, | |
| gr.Slider(1, 4, value=2, step=1, label="Vertical Splits"), | |
| gr.Slider(1, 4, value=2, step=1, label="Horizontal Splits"), | |
| gr.Slider(1, 5, value=2, step=1, label="Overlaps"), | |
| gr.Slider(0, 5, value=3, step=1, label="Guidance Scale"),] | |
| with gr.Column(scale=1.0): | |
| with gr.Tabs() as res_tabs: | |
| with gr.TabItem("Generated Samples", id=_SAMPLE_TAB_ID_): | |
| image_output = gr.Gallery(label="Generated Samples", object_fit="contain", columns=[2], rows=[2],height=512, selected_index=0) | |
| with gr.TabItem("High Resolution Sample", id=_HIGHRES_TAB_ID_): | |
| image_output_high = gr.Gallery(label="High Resolution Sample", object_fit="contain", columns=[1], rows=[1],height=512, selected_index=0) | |
| with gr.TabItem("Foreground Object", id=_FOREGROUND_TAB_ID_): | |
| forground_output = gr.Gallery(label="Foreground Object", object_fit="contain", columns=[2], rows=[1],height=512, selected_index=0) | |
| with gr.Row(): | |
| generate_button = gr.Button("Generate") | |
| generate_button_fine = gr.Button("Generate High-Res") | |
| examples_gr = gr.Examples(examples=examples, inputs=image_input, | |
| cache_examples=False, examples_per_page=30, | |
| label='Examples (Click one to start!)') | |
| with gr.Row(): | |
| pass | |
| # forground_output = gr.Gallery(label="Inputs", preview=False, columns=[2], rows=[1],height=512, selected_index=0) | |
| # image_output = gr.Gallery(label="Generated Samples", object_fit="cover", columns=[1], rows=[6],height=512, selected_index=0) | |
| # image_output_high = gr.Gallery(label="High Resolution Sample", object_fit="cover", columns=[1], rows=[1],height=512, selected_index=0) | |
| generate_button.click(sampling, inputs=image_input+options_tab, | |
| outputs=[forground_output, image_output, res_tabs]) | |
| generate_button_fine.click(sample_fine, | |
| inputs=image_input+options_tab+options_advanced_tab, | |
| outputs=[image_output_high, res_tabs]) | |
| image_output.select(on_guide_select, None, [guiding_img, sample_idx]) | |
| logger.info(f"Demo Initilized, Starting...") | |
| demo.queue().launch() | |