Harsh-7300 commited on
Commit
968edaf
·
verified ·
1 Parent(s): 399b9e6

Upload 23 files

Browse files
app.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import random
4
+ import gradio as gr
5
+ from gradio_client import Client
6
+
7
+ machine_number = 0
8
+ model = os.path.join(os.path.dirname(__file__), "models/simon_online/Simon_0.png")
9
+ url = os.environ['OA_IP_ADDRESS']
10
+ print("API:", url)
11
+ client = Client(url)
12
+
13
+ MODEL_MAP = {
14
+ "AI Model Simon_0": 'models/simon_online/Simon_0.png',
15
+ "AI Model Xuanxuan_0": 'models/xiaoxuan_online/Xuanxuan_0.png',
16
+ "AI Model Yifeng_0": 'models/yifeng_online/Yifeng_0.png'
17
+ }
18
+
19
+
20
+ def add_waterprint(img):
21
+ h, w, _ = img.shape
22
+ img = cv2.putText(img, 'AI VTON', (int(0.3 * w), h - 20), cv2.FONT_HERSHEY_PLAIN, 2,
23
+ (128, 128, 128), 2, cv2.LINE_AA)
24
+
25
+ return img
26
+
27
+
28
+ def get_tryon_result(model_name, garment1, garment2, seed=1234):
29
+ # _model = "AI Model " + model_name.split("\\")[-1].split(".")[0] # windows
30
+ _model = "AI Model " + model_name.split("/")[-1].split(".")[0] # linux
31
+ print("Use Model:", _model)
32
+ seed = random.randint(0, 1222222222)
33
+ result = client.predict(
34
+ model_name,
35
+ garment1,
36
+ garment2,
37
+ api_name="/get_tryon_result",
38
+ fn_index=seed
39
+ )
40
+ final_img = remove_watermark2(result)
41
+ return final_img
42
+
43
+
44
+ def remove_watermark2(path):
45
+ img = cv2.imread(path)
46
+ img_ = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
47
+ h, w, _ = img.shape
48
+
49
+ y_start = max(h - 45, 0)
50
+ y_end = h
51
+ x_start = max(int(0.3 * w), 0)
52
+ x_end = w
53
+
54
+ img_[y_start:y_end, x_start:x_end, :] = [255, 255, 255]
55
+
56
+ return img_
57
+
58
+
59
+ with gr.Blocks(css=".output-image, .input-image, .image-preview {height: 400px !important} ") as demo:
60
+ # gr.Markdown("# Outfit Anyone v0.9")
61
+
62
+ with gr.Row():
63
+ with gr.Column():
64
+ init_image = gr.Image(sources='clipboard', type="filepath", label="model", value=model)
65
+ example = gr.Examples(inputs=init_image,
66
+ examples_per_page=4,
67
+ examples=[
68
+ os.path.join(os.path.dirname(__file__), MODEL_MAP.get('AI Model Simon_0')),
69
+ os.path.join(os.path.dirname(__file__),
70
+ MODEL_MAP.get('AI Model Xuanxuan_0')),
71
+ os.path.join(os.path.dirname(__file__), MODEL_MAP.get('AI Model Yifeng_0')),
72
+ ])
73
+ with gr.Column():
74
+ gr.HTML(
75
+ """
76
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
77
+ <div>
78
+ </div>
79
+ </div>
80
+ """)
81
+ with gr.Row():
82
+ garment_top = gr.Image(sources='upload', type="filepath", label="top garment")
83
+ example_top = gr.Examples(inputs=garment_top,
84
+ examples_per_page=5,
85
+ examples=[os.path.join(os.path.dirname(__file__), "garments/top222.JPG"),
86
+ os.path.join(os.path.dirname(__file__), "garments/top5.png"),
87
+ os.path.join(os.path.dirname(__file__), "garments/top333.png"),
88
+ os.path.join(os.path.dirname(__file__), "garments/dress1.png"),
89
+ os.path.join(os.path.dirname(__file__), "garments/dress2.png"),
90
+ ])
91
+ garment_down = gr.Image(sources='upload', type="filepath", label="lower garment")
92
+ example_down = gr.Examples(inputs=garment_down,
93
+ examples_per_page=5,
94
+ examples=[os.path.join(os.path.dirname(__file__), "garments/bottom1.png"),
95
+ os.path.join(os.path.dirname(__file__), "garments/bottom2.PNG"),
96
+ os.path.join(os.path.dirname(__file__), "garments/bottom3.JPG"),
97
+ os.path.join(os.path.dirname(__file__), "garments/bottom4.PNG"),
98
+ os.path.join(os.path.dirname(__file__), "garments/bottom5.png"),
99
+ ])
100
+
101
+ run_button = gr.Button(value="Run")
102
+ with gr.Column():
103
+ gallery = gr.Image()
104
+
105
+ run_button.click(fn=get_tryon_result,
106
+ inputs=[
107
+ init_image,
108
+ garment_top,
109
+ garment_down,
110
+ ],
111
+ outputs=[gallery],
112
+ concurrency_limit=4)
113
+
114
+
115
+ if __name__ == "__main__":
116
+ demo.queue(max_size=10)
117
+ demo.launch(share=False, server_name='0.0.0.0',server_port=7860)
examples/basemodel.png ADDED
examples/garment1.jpg ADDED
examples/garment1.png ADDED
examples/garment2.jpg ADDED
examples/garment2.png ADDED
examples/garment3.png ADDED
examples/result1.png ADDED
examples/result2.png ADDED
examples/result3.png ADDED
garments/bottom1.png ADDED
garments/bottom2.PNG ADDED
garments/bottom3.JPG ADDED
garments/bottom4.PNG ADDED
garments/bottom5.png ADDED
garments/dress1.png ADDED
garments/dress2.png ADDED
garments/top111.png ADDED
garments/top222.JPG ADDED
garments/top3.JPG ADDED
garments/top333.png ADDED
garments/top4.png ADDED
garments/top5.png ADDED