POET / utils.py
xh365's picture
update verbal message
d751ed8
raw
history blame
12.1 kB
import re
from diffusers import DiffusionPipeline, FluxPipeline
from live_preview_helpers import FLUXPipelineWithIntermediateOutputs
import torch
import os
from openai import OpenAI
import subprocess
import spaces #[uncomment to use ZeroGPU]
import base64
from io import BytesIO
T2I_MODELS = {
"Stable Diffusion v2.1": "stabilityai/stable-diffusion-2-1",
"SDXL-Turbo": "stabilityai/sdxl-turbo",
"Stable Diffusion v3.5-medium": "stabilityai/stable-diffusion-3.5-medium", # Default
"Flux.1-dev": "black-forest-labs/FLUX.1-dev",
}
SCENARIOS = {
"Product advertisement": "You are designing an advertising campaign for a new line of coffee machines. To ensure the campaign resonates with a wider audience, you use generative models to create marketing images that showcase a variety of users interacting with the product.",
"Tourist promotion": "You are creating a travel campaign to attract a variety of visitors to a specific destination. To make the promotional materials more engaging, you use generative models to design posters that highlight a broader array of experiences.",
"Fictional character generation": "You are creating a superhero video game that’s fun and relatable to a range of users. You decide to use generative models to help visualize a new character.",
"Interior Design": "You are helping design the furniture layout for a model one-bedroom rental apartment. To make the apartment appealing to different potential tenants, you try to visualize different furniture placements before setting everything up.",
}
PROMPTS = {
"Product advertisement": "Design an advertisement image showcasing a range of users operating coffee machines.",
"Tourist promotion": "Design a promotional poster to attract a variety of visitors to a tourist destination.",
"Fictional character generation": "Design a video game superhero character that is relatable. ",
"Interior Design": "Design an apartment that’s appealing to potential tenants.",
}
IMAGES = {
"Product advertisement": {"baseline": ["images/scenario1_base1.png","images/scenario1_base2.png","images/scenario1_base3.png","images/scenario1_base4.png"],
"ours": ["images/scenario1_ours1.png","images/scenario1_ours2.png","images/scenario1_ours3.png","images/scenario1_ours4.png"]},
"Tourist promotion": {"baseline": ["images/scenario2_base1.png","images/scenario2_base2.png","images/scenario2_base3.png","images/scenario2_base4.png"],
"ours": ["images/scenario2_ours1.png","images/scenario2_ours2.png","images/scenario2_ours3.png","images/scenario2_ours4.png"]},
"Fictional character generation": {"baseline": ["images/scenario3_base1.png","images/scenario3_base2.png","images/scenario3_base3.png","images/scenario3_base4.png"],
"ours": ["images/scenario3_ours1.png","images/scenario3_ours2.png","images/scenario3_ours3.png","images/scenario3_ours4.png"]},
"Interior Design": {"baseline": ["images/scenario4_base1.png","images/scenario4_base2.png","images/scenario4_base3.png","images/scenario4_base4.png"],
"ours": ["images/scenario4_ours1.png","images/scenario4_ours2.png","images/scenario4_ours3.png","images/scenario4_ours4.png"]},
}
OPTIONS = ["Very Unsatisfied", "Unsatisfied", "Slightly Unsatisfied", "Neutral", "Slightly Satisfied", "Satisfied", "Very Satisfied"]
IMAGE_OPTIONS = ["First Image", "Second Image", "Third Image", "Fourth Image"]
INSTRUCTION = "📌 **Instruction**: Now, we want to understand your satisfaction with the images generated. <br /> 📌 Step 1: You will start from evaluating the following images based on the given prompt. <br /> 📌 Step 2: Then please modify the prompt according to your expectations for the given scenario background, and answer the evaluation question **until you are satisfied** with at least one of the images generated below. If you are not satisfied with the generated images, you can repeatedly modify the prompts for at most **5 times**."
def clean_cache():
subprocess.run("rm -rf /data-nvme/zerogpu-offload/*", env={}, shell=True)
if torch.cuda.is_available():
torch.cuda.empty_cache()
def setup_model(t2i_model_repo, torch_dtype, device):
if t2i_model_repo == "stabilityai/sdxl-turbo" or t2i_model_repo == "stabilityai/stable-diffusion-3.5-medium" or t2i_model_repo == "stabilityai/stable-diffusion-2-1":
pipe = DiffusionPipeline.from_pretrained(t2i_model_repo, torch_dtype=torch_dtype).to(device)
elif t2i_model_repo == "black-forest-labs/FLUX.1-dev":
# pipe = FluxPipeline.from_pretrained(t2i_model_repo, torch_dtype=torch_dtype).to(device)
pipe = FLUXPipelineWithIntermediateOutputs.from_pretrained(t2i_model_repo, torch_dtype=torch_dtype).to(device)
torch.cuda.empty_cache()
return pipe
def init_gpt_api():
return OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
def call_gpt_api(messages, client, model, seed, max_tokens, temperature, top_p):
completion = client.chat.completions.create(
model=model,
messages=messages,
seed=seed,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
)
return completion.choices[0].message.content
def clean_response_gpt(res: str):
prompts = re.findall(r'\d+\.\s"?(.*?)"?(?=\n|$)', res)
return prompts
def clean_refined_prompt_response_gpt(res: str):
# Using regex to extract the refined prompt
match = re.search(r"\*\*Refined Prompt:\*\*\n\n(.+)", res, re.DOTALL)
if match:
refined_prompt = match.group(1).strip()
else:
refined_prompt = res.strip() # Fallback: Use full text if no match found
return refined_prompt
def get_refine_msg(prompt, num_prompts):
messages = [{"role": "system", "content": f"You are a helpful, respectful and precise assistant. You will be asked to generate {num_prompts} refined prompts. Only respond with those refined prompts"}]
message = f"""Given a prompt, modify the prompt for me to explore variations in subject attributes, actions, and contextual details, while retaining the semantic consistency of the original description.
Follow the following refinement instruction:
1. Subject: refine broad terms into specific subsets, focusing on but not restricted on ethinity, gender, age of human.
2. Object: modify the brand, color of object(s) only if it's not specified in the prompt.
3. Setting: add details to the background environment, such as change of temporal or spatial details (e.g., day to night, indoor to outdoor).
4. Action: add more details to the action or specify the object or goal of the action.
For example, given this prompt: a person is drinking a coffee in a coffee shop, the refined prompts could be:
'an elderly woman is drinking a coffee in a coffee shop' (subject adjective)
'an asian young woman is drinking a coffee in a coffee shop' (subject adjective)
'a young woman is drinking a hot coffee with her left hand in a coffee shop' (action details)
'a woman is drinking a coffee in an outdoor coffee shop in the garden' (setting details)
If there is no human in the sentence, you do not need to add person intentionally.
If you use adjectives, they should be visual. So don't use something like 'interesting'. Please also vary the number of modifications but do not change the number of subjects/objects that have been specified in the prompt. Remember not to change the predefined concepts that have been specified in the prompt. e.g. don't change a boy to several boys.
Can you give me {num_prompts} modified prompts for the prompt '{prompt}' please."""
messages.append({"role": "user", "content": f"{message}"})
return messages
def encode_image(image_path):
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")
def get_personalize_message(prompt, history_prompts, history_feedback, like_image, dislike_image):
messages = [
{"role": "system", "content": f"You are a prompt refinement assistant. Your task is to improve a user’s text prompt based on their prompt revision history, satisfaction ratings, and preferences inferred from selected images. Your goal is to refine the prompt while maintaining its original meaning, improving clarity, specificity, and alignment with user preferences."}
]
message = f"""The refinement should preserve the core meaning of the current prompt while improving its clarity, specificity, and style based on user feedback.
### **Input Format:**
1. **Prompt History**: A list of previously revised prompts and their corresponding satisfaction ratings.
2. **Rating Scale**: Very Unsatisfied, Unsatisfied, Slightly Unsatisfied, Neutral, Slightly Satisfied, Satisfied, Very Satisfied
3. **User-Selected Image Preferences**:
- **Preferred Image**: The image the user found most satisfactory.
- **Disliked Image**: The image the user found least satisfactory.
*Note: These images are for reference only and should be used to infer stylistic preferences rather than directly modifying prompt content.*
4. **Current Prompt**: The latest prompt from the user, which requires refinement.
### **Refinement Guidelines:**
- Identify and retain/expand patterns/elements in past revisions and correlate them with satisfaction ratings.
- You may expand current prompt in details and incorporate information from retained pattens in past revisions.
- Avoid or adjust features that led to lower ratings.
- Improve clarity, specificity, and descriptive quality while ensuring the prompt remains faithful to its current prompt's meaning.
- The preferred image reflects desirable attributes; the disliked image indicates elements to avoid. Use these for reference but **do not describe them.**
- Output only the refined prompt, no explanations, disclaimers, or formatting.
The first provided image is the user's preferred image, and the second is the disliked image.
Now, refine the following current prompt based on the given user history and preferences:\n"""
message += "Prompt History\n"
for his_prompt, feedback in zip(history_prompts, history_feedback):
message += f"{his_prompt}: {feedback}\n"
message += f"Current Prompt: '{prompt}'\n Refined Prompt:"
messages.append({
"role": "user",
"content": [
{"type": "text", "text": f"{message}"},
],
})
if like_image:
like_image_base64 = encode_image(like_image)
messages[-1]["content"].append({
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{like_image_base64}",
},
})
if dislike_image:
dislike_image_base64 = encode_image(dislike_image)
messages[-1]["content"].append({
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{dislike_image_base64}",
},
})
return messages
@spaces.GPU
def call_llm_refine_prompt(prompt, num_prompts=5, max_tokens=1000, temperature=0.7, top_p=0.9):
print(f"loading {default_llm_model}")
global llm_pipe
if not llm_pipe:
llm_pipe = transformers.pipeline("text-generation", model=default_llm_model, model_kwargs={"torch_dtype": torch_dtype}, device_map="auto")
messages = get_refine_msg(prmpt, num_prompts)
terminators = [
llm_pipe.tokenizer.eos_token_id,
llm_pipe.tokenizer.convert_tokens_to_ids("<|eot_id|>")
]
outputs = llm_pipe(
messages,
max_new_tokens=max_tokens,
eos_token_id=terminators,
do_sample=True,
temperature=temperature,
top_p=top_p,
)
prompt_list = clean_response_gpt(outputs[0]["generated_text"][-1]["content"])
return prompt_list