File size: 15,078 Bytes
2bc1f58
cfe4480
 
 
3546648
207e38f
3300d90
3c7e33c
0390fed
3b76ad6
 
e571768
3263fff
f88c368
4003ec5
e610a17
4003ec5
6eba959
4003ec5
 
6eba959
4003ec5
99c30e3
6eba959
9d8da4e
3b76ad6
 
 
 
 
 
 
4003ec5
 
3b76ad6
38a4b70
3b76ad6
 
e610a17
 
4003ec5
 
 
 
 
 
 
 
 
3b76ad6
f2e5577
 
3b76ad6
 
 
 
4003ec5
 
 
e610a17
4003ec5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e610a17
 
e571768
 
 
4003ec5
 
 
 
 
e571768
e610a17
38a4b70
0352de5
e571768
 
4003ec5
 
 
 
 
e571768
4003ec5
e571768
4003ec5
f88c368
7fc69e4
4003ec5
 
 
 
 
 
 
 
 
 
f88c368
4003ec5
 
 
 
 
f88c368
4003ec5
f88c368
4003ec5
f88c368
e571768
4003ec5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e571768
4003ec5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e571768
 
4003ec5
 
 
 
e571768
4003ec5
e571768
4003ec5
e571768
4003ec5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99c30e3
 
 
 
 
 
 
4003ec5
 
 
 
 
 
 
e571768
 
4003ec5
 
a4b3229
f2e5577
4003ec5
a4b3229
 
 
 
 
 
 
 
 
 
4003ec5
 
 
 
 
 
207e38f
4003ec5
9e40d74
4003ec5
48070c9
4003ec5
99c30e3
4003ec5
 
 
 
 
 
 
 
48070c9
 
 
4003ec5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48070c9
 
 
 
4003ec5
48070c9
4003ec5
 
 
 
99c30e3
 
e5748f0
99c30e3
 
 
 
 
4003ec5
 
 
 
 
 
 
 
 
 
 
b405950
99c30e3
 
 
4003ec5
99c30e3
4003ec5
 
 
 
 
 
 
 
b405950
99c30e3
 
 
 
4003ec5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48070c9
4003ec5
 
48070c9
0390fed
 
 
4003ec5
0390fed
4003ec5
 
0390fed
4003ec5
0390fed
3b76ad6
716b26d
0390fed
 
 
 
6cb9b6b
0390fed
 
716b26d
 
 
0390fed
 
 
 
 
 
4003ec5
716b26d
 
 
4003ec5
716b26d
4003ec5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
from fastapi import FastAPI, File, UploadFile, Form, BackgroundTasks
import httpx
from fastapi.responses import JSONResponse
from io import BytesIO
import fal_client
import os
import base_generator
from openai    import OpenAI
import base64
import redis
import uuid
from time import sleep
import json
from functions import combine_images_side_by_side
from poses import poses
from typing import Optional
from base64 import decodebytes
from dotenv import load_dotenv
import random
import asyncio
load_dotenv()
import ast
from PIL import Image

app = FastAPI()
openai_client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))

#This is a temp redis instance, replace with one on aws when going into production!
r = redis.Redis(
  host='cute-wildcat-41430.upstash.io',
  port=6379,
  password='AaHWAAIjcDFhZDVlOGUyMDQ0ZGQ0MTZmODA4ZjdkNzc4ZDVhZjUzZHAxMA',
  ssl=True,
  decode_responses=True
)

@app.post("/virtualTryOn", summary="Virtual try on single call", 
    description="Virtual try on single call for complete outfit try on")
async def virtual_try_on(
    background_tasks: BackgroundTasks,
    num_images: int = 2,
    num_variations: int = 2,
    person_reference: str = "",
    person_image: UploadFile = File(None), 
    person_face_image: UploadFile = File(None), 
    garment_image_1: UploadFile = File(None), 
    garment_image_2: UploadFile = File(None),
    garment_image_3: UploadFile = File(None), 
    garment_image_4: UploadFile = File(None),):
    """
    Pass at least the person_image and one garment_image. Pose is random (from the poses.py file). if not enough garments are provided, the system will generate the rest of the outfit.
    Variations are created after base images are generated. sometimes they get a bit soft, but they produce interesting results.
    Default number of images is 4.
    """
    request_id = str(uuid.uuid4())
    r.set(request_id, "pending")
    print("request_id: ", request_id)
    model_images = []
    garment_images = [] 
    # Read all files first
    person_image_data = await safe_read_file(person_image)
    person_face_image_data = await safe_read_file(person_face_image)
    if person_image is not None:
        model_images.append(person_image_data)
    if person_face_image is not None:
        model_images.append(person_face_image_data)
    if garment_image_1 is not None:
        garment_image_1_data = await safe_read_file(garment_image_1)
        garment_images.append(garment_image_1_data)
    if garment_image_2 is not None:
        garment_image_2_data = await safe_read_file(garment_image_2)
        garment_images.append(garment_image_2_data)
    if garment_image_3 is not None:
        garment_image_3_data = await safe_read_file(garment_image_3)
        garment_images.append(garment_image_3_data)
    if garment_image_4 is not None:
        garment_image_4_data = await safe_read_file(garment_image_4)
        garment_images.append(garment_image_4_data)
    
    # Launch background task with the actual data
    background_tasks.add_task(
        run_virtual_tryon_pipeline,
        request_id,
        person_reference,
        model_images,
        garment_images,
        num_images,
        num_variations
    )
    return {"request_id": request_id}


async def run_virtual_tryon_pipeline(
    request_id,
    person_reference,
    model_images: list,
    garment_images: list,
    num_images: int,
    num_variations: int,
):
    output_images = []
    try:
        #Step 1: Check incoming data
        r.set(request_id, "checking incoming data...")

        #Step 2: Describe garment
        r.set(request_id, "Describing garment")
        garment_descriptions = []
        for image in garment_images:
            garment_description = await describe_garment(image)
            garment_descriptions.append(garment_description)
       
       
        #Step 4: Create prompt
        r.set(request_id, "Creating prompt")
        try:
            completed_outfit = await complete_outfit(garment_descriptions)
            prompt = f"{person_reference} wearing {completed_outfit['description']} in front of a white background"
            pose = random.choice(poses)
            prompt += f" {pose}"
            r.set(request_id + "_prompt", prompt)
        except Exception as e:
            r.set(request_id, "error creating prompt")
            r.set(f"{request_id}_error", str(e))
            print(f"error creating prompt: {e}")
            return

        #Step 5: Create base images
        r.set(request_id, "Creating base images")
        r.set(request_id + "_content", "")
        try:    
            base_images = await create_image(garment_images, model_images, prompt, num_images, request_id)
            output_images.append(base_images)
            r.set(request_id, "base images created")
            print(str(base_images))
            r.set(request_id + "_content", str(base_images))
        except Exception as e:
            r.set(request_id, "error creating base images")
            r.set(f"{request_id}_error", str(e))
            if isinstance(base_images, asyncio.Future):
                base_images.cancel()
            return

        #Step 6: Create variations
        if num_variations > 0:
            r.set(request_id, "Creating variations")
            r.set(request_id + "_variations", "")   
            try:
                for image in base_images['images'  ]:
                    variations = await make_versions(image['url'], num_variations)
                output_images.append(variations)
                r.set(request_id + "_variations", str(variations))
                r.set(request_id, "variations created")
            except Exception as e:
                r.set(request_id, "error creating variations")
                r.set(f"{request_id}_error", str(e))
                return
        else:
            r.set(request_id, "no variations created")


        #Step 7: Final result
        r.set(request_id, "Final result")
        result_images = get_result_images(request_id)
        r.set(f"{request_id}_result", result_images)
        r.set(request_id, "completed")
        return {"request_id": request_id, "status": "completed", "result": result_images}
    except Exception as e:
        #r.set(request_id, "error")
        r.set(f"{request_id}_error", str(e))
        return {"request_id": request_id, "status": "error", "error": str(e)}

def get_result_images(request_id):
    result_images = []
    images = ast.literal_eval( r.get(request_id + "_content"))
    print(type(images))
    print("images coming here: ")
    print(images)   
    variations = ast.literal_eval(r.get(request_id + "_variations")["images"])
    print("variations coming here: ")   
    print(variations)
    print("all images coming here: ")
    all_images = images + variations
    print(all_images)
    for image in all_images:
        result_images.append(image['url'])
    return result_images

@app.get("/result/{request_id}")
async def check_status(request_id: str):
    output_images = []
    status = r.get(request_id)
    images = ast.literal_eval( r.get(request_id + "_content"))
    for image in images['images']:
        output_images.append(image['url'])
    try:    
        if r.get(request_id + "_variations") != "null":
            variations = ast.literal_eval(r.get(request_id + "_variations"))
            for image in variations['images']:
                output_images.append(image['url'])
    except Exception as e:
        print("no variations")
    prompt = r.get(request_id + "_prompt")
    result = r.get(request_id + "_result")
    error = r.get(request_id + "_error")
    if not status:
        return {"error": "Invalid request_id"}
    #return {"request_id": request_id, "status": status, "\n images": images, "\n variations": variations, "\n prompt": prompt, "\n result": result, "\n error": error}
    return {"images": output_images }


@app.get("/status/{request_id}")
async def check_status(request_id: str):
    output_images = []
    variations = []
    status = r.get(request_id)
    images = ast.literal_eval( r.get(request_id + "_content"))
    for image in images['images']:
        output_images.append(image['url'])
    try:    
        if r.get(request_id + "_variations") != "null":
            variations = ast.literal_eval(r.get(request_id + "_variations"))
            for image in variations['images']:
                output_images.append(image['url'])
    except Exception as e:
        print("no variations")
    prompt = r.get(request_id + "_prompt")
    result = r.get(request_id + "_result")
    error = r.get(request_id + "_error")
    if not status:
        return {"error": "Invalid request_id"}
    return {"request_id": request_id, "status": status, "\n images": images, "\n variations": variations, "\n prompt": prompt, "\n result": result, "\n error": error}

#endpoints related to base image generation

# Function related to virtual outfit try on

async def safe_read_file(file: UploadFile):
    #print("safe read file")
    #print(file)
    if file is None:
        #print("file is none")
        return None
    #print("file is not none")
    content = await file.read()
    #print("content read")
    return content if content else None



async def make_versions(input_image_url: str, num_images: int):
    try:
        # Call external API
        handler = await fal_client.submit_async(
            "fal-ai/instant-character",
            arguments={
                "prompt": "Model posing in front of white background",
                "image_url": input_image_url,
                "num_inference_steps": 50,
                "num_images": num_images,
                "safety_tolerance": "5",
                "guidance_scale": 20,
                "image_size": "portrait_16_9"
            }
        )
        
        # Get the result first
        result = await handler.get()
        
        # Now we can safely iterate events
        async for event in handler.iter_events(with_logs=False):
            print(event)
            
        return result
    except Exception as e:
        print(f"Error in make_versions: {e}")
        raise e




#Auxiliary functions

#@app.post("/describeGarment", summary="Describe a garment or garments in the image",
#    description="Passes the garment image to openai to describe it to improve generation of base images")
#async def describe_garment(image: UploadFile = File(...)):

def get_file_extension(filedata):
    try:
        img = Image.open(BytesIO(filedata))
        return img.format.lower()
    except Exception as e:
        print(f"Error getting file extension: {e}")
        return None

async def create_image(garment_images: list, 
    model_images: list,
    prompt: str,
    num_images: int,
    request_id: str,    
    ):
    print("Lets create the images!")
    #create urls for image
    image_urls = []
    img_no = 0
    for image in garment_images:
        print("looping rhougth images")
        extension = get_file_extension(image)
        if extension is None:
            extension = "jpg"
        #Save the uploaded file tempoarily
        file_path = f"/tmp/img_{request_id}_{img_no}.{extension}"
        with open(file_path, "wb") as f:
            f.write(image)
        image_url = fal_client.upload_file(file_path)
        image_urls.append(image_url)
        os.remove(file_path)
        #print("Garment image uploaded")
        img_no += 1
    for image in model_images:
        print("looping rhougth model images")
        extension = get_file_extension(image)
        if extension is None:
            extension = "jpg"
        file_path = f"/tmp/img_{request_id}_{img_no}.{extension}"
        with open(file_path, "wb") as f:
            f.write(image)
        image_url = fal_client.upload_file(file_path)
        image_urls.append(image_url)
        os.remove(file_path)
        #print("Model image uploaded")


    #Generate images using flux kontext
    try:
        handler = await fal_client.submit_async(
            "fal-ai/flux-pro/kontext/multi",
            arguments={
                "prompt": prompt,
                "guidance_scale": 18,
                "num_images": num_images,
                "safety_tolerance": "5",
                "output_format": "jpeg",
                "image_urls": image_urls,
                "aspect_ratio": "9:16"
            }
        )
        
        # Get the result first before iterating events
        result = await handler.get()
        
        # Now we can safely iterate events
        async for event in handler.iter_events(with_logs=True):
            print(event)
        
        return result
    except Exception as e:
        print(f"Error in create_image: {e}")
        raise e

    #print(result)
    #print("Images generated")
    return(result)



async def describe_garment(image):

    print("describe garment process running")
    image_bytes = image
    base64_image = base64.b64encode(image_bytes).decode("utf-8")
    #print(base64_image)

    response = openai_client.chat.completions.create(
        model="gpt-4o",
        messages=[
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": "Describe this garment or garments in the image. the format should be ready to be inserted into the sentence: a man wearing ..... in front of a white background. describe the garments as type and fit, but only overall colors and not specific design, not the rest of the images. make sure to only return the description, not the complete sentence."},
                    {
                        "type": "image_url",
                        "image_url": {
                            "url": f"data:image/jpeg;base64,{base64_image}"
                        }
                    }
                ]
            }
        ],
        max_tokens=300
    )
    #print(response)
    return {"description": response.choices[0].message.content}


async def complete_outfit(outfit_desc):

    print("complete outfit process running")
    current_outfit = ""
    for description in outfit_desc:
        print(f"description: {description['description']}")
        current_outfit += description['description']

    response = openai_client.chat.completions.create(
        model="gpt-4o",
        messages=[
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": f"If thie outfit is not complete (ie, it is missing a top or bottom, or a pair of pants or a pair of shoes), complete it, making sure the model is wearing a complete outfit. for the garments not already mentioned, keep it simple to not draw too much attention to those new garments. Complete this outfit description: {current_outfit}. Return only the description of the complete outfit including the current outfit, not the complete sentence in the format: blue strap top and pink skirt and white sneakers"},
                ]
            }
        ],
        max_tokens=300
    )
    #print(response)
    return {"description": response.choices[0].message.content}