music-to-outfit / app.py
fffiloni's picture
Update app.py
a56420e verified
import gradio as gr
import spaces
import json
import re
import random
import numpy as np
from gradio_client import Client, handle_file
MAX_SEED = np.iinfo(np.int32).max
import re
import torch
from transformers import pipeline
zephyr_model = "HuggingFaceH4/zephyr-7b-beta"
#mixtral_model = "mistralai/Mixtral-8x7B-Instruct-v0.1"
pipe = pipeline("text-generation", model=zephyr_model, torch_dtype=torch.bfloat16, device_map="auto")
standard_sys = f"""
You are an AI Art Director that specializes in translating music and audio descriptions into visually expressive fashion outfit ideas.
Your task:
- Given a description of a piece of music or sound, generate a **single outfit suggestion** that captures the mood, tempo, and emotional tone of the audio.
- Be specific. Mention the type of clothing, colors, materials, accessories, and any stylistic flourishes.
- The response must be friendly but concise (max 1-2 sentences), directly delivering the outfit description.
- **Only return the outfit in the following exact format**, within double quotes:
"A person dressed in [...]."
Do not include any explanations or extra commentary.
Examples:
Input:
"This song features a female vocalist singing a beautiful and emotional melody. The melody is accompanied by the sound of a piano playing a slow and melancholic tune. The song has a dreamy and ethereal feel to it. The lyrics of the song are about the beauty of love and the joy it brings to one's life."
Output:
"A person dressed with a flowy, pastel-colored dress paired with strappy sandals and a wide-brimmed hat, accessorized with delicate jewelry, such as dainty earrings and a necklace."
Input:
"A hard-hitting techno track with industrial beats, glitchy textures, and a driving, relentless rhythm."
Output:
"A person dressed in a black leather jacket over a mesh top, paired with chunky combat boots and silver accessories, with bold eyeliner completing the edgy, cyberpunk look."
Always output in this format and stop immediately.
"""
@spaces.GPU
def get_outfit_prompt(user_prompt):
agent_maker_sys = standard_sys
instruction = f"""
<|system|>
{agent_maker_sys}</s>
<|user|>
"""
prompt = f"{instruction.strip()}\n{user_prompt}</s>"
outputs = pipe(prompt, max_new_tokens=256, do_sample=True, temperature=0.7, top_k=50, top_p=0.95)
pattern = r'\<\|system\|\>(.*?)\<\|assistant\|\>'
cleaned_text = re.sub(pattern, '', outputs[0]["generated_text"], flags=re.DOTALL)
print(f"SUGGESTED Musical prompt: {cleaned_text}")
return cleaned_text.lstrip("\n")
def get_salmonn(audio_in, prompt, token):
client = Client("fffiloni/SALMONN-7B-gradio", hf_token=token)
result = client.predict(
speech=handle_file(audio_in),
text_input=prompt,
num_beams=4,
temperature=1,
top_p=0.9,
api_name="/gradio_answer_1"
)
print(result)
return result
def sdxl_image(suggested_outfit_prompt, token):
client = Client("ByteDance/SDXL-Lightning", hf_token=token)
result = client.predict(
prompt=suggested_outfit_prompt,
ckpt="4-Step",
api_name="/generate_image"
)
print(result)
return result
def extract_json(text):
"""
Extracts the first JSON object found in a string and parses it.
Returns a dictionary or {} if parsing fails.
"""
try:
# Attempt fast parse first
return json.loads(text)
except json.JSONDecodeError:
# Fallback: Extract {...} content using regex
match = re.search(r'\{.*\}', text, re.DOTALL)
if match:
try:
return json.loads(match.group(0))
except json.JSONDecodeError as e:
print("⚠️ JSON decode failed after match:", e)
return {}
@spaces.GPU
def get_parsed_outfit_items(outfit_sentence):
parser_sys = """
You are a fashion assistant AI that helps e-commerce designers turn full outfit descriptions into individual product image prompts.
Your task:
- Given an outfit description (1 sentence), break it into key labeled parts: dress, top, bottom, shoes, outerwear, jewelry, hat, accessories.
- Write one short, specific image-generation prompt per part.
- Focus on describing each item visually and clearly as it would appear in a product photo.
- Respond only in raw JSON like this:
{
"shoes": "High-quality product image of brown leather boots, white background",
"hat": "Studio photo of a navy beret on a stand, isolated on white"
}
Respond only with a valid JSON object. Ensure the JSON is properly formatted with correct commas between fields.
Do not forget commas between entries. Validate before finishing your response.
Do not include any explanations or markdown syntax. No commentary. No extra text.
Start directly with `{` and end with `}`.
"""
prompt = f"""<|system|>
{parser_sys}</s>
<|user|>
"{outfit_sentence}"</s>
"""
outputs = pipe(prompt, max_new_tokens=512, temperature=0.7, top_k=50, top_p=0.9)
pattern = r'\<\|system\|\>(.*?)\<\|assistant\|\>'
cleaned_text = re.sub(pattern, '', outputs[0]["generated_text"], flags=re.DOTALL)
print(f"\n🧾 Raw LLM response:\n{cleaned_text}")
item_dict = extract_json(cleaned_text)
print(f"\n🧩 Parsed outfit parts:\n{json.dumps(item_dict, indent=2)}")
return item_dict
def generate_sdxl_images_list_dynamic(item_prompts, hf_token):
images = []
for part, prompt in item_prompts.items():
print(f"Generating image for {part}...")
result = sdxl_image(prompt, hf_token)
images.append((result, part))
return images
def infer(audio_in, oauth_token: gr.OAuthToken):
gradio_auth_token = oauth_token.token
salmonn_prompt = "Please describe the audio in detail."
gr.Info("Calling SALMONN to understand audio...")
salmonn_res = get_salmonn(audio_in, salmonn_prompt, gradio_auth_token)
yield None, salmonn_res, None
gr.Info("Creating an outfit suggestion based on audio understanding...")
outfit_sentence = get_outfit_prompt(salmonn_res)
yield outfit_sentence, salmonn_res, None
gr.Info("Generate an image with SDXL Lightning...")
outfit_image = sdxl_image(outfit_sentence, gradio_auth_token)
"""
gr.Info("Get outfit parts...")
item_prompts = get_parsed_outfit_items(outfit_sentence)
gr.Info("Generate shopping gallery...")
images_with_labels = generate_sdxl_images_list_dynamic(item_prompts, gradio_auth_token)
"""
yield outfit_sentence, salmonn_res, outfit_image
demo_title = "Music to Outfit"
description = "Get an outfit idea from audio/music input"
css = """
#col-container {
margin: 0 auto;
max-width: 980px;
text-align: left;
}
#inspi-prompt textarea {
font-size: 20px;
line-height: 24px;
font-weight: 600;
}
/* fix examples gallery width on mobile */
div#component-11 > .gallery > .gallery-item > .container > img {
width: auto!important;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.HTML(f"""
<h2 style="text-align: center;">{demo_title}</h2>
<p style="text-align: center;">{description}</p>
""")
with gr.Row():
with gr.Column():
gr.LoginButton()
audio_in = gr.Audio(
label = "Audio reference",
type = "filepath",
elem_id = "audio-in"
)
submit_btn = gr.Button("Make an outfit from my sound !")
salmonn_desc = gr.Textbox(label="Salmonn audio undestanding")
with gr.Column():
caption = gr.Textbox(
label = "Inspirational outfit prompt",
interactive = False,
elem_id = "inspi-prompt"
)
result = gr.Image(
label = "Outfit propal"
)
#clothes_gallery = gr.Gallery()
submit_btn.click(
fn = infer,
inputs = [
audio_in
],
outputs =[
caption,
salmonn_desc,
result,
#clothes_gallery
]
)
demo.queue().launch(show_api=False, show_error=True, ssr_mode=False)