import gradio as gr import spaces import json import re import random import numpy as np from gradio_client import Client, handle_file MAX_SEED = np.iinfo(np.int32).max import re import torch from transformers import pipeline zephyr_model = "HuggingFaceH4/zephyr-7b-beta" #mixtral_model = "mistralai/Mixtral-8x7B-Instruct-v0.1" pipe = pipeline("text-generation", model=zephyr_model, torch_dtype=torch.bfloat16, device_map="auto") standard_sys = f""" You are an AI Art Director that specializes in translating music and audio descriptions into visually expressive fashion outfit ideas. Your task: - Given a description of a piece of music or sound, generate a **single outfit suggestion** that captures the mood, tempo, and emotional tone of the audio. - Be specific. Mention the type of clothing, colors, materials, accessories, and any stylistic flourishes. - The response must be friendly but concise (max 1-2 sentences), directly delivering the outfit description. - **Only return the outfit in the following exact format**, within double quotes: "A person dressed in [...]." Do not include any explanations or extra commentary. Examples: Input: "This song features a female vocalist singing a beautiful and emotional melody. The melody is accompanied by the sound of a piano playing a slow and melancholic tune. The song has a dreamy and ethereal feel to it. The lyrics of the song are about the beauty of love and the joy it brings to one's life." Output: "A person dressed with a flowy, pastel-colored dress paired with strappy sandals and a wide-brimmed hat, accessorized with delicate jewelry, such as dainty earrings and a necklace." Input: "A hard-hitting techno track with industrial beats, glitchy textures, and a driving, relentless rhythm." Output: "A person dressed in a black leather jacket over a mesh top, paired with chunky combat boots and silver accessories, with bold eyeliner completing the edgy, cyberpunk look." Always output in this format and stop immediately. """ @spaces.GPU def get_outfit_prompt(user_prompt): agent_maker_sys = standard_sys instruction = f""" <|system|> {agent_maker_sys} <|user|> """ prompt = f"{instruction.strip()}\n{user_prompt}" outputs = pipe(prompt, max_new_tokens=256, do_sample=True, temperature=0.7, top_k=50, top_p=0.95) pattern = r'\<\|system\|\>(.*?)\<\|assistant\|\>' cleaned_text = re.sub(pattern, '', outputs[0]["generated_text"], flags=re.DOTALL) print(f"SUGGESTED Musical prompt: {cleaned_text}") return cleaned_text.lstrip("\n") def get_salmonn(audio_in, prompt, token): client = Client("fffiloni/SALMONN-7B-gradio", hf_token=token) result = client.predict( speech=handle_file(audio_in), text_input=prompt, num_beams=4, temperature=1, top_p=0.9, api_name="/gradio_answer_1" ) print(result) return result def sdxl_image(suggested_outfit_prompt, token): client = Client("ByteDance/SDXL-Lightning", hf_token=token) result = client.predict( prompt=suggested_outfit_prompt, ckpt="4-Step", api_name="/generate_image" ) print(result) return result def extract_json(text): """ Extracts the first JSON object found in a string and parses it. Returns a dictionary or {} if parsing fails. """ try: # Attempt fast parse first return json.loads(text) except json.JSONDecodeError: # Fallback: Extract {...} content using regex match = re.search(r'\{.*\}', text, re.DOTALL) if match: try: return json.loads(match.group(0)) except json.JSONDecodeError as e: print("⚠️ JSON decode failed after match:", e) return {} @spaces.GPU def get_parsed_outfit_items(outfit_sentence): parser_sys = """ You are a fashion assistant AI that helps e-commerce designers turn full outfit descriptions into individual product image prompts. Your task: - Given an outfit description (1 sentence), break it into key labeled parts: dress, top, bottom, shoes, outerwear, jewelry, hat, accessories. - Write one short, specific image-generation prompt per part. - Focus on describing each item visually and clearly as it would appear in a product photo. - Respond only in raw JSON like this: { "shoes": "High-quality product image of brown leather boots, white background", "hat": "Studio photo of a navy beret on a stand, isolated on white" } Respond only with a valid JSON object. Ensure the JSON is properly formatted with correct commas between fields. Do not forget commas between entries. Validate before finishing your response. Do not include any explanations or markdown syntax. No commentary. No extra text. Start directly with `{` and end with `}`. """ prompt = f"""<|system|> {parser_sys} <|user|> "{outfit_sentence}" """ outputs = pipe(prompt, max_new_tokens=512, temperature=0.7, top_k=50, top_p=0.9) pattern = r'\<\|system\|\>(.*?)\<\|assistant\|\>' cleaned_text = re.sub(pattern, '', outputs[0]["generated_text"], flags=re.DOTALL) print(f"\n🧾 Raw LLM response:\n{cleaned_text}") item_dict = extract_json(cleaned_text) print(f"\n🧩 Parsed outfit parts:\n{json.dumps(item_dict, indent=2)}") return item_dict def generate_sdxl_images_list_dynamic(item_prompts, hf_token): images = [] for part, prompt in item_prompts.items(): print(f"Generating image for {part}...") result = sdxl_image(prompt, hf_token) images.append((result, part)) return images def infer(audio_in, oauth_token: gr.OAuthToken): gradio_auth_token = oauth_token.token salmonn_prompt = "Please describe the audio in detail." gr.Info("Calling SALMONN to understand audio...") salmonn_res = get_salmonn(audio_in, salmonn_prompt, gradio_auth_token) yield None, salmonn_res, None gr.Info("Creating an outfit suggestion based on audio understanding...") outfit_sentence = get_outfit_prompt(salmonn_res) yield outfit_sentence, salmonn_res, None gr.Info("Generate an image with SDXL Lightning...") outfit_image = sdxl_image(outfit_sentence, gradio_auth_token) """ gr.Info("Get outfit parts...") item_prompts = get_parsed_outfit_items(outfit_sentence) gr.Info("Generate shopping gallery...") images_with_labels = generate_sdxl_images_list_dynamic(item_prompts, gradio_auth_token) """ yield outfit_sentence, salmonn_res, outfit_image demo_title = "Music to Outfit" description = "Get an outfit idea from audio/music input" css = """ #col-container { margin: 0 auto; max-width: 980px; text-align: left; } #inspi-prompt textarea { font-size: 20px; line-height: 24px; font-weight: 600; } /* fix examples gallery width on mobile */ div#component-11 > .gallery > .gallery-item > .container > img { width: auto!important; } """ with gr.Blocks(css=css) as demo: with gr.Column(elem_id="col-container"): gr.HTML(f"""

{demo_title}

{description}

""") with gr.Row(): with gr.Column(): gr.LoginButton() audio_in = gr.Audio( label = "Audio reference", type = "filepath", elem_id = "audio-in" ) submit_btn = gr.Button("Make an outfit from my sound !") salmonn_desc = gr.Textbox(label="Salmonn audio undestanding") with gr.Column(): caption = gr.Textbox( label = "Inspirational outfit prompt", interactive = False, elem_id = "inspi-prompt" ) result = gr.Image( label = "Outfit propal" ) #clothes_gallery = gr.Gallery() submit_btn.click( fn = infer, inputs = [ audio_in ], outputs =[ caption, salmonn_desc, result, #clothes_gallery ] ) demo.queue().launch(show_api=False, show_error=True, ssr_mode=False)