BRIA-2.3-T5 / app.py
Eyalgut's picture
Update app.py
e12754a verified
import gradio as gr
import os
hf_token = os.environ.get("HF_TOKEN")
import spaces
from diffusers import DiffusionPipeline
from huggingface_hub import snapshot_download
import torch
import os, sys
import time
class Dummy():
pass
pipeline_path = snapshot_download(repo_id='briaai/BRIA-2.3-T5')
sys.path.append(pipeline_path)
from ella_xl_pipeline import EllaXLPipeline
resolutions = ["1024 1024","1280 768","1344 768","768 1344","768 1280"]
# Ng
default_negative_prompt= "Logo,Watermark,Text,Ugly,Morbid,Extra fingers,Poorly drawn hands,Mutation,Blurry,Extra limbs,Gross proportions,Missing arms,Mutated hands,Long neck,Duplicate,Mutilated,Mutilated hands,Poorly drawn face,Deformed,Bad anatomy,Cloned face,Malformed limbs,Missing legs,Too many fingers"
# Load pipeline
pipe = DiffusionPipeline.from_pretrained("briaai/BRIA-2.3", torch_dtype=torch.float16, use_safetensors=True)
pipe.load_lora_weights(f'{pipeline_path}/pytorch_lora_weights.safetensors')
pipe.fuse_lora()
pipe.unload_lora_weights()
pipe.force_zeros_for_empty_prompt = False
pipe.to("cuda")
pipe = EllaXLPipeline(pipe,f'{pipeline_path}/pytorch_model.bin')
# def tocuda():
# pipe.pipe.vae.to('cuda')
# pipe.t5_encoder.to('cuda')
# pipe.pipe.unet.unet.to('cuda')
# pipe.pipe.unet.ella.to('cuda')
# print("Optimizing BRIA-2.3-T5 - this could take a while")
# t=time.time()
# pipe.unet = torch.compile(
# pipe.unet, mode="reduce-overhead", fullgraph=True # 600 secs compilation
# )
# with torch.no_grad():
# outputs = pipe(
# prompt="an apple",
# num_inference_steps=30,
# )
# # This will avoid future compilations on different shapes
# unet_compiled = torch._dynamo.run(pipe.unet)
# unet_compiled.config=pipe.unet.config
# unet_compiled.add_embedding = Dummy()
# unet_compiled.add_embedding.linear_1 = Dummy()
# unet_compiled.add_embedding.linear_1.in_features = pipe.unet.add_embedding.linear_1.in_features
# pipe.unet = unet_compiled
# print(f"Optimizing finished successfully after {time.time()-t} secs")
@spaces.GPU(enable_queue=True)
def infer(prompt,negative_prompt,seed,resolution, steps):
# if 'cuda' not in pipe.pipe.device.type:
# tocuda()
print(f"""
—/n
{prompt}
""")
t=time.time()
if seed=="-1":
generator=None
else:
try:
seed=int(seed)
generator = torch.Generator("cuda").manual_seed(seed)
except:
generator=None
try:
steps=int(steps)
except:
raise Exception('Steps must be an integer')
w,h = resolution.split()
w,h = int(w),int(h)
image = pipe(prompt,num_inference_steps=steps, negative_prompt=negative_prompt,generator=generator,width=w,height=h).images[0]
print(f'gen time is {time.time()-t} secs')
# Future
# Add amound of steps
# if nsfw:
# raise gr.Error("Generated image is NSFW")
return image
css = """
#col-container{
margin: 0 auto;
max-width: 580px;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown("## BRIA 2.3 T5")
gr.HTML('''
<p style="margin-bottom: 10px; font-size: 94%">
This is a demo for
<a href="https://huggingface.co/briaai/BRIA-2.3-T5" target="_blank">BRIA 2.3 T5 text-to-image </a>.
</p>
''')
with gr.Group():
with gr.Column():
prompt_in = gr.Textbox(label="Prompt", value="A smiling man with wavy brown hair and a trimmed beard")
resolution = gr.Dropdown(value=resolutions[0], show_label=True, label="Resolution", choices=resolutions)
seed = gr.Textbox(label="Seed", value=-1)
steps = gr.Textbox(label="Steps", value=30)
negative_prompt = gr.Textbox(label="Negative Prompt", value=default_negative_prompt)
submit_btn = gr.Button("Generate")
result = gr.Image(label="BRIA-2.3-T5 Result")
# gr.Examples(
# examples = [
# "Dragon, digital art, by Greg Rutkowski",
# "Armored knight holding sword",
# "A flat roof villa near a river with black walls and huge windows",
# "A calm and peaceful office",
# "Pirate guinea pig"
# ],
# fn = infer,
# inputs = [
# prompt_in
# ],
# outputs = [
# result
# ]
# )
submit_btn.click(
fn = infer,
inputs = [
prompt_in,
negative_prompt,
seed,
resolution,
steps,
],
outputs = [
result
]
)
demo.queue().launch(show_api=False)