update personalization
Browse files- __pycache__/utils.cpython-310.pyc +0 -0
- app.py +242 -122
- utils.py +31 -1
__pycache__/utils.cpython-310.pyc
CHANGED
Binary files a/__pycache__/utils.cpython-310.pyc and b/__pycache__/utils.cpython-310.pyc differ
|
|
app.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
|
2 |
import gradio as gr
|
3 |
import numpy as np
|
4 |
import random
|
@@ -11,7 +10,7 @@ import open_clip
|
|
11 |
from optim_utils import optimize_prompt
|
12 |
from utils import (
|
13 |
clean_response_gpt, setup_model, init_gpt_api, call_gpt_api,
|
14 |
-
get_refine_msg, clean_cache, get_personalize_message,
|
15 |
clean_refined_prompt_response_gpt, IMAGES, OPTIONS, T2I_MODELS,
|
16 |
INSTRUCTION, IMAGE_OPTIONS, PROMPTS, SCENARIOS
|
17 |
)
|
@@ -41,6 +40,7 @@ torch.cuda.empty_cache()
|
|
41 |
METHOD = "Experimental"
|
42 |
counter = 1
|
43 |
enable_submit = False
|
|
|
44 |
responses_memory = {METHOD: {}}
|
45 |
example_data = [
|
46 |
[
|
@@ -100,7 +100,8 @@ def call_gpt_refine_prompt(prompt, num_prompts=5, max_tokens=1000, temperature=0
|
|
100 |
def personalize_prompt(prompt, history, feedback, like_image, dislike_image):
|
101 |
seed = random.randint(0, MAX_SEED)
|
102 |
client = init_gpt_api()
|
103 |
-
messages = get_personalize_message(prompt, history, feedback, like_image, dislike_image)
|
|
|
104 |
outputs = call_gpt_api(messages, client, "gpt-4o", seed, max_tokens=2000, temperature=0.7, top_p=0.9)
|
105 |
return outputs
|
106 |
|
@@ -121,12 +122,12 @@ def invert_prompt(prompt, images, prompt_len=15, iter=500, lr=0.1, batch_size=2)
|
|
121 |
}
|
122 |
inverted_prompt = optimize_prompt(clip_model, preprocess, text_params, device, target_images=images, target_prompts=prompt)
|
123 |
|
124 |
-
# eval(prompt, learned_prompt, optimized_images, clip_model, preprocess)
|
125 |
-
# return learned_prompt
|
126 |
-
|
127 |
# =========================
|
128 |
# UI Helper Functions
|
129 |
# =========================
|
|
|
|
|
|
|
130 |
def reset_gallery():
|
131 |
return []
|
132 |
|
@@ -136,105 +137,106 @@ def display_error_message(msg, duration=5):
|
|
136 |
def display_info_message(msg, duration=5):
|
137 |
gr.Info(msg, duration=duration)
|
138 |
|
139 |
-
def
|
140 |
-
|
141 |
-
fully_satisfied_option = ["Satisfied", "Very Satisfied"]
|
142 |
-
if_submit = (sim_radio in fully_satisfied_option) or enable_submit or (counter > MAX_ROUND)
|
143 |
-
return gr.update(interactive=if_submit)
|
144 |
-
|
145 |
-
def select_image(like_radio, images_method):
|
146 |
-
if like_radio == IMAGE_OPTIONS[0]:
|
147 |
-
return images_method[0][0]
|
148 |
-
elif like_radio == IMAGE_OPTIONS[1]:
|
149 |
-
return images_method[1][0]
|
150 |
-
elif like_radio == IMAGE_OPTIONS[2]:
|
151 |
-
return images_method[2][0]
|
152 |
-
elif like_radio == IMAGE_OPTIONS[3]:
|
153 |
-
return images_method[3][0]
|
154 |
-
else:
|
155 |
-
return None
|
156 |
-
|
157 |
-
def check_evaluation(sim_radio):
|
158 |
-
if not sim_radio:
|
159 |
display_error_message("โ Please fill all evaluations before changing image or submitting.")
|
160 |
return False
|
161 |
return True
|
162 |
|
163 |
def generate_image(prompt, like_image, dislike_image):
|
164 |
-
global responses_memory
|
165 |
history_prompts = [v["prompt"] for v in responses_memory[METHOD].values()]
|
166 |
feedback = [v["sim_radio"] for v in responses_memory[METHOD].values()]
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
gallery_images = []
|
|
|
173 |
refined_prompts = call_gpt_refine_prompt(personalized)
|
174 |
for i in range(NUM_IMAGES):
|
175 |
img = infer(refined_prompts[i])
|
176 |
gallery_images.append(img)
|
|
|
177 |
yield gallery_images
|
178 |
|
179 |
-
def
|
180 |
-
|
181 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
182 |
responses_memory[METHOD][counter] = {
|
183 |
"prompt": prompt,
|
184 |
"sim_radio": sim_radio,
|
185 |
"response": "",
|
186 |
-
"satisfied_img": f"round {counter},
|
187 |
-
"unsatisfied_img": f"round {counter},
|
188 |
}
|
189 |
|
190 |
-
enable_submit = True if sim_radio in ["Satisfied", "Very Satisfied"] or enable_submit else False
|
191 |
-
|
192 |
history_prompts = [[v["prompt"]] for v in responses_memory[METHOD].values()]
|
|
|
|
|
193 |
if not history_images:
|
194 |
-
history_images = current_images
|
195 |
elif current_images:
|
196 |
history_images.extend(current_images)
|
|
|
197 |
current_images = []
|
198 |
|
199 |
examples_state = gr.update(samples=history_prompts, visible=True)
|
200 |
prompt_state = gr.update(interactive=True)
|
201 |
next_state = gr.update(visible=True, interactive=True)
|
202 |
redesign_state = gr.update(interactive=False) if counter >= MAX_ROUND else gr.update(interactive=True)
|
203 |
-
|
204 |
-
|
205 |
counter += 1
|
|
|
206 |
|
207 |
-
|
208 |
-
else:
|
209 |
-
return {submit_btn: gr.skip()}
|
210 |
|
211 |
-
|
212 |
-
global counter, enable_submit, responses_memory
|
213 |
-
|
214 |
-
if check_evaluation(sim_radio):
|
215 |
-
# Save the final round entry
|
216 |
-
responses_memory[METHOD][counter] = {
|
217 |
-
"prompt": prompt,
|
218 |
-
"sim_radio": sim_radio,
|
219 |
-
"response": "",
|
220 |
-
"satisfied_img": f"round {counter}, {like_radio}",
|
221 |
-
"unsatisfied_img": f"round {counter}, {dislike_radio}",
|
222 |
-
}
|
223 |
-
|
224 |
-
# Reset states
|
225 |
-
counter = 1
|
226 |
-
enable_submit = False
|
227 |
-
|
228 |
-
# Reset buttons
|
229 |
-
prompt_state = gr.update(interactive=False)
|
230 |
-
next_state = gr.update(visible=False, interactive=False)
|
231 |
-
submit_state = gr.update(interactive=False)
|
232 |
-
redesign_state = gr.update(interactive=False)
|
233 |
-
|
234 |
-
display_info_message("โ
Your answer is saved!")
|
235 |
-
return None, None, None, prompt_state, next_state, redesign_state, submit_state
|
236 |
else:
|
237 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
238 |
|
239 |
# =========================
|
240 |
# Interface (single tab, no participant/scenario/background)
|
@@ -256,6 +258,7 @@ css = """
|
|
256 |
#button-container {
|
257 |
display: flex;
|
258 |
justify-content: center;
|
|
|
259 |
}
|
260 |
#compact-compact-row {
|
261 |
width:100%;
|
@@ -315,9 +318,56 @@ css = """
|
|
315 |
max-width: 150px;
|
316 |
display: inline-block;
|
317 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
318 |
"""
|
319 |
|
320 |
with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Inconsolata"), "Arial", "sans-serif"]), css=css) as demo:
|
|
|
|
|
|
|
321 |
with gr.Column(elem_id="col-container", elem_classes=["header-section"]):
|
322 |
gr.HTML('<div class="logo-container"><img src="https://huggingface.co/spaces/PAI-GEN/POET/resolve/main/images/icon.png" alt="POET Logo"></div>')
|
323 |
gr.Markdown("### Supporting Prompting Creativity with Automated Expansion of Text-to-Image Generation")
|
@@ -325,7 +375,7 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Inconsolata"), "
|
|
325 |
gr.HTML("""
|
326 |
<div style="text-align: center;">
|
327 |
<a href="https://arxiv.org/pdf/2504.13392" target="_blank" class="paper-link">
|
328 |
-
๐ Read the Full Paper
|
329 |
</a>
|
330 |
</div>
|
331 |
""")
|
@@ -337,13 +387,15 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Inconsolata"), "
|
|
337 |
|
338 |
gr.Markdown("""
|
339 |
<div class="authors-section">
|
340 |
-
<a href="https://scholar.google.com/citations?user=HXED4kIAAAAJ&hl=en">Evans Han</a>,
|
341 |
-
<a href="https://
|
342 |
-
<a href="https://
|
|
|
|
|
|
|
343 |
</div>
|
344 |
""", elem_classes=["authors-section"])
|
345 |
|
346 |
-
# gr.Markdown("---")
|
347 |
|
348 |
with gr.Tab(""):
|
349 |
with gr.Row(elem_id="compact-row"):
|
@@ -360,47 +412,99 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Inconsolata"), "
|
|
360 |
|
361 |
with gr.Row(elem_id="compact-row"):
|
362 |
with gr.Column(elem_id="col-container"):
|
363 |
-
images_method = gr.Gallery(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
364 |
|
365 |
with gr.Column(elem_id="col-container3"):
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
397 |
|
398 |
with gr.Column(elem_id="col-container2"):
|
399 |
gr.Markdown("### ๐ Examples")
|
400 |
-
ex1 = gr.Image(label="Image 1", width=200, height=200,
|
401 |
-
ex2 = gr.Image(label="Image 2", width=200, height=200,
|
402 |
-
ex3 = gr.Image(label="Image 3", width=200, height=200,
|
403 |
-
ex4 = gr.Image(label="Image 4", width=200, height=200,
|
404 |
|
405 |
gr.Examples(
|
406 |
examples=[[ex[0], ex[1][0], ex[1][1], ex[1][2], ex[1][3]] for ex in example_data],
|
@@ -410,28 +514,44 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Inconsolata"), "
|
|
410 |
# =========================
|
411 |
# Wiring
|
412 |
# =========================
|
413 |
-
|
414 |
-
|
415 |
-
|
416 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
417 |
|
418 |
next_btn.click(
|
419 |
fn=generate_image,
|
420 |
inputs=[prompt, like_image, dislike_image],
|
421 |
outputs=[images_method]
|
422 |
-
).success(lambda: [gr.update(interactive=True), gr.update(interactive=True)
|
|
|
423 |
|
424 |
redesign_btn.click(
|
425 |
fn=redesign,
|
426 |
-
inputs=[prompt, sim_radio,
|
427 |
-
outputs=[sim_radio,
|
428 |
)
|
429 |
|
430 |
submit_btn.click(
|
431 |
fn=save_response,
|
432 |
-
inputs=[prompt, sim_radio,
|
433 |
-
outputs=[sim_radio,
|
434 |
)
|
435 |
|
436 |
if __name__ == "__main__":
|
437 |
-
demo.launch()
|
|
|
|
|
1 |
import gradio as gr
|
2 |
import numpy as np
|
3 |
import random
|
|
|
10 |
from optim_utils import optimize_prompt
|
11 |
from utils import (
|
12 |
clean_response_gpt, setup_model, init_gpt_api, call_gpt_api,
|
13 |
+
get_refine_msg, clean_cache, get_personalize_message, get_personalized_simplified,
|
14 |
clean_refined_prompt_response_gpt, IMAGES, OPTIONS, T2I_MODELS,
|
15 |
INSTRUCTION, IMAGE_OPTIONS, PROMPTS, SCENARIOS
|
16 |
)
|
|
|
40 |
METHOD = "Experimental"
|
41 |
counter = 1
|
42 |
enable_submit = False
|
43 |
+
redesign_flag = False
|
44 |
responses_memory = {METHOD: {}}
|
45 |
example_data = [
|
46 |
[
|
|
|
100 |
def personalize_prompt(prompt, history, feedback, like_image, dislike_image):
|
101 |
seed = random.randint(0, MAX_SEED)
|
102 |
client = init_gpt_api()
|
103 |
+
# messages = get_personalize_message(prompt, history, feedback, like_image, dislike_image)
|
104 |
+
messages = get_personalized_simplified(prompt, like_image, dislike_image)
|
105 |
outputs = call_gpt_api(messages, client, "gpt-4o", seed, max_tokens=2000, temperature=0.7, top_p=0.9)
|
106 |
return outputs
|
107 |
|
|
|
122 |
}
|
123 |
inverted_prompt = optimize_prompt(clip_model, preprocess, text_params, device, target_images=images, target_prompts=prompt)
|
124 |
|
|
|
|
|
|
|
125 |
# =========================
|
126 |
# UI Helper Functions
|
127 |
# =========================
|
128 |
+
# Store generated images for selection
|
129 |
+
current_generated_images = []
|
130 |
+
|
131 |
def reset_gallery():
|
132 |
return []
|
133 |
|
|
|
137 |
def display_info_message(msg, duration=5):
|
138 |
gr.Info(msg, duration=duration)
|
139 |
|
140 |
+
def check_evaluation(sim_radio, like_image, dislike_image):
|
141 |
+
if not sim_radio or not like_image or not dislike_image:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
display_error_message("โ Please fill all evaluations before changing image or submitting.")
|
143 |
return False
|
144 |
return True
|
145 |
|
146 |
def generate_image(prompt, like_image, dislike_image):
|
147 |
+
global responses_memory, current_generated_images
|
148 |
history_prompts = [v["prompt"] for v in responses_memory[METHOD].values()]
|
149 |
feedback = [v["sim_radio"] for v in responses_memory[METHOD].values()]
|
150 |
+
print(feedback, like_image, dislike_image)
|
151 |
+
if like_image and dislike_image and feedback:
|
152 |
+
personalized = personalize_prompt(prompt, history_prompts, feedback, like_image, dislike_image)
|
153 |
+
else:
|
154 |
+
personalized = prompt
|
155 |
gallery_images = []
|
156 |
+
current_generated_images = [] # Reset the stored images
|
157 |
refined_prompts = call_gpt_refine_prompt(personalized)
|
158 |
for i in range(NUM_IMAGES):
|
159 |
img = infer(refined_prompts[i])
|
160 |
gallery_images.append(img)
|
161 |
+
current_generated_images.append(img) # Store for selection
|
162 |
yield gallery_images
|
163 |
|
164 |
+
def on_gallery_select(evt: gr.SelectData):
|
165 |
+
"""Handle gallery image selection and return the selected image"""
|
166 |
+
global current_generated_images
|
167 |
+
if current_generated_images and evt.index < len(current_generated_images):
|
168 |
+
return current_generated_images[evt.index]
|
169 |
+
return None
|
170 |
+
|
171 |
+
def handle_like_drag(selected_image):
|
172 |
+
"""Handle setting an image as liked"""
|
173 |
+
return selected_image
|
174 |
+
|
175 |
+
def handle_dislike_drag(selected_image):
|
176 |
+
"""Handle setting an image as disliked"""
|
177 |
+
return selected_image
|
178 |
+
|
179 |
+
def redesign(prompt, sim_radio, current_images, history_images, like_image, dislike_image):
|
180 |
+
global counter, responses_memory, redesign_flag
|
181 |
+
|
182 |
+
if check_evaluation(sim_radio, like_image, dislike_image):
|
183 |
responses_memory[METHOD][counter] = {
|
184 |
"prompt": prompt,
|
185 |
"sim_radio": sim_radio,
|
186 |
"response": "",
|
187 |
+
"satisfied_img": f"round {counter}, liked image",
|
188 |
+
"unsatisfied_img": f"round {counter}, disliked image",
|
189 |
}
|
190 |
|
|
|
|
|
191 |
history_prompts = [[v["prompt"]] for v in responses_memory[METHOD].values()]
|
192 |
+
|
193 |
+
# Update history images
|
194 |
if not history_images:
|
195 |
+
history_images = current_images.copy() if current_images else []
|
196 |
elif current_images:
|
197 |
history_images.extend(current_images)
|
198 |
+
|
199 |
current_images = []
|
200 |
|
201 |
examples_state = gr.update(samples=history_prompts, visible=True)
|
202 |
prompt_state = gr.update(interactive=True)
|
203 |
next_state = gr.update(visible=True, interactive=True)
|
204 |
redesign_state = gr.update(interactive=False) if counter >= MAX_ROUND else gr.update(interactive=True)
|
205 |
+
|
|
|
206 |
counter += 1
|
207 |
+
redesign_flag = True
|
208 |
|
209 |
+
display_info_message(f"โ
Round {counter-1} feedback saved! You can continue redesigning or restart.")
|
|
|
|
|
210 |
|
211 |
+
return None, current_images, history_images, examples_state, prompt_state, next_state, redesign_state
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
212 |
else:
|
213 |
+
return gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip()
|
214 |
+
|
215 |
+
def save_response(prompt, sim_radio, like_image, dislike_image):
|
216 |
+
global counter, responses_memory, redesign_flag, current_generated_images
|
217 |
+
|
218 |
+
# Reset all global variables
|
219 |
+
responses_memory[METHOD] = {}
|
220 |
+
counter = 1
|
221 |
+
redesign_flag = False
|
222 |
+
current_generated_images = []
|
223 |
+
|
224 |
+
# Reset UI states
|
225 |
+
prompt_state = gr.update(value="", interactive=True)
|
226 |
+
next_state = gr.update(visible=True, interactive=True)
|
227 |
+
redesign_state = gr.update(interactive=False)
|
228 |
+
submit_state = gr.update(interactive=False)
|
229 |
+
sim_radio_state = gr.update(value=None)
|
230 |
+
like_image_state = gr.update(value=None)
|
231 |
+
dislike_image_state = gr.update(value=None)
|
232 |
+
gallery_state = []
|
233 |
+
history_gallery_state = []
|
234 |
+
examples_state = gr.update(samples=[['']], visible=True)
|
235 |
+
|
236 |
+
display_info_message("๐ Session restarted! You can begin with a new prompt.")
|
237 |
+
|
238 |
+
return (sim_radio_state, prompt_state, next_state, redesign_state,
|
239 |
+
like_image_state, dislike_image_state, gallery_state, history_gallery_state, examples_state)
|
240 |
|
241 |
# =========================
|
242 |
# Interface (single tab, no participant/scenario/background)
|
|
|
258 |
#button-container {
|
259 |
display: flex;
|
260 |
justify-content: center;
|
261 |
+
gap: 10px;
|
262 |
}
|
263 |
#compact-compact-row {
|
264 |
width:100%;
|
|
|
318 |
max-width: 150px;
|
319 |
display: inline-block;
|
320 |
}
|
321 |
+
.instruction-box {
|
322 |
+
background: linear-gradient(135deg, #e8f4fd 0%, #f0f8ff 100%);
|
323 |
+
border: 2px solid #3498db;
|
324 |
+
border-radius: 12px;
|
325 |
+
padding: 20px;
|
326 |
+
margin: 15px 0;
|
327 |
+
color: #2c3e50;
|
328 |
+
}
|
329 |
+
.instruction-title {
|
330 |
+
font-size: 1.2em;
|
331 |
+
font-weight: bold;
|
332 |
+
margin-bottom: 15px;
|
333 |
+
color: #2c3e50;
|
334 |
+
display: flex;
|
335 |
+
align-items: center;
|
336 |
+
gap: 8px;
|
337 |
+
}
|
338 |
+
.step-list {
|
339 |
+
list-style: none;
|
340 |
+
padding: 0;
|
341 |
+
margin: 0;
|
342 |
+
}
|
343 |
+
.step-item {
|
344 |
+
background: rgba(52, 152, 219, 0.1);
|
345 |
+
border-radius: 8px;
|
346 |
+
padding: 12px 16px;
|
347 |
+
margin: 8px 0;
|
348 |
+
border-left: 4px solid #3498db;
|
349 |
+
}
|
350 |
+
.step-number {
|
351 |
+
font-weight: bold;
|
352 |
+
color: #3498db;
|
353 |
+
margin-right: 8px;
|
354 |
+
}
|
355 |
+
.personalization-header {
|
356 |
+
background: linear-gradient(135deg, #ff6b6b, #ee5a24);
|
357 |
+
color: white;
|
358 |
+
padding: 15px;
|
359 |
+
border-radius: 10px 10px 0 0;
|
360 |
+
margin: -10px -10px 15px -10px;
|
361 |
+
text-align: center;
|
362 |
+
font-weight: bold;
|
363 |
+
font-size: 1.1em;
|
364 |
+
}
|
365 |
"""
|
366 |
|
367 |
with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Inconsolata"), "Arial", "sans-serif"]), css=css) as demo:
|
368 |
+
# State variable to hold selected image
|
369 |
+
selected_image = gr.State(None)
|
370 |
+
|
371 |
with gr.Column(elem_id="col-container", elem_classes=["header-section"]):
|
372 |
gr.HTML('<div class="logo-container"><img src="https://huggingface.co/spaces/PAI-GEN/POET/resolve/main/images/icon.png" alt="POET Logo"></div>')
|
373 |
gr.Markdown("### Supporting Prompting Creativity with Automated Expansion of Text-to-Image Generation")
|
|
|
375 |
gr.HTML("""
|
376 |
<div style="text-align: center;">
|
377 |
<a href="https://arxiv.org/pdf/2504.13392" target="_blank" class="paper-link">
|
378 |
+
๐ Read the Full Paper
|
379 |
</a>
|
380 |
</div>
|
381 |
""")
|
|
|
387 |
|
388 |
gr.Markdown("""
|
389 |
<div class="authors-section">
|
390 |
+
<a href="https://scholar.google.com/citations?user=HXED4kIAAAAJ&hl=en">Evans Han</a>,
|
391 |
+
<a href="https://www.aliceqian.com/">Alice Qian Zhang</a>,
|
392 |
+
<a href="https://haiyizhu.com/">Haiyi Zhu</a>,
|
393 |
+
<a href="https://www.andrew.cmu.edu/user/hongs/">Hong Shen</a>,
|
394 |
+
<a href="https://pliang279.github.io/">Paul Pu Liang</a>,
|
395 |
+
<a href="https://janeon.github.io/">Jane Hsieh</a>
|
396 |
</div>
|
397 |
""", elem_classes=["authors-section"])
|
398 |
|
|
|
399 |
|
400 |
with gr.Tab(""):
|
401 |
with gr.Row(elem_id="compact-row"):
|
|
|
412 |
|
413 |
with gr.Row(elem_id="compact-row"):
|
414 |
with gr.Column(elem_id="col-container"):
|
415 |
+
images_method = gr.Gallery(
|
416 |
+
label="Generated Images (Click to select, then set to Like/Dislike image)",
|
417 |
+
columns=[4],
|
418 |
+
rows=[1],
|
419 |
+
height=400,
|
420 |
+
interactive=False,
|
421 |
+
elem_id="gallery",
|
422 |
+
format="png"
|
423 |
+
)
|
424 |
|
425 |
with gr.Column(elem_id="col-container3"):
|
426 |
+
like_btn = gr.Button("๐ Set as Liked (Optional for personalization)", size="sm", variant="secondary")
|
427 |
+
like_image = gr.Image(
|
428 |
+
label="Satisfied Image",
|
429 |
+
width=150,
|
430 |
+
height=150,
|
431 |
+
interactive=False,
|
432 |
+
format="png",
|
433 |
+
type="filepath"
|
434 |
+
)
|
435 |
+
dislike_btn = gr.Button("๐ Set as Disliked (Optional for personalization)", size="sm", variant="secondary")
|
436 |
+
dislike_image = gr.Image(
|
437 |
+
label="Unsatisfied Image",
|
438 |
+
width=150,
|
439 |
+
height=150,
|
440 |
+
interactive=False,
|
441 |
+
format="png",
|
442 |
+
type="filepath"
|
443 |
+
)
|
444 |
+
|
445 |
+
with gr.Accordion("๐ฏ Advanced: Personalized Image Redesign", open=False, elem_id="col-container2"):
|
446 |
+
gr.HTML("""
|
447 |
+
<div class="instruction-box">
|
448 |
+
<div class="instruction-title">
|
449 |
+
๐ How to Use Personalized Redesign
|
450 |
+
</div>
|
451 |
+
<div class="step-list">
|
452 |
+
<div class="step-item">
|
453 |
+
<span class="step-number">1.</span>
|
454 |
+
<strong>Rate Your Satisfaction:</strong> Provide a satisfaction score for the current generated images
|
455 |
+
</div>
|
456 |
+
<div class="step-item">
|
457 |
+
<span class="step-number">2.</span>
|
458 |
+
<strong>Select Preferences:</strong> Choose your most liked and disliked images
|
459 |
+
</div>
|
460 |
+
<div class="step-item">
|
461 |
+
<span class="step-number">3.</span>
|
462 |
+
<strong>Save & Iterate:</strong> Click "Save Personalized Data" before redesgining your prompt and clicking "Generate"
|
463 |
+
</div>
|
464 |
+
<div class="step-item">
|
465 |
+
<span class="step-number">4.</span>
|
466 |
+
<strong>Restart Anytime:</strong> Use the "Restart" button to begin a fresh session
|
467 |
+
</div>
|
468 |
+
</div>
|
469 |
+
</div>
|
470 |
+
""")
|
471 |
+
|
472 |
+
with gr.Column(elem_id="col-container2"):
|
473 |
+
gr.Markdown("### ๐ Rate Current Generation")
|
474 |
+
with gr.Row():
|
475 |
+
sim_radio = gr.Radio(
|
476 |
+
OPTIONS,
|
477 |
+
label="How satisfied are you with the current generated images?",
|
478 |
+
type="value",
|
479 |
+
show_label=True,
|
480 |
+
container=True,
|
481 |
+
scale=1
|
482 |
+
)
|
483 |
+
|
484 |
+
with gr.Row(elem_id="button-container"):
|
485 |
+
with gr.Column(scale=1):
|
486 |
+
redesign_btn = gr.Button("๐พ Save Personalization Data", variant="primary", size="lg")
|
487 |
+
with gr.Column(scale=1):
|
488 |
+
submit_btn = gr.Button("๐ Restart Session", variant="secondary", size="lg")
|
489 |
+
|
490 |
+
|
491 |
+
with gr.Column(elem_id="col-container2"):
|
492 |
+
example = gr.Examples([['']], prompt, label="๐ Prompt History", visible=True)
|
493 |
+
history_images = gr.Gallery(
|
494 |
+
label="๐๏ธ Generation History",
|
495 |
+
columns=[4],
|
496 |
+
rows=[1],
|
497 |
+
elem_id="gallery",
|
498 |
+
format="png",
|
499 |
+
interactive=False,
|
500 |
+
)
|
501 |
|
502 |
with gr.Column(elem_id="col-container2"):
|
503 |
gr.Markdown("### ๐ Examples")
|
504 |
+
ex1 = gr.Image(label="Image 1", width=200, height=200, format="png", type="filepath", visible=False)
|
505 |
+
ex2 = gr.Image(label="Image 2", width=200, height=200, format="png", type="filepath", visible=False)
|
506 |
+
ex3 = gr.Image(label="Image 3", width=200, height=200, format="png", type="filepath", visible=False)
|
507 |
+
ex4 = gr.Image(label="Image 4", width=200, height=200, format="png", type="filepath", visible=False)
|
508 |
|
509 |
gr.Examples(
|
510 |
examples=[[ex[0], ex[1][0], ex[1][1], ex[1][2], ex[1][3]] for ex in example_data],
|
|
|
514 |
# =========================
|
515 |
# Wiring
|
516 |
# =========================
|
517 |
+
# Handle gallery selection
|
518 |
+
images_method.select(
|
519 |
+
fn=on_gallery_select,
|
520 |
+
inputs=[],
|
521 |
+
outputs=[selected_image]
|
522 |
+
)
|
523 |
+
|
524 |
+
# Handle like/dislike button clicks
|
525 |
+
like_btn.click(
|
526 |
+
fn=handle_like_drag,
|
527 |
+
inputs=[selected_image],
|
528 |
+
outputs=[like_image]
|
529 |
+
)
|
530 |
+
|
531 |
+
dislike_btn.click(
|
532 |
+
fn=handle_dislike_drag,
|
533 |
+
inputs=[selected_image],
|
534 |
+
outputs=[dislike_image]
|
535 |
+
)
|
536 |
|
537 |
next_btn.click(
|
538 |
fn=generate_image,
|
539 |
inputs=[prompt, like_image, dislike_image],
|
540 |
outputs=[images_method]
|
541 |
+
).success(lambda: [gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True)],
|
542 |
+
outputs=[next_btn, prompt, redesign_btn, submit_btn])
|
543 |
|
544 |
redesign_btn.click(
|
545 |
fn=redesign,
|
546 |
+
inputs=[prompt, sim_radio, images_method, history_images, like_image, dislike_image],
|
547 |
+
outputs=[sim_radio, images_method, history_images, example.dataset, prompt, next_btn, redesign_btn]
|
548 |
)
|
549 |
|
550 |
submit_btn.click(
|
551 |
fn=save_response,
|
552 |
+
inputs=[prompt, sim_radio, like_image, dislike_image],
|
553 |
+
outputs=[sim_radio, prompt, next_btn, redesign_btn, like_image, dislike_image, images_method, history_images, example.dataset]
|
554 |
)
|
555 |
|
556 |
if __name__ == "__main__":
|
557 |
+
demo.launch()
|
utils.py
CHANGED
@@ -171,7 +171,37 @@ def get_personalize_message(prompt, history_prompts, history_feedback, like_imag
|
|
171 |
"url": f"data:image/png;base64,{dislike_image_base64}",
|
172 |
},
|
173 |
})
|
174 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
175 |
return messages
|
176 |
|
177 |
@spaces.GPU
|
|
|
171 |
"url": f"data:image/png;base64,{dislike_image_base64}",
|
172 |
},
|
173 |
})
|
174 |
+
return messages
|
175 |
+
|
176 |
+
def get_personalized_simplified(prompt, like_image, dislike_image):
|
177 |
+
messages = [
|
178 |
+
{"role": "system", "content": f"You are a prompt refinement assistant. Your task is to improve a userโs text prompt based on his liked and disliked images. Your goal is to refine the prompt while maintaining its original meaning, improving clarity, specificity, and alignment with user preferences."}
|
179 |
+
]
|
180 |
+
|
181 |
+
message = f"""The first given image is user's liked image, refine prompt with style, and content user likes. The second given image is user's disliked image, refine prompt to avoid those elements or style of this image."""
|
182 |
+
|
183 |
+
messages.append({
|
184 |
+
"role": "user",
|
185 |
+
"content": [
|
186 |
+
{"type": "text", "text": f"{message}"},
|
187 |
+
],
|
188 |
+
})
|
189 |
+
if like_image:
|
190 |
+
like_image_base64 = encode_image(like_image)
|
191 |
+
messages[-1]["content"].append({
|
192 |
+
"type": "image_url",
|
193 |
+
"image_url": {
|
194 |
+
"url": f"data:image/png;base64,{like_image_base64}",
|
195 |
+
},
|
196 |
+
})
|
197 |
+
if dislike_image:
|
198 |
+
dislike_image_base64 = encode_image(dislike_image)
|
199 |
+
messages[-1]["content"].append({
|
200 |
+
"type": "image_url",
|
201 |
+
"image_url": {
|
202 |
+
"url": f"data:image/png;base64,{dislike_image_base64}",
|
203 |
+
},
|
204 |
+
})
|
205 |
return messages
|
206 |
|
207 |
@spaces.GPU
|