ovi054 commited on
Commit
d590d5c
·
verified ·
1 Parent(s): c226624

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +144 -0
app.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import random
4
+ import spaces
5
+ import torch
6
+ from diffusers import QwenImagePipeline
7
+ # from transformers import CLIPTextModel, CLIPTokenizer,T5EncoderModel, T5TokenizerFast
8
+ # from live_preview_helpers import calculate_shift, retrieve_timesteps, flux_pipe_call_that_returns_an_iterable_of_images
9
+
10
+ dtype = torch.bfloat16
11
+ device = "cuda" if torch.cuda.is_available() else "cpu"
12
+ #black-forest-labs/FLUX.1-Krea-dev
13
+ #black-forest-labs/FLUX.1-dev
14
+
15
+
16
+ # Load the model pipeline
17
+ pipe = QwenImagePipeline.from_pretrained("Qwen/Qwen-Image", torch_dtype=dtype).to(device)
18
+
19
+ torch.cuda.empty_cache()
20
+
21
+ MAX_SEED = np.iinfo(np.int32).max
22
+ MAX_IMAGE_SIZE = 2048
23
+
24
+ # pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
25
+
26
+ @spaces.GPU()
27
+ def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, guidance_scale=4, num_inference_steps=28, lora_id=None, lora_scale=0.95, progress=gr.Progress(track_tqdm=True)):
28
+ if randomize_seed:
29
+ seed = random.randint(0, MAX_SEED)
30
+ generator = torch.Generator().manual_seed(seed)
31
+
32
+
33
+ # for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
34
+ # prompt=prompt,
35
+ # guidance_scale=guidance_scale,
36
+ # num_inference_steps=num_inference_steps,
37
+ # width=width,
38
+ # height=height,
39
+ # generator=generator,
40
+ # output_type="pil",
41
+ # good_vae=good_vae,
42
+ # ):
43
+ # yield img, seed
44
+
45
+ # Handle LoRA loading
46
+ # Load LoRA weights and prepare joint_attention_kwargs
47
+ if lora_id and lora_id.strip() != "":
48
+ pipe.unload_lora_weights()
49
+ pipe.load_lora_weights(lora_id.strip())
50
+ # joint_attention_kwargs = {"scale": lora_scale}
51
+ # else:
52
+ # joint_attention_kwargs = None
53
+
54
+ try:
55
+ image = pipe(
56
+ prompt=prompt,
57
+ negative_prompt="",
58
+ width=width,
59
+ height=height,
60
+ num_inference_steps=num_inference_steps,
61
+ generator=generator,
62
+ true_cfg_scale=guidance_scale,
63
+ guidance_scale=1.0 # Use a fixed default for distilled guidance
64
+ ).images[0]
65
+ return image, seed
66
+ finally:
67
+ # Unload LoRA weights if they were loaded
68
+ if lora_id:
69
+ pipe.unload_lora_weights()
70
+
71
+ examples = [
72
+ "a tiny astronaut hatching from an egg on the moon",
73
+ "a cat holding a sign that says hello world",
74
+ "an anime illustration of a wiener schnitzel",
75
+ ]
76
+
77
+ css = """
78
+ #col-container {
79
+ margin: 0 auto;
80
+ max-width: 960px;
81
+ }
82
+ .generate-btn {
83
+ background: linear-gradient(90deg, #4B79A1 0%, #283E51 100%) !important;
84
+ border: none !important;
85
+ color: white !important;
86
+ }
87
+ .generate-btn:hover {
88
+ transform: translateY(-2px);
89
+ box-shadow: 0 5px 15px rgba(0,0,0,0.2);
90
+ }
91
+ """
92
+
93
+ with gr.Blocks(css=css) as app:
94
+ gr.HTML("<center><h1>FLUX.1-Dev with LoRA support</h1></center>")
95
+ with gr.Column(elem_id="col-container"):
96
+ with gr.Row():
97
+ with gr.Column():
98
+ with gr.Row():
99
+ text_prompt = gr.Textbox(label="Prompt", placeholder="Enter a prompt here", lines=3, elem_id="prompt-text-input")
100
+ with gr.Row():
101
+ custom_lora = gr.Textbox(label="Custom LoRA", info="LoRA Hugging Face path (optional)", placeholder="multimodalart/vintage-ads-flux")
102
+ with gr.Row():
103
+ with gr.Accordion("Advanced Settings", open=False):
104
+ lora_scale = gr.Slider(
105
+ label="LoRA Scale",
106
+ minimum=0,
107
+ maximum=2,
108
+ step=0.01,
109
+ value=0.95,
110
+ )
111
+ with gr.Row():
112
+ width = gr.Slider(label="Width", value=1024, minimum=64, maximum=2048, step=8)
113
+ height = gr.Slider(label="Height", value=1024, minimum=64, maximum=2048, step=8)
114
+ seed = gr.Slider(label="Seed", value=-1, minimum=-1, maximum=4294967296, step=1)
115
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
116
+ with gr.Row():
117
+ steps = gr.Slider(label="Inference steps steps", value=28, minimum=1, maximum=100, step=1)
118
+ cfg = gr.Slider(label="Guidance Scale", value=3.5, minimum=1, maximum=20, step=0.5)
119
+ # method = gr.Radio(label="Sampling method", value="DPM++ 2M Karras", choices=["DPM++ 2M Karras", "DPM++ SDE Karras", "Euler", "Euler a", "Heun", "DDIM"])
120
+
121
+ with gr.Row():
122
+ # text_button = gr.Button("Run", variant='primary', elem_id="gen-button")
123
+ text_button = gr.Button("✨ Generate Image", variant='primary', elem_classes=["generate-btn"])
124
+ with gr.Column():
125
+ with gr.Row():
126
+ image_output = gr.Image(type="pil", label="Image Output", elem_id="gallery")
127
+
128
+ # gr.Markdown(article_text)
129
+ with gr.Column():
130
+ gr.Examples(
131
+ examples = examples,
132
+ inputs = [text_prompt],
133
+ )
134
+ gr.on(
135
+ triggers=[text_button.click, text_prompt.submit],
136
+ fn = infer,
137
+ inputs=[text_prompt, seed, randomize_seed, width, height, cfg, steps, custom_lora, lora_scale],
138
+ outputs=[image_output, seed]
139
+ )
140
+
141
+ # text_button.click(query, inputs=[custom_lora, text_prompt, steps, cfg, randomize_seed, seed, width, height], outputs=[image_output,seed_output, seed])
142
+ # text_button.click(infer, inputs=[text_prompt, seed, randomize_seed, width, height, cfg, steps, custom_lora, lora_scale], outputs=[image_output,seed_output, seed])
143
+
144
+ app.launch(share=True)