Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, File, UploadFile, Form, BackgroundTasks | |
import httpx | |
from fastapi.responses import JSONResponse | |
from io import BytesIO | |
import fal_client | |
import os | |
import base_generator | |
from openai import OpenAI | |
import base64 | |
import redis | |
import uuid | |
from time import sleep | |
import json | |
from functions import combine_images_side_by_side | |
from poses import poses | |
from typing import Optional | |
from base64 import decodebytes | |
from dotenv import load_dotenv | |
import random | |
import asyncio | |
load_dotenv() | |
import ast | |
from PIL import Image | |
app = FastAPI() | |
openai_client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) | |
#This is a temp redis instance, replace with one on aws when going into production! | |
r = redis.Redis( | |
host='cute-wildcat-41430.upstash.io', | |
port=6379, | |
password='AaHWAAIjcDFhZDVlOGUyMDQ0ZGQ0MTZmODA4ZjdkNzc4ZDVhZjUzZHAxMA', | |
ssl=True, | |
decode_responses=True | |
) | |
async def virtual_try_on( | |
background_tasks: BackgroundTasks, | |
num_images: int = 2, | |
num_variations: int = 2, | |
person_reference: str = "", | |
person_image: UploadFile = File(None), | |
person_face_image: UploadFile = File(None), | |
garment_image_1: UploadFile = File(None), | |
garment_image_2: UploadFile = File(None), | |
garment_image_3: UploadFile = File(None), | |
garment_image_4: UploadFile = File(None),): | |
""" | |
Pass at least the person_image and one garment_image. Pose is random (from the poses.py file). if not enough garments are provided, the system will generate the rest of the outfit. | |
Variations are created after base images are generated. sometimes they get a bit soft, but they produce interesting results. | |
Default number of images is 4. | |
""" | |
request_id = str(uuid.uuid4()) | |
r.set(request_id, "pending") | |
print("request_id: ", request_id) | |
model_images = [] | |
garment_images = [] | |
# Read all files first | |
person_image_data = await safe_read_file(person_image) | |
person_face_image_data = await safe_read_file(person_face_image) | |
if person_image is not None: | |
model_images.append(person_image_data) | |
if person_face_image is not None: | |
model_images.append(person_face_image_data) | |
if garment_image_1 is not None: | |
garment_image_1_data = await safe_read_file(garment_image_1) | |
garment_images.append(garment_image_1_data) | |
if garment_image_2 is not None: | |
garment_image_2_data = await safe_read_file(garment_image_2) | |
garment_images.append(garment_image_2_data) | |
if garment_image_3 is not None: | |
garment_image_3_data = await safe_read_file(garment_image_3) | |
garment_images.append(garment_image_3_data) | |
if garment_image_4 is not None: | |
garment_image_4_data = await safe_read_file(garment_image_4) | |
garment_images.append(garment_image_4_data) | |
# Launch background task with the actual data | |
background_tasks.add_task( | |
run_virtual_tryon_pipeline, | |
request_id, | |
person_reference, | |
model_images, | |
garment_images, | |
num_images, | |
num_variations | |
) | |
return {"request_id": request_id} | |
async def run_virtual_tryon_pipeline( | |
request_id, | |
person_reference, | |
model_images: list, | |
garment_images: list, | |
num_images: int, | |
num_variations: int, | |
): | |
output_images = [] | |
try: | |
#Step 1: Check incoming data | |
r.set(request_id, "checking incoming data...") | |
#Step 2: Describe garment | |
r.set(request_id, "Describing garment") | |
garment_descriptions = [] | |
for image in garment_images: | |
garment_description = await describe_garment(image) | |
garment_descriptions.append(garment_description) | |
#Step 4: Create prompt | |
r.set(request_id, "Creating prompt") | |
try: | |
completed_outfit = await complete_outfit(garment_descriptions) | |
prompt = f"{person_reference} wearing {completed_outfit['description']} in front of a white background" | |
pose = random.choice(poses) | |
prompt += f" {pose}" | |
r.set(request_id + "_prompt", prompt) | |
except Exception as e: | |
r.set(request_id, "error creating prompt") | |
r.set(f"{request_id}_error", str(e)) | |
print(f"error creating prompt: {e}") | |
return | |
#Step 5: Create base images | |
r.set(request_id, "Creating base images") | |
r.set(request_id + "_content", "") | |
try: | |
base_images = await create_image(garment_images, model_images, prompt, num_images, request_id) | |
output_images.append(base_images) | |
r.set(request_id, "base images created") | |
print(str(base_images)) | |
r.set(request_id + "_content", str(base_images)) | |
except Exception as e: | |
r.set(request_id, "error creating base images") | |
r.set(f"{request_id}_error", str(e)) | |
if isinstance(base_images, asyncio.Future): | |
base_images.cancel() | |
return | |
#Step 6: Create variations | |
if num_variations > 0: | |
r.set(request_id, "Creating variations") | |
r.set(request_id + "_variations", "") | |
try: | |
for image in base_images['images' ]: | |
variations = await make_versions(image['url'], num_variations) | |
output_images.append(variations) | |
r.set(request_id + "_variations", str(variations)) | |
r.set(request_id, "variations created") | |
except Exception as e: | |
r.set(request_id, "error creating variations") | |
r.set(f"{request_id}_error", str(e)) | |
return | |
else: | |
r.set(request_id, "no variations created") | |
#Step 7: Final result | |
r.set(request_id, "Final result") | |
result_images = get_result_images(request_id) | |
r.set(f"{request_id}_result", result_images) | |
r.set(request_id, "completed") | |
return {"request_id": request_id, "status": "completed", "result": result_images} | |
except Exception as e: | |
#r.set(request_id, "error") | |
r.set(f"{request_id}_error", str(e)) | |
return {"request_id": request_id, "status": "error", "error": str(e)} | |
def get_result_images(request_id): | |
result_images = [] | |
images = ast.literal_eval( r.get(request_id + "_content")) | |
print(type(images)) | |
print("images coming here: ") | |
print(images) | |
variations = ast.literal_eval(r.get(request_id + "_variations")["images"]) | |
print("variations coming here: ") | |
print(variations) | |
print("all images coming here: ") | |
all_images = images + variations | |
print(all_images) | |
for image in all_images: | |
result_images.append(image['url']) | |
return result_images | |
async def check_status(request_id: str): | |
output_images = [] | |
status = r.get(request_id) | |
images = ast.literal_eval( r.get(request_id + "_content")) | |
for image in images['images']: | |
output_images.append(image['url']) | |
try: | |
if r.get(request_id + "_variations") != "null": | |
variations = ast.literal_eval(r.get(request_id + "_variations")) | |
for image in variations['images']: | |
output_images.append(image['url']) | |
except Exception as e: | |
print("no variations") | |
prompt = r.get(request_id + "_prompt") | |
result = r.get(request_id + "_result") | |
error = r.get(request_id + "_error") | |
if not status: | |
return {"error": "Invalid request_id"} | |
#return {"request_id": request_id, "status": status, "\n images": images, "\n variations": variations, "\n prompt": prompt, "\n result": result, "\n error": error} | |
return {"images": output_images } | |
async def check_status(request_id: str): | |
output_images = [] | |
variations = [] | |
status = r.get(request_id) | |
images = ast.literal_eval( r.get(request_id + "_content")) | |
for image in images['images']: | |
output_images.append(image['url']) | |
try: | |
if r.get(request_id + "_variations") != "null": | |
variations = ast.literal_eval(r.get(request_id + "_variations")) | |
for image in variations['images']: | |
output_images.append(image['url']) | |
except Exception as e: | |
print("no variations") | |
prompt = r.get(request_id + "_prompt") | |
result = r.get(request_id + "_result") | |
error = r.get(request_id + "_error") | |
if not status: | |
return {"error": "Invalid request_id"} | |
return {"request_id": request_id, "status": status, "\n images": images, "\n variations": variations, "\n prompt": prompt, "\n result": result, "\n error": error} | |
#endpoints related to base image generation | |
# Function related to virtual outfit try on | |
async def safe_read_file(file: UploadFile): | |
#print("safe read file") | |
#print(file) | |
if file is None: | |
#print("file is none") | |
return None | |
#print("file is not none") | |
content = await file.read() | |
#print("content read") | |
return content if content else None | |
async def make_versions(input_image_url: str, num_images: int): | |
try: | |
# Call external API | |
handler = await fal_client.submit_async( | |
"fal-ai/instant-character", | |
arguments={ | |
"prompt": "Model posing in front of white background", | |
"image_url": input_image_url, | |
"num_inference_steps": 50, | |
"num_images": num_images, | |
"safety_tolerance": "5", | |
"guidance_scale": 20, | |
"image_size": "portrait_16_9" | |
} | |
) | |
# Get the result first | |
result = await handler.get() | |
# Now we can safely iterate events | |
async for event in handler.iter_events(with_logs=False): | |
print(event) | |
return result | |
except Exception as e: | |
print(f"Error in make_versions: {e}") | |
raise e | |
#Auxiliary functions | |
#@app.post("/describeGarment", summary="Describe a garment or garments in the image", | |
# description="Passes the garment image to openai to describe it to improve generation of base images") | |
#async def describe_garment(image: UploadFile = File(...)): | |
def get_file_extension(filedata): | |
try: | |
img = Image.open(BytesIO(filedata)) | |
return img.format.lower() | |
except Exception as e: | |
print(f"Error getting file extension: {e}") | |
return None | |
async def create_image(garment_images: list, | |
model_images: list, | |
prompt: str, | |
num_images: int, | |
request_id: str, | |
): | |
print("Lets create the images!") | |
#create urls for image | |
image_urls = [] | |
img_no = 0 | |
for image in garment_images: | |
print("looping rhougth images") | |
extension = get_file_extension(image) | |
if extension is None: | |
extension = "jpg" | |
#Save the uploaded file tempoarily | |
file_path = f"/tmp/img_{request_id}_{img_no}.{extension}" | |
with open(file_path, "wb") as f: | |
f.write(image) | |
image_url = fal_client.upload_file(file_path) | |
image_urls.append(image_url) | |
os.remove(file_path) | |
#print("Garment image uploaded") | |
img_no += 1 | |
for image in model_images: | |
print("looping rhougth model images") | |
extension = get_file_extension(image) | |
if extension is None: | |
extension = "jpg" | |
file_path = f"/tmp/img_{request_id}_{img_no}.{extension}" | |
with open(file_path, "wb") as f: | |
f.write(image) | |
image_url = fal_client.upload_file(file_path) | |
image_urls.append(image_url) | |
os.remove(file_path) | |
#print("Model image uploaded") | |
#Generate images using flux kontext | |
try: | |
handler = await fal_client.submit_async( | |
"fal-ai/flux-pro/kontext/multi", | |
arguments={ | |
"prompt": prompt, | |
"guidance_scale": 18, | |
"num_images": num_images, | |
"safety_tolerance": "5", | |
"output_format": "jpeg", | |
"image_urls": image_urls, | |
"aspect_ratio": "9:16" | |
} | |
) | |
# Get the result first before iterating events | |
result = await handler.get() | |
# Now we can safely iterate events | |
async for event in handler.iter_events(with_logs=True): | |
print(event) | |
return result | |
except Exception as e: | |
print(f"Error in create_image: {e}") | |
raise e | |
#print(result) | |
#print("Images generated") | |
return(result) | |
async def describe_garment(image): | |
print("describe garment process running") | |
image_bytes = image | |
base64_image = base64.b64encode(image_bytes).decode("utf-8") | |
#print(base64_image) | |
response = openai_client.chat.completions.create( | |
model="gpt-4o", | |
messages=[ | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "text", "text": "Describe this garment or garments in the image. the format should be ready to be inserted into the sentence: a man wearing ..... in front of a white background. describe the garments as type and fit, but only overall colors and not specific design, not the rest of the images. make sure to only return the description, not the complete sentence."}, | |
{ | |
"type": "image_url", | |
"image_url": { | |
"url": f"data:image/jpeg;base64,{base64_image}" | |
} | |
} | |
] | |
} | |
], | |
max_tokens=300 | |
) | |
#print(response) | |
return {"description": response.choices[0].message.content} | |
async def complete_outfit(outfit_desc): | |
print("complete outfit process running") | |
current_outfit = "" | |
for description in outfit_desc: | |
print(f"description: {description['description']}") | |
current_outfit += description['description'] | |
response = openai_client.chat.completions.create( | |
model="gpt-4o", | |
messages=[ | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "text", "text": f"If thie outfit is not complete (ie, it is missing a top or bottom, or a pair of pants or a pair of shoes), complete it, making sure the model is wearing a complete outfit. for the garments not already mentioned, keep it simple to not draw too much attention to those new garments. Complete this outfit description: {current_outfit}. Return only the description of the complete outfit including the current outfit, not the complete sentence in the format: blue strap top and pink skirt and white sneakers"}, | |
] | |
} | |
], | |
max_tokens=300 | |
) | |
#print(response) | |
return {"description": response.choices[0].message.content} | |