ZenCtrl / app.py
salso's picture
Add examples modes (#4)
5226627 verified
raw
history blame
8.55 kB
import os
import base64
import io
from typing import TypedDict
import requests
import gradio as gr
from PIL import Image
import examples_db
# Read Baseten configuration from environment variables.
BTEN_API_KEY = os.getenv("API_KEY")
URL = os.getenv("URL")
def image_to_base64(image: Image.Image) -> str:
with io.BytesIO() as buffer:
image.save(buffer, format="PNG")
return base64.b64encode(buffer.getvalue()).decode("utf-8")
def ensure_image(img) -> Image.Image:
if isinstance(img, Image.Image):
return img
elif isinstance(img, str):
return Image.open(img)
elif isinstance(img, dict) and "name" in img:
return Image.open(img["name"])
else:
raise ValueError("Cannot convert input to a PIL Image.")
def call_baseten_generate(
image: Image.Image,
prompt: str,
steps: int,
strength: float,
height: int,
width: int,
lora_name: str,
remove_bg: bool,
) -> Image.Image | None:
image = ensure_image(image)
b64_image = image_to_base64(image)
payload = {
"image": b64_image,
"prompt": prompt,
"steps": steps,
"strength": strength,
"height": height,
"width": width,
"lora_name": lora_name,
"bgrm": remove_bg,
}
headers = {"Authorization": f"Api-Key {BTEN_API_KEY or os.getenv('API_KEY')}"}
try:
if not URL:
raise ValueError("The URL environment variable is not set.")
response = requests.post(URL, headers=headers, json=payload)
if response.status_code == 200:
data = response.json()
gen_b64 = data.get("generated_image", None)
if gen_b64:
return Image.open(io.BytesIO(base64.b64decode(gen_b64)))
else:
return None
else:
print(f"Error: HTTP {response.status_code}\n{response.text}")
return None
except Exception as e:
print(f"Error: {e}")
return None
# ================== MODE CONFIG =====================
Mode = TypedDict(
"Mode",
{
"model": str,
"prompt": str,
"default_strength": float,
"default_height": int,
"default_width": int,
"models": list[str],
"remove_bg": bool,
},
)
MODE_DEFAULTS: dict[str, Mode] = {
"Background Generation": {
"model": "bg_canny_58000_1024",
"prompt": "A vibrant background with dynamic lighting and textures",
"default_strength": 1.2,
"default_height": 1024,
"default_width": 1024,
"models": ["bgwlight_15000_1024", "bg_canny_58000_1024", "gen_back_7000_1024"],
"remove_bg": True,
},
"Subject Generation": {
"model": "subject_99000_512",
"prompt": "A detailed portrait with soft lighting",
"default_strength": 1.2,
"default_height": 512,
"default_width": 512,
"models": ["zendsd_512_146000", "subject_99000_512", "zen_26000_512"],
"remove_bg": True,
},
"Canny": {
"model": "canny_21000_1024",
"prompt": "A futuristic cityscape with neon lights",
"default_strength": 1.2,
"default_height": 1024,
"default_width": 1024,
"models": ["canny_21000_1024"],
"remove_bg": True,
},
"Depth": {
"model": "depth_9800_1024",
"prompt": "A scene with pronounced depth and perspective",
"default_strength": 1.2,
"default_height": 1024,
"default_width": 1024,
"models": ["depth_9800_1024"],
"remove_bg": True,
},
"Deblurring": {
"model": "deblurr_1024_10000",
"prompt": "A scene with pronounced depth and perspective",
"default_strength": 1.2,
"default_height": 1024,
"default_width": 1024,
"models": ["deblurr_1024_10000"],
"remove_bg": False,
},
}
# ================== PRESET EXAMPLES =====================
# ================== UI =====================
header = """
<h1>🌍 ZenCtrl / FLUX</h1>
<div align="center" style="line-height: 1;">
<a href="https://github.com/FotographerAI/ZenCtrl/tree/main" target="_blank"><img src="https://img.shields.io/badge/GitHub-Repo-181717.svg"></a>
<a href="https://huggingface.co/spaces/fotographerai/ZenCtrl" target="_blank"><img src="https://img.shields.io/badge/πŸ€—_HuggingFace-Space-ffbd45.svg"></a>
<a href="https://discord.com/invite/b9RuYQ3F8k" target="_blank"><img src="https://img.shields.io/badge/Discord-Join-7289da.svg?logo=discord"></a>
</div>
"""
with gr.Blocks(title="🌍 ZenCtrl") as demo:
gr.HTML(header)
gr.Markdown("# ZenCtrl Demo")
with gr.Tabs():
for mode in MODE_DEFAULTS:
with gr.Tab(mode):
defaults = MODE_DEFAULTS[mode]
gr.Markdown(f"### {mode} Mode")
with gr.Row():
with gr.Column(scale=2):
input_image = gr.Image(label="Input Image", type="pil")
generate_button = gr.Button("Generate")
with gr.Blocks():
model_dropdown = gr.Dropdown(
label="Model",
choices=defaults["models"],
value=defaults["model"],
interactive=True,
)
remove_bg_checkbox = gr.Checkbox(
label="Remove Background", value=defaults["remove_bg"]
)
with gr.Column(scale=2):
output_image = gr.Image(label="Generated Image", type="pil")
prompt_box = gr.Textbox(
label="Prompt", value=defaults["prompt"], lines=2
)
with gr.Accordion("Generation Parameters", open=False):
with gr.Row():
step_slider = gr.Slider(
minimum=2, maximum=28, value=10, step=2, label="Steps"
)
strength_slider = gr.Slider(
minimum=0.5,
maximum=2.0,
value=defaults["default_strength"],
step=0.1,
label="Strength",
)
with gr.Row():
height_slider = gr.Slider(
minimum=512,
maximum=1360,
value=defaults["default_height"],
step=1,
label="Height",
)
width_slider = gr.Slider(
minimum=512,
maximum=1360,
value=defaults["default_width"],
step=1,
label="Width",
)
def on_generate_click(
model_name, prompt, steps, strength, height, width, remove_bg, image
):
return call_baseten_generate(
image,
prompt,
steps,
strength,
height,
width,
model_name,
remove_bg,
)
generate_button.click(
fn=on_generate_click,
inputs=[
model_dropdown,
prompt_box,
step_slider,
strength_slider,
height_slider,
width_slider,
remove_bg_checkbox,
input_image,
],
outputs=[output_image],
)
# ---------------- Templates --------------------
if examples_db.MODE_EXAMPLES.get(mode):
gr.Examples(
examples=examples_db.MODE_EXAMPLES.get(mode, []),
inputs=[input_image, prompt_box, output_image],
label="Presets (Input / Prompt / Output)",
examples_per_page=6,
)
if __name__ == "__main__":
demo.launch()