Stylique commited on
Commit
4a2bd0c
·
verified ·
1 Parent(s): d785382

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +71 -17
  2. loop.py +12 -5
app.py CHANGED
@@ -158,18 +158,34 @@ def process_garment(input_type, text_prompt, base_text_prompt, target_image, bas
158
  target_image_path = os.path.join(temp_dir, "target_image.jpg")
159
 
160
  # Handle different possible image formats from Gradio
 
 
 
 
 
161
  if isinstance(target_image, str):
162
  # Image is a file path
 
163
  shutil.copy(target_image, target_image_path)
164
  elif isinstance(target_image, np.ndarray):
165
  # Image is a numpy array from Gradio
166
- img = Image.fromarray(target_image)
 
 
 
 
 
 
 
167
  img.save(target_image_path)
 
168
  elif hasattr(target_image, 'save'):
169
  # Image is a PIL image
 
170
  target_image.save(target_image_path)
171
  else:
172
- return "Error: Unsupported target image format. Please try again."
 
173
 
174
  print(f"Target image saved to {target_image_path}")
175
 
@@ -317,7 +333,8 @@ def process_garment(input_type, text_prompt, base_text_prompt, target_image, bas
317
  error_details = traceback.format_exc()
318
  print(f"Error during processing: {str(e)}")
319
  print(f"Error details: {error_details}")
320
- return f"Error during processing: {str(e)}. Please check the logs for more details."
 
321
 
322
  def create_interface():
323
  """
@@ -364,15 +381,17 @@ def create_interface():
364
  with gr.Group(visible=False) as image_group:
365
  target_image = gr.Image(
366
  label="Target Garment Image",
367
- type="pil",
368
- image_mode="RGB"
 
369
  )
370
  gr.Markdown("*Upload an image of the desired garment style*")
371
 
372
  base_image = gr.Image(
373
  label="Base Garment Image (Optional)",
374
- type="pil",
375
- image_mode="RGB"
 
376
  )
377
  gr.Markdown("*Upload a base garment image (optional)*")
378
 
@@ -418,7 +437,11 @@ def create_interface():
418
  generate_btn = gr.Button("Generate 3D Garment")
419
 
420
  with gr.Column():
421
- output = gr.File(label="Generated 3D Garment")
 
 
 
 
422
 
423
  gr.Markdown("""
424
  ## Tips:
@@ -426,23 +449,54 @@ def create_interface():
426
  - For text mode: Be specific in your descriptions
427
  - For image mode: Use clear, front-facing garment images
428
  - Higher epochs = better quality but longer processing time
 
429
 
430
  Processing may take several minutes.
431
  """)
432
 
433
- # Toggle visibility based on input mode
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
434
  input_type.change(
435
- fn=lambda mode: (
436
- gr.Group.update(visible=(mode == "Text")),
437
- gr.Group.update(visible=(mode == "Image"))
438
- ),
439
  inputs=[input_type],
440
- outputs=[text_group, image_group]
441
  )
442
 
443
- # Connect the button to the processing function
444
  generate_btn.click(
445
- fn=process_garment,
446
  inputs=[
447
  input_type,
448
  text_prompt,
@@ -455,7 +509,7 @@ def create_interface():
455
  clip_weight,
456
  delta_clip_weight
457
  ],
458
- outputs=[output]
459
  )
460
 
461
  return interface
 
158
  target_image_path = os.path.join(temp_dir, "target_image.jpg")
159
 
160
  # Handle different possible image formats from Gradio
161
+ if target_image is None:
162
+ return None
163
+
164
+ print(f"Target image type: {type(target_image)}")
165
+
166
  if isinstance(target_image, str):
167
  # Image is a file path
168
+ print(f"Copying image from path: {target_image}")
169
  shutil.copy(target_image, target_image_path)
170
  elif isinstance(target_image, np.ndarray):
171
  # Image is a numpy array from Gradio
172
+ print(f"Converting numpy array image with shape: {target_image.shape}")
173
+ # Make sure the array is in RGB format (convert if grayscale)
174
+ if len(target_image.shape) == 2:
175
+ target_image = np.stack([target_image] * 3, axis=2)
176
+ elif target_image.shape[2] == 4: # RGBA
177
+ target_image = target_image[:,:,:3] # Drop alpha channel
178
+
179
+ img = Image.fromarray(target_image.astype(np.uint8))
180
  img.save(target_image_path)
181
+ print(f"Saved numpy array as image to: {target_image_path}")
182
  elif hasattr(target_image, 'save'):
183
  # Image is a PIL image
184
+ print("Saving PIL image")
185
  target_image.save(target_image_path)
186
  else:
187
+ print(f"Unsupported image type: {type(target_image)}")
188
+ return None
189
 
190
  print(f"Target image saved to {target_image_path}")
191
 
 
333
  error_details = traceback.format_exc()
334
  print(f"Error during processing: {str(e)}")
335
  print(f"Error details: {error_details}")
336
+ # Return None instead of an error string to avoid file not found errors with Gradio
337
+ return None
338
 
339
  def create_interface():
340
  """
 
381
  with gr.Group(visible=False) as image_group:
382
  target_image = gr.Image(
383
  label="Target Garment Image",
384
+ sources=["upload", "webcam"],
385
+ type="numpy",
386
+ interactive=True
387
  )
388
  gr.Markdown("*Upload an image of the desired garment style*")
389
 
390
  base_image = gr.Image(
391
  label="Base Garment Image (Optional)",
392
+ sources=["upload", "webcam"],
393
+ type="numpy",
394
+ interactive=True
395
  )
396
  gr.Markdown("*Upload a base garment image (optional)*")
397
 
 
437
  generate_btn = gr.Button("Generate 3D Garment")
438
 
439
  with gr.Column():
440
+ output = gr.File(
441
+ label="Generated 3D Garment",
442
+ file_types=[".obj", ".glb", ".png", ".jpg"],
443
+ file_count="single"
444
+ )
445
 
446
  gr.Markdown("""
447
  ## Tips:
 
449
  - For text mode: Be specific in your descriptions
450
  - For image mode: Use clear, front-facing garment images
451
  - Higher epochs = better quality but longer processing time
452
+ - Output files can be downloaded by clicking on them
453
 
454
  Processing may take several minutes.
455
  """)
456
 
457
+ # Add a status output for errors and messages
458
+ status_output = gr.Markdown("Ready to generate garments. Select an input method and click 'Generate 3D Garment'.")
459
+
460
+ # Define a function to handle mode changes with clearer UI feedback
461
+ def update_mode(mode):
462
+ text_visibility = mode == "Text"
463
+ image_visibility = mode == "Image"
464
+ status_msg = f"Mode changed to {mode}. "
465
+
466
+ if text_visibility:
467
+ status_msg += "Enter garment descriptions and click Generate."
468
+ else:
469
+ status_msg += "Upload garment images and click Generate."
470
+
471
+ return (
472
+ gr.Group.update(visible=text_visibility),
473
+ gr.Group.update(visible=image_visibility),
474
+ status_msg
475
+ )
476
+
477
+ # Function to handle processing with better error feedback
478
+ def process_with_feedback(*args):
479
+ try:
480
+ result = process_garment(*args)
481
+ if result is None:
482
+ return None, "Processing failed. Please check the logs for details."
483
+ return result, "Processing completed successfully! Download your 3D garment file below."
484
+ except Exception as e:
485
+ import traceback
486
+ print(f"Error in interface: {str(e)}")
487
+ print(traceback.format_exc())
488
+ return None, f"Error: {str(e)}"
489
+
490
+ # Toggle visibility based on input mode with better feedback
491
  input_type.change(
492
+ fn=update_mode,
 
 
 
493
  inputs=[input_type],
494
+ outputs=[text_group, image_group, status_output]
495
  )
496
 
497
+ # Connect the button to the processing function with error handling
498
  generate_btn.click(
499
+ fn=process_with_feedback,
500
  inputs=[
501
  input_type,
502
  text_prompt,
 
509
  clip_weight,
510
  delta_clip_weight
511
  ],
512
+ outputs=[output, status_output]
513
  )
514
 
515
  return interface
loop.py CHANGED
@@ -85,12 +85,19 @@ def loop(cfg):
85
 
86
  fe = CLIPVisualEncoder(cfg.consistency_clip_model, cfg.consistency_vit_stride, device)
87
 
88
- if fashion_text or fashion_image:
 
 
89
  target_direction_embeds, delta_direction_embeds = get_fashion_img_embeddings(fclip, cfg, device, True)
90
- elif text_input:
91
- target_direction_embeds, delta_direction_embeds = get_text_embeddings(clip, model, cfg, device)
92
- elif image_input:
93
- target_direction_embeds, delta_direction_embeds = get_img_embeddings(model, preprocess, cfg, device)
 
 
 
 
 
94
 
95
  clip_mean = torch.tensor([0.48154660, 0.45782750, 0.40821073], device=device)
96
  clip_std = torch.tensor([0.26862954, 0.26130258, 0.27577711], device=device)
 
85
 
86
  fe = CLIPVisualEncoder(cfg.consistency_clip_model, cfg.consistency_vit_stride, device)
87
 
88
+ # Use FashionCLIP for all modes to avoid CLIP loading issues
89
+ if fashion_image:
90
+ print('Processing with fashion image embeddings')
91
  target_direction_embeds, delta_direction_embeds = get_fashion_img_embeddings(fclip, cfg, device, True)
92
+ elif fashion_text:
93
+ print('Processing with fashion text embeddings')
94
+ target_direction_embeds, delta_direction_embeds = get_fashion_text_embeddings(fclip, cfg, device)
95
+ elif text_input or image_input:
96
+ print('WARNING: Regular CLIP embeddings are disabled, using FashionCLIP instead')
97
+ if text_input:
98
+ target_direction_embeds, delta_direction_embeds = get_fashion_text_embeddings(fclip, cfg, device)
99
+ else:
100
+ target_direction_embeds, delta_direction_embeds = get_fashion_img_embeddings(fclip, cfg, device, True)
101
 
102
  clip_mean = torch.tensor([0.48154660, 0.45782750, 0.40821073], device=device)
103
  clip_std = torch.tensor([0.26862954, 0.26130258, 0.27577711], device=device)