File size: 4,841 Bytes
bccf6a9
 
 
 
 
 
 
 
 
 
 
 
 
 
028fa1b
bccf6a9
 
 
 
 
 
 
 
 
bb10c2b
fa999ed
 
 
 
 
f2b506b
 
 
5bf77c5
cd3c12d
f2b506b
 
dc35a57
 
 
 
 
bb10c2b
bccf6a9
028fa1b
bccf6a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e12754a
7457dcf
d34d00a
f2b506b
 
bb10c2b
bccf6a9
 
 
 
 
 
bb10c2b
99c26e2
bccf6a9
 
 
 
 
 
 
 
 
ffd0df2
 
 
 
 
bccf6a9
 
398ef2b
bccf6a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
028fa1b
bccf6a9
 
 
028fa1b
bccf6a9
 
 
 
 
 
 
e12754a
bccf6a9
 
028fa1b
bccf6a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
398ef2b
d34d00a
bccf6a9
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import gradio as gr
import os
hf_token = os.environ.get("HF_TOKEN")
import spaces
from diffusers import DiffusionPipeline
from huggingface_hub import snapshot_download
import torch
import os, sys
import time

class Dummy():
    pass


pipeline_path = snapshot_download(repo_id='briaai/BRIA-2.3-T5')    
sys.path.append(pipeline_path)
from ella_xl_pipeline import EllaXLPipeline

resolutions = ["1024 1024","1280 768","1344 768","768 1344","768 1280"] 

# Ng
default_negative_prompt= "Logo,Watermark,Text,Ugly,Morbid,Extra fingers,Poorly drawn hands,Mutation,Blurry,Extra limbs,Gross proportions,Missing arms,Mutated hands,Long neck,Duplicate,Mutilated,Mutilated hands,Poorly drawn face,Deformed,Bad anatomy,Cloned face,Malformed limbs,Missing legs,Too many fingers"

# Load pipeline

pipe = DiffusionPipeline.from_pretrained("briaai/BRIA-2.3", torch_dtype=torch.float16, use_safetensors=True)
pipe.load_lora_weights(f'{pipeline_path}/pytorch_lora_weights.safetensors')
pipe.fuse_lora()
pipe.unload_lora_weights()
pipe.force_zeros_for_empty_prompt = False

pipe.to("cuda")

pipe = EllaXLPipeline(pipe,f'{pipeline_path}/pytorch_model.bin')



# def tocuda():
#     pipe.pipe.vae.to('cuda')
#     pipe.t5_encoder.to('cuda')
#     pipe.pipe.unet.unet.to('cuda')
#     pipe.pipe.unet.ella.to('cuda')


# print("Optimizing BRIA-2.3-T5 - this could take a while")
# t=time.time()
# pipe.unet = torch.compile(
#     pipe.unet, mode="reduce-overhead", fullgraph=True # 600 secs compilation
# )
# with torch.no_grad():
#     outputs = pipe(
#         prompt="an apple",
#         num_inference_steps=30,
#     )

#     # This will avoid future compilations on different shapes
#     unet_compiled = torch._dynamo.run(pipe.unet)
#     unet_compiled.config=pipe.unet.config
#     unet_compiled.add_embedding = Dummy()
#     unet_compiled.add_embedding.linear_1 = Dummy()
#     unet_compiled.add_embedding.linear_1.in_features = pipe.unet.add_embedding.linear_1.in_features
#     pipe.unet = unet_compiled

# print(f"Optimizing finished successfully after {time.time()-t} secs")

@spaces.GPU(enable_queue=True)
def infer(prompt,negative_prompt,seed,resolution, steps):
    
    # if 'cuda' not in pipe.pipe.device.type:
    #     tocuda()
        
    print(f"""
    —/n
    {prompt}
    """)
    
    t=time.time()

    
    if seed=="-1":
        generator=None
    else:
        try:
            seed=int(seed)
            generator = torch.Generator("cuda").manual_seed(seed)
        except:
            generator=None

    try:
        steps=int(steps)
    except:
        raise Exception('Steps must be an integer')

    w,h = resolution.split()
    w,h = int(w),int(h)
    image = pipe(prompt,num_inference_steps=steps, negative_prompt=negative_prompt,generator=generator,width=w,height=h).images[0]
    print(f'gen time is {time.time()-t} secs')
    
    # Future
    # Add amound of steps
    # if nsfw:
    #     raise gr.Error("Generated image is NSFW")
    
    return image

css = """
#col-container{
    margin: 0 auto;
    max-width: 580px;
}
"""
with gr.Blocks(css=css) as demo:
    with gr.Column(elem_id="col-container"):
        gr.Markdown("## BRIA 2.3 T5")
        gr.HTML('''
          <p style="margin-bottom: 10px; font-size: 94%">
            This is a demo for 
            <a href="https://huggingface.co/briaai/BRIA-2.3-T5" target="_blank">BRIA 2.3 T5 text-to-image </a>. 
          </p>
        ''')
        with gr.Group():
            with gr.Column():
                prompt_in = gr.Textbox(label="Prompt", value="A smiling man with wavy brown hair and a trimmed beard")
                resolution = gr.Dropdown(value=resolutions[0], show_label=True, label="Resolution", choices=resolutions)
                seed = gr.Textbox(label="Seed", value=-1)
                steps = gr.Textbox(label="Steps", value=30)
                negative_prompt = gr.Textbox(label="Negative Prompt", value=default_negative_prompt)
                submit_btn = gr.Button("Generate")
        result = gr.Image(label="BRIA-2.3-T5 Result")

        # gr.Examples(
        #     examples = [ 
        #         "Dragon, digital art, by Greg Rutkowski",
        #         "Armored knight holding sword",
        #         "A flat roof villa near a river with black walls and huge windows",
        #         "A calm and peaceful office",
        #         "Pirate guinea pig"
        #     ],
        #     fn = infer, 
        #     inputs = [
        #         prompt_in
        #     ],
        #     outputs = [
        #         result
        #     ]
        # )

    submit_btn.click(
        fn = infer,
        inputs = [
            prompt_in,
            negative_prompt,
            seed,
            resolution,
            steps,
        ],
        outputs = [
            result
        ]
    )

demo.queue().launch(show_api=False)