import gradio as gr import numpy as np import random import spaces import torch import re import transformers import open_clip from optim_utils import optimize_prompt from utils import ( clean_response_gpt, setup_model, init_gpt_api, call_gpt_api, get_refine_msg, clean_cache, get_personalize_message, get_personalized_simplified, clean_refined_prompt_response_gpt, IMAGES, OPTIONS, T2I_MODELS, INSTRUCTION, IMAGE_OPTIONS, PROMPTS, SCENARIOS ) # ========================= # Constants / Defaults # ========================= CLIP_MODEL = "ViT-H-14" PRETRAINED_CLIP = "laion2b_s32b_b79k" default_t2i_model = "black-forest-labs/FLUX.1-dev" default_llm_model = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B" MAX_SEED = np.iinfo(np.int32).max MAX_IMAGE_SIZE = 1024 NUM_IMAGES = 4 MAX_ROUND = 5 device = "cuda" if torch.cuda.is_available() else "cpu" torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 clean_cache() selected_pipe = setup_model(default_t2i_model, torch_dtype, device) clip_model, _, preprocess = open_clip.create_model_and_transforms(CLIP_MODEL, pretrained=PRETRAINED_CLIP, device=device) llm_pipe = None inverted_prompt = "" torch.cuda.empty_cache() METHOD = "Experimental" counter = 1 enable_submit = False redesign_flag = False responses_memory = {METHOD: {}} example_data = [ [ PROMPTS["Tourist promotion"], IMAGES["Tourist promotion"]["ours"] ], [ PROMPTS["Fictional character generation"], IMAGES["Fictional character generation"]["ours"] ], [ PROMPTS["Interior Design"], IMAGES["Interior Design"]["ours"] ], ] # ========================= # Image Generation Helpers # ========================= @spaces.GPU(duration=65) def infer( prompt, negative_prompt="", seed=42, randomize_seed=True, width=256, height=256, guidance_scale=5, num_inference_steps=18, progress=gr.Progress(track_tqdm=True), ): if randomize_seed: seed = random.randint(0, MAX_SEED) generator = torch.Generator().manual_seed(seed) with torch.no_grad(): image = selected_pipe( prompt=prompt, negative_prompt=negative_prompt, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, width=width, height=height, generator=generator, ).images[0] return image def call_gpt_refine_prompt(prompt, num_prompts=5, max_tokens=1000, temperature=0.7, top_p=0.9): seed = random.randint(0, MAX_SEED) client = init_gpt_api() messages = get_refine_msg(prompt, num_prompts) outputs = call_gpt_api(messages, client, "gpt-4o", seed, max_tokens, temperature, top_p) prompt_list = clean_response_gpt(outputs) return prompt_list def personalize_prompt(prompt, history, feedback, like_image, dislike_image): seed = random.randint(0, MAX_SEED) client = init_gpt_api() # messages = get_personalize_message(prompt, history, feedback, like_image, dislike_image) messages = get_personalized_simplified(prompt, like_image, dislike_image) outputs = call_gpt_api(messages, client, "gpt-4o", seed, max_tokens=2000, temperature=0.7, top_p=0.9) return outputs @spaces.GPU(duration=100) def invert_prompt(prompt, images, prompt_len=15, iter=500, lr=0.1, batch_size=2): global inverted_prompt text_params = { "iter": iter, "lr": lr, "batch_size": batch_size, "prompt_len": prompt_len, "weight_decay": 0.1, "prompt_bs": 1, "loss_weight": 1.0, "print_step": 100, "clip_model": CLIP_MODEL, "clip_pretrain": PRETRAINED_CLIP, } inverted_prompt = optimize_prompt(clip_model, preprocess, text_params, device, target_images=images, target_prompts=prompt) # ========================= # UI Helper Functions # ========================= # Store generated images for selection current_generated_images = [] def reset_gallery(): return [] def display_error_message(msg, duration=5): gr.Warning(msg, duration=duration) def display_info_message(msg, duration=5): gr.Info(msg, duration=duration) def check_evaluation(sim_radio, like_image, dislike_image): if not sim_radio or not like_image or not dislike_image: display_error_message("❌ Please fill all evaluations before changing image or submitting.") return False return True def generate_image(prompt, like_image, dislike_image): global responses_memory, current_generated_images history_prompts = [v["prompt"] for v in responses_memory[METHOD].values()] feedback = [v["sim_radio"] for v in responses_memory[METHOD].values()] print(feedback, like_image, dislike_image) if like_image and dislike_image and feedback: personalized = personalize_prompt(prompt, history_prompts, feedback, like_image, dislike_image) else: personalized = prompt gallery_images = [] current_generated_images = [] # Reset the stored images refined_prompts = call_gpt_refine_prompt(personalized) for i in range(NUM_IMAGES): img = infer(refined_prompts[i]) gallery_images.append(img) current_generated_images.append(img) # Store for selection yield gallery_images def on_gallery_select(evt: gr.SelectData): """Handle gallery image selection and return the selected image""" global current_generated_images if current_generated_images and evt.index < len(current_generated_images): return current_generated_images[evt.index] return None def handle_like_drag(selected_image): """Handle setting an image as liked""" return selected_image def handle_dislike_drag(selected_image): """Handle setting an image as disliked""" return selected_image def redesign(prompt, sim_radio, current_images, history_images, like_image, dislike_image): global counter, responses_memory, redesign_flag if check_evaluation(sim_radio, like_image, dislike_image): responses_memory[METHOD][counter] = { "prompt": prompt, "sim_radio": sim_radio, "response": "", "satisfied_img": f"round {counter}, liked image", "unsatisfied_img": f"round {counter}, disliked image", } history_prompts = [[v["prompt"]] for v in responses_memory[METHOD].values()] # Update history images if not history_images: history_images = current_images.copy() if current_images else [] elif current_images: history_images.extend(current_images) current_images = [] examples_state = gr.update(samples=history_prompts, visible=True) prompt_state = gr.update(interactive=True) next_state = gr.update(visible=True, interactive=True) redesign_state = gr.update(interactive=False) if counter >= MAX_ROUND else gr.update(interactive=True) counter += 1 redesign_flag = True display_info_message(f"✅ Round {counter-1} feedback saved! You can continue redesigning or restart.") return None, current_images, history_images, examples_state, prompt_state, next_state, redesign_state else: return gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip() def save_response(prompt, sim_radio, like_image, dislike_image): global counter, responses_memory, redesign_flag, current_generated_images # Reset all global variables responses_memory[METHOD] = {} counter = 1 redesign_flag = False current_generated_images = [] # Reset UI states prompt_state = gr.update(value="", interactive=True) next_state = gr.update(visible=True, interactive=True) redesign_state = gr.update(interactive=False) submit_state = gr.update(interactive=False) sim_radio_state = gr.update(value=None) like_image_state = gr.update(value=None) dislike_image_state = gr.update(value=None) gallery_state = [] history_gallery_state = [] examples_state = gr.update(samples=[['']], visible=True) display_info_message("🔄 Session restarted! You can begin with a new prompt.") return (sim_radio_state, prompt_state, next_state, redesign_state, like_image_state, dislike_image_state, gallery_state, history_gallery_state, examples_state) # ========================= # Interface (single tab, no participant/scenario/background) # ========================= css = """ #col-container { margin: 0 auto; max-width: 700px; } #col-container2 { margin: 0 auto; max-width: 1000px; } #col-container3 { margin: 0 0 auto auto; max-width: 300px; } #button-container { display: flex; justify-content: center; gap: 10px; } #compact-compact-row { width:100%; max-width: 800px; margin: 0px auto; } #compact-row { width:100%; max-width: 1000px; margin: 0px auto; } .header-section { text-align: center; margin-bottom: 2rem; } .abstract-text { text-align: justify; line-height: 1.5; margin: 0rem 0; padding: 0 0.5rem; background-color: rgba(0, 0, 0, 0.05); border-radius: 8px; border-left: 4px solid #3498db; } .paper-link { display: inline-block; margin: 0rem 0; padding: 0rem 0rem; background-color: #3498db; color: white; text-decoration: none; border-radius: 5px; font-weight: 500; } .paper-link:hover { background-color: #2980b9; text-decoration: none; } .authors-section { text-align: center; margin: 0 0; font-style: italic; color: #666; } .authors-title { font-weight: bold; margin-bottom: 0rem; color: #333; } .logo-container { text-align: center; margin: 0.5rem 0 1rem 0; } .logo-container img { height: 60px; width: auto; max-width: 150px; display: inline-block; } .instruction-box { background: linear-gradient(135deg, #e8f4fd 0%, #f0f8ff 100%); border: 2px solid #3498db; border-radius: 12px; padding: 20px; margin: 15px 0; color: #2c3e50; } .instruction-title { font-size: 1.2em; font-weight: bold; margin-bottom: 15px; color: #2c3e50; display: flex; align-items: center; gap: 8px; } .step-list { list-style: none; padding: 0; margin: 0; } .step-item { background: rgba(52, 152, 219, 0.1); border-radius: 8px; padding: 12px 16px; margin: 8px 0; border-left: 4px solid #3498db; } .step-number { font-weight: bold; color: #3498db; margin-right: 8px; } .personalization-header { background: linear-gradient(135deg, #ff6b6b, #ee5a24); color: white; padding: 15px; border-radius: 10px 10px 0 0; margin: -10px -10px 15px -10px; text-align: center; font-weight: bold; font-size: 1.1em; } """ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Inconsolata"), "Arial", "sans-serif"]), css=css) as demo: # State variable to hold selected image selected_image = gr.State(None) with gr.Column(elem_id="col-container", elem_classes=["header-section"]): gr.HTML('