Daraphan commited on
Commit
3b96ed7
·
verified ·
1 Parent(s): 829020a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +103 -0
app.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import spaces
3
+ import torch
4
+ import gradio as gr
5
+
6
+ from modeling.dmm_pipeline import StableDiffusionDMMPipeline
7
+ from huggingface_hub import snapshot_download
8
+
9
+
10
+ ckpt_path = "ckpt"
11
+ snapshot_download(repo_id="MCG-NJU/DMM", local_dir=ckpt_path)
12
+
13
+ pipe = StableDiffusionDMMPipeline.from_pretrained(
14
+ ckpt_path,
15
+ torch_dtype=torch.float16,
16
+ use_safetensors=True
17
+ )
18
+ pipe.to("cuda")
19
+
20
+
21
+ @spaces.GPU
22
+ def generate(prompt: str,
23
+ negative_prompt: str,
24
+ model_id: int,
25
+ seed: int = 1234,
26
+ height: int = 512,
27
+ width: int = 512,
28
+ all: bool = True):
29
+ if all:
30
+ outputs = []
31
+ for i in range(pipe.unet.get_num_models()):
32
+ output = pipe(
33
+ prompt=prompt,
34
+ negative_prompt=negative_prompt,
35
+ width=width,
36
+ height=height,
37
+ num_inference_steps=25,
38
+ guidance_scale=7,
39
+ model_id=i,
40
+ generator=torch.Generator().manual_seed(seed),
41
+ ).images[0]
42
+ outputs.append(output)
43
+ return outputs
44
+ else:
45
+ output = pipe(
46
+ prompt=prompt,
47
+ negative_prompt=negative_prompt,
48
+ width=width,
49
+ height=height,
50
+ num_inference_steps=25,
51
+ guidance_scale=7,
52
+ model_id=int(model_id),
53
+ generator=torch.Generator().manual_seed(seed),
54
+ ).images[0]
55
+ return [output,]
56
+
57
+
58
+ candidates = [
59
+ "0. [JuggernautReborn] realistic",
60
+ "1. [MajicmixRealisticV7] realistic, Asia portrait",
61
+ "2. [EpicRealismV5] realistic",
62
+ "3. [RealisticVisionV5] realistic",
63
+ "4. [MajicmixFantasyV3] animation",
64
+ "5. [MinimalismV2] illustration",
65
+ "6. [RealCartoon3dV17] cartoon 3d",
66
+ "7. [AWPaintingV1.4] animation",
67
+ ]
68
+
69
+ def main():
70
+ with gr.Blocks() as demo:
71
+ gr.Markdown(
72
+ """
73
+ # DMM Demo
74
+ The checkpoint is https://huggingface.co/MCG-NJU/DMM.
75
+ """
76
+ )
77
+ with gr.Row():
78
+ with gr.Column():
79
+ with gr.Column():
80
+ model_id = gr.Dropdown(candidates, label="Model Index", type="index")
81
+ all_check = gr.Checkbox(label="All (ignore the selection above)")
82
+ prompt = gr.Textbox("portrait photo of a girl, long golden hair, flowers, best quality", label="Prompt")
83
+ negative_prompt = gr.Textbox("worst quality,low quality,normal quality,lowres,watermark,nsfw", label="Negative Prompt")
84
+ with gr.Row():
85
+ seed = gr.Number(0, label="Seed", precision=0, scale=3)
86
+ update_seed_btn = gr.Button("🎲", scale=1)
87
+ with gr.Row():
88
+ height = gr.Number(768, step=8, label="Height (suggest 512~768)")
89
+ width = gr.Number(512, step=8, label="Width")
90
+ submit_btn = gr.Button("Submit", variant="primary")
91
+ output = gr.Gallery(label="images")
92
+
93
+ submit_btn.click(generate,
94
+ inputs=[prompt, negative_prompt, model_id, seed, height, width, all_check],
95
+ outputs=[output])
96
+ update_seed_btn.click(lambda: random.randint(0, 1000000),
97
+ outputs=[seed])
98
+
99
+ demo.launch()
100
+
101
+
102
+ if __name__ == "__main__":
103
+ main()