jocoandonob commited on
Commit
6b8c8e9
·
1 Parent(s): d9fec20

Deploy custom Gradio projec3t

Browse files
Files changed (1) hide show
  1. app.py +440 -4
app.py CHANGED
@@ -1,7 +1,443 @@
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
  import gradio as gr
4
+ from diffusers import (
5
+ StableDiffusionXLPipeline,
6
+ AutoPipelineForInpainting,
7
+ TCDScheduler,
8
+ ControlNetModel,
9
+ StableDiffusionXLControlNetPipeline,
10
+ MotionAdapter,
11
+ AnimateDiffPipeline
12
+ )
13
+ from diffusers.utils import make_image_grid, export_to_gif
14
+ from PIL import Image
15
+ import io
16
+ import requests
17
+ from transformers import DPTImageProcessor, DPTForDepthEstimation
18
 
19
+ # Available models
20
+ AVAILABLE_MODELS = {
21
+ "Stable Diffusion XL": "stabilityai/stable-diffusion-xl-base-1.0",
22
+ "Animagine XL 3.0": "cagliostrolab/animagine-xl-3.0",
23
+ }
24
 
25
+ # Available LoRA styles
26
+ AVAILABLE_LORAS = {
27
+ "TCD": "h1t/TCD-SDXL-LoRA",
28
+ "Papercut": "TheLastBen/Papercut_SDXL",
29
+ }
30
+
31
+ def get_depth_map(image):
32
+ # Initialize depth estimator
33
+ depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas")
34
+ feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas")
35
+
36
+ # Process image
37
+ image = feature_extractor(images=image, return_tensors="pt").pixel_values
38
+ with torch.no_grad():
39
+ depth_map = depth_estimator(image).predicted_depth
40
+
41
+ # Resize and normalize depth map
42
+ depth_map = torch.nn.functional.interpolate(
43
+ depth_map.unsqueeze(1),
44
+ size=(1024, 1024),
45
+ mode="bicubic",
46
+ align_corners=False,
47
+ )
48
+ depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True)
49
+ depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True)
50
+ depth_map = (depth_map - depth_min) / (depth_max - depth_min)
51
+ image = torch.cat([depth_map] * 3, dim=1)
52
+
53
+ # Convert to PIL Image
54
+ image = image.permute(0, 2, 3, 1).cpu().numpy()[0]
55
+ image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8))
56
+ return image
57
+
58
+ def load_image_from_url(url):
59
+ response = requests.get(url)
60
+ return Image.open(io.BytesIO(response.content)).convert("RGB")
61
+
62
+ def generate_image(prompt, seed, num_steps, guidance_scale, eta):
63
+ # Initialize the pipeline
64
+ base_model_id = "stabilityai/stable-diffusion-xl-base-1.0"
65
+ tcd_lora_id = "h1t/TCD-SDXL-LoRA"
66
+
67
+ # Use CPU for inference
68
+ pipe = StableDiffusionXLPipeline.from_pretrained(
69
+ base_model_id,
70
+ torch_dtype=torch.float32 # Use float32 for CPU
71
+ )
72
+ pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
73
+
74
+ # Load and fuse LoRA weights
75
+ pipe.load_lora_weights(tcd_lora_id)
76
+ pipe.fuse_lora()
77
+
78
+ # Generate the image
79
+ generator = torch.Generator().manual_seed(seed)
80
+ image = pipe(
81
+ prompt=prompt,
82
+ num_inference_steps=num_steps,
83
+ guidance_scale=guidance_scale,
84
+ eta=eta,
85
+ generator=generator,
86
+ ).images[0]
87
+
88
+ return image
89
+
90
+ def generate_community_image(prompt, model_name, seed, num_steps, guidance_scale, eta):
91
+ # Initialize the pipeline
92
+ base_model_id = AVAILABLE_MODELS[model_name]
93
+ tcd_lora_id = "h1t/TCD-SDXL-LoRA"
94
+
95
+ # Use CPU for inference
96
+ pipe = StableDiffusionXLPipeline.from_pretrained(
97
+ base_model_id,
98
+ torch_dtype=torch.float32 # Use float32 for CPU
99
+ )
100
+ pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
101
+
102
+ # Load and fuse LoRA weights
103
+ pipe.load_lora_weights(tcd_lora_id)
104
+ pipe.fuse_lora()
105
+
106
+ # Generate the image
107
+ generator = torch.Generator().manual_seed(seed)
108
+ image = pipe(
109
+ prompt=prompt,
110
+ num_inference_steps=num_steps,
111
+ guidance_scale=guidance_scale,
112
+ eta=eta,
113
+ generator=generator,
114
+ ).images[0]
115
+
116
+ return image
117
+
118
+ def generate_style_mix(prompt, seed, num_steps, guidance_scale, eta, style_weight):
119
+ # Initialize the pipeline
120
+ base_model_id = "stabilityai/stable-diffusion-xl-base-1.0"
121
+ tcd_lora_id = "h1t/TCD-SDXL-LoRA"
122
+ styled_lora_id = "TheLastBen/Papercut_SDXL"
123
+
124
+ # Use CPU for inference
125
+ pipe = StableDiffusionXLPipeline.from_pretrained(
126
+ base_model_id,
127
+ torch_dtype=torch.float32 # Use float32 for CPU
128
+ )
129
+ pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
130
+
131
+ # Load multiple LoRA weights
132
+ pipe.load_lora_weights(tcd_lora_id, adapter_name="tcd")
133
+ pipe.load_lora_weights(styled_lora_id, adapter_name="style")
134
+
135
+ # Set adapter weights
136
+ pipe.set_adapters(["tcd", "style"], adapter_weights=[1.0, style_weight])
137
+
138
+ # Generate the image
139
+ generator = torch.Generator().manual_seed(seed)
140
+ image = pipe(
141
+ prompt=prompt,
142
+ num_inference_steps=num_steps,
143
+ guidance_scale=guidance_scale,
144
+ eta=eta,
145
+ generator=generator,
146
+ ).images[0]
147
+
148
+ return image
149
+
150
+ def generate_controlnet(prompt, init_image, seed, num_steps, guidance_scale, eta, controlnet_scale):
151
+ # Initialize the pipeline
152
+ base_model_id = "stabilityai/stable-diffusion-xl-base-1.0"
153
+ controlnet_id = "diffusers/controlnet-depth-sdxl-1.0"
154
+ tcd_lora_id = "h1t/TCD-SDXL-LoRA"
155
+
156
+ # Initialize ControlNet
157
+ controlnet = ControlNetModel.from_pretrained(
158
+ controlnet_id,
159
+ torch_dtype=torch.float32 # Use float32 for CPU
160
+ )
161
+
162
+ # Initialize pipeline
163
+ pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
164
+ base_model_id,
165
+ controlnet=controlnet,
166
+ torch_dtype=torch.float32 # Use float32 for CPU
167
+ )
168
+ pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
169
+
170
+ # Load and fuse LoRA weights
171
+ pipe.load_lora_weights(tcd_lora_id)
172
+ pipe.fuse_lora()
173
+
174
+ # Generate depth map
175
+ depth_image = get_depth_map(init_image)
176
+
177
+ # Generate the image
178
+ generator = torch.Generator().manual_seed(seed)
179
+ image = pipe(
180
+ prompt=prompt,
181
+ image=depth_image,
182
+ num_inference_steps=num_steps,
183
+ guidance_scale=guidance_scale,
184
+ eta=eta,
185
+ controlnet_conditioning_scale=controlnet_scale,
186
+ generator=generator,
187
+ ).images[0]
188
+
189
+ # Create a grid of the depth map and result
190
+ grid = make_image_grid([depth_image, image], rows=1, cols=2)
191
+ return grid
192
+
193
+ def inpaint_image(prompt, init_image, mask_image, seed, num_steps, guidance_scale, eta, strength):
194
+ # Initialize the pipeline
195
+ base_model_id = "diffusers/stable-diffusion-xl-1.0-inpainting-0.1"
196
+ tcd_lora_id = "h1t/TCD-SDXL-LoRA"
197
+
198
+ # Use CPU for inference
199
+ pipe = AutoPipelineForInpainting.from_pretrained(
200
+ base_model_id,
201
+ torch_dtype=torch.float32 # Use float32 for CPU
202
+ )
203
+ pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
204
+
205
+ # Load and fuse LoRA weights
206
+ pipe.load_lora_weights(tcd_lora_id)
207
+ pipe.fuse_lora()
208
+
209
+ # Generate the image
210
+ generator = torch.Generator().manual_seed(seed)
211
+ image = pipe(
212
+ prompt=prompt,
213
+ image=init_image,
214
+ mask_image=mask_image,
215
+ num_inference_steps=num_steps,
216
+ guidance_scale=guidance_scale,
217
+ eta=eta,
218
+ strength=strength,
219
+ generator=generator,
220
+ ).images[0]
221
+
222
+ # Create a grid of the original image, mask, and result
223
+ grid = make_image_grid([init_image, mask_image, image], rows=1, cols=3)
224
+ return grid
225
+
226
+ def generate_animation(prompt, seed, num_steps, guidance_scale, eta, num_frames, motion_scale):
227
+ # Initialize the pipeline
228
+ base_model_id = "frankjoshua/toonyou_beta6"
229
+ motion_adapter_id = "guoyww/animatediff-motion-adapter-v1-5"
230
+ tcd_lora_id = "h1t/TCD-SD15-LoRA"
231
+ motion_lora_id = "guoyww/animatediff-motion-lora-zoom-in"
232
+
233
+ # Load motion adapter
234
+ adapter = MotionAdapter.from_pretrained(motion_adapter_id)
235
+
236
+ # Initialize pipeline with CPU optimization
237
+ pipe = AnimateDiffPipeline.from_pretrained(
238
+ base_model_id,
239
+ motion_adapter=adapter,
240
+ torch_dtype=torch.float32, # Use float32 for CPU
241
+ low_cpu_mem_usage=True, # Enable low CPU memory usage
242
+ use_safetensors=False # Use standard PyTorch weights
243
+ )
244
+
245
+ # Set TCD scheduler
246
+ pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
247
+
248
+ # Load LoRA weights
249
+ pipe.load_lora_weights(tcd_lora_id, adapter_name="tcd")
250
+ pipe.load_lora_weights(
251
+ motion_lora_id,
252
+ adapter_name="motion-lora"
253
+ )
254
+
255
+ # Set adapter weights
256
+ pipe.set_adapters(["tcd", "motion-lora"], adapter_weights=[1.0, motion_scale])
257
+
258
+ # Generate animation
259
+ generator = torch.Generator().manual_seed(seed)
260
+ frames = pipe(
261
+ prompt=prompt,
262
+ num_inference_steps=num_steps,
263
+ guidance_scale=guidance_scale,
264
+ cross_attention_kwargs={"scale": 1},
265
+ num_frames=num_frames,
266
+ eta=eta,
267
+ generator=generator
268
+ ).frames[0]
269
+
270
+ # Export to GIF
271
+ gif_path = "animation.gif"
272
+ export_to_gif(frames, gif_path)
273
+ return gif_path
274
+
275
+ # Create the Gradio interface
276
+ with gr.Blocks(title="TCD-SDXL Image Generator") as demo:
277
+ gr.Markdown("# TCD-SDXL Image Generator")
278
+ gr.Markdown("Generate images using Trajectory Consistency Distillation with Stable Diffusion XL. Note: This runs on CPU, so generation may take some time.")
279
+
280
+ with gr.Tabs():
281
+ with gr.TabItem("Text to Image"):
282
+ with gr.Row():
283
+ with gr.Column():
284
+ text_prompt = gr.Textbox(
285
+ label="Prompt",
286
+ value="Painting of the orange cat Otto von Garfield, Count of Bismarck-Schönhausen, Duke of Lauenburg, Minister-President of Prussia. Depicted wearing a Prussian Pickelhaube and eating his favorite meal - lasagna.",
287
+ lines=3
288
+ )
289
+ text_seed = gr.Slider(minimum=0, maximum=2147483647, value=0, label="Seed", step=1)
290
+ text_steps = gr.Slider(minimum=1, maximum=10, value=4, label="Number of Steps", step=1)
291
+ text_guidance = gr.Slider(minimum=0, maximum=1, value=0, label="Guidance Scale")
292
+ text_eta = gr.Slider(minimum=0, maximum=1, value=0.3, label="Eta")
293
+ text_button = gr.Button("Generate")
294
+ with gr.Column():
295
+ text_output = gr.Image(label="Generated Image")
296
+
297
+ text_button.click(
298
+ fn=generate_image,
299
+ inputs=[text_prompt, text_seed, text_steps, text_guidance, text_eta],
300
+ outputs=text_output
301
+ )
302
+
303
+ with gr.TabItem("Inpainting"):
304
+ with gr.Row():
305
+ with gr.Column():
306
+ inpaint_prompt = gr.Textbox(
307
+ label="Prompt",
308
+ value="a tiger sitting on a park bench",
309
+ lines=3
310
+ )
311
+ init_image = gr.Image(label="Initial Image", type="pil")
312
+ mask_image = gr.Image(label="Mask Image", type="pil")
313
+ inpaint_seed = gr.Slider(minimum=0, maximum=2147483647, value=0, label="Seed", step=1)
314
+ inpaint_steps = gr.Slider(minimum=1, maximum=10, value=8, label="Number of Steps", step=1)
315
+ inpaint_guidance = gr.Slider(minimum=0, maximum=1, value=0, label="Guidance Scale")
316
+ inpaint_eta = gr.Slider(minimum=0, maximum=1, value=0.3, label="Eta")
317
+ inpaint_strength = gr.Slider(minimum=0, maximum=1, value=0.99, label="Strength")
318
+ inpaint_button = gr.Button("Inpaint")
319
+ with gr.Column():
320
+ inpaint_output = gr.Image(label="Result (Original | Mask | Generated)")
321
+
322
+ inpaint_button.click(
323
+ fn=inpaint_image,
324
+ inputs=[
325
+ inpaint_prompt, init_image, mask_image, inpaint_seed,
326
+ inpaint_steps, inpaint_guidance, inpaint_eta, inpaint_strength
327
+ ],
328
+ outputs=inpaint_output
329
+ )
330
+
331
+ with gr.TabItem("Community Models"):
332
+ with gr.Row():
333
+ with gr.Column():
334
+ community_prompt = gr.Textbox(
335
+ label="Prompt",
336
+ value="A man, clad in a meticulously tailored military uniform, stands with unwavering resolve. The uniform boasts intricate details, and his eyes gleam with determination. Strands of vibrant, windswept hair peek out from beneath the brim of his cap.",
337
+ lines=3
338
+ )
339
+ model_dropdown = gr.Dropdown(
340
+ choices=list(AVAILABLE_MODELS.keys()),
341
+ value="Animagine XL 3.0",
342
+ label="Select Model"
343
+ )
344
+ community_seed = gr.Slider(minimum=0, maximum=2147483647, value=0, label="Seed", step=1)
345
+ community_steps = gr.Slider(minimum=1, maximum=10, value=8, label="Number of Steps", step=1)
346
+ community_guidance = gr.Slider(minimum=0, maximum=1, value=0, label="Guidance Scale")
347
+ community_eta = gr.Slider(minimum=0, maximum=1, value=0.3, label="Eta")
348
+ community_button = gr.Button("Generate")
349
+ with gr.Column():
350
+ community_output = gr.Image(label="Generated Image")
351
+
352
+ community_button.click(
353
+ fn=generate_community_image,
354
+ inputs=[
355
+ community_prompt, model_dropdown, community_seed,
356
+ community_steps, community_guidance, community_eta
357
+ ],
358
+ outputs=community_output
359
+ )
360
+
361
+ with gr.TabItem("Style Mixing"):
362
+ with gr.Row():
363
+ with gr.Column():
364
+ style_prompt = gr.Textbox(
365
+ label="Prompt",
366
+ value="papercut of a winter mountain, snow",
367
+ lines=3
368
+ )
369
+ style_seed = gr.Slider(minimum=0, maximum=2147483647, value=0, label="Seed", step=1)
370
+ style_steps = gr.Slider(minimum=1, maximum=10, value=4, label="Number of Steps", step=1)
371
+ style_guidance = gr.Slider(minimum=0, maximum=1, value=0, label="Guidance Scale")
372
+ style_eta = gr.Slider(minimum=0, maximum=1, value=0.3, label="Eta")
373
+ style_weight = gr.Slider(minimum=0, maximum=2, value=1.0, label="Style Weight", step=0.1)
374
+ style_button = gr.Button("Generate")
375
+ with gr.Column():
376
+ style_output = gr.Image(label="Generated Image")
377
+
378
+ style_button.click(
379
+ fn=generate_style_mix,
380
+ inputs=[
381
+ style_prompt, style_seed, style_steps,
382
+ style_guidance, style_eta, style_weight
383
+ ],
384
+ outputs=style_output
385
+ )
386
+
387
+ with gr.TabItem("ControlNet"):
388
+ with gr.Row():
389
+ with gr.Column():
390
+ control_prompt = gr.Textbox(
391
+ label="Prompt",
392
+ value="stormtrooper lecture, photorealistic",
393
+ lines=3
394
+ )
395
+ control_image = gr.Image(label="Input Image", type="pil")
396
+ control_seed = gr.Slider(minimum=0, maximum=2147483647, value=0, label="Seed", step=1)
397
+ control_steps = gr.Slider(minimum=1, maximum=10, value=4, label="Number of Steps", step=1)
398
+ control_guidance = gr.Slider(minimum=0, maximum=1, value=0, label="Guidance Scale")
399
+ control_eta = gr.Slider(minimum=0, maximum=1, value=0.3, label="Eta")
400
+ control_scale = gr.Slider(minimum=0, maximum=1, value=0.5, label="ControlNet Scale", step=0.1)
401
+ control_button = gr.Button("Generate")
402
+ with gr.Column():
403
+ control_output = gr.Image(label="Result (Depth Map | Generated)")
404
+
405
+ control_button.click(
406
+ fn=generate_controlnet,
407
+ inputs=[
408
+ control_prompt, control_image, control_seed,
409
+ control_steps, control_guidance, control_eta, control_scale
410
+ ],
411
+ outputs=control_output
412
+ )
413
+
414
+ with gr.TabItem("Animation"):
415
+ with gr.Row():
416
+ with gr.Column():
417
+ anim_prompt = gr.Textbox(
418
+ label="Prompt",
419
+ value="best quality, masterpiece, 1girl, looking at viewer, blurry background, upper body, contemporary, dress",
420
+ lines=3
421
+ )
422
+ anim_seed = gr.Slider(minimum=0, maximum=2147483647, value=0, label="Seed", step=1)
423
+ anim_steps = gr.Slider(minimum=1, maximum=10, value=5, label="Number of Steps", step=1)
424
+ anim_guidance = gr.Slider(minimum=0, maximum=1, value=0, label="Guidance Scale")
425
+ anim_eta = gr.Slider(minimum=0, maximum=1, value=0.3, label="Eta")
426
+ anim_frames = gr.Slider(minimum=8, maximum=32, value=24, label="Number of Frames", step=1)
427
+ anim_motion_scale = gr.Slider(minimum=0, maximum=2, value=1.2, label="Motion Scale", step=0.1)
428
+ anim_button = gr.Button("Generate Animation")
429
+ with gr.Column():
430
+ anim_output = gr.Image(label="Generated Animation", format="gif")
431
+
432
+ anim_button.click(
433
+ fn=generate_animation,
434
+ inputs=[
435
+ anim_prompt, anim_seed, anim_steps,
436
+ anim_guidance, anim_eta, anim_frames,
437
+ anim_motion_scale
438
+ ],
439
+ outputs=anim_output
440
+ )
441
+
442
+ if __name__ == "__main__":
443
+ demo.launch()