Qwen-Image-Dev / app.py
daniel-dona's picture
Update app.py
64b6943 verified
raw
history blame
5.26 kB
import gradio as gr
import numpy as np
import random
import json
from PIL import Image
import spaces
from http import HTTPStatus
from urllib.parse import urlparse, unquote
from pathlib import PurePosixPath
import requests
import os
from diffusers import DiffusionPipeline
import torch
model_name = "Qwen/Qwen-Image"
pipe = DiffusionPipeline.from_pretrained(model_name, torch_dtype=torch.bfloat16)
pipe.to('cuda')
MAX_SEED = np.iinfo(np.int32).max
#MAX_IMAGE_SIZE = 1440
examples = json.loads(open("examples.json").read())
# (1664, 928), (1472, 1140), (1328, 1328)
def get_image_size(aspect_ratio):
if aspect_ratio == "1:1":
return 1920, 1920
elif aspect_ratio == "16:9":
return 1920, 1080
elif aspect_ratio == "9:16":
return 1080, 1920
elif aspect_ratio == "4:3":
return 1920, 1440
elif aspect_ratio == "3:4":
return 1440, 1920
else:
return 640, 640
def polish_prompt_en(original_prompt):
SYSTEM_PROMPT = open("improve_prompt.txt").read()
original_prompt = original_prompt.strip()
prompt = f"{SYSTEM_PROMPT}\n\nUser Input: {original_prompt}\n\n Rewritten Prompt:"
success=False
while not success:
try:
polished_prompt = api(prompt, model='qwen-plus')
polished_prompt = polished_prompt.strip()
polished_prompt = polished_prompt.replace("\n", " ")
success = True
except Exception as e:
print(f"Error during API call: {e}")
return polished_prompt
@spaces.GPU(duration=45)
def infer(
prompt,
negative_prompt=" ",
seed=42,
randomize_seed=False,
aspect_ratio="16:9",
guidance_scale=4,
num_inference_steps=50,
progress=gr.Progress(track_tqdm=True),
):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
width, height = get_image_size(aspect_ratio)
print("Generating for prompt:", prompt)
image = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
width=width,
height=height,
num_inference_steps=50,
true_cfg_scale=4.0,
generator=torch.Generator(device="cuda").manual_seed(42)
).images[0]
#image.save("example.png")
return image, seed
css = """
#col-container {
margin: 0 auto;
max-width: 1024px;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
# gr.Markdown('<div style="text-align: center;"><a href="https://huggingface.co/Qwen/Qwen-Image"><img src="https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-Image/qwen_image_logo.png" width="400"/></a></div>')
gr.Markdown('<img src="https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-Image/qwen_image_logo.png" alt="your_alt_text" width="400" style="display: block; margin: 0 auto;">')
gr.Markdown("[Learn more](https://github.com/QwenLM/Qwen-Image) about the Qwen-Image series. Try on [Qwen Chat](https://chat.qwen.ai/), or [download model](https://huggingface.co/Qwen/Qwen-Image) to run locally with ComfyUI or diffusers.")
with gr.Row():
prompt = gr.Text(
label="Prompt",
show_label=False,
placeholder="Enter your prompt",
container=False,
)
run_button = gr.Button("Run", scale=0, variant="primary")
result = gr.Image(label="Result", show_label=False)
with gr.Accordion("Advanced Settings", open=False):
negative_prompt = gr.Text(
label="Negative prompt",
max_lines=1,
placeholder="Enter a negative prompt",
visible=True,
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
aspect_ratio = gr.Radio(
label="Image size (ratio, max dim 1920)",
choices=["1:1", "16:9", "9:16", "4:3", "3:4"],
value="16:9",
)
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=0.0,
maximum=7.5,
step=0.1,
value=4.0,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=50,
step=1,
value=35,
)
gr.Examples(examples=examples, inputs=[prompt], outputs=[result, seed], fn=infer, cache_examples=False, cache_mode="lazy")
gr.on(
triggers=[run_button.click, prompt.submit],
fn=infer,
inputs=[
prompt,
negative_prompt,
seed,
randomize_seed,
aspect_ratio,
guidance_scale,
num_inference_steps,
],
outputs=[result, seed],
)
if __name__ == "__main__":
demo.launch()