Spaces:
Running
Running
| import os | |
| import base64 | |
| import numpy as np | |
| from PIL import Image | |
| import io | |
| import requests | |
| import replicate | |
| from flask import Flask, request | |
| import gradio as gr | |
| import openai | |
| from openai import OpenAI | |
| from dotenv import load_dotenv, find_dotenv | |
| import json | |
| # Locate the .env file | |
| dotenv_path = find_dotenv() | |
| load_dotenv(dotenv_path) | |
| OPENAI_API_KEY = os.getenv('OPENAI_API_KEY') | |
| REPLICATE_API_TOKEN = os.getenv('REPLICATE_API_TOKEN') | |
| client = OpenAI() | |
| def call_openai(pil_image): | |
| # Save the PIL image to a bytes buffer | |
| buffered = io.BytesIO() | |
| pil_image.save(buffered, format="JPEG") | |
| # Encode the image to base64 | |
| image_data = base64.b64encode(buffered.getvalue()).decode('utf-8') | |
| try: | |
| response = client.chat.completions.create( | |
| model="gpt-4o", | |
| messages=[ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "text", "text": "You are a product designer. I've attached a moodboard here. In one sentence, what do all of these elements have in common? Answer from a design language perspective, if you were telling another designer to create something similar, including any repeating colors and materials and shapes and textures. This is for a single product, so respond as though you're applying them to a single object. Reply with a completion to the following (don't include these words please, just the rest): [A render of an object which] [your response]. Do NOT include 'A render of an object which' in your response."}, | |
| { | |
| "type": "image_url", | |
| "image_url": { | |
| "url": "data:image/jpeg;base64," + image_data, | |
| }, | |
| }, | |
| ], | |
| } | |
| ], | |
| max_tokens=300, | |
| ) | |
| return response.choices[0].message.content | |
| except openai.BadRequestError as e: | |
| print(e) | |
| print("e type") | |
| print(type(e)) | |
| raise gr.Error(f"Please retry with a different moodboard file (below 20 MB in size and is of one the following formats: ['png', 'jpeg', 'gif', 'webp'])") | |
| except Exception as e: | |
| raise gr.Error("Unknown Error") | |
| # Todo -- better prompt generator, add another LLM layer combining the user prompt and moodboard description (in the case of the jacket, 'high quality render of yellow jacket, its fabric is a pattern of cosmic etc etc' worked well) | |
| # Could even do this 4 different times to get more diversity of renders | |
| # Add "simple" to prompt before word | |
| def image_classifier(moodboard, prompt): | |
| if moodboard is not None: | |
| pil_image = Image.fromarray(moodboard.astype('uint8')) | |
| openai_response = call_openai(pil_image) | |
| else: | |
| raise gr.Error(f"Please upload a moodboard to control image generation style") | |
| input = { | |
| "prompt": "high quality render of a " + prompt + " which " + openai_response + ", minimalist and simple mockup on a white background", | |
| "output_format": "jpg" | |
| } | |
| try: | |
| output = replicate.run( | |
| "stability-ai/stable-diffusion-3", | |
| input=input | |
| ) | |
| except Exception as e: | |
| raise gr.Error(f"Error: {e}") | |
| try: | |
| image_url = output[0] | |
| response = requests.get(image_url) | |
| img1 = Image.open(io.BytesIO(response.content)) | |
| except Exception as e: | |
| raise gr.Error(f"Image download failed: {e}") | |
| input["aspect_ratio"] = "3:2" | |
| input["cfg"] = 6 | |
| try: | |
| output = replicate.run( | |
| "stability-ai/stable-diffusion-3", | |
| input=input | |
| ) | |
| image_url = output[0] | |
| response = requests.get(image_url) | |
| img2 = Image.open(io.BytesIO(response.content)) | |
| except Exception as e: | |
| raise gr.Error(f"Second image download failed: {e}") | |
| # Call SDXL API with the response from OpenAI | |
| input = { | |
| "width": 768, | |
| "height": 768, | |
| "prompt": "centered high quality render of a " + prompt + " which " + openai_response + ' centered on a plain white background', | |
| "negative_prompt": "worst quality, low quality, illustration, 2d, painting, cartoons, sketch, logo, buttons, markings, text, wires, complex, screws, nails, construction", | |
| "refine": "expert_ensemble_refiner", | |
| "apply_watermark": False, | |
| "num_inference_steps": 25, | |
| "num_outputs": 2, | |
| "guidance_scale": 8.5 | |
| } | |
| output = replicate.run( | |
| "stability-ai/sdxl:7762fd07cf82c948538e41f63f77d685e02b063e37e496e96eefd46c929f9bdc", | |
| input=input | |
| ) | |
| images = [img1, img2] | |
| for i in range(min(len(output), 2)): | |
| image_url = output[i] | |
| response = requests.get(image_url) | |
| images.append(Image.open(io.BytesIO(response.content))) | |
| # Add empty images if fewer than 3 were returned | |
| while len(images) < 4: | |
| images.append(Image.new('RGB', (768, 768), 'gray')) | |
| images.reverse() | |
| return images | |
| demo = gr.Interface(fn=image_classifier, inputs=["image", "text"], outputs=["image", "image", "image", "image"]) | |
| demo.launch(share=True) | |