xh365 commited on
Commit
64dd181
·
1 Parent(s): 3b5c7c9

update refine policy

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. __pycache__/live_preview_helpers.cpython-310.pyc +0 -0
  2. __pycache__/optim_utils.cpython-310.pyc +0 -0
  3. __pycache__/utils.cpython-310.pyc +0 -0
  4. app.py +153 -212
  5. images/.DS_Store +0 -0
  6. images/scenario1_base1.png +0 -0
  7. images/scenario1_base2.png +0 -0
  8. images/scenario1_base3.png +0 -0
  9. images/scenario1_base4.png +0 -0
  10. images/scenario1_our1.png +0 -0
  11. images/scenario1_our2.png +0 -0
  12. images/scenario1_our3.png +0 -0
  13. images/scenario1_our4.png +0 -0
  14. images/scenario1_ours1.png +0 -0
  15. images/scenario1_ours2.png +0 -0
  16. images/scenario1_ours3.png +0 -0
  17. images/scenario1_ours4.png +0 -0
  18. images/scenario2_base1.png +0 -0
  19. images/scenario2_base2.png +0 -0
  20. images/scenario2_base3.png +0 -0
  21. images/scenario2_base4.png +0 -0
  22. images/scenario2_our1.png +0 -0
  23. images/scenario2_our2.png +0 -0
  24. images/scenario2_our3.png +0 -0
  25. images/scenario2_our4.png +0 -0
  26. images/scenario2_ours1.png +0 -0
  27. images/scenario2_ours2.png +0 -0
  28. images/scenario2_ours3.png +0 -0
  29. images/scenario2_ours4.png +0 -0
  30. images/scenario3_base1.png +0 -0
  31. images/scenario3_base2.png +0 -0
  32. images/scenario3_base3.png +0 -0
  33. images/scenario3_base4.png +0 -0
  34. images/scenario3_our1.png +0 -0
  35. images/scenario3_our2.png +0 -0
  36. images/scenario3_our3.png +0 -0
  37. images/scenario3_our4.png +0 -0
  38. images/scenario3_ours1.png +0 -0
  39. images/scenario3_ours2.png +0 -0
  40. images/scenario3_ours3.png +0 -0
  41. images/scenario3_ours4.png +0 -0
  42. images/scenario4_base1.png +0 -0
  43. images/scenario4_base2.png +0 -0
  44. images/scenario4_base3.png +0 -0
  45. images/scenario4_base4.png +0 -0
  46. images/scenario4_our1.png +0 -0
  47. images/scenario4_our2.png +0 -0
  48. images/scenario4_our3.png +0 -0
  49. images/scenario4_our4.png +0 -0
  50. images/scenario4_ours1.png +0 -0
__pycache__/live_preview_helpers.cpython-310.pyc CHANGED
Binary files a/__pycache__/live_preview_helpers.cpython-310.pyc and b/__pycache__/live_preview_helpers.cpython-310.pyc differ
 
__pycache__/optim_utils.cpython-310.pyc CHANGED
Binary files a/__pycache__/optim_utils.cpython-310.pyc and b/__pycache__/optim_utils.cpython-310.pyc differ
 
__pycache__/utils.cpython-310.pyc CHANGED
Binary files a/__pycache__/utils.cpython-310.pyc and b/__pycache__/utils.cpython-310.pyc differ
 
app.py CHANGED
@@ -7,13 +7,11 @@ import torch
7
  import re
8
  import open_clip
9
  from optim_utils import optimize_prompt
10
- from utils import clean_response_gpt, setup_model, init_gpt_api, call_gpt_api, get_refine_msg, clean_cache
11
  from utils import SCENARIOS, PROMPTS, IMAGES, OPTIONS, T2I_MODELS, INSTRUCTION, IMAGE_OPTIONS
12
  import spaces #[uncomment to use ZeroGPU]
13
  import transformers
14
  import gspread
15
- import asyncio
16
- from datetime import datetime
17
 
18
  CLIP_MODEL = "ViT-H-14"
19
  PRETRAINED_CLIP = "laion2b_s32b_b79k"
@@ -33,7 +31,7 @@ llm_pipe = None
33
  torch.cuda.empty_cache()
34
  inverted_prompt = ""
35
 
36
- VERBAL_MSG = "Please verbally describe key differences found in the image pair."
37
  DEFAULT_SCENARIO = "Product advertisement"
38
  METHODS = ["Method 1", "Method 2"]
39
  MAX_ROUND = 5
@@ -78,37 +76,6 @@ def infer(
78
 
79
  return image
80
 
81
- async def infer_async(prompt):
82
- return infer(prompt)
83
- # generate a batch of images in parallel
84
- async def generate_batch(prompts):
85
- tasks = [infer_async(p) for p in prompts]
86
- images = await asyncio.gather(*tasks) # Run all in parallel
87
- return images
88
-
89
- @spaces.GPU
90
- def call_llm_refine_prompt(prompt, num_prompts=5, max_tokens=1000, temperature=0.7, top_p=0.9):
91
- print(f"loading {default_llm_model}")
92
- global llm_pipe
93
- if not llm_pipe:
94
- llm_pipe = transformers.pipeline("text-generation", model=default_llm_model, model_kwargs={"torch_dtype": torch_dtype}, device_map="auto")
95
-
96
- messages = get_refine_msg(prmpt, num_prompts)
97
- terminators = [
98
- llm_pipe.tokenizer.eos_token_id,
99
- llm_pipe.tokenizer.convert_tokens_to_ids("<|eot_id|>")
100
- ]
101
- outputs = llm_pipe(
102
- messages,
103
- max_new_tokens=max_tokens,
104
- eos_token_id=terminators,
105
- do_sample=True,
106
- temperature=temperature,
107
- top_p=top_p,
108
- )
109
- prompt_list = clean_response_gpt(outputs[0]["generated_text"][-1]["content"])
110
- return prompt_list
111
-
112
  def call_gpt_refine_prompt(prompt, num_prompts=5, max_tokens=1000, temperature=0.7, top_p=0.9):
113
  seed = random.randint(0, MAX_SEED)
114
  client = init_gpt_api()
@@ -117,12 +84,6 @@ def call_gpt_refine_prompt(prompt, num_prompts=5, max_tokens=1000, temperature=0
117
  prompt_list = clean_response_gpt(outputs)
118
  return prompt_list
119
 
120
- def refine_prompt(gallery_state, prompt):
121
- modified_prompts = call_gpt_refine_prompt(prompt)
122
- return modified_prompts
123
-
124
- # eval(prompt, inverted_prompt, gallery_state, clip_model, preprocess)
125
-
126
  @spaces.GPU(duration=100)
127
  def invert_prompt(prompt, images, prompt_len=15, iter=1000, lr=0.1, batch_size=2):
128
  text_params = {
@@ -142,25 +103,15 @@ def invert_prompt(prompt, images, prompt_len=15, iter=1000, lr=0.1, batch_size=2
142
  # eval(prompt, learned_prompt, optimized_images, clip_model, preprocess)
143
  # return learned_prompt
144
 
145
-
146
- def eval(prompt, optimized_prompt, optimized_images, clip_model, preprocess):
147
- torch.cuda.empty_cache()
148
- tokenizer = open_clip.get_tokenizer(CLIP_MODEL)
149
- images = [preprocess(i).unsqueeze(0) for i in optimized_images]
150
- images = torch.concatenate(images).to(device)
151
-
152
- with torch.no_grad():
153
- image_feat = clip_model.encode_image(images)
154
- text_feat = clip_model.encode_text(tokenizer([prompt]).to(device))
155
- optimized_text_feat = clip_model.encode_text(tokenizer([optimized_prompt]).to(device))
156
-
157
- image_feat /= image_feat.norm(dim=-1, keepdim=True)
158
- text_feat /= text_feat.norm(dim=-1, keepdim=True)
159
- optimized_text_feat /= optimized_text_feat.norm(dim=-1, keepdim=True)
160
-
161
- similarity = text_feat.cpu().numpy() @ image_feat.cpu().numpy().T
162
- similarity_optimized = optimized_text_feat.cpu().numpy() @ image_feat.cpu().numpy().T
163
-
164
 
165
  ########################################################################################################
166
  # Button-related functions
@@ -182,8 +133,43 @@ def switch_tab(active_tab):
182
  else:
183
  return gr.Tabs(selected="Task A")
184
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  def set_user(participant):
186
- global responses_memory
187
  responses_memory[participant] = {METHODS[0]:{}, METHODS[1]:{}}
188
 
189
  id = re.findall(r'\d+', participant)
@@ -213,25 +199,12 @@ def display_scenario(participant, choice):
213
 
214
  res = {
215
  scenario_content: SCENARIOS.get(choice, ""),
216
- prompt: PROMPTS.get(choice, ""),
217
- prompt1: "",
218
- prompt2: "",
219
  images_method1: initial_images1,
220
  images_method2: initial_images2,
221
- gallery_state1: initial_images1,
222
- gallery_state2: initial_images2,
223
- sim_radio1: None,
224
- sim_radio2: None,
225
- dislike_radio1: None,
226
- like_radio1: None,
227
- dislike_radio2: None,
228
- like_radio2: None,
229
- like_image1: None,
230
- dislike_image1: None,
231
- like_image2: None,
232
- dislike_image2: None,
233
- response1: VERBAL_MSG,
234
- response2: VERBAL_MSG,
235
  next_btn1: gr.update(interactive=False),
236
  next_btn2: gr.update(interactive=False),
237
  redesign_btn1: gr.update(interactive=True),
@@ -241,64 +214,34 @@ def display_scenario(participant, choice):
241
  }
242
  return res
243
 
244
- def generate_image(participant, scenario, prompt, gallery_state, active_tab):
245
  if not check_participant(participant): return [], []
246
  global current_task1, current_task2
247
-
248
  method = current_task1 if active_tab == "Task A" else current_task2
249
 
 
 
 
 
 
 
250
  if method == METHODS[0]:
251
  for i in range(NUM_IMAGES):
252
- img = infer(prompt)
253
- gallery_state.append(img)
254
- yield gallery_state
255
  else:
256
- refined_prompts = refine_prompt(gallery_state, prompt)
257
  for i in range(NUM_IMAGES):
258
  img = infer(refined_prompts[i])
259
- gallery_state.append(img)
260
- yield gallery_state
261
-
262
- def check_satisfaction(sim_radio, active_tab):
263
- global counter1, counter2, current_task1, current_task2
264
- method = current_task1 if active_tab == "Task A" else current_task2
265
- counter = counter1 if method == METHODS[0] else counter2
266
-
267
- fully_satisfied_option = ["Satisfied", "Very Satisfied"] # The value to trigger submit
268
- enable_submit = sim_radio in fully_satisfied_option or counter >= MAX_ROUND
269
-
270
- return gr.update(interactive=enable_submit), gr.update(interactive=(not enable_submit))
271
-
272
- def check_participant(participant):
273
- if participant == "":
274
- display_error_message("Please fill your participant id!")
275
- return False
276
- return True
277
-
278
- def check_evaluation(sim_radio, response):
279
- if not sim_radio :
280
- display_error_message("❌ Please fill all evaluations before change image or submit.")
281
- return False
282
-
283
- return True
284
-
285
- def select_dislike(like_radio, images_method):
286
- if like_radio == IMAGE_OPTIONS[0]:
287
- return images_method[0]
288
- elif like_radio == IMAGE_OPTIONS[1]:
289
- return images_method[1]
290
- elif like_radio == IMAGE_OPTIONS[2]:
291
- return images_method[2]
292
- elif like_radio == IMAGE_OPTIONS[3]:
293
- return images_method[3]
294
- else:
295
- return None
296
 
297
- def redesign(participant, scenario, prompt, sim_radio, response, images_method, active_tab):
298
  global counter1, counter2, responses_memory, current_task1, current_task2
299
  method = current_task1 if active_tab == "Task A" else current_task2
300
 
301
- if check_evaluation(sim_radio, response) and check_participant(participant):
302
  if method == METHODS[0]:
303
  counter1 += 1
304
  counter = counter1
@@ -309,62 +252,68 @@ def redesign(participant, scenario, prompt, sim_radio, response, images_method,
309
  responses_memory[participant][method][counter-1] = {}
310
  responses_memory[participant][method][counter-1]["prompt"] = prompt
311
  responses_memory[participant][method][counter-1]["sim_radio"] = sim_radio
312
- responses_memory[participant][method][counter-1]["response"] = response
313
-
314
- prompt_state = gr.update(visible=True)
 
 
 
 
 
 
 
315
  next_state = gr.update(interactive=False) if counter >= MAX_ROUND else gr.update(visible=True, interactive=True)
316
  redesign_state = gr.update(interactive=False) if counter >= MAX_ROUND else gr.update(interactive=True)
317
  submit_state = gr.update(interactive=True) if counter >= MAX_ROUND else gr.update(interactive=False)
318
 
319
- return [], None, VERBAL_MSG, prompt_state, next_state, redesign_state, submit_state
320
  else:
321
  return {submit_btn1: gr.skip()} if active_tab == "Task A" else {submit_btn2: gr.skip()}
322
 
323
- def show_message(selected_option):
324
- if selected_option:
325
- return "Click \"Redesign\" and revise your prompt to create images that may more closely match your expectations."
326
- return ""
327
-
328
- def save_response(participant, scenario, prompt, sim_radio, response, images_method, active_tab):
329
- global current_task1, current_task2, counter1, counter2, responses_memory, task1_success, task2_success, assigned_scenarios
330
  method = current_task1 if active_tab == "Task A" else current_task2
331
-
332
- if check_evaluation(sim_radio, response) and check_participant(participant):
333
  counter = counter1 if method == METHODS[0] else counter2
334
- # image_paths = [save_image(img, "method", i) for i, img in enumerate(images_method)]
335
 
336
  responses_memory[participant][method][counter] = {}
337
  responses_memory[participant][method][counter]["prompt"] = prompt
338
  responses_memory[participant][method][counter]["sim_radio"] = sim_radio
339
- responses_memory[participant][method][counter]["response"] = response
340
- prompt_state = gr.update(visible=False)
341
- next_state = gr.update(visible=False, interactive=False)
342
- submit_state = gr.update(interactive=False)
343
- redesign_state = gr.update(interactive=False)
344
-
345
  try:
346
  gc = gspread.service_account(filename='credentials.json')
347
  sheet = gc.open("DiverseGen-phase3").sheet1
348
 
349
  for i, entry in responses_memory[participant][method].items():
350
- sheet.append_row([participant, scenario, method, i, entry["prompt"], entry["sim_radio"], entry["response"]])
351
 
352
  display_info_message("✅ Your answer is saved!")
353
 
354
- # reset counter and update success indicator
355
  if method == METHODS[0]:
356
  counter1 = 1
357
  else:
358
  counter2 = 1
359
-
360
  if active_tab == "Task A":
361
  task1_success = True
362
  else:
363
  task2_success = True
364
-
365
- tabs = switch_tab(active_tab)
366
  next_scenario = assigned_scenarios[1] if task1_success and task2_success else assigned_scenarios[0]
367
- return [], [], None, None, None, None, None, VERBAL_MSG, prompt_state, next_state, redesign_state, submit_state, tabs, next_scenario
 
 
 
 
 
 
 
 
 
368
  except Exception as e:
369
  display_error_message(f"❌ Error saving response: {str(e)}")
370
  return {submit_btn1: gr.skip()} if active_tab == "Task A" else {submit_btn2: gr.skip()}
@@ -388,7 +337,7 @@ css="""
388
  }
389
 
390
  #col-container3 {
391
- margin: 0 auto;
392
  max-width: 300px;
393
  }
394
 
@@ -413,7 +362,6 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Inconsolata"), "
413
  )
414
  scenario = gr.Dropdown(
415
  choices=list(SCENARIOS.keys()),
416
- # value=DEFAULT_SCENARIO,
417
  value=None,
418
  label="📌 Scenario",
419
  interactive=False,
@@ -421,13 +369,6 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Inconsolata"), "
421
  scenario_content = gr.Textbox(
422
  label="📖 Background",
423
  interactive=False,
424
- # value=SCENARIOS[DEFAULT_SCENARIO]
425
- )
426
- prompt = gr.Textbox(
427
- label="🎨 Prompt",
428
- max_lines=1,
429
- # value=PROMPTS[DEFAULT_SCENARIO],
430
- interactive=False
431
  )
432
  active_tab = gr.State("Task A")
433
  instruction = gr.Markdown(INSTRUCTION)
@@ -435,26 +376,25 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Inconsolata"), "
435
  with gr.Tabs() as tabs:
436
  with gr.TabItem("Task A", id="Task A") as task1_tab:
437
  task1_tab.select(lambda: "Task A", outputs=[active_tab])
438
- with gr.Column(elem_id="col-container"):
439
- # gr.Markdown("### Step 2: This is the prompt to generate images, you may modify the prompt after first round evaluation")
440
- with gr.Row():
441
- prompt1 = gr.Textbox(
442
- label="🎨 Revise Prompt",
443
- max_lines=1,
444
- placeholder="Enter your prompt",
445
- # value=PROMPTS[DEFAULT_SCENARIO],
446
- scale=4,
447
- visible=False
448
- )
449
- next_btn1 = gr.Button("Generate", variant="primary", scale=1, interactive=False, visible=False)
450
-
451
  with gr.Row(elem_id="compact-row"):
452
  with gr.Column(elem_id="col-container"):
453
- gallery_state1 = gr.State([])
454
- images_method1 = gr.Gallery(show_label=False, columns=[4], rows=[1], height=420, elem_id="gallery")
455
  with gr.Column(elem_id="col-container3"):
456
- like_image1 = gr.Image(label="Satisfied Image", width=200, height=200, sources='upload')
457
- dislike_image1 = gr.Image(label="Unsatisfied Image", width=200, height=200, sources='upload')
458
  with gr.Column(elem_id="col-container2"):
459
  gr.Markdown("### 📝 Evaluation")
460
  sim_radio1 = gr.Radio(
@@ -465,13 +405,13 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Inconsolata"), "
465
  )
466
  like_radio1 = gr.Radio(
467
  IMAGE_OPTIONS,
468
- label="Select the image you are most satisfied.",
469
  type="value",
470
  elem_classes=["gradio-radio"]
471
  )
472
  dislike_radio1 = gr.Radio(
473
  IMAGE_OPTIONS,
474
- label="Select the image you are most unsatisfied.",
475
  type="value",
476
  elem_classes=["gradio-radio"]
477
  )
@@ -491,26 +431,25 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Inconsolata"), "
491
 
492
  with gr.TabItem("Task B", id="Task B") as task2_tab:
493
  task2_tab.select(lambda: "Task B", outputs=[active_tab])
494
- with gr.Column(elem_id="col-container"):
495
- # gr.Markdown("### Step 2: This is the prompt to generate images, you may modify the prompt after first round evaluation")
496
- with gr.Row():
497
- prompt2 = gr.Textbox(
498
- label="🎨 Revise Prompt",
499
- max_lines=1,
500
- placeholder="Enter your prompt",
501
- # value=PROMPTS[DEFAULT_SCENARIO],
502
- scale=4,
503
- visible=False
504
- )
505
- next_btn2 = gr.Button("Generate", variant="primary", scale=1, interactive=False, visible=False)
506
 
507
  with gr.Row(elem_id="compact-row"):
508
  with gr.Column(elem_id="col-container"):
509
- gallery_state2 = gr.State(IMAGES[DEFAULT_SCENARIO]["ours"])
510
- images_method2 = gr.Gallery(height=420, show_label=False, columns=[4], rows=[1], elem_id="gallery")
511
  with gr.Column(elem_id="col-container3"):
512
- like_image2 = gr.Image(label="Satisfied Image", width=200, height=200, sources='upload')
513
- dislike_image2 = gr.Image(label="Unsatisfied Image", width=200, height=200, sources='upload')
514
 
515
  with gr.Column(elem_id="col-container2"):
516
  gr.Markdown("### 📝 Evaluation")
@@ -522,13 +461,13 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Inconsolata"), "
522
  )
523
  like_radio2 = gr.Radio(
524
  IMAGE_OPTIONS,
525
- label="Select the image you are most satisfied.",
526
  type="value",
527
  elem_classes=["gradio-radio"]
528
  )
529
  dislike_radio2 = gr.Radio(
530
  IMAGE_OPTIONS,
531
- label="Select the image you are most unsatisfied.",
532
  type="value",
533
  elem_classes=["gradio-radio"]
534
  )
@@ -550,35 +489,37 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Inconsolata"), "
550
  ########################################################################################################
551
 
552
  participant.change(fn=set_user, inputs=[participant], outputs=[scenario])
553
- scenario.change(display_scenario, inputs=[participant, scenario], outputs=[scenario_content, prompt, prompt1, prompt2, images_method1, images_method2, gallery_state1, gallery_state2, sim_radio1, sim_radio2, dislike_radio1, like_radio1, dislike_radio2, like_radio2, like_image1, dislike_image1, like_image2, dislike_image2, response1, response2, next_btn1, next_btn2, redesign_btn1, redesign_btn2, submit_btn1, submit_btn2])
554
- prompt1.change(fn=reset_gallery, inputs=[], outputs=[gallery_state1])
555
- prompt2.change(fn=reset_gallery, inputs=[], outputs=[gallery_state2])
556
- next_btn1.click(fn=generate_image, inputs=[participant, scenario, prompt1, gallery_state1, active_tab], outputs=[images_method1])
557
- next_btn2.click(fn=generate_image, inputs=[participant, scenario, prompt2, gallery_state2, active_tab], outputs=[images_method2])
 
 
558
  sim_radio1.change(fn=check_satisfaction, inputs=[sim_radio1, active_tab], outputs=[submit_btn1, redesign_btn1])
559
  sim_radio2.change(fn=check_satisfaction, inputs=[sim_radio2, active_tab], outputs=[submit_btn2, redesign_btn2])
560
- dislike_radio1.select(fn=select_dislike, inputs=[dislike_radio1, gallery_state1], outputs=[dislike_image1])
561
- like_radio1.select(fn=select_dislike, inputs=[like_radio1, gallery_state1], outputs=[like_image1])
562
- dislike_radio2.select(fn=select_dislike, inputs=[dislike_radio2, gallery_state2], outputs=[dislike_image2])
563
- like_radio2.select(fn=select_dislike, inputs=[like_radio2, gallery_state2], outputs=[like_image2])
564
 
565
  redesign_btn1.click(
566
  fn=redesign,
567
- inputs=[participant, scenario, prompt1, sim_radio1, response1, images_method1, active_tab],
568
- outputs=[gallery_state1, sim_radio1, response1, prompt1, next_btn1, redesign_btn1, submit_btn1]
569
  )
570
  redesign_btn2.click(
571
  fn=redesign,
572
- inputs=[participant, scenario, prompt2, sim_radio2, response2, images_method2, active_tab],
573
- outputs=[gallery_state2, sim_radio2, response2, prompt2, next_btn2, redesign_btn2, submit_btn2]
574
  )
575
  submit_btn1.click(fn=save_response,
576
- inputs=[participant, scenario, prompt1, sim_radio1, response1, images_method1, active_tab],
577
- outputs=[images_method1, gallery_state1, sim_radio1, dislike_radio1, like_radio1, like_image1, dislike_image1, prompt1, response1, next_btn1, redesign_btn1, submit_btn1, tabs, scenario])
578
 
579
  submit_btn2.click(fn=save_response,
580
- inputs=[participant, scenario, prompt2, sim_radio2, response2, images_method2, active_tab],
581
- outputs=[images_method2, gallery_state2, sim_radio2, dislike_radio2, like_radio2, like_image2, dislike_image2, prompt2, response2, next_btn2, redesign_btn2, submit_btn2, tabs, scenario])
582
 
583
 
584
  if __name__ == "__main__":
 
7
  import re
8
  import open_clip
9
  from optim_utils import optimize_prompt
10
+ from utils import clean_response_gpt, setup_model, init_gpt_api, call_gpt_api, get_refine_msg, clean_cache, get_personalize_message
11
  from utils import SCENARIOS, PROMPTS, IMAGES, OPTIONS, T2I_MODELS, INSTRUCTION, IMAGE_OPTIONS
12
  import spaces #[uncomment to use ZeroGPU]
13
  import transformers
14
  import gspread
 
 
15
 
16
  CLIP_MODEL = "ViT-H-14"
17
  PRETRAINED_CLIP = "laion2b_s32b_b79k"
 
31
  torch.cuda.empty_cache()
32
  inverted_prompt = ""
33
 
34
+ VERBAL_MSG = "Please verbally describe why you are satisfied or unsatisfied at the generated images."
35
  DEFAULT_SCENARIO = "Product advertisement"
36
  METHODS = ["Method 1", "Method 2"]
37
  MAX_ROUND = 5
 
76
 
77
  return image
78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  def call_gpt_refine_prompt(prompt, num_prompts=5, max_tokens=1000, temperature=0.7, top_p=0.9):
80
  seed = random.randint(0, MAX_SEED)
81
  client = init_gpt_api()
 
84
  prompt_list = clean_response_gpt(outputs)
85
  return prompt_list
86
 
 
 
 
 
 
 
87
  @spaces.GPU(duration=100)
88
  def invert_prompt(prompt, images, prompt_len=15, iter=1000, lr=0.1, batch_size=2):
89
  text_params = {
 
103
  # eval(prompt, learned_prompt, optimized_images, clip_model, preprocess)
104
  # return learned_prompt
105
 
106
+ def personalize_prompt(prompt, history, feedback, like_image, dislike_image):
107
+ seed = random.randint(0, MAX_SEED)
108
+ client = init_gpt_api()
109
+ messages = get_personalize_message(prompt, history, feedback, like_image, dislike_image)
110
+ outputs = call_gpt_api(messages, client, "gpt-4o", seed, max_tokens=2000, temperature=0.7, top_p=0.9)
111
+ print(outputs)
112
+ # prompt_list = clean_response_gpt(outputs)
113
+ # print(prompt_list)
114
+ return outputs
 
 
 
 
 
 
 
 
 
 
115
 
116
  ########################################################################################################
117
  # Button-related functions
 
133
  else:
134
  return gr.Tabs(selected="Task A")
135
 
136
+ def check_satisfaction(sim_radio, active_tab):
137
+ global counter1, counter2, current_task1, current_task2
138
+ method = current_task1 if active_tab == "Task A" else current_task2
139
+ counter = counter1 if method == METHODS[0] else counter2
140
+
141
+ fully_satisfied_option = ["Satisfied", "Very Satisfied"] # The value to trigger submit
142
+ enable_submit = sim_radio in fully_satisfied_option or counter >= MAX_ROUND
143
+
144
+ return gr.update(interactive=enable_submit), gr.update(interactive=(not enable_submit))
145
+
146
+ def check_participant(participant):
147
+ if participant == "":
148
+ display_error_message("Please fill your participant id!")
149
+ return False
150
+ return True
151
+
152
+ def check_evaluation(sim_radio):
153
+ if not sim_radio :
154
+ display_error_message("❌ Please fill all evaluations before change image or submit.")
155
+ return False
156
+
157
+ return True
158
+
159
+ def select_image(like_radio, images_method):
160
+ if like_radio == IMAGE_OPTIONS[0]:
161
+ return images_method[0][0]
162
+ elif like_radio == IMAGE_OPTIONS[1]:
163
+ return images_method[1][0]
164
+ elif like_radio == IMAGE_OPTIONS[2]:
165
+ return images_method[2][0]
166
+ elif like_radio == IMAGE_OPTIONS[3]:
167
+ return images_method[3][0]
168
+ else:
169
+ return None
170
+
171
  def set_user(participant):
172
+ global responses_memory, assigned_scenarios
173
  responses_memory[participant] = {METHODS[0]:{}, METHODS[1]:{}}
174
 
175
  id = re.findall(r'\d+', participant)
 
199
 
200
  res = {
201
  scenario_content: SCENARIOS.get(choice, ""),
202
+ prompt1: gr.update(value=PROMPTS.get(choice, ""), interactive=False),
203
+ prompt2: gr.update(value=PROMPTS.get(choice, ""), interactive=False),
 
204
  images_method1: initial_images1,
205
  images_method2: initial_images2,
206
+ history_images1: [],
207
+ history_images2: [],
 
 
 
 
 
 
 
 
 
 
 
 
208
  next_btn1: gr.update(interactive=False),
209
  next_btn2: gr.update(interactive=False),
210
  redesign_btn1: gr.update(interactive=True),
 
214
  }
215
  return res
216
 
217
+ def generate_image(participant, scenario, prompt, active_tab, like_image, dislike_image):
218
  if not check_participant(participant): return [], []
219
  global current_task1, current_task2
 
220
  method = current_task1 if active_tab == "Task A" else current_task2
221
 
222
+ history_prompts = [v["prompt"] for v in responses_memory[participant][method].values()]
223
+ feedback = [v["sim_radio"] for v in responses_memory[participant][method].values()]
224
+ print(history_prompts, feedback)
225
+ personalized_prompt = personalize_prompt(prompt, history_prompts, feedback, like_image, dislike_image)
226
+
227
+ gallery_images = []
228
  if method == METHODS[0]:
229
  for i in range(NUM_IMAGES):
230
+ img = infer(personalized_prompt)
231
+ gallery_images.append(img)
232
+ yield gallery_images
233
  else:
234
+ refined_prompts = call_gpt_refine_prompt(personalized_prompt)
235
  for i in range(NUM_IMAGES):
236
  img = infer(refined_prompts[i])
237
+ gallery_images.append(img)
238
+ yield gallery_images
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
239
 
240
+ def redesign(participant, scenario, prompt, sim_radio, current_images, history_images, active_tab):
241
  global counter1, counter2, responses_memory, current_task1, current_task2
242
  method = current_task1 if active_tab == "Task A" else current_task2
243
 
244
+ if check_evaluation(sim_radio) and check_participant(participant):
245
  if method == METHODS[0]:
246
  counter1 += 1
247
  counter = counter1
 
252
  responses_memory[participant][method][counter-1] = {}
253
  responses_memory[participant][method][counter-1]["prompt"] = prompt
254
  responses_memory[participant][method][counter-1]["sim_radio"] = sim_radio
255
+ # responses_memory[participant][method][counter-1]["response"] = response
256
+
257
+ history_prompts = [[v["prompt"]] for v in responses_memory[participant][method].values()]
258
+ if not history_images:
259
+ history_images = current_images
260
+ elif current_images:
261
+ history_images.extend(current_images)
262
+ current_images = []
263
+ examples_state = gr.update(samples=history_prompts, visible=True)
264
+ prompt_state = gr.update(interactive=True)
265
  next_state = gr.update(interactive=False) if counter >= MAX_ROUND else gr.update(visible=True, interactive=True)
266
  redesign_state = gr.update(interactive=False) if counter >= MAX_ROUND else gr.update(interactive=True)
267
  submit_state = gr.update(interactive=True) if counter >= MAX_ROUND else gr.update(interactive=False)
268
 
269
+ return None, None, None, current_images, history_images, examples_state, prompt_state, next_state, redesign_state, submit_state
270
  else:
271
  return {submit_btn1: gr.skip()} if active_tab == "Task A" else {submit_btn2: gr.skip()}
272
 
273
+
274
+ def save_response(participant, scenario, prompt, sim_radio, active_tab):
275
+ global current_task1, current_task2 # not change
276
+ global task1_success, task2_success, counter1, counter2, responses_memory, assigned_scenarios # will change
277
+
 
 
278
  method = current_task1 if active_tab == "Task A" else current_task2
279
+ if check_evaluation(sim_radio) and check_participant(participant):
 
280
  counter = counter1 if method == METHODS[0] else counter2
 
281
 
282
  responses_memory[participant][method][counter] = {}
283
  responses_memory[participant][method][counter]["prompt"] = prompt
284
  responses_memory[participant][method][counter]["sim_radio"] = sim_radio
285
+ # responses_memory[participant][method][counter]["response"] = response
286
+
 
 
 
 
287
  try:
288
  gc = gspread.service_account(filename='credentials.json')
289
  sheet = gc.open("DiverseGen-phase3").sheet1
290
 
291
  for i, entry in responses_memory[participant][method].items():
292
+ sheet.append_row([participant, scenario, method, i, entry["prompt"], entry["sim_radio"]])
293
 
294
  display_info_message("✅ Your answer is saved!")
295
 
296
+ # reset global variables
297
  if method == METHODS[0]:
298
  counter1 = 1
299
  else:
300
  counter2 = 1
 
301
  if active_tab == "Task A":
302
  task1_success = True
303
  else:
304
  task2_success = True
305
+ # decide if change scenario
 
306
  next_scenario = assigned_scenarios[1] if task1_success and task2_success else assigned_scenarios[0]
307
+ # update buttons
308
+ example_state = gr.update(samples=[], visible=False)
309
+ prompt_state = gr.update(interactive=False)
310
+ next_state = gr.update(visible=False, interactive=False)
311
+ submit_state = gr.update(interactive=False)
312
+ redesign_state = gr.update(interactive=False)
313
+ tabs = switch_tab(active_tab)
314
+
315
+ return None, None, None, None, None, [], [], example_state, prompt_state, next_state, redesign_state, submit_state, next_scenario, tabs
316
+
317
  except Exception as e:
318
  display_error_message(f"❌ Error saving response: {str(e)}")
319
  return {submit_btn1: gr.skip()} if active_tab == "Task A" else {submit_btn2: gr.skip()}
 
337
  }
338
 
339
  #col-container3 {
340
+ margin: 0 0 auto auto;
341
  max-width: 300px;
342
  }
343
 
 
362
  )
363
  scenario = gr.Dropdown(
364
  choices=list(SCENARIOS.keys()),
 
365
  value=None,
366
  label="📌 Scenario",
367
  interactive=False,
 
369
  scenario_content = gr.Textbox(
370
  label="📖 Background",
371
  interactive=False,
 
 
 
 
 
 
 
372
  )
373
  active_tab = gr.State("Task A")
374
  instruction = gr.Markdown(INSTRUCTION)
 
376
  with gr.Tabs() as tabs:
377
  with gr.TabItem("Task A", id="Task A") as task1_tab:
378
  task1_tab.select(lambda: "Task A", outputs=[active_tab])
379
+ with gr.Row(elem_id="compact-row"):
380
+ prompt1 = gr.Textbox(
381
+ label="🎨 Revise Prompt",
382
+ max_lines=5,
383
+ placeholder="Enter your prompt",
384
+ scale=4,
385
+ visible=True,
386
+ )
387
+ next_btn1 = gr.Button("Generate", variant="primary", scale=1, interactive=False, visible=False)
388
+ with gr.Row(elem_id="compact-row"):
389
+ example1 = gr.Examples([['']], prompt1, label="Revised Prompt History", visible=False)
390
+
 
391
  with gr.Row(elem_id="compact-row"):
392
  with gr.Column(elem_id="col-container"):
393
+ images_method1 = gr.Gallery(label="Images", columns=[4], rows=[1], height=200, elem_id="gallery")
394
+ history_images1 = gr.Gallery(label="History Images", columns=[4], rows=[1], elem_id="gallery")
395
  with gr.Column(elem_id="col-container3"):
396
+ like_image1 = gr.Image(label="Satisfied Image", width=200, height=200, sources='upload', type="pil")
397
+ dislike_image1 = gr.Image(label="Unsatisfied Image", width=200, height=200, sources='upload', type="pil")
398
  with gr.Column(elem_id="col-container2"):
399
  gr.Markdown("### 📝 Evaluation")
400
  sim_radio1 = gr.Radio(
 
405
  )
406
  like_radio1 = gr.Radio(
407
  IMAGE_OPTIONS,
408
+ label="Select the image that you find MOST satisfactory. You may leave this section blank if you prefer the previous images.",
409
  type="value",
410
  elem_classes=["gradio-radio"]
411
  )
412
  dislike_radio1 = gr.Radio(
413
  IMAGE_OPTIONS,
414
+ label="Please choose the image that you find LEAST satisfactory. You may leave this section blank if you are more dislike previous images.",
415
  type="value",
416
  elem_classes=["gradio-radio"]
417
  )
 
431
 
432
  with gr.TabItem("Task B", id="Task B") as task2_tab:
433
  task2_tab.select(lambda: "Task B", outputs=[active_tab])
434
+ with gr.Row(elem_id="compact-row"):
435
+ prompt2 = gr.Textbox(
436
+ label="🎨 Revise Prompt",
437
+ max_lines=5,
438
+ placeholder="Enter your prompt",
439
+ scale=4,
440
+ visible=True,
441
+ )
442
+ next_btn2 = gr.Button("Generate", variant="primary", scale=1, interactive=False, visible=False)
443
+ with gr.Row(elem_id="compact-row"):
444
+ example2 = gr.Examples([['']], prompt2, label="Revised Prompt History", visible=False)
 
445
 
446
  with gr.Row(elem_id="compact-row"):
447
  with gr.Column(elem_id="col-container"):
448
+ images_method2 = gr.Gallery(label="Images", columns=[4], rows=[1], height=200, elem_id="gallery")
449
+ history_images2 = gr.Gallery(label="History Images", columns=[4], rows=[1], elem_id="gallery")
450
  with gr.Column(elem_id="col-container3"):
451
+ like_image2 = gr.Image(label="Satisfied Image", width=200, height=200, sources='upload', type="pil")
452
+ dislike_image2 = gr.Image(label="Unsatisfied Image", width=200, height=200, sources='upload', type="pil")
453
 
454
  with gr.Column(elem_id="col-container2"):
455
  gr.Markdown("### 📝 Evaluation")
 
461
  )
462
  like_radio2 = gr.Radio(
463
  IMAGE_OPTIONS,
464
+ label="Select the image that you find MOST satisfactory. You may leave this section blank if you prefer the previous images.",
465
  type="value",
466
  elem_classes=["gradio-radio"]
467
  )
468
  dislike_radio2 = gr.Radio(
469
  IMAGE_OPTIONS,
470
+ label="Please choose the image that you find LEAST satisfactory. You may leave this section blank if you are more dislike previous images.",
471
  type="value",
472
  elem_classes=["gradio-radio"]
473
  )
 
489
  ########################################################################################################
490
 
491
  participant.change(fn=set_user, inputs=[participant], outputs=[scenario])
492
+ scenario.change(display_scenario,
493
+ inputs=[participant, scenario],
494
+ outputs=[scenario_content, prompt1, prompt2, images_method1, images_method2, history_images1, history_images2, next_btn1, next_btn2, redesign_btn1, redesign_btn2, submit_btn1, submit_btn2])
495
+ # prompt1.change(fn=reset_gallery, inputs=[], outputs=[gallery_state1])
496
+ # prompt2.change(fn=reset_gallery, inputs=[], outputs=[gallery_state2])
497
+ next_btn1.click(fn=generate_image, inputs=[participant, scenario, prompt1, active_tab, like_image1, dislike_image1], outputs=[images_method1])
498
+ next_btn2.click(fn=generate_image, inputs=[participant, scenario, prompt2, active_tab, like_image2, dislike_image2], outputs=[images_method2])
499
  sim_radio1.change(fn=check_satisfaction, inputs=[sim_radio1, active_tab], outputs=[submit_btn1, redesign_btn1])
500
  sim_radio2.change(fn=check_satisfaction, inputs=[sim_radio2, active_tab], outputs=[submit_btn2, redesign_btn2])
501
+ dislike_radio1.select(fn=select_image, inputs=[dislike_radio1, images_method1], outputs=[dislike_image1])
502
+ like_radio1.select(fn=select_image, inputs=[like_radio1, images_method1], outputs=[like_image1])
503
+ dislike_radio2.select(fn=select_image, inputs=[dislike_radio2, images_method2], outputs=[dislike_image2])
504
+ like_radio2.select(fn=select_image, inputs=[like_radio2, images_method2], outputs=[like_image2])
505
 
506
  redesign_btn1.click(
507
  fn=redesign,
508
+ inputs=[participant, scenario, prompt1, sim_radio1, images_method1, history_images1, active_tab],
509
+ outputs=[sim_radio1, dislike_radio1, like_radio1, images_method1, history_images1, example1.dataset, prompt1, next_btn1, redesign_btn1, submit_btn1]
510
  )
511
  redesign_btn2.click(
512
  fn=redesign,
513
+ inputs=[participant, scenario, prompt2, sim_radio2, images_method2, history_images2, active_tab],
514
+ outputs=[sim_radio2, dislike_radio2, like_radio2, images_method2, history_images2, example2.dataset, prompt2, next_btn2, redesign_btn2, submit_btn2]
515
  )
516
  submit_btn1.click(fn=save_response,
517
+ inputs=[participant, scenario, prompt1, sim_radio1, active_tab],
518
+ outputs=[sim_radio1, dislike_radio1, like_radio1, like_image1, dislike_image1, images_method1, history_images1, example1.dataset, prompt1, next_btn1, redesign_btn1, submit_btn1, scenario, tabs])
519
 
520
  submit_btn2.click(fn=save_response,
521
+ inputs=[participant, scenario, prompt2, sim_radio2, active_tab],
522
+ outputs=[sim_radio2, dislike_radio2, like_radio2, like_image2, dislike_image2, images_method2, history_images2, example2.dataset, prompt2, next_btn2, redesign_btn2, submit_btn2, scenario, tabs])
523
 
524
 
525
  if __name__ == "__main__":
images/.DS_Store DELETED
Binary file (6.15 kB)
 
images/scenario1_base1.png CHANGED
images/scenario1_base2.png CHANGED
images/scenario1_base3.png CHANGED
images/scenario1_base4.png CHANGED
images/scenario1_our1.png DELETED
Binary file (867 kB)
 
images/scenario1_our2.png DELETED
Binary file (579 kB)
 
images/scenario1_our3.png DELETED
Binary file (54.6 kB)
 
images/scenario1_our4.png DELETED
Binary file (32.5 kB)
 
images/scenario1_ours1.png ADDED
images/scenario1_ours2.png ADDED
images/scenario1_ours3.png ADDED
images/scenario1_ours4.png ADDED
images/scenario2_base1.png CHANGED
images/scenario2_base2.png CHANGED
images/scenario2_base3.png CHANGED
images/scenario2_base4.png CHANGED
images/scenario2_our1.png DELETED
Binary file (48.6 kB)
 
images/scenario2_our2.png DELETED
Binary file (59.1 kB)
 
images/scenario2_our3.png DELETED
Binary file (44.1 kB)
 
images/scenario2_our4.png DELETED
Binary file (99.4 kB)
 
images/scenario2_ours1.png ADDED
images/scenario2_ours2.png ADDED
images/scenario2_ours3.png ADDED
images/scenario2_ours4.png ADDED
images/scenario3_base1.png CHANGED
images/scenario3_base2.png CHANGED
images/scenario3_base3.png CHANGED
images/scenario3_base4.png CHANGED
images/scenario3_our1.png DELETED
Binary file (55.3 kB)
 
images/scenario3_our2.png DELETED
Binary file (67.8 kB)
 
images/scenario3_our3.png DELETED
Binary file (151 kB)
 
images/scenario3_our4.png DELETED
Binary file (73.9 kB)
 
images/scenario3_ours1.png ADDED
images/scenario3_ours2.png ADDED
images/scenario3_ours3.png ADDED
images/scenario3_ours4.png ADDED
images/scenario4_base1.png CHANGED
images/scenario4_base2.png CHANGED
images/scenario4_base3.png CHANGED
images/scenario4_base4.png CHANGED
images/scenario4_our1.png DELETED
Binary file (99.1 kB)
 
images/scenario4_our2.png DELETED
Binary file (104 kB)
 
images/scenario4_our3.png DELETED
Binary file (88.1 kB)
 
images/scenario4_our4.png DELETED
Binary file (94.9 kB)
 
images/scenario4_ours1.png ADDED