Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	Update app.py
Browse files
    	
        app.py
    CHANGED
    
    | @@ -5,7 +5,9 @@ import logging | |
| 5 | 
             
            import torch
         | 
| 6 | 
             
            from PIL import Image
         | 
| 7 | 
             
            import spaces
         | 
| 8 | 
            -
            from diffusers import DiffusionPipeline
         | 
|  | |
|  | |
| 9 | 
             
            from huggingface_hub import hf_hub_download, HfFileSystem, ModelCard, snapshot_download
         | 
| 10 | 
             
            import copy
         | 
| 11 | 
             
            import random
         | 
| @@ -16,11 +18,18 @@ with open('loras.json', 'r') as f: | |
| 16 | 
             
                loras = json.load(f)
         | 
| 17 |  | 
| 18 | 
             
            # Initialize the base model
         | 
|  | |
|  | |
| 19 | 
             
            base_model = "black-forest-labs/FLUX.1-dev"
         | 
| 20 | 
            -
             | 
|  | |
|  | |
|  | |
| 21 |  | 
| 22 | 
             
            MAX_SEED = 2**32-1
         | 
| 23 |  | 
|  | |
|  | |
| 24 | 
             
            class calculateDuration:
         | 
| 25 | 
             
                def __init__(self, activity_name=""):
         | 
| 26 | 
             
                    self.activity_name = activity_name
         | 
| @@ -61,10 +70,9 @@ def update_selection(evt: gr.SelectData, width, height): | |
| 61 | 
             
            def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale, progress):
         | 
| 62 | 
             
                pipe.to("cuda")
         | 
| 63 | 
             
                generator = torch.Generator(device="cuda").manual_seed(seed)
         | 
| 64 | 
            -
                
         | 
| 65 | 
             
                with calculateDuration("Generating image"):
         | 
| 66 | 
             
                    # Generate image
         | 
| 67 | 
            -
                     | 
| 68 | 
             
                        prompt=prompt_mash,
         | 
| 69 | 
             
                        num_inference_steps=steps,
         | 
| 70 | 
             
                        guidance_scale=cfg_scale,
         | 
| @@ -72,13 +80,14 @@ def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scal | |
| 72 | 
             
                        height=height,
         | 
| 73 | 
             
                        generator=generator,
         | 
| 74 | 
             
                        joint_attention_kwargs={"scale": lora_scale},
         | 
| 75 | 
            -
             | 
| 76 | 
            -
             | 
|  | |
|  | |
| 77 |  | 
| 78 | 
             
            def run_lora(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)):
         | 
| 79 | 
             
                if selected_index is None:
         | 
| 80 | 
             
                    raise gr.Error("You must select a LoRA before proceeding.")
         | 
| 81 | 
            -
             | 
| 82 | 
             
                selected_lora = loras[selected_index]
         | 
| 83 | 
             
                lora_path = selected_lora["repo"]
         | 
| 84 | 
             
                trigger_word = selected_lora["trigger_word"]
         | 
| @@ -92,24 +101,31 @@ def run_lora(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, wid | |
| 92 | 
             
                        prompt_mash = f"{trigger_word} {prompt}"
         | 
| 93 | 
             
                else:
         | 
| 94 | 
             
                    prompt_mash = prompt
         | 
|  | |
| 95 | 
             
                # Load LoRA weights
         | 
| 96 | 
             
                with calculateDuration(f"Loading LoRA weights for {selected_lora['title']}"):
         | 
| 97 | 
             
                    if "weights" in selected_lora:
         | 
| 98 | 
             
                        pipe.load_lora_weights(lora_path, weight_name=selected_lora["weights"])
         | 
| 99 | 
            -
                        #pipe.fuse_lora()
         | 
| 100 | 
             
                    else:
         | 
| 101 | 
             
                        pipe.load_lora_weights(lora_path)
         | 
| 102 | 
            -
             | 
| 103 | 
             
                # Set random seed for reproducibility
         | 
| 104 | 
             
                with calculateDuration("Randomizing seed"):
         | 
| 105 | 
             
                    if randomize_seed:
         | 
| 106 | 
             
                        seed = random.randint(0, MAX_SEED)
         | 
|  | |
|  | |
| 107 |  | 
| 108 | 
            -
                 | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 109 | 
             
                pipe.to("cpu")
         | 
| 110 | 
            -
                #pipe.unfuse_lora()
         | 
| 111 | 
             
                pipe.unload_lora_weights()
         | 
| 112 | 
            -
                 | 
|  | |
| 113 |  | 
| 114 | 
             
            def get_huggingface_safetensors(link):
         | 
| 115 | 
             
              split_link = link.split("/")
         | 
|  | |
| 5 | 
             
            import torch
         | 
| 6 | 
             
            from PIL import Image
         | 
| 7 | 
             
            import spaces
         | 
| 8 | 
            +
            from diffusers import DiffusionPipeline, AutoencoderTiny, AutoencoderKL
         | 
| 9 | 
            +
            from live_preview_helpers import calculate_shift, retrieve_timesteps, flux_pipe_call_that_returns_an_iterable_of_images
         | 
| 10 | 
            +
             | 
| 11 | 
             
            from huggingface_hub import hf_hub_download, HfFileSystem, ModelCard, snapshot_download
         | 
| 12 | 
             
            import copy
         | 
| 13 | 
             
            import random
         | 
|  | |
| 18 | 
             
                loras = json.load(f)
         | 
| 19 |  | 
| 20 | 
             
            # Initialize the base model
         | 
| 21 | 
            +
            dtype = torch.bfloat16
         | 
| 22 | 
            +
            device = "cuda" if torch.cuda.is_available() else "cpu"
         | 
| 23 | 
             
            base_model = "black-forest-labs/FLUX.1-dev"
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
         | 
| 26 | 
            +
            good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype).to(device)
         | 
| 27 | 
            +
            pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1).to(device)
         | 
| 28 |  | 
| 29 | 
             
            MAX_SEED = 2**32-1
         | 
| 30 |  | 
| 31 | 
            +
            pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
         | 
| 32 | 
            +
             | 
| 33 | 
             
            class calculateDuration:
         | 
| 34 | 
             
                def __init__(self, activity_name=""):
         | 
| 35 | 
             
                    self.activity_name = activity_name
         | 
|  | |
| 70 | 
             
            def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale, progress):
         | 
| 71 | 
             
                pipe.to("cuda")
         | 
| 72 | 
             
                generator = torch.Generator(device="cuda").manual_seed(seed)
         | 
|  | |
| 73 | 
             
                with calculateDuration("Generating image"):
         | 
| 74 | 
             
                    # Generate image
         | 
| 75 | 
            +
                    for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
         | 
| 76 | 
             
                        prompt=prompt_mash,
         | 
| 77 | 
             
                        num_inference_steps=steps,
         | 
| 78 | 
             
                        guidance_scale=cfg_scale,
         | 
|  | |
| 80 | 
             
                        height=height,
         | 
| 81 | 
             
                        generator=generator,
         | 
| 82 | 
             
                        joint_attention_kwargs={"scale": lora_scale},
         | 
| 83 | 
            +
                        output_type="pil",
         | 
| 84 | 
            +
                        good_vae=good_vae,
         | 
| 85 | 
            +
                    ):
         | 
| 86 | 
            +
                        yield img
         | 
| 87 |  | 
| 88 | 
             
            def run_lora(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)):
         | 
| 89 | 
             
                if selected_index is None:
         | 
| 90 | 
             
                    raise gr.Error("You must select a LoRA before proceeding.")
         | 
|  | |
| 91 | 
             
                selected_lora = loras[selected_index]
         | 
| 92 | 
             
                lora_path = selected_lora["repo"]
         | 
| 93 | 
             
                trigger_word = selected_lora["trigger_word"]
         | 
|  | |
| 101 | 
             
                        prompt_mash = f"{trigger_word} {prompt}"
         | 
| 102 | 
             
                else:
         | 
| 103 | 
             
                    prompt_mash = prompt
         | 
| 104 | 
            +
             | 
| 105 | 
             
                # Load LoRA weights
         | 
| 106 | 
             
                with calculateDuration(f"Loading LoRA weights for {selected_lora['title']}"):
         | 
| 107 | 
             
                    if "weights" in selected_lora:
         | 
| 108 | 
             
                        pipe.load_lora_weights(lora_path, weight_name=selected_lora["weights"])
         | 
|  | |
| 109 | 
             
                    else:
         | 
| 110 | 
             
                        pipe.load_lora_weights(lora_path)
         | 
| 111 | 
            +
             | 
| 112 | 
             
                # Set random seed for reproducibility
         | 
| 113 | 
             
                with calculateDuration("Randomizing seed"):
         | 
| 114 | 
             
                    if randomize_seed:
         | 
| 115 | 
             
                        seed = random.randint(0, MAX_SEED)
         | 
| 116 | 
            +
             | 
| 117 | 
            +
                image_generator = generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale, progress)
         | 
| 118 |  | 
| 119 | 
            +
                # Consume the generator to get the final image
         | 
| 120 | 
            +
                final_image = None
         | 
| 121 | 
            +
                for image in image_generator:
         | 
| 122 | 
            +
                    final_image = image
         | 
| 123 | 
            +
                    yield image, seed  # Yield intermediate images and seed
         | 
| 124 | 
            +
             | 
| 125 | 
             
                pipe.to("cpu")
         | 
|  | |
| 126 | 
             
                pipe.unload_lora_weights()
         | 
| 127 | 
            +
                
         | 
| 128 | 
            +
                return final_image, seed  # Return the final image and seed
         | 
| 129 |  | 
| 130 | 
             
            def get_huggingface_safetensors(link):
         | 
| 131 | 
             
              split_link = link.split("/")
         | 
