Spaces:
PAI-GEN
/
Running on Zero

xh365 commited on
Commit
489c70c
ยท
1 Parent(s): acce46b

update personalization

Browse files
Files changed (3) hide show
  1. __pycache__/utils.cpython-310.pyc +0 -0
  2. app.py +242 -122
  3. utils.py +31 -1
__pycache__/utils.cpython-310.pyc CHANGED
Binary files a/__pycache__/utils.cpython-310.pyc and b/__pycache__/utils.cpython-310.pyc differ
 
app.py CHANGED
@@ -1,4 +1,3 @@
1
-
2
  import gradio as gr
3
  import numpy as np
4
  import random
@@ -11,7 +10,7 @@ import open_clip
11
  from optim_utils import optimize_prompt
12
  from utils import (
13
  clean_response_gpt, setup_model, init_gpt_api, call_gpt_api,
14
- get_refine_msg, clean_cache, get_personalize_message,
15
  clean_refined_prompt_response_gpt, IMAGES, OPTIONS, T2I_MODELS,
16
  INSTRUCTION, IMAGE_OPTIONS, PROMPTS, SCENARIOS
17
  )
@@ -41,6 +40,7 @@ torch.cuda.empty_cache()
41
  METHOD = "Experimental"
42
  counter = 1
43
  enable_submit = False
 
44
  responses_memory = {METHOD: {}}
45
  example_data = [
46
  [
@@ -100,7 +100,8 @@ def call_gpt_refine_prompt(prompt, num_prompts=5, max_tokens=1000, temperature=0
100
  def personalize_prompt(prompt, history, feedback, like_image, dislike_image):
101
  seed = random.randint(0, MAX_SEED)
102
  client = init_gpt_api()
103
- messages = get_personalize_message(prompt, history, feedback, like_image, dislike_image)
 
104
  outputs = call_gpt_api(messages, client, "gpt-4o", seed, max_tokens=2000, temperature=0.7, top_p=0.9)
105
  return outputs
106
 
@@ -121,12 +122,12 @@ def invert_prompt(prompt, images, prompt_len=15, iter=500, lr=0.1, batch_size=2)
121
  }
122
  inverted_prompt = optimize_prompt(clip_model, preprocess, text_params, device, target_images=images, target_prompts=prompt)
123
 
124
- # eval(prompt, learned_prompt, optimized_images, clip_model, preprocess)
125
- # return learned_prompt
126
-
127
  # =========================
128
  # UI Helper Functions
129
  # =========================
 
 
 
130
  def reset_gallery():
131
  return []
132
 
@@ -136,105 +137,106 @@ def display_error_message(msg, duration=5):
136
  def display_info_message(msg, duration=5):
137
  gr.Info(msg, duration=duration)
138
 
139
- def check_satisfaction(sim_radio):
140
- global enable_submit, counter
141
- fully_satisfied_option = ["Satisfied", "Very Satisfied"]
142
- if_submit = (sim_radio in fully_satisfied_option) or enable_submit or (counter > MAX_ROUND)
143
- return gr.update(interactive=if_submit)
144
-
145
- def select_image(like_radio, images_method):
146
- if like_radio == IMAGE_OPTIONS[0]:
147
- return images_method[0][0]
148
- elif like_radio == IMAGE_OPTIONS[1]:
149
- return images_method[1][0]
150
- elif like_radio == IMAGE_OPTIONS[2]:
151
- return images_method[2][0]
152
- elif like_radio == IMAGE_OPTIONS[3]:
153
- return images_method[3][0]
154
- else:
155
- return None
156
-
157
- def check_evaluation(sim_radio):
158
- if not sim_radio:
159
  display_error_message("โŒ Please fill all evaluations before changing image or submitting.")
160
  return False
161
  return True
162
 
163
  def generate_image(prompt, like_image, dislike_image):
164
- global responses_memory
165
  history_prompts = [v["prompt"] for v in responses_memory[METHOD].values()]
166
  feedback = [v["sim_radio"] for v in responses_memory[METHOD].values()]
167
- personalized = prompt
168
- # personalized = personalize_prompt(prompt, history_prompts, feedback, like_image, dislike_image)
169
- # personalized = clean_refined_prompt_response_gpt(personalized)
170
- # if "I'm sorry, I can't assist with" in personalized:
171
- # personalized = prompt
172
  gallery_images = []
 
173
  refined_prompts = call_gpt_refine_prompt(personalized)
174
  for i in range(NUM_IMAGES):
175
  img = infer(refined_prompts[i])
176
  gallery_images.append(img)
 
177
  yield gallery_images
178
 
179
- def redesign(prompt, sim_radio, like_radio, dislike_radio, current_images, history_images, like_image, dislike_image):
180
- global counter, enable_submit, responses_memory
181
- if check_evaluation(sim_radio):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
  responses_memory[METHOD][counter] = {
183
  "prompt": prompt,
184
  "sim_radio": sim_radio,
185
  "response": "",
186
- "satisfied_img": f"round {counter}, {like_radio}",
187
- "unsatisfied_img": f"round {counter}, {dislike_radio}",
188
  }
189
 
190
- enable_submit = True if sim_radio in ["Satisfied", "Very Satisfied"] or enable_submit else False
191
-
192
  history_prompts = [[v["prompt"]] for v in responses_memory[METHOD].values()]
 
 
193
  if not history_images:
194
- history_images = current_images
195
  elif current_images:
196
  history_images.extend(current_images)
 
197
  current_images = []
198
 
199
  examples_state = gr.update(samples=history_prompts, visible=True)
200
  prompt_state = gr.update(interactive=True)
201
  next_state = gr.update(visible=True, interactive=True)
202
  redesign_state = gr.update(interactive=False) if counter >= MAX_ROUND else gr.update(interactive=True)
203
- submit_state = gr.update(interactive=True) if counter >= MAX_ROUND or enable_submit else gr.update(interactive=False)
204
-
205
  counter += 1
 
206
 
207
- return None, None, None, current_images, history_images, examples_state, prompt_state, next_state, redesign_state, submit_state
208
- else:
209
- return {submit_btn: gr.skip()}
210
 
211
- def save_response(prompt, sim_radio, like_radio, dislike_radio, like_image, dislike_image):
212
- global counter, enable_submit, responses_memory
213
-
214
- if check_evaluation(sim_radio):
215
- # Save the final round entry
216
- responses_memory[METHOD][counter] = {
217
- "prompt": prompt,
218
- "sim_radio": sim_radio,
219
- "response": "",
220
- "satisfied_img": f"round {counter}, {like_radio}",
221
- "unsatisfied_img": f"round {counter}, {dislike_radio}",
222
- }
223
-
224
- # Reset states
225
- counter = 1
226
- enable_submit = False
227
-
228
- # Reset buttons
229
- prompt_state = gr.update(interactive=False)
230
- next_state = gr.update(visible=False, interactive=False)
231
- submit_state = gr.update(interactive=False)
232
- redesign_state = gr.update(interactive=False)
233
-
234
- display_info_message("โœ… Your answer is saved!")
235
- return None, None, None, prompt_state, next_state, redesign_state, submit_state
236
  else:
237
- return {submit_btn: gr.skip()}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
 
239
  # =========================
240
  # Interface (single tab, no participant/scenario/background)
@@ -256,6 +258,7 @@ css = """
256
  #button-container {
257
  display: flex;
258
  justify-content: center;
 
259
  }
260
  #compact-compact-row {
261
  width:100%;
@@ -315,9 +318,56 @@ css = """
315
  max-width: 150px;
316
  display: inline-block;
317
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
318
  """
319
 
320
  with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Inconsolata"), "Arial", "sans-serif"]), css=css) as demo:
 
 
 
321
  with gr.Column(elem_id="col-container", elem_classes=["header-section"]):
322
  gr.HTML('<div class="logo-container"><img src="https://huggingface.co/spaces/PAI-GEN/POET/resolve/main/images/icon.png" alt="POET Logo"></div>')
323
  gr.Markdown("### Supporting Prompting Creativity with Automated Expansion of Text-to-Image Generation")
@@ -325,7 +375,7 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Inconsolata"), "
325
  gr.HTML("""
326
  <div style="text-align: center;">
327
  <a href="https://arxiv.org/pdf/2504.13392" target="_blank" class="paper-link">
328
- ๐Ÿ“„ Read the Full Paper .
329
  </a>
330
  </div>
331
  """)
@@ -337,13 +387,15 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Inconsolata"), "
337
 
338
  gr.Markdown("""
339
  <div class="authors-section">
340
- <a href="https://scholar.google.com/citations?user=HXED4kIAAAAJ&hl=en">Evans Han</a>, <a href"https://www.aliceqian.com/">Alice Qian Zhang</a>,
341
- <a href="https://haiyizhu.com/">Haiyi Zhu</a>, <a href="https://www.andrew.cmu.edu/user/hongs/">Hong Shen</a>,
342
- <a href="https://pliang279.github.io/">Paul Pu Liang</a>, <a href="https://janeon.github.io/">Jane Hsieh</a>
 
 
 
343
  </div>
344
  """, elem_classes=["authors-section"])
345
 
346
- # gr.Markdown("---")
347
 
348
  with gr.Tab(""):
349
  with gr.Row(elem_id="compact-row"):
@@ -360,47 +412,99 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Inconsolata"), "
360
 
361
  with gr.Row(elem_id="compact-row"):
362
  with gr.Column(elem_id="col-container"):
363
- images_method = gr.Gallery(label="Images", columns=[4], rows=[1], height=400, elem_id="gallery", format="png")
 
 
 
 
 
 
 
 
364
 
365
  with gr.Column(elem_id="col-container3"):
366
- like_image = gr.Image(label="Satisfied Image", width=200, height=200, sources='upload', format="png", type="filepath", visible=False)
367
- dislike_image = gr.Image(label="Unsatisfied Image", width=200, height=200, sources='upload', format="png", type="filepath", visible=False)
368
-
369
- with gr.Column(elem_id="col-container2", visible=False):
370
- gr.Markdown("### ๐Ÿ“ Evaluation")
371
- sim_radio = gr.Radio(
372
- OPTIONS,
373
- label="How would you rate your satisfaction with the generated images?",
374
- type="value",
375
- elem_classes=["gradio-radio"]
376
- )
377
- like_radio = gr.Radio(
378
- IMAGE_OPTIONS,
379
- label="Select your all-time favorite image (optional).",
380
- type="value",
381
- elem_classes=["gradio-radio"]
382
- )
383
- dislike_radio = gr.Radio(
384
- IMAGE_OPTIONS,
385
- label="Select your all-time least satisfactory image (optional).",
386
- type="value",
387
- elem_classes=["gradio-radio"]
388
- )
389
-
390
- with gr.Column(elem_id="col-container2", visible=False):
391
- example = gr.Examples([['']], prompt, label="Revised Prompt History", visible=False)
392
- history_images = gr.Gallery(label="History Images", columns=[4], rows=[1], elem_id="gallery", format="png")
393
-
394
- with gr.Row(elem_id="button-container"):
395
- redesign_btn = gr.Button("๐ŸŽจ Redesign", variant="primary", scale=0)
396
- submit_btn = gr.Button("โœ… Submit", variant="primary", interactive=False, scale=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
397
 
398
  with gr.Column(elem_id="col-container2"):
399
  gr.Markdown("### ๐ŸŒŸ Examples")
400
- ex1 = gr.Image(label="Image 1", width=200, height=200, sources='upload', format="png", type="filepath", visible=False)
401
- ex2 = gr.Image(label="Image 2", width=200, height=200, sources='upload', format="png", type="filepath", visible=False)
402
- ex3 = gr.Image(label="Image 3", width=200, height=200, sources='upload', format="png", type="filepath", visible=False)
403
- ex4 = gr.Image(label="Image 4", width=200, height=200, sources='upload', format="png", type="filepath", visible=False)
404
 
405
  gr.Examples(
406
  examples=[[ex[0], ex[1][0], ex[1][1], ex[1][2], ex[1][3]] for ex in example_data],
@@ -410,28 +514,44 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Inconsolata"), "
410
  # =========================
411
  # Wiring
412
  # =========================
413
- sim_radio.change(fn=check_satisfaction, inputs=[sim_radio], outputs=[submit_btn])
414
-
415
- dislike_radio.select(fn=select_image, inputs=[dislike_radio, images_method], outputs=[dislike_image])
416
- like_radio.select(fn=select_image, inputs=[like_radio, images_method], outputs=[like_image])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
417
 
418
  next_btn.click(
419
  fn=generate_image,
420
  inputs=[prompt, like_image, dislike_image],
421
  outputs=[images_method]
422
- ).success(lambda: [gr.update(interactive=True), gr.update(interactive=True)], outputs=[next_btn, prompt])
 
423
 
424
  redesign_btn.click(
425
  fn=redesign,
426
- inputs=[prompt, sim_radio, like_radio, dislike_radio, images_method, history_images, like_image, dislike_image],
427
- outputs=[sim_radio, dislike_radio, like_radio, images_method, history_images, example.dataset, prompt, next_btn, redesign_btn, submit_btn]
428
  )
429
 
430
  submit_btn.click(
431
  fn=save_response,
432
- inputs=[prompt, sim_radio, like_radio, dislike_radio, like_image, dislike_image],
433
- outputs=[sim_radio, dislike_radio, like_radio, prompt, next_btn, redesign_btn, submit_btn]
434
  )
435
 
436
  if __name__ == "__main__":
437
- demo.launch()
 
 
1
  import gradio as gr
2
  import numpy as np
3
  import random
 
10
  from optim_utils import optimize_prompt
11
  from utils import (
12
  clean_response_gpt, setup_model, init_gpt_api, call_gpt_api,
13
+ get_refine_msg, clean_cache, get_personalize_message, get_personalized_simplified,
14
  clean_refined_prompt_response_gpt, IMAGES, OPTIONS, T2I_MODELS,
15
  INSTRUCTION, IMAGE_OPTIONS, PROMPTS, SCENARIOS
16
  )
 
40
  METHOD = "Experimental"
41
  counter = 1
42
  enable_submit = False
43
+ redesign_flag = False
44
  responses_memory = {METHOD: {}}
45
  example_data = [
46
  [
 
100
  def personalize_prompt(prompt, history, feedback, like_image, dislike_image):
101
  seed = random.randint(0, MAX_SEED)
102
  client = init_gpt_api()
103
+ # messages = get_personalize_message(prompt, history, feedback, like_image, dislike_image)
104
+ messages = get_personalized_simplified(prompt, like_image, dislike_image)
105
  outputs = call_gpt_api(messages, client, "gpt-4o", seed, max_tokens=2000, temperature=0.7, top_p=0.9)
106
  return outputs
107
 
 
122
  }
123
  inverted_prompt = optimize_prompt(clip_model, preprocess, text_params, device, target_images=images, target_prompts=prompt)
124
 
 
 
 
125
  # =========================
126
  # UI Helper Functions
127
  # =========================
128
+ # Store generated images for selection
129
+ current_generated_images = []
130
+
131
  def reset_gallery():
132
  return []
133
 
 
137
  def display_info_message(msg, duration=5):
138
  gr.Info(msg, duration=duration)
139
 
140
+ def check_evaluation(sim_radio, like_image, dislike_image):
141
+ if not sim_radio or not like_image or not dislike_image:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  display_error_message("โŒ Please fill all evaluations before changing image or submitting.")
143
  return False
144
  return True
145
 
146
  def generate_image(prompt, like_image, dislike_image):
147
+ global responses_memory, current_generated_images
148
  history_prompts = [v["prompt"] for v in responses_memory[METHOD].values()]
149
  feedback = [v["sim_radio"] for v in responses_memory[METHOD].values()]
150
+ print(feedback, like_image, dislike_image)
151
+ if like_image and dislike_image and feedback:
152
+ personalized = personalize_prompt(prompt, history_prompts, feedback, like_image, dislike_image)
153
+ else:
154
+ personalized = prompt
155
  gallery_images = []
156
+ current_generated_images = [] # Reset the stored images
157
  refined_prompts = call_gpt_refine_prompt(personalized)
158
  for i in range(NUM_IMAGES):
159
  img = infer(refined_prompts[i])
160
  gallery_images.append(img)
161
+ current_generated_images.append(img) # Store for selection
162
  yield gallery_images
163
 
164
+ def on_gallery_select(evt: gr.SelectData):
165
+ """Handle gallery image selection and return the selected image"""
166
+ global current_generated_images
167
+ if current_generated_images and evt.index < len(current_generated_images):
168
+ return current_generated_images[evt.index]
169
+ return None
170
+
171
+ def handle_like_drag(selected_image):
172
+ """Handle setting an image as liked"""
173
+ return selected_image
174
+
175
+ def handle_dislike_drag(selected_image):
176
+ """Handle setting an image as disliked"""
177
+ return selected_image
178
+
179
+ def redesign(prompt, sim_radio, current_images, history_images, like_image, dislike_image):
180
+ global counter, responses_memory, redesign_flag
181
+
182
+ if check_evaluation(sim_radio, like_image, dislike_image):
183
  responses_memory[METHOD][counter] = {
184
  "prompt": prompt,
185
  "sim_radio": sim_radio,
186
  "response": "",
187
+ "satisfied_img": f"round {counter}, liked image",
188
+ "unsatisfied_img": f"round {counter}, disliked image",
189
  }
190
 
 
 
191
  history_prompts = [[v["prompt"]] for v in responses_memory[METHOD].values()]
192
+
193
+ # Update history images
194
  if not history_images:
195
+ history_images = current_images.copy() if current_images else []
196
  elif current_images:
197
  history_images.extend(current_images)
198
+
199
  current_images = []
200
 
201
  examples_state = gr.update(samples=history_prompts, visible=True)
202
  prompt_state = gr.update(interactive=True)
203
  next_state = gr.update(visible=True, interactive=True)
204
  redesign_state = gr.update(interactive=False) if counter >= MAX_ROUND else gr.update(interactive=True)
205
+
 
206
  counter += 1
207
+ redesign_flag = True
208
 
209
+ display_info_message(f"โœ… Round {counter-1} feedback saved! You can continue redesigning or restart.")
 
 
210
 
211
+ return None, current_images, history_images, examples_state, prompt_state, next_state, redesign_state
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
  else:
213
+ return gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip()
214
+
215
+ def save_response(prompt, sim_radio, like_image, dislike_image):
216
+ global counter, responses_memory, redesign_flag, current_generated_images
217
+
218
+ # Reset all global variables
219
+ responses_memory[METHOD] = {}
220
+ counter = 1
221
+ redesign_flag = False
222
+ current_generated_images = []
223
+
224
+ # Reset UI states
225
+ prompt_state = gr.update(value="", interactive=True)
226
+ next_state = gr.update(visible=True, interactive=True)
227
+ redesign_state = gr.update(interactive=False)
228
+ submit_state = gr.update(interactive=False)
229
+ sim_radio_state = gr.update(value=None)
230
+ like_image_state = gr.update(value=None)
231
+ dislike_image_state = gr.update(value=None)
232
+ gallery_state = []
233
+ history_gallery_state = []
234
+ examples_state = gr.update(samples=[['']], visible=True)
235
+
236
+ display_info_message("๐Ÿ”„ Session restarted! You can begin with a new prompt.")
237
+
238
+ return (sim_radio_state, prompt_state, next_state, redesign_state,
239
+ like_image_state, dislike_image_state, gallery_state, history_gallery_state, examples_state)
240
 
241
  # =========================
242
  # Interface (single tab, no participant/scenario/background)
 
258
  #button-container {
259
  display: flex;
260
  justify-content: center;
261
+ gap: 10px;
262
  }
263
  #compact-compact-row {
264
  width:100%;
 
318
  max-width: 150px;
319
  display: inline-block;
320
  }
321
+ .instruction-box {
322
+ background: linear-gradient(135deg, #e8f4fd 0%, #f0f8ff 100%);
323
+ border: 2px solid #3498db;
324
+ border-radius: 12px;
325
+ padding: 20px;
326
+ margin: 15px 0;
327
+ color: #2c3e50;
328
+ }
329
+ .instruction-title {
330
+ font-size: 1.2em;
331
+ font-weight: bold;
332
+ margin-bottom: 15px;
333
+ color: #2c3e50;
334
+ display: flex;
335
+ align-items: center;
336
+ gap: 8px;
337
+ }
338
+ .step-list {
339
+ list-style: none;
340
+ padding: 0;
341
+ margin: 0;
342
+ }
343
+ .step-item {
344
+ background: rgba(52, 152, 219, 0.1);
345
+ border-radius: 8px;
346
+ padding: 12px 16px;
347
+ margin: 8px 0;
348
+ border-left: 4px solid #3498db;
349
+ }
350
+ .step-number {
351
+ font-weight: bold;
352
+ color: #3498db;
353
+ margin-right: 8px;
354
+ }
355
+ .personalization-header {
356
+ background: linear-gradient(135deg, #ff6b6b, #ee5a24);
357
+ color: white;
358
+ padding: 15px;
359
+ border-radius: 10px 10px 0 0;
360
+ margin: -10px -10px 15px -10px;
361
+ text-align: center;
362
+ font-weight: bold;
363
+ font-size: 1.1em;
364
+ }
365
  """
366
 
367
  with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Inconsolata"), "Arial", "sans-serif"]), css=css) as demo:
368
+ # State variable to hold selected image
369
+ selected_image = gr.State(None)
370
+
371
  with gr.Column(elem_id="col-container", elem_classes=["header-section"]):
372
  gr.HTML('<div class="logo-container"><img src="https://huggingface.co/spaces/PAI-GEN/POET/resolve/main/images/icon.png" alt="POET Logo"></div>')
373
  gr.Markdown("### Supporting Prompting Creativity with Automated Expansion of Text-to-Image Generation")
 
375
  gr.HTML("""
376
  <div style="text-align: center;">
377
  <a href="https://arxiv.org/pdf/2504.13392" target="_blank" class="paper-link">
378
+ ๐Ÿ“„ Read the Full Paper
379
  </a>
380
  </div>
381
  """)
 
387
 
388
  gr.Markdown("""
389
  <div class="authors-section">
390
+ <a href="https://scholar.google.com/citations?user=HXED4kIAAAAJ&hl=en">Evans Han</a>,
391
+ <a href="https://www.aliceqian.com/">Alice Qian Zhang</a>,
392
+ <a href="https://haiyizhu.com/">Haiyi Zhu</a>,
393
+ <a href="https://www.andrew.cmu.edu/user/hongs/">Hong Shen</a>,
394
+ <a href="https://pliang279.github.io/">Paul Pu Liang</a>,
395
+ <a href="https://janeon.github.io/">Jane Hsieh</a>
396
  </div>
397
  """, elem_classes=["authors-section"])
398
 
 
399
 
400
  with gr.Tab(""):
401
  with gr.Row(elem_id="compact-row"):
 
412
 
413
  with gr.Row(elem_id="compact-row"):
414
  with gr.Column(elem_id="col-container"):
415
+ images_method = gr.Gallery(
416
+ label="Generated Images (Click to select, then set to Like/Dislike image)",
417
+ columns=[4],
418
+ rows=[1],
419
+ height=400,
420
+ interactive=False,
421
+ elem_id="gallery",
422
+ format="png"
423
+ )
424
 
425
  with gr.Column(elem_id="col-container3"):
426
+ like_btn = gr.Button("๐Ÿ‘ Set as Liked (Optional for personalization)", size="sm", variant="secondary")
427
+ like_image = gr.Image(
428
+ label="Satisfied Image",
429
+ width=150,
430
+ height=150,
431
+ interactive=False,
432
+ format="png",
433
+ type="filepath"
434
+ )
435
+ dislike_btn = gr.Button("๐Ÿ‘Ž Set as Disliked (Optional for personalization)", size="sm", variant="secondary")
436
+ dislike_image = gr.Image(
437
+ label="Unsatisfied Image",
438
+ width=150,
439
+ height=150,
440
+ interactive=False,
441
+ format="png",
442
+ type="filepath"
443
+ )
444
+
445
+ with gr.Accordion("๐ŸŽฏ Advanced: Personalized Image Redesign", open=False, elem_id="col-container2"):
446
+ gr.HTML("""
447
+ <div class="instruction-box">
448
+ <div class="instruction-title">
449
+ ๐Ÿ“‹ How to Use Personalized Redesign
450
+ </div>
451
+ <div class="step-list">
452
+ <div class="step-item">
453
+ <span class="step-number">1.</span>
454
+ <strong>Rate Your Satisfaction:</strong> Provide a satisfaction score for the current generated images
455
+ </div>
456
+ <div class="step-item">
457
+ <span class="step-number">2.</span>
458
+ <strong>Select Preferences:</strong> Choose your most liked and disliked images
459
+ </div>
460
+ <div class="step-item">
461
+ <span class="step-number">3.</span>
462
+ <strong>Save & Iterate:</strong> Click "Save Personalized Data" before redesgining your prompt and clicking "Generate"
463
+ </div>
464
+ <div class="step-item">
465
+ <span class="step-number">4.</span>
466
+ <strong>Restart Anytime:</strong> Use the "Restart" button to begin a fresh session
467
+ </div>
468
+ </div>
469
+ </div>
470
+ """)
471
+
472
+ with gr.Column(elem_id="col-container2"):
473
+ gr.Markdown("### ๐Ÿ“Š Rate Current Generation")
474
+ with gr.Row():
475
+ sim_radio = gr.Radio(
476
+ OPTIONS,
477
+ label="How satisfied are you with the current generated images?",
478
+ type="value",
479
+ show_label=True,
480
+ container=True,
481
+ scale=1
482
+ )
483
+
484
+ with gr.Row(elem_id="button-container"):
485
+ with gr.Column(scale=1):
486
+ redesign_btn = gr.Button("๐Ÿ’พ Save Personalization Data", variant="primary", size="lg")
487
+ with gr.Column(scale=1):
488
+ submit_btn = gr.Button("๐Ÿ”„ Restart Session", variant="secondary", size="lg")
489
+
490
+
491
+ with gr.Column(elem_id="col-container2"):
492
+ example = gr.Examples([['']], prompt, label="๐Ÿ“ Prompt History", visible=True)
493
+ history_images = gr.Gallery(
494
+ label="๐Ÿ—ƒ๏ธ Generation History",
495
+ columns=[4],
496
+ rows=[1],
497
+ elem_id="gallery",
498
+ format="png",
499
+ interactive=False,
500
+ )
501
 
502
  with gr.Column(elem_id="col-container2"):
503
  gr.Markdown("### ๐ŸŒŸ Examples")
504
+ ex1 = gr.Image(label="Image 1", width=200, height=200, format="png", type="filepath", visible=False)
505
+ ex2 = gr.Image(label="Image 2", width=200, height=200, format="png", type="filepath", visible=False)
506
+ ex3 = gr.Image(label="Image 3", width=200, height=200, format="png", type="filepath", visible=False)
507
+ ex4 = gr.Image(label="Image 4", width=200, height=200, format="png", type="filepath", visible=False)
508
 
509
  gr.Examples(
510
  examples=[[ex[0], ex[1][0], ex[1][1], ex[1][2], ex[1][3]] for ex in example_data],
 
514
  # =========================
515
  # Wiring
516
  # =========================
517
+ # Handle gallery selection
518
+ images_method.select(
519
+ fn=on_gallery_select,
520
+ inputs=[],
521
+ outputs=[selected_image]
522
+ )
523
+
524
+ # Handle like/dislike button clicks
525
+ like_btn.click(
526
+ fn=handle_like_drag,
527
+ inputs=[selected_image],
528
+ outputs=[like_image]
529
+ )
530
+
531
+ dislike_btn.click(
532
+ fn=handle_dislike_drag,
533
+ inputs=[selected_image],
534
+ outputs=[dislike_image]
535
+ )
536
 
537
  next_btn.click(
538
  fn=generate_image,
539
  inputs=[prompt, like_image, dislike_image],
540
  outputs=[images_method]
541
+ ).success(lambda: [gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True)],
542
+ outputs=[next_btn, prompt, redesign_btn, submit_btn])
543
 
544
  redesign_btn.click(
545
  fn=redesign,
546
+ inputs=[prompt, sim_radio, images_method, history_images, like_image, dislike_image],
547
+ outputs=[sim_radio, images_method, history_images, example.dataset, prompt, next_btn, redesign_btn]
548
  )
549
 
550
  submit_btn.click(
551
  fn=save_response,
552
+ inputs=[prompt, sim_radio, like_image, dislike_image],
553
+ outputs=[sim_radio, prompt, next_btn, redesign_btn, like_image, dislike_image, images_method, history_images, example.dataset]
554
  )
555
 
556
  if __name__ == "__main__":
557
+ demo.launch()
utils.py CHANGED
@@ -171,7 +171,37 @@ def get_personalize_message(prompt, history_prompts, history_feedback, like_imag
171
  "url": f"data:image/png;base64,{dislike_image_base64}",
172
  },
173
  })
174
- print(messages)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
  return messages
176
 
177
  @spaces.GPU
 
171
  "url": f"data:image/png;base64,{dislike_image_base64}",
172
  },
173
  })
174
+ return messages
175
+
176
+ def get_personalized_simplified(prompt, like_image, dislike_image):
177
+ messages = [
178
+ {"role": "system", "content": f"You are a prompt refinement assistant. Your task is to improve a userโ€™s text prompt based on his liked and disliked images. Your goal is to refine the prompt while maintaining its original meaning, improving clarity, specificity, and alignment with user preferences."}
179
+ ]
180
+
181
+ message = f"""The first given image is user's liked image, refine prompt with style, and content user likes. The second given image is user's disliked image, refine prompt to avoid those elements or style of this image."""
182
+
183
+ messages.append({
184
+ "role": "user",
185
+ "content": [
186
+ {"type": "text", "text": f"{message}"},
187
+ ],
188
+ })
189
+ if like_image:
190
+ like_image_base64 = encode_image(like_image)
191
+ messages[-1]["content"].append({
192
+ "type": "image_url",
193
+ "image_url": {
194
+ "url": f"data:image/png;base64,{like_image_base64}",
195
+ },
196
+ })
197
+ if dislike_image:
198
+ dislike_image_base64 = encode_image(dislike_image)
199
+ messages[-1]["content"].append({
200
+ "type": "image_url",
201
+ "image_url": {
202
+ "url": f"data:image/png;base64,{dislike_image_base64}",
203
+ },
204
+ })
205
  return messages
206
 
207
  @spaces.GPU