diff --git a/__pycache__/live_preview_helpers.cpython-310.pyc b/__pycache__/live_preview_helpers.cpython-310.pyc
index 422dde1553a0650d56b8ef19f963713b8562d245..297303fbe50dc663031b9e597f37ac9d947766f1 100644
Binary files a/__pycache__/live_preview_helpers.cpython-310.pyc and b/__pycache__/live_preview_helpers.cpython-310.pyc differ
diff --git a/__pycache__/optim_utils.cpython-310.pyc b/__pycache__/optim_utils.cpython-310.pyc
index c101d5ead6cc1076a582b7061217fc6fb2a05c53..98049bf4f0530ff3e4d326734bcc1f45f671cf74 100644
Binary files a/__pycache__/optim_utils.cpython-310.pyc and b/__pycache__/optim_utils.cpython-310.pyc differ
diff --git a/__pycache__/utils.cpython-310.pyc b/__pycache__/utils.cpython-310.pyc
index 3ce89868a99fdf552a12fe7b347ab90f7027ce4b..0c74b8014decef246c54c355c0b85f7eea56afb5 100644
Binary files a/__pycache__/utils.cpython-310.pyc and b/__pycache__/utils.cpython-310.pyc differ
diff --git a/app.py b/app.py
index 7011ce81485d59cc7ff92deff6c3f9939fcfac70..f4a957fe656ce8626213dce5f85baf3ab364f907 100644
--- a/app.py
+++ b/app.py
@@ -7,13 +7,11 @@ import torch
import re
import open_clip
from optim_utils import optimize_prompt
-from utils import clean_response_gpt, setup_model, init_gpt_api, call_gpt_api, get_refine_msg, clean_cache
+from utils import clean_response_gpt, setup_model, init_gpt_api, call_gpt_api, get_refine_msg, clean_cache, get_personalize_message
from utils import SCENARIOS, PROMPTS, IMAGES, OPTIONS, T2I_MODELS, INSTRUCTION, IMAGE_OPTIONS
import spaces #[uncomment to use ZeroGPU]
import transformers
import gspread
-import asyncio
-from datetime import datetime
CLIP_MODEL = "ViT-H-14"
PRETRAINED_CLIP = "laion2b_s32b_b79k"
@@ -33,7 +31,7 @@ llm_pipe = None
torch.cuda.empty_cache()
inverted_prompt = ""
-VERBAL_MSG = "Please verbally describe key differences found in the image pair."
+VERBAL_MSG = "Please verbally describe why you are satisfied or unsatisfied at the generated images."
DEFAULT_SCENARIO = "Product advertisement"
METHODS = ["Method 1", "Method 2"]
MAX_ROUND = 5
@@ -78,37 +76,6 @@ def infer(
return image
-async def infer_async(prompt):
- return infer(prompt)
-# generate a batch of images in parallel
-async def generate_batch(prompts):
- tasks = [infer_async(p) for p in prompts]
- images = await asyncio.gather(*tasks) # Run all in parallel
- return images
-
-@spaces.GPU
-def call_llm_refine_prompt(prompt, num_prompts=5, max_tokens=1000, temperature=0.7, top_p=0.9):
- print(f"loading {default_llm_model}")
- global llm_pipe
- if not llm_pipe:
- llm_pipe = transformers.pipeline("text-generation", model=default_llm_model, model_kwargs={"torch_dtype": torch_dtype}, device_map="auto")
-
- messages = get_refine_msg(prmpt, num_prompts)
- terminators = [
- llm_pipe.tokenizer.eos_token_id,
- llm_pipe.tokenizer.convert_tokens_to_ids("<|eot_id|>")
- ]
- outputs = llm_pipe(
- messages,
- max_new_tokens=max_tokens,
- eos_token_id=terminators,
- do_sample=True,
- temperature=temperature,
- top_p=top_p,
- )
- prompt_list = clean_response_gpt(outputs[0]["generated_text"][-1]["content"])
- return prompt_list
-
def call_gpt_refine_prompt(prompt, num_prompts=5, max_tokens=1000, temperature=0.7, top_p=0.9):
seed = random.randint(0, MAX_SEED)
client = init_gpt_api()
@@ -117,12 +84,6 @@ def call_gpt_refine_prompt(prompt, num_prompts=5, max_tokens=1000, temperature=0
prompt_list = clean_response_gpt(outputs)
return prompt_list
-def refine_prompt(gallery_state, prompt):
- modified_prompts = call_gpt_refine_prompt(prompt)
- return modified_prompts
-
- # eval(prompt, inverted_prompt, gallery_state, clip_model, preprocess)
-
@spaces.GPU(duration=100)
def invert_prompt(prompt, images, prompt_len=15, iter=1000, lr=0.1, batch_size=2):
text_params = {
@@ -142,25 +103,15 @@ def invert_prompt(prompt, images, prompt_len=15, iter=1000, lr=0.1, batch_size=2
# eval(prompt, learned_prompt, optimized_images, clip_model, preprocess)
# return learned_prompt
-
-def eval(prompt, optimized_prompt, optimized_images, clip_model, preprocess):
- torch.cuda.empty_cache()
- tokenizer = open_clip.get_tokenizer(CLIP_MODEL)
- images = [preprocess(i).unsqueeze(0) for i in optimized_images]
- images = torch.concatenate(images).to(device)
-
- with torch.no_grad():
- image_feat = clip_model.encode_image(images)
- text_feat = clip_model.encode_text(tokenizer([prompt]).to(device))
- optimized_text_feat = clip_model.encode_text(tokenizer([optimized_prompt]).to(device))
-
- image_feat /= image_feat.norm(dim=-1, keepdim=True)
- text_feat /= text_feat.norm(dim=-1, keepdim=True)
- optimized_text_feat /= optimized_text_feat.norm(dim=-1, keepdim=True)
-
- similarity = text_feat.cpu().numpy() @ image_feat.cpu().numpy().T
- similarity_optimized = optimized_text_feat.cpu().numpy() @ image_feat.cpu().numpy().T
-
+def personalize_prompt(prompt, history, feedback, like_image, dislike_image):
+ seed = random.randint(0, MAX_SEED)
+ client = init_gpt_api()
+ messages = get_personalize_message(prompt, history, feedback, like_image, dislike_image)
+ outputs = call_gpt_api(messages, client, "gpt-4o", seed, max_tokens=2000, temperature=0.7, top_p=0.9)
+ print(outputs)
+ # prompt_list = clean_response_gpt(outputs)
+ # print(prompt_list)
+ return outputs
########################################################################################################
# Button-related functions
@@ -182,8 +133,43 @@ def switch_tab(active_tab):
else:
return gr.Tabs(selected="Task A")
+def check_satisfaction(sim_radio, active_tab):
+ global counter1, counter2, current_task1, current_task2
+ method = current_task1 if active_tab == "Task A" else current_task2
+ counter = counter1 if method == METHODS[0] else counter2
+
+ fully_satisfied_option = ["Satisfied", "Very Satisfied"] # The value to trigger submit
+ enable_submit = sim_radio in fully_satisfied_option or counter >= MAX_ROUND
+
+ return gr.update(interactive=enable_submit), gr.update(interactive=(not enable_submit))
+
+def check_participant(participant):
+ if participant == "":
+ display_error_message("Please fill your participant id!")
+ return False
+ return True
+
+def check_evaluation(sim_radio):
+ if not sim_radio :
+ display_error_message("β Please fill all evaluations before change image or submit.")
+ return False
+
+ return True
+
+def select_image(like_radio, images_method):
+ if like_radio == IMAGE_OPTIONS[0]:
+ return images_method[0][0]
+ elif like_radio == IMAGE_OPTIONS[1]:
+ return images_method[1][0]
+ elif like_radio == IMAGE_OPTIONS[2]:
+ return images_method[2][0]
+ elif like_radio == IMAGE_OPTIONS[3]:
+ return images_method[3][0]
+ else:
+ return None
+
def set_user(participant):
- global responses_memory
+ global responses_memory, assigned_scenarios
responses_memory[participant] = {METHODS[0]:{}, METHODS[1]:{}}
id = re.findall(r'\d+', participant)
@@ -213,25 +199,12 @@ def display_scenario(participant, choice):
res = {
scenario_content: SCENARIOS.get(choice, ""),
- prompt: PROMPTS.get(choice, ""),
- prompt1: "",
- prompt2: "",
+ prompt1: gr.update(value=PROMPTS.get(choice, ""), interactive=False),
+ prompt2: gr.update(value=PROMPTS.get(choice, ""), interactive=False),
images_method1: initial_images1,
images_method2: initial_images2,
- gallery_state1: initial_images1,
- gallery_state2: initial_images2,
- sim_radio1: None,
- sim_radio2: None,
- dislike_radio1: None,
- like_radio1: None,
- dislike_radio2: None,
- like_radio2: None,
- like_image1: None,
- dislike_image1: None,
- like_image2: None,
- dislike_image2: None,
- response1: VERBAL_MSG,
- response2: VERBAL_MSG,
+ history_images1: [],
+ history_images2: [],
next_btn1: gr.update(interactive=False),
next_btn2: gr.update(interactive=False),
redesign_btn1: gr.update(interactive=True),
@@ -241,64 +214,34 @@ def display_scenario(participant, choice):
}
return res
-def generate_image(participant, scenario, prompt, gallery_state, active_tab):
+def generate_image(participant, scenario, prompt, active_tab, like_image, dislike_image):
if not check_participant(participant): return [], []
global current_task1, current_task2
-
method = current_task1 if active_tab == "Task A" else current_task2
+ history_prompts = [v["prompt"] for v in responses_memory[participant][method].values()]
+ feedback = [v["sim_radio"] for v in responses_memory[participant][method].values()]
+ print(history_prompts, feedback)
+ personalized_prompt = personalize_prompt(prompt, history_prompts, feedback, like_image, dislike_image)
+
+ gallery_images = []
if method == METHODS[0]:
for i in range(NUM_IMAGES):
- img = infer(prompt)
- gallery_state.append(img)
- yield gallery_state
+ img = infer(personalized_prompt)
+ gallery_images.append(img)
+ yield gallery_images
else:
- refined_prompts = refine_prompt(gallery_state, prompt)
+ refined_prompts = call_gpt_refine_prompt(personalized_prompt)
for i in range(NUM_IMAGES):
img = infer(refined_prompts[i])
- gallery_state.append(img)
- yield gallery_state
-
-def check_satisfaction(sim_radio, active_tab):
- global counter1, counter2, current_task1, current_task2
- method = current_task1 if active_tab == "Task A" else current_task2
- counter = counter1 if method == METHODS[0] else counter2
-
- fully_satisfied_option = ["Satisfied", "Very Satisfied"] # The value to trigger submit
- enable_submit = sim_radio in fully_satisfied_option or counter >= MAX_ROUND
-
- return gr.update(interactive=enable_submit), gr.update(interactive=(not enable_submit))
-
-def check_participant(participant):
- if participant == "":
- display_error_message("Please fill your participant id!")
- return False
- return True
-
-def check_evaluation(sim_radio, response):
- if not sim_radio :
- display_error_message("β Please fill all evaluations before change image or submit.")
- return False
-
- return True
-
-def select_dislike(like_radio, images_method):
- if like_radio == IMAGE_OPTIONS[0]:
- return images_method[0]
- elif like_radio == IMAGE_OPTIONS[1]:
- return images_method[1]
- elif like_radio == IMAGE_OPTIONS[2]:
- return images_method[2]
- elif like_radio == IMAGE_OPTIONS[3]:
- return images_method[3]
- else:
- return None
+ gallery_images.append(img)
+ yield gallery_images
-def redesign(participant, scenario, prompt, sim_radio, response, images_method, active_tab):
+def redesign(participant, scenario, prompt, sim_radio, current_images, history_images, active_tab):
global counter1, counter2, responses_memory, current_task1, current_task2
method = current_task1 if active_tab == "Task A" else current_task2
- if check_evaluation(sim_radio, response) and check_participant(participant):
+ if check_evaluation(sim_radio) and check_participant(participant):
if method == METHODS[0]:
counter1 += 1
counter = counter1
@@ -309,62 +252,68 @@ def redesign(participant, scenario, prompt, sim_radio, response, images_method,
responses_memory[participant][method][counter-1] = {}
responses_memory[participant][method][counter-1]["prompt"] = prompt
responses_memory[participant][method][counter-1]["sim_radio"] = sim_radio
- responses_memory[participant][method][counter-1]["response"] = response
-
- prompt_state = gr.update(visible=True)
+ # responses_memory[participant][method][counter-1]["response"] = response
+
+ history_prompts = [[v["prompt"]] for v in responses_memory[participant][method].values()]
+ if not history_images:
+ history_images = current_images
+ elif current_images:
+ history_images.extend(current_images)
+ current_images = []
+ examples_state = gr.update(samples=history_prompts, visible=True)
+ prompt_state = gr.update(interactive=True)
next_state = gr.update(interactive=False) if counter >= MAX_ROUND else gr.update(visible=True, interactive=True)
redesign_state = gr.update(interactive=False) if counter >= MAX_ROUND else gr.update(interactive=True)
submit_state = gr.update(interactive=True) if counter >= MAX_ROUND else gr.update(interactive=False)
- return [], None, VERBAL_MSG, prompt_state, next_state, redesign_state, submit_state
+ return None, None, None, current_images, history_images, examples_state, prompt_state, next_state, redesign_state, submit_state
else:
return {submit_btn1: gr.skip()} if active_tab == "Task A" else {submit_btn2: gr.skip()}
-def show_message(selected_option):
- if selected_option:
- return "Click \"Redesign\" and revise your prompt to create images that may more closely match your expectations."
- return ""
-
-def save_response(participant, scenario, prompt, sim_radio, response, images_method, active_tab):
- global current_task1, current_task2, counter1, counter2, responses_memory, task1_success, task2_success, assigned_scenarios
+
+def save_response(participant, scenario, prompt, sim_radio, active_tab):
+ global current_task1, current_task2 # not change
+ global task1_success, task2_success, counter1, counter2, responses_memory, assigned_scenarios # will change
+
method = current_task1 if active_tab == "Task A" else current_task2
-
- if check_evaluation(sim_radio, response) and check_participant(participant):
+ if check_evaluation(sim_radio) and check_participant(participant):
counter = counter1 if method == METHODS[0] else counter2
- # image_paths = [save_image(img, "method", i) for i, img in enumerate(images_method)]
responses_memory[participant][method][counter] = {}
responses_memory[participant][method][counter]["prompt"] = prompt
responses_memory[participant][method][counter]["sim_radio"] = sim_radio
- responses_memory[participant][method][counter]["response"] = response
- prompt_state = gr.update(visible=False)
- next_state = gr.update(visible=False, interactive=False)
- submit_state = gr.update(interactive=False)
- redesign_state = gr.update(interactive=False)
-
+ # responses_memory[participant][method][counter]["response"] = response
+
try:
gc = gspread.service_account(filename='credentials.json')
sheet = gc.open("DiverseGen-phase3").sheet1
for i, entry in responses_memory[participant][method].items():
- sheet.append_row([participant, scenario, method, i, entry["prompt"], entry["sim_radio"], entry["response"]])
+ sheet.append_row([participant, scenario, method, i, entry["prompt"], entry["sim_radio"]])
display_info_message("β
Your answer is saved!")
- # reset counter and update success indicator
+ # reset global variables
if method == METHODS[0]:
counter1 = 1
else:
counter2 = 1
-
if active_tab == "Task A":
task1_success = True
else:
task2_success = True
-
- tabs = switch_tab(active_tab)
+ # decide if change scenario
next_scenario = assigned_scenarios[1] if task1_success and task2_success else assigned_scenarios[0]
- return [], [], None, None, None, None, None, VERBAL_MSG, prompt_state, next_state, redesign_state, submit_state, tabs, next_scenario
+ # update buttons
+ example_state = gr.update(samples=[], visible=False)
+ prompt_state = gr.update(interactive=False)
+ next_state = gr.update(visible=False, interactive=False)
+ submit_state = gr.update(interactive=False)
+ redesign_state = gr.update(interactive=False)
+ tabs = switch_tab(active_tab)
+
+ return None, None, None, None, None, [], [], example_state, prompt_state, next_state, redesign_state, submit_state, next_scenario, tabs
+
except Exception as e:
display_error_message(f"β Error saving response: {str(e)}")
return {submit_btn1: gr.skip()} if active_tab == "Task A" else {submit_btn2: gr.skip()}
@@ -388,7 +337,7 @@ css="""
}
#col-container3 {
- margin: 0 auto;
+ margin: 0 0 auto auto;
max-width: 300px;
}
@@ -413,7 +362,6 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Inconsolata"), "
)
scenario = gr.Dropdown(
choices=list(SCENARIOS.keys()),
- # value=DEFAULT_SCENARIO,
value=None,
label="π Scenario",
interactive=False,
@@ -421,13 +369,6 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Inconsolata"), "
scenario_content = gr.Textbox(
label="π Background",
interactive=False,
- # value=SCENARIOS[DEFAULT_SCENARIO]
- )
- prompt = gr.Textbox(
- label="π¨ Prompt",
- max_lines=1,
- # value=PROMPTS[DEFAULT_SCENARIO],
- interactive=False
)
active_tab = gr.State("Task A")
instruction = gr.Markdown(INSTRUCTION)
@@ -435,26 +376,25 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Inconsolata"), "
with gr.Tabs() as tabs:
with gr.TabItem("Task A", id="Task A") as task1_tab:
task1_tab.select(lambda: "Task A", outputs=[active_tab])
- with gr.Column(elem_id="col-container"):
- # gr.Markdown("### Step 2: This is the prompt to generate images, you may modify the prompt after first round evaluation")
- with gr.Row():
- prompt1 = gr.Textbox(
- label="π¨ Revise Prompt",
- max_lines=1,
- placeholder="Enter your prompt",
- # value=PROMPTS[DEFAULT_SCENARIO],
- scale=4,
- visible=False
- )
- next_btn1 = gr.Button("Generate", variant="primary", scale=1, interactive=False, visible=False)
-
+ with gr.Row(elem_id="compact-row"):
+ prompt1 = gr.Textbox(
+ label="π¨ Revise Prompt",
+ max_lines=5,
+ placeholder="Enter your prompt",
+ scale=4,
+ visible=True,
+ )
+ next_btn1 = gr.Button("Generate", variant="primary", scale=1, interactive=False, visible=False)
+ with gr.Row(elem_id="compact-row"):
+ example1 = gr.Examples([['']], prompt1, label="Revised Prompt History", visible=False)
+
with gr.Row(elem_id="compact-row"):
with gr.Column(elem_id="col-container"):
- gallery_state1 = gr.State([])
- images_method1 = gr.Gallery(show_label=False, columns=[4], rows=[1], height=420, elem_id="gallery")
+ images_method1 = gr.Gallery(label="Images", columns=[4], rows=[1], height=200, elem_id="gallery")
+ history_images1 = gr.Gallery(label="History Images", columns=[4], rows=[1], elem_id="gallery")
with gr.Column(elem_id="col-container3"):
- like_image1 = gr.Image(label="Satisfied Image", width=200, height=200, sources='upload')
- dislike_image1 = gr.Image(label="Unsatisfied Image", width=200, height=200, sources='upload')
+ like_image1 = gr.Image(label="Satisfied Image", width=200, height=200, sources='upload', type="pil")
+ dislike_image1 = gr.Image(label="Unsatisfied Image", width=200, height=200, sources='upload', type="pil")
with gr.Column(elem_id="col-container2"):
gr.Markdown("### π Evaluation")
sim_radio1 = gr.Radio(
@@ -465,13 +405,13 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Inconsolata"), "
)
like_radio1 = gr.Radio(
IMAGE_OPTIONS,
- label="Select the image you are most satisfied.",
+ label="Select the image that you find MOST satisfactory. You may leave this section blank if you prefer the previous images.",
type="value",
elem_classes=["gradio-radio"]
)
dislike_radio1 = gr.Radio(
IMAGE_OPTIONS,
- label="Select the image you are most unsatisfied.",
+ label="Please choose the image that you find LEAST satisfactory. You may leave this section blank if you are more dislike previous images.",
type="value",
elem_classes=["gradio-radio"]
)
@@ -491,26 +431,25 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Inconsolata"), "
with gr.TabItem("Task B", id="Task B") as task2_tab:
task2_tab.select(lambda: "Task B", outputs=[active_tab])
- with gr.Column(elem_id="col-container"):
- # gr.Markdown("### Step 2: This is the prompt to generate images, you may modify the prompt after first round evaluation")
- with gr.Row():
- prompt2 = gr.Textbox(
- label="π¨ Revise Prompt",
- max_lines=1,
- placeholder="Enter your prompt",
- # value=PROMPTS[DEFAULT_SCENARIO],
- scale=4,
- visible=False
- )
- next_btn2 = gr.Button("Generate", variant="primary", scale=1, interactive=False, visible=False)
+ with gr.Row(elem_id="compact-row"):
+ prompt2 = gr.Textbox(
+ label="π¨ Revise Prompt",
+ max_lines=5,
+ placeholder="Enter your prompt",
+ scale=4,
+ visible=True,
+ )
+ next_btn2 = gr.Button("Generate", variant="primary", scale=1, interactive=False, visible=False)
+ with gr.Row(elem_id="compact-row"):
+ example2 = gr.Examples([['']], prompt2, label="Revised Prompt History", visible=False)
with gr.Row(elem_id="compact-row"):
with gr.Column(elem_id="col-container"):
- gallery_state2 = gr.State(IMAGES[DEFAULT_SCENARIO]["ours"])
- images_method2 = gr.Gallery(height=420, show_label=False, columns=[4], rows=[1], elem_id="gallery")
+ images_method2 = gr.Gallery(label="Images", columns=[4], rows=[1], height=200, elem_id="gallery")
+ history_images2 = gr.Gallery(label="History Images", columns=[4], rows=[1], elem_id="gallery")
with gr.Column(elem_id="col-container3"):
- like_image2 = gr.Image(label="Satisfied Image", width=200, height=200, sources='upload')
- dislike_image2 = gr.Image(label="Unsatisfied Image", width=200, height=200, sources='upload')
+ like_image2 = gr.Image(label="Satisfied Image", width=200, height=200, sources='upload', type="pil")
+ dislike_image2 = gr.Image(label="Unsatisfied Image", width=200, height=200, sources='upload', type="pil")
with gr.Column(elem_id="col-container2"):
gr.Markdown("### π Evaluation")
@@ -522,13 +461,13 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Inconsolata"), "
)
like_radio2 = gr.Radio(
IMAGE_OPTIONS,
- label="Select the image you are most satisfied.",
+ label="Select the image that you find MOST satisfactory. You may leave this section blank if you prefer the previous images.",
type="value",
elem_classes=["gradio-radio"]
)
dislike_radio2 = gr.Radio(
IMAGE_OPTIONS,
- label="Select the image you are most unsatisfied.",
+ label="Please choose the image that you find LEAST satisfactory. You may leave this section blank if you are more dislike previous images.",
type="value",
elem_classes=["gradio-radio"]
)
@@ -550,35 +489,37 @@ with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Inconsolata"), "
########################################################################################################
participant.change(fn=set_user, inputs=[participant], outputs=[scenario])
- scenario.change(display_scenario, inputs=[participant, scenario], outputs=[scenario_content, prompt, prompt1, prompt2, images_method1, images_method2, gallery_state1, gallery_state2, sim_radio1, sim_radio2, dislike_radio1, like_radio1, dislike_radio2, like_radio2, like_image1, dislike_image1, like_image2, dislike_image2, response1, response2, next_btn1, next_btn2, redesign_btn1, redesign_btn2, submit_btn1, submit_btn2])
- prompt1.change(fn=reset_gallery, inputs=[], outputs=[gallery_state1])
- prompt2.change(fn=reset_gallery, inputs=[], outputs=[gallery_state2])
- next_btn1.click(fn=generate_image, inputs=[participant, scenario, prompt1, gallery_state1, active_tab], outputs=[images_method1])
- next_btn2.click(fn=generate_image, inputs=[participant, scenario, prompt2, gallery_state2, active_tab], outputs=[images_method2])
+ scenario.change(display_scenario,
+ inputs=[participant, scenario],
+ outputs=[scenario_content, prompt1, prompt2, images_method1, images_method2, history_images1, history_images2, next_btn1, next_btn2, redesign_btn1, redesign_btn2, submit_btn1, submit_btn2])
+ # prompt1.change(fn=reset_gallery, inputs=[], outputs=[gallery_state1])
+ # prompt2.change(fn=reset_gallery, inputs=[], outputs=[gallery_state2])
+ next_btn1.click(fn=generate_image, inputs=[participant, scenario, prompt1, active_tab, like_image1, dislike_image1], outputs=[images_method1])
+ next_btn2.click(fn=generate_image, inputs=[participant, scenario, prompt2, active_tab, like_image2, dislike_image2], outputs=[images_method2])
sim_radio1.change(fn=check_satisfaction, inputs=[sim_radio1, active_tab], outputs=[submit_btn1, redesign_btn1])
sim_radio2.change(fn=check_satisfaction, inputs=[sim_radio2, active_tab], outputs=[submit_btn2, redesign_btn2])
- dislike_radio1.select(fn=select_dislike, inputs=[dislike_radio1, gallery_state1], outputs=[dislike_image1])
- like_radio1.select(fn=select_dislike, inputs=[like_radio1, gallery_state1], outputs=[like_image1])
- dislike_radio2.select(fn=select_dislike, inputs=[dislike_radio2, gallery_state2], outputs=[dislike_image2])
- like_radio2.select(fn=select_dislike, inputs=[like_radio2, gallery_state2], outputs=[like_image2])
+ dislike_radio1.select(fn=select_image, inputs=[dislike_radio1, images_method1], outputs=[dislike_image1])
+ like_radio1.select(fn=select_image, inputs=[like_radio1, images_method1], outputs=[like_image1])
+ dislike_radio2.select(fn=select_image, inputs=[dislike_radio2, images_method2], outputs=[dislike_image2])
+ like_radio2.select(fn=select_image, inputs=[like_radio2, images_method2], outputs=[like_image2])
redesign_btn1.click(
fn=redesign,
- inputs=[participant, scenario, prompt1, sim_radio1, response1, images_method1, active_tab],
- outputs=[gallery_state1, sim_radio1, response1, prompt1, next_btn1, redesign_btn1, submit_btn1]
+ inputs=[participant, scenario, prompt1, sim_radio1, images_method1, history_images1, active_tab],
+ outputs=[sim_radio1, dislike_radio1, like_radio1, images_method1, history_images1, example1.dataset, prompt1, next_btn1, redesign_btn1, submit_btn1]
)
redesign_btn2.click(
fn=redesign,
- inputs=[participant, scenario, prompt2, sim_radio2, response2, images_method2, active_tab],
- outputs=[gallery_state2, sim_radio2, response2, prompt2, next_btn2, redesign_btn2, submit_btn2]
+ inputs=[participant, scenario, prompt2, sim_radio2, images_method2, history_images2, active_tab],
+ outputs=[sim_radio2, dislike_radio2, like_radio2, images_method2, history_images2, example2.dataset, prompt2, next_btn2, redesign_btn2, submit_btn2]
)
submit_btn1.click(fn=save_response,
- inputs=[participant, scenario, prompt1, sim_radio1, response1, images_method1, active_tab],
- outputs=[images_method1, gallery_state1, sim_radio1, dislike_radio1, like_radio1, like_image1, dislike_image1, prompt1, response1, next_btn1, redesign_btn1, submit_btn1, tabs, scenario])
+ inputs=[participant, scenario, prompt1, sim_radio1, active_tab],
+ outputs=[sim_radio1, dislike_radio1, like_radio1, like_image1, dislike_image1, images_method1, history_images1, example1.dataset, prompt1, next_btn1, redesign_btn1, submit_btn1, scenario, tabs])
submit_btn2.click(fn=save_response,
- inputs=[participant, scenario, prompt2, sim_radio2, response2, images_method2, active_tab],
- outputs=[images_method2, gallery_state2, sim_radio2, dislike_radio2, like_radio2, like_image2, dislike_image2, prompt2, response2, next_btn2, redesign_btn2, submit_btn2, tabs, scenario])
+ inputs=[participant, scenario, prompt2, sim_radio2, active_tab],
+ outputs=[sim_radio2, dislike_radio2, like_radio2, like_image2, dislike_image2, images_method2, history_images2, example2.dataset, prompt2, next_btn2, redesign_btn2, submit_btn2, scenario, tabs])
if __name__ == "__main__":
diff --git a/images/.DS_Store b/images/.DS_Store
deleted file mode 100644
index c97704c16cc70739e72d2ae1ae9a4232c2d192b4..0000000000000000000000000000000000000000
Binary files a/images/.DS_Store and /dev/null differ
diff --git a/images/scenario1_base1.png b/images/scenario1_base1.png
index 470fdb4b944de2bf0ecb9bbccb95e82d308db5c1..362f9b9cbadb201c78631c334b9c416e8a9de8e7 100644
Binary files a/images/scenario1_base1.png and b/images/scenario1_base1.png differ
diff --git a/images/scenario1_base2.png b/images/scenario1_base2.png
index 0a8305100d042e117d6e22dca2885173a23c1111..f1bcaf7d867fd738301f7151077056088da9207c 100644
Binary files a/images/scenario1_base2.png and b/images/scenario1_base2.png differ
diff --git a/images/scenario1_base3.png b/images/scenario1_base3.png
index edadcfb1c9fbb42f7da818a99875a2582d7b30e6..7915687cabcef9b4879af35c3b67cba65bb85dad 100644
Binary files a/images/scenario1_base3.png and b/images/scenario1_base3.png differ
diff --git a/images/scenario1_base4.png b/images/scenario1_base4.png
index fe257c4a4da9fb90572a6788ca9f85710c92abfc..c3f2c45bd91e0732d66452a2d69972a21363e951 100644
Binary files a/images/scenario1_base4.png and b/images/scenario1_base4.png differ
diff --git a/images/scenario1_our1.png b/images/scenario1_our1.png
deleted file mode 100644
index 26efc98daea3c0c2d100256b92fd33345d91c203..0000000000000000000000000000000000000000
Binary files a/images/scenario1_our1.png and /dev/null differ
diff --git a/images/scenario1_our2.png b/images/scenario1_our2.png
deleted file mode 100644
index 9e19fef280421caccec9ce5ddeef5fd1fb2a6b12..0000000000000000000000000000000000000000
Binary files a/images/scenario1_our2.png and /dev/null differ
diff --git a/images/scenario1_our3.png b/images/scenario1_our3.png
deleted file mode 100644
index b59ca55cf74dc339f39985443bcf0ab9c92d4c6a..0000000000000000000000000000000000000000
Binary files a/images/scenario1_our3.png and /dev/null differ
diff --git a/images/scenario1_our4.png b/images/scenario1_our4.png
deleted file mode 100644
index e6e763437be0ccec960b4d1299f6580828b3caf5..0000000000000000000000000000000000000000
Binary files a/images/scenario1_our4.png and /dev/null differ
diff --git a/images/scenario1_ours1.png b/images/scenario1_ours1.png
new file mode 100644
index 0000000000000000000000000000000000000000..bd152a46a055c29a03ae7ecf5b5c36b1eab0c565
Binary files /dev/null and b/images/scenario1_ours1.png differ
diff --git a/images/scenario1_ours2.png b/images/scenario1_ours2.png
new file mode 100644
index 0000000000000000000000000000000000000000..7d6a399de539e6f224f3de5cdee9be122ff63097
Binary files /dev/null and b/images/scenario1_ours2.png differ
diff --git a/images/scenario1_ours3.png b/images/scenario1_ours3.png
new file mode 100644
index 0000000000000000000000000000000000000000..08b6e27ea6a5682378499035668d0f168374391f
Binary files /dev/null and b/images/scenario1_ours3.png differ
diff --git a/images/scenario1_ours4.png b/images/scenario1_ours4.png
new file mode 100644
index 0000000000000000000000000000000000000000..a6b6f20660f0a29ea56c3abf99b0a6b0b28844f9
Binary files /dev/null and b/images/scenario1_ours4.png differ
diff --git a/images/scenario2_base1.png b/images/scenario2_base1.png
index b730b33a2e650bc81e21e86760e71345743e162a..197e38831fbda404a55f402edb6ff8a565001edb 100644
Binary files a/images/scenario2_base1.png and b/images/scenario2_base1.png differ
diff --git a/images/scenario2_base2.png b/images/scenario2_base2.png
index 60a4d31e4d94e9a8824a114ec260c83f74c8ff15..8f6a3ad5826fe6706be260febe9acb5e60789bc1 100644
Binary files a/images/scenario2_base2.png and b/images/scenario2_base2.png differ
diff --git a/images/scenario2_base3.png b/images/scenario2_base3.png
index ec64fdc26e4a249e0734ead89682b46a8115d5d2..c5cb709155308cb93832130cebda3e89fa4b6e0b 100644
Binary files a/images/scenario2_base3.png and b/images/scenario2_base3.png differ
diff --git a/images/scenario2_base4.png b/images/scenario2_base4.png
index 53e33e6a85c7df3dde917f038872333c30e43cdb..7a462c09c72e6cb394a95cb4485d4e934ae0b9e0 100644
Binary files a/images/scenario2_base4.png and b/images/scenario2_base4.png differ
diff --git a/images/scenario2_our1.png b/images/scenario2_our1.png
deleted file mode 100644
index 31dd0b8a48e1ffe22ac279357b259e999a338bb7..0000000000000000000000000000000000000000
Binary files a/images/scenario2_our1.png and /dev/null differ
diff --git a/images/scenario2_our2.png b/images/scenario2_our2.png
deleted file mode 100644
index 1948548951b603c4980530b5e3bb0416079b7765..0000000000000000000000000000000000000000
Binary files a/images/scenario2_our2.png and /dev/null differ
diff --git a/images/scenario2_our3.png b/images/scenario2_our3.png
deleted file mode 100644
index 9ffcd41fd9119b185fad1bcf38ccc7ba3469fbd0..0000000000000000000000000000000000000000
Binary files a/images/scenario2_our3.png and /dev/null differ
diff --git a/images/scenario2_our4.png b/images/scenario2_our4.png
deleted file mode 100644
index 0d7c0aa77c3f52b83bba3a376fe0d2d691be1298..0000000000000000000000000000000000000000
Binary files a/images/scenario2_our4.png and /dev/null differ
diff --git a/images/scenario2_ours1.png b/images/scenario2_ours1.png
new file mode 100644
index 0000000000000000000000000000000000000000..d73afbda69d572e599a6acc0e90ef45b751a6f96
Binary files /dev/null and b/images/scenario2_ours1.png differ
diff --git a/images/scenario2_ours2.png b/images/scenario2_ours2.png
new file mode 100644
index 0000000000000000000000000000000000000000..5861c876a8794e4c82737120a91f298be341e4df
Binary files /dev/null and b/images/scenario2_ours2.png differ
diff --git a/images/scenario2_ours3.png b/images/scenario2_ours3.png
new file mode 100644
index 0000000000000000000000000000000000000000..85bd55025775063266116ba186d73867d1b3ada2
Binary files /dev/null and b/images/scenario2_ours3.png differ
diff --git a/images/scenario2_ours4.png b/images/scenario2_ours4.png
new file mode 100644
index 0000000000000000000000000000000000000000..427fafa41c54c01da947f79b5d205b21fccb6e68
Binary files /dev/null and b/images/scenario2_ours4.png differ
diff --git a/images/scenario3_base1.png b/images/scenario3_base1.png
index 2933f02f0edfd6f49d4c8b6ed9ba47ff3f9c9c9d..3dfe7504477f121504c1635157af65a6fb642f36 100644
Binary files a/images/scenario3_base1.png and b/images/scenario3_base1.png differ
diff --git a/images/scenario3_base2.png b/images/scenario3_base2.png
index 6441d12f44ce166fdfdec9633b22d3fba7aa82d3..5270005588f1289d537c367f6d057e7443d72d4c 100644
Binary files a/images/scenario3_base2.png and b/images/scenario3_base2.png differ
diff --git a/images/scenario3_base3.png b/images/scenario3_base3.png
index e88452336f6cd27e1ea689bf8b6caa171863bb03..01c5029aa0543e3cf6f155fea84bda06d5f73630 100644
Binary files a/images/scenario3_base3.png and b/images/scenario3_base3.png differ
diff --git a/images/scenario3_base4.png b/images/scenario3_base4.png
index 0efe6763bca849070c4c0607cdb09c7e69cb9ac4..3cc4c35c744c35bfa0f3b675ed6d5946a0343b12 100644
Binary files a/images/scenario3_base4.png and b/images/scenario3_base4.png differ
diff --git a/images/scenario3_our1.png b/images/scenario3_our1.png
deleted file mode 100644
index bcf23b153d1fa8a0daefaf150a53451d4dd0d957..0000000000000000000000000000000000000000
Binary files a/images/scenario3_our1.png and /dev/null differ
diff --git a/images/scenario3_our2.png b/images/scenario3_our2.png
deleted file mode 100644
index 08913594a4b224b54bf7d5d3430d757344fbe19a..0000000000000000000000000000000000000000
Binary files a/images/scenario3_our2.png and /dev/null differ
diff --git a/images/scenario3_our3.png b/images/scenario3_our3.png
deleted file mode 100644
index 069d34ba15c1b6f57e5d9eb2ecacd4b620255cce..0000000000000000000000000000000000000000
Binary files a/images/scenario3_our3.png and /dev/null differ
diff --git a/images/scenario3_our4.png b/images/scenario3_our4.png
deleted file mode 100644
index 6d02aef41a7f8b630eeac903a6a9f93f241ec0e6..0000000000000000000000000000000000000000
Binary files a/images/scenario3_our4.png and /dev/null differ
diff --git a/images/scenario3_ours1.png b/images/scenario3_ours1.png
new file mode 100644
index 0000000000000000000000000000000000000000..04f4c13b92112961837188237a70d2d0009038b2
Binary files /dev/null and b/images/scenario3_ours1.png differ
diff --git a/images/scenario3_ours2.png b/images/scenario3_ours2.png
new file mode 100644
index 0000000000000000000000000000000000000000..08dc5316b7c0274b2ec5e2afed2f9651abeb8f44
Binary files /dev/null and b/images/scenario3_ours2.png differ
diff --git a/images/scenario3_ours3.png b/images/scenario3_ours3.png
new file mode 100644
index 0000000000000000000000000000000000000000..b14e5af8d1475c50ea24250aacc0b6a0b56e38d6
Binary files /dev/null and b/images/scenario3_ours3.png differ
diff --git a/images/scenario3_ours4.png b/images/scenario3_ours4.png
new file mode 100644
index 0000000000000000000000000000000000000000..2f490c380a6588ff5117321f642506df059c3359
Binary files /dev/null and b/images/scenario3_ours4.png differ
diff --git a/images/scenario4_base1.png b/images/scenario4_base1.png
index 76aa0370baffe965e149bcdb3486cd043988399f..fa02ab670f677878e81a8574d47fa1b7fab0fcd0 100644
Binary files a/images/scenario4_base1.png and b/images/scenario4_base1.png differ
diff --git a/images/scenario4_base2.png b/images/scenario4_base2.png
index 75cbc6bd66c0b66557516b423e541757aa3c4320..30eb5ad0469229a164d11679f56a8ca7d650010a 100644
Binary files a/images/scenario4_base2.png and b/images/scenario4_base2.png differ
diff --git a/images/scenario4_base3.png b/images/scenario4_base3.png
index c0eb30f5e1f7c69d5c570ac62f636d36c8b62759..f8fa5a39a235dc6976d6a4f82ab8da0969ffa0db 100644
Binary files a/images/scenario4_base3.png and b/images/scenario4_base3.png differ
diff --git a/images/scenario4_base4.png b/images/scenario4_base4.png
index 6cffd95b36af91a7aa5213b4ffb528b5f9ab541b..22c63a98476dfa1a245d1dd5c426fa14c1cf2007 100644
Binary files a/images/scenario4_base4.png and b/images/scenario4_base4.png differ
diff --git a/images/scenario4_our1.png b/images/scenario4_our1.png
deleted file mode 100644
index d0313104de102b81c680695c8bd033e6c7264fe0..0000000000000000000000000000000000000000
Binary files a/images/scenario4_our1.png and /dev/null differ
diff --git a/images/scenario4_our2.png b/images/scenario4_our2.png
deleted file mode 100644
index ff9a2e654382160ae347f4f172b4c253e459c285..0000000000000000000000000000000000000000
Binary files a/images/scenario4_our2.png and /dev/null differ
diff --git a/images/scenario4_our3.png b/images/scenario4_our3.png
deleted file mode 100644
index 9b587e764be9b04189f289ce27537a691a016d77..0000000000000000000000000000000000000000
Binary files a/images/scenario4_our3.png and /dev/null differ
diff --git a/images/scenario4_our4.png b/images/scenario4_our4.png
deleted file mode 100644
index ef264c8429d88a771083fadbc526202f7250cfbd..0000000000000000000000000000000000000000
Binary files a/images/scenario4_our4.png and /dev/null differ
diff --git a/images/scenario4_ours1.png b/images/scenario4_ours1.png
new file mode 100644
index 0000000000000000000000000000000000000000..0c220c0c40a5e6b762e84bc4cf2860f2d6853ce2
Binary files /dev/null and b/images/scenario4_ours1.png differ
diff --git a/images/scenario4_ours2.png b/images/scenario4_ours2.png
new file mode 100644
index 0000000000000000000000000000000000000000..7d9e147ef8d6ed6fd5ee39de39199ad4a19c85c2
Binary files /dev/null and b/images/scenario4_ours2.png differ
diff --git a/images/scenario4_ours3.png b/images/scenario4_ours3.png
new file mode 100644
index 0000000000000000000000000000000000000000..41638b80664925acf0080841034e8ea168354221
Binary files /dev/null and b/images/scenario4_ours3.png differ
diff --git a/images/scenario4_ours4.png b/images/scenario4_ours4.png
new file mode 100644
index 0000000000000000000000000000000000000000..265a461b20baaaf2c02074cc4616cf97dc458e81
Binary files /dev/null and b/images/scenario4_ours4.png differ
diff --git a/images/scenario5_base1.png b/images/scenario5_base1.png
deleted file mode 100644
index 846c158b7d64f069a5260a1cb899a00505dac67d..0000000000000000000000000000000000000000
Binary files a/images/scenario5_base1.png and /dev/null differ
diff --git a/images/scenario5_base2.png b/images/scenario5_base2.png
deleted file mode 100644
index 43a5496d87a3870e22831ee8a2239c73bf989315..0000000000000000000000000000000000000000
Binary files a/images/scenario5_base2.png and /dev/null differ
diff --git a/images/scenario5_base3.png b/images/scenario5_base3.png
deleted file mode 100644
index be2731786ed47769b4f54e55e7e94fcfe2aa662b..0000000000000000000000000000000000000000
Binary files a/images/scenario5_base3.png and /dev/null differ
diff --git a/images/scenario5_base4.png b/images/scenario5_base4.png
deleted file mode 100644
index 934cf4356516164fb958bfa2229fe3fc79062a75..0000000000000000000000000000000000000000
Binary files a/images/scenario5_base4.png and /dev/null differ
diff --git a/images/scenario5_our1.png b/images/scenario5_our1.png
deleted file mode 100644
index 4350496a9416aeb132dd5cb18b0df74301a496bf..0000000000000000000000000000000000000000
Binary files a/images/scenario5_our1.png and /dev/null differ
diff --git a/images/scenario5_our2.png b/images/scenario5_our2.png
deleted file mode 100644
index 5366797b99f2f1244cbcf8e6e54ce9bbba118f91..0000000000000000000000000000000000000000
Binary files a/images/scenario5_our2.png and /dev/null differ
diff --git a/images/scenario5_our3.png b/images/scenario5_our3.png
deleted file mode 100644
index 869b1cb973e8045c26233b7cd6cb8241a77fce06..0000000000000000000000000000000000000000
Binary files a/images/scenario5_our3.png and /dev/null differ
diff --git a/images/scenario5_our4.png b/images/scenario5_our4.png
deleted file mode 100644
index 2955127e19c4c009dc9895e531ee56a3974399e4..0000000000000000000000000000000000000000
Binary files a/images/scenario5_our4.png and /dev/null differ
diff --git a/utils.py b/utils.py
index 2ef53a920b059ba1e01d349fb89952682fbdf16f..472b7e81ab150b8bc748a6775ab0406b923a71de 100644
--- a/utils.py
+++ b/utils.py
@@ -5,6 +5,9 @@ import torch
import os
from openai import OpenAI
import subprocess
+import spaces #[uncomment to use ZeroGPU]
+import base64
+from io import BytesIO
T2I_MODELS = {
"Stable Diffusion v2.1": "stabilityai/stable-diffusion-2-1",
@@ -15,35 +18,31 @@ T2I_MODELS = {
SCENARIOS = {
"Product advertisement": "You are designing an advertising campaign for a new line of coffee machines. To ensure the campaign resonates with a wider audience, you use generative models to create marketing images that showcase a variety of users interacting with the product.",
- "Tourist promotion": "You are creating a travel campaign to attract a diverse range of visitors to a specific destination. To make the promotional materials more engaging and inclusive, you use generative models to design posters that highlight a broader array of experiences.",
- "Fictional character generation": "You are creating a narrative superhero game where the player often interacts with multiple other non-player characters in the story. To test how different characters would affect the experience of gameplay, you decide to use generative models to help construct characters for (play)testing.",
- "Interior Design": "You have a one-bedroom apartment and want to arrange your bed, desk, and dresser in the best way possible. You love the color white and want to ensure your space feels bright and open. To make a decision, you need a way to visualize different furniture placements before setting everything up.",
-# "Education & accessibility": "You are a grade school teacher and the lesson of the day is about teamwork. Some of your students may have a difficult time visualizing what teamwork looks like because they are either (1) too young, (2) English is not their first language, or (3) they may have cognitive impairments that make it difficult for them to visualize concepts (e.g. aphantasia).."
+ "Tourist promotion": "You are creating a travel campaign to attract a variety of visitors to a specific destination. To make the promotional materials more engaging, you use generative models to design posters that highlight a broader array of experiences.",
+ "Fictional character generation": "You are creating a superhero video game thatβs fun and relatable to a range of users. You decide to use generative models to help visualize a new character.",
+ "Interior Design": "You are helping design the furniture layout for a model one-bedroom rental apartment. To make the apartment appealing to different potential tenants, you try to visualize different furniture placements before setting everything up.",
}
PROMPTS = {
- "Product advertisement": "Design a marketing advertisement image for a coffee machine.",
- "Tourist promotion": "Design a travel promotional poster to showcase the beauty and attractions of a tourist destination.",
- "Fictional character generation": "Generate a character of a superhero.",
- "Interior Design": "Generate an one-bedroom apartment interior design.",
- # "Education & accessibility": "Generate an image of grade school students buildind a sandcastle together on the beach."
+ "Product advertisement": "Design an advertisement image showcasing a range of users operating coffee machines.",
+ "Tourist promotion": "Design a promotional poster to attract a variety of visitors to a tourist destination.",
+ "Fictional character generation": "Design a video game superhero character that is relatable. ",
+ "Interior Design": "Design an apartment thatβs appealing to potential tenants.",
}
IMAGES = {
"Product advertisement": {"baseline": ["images/scenario1_base1.png","images/scenario1_base2.png","images/scenario1_base3.png","images/scenario1_base4.png"],
- "ours": ["images/scenario1_our1.png","images/scenario1_our2.png","images/scenario1_our3.png","images/scenario1_our4.png"]},
- "Tourist promotion": {"baseline": ["images/scenario5_base1.png","images/scenario5_base2.png","images/scenario5_base3.png","images/scenario5_base4.png"],
- "ours": ["images/scenario5_our1.png","images/scenario5_our2.png","images/scenario5_our3.png","images/scenario5_our4.png"]},
- "Fictional character generation": {"baseline": ["images/scenario2_base1.png","images/scenario2_base2.png","images/scenario2_base3.png","images/scenario2_base4.png"],
- "ours": ["images/scenario2_our1.png","images/scenario2_our2.png","images/scenario2_our3.png","images/scenario2_our4.png"]},
- "Interior Design": {"baseline": ["images/scenario3_base1.png","images/scenario3_base2.png","images/scenario3_base3.png","images/scenario3_base4.png"],
- "ours": ["images/scenario3_our1.png","images/scenario3_our2.png","images/scenario3_our3.png","images/scenario3_our4.png"]},
- # "Education & accessibility": {"baseline": ["images/scenario4_base1.png","images/scenario4_base2.png","images/scenario4_base3.png","images/scenario4_base4.png"],
- # "ours": ["images/scenario4_our1.png","images/scenario4_our2.png","images/scenario4_our3.png","images/scenario4_our4.png"]},
+ "ours": ["images/scenario1_ours1.png","images/scenario1_ours2.png","images/scenario1_ours3.png","images/scenario1_ours4.png"]},
+ "Tourist promotion": {"baseline": ["images/scenario2_base1.png","images/scenario2_base2.png","images/scenario2_base3.png","images/scenario2_base4.png"],
+ "ours": ["images/scenario2_ours1.png","images/scenario2_ours2.png","images/scenario2_ours3.png","images/scenario2_ours4.png"]},
+ "Fictional character generation": {"baseline": ["images/scenario3_base1.png","images/scenario3_base2.png","images/scenario3_base3.png","images/scenario3_base4.png"],
+ "ours": ["images/scenario3_ours1.png","images/scenario3_ours2.png","images/scenario3_ours3.png","images/scenario3_ours4.png"]},
+ "Interior Design": {"baseline": ["images/scenario4_base1.png","images/scenario4_base2.png","images/scenario3_base4.png","images/scenario4_base4.png"],
+ "ours": ["images/scenario4_ours1.png","images/scenario4_ours2.png","images/scenario4_ours3.png","images/scenario4_ours4.png"]},
}
OPTIONS = ["Very Unsatisfied", "Unsatisfied", "Slightly Unsatisfied", "Neutral", "Slightly Satisfied", "Satisfied", "Very Satisfied"]
-IMAGE_OPTIONS = ["First Image", "Second Image", "Third Image", "Fourth Image", "None of them"]
+IMAGE_OPTIONS = ["First Image", "Second Image", "Third Image", "Fourth Image"]
INSTRUCTION = "π **Instruction**: Now, we want to understand your satisfaction with the images generated.
π Step 1: You will start from evaluating the following images based on the given prompt.
π Step 2: Then please modify the prompt according to your expectations for the given scenario background, and answer the evaluation question **until you are satisfied** with at least one of the images generated below. If you are not satisfied with the generated images, you can repeatedly modify the prompts for at most **5 times**."
def clean_cache():
subprocess.run("rm -rf /data-nvme/zerogpu-offload/*", env={}, shell=True)
@@ -102,3 +101,90 @@ Can you give me {num_prompts} modified prompts for the prompt '{prompt}' please.
messages.append({"role": "user", "content": f"{message}"})
return messages
+
+def encode_image(image):
+ buffered = BytesIO()
+ image.save(buffered, format="PNG")
+ return base64.b64encode(buffered.getvalue()).decode("utf-8")
+
+def get_personalize_message(prompt, history_prompts, history_feedback, like_image, dislike_image):
+ messages = [{"role": "system", "content": f"You will act as a prompt optimization assistant that helps refine an original prompt based on user feedback over multiple rounds of image generation. The goal is to dynamically adjust the prompt to better align with user preferences while preserving the original intent."}]
+
+ message = f"""The process consists of a maximum of 5 rounds.
+ Users start with an initial prompt and generate 4 images per round. After reviewing the images, users will modify the prompt based on their preferences. Then we will generate new images based on the modified prompt.
+ Users will rate the generated images on a scale from ["Very Unsatisfied", "Unsatisfied", "Slightly Unsatisfied", "Neutral", "Slightly Satisfied", "Satisfied", "Very Satisfied"], indicating how satisfied they are with the results.
+
+ Your task is to analyze the sequence of modified prompts and corresponding ratings to refine the prompt dynamically, ensuring improved results in the next rounds. For each new round, you should:
+ Incorporate the user's modifications: Use the latest user-revised prompt as a reference but retain essential details from previous rounds if they contributed positively.
+
+ Analyze user ratings:
+ If the rating is high ("Satisfied", "Very Satisfied") β Maintain key aspects of the most recent prompt since it aligns well with user preferences.
+ If the rating is medium ("Slightly Unsatisfied", "Neutral", "Slightly Satisfied") β Adjust minor details that could improve alignment with the userβs preferences, considering the changes from previous rounds.
+ If the rating is low ("Very Unsatisfied", "Unsatisfied") β Identify aspects that might be causing dissatisfaction (e.g., unwanted elements, style mismatch) and rework the prompt while keeping the userβs core intent intact.
+
+ Refine the prompt intelligently and ensure the following:
+ - The updated prompt reflects user feedback without unnecessary repetition.
+ - Unwanted elements (if any) from previous rounds are removed.
+ - Preferred elements are retained and enhanced.
+ - The modifications remain subtle but progressive to ensure smooth refinement over multiple rounds.
+ - Maintain coherence: Avoid drastic changes that might deviate from the original intent unless the user explicitly requests them.
+
+
+ Now given the following revised prompts and ratings from user\n:
+"""
+
+ for his_prompt, feedback in zip(history_prompts, history_feedback):
+ message += f"Revised prompt: {his_prompt}; Rating: {feedback}\n"
+
+ message += f"\nWe also provide the user's preferred image during this process as the first image provided and the disliked image as the second image\n"
+ message += "Now, please optimize current prompt and only output the modified prompt: '{prompt}'"""
+
+ messages.append({
+ "role": "user",
+ "content": [
+ {"type": "text", "text": f"{message}"},
+ ],
+ })
+ if like_image:
+ like_image_base64 = encode_image(like_image)
+ messages[-1]["content"].append({
+ "type": "image_url",
+ "image_url": {
+ "url": f"data:image/jpeg;base64,{like_image_base64}",
+ },
+ })
+ if dislike_image:
+ dislike_image_base64 = encode_image(dislike_image)
+ messages[-1]["content"].append({
+ "type": "image_url",
+ "image_url": {
+ "url": f"data:image/jpeg;base64,{dislike_image_base64}",
+ },
+ })
+
+ print(messages)
+
+ return messages
+
+@spaces.GPU
+def call_llm_refine_prompt(prompt, num_prompts=5, max_tokens=1000, temperature=0.7, top_p=0.9):
+ print(f"loading {default_llm_model}")
+ global llm_pipe
+ if not llm_pipe:
+ llm_pipe = transformers.pipeline("text-generation", model=default_llm_model, model_kwargs={"torch_dtype": torch_dtype}, device_map="auto")
+
+ messages = get_refine_msg(prmpt, num_prompts)
+ terminators = [
+ llm_pipe.tokenizer.eos_token_id,
+ llm_pipe.tokenizer.convert_tokens_to_ids("<|eot_id|>")
+ ]
+ outputs = llm_pipe(
+ messages,
+ max_new_tokens=max_tokens,
+ eos_token_id=terminators,
+ do_sample=True,
+ temperature=temperature,
+ top_p=top_p,
+ )
+ prompt_list = clean_response_gpt(outputs[0]["generated_text"][-1]["content"])
+ return prompt_list
\ No newline at end of file