varad-simpli commited on
Commit
f87de7a
·
verified ·
1 Parent(s): 9be23cd

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +330 -0
main.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from concurrent.futures import ThreadPoolExecutor
3
+ import time
4
+ import json
5
+ import tritonclient.grpc as grpcclient
6
+ from tritonclient.utils import *
7
+ import queue
8
+ from functools import partial
9
+ import random
10
+
11
+
12
+ run_multiple_tests = False
13
+
14
+ resp_list = []
15
+ scode_breakup = {}
16
+
17
+ def np_to_server_dtype(np_dtype):
18
+ if np_dtype == bool:
19
+ return "BOOL"
20
+ elif np_dtype == np.int8:
21
+ return "INT8"
22
+ elif np_dtype == np.int16:
23
+ return "INT16"
24
+ elif np_dtype == np.int32:
25
+ return "INT32"
26
+ elif np_dtype == np.int64:
27
+ return "INT64"
28
+ elif np_dtype == np.uint8:
29
+ return "UINT8"
30
+ elif np_dtype == np.uint16:
31
+ return "UINT16"
32
+ elif np_dtype == np.uint32:
33
+ return "UINT32"
34
+ elif np_dtype == np.uint64:
35
+ return "UINT64"
36
+ elif np_dtype == np.float16:
37
+ return "FP16"
38
+ elif np_dtype == np.float32:
39
+ return "FP32"
40
+ elif np_dtype == np.float64:
41
+ return "FP64"
42
+ elif np_dtype == np.object_ or np_dtype.type == np.bytes_:
43
+ return "BYTES"
44
+ return None
45
+
46
+ class UserData:
47
+ def __init__(self):
48
+ self._completed_requests = queue.Queue()
49
+
50
+ def callback(user_data, result, error):
51
+ if error:
52
+ user_data._completed_requests.put(error)
53
+ else:
54
+ user_data._completed_requests.put(result)
55
+
56
+ def prepare_tensor(name: str, data: np.ndarray):
57
+ server_input = grpcclient.InferInput(name=name, shape=data.shape,
58
+ datatype=np_to_server_dtype(data.dtype))
59
+ server_input.set_data_from_numpy(data)
60
+ return server_input
61
+
62
+
63
+ def process_and_send_request(sample_request):
64
+ prompt = sample_request['prompt']
65
+ negative_prompt = sample_request['negative_prompt'] if 'negative_prompt' in sample_request else None
66
+ height = sample_request['height'] if 'height' in sample_request else None
67
+ width = sample_request['width'] if 'width' in sample_request else None
68
+ num_images_per_prompt = sample_request['num_images_per_prompt'] if 'num_images_per_prompt' in sample_request else 1
69
+ num_inference_steps = sample_request['num_inference_steps'] if 'num_inference_steps' in sample_request else 20
70
+ image = sample_request['image'] if 'image' in sample_request else None
71
+ mask_image = sample_request['mask_image'] if 'mask_image' in sample_request else None
72
+ control_images = sample_request['control_images'] if 'control_images' in sample_request else None
73
+ control_weightages = sample_request['control_weightages'] if 'control_weightages' in sample_request else None
74
+ control_modes = sample_request['control_modes'] if 'control_modes' in sample_request else None
75
+ seed = sample_request['seed'] if 'seed' in sample_request else -1
76
+ guidance_scale = sample_request['guidance_scale'] if 'guidance_scale' in sample_request else 7.5
77
+ strength = sample_request['strength'] if 'strength' in sample_request else 1
78
+ scheduler = sample_request['scheduler'] if 'scheduler' in sample_request else "EULER-A"
79
+ model_type = sample_request['model_type'] if 'model_type' in sample_request else None
80
+ lora_weights = sample_request['lora_weights'] if 'lora_weights' in sample_request else None
81
+ control_guidance_start = sample_request['control_guidance_start'] if 'control_guidance_start' in sample_request else None
82
+ control_guidance_end = sample_request['control_guidance_end'] if 'control_guidance_end' in sample_request else None
83
+
84
+ inputs = []
85
+ inputs.append(prepare_tensor("prompt", np.array([prompt], dtype=np.object_)))
86
+
87
+ if negative_prompt is not None:
88
+ inputs.append(prepare_tensor("negative_prompt", np.array([negative_prompt], dtype=np.object_)))
89
+
90
+ if height is not None:
91
+ inputs.append(prepare_tensor("height", np.array([height], dtype=np.int32)))
92
+
93
+ if width is not None:
94
+ inputs.append(prepare_tensor("width", np.array([width], dtype=np.int32)))
95
+
96
+ if num_images_per_prompt is not None:
97
+ inputs.append(prepare_tensor("num_images_per_prompt", np.array([num_images_per_prompt], dtype=np.int32)))
98
+
99
+ if num_inference_steps is not None:
100
+ inputs.append(prepare_tensor("num_inference_steps", np.array([num_inference_steps], dtype=np.int32)))
101
+
102
+ if image is not None:
103
+ inputs.append(prepare_tensor("image", np.array([image], dtype=np.object_)))
104
+
105
+ if mask_image is not None:
106
+ inputs.append(prepare_tensor("mask_image", np.array([mask_image], dtype=np.object_)))
107
+
108
+ if seed is not None:
109
+ inputs.append(prepare_tensor("seed", np.array([seed], dtype=np.int64)))
110
+
111
+ if guidance_scale is not None:
112
+ inputs.append(prepare_tensor("guidance_scale", np.array([guidance_scale], dtype=np.float32)))
113
+
114
+ if model_type is not None:
115
+ inputs.append(prepare_tensor("model_type", np.array([model_type], dtype=np.object_)))
116
+
117
+ if strength is not None:
118
+ inputs.append(prepare_tensor("strength", np.array([strength], dtype=np.float32)))
119
+
120
+ if scheduler is not None:
121
+ inputs.append(prepare_tensor("scheduler", np.array([scheduler], dtype=np.object_)))
122
+
123
+ if control_images is not None:
124
+ inputs.append(prepare_tensor("control_images", np.array([control_images], dtype=np.object_)))
125
+
126
+ if control_weightages is not None:
127
+ inputs.append(prepare_tensor("control_weightages", np.array([control_weightages], dtype=np.float32)))
128
+
129
+ if control_modes is not None:
130
+ inputs.append(prepare_tensor("control_modes", np.array([control_modes], dtype=np.int32)))
131
+
132
+ if lora_weights is not None:
133
+ inputs.append(prepare_tensor("lora_weights", np.array([lora_weights], dtype=np.object_)))
134
+
135
+ if control_guidance_start is not None:
136
+ inputs.append(prepare_tensor("control_guidance_start", np.array([control_guidance_start], dtype=np.float32)))
137
+
138
+ if control_guidance_end is not None:
139
+ inputs.append(prepare_tensor("control_guidance_end", np.array([control_guidance_end], dtype=np.float32)))
140
+
141
+ outputs = [
142
+ grpcclient.InferRequestedOutput("response_id"),
143
+ grpcclient.InferRequestedOutput("time_taken"),
144
+ grpcclient.InferRequestedOutput("load_lora"),
145
+ grpcclient.InferRequestedOutput("output_image_urls"),
146
+ grpcclient.InferRequestedOutput("error"),
147
+ # grpcclient.InferRequestedOutput("mega_pixel")
148
+ ]
149
+ user_data = UserData()
150
+ st = time.time()
151
+ mega_pixel = 0
152
+
153
+ url = "localhost:8002"
154
+ with grpcclient.InferenceServerClient(url=url, ssl=False) as triton_client:
155
+ triton_client.start_stream(callback=partial(callback, user_data))
156
+
157
+ triton_client.async_stream_infer(
158
+ model_name="flux",
159
+ inputs=inputs,
160
+ outputs=outputs,
161
+ )
162
+ et = time.time()
163
+ response = user_data._completed_requests.get()
164
+ print(response)
165
+
166
+ # Check if response is an error (InferenceServerException)
167
+ if hasattr(response, 'message'):
168
+ # This is an error response
169
+ print(f"Server error: {response}")
170
+ output_image_urls = []
171
+ inference_time = 0
172
+ lora_time = 0
173
+ response_id = None
174
+ mega_pixel = 0
175
+ error = str(response)
176
+ sCode = 500
177
+ else:
178
+ # This is a successful response
179
+ try:
180
+ inference_time = 0
181
+ lora_time = 0
182
+ response_id = None
183
+ inference_time = response.as_numpy("time_taken").item()
184
+ lora_time = response.as_numpy("load_lora").item()
185
+ response_id = response.as_numpy("response_id").item().decode() if response.as_numpy("response_id").item() else None
186
+ output_image_urls = response.as_numpy("output_image_urls").tolist() if response.as_numpy("output_image_urls") is not None else []
187
+ mega_pixel = response.as_numpy("mega_pixel").item().decode() if response.as_numpy("mega_pixel") is not None else "0"
188
+ error_tensor = response.as_numpy("error")
189
+ error = error_tensor.item().decode() if error_tensor is not None and error_tensor.item() else None
190
+ sCode = 200
191
+ except Exception as e:
192
+ print(f"Error processing response: {e}")
193
+ output_image_urls = []
194
+ inference_time = 0
195
+ lora_time = 0
196
+ response_id = None
197
+ mega_pixel = 0
198
+ error = str(e)
199
+ sCode = 500
200
+
201
+ results = {
202
+ "response_id": response_id,
203
+ "total_time_taken": et-st,
204
+ "inference_time_taken": inference_time,
205
+ "loading_lora_time": lora_time,
206
+ "output_image_urls": output_image_urls,
207
+ "error": error,
208
+ "mega_pixel": 0 if mega_pixel is None else mega_pixel
209
+ }
210
+
211
+ print(results)
212
+
213
+ if output_image_urls == []:
214
+ print("No images generated")
215
+ results["error"] = "No images generated"
216
+ return results
217
+
218
+ def warmup_and_load_lora(warmup_json_path):
219
+ if warmup_json_path is None:
220
+ return False
221
+ with open(warmup_json_path, 'r') as f:
222
+ warmup_data = json.load(f)
223
+ st = time.time()
224
+ for request in warmup_data:
225
+ process_and_send_request(request)
226
+ resp_time = time.time()-st
227
+ print(f"Warmup and load lora done in {resp_time:.3f} seconds")
228
+ return True
229
+
230
+ def generate_jitter_window():
231
+ percent_bifer = random.randint(1,100)
232
+ if percent_bifer >= 1 and percent_bifer <= 50:
233
+ jitter_window = [1, 5]
234
+ elif percent_bifer >= 51 and percent_bifer <= 75:
235
+ jitter_window = [10, 15]
236
+ else:
237
+ jitter_window = [20,30]
238
+ time.sleep(random.randint(jitter_window[0],jitter_window[1]))
239
+ return True
240
+
241
+
242
+ def predict(requests_data,percent_bifer):
243
+
244
+ random_request = random.choice(requests_data)
245
+ sample_request = random_request['payload']
246
+ generate_jitter_window()
247
+ return process_and_send_request(sample_request)
248
+
249
+ def run_single_test(requests_data,id = 0):
250
+ return predict(requests_data,id)
251
+
252
+ number_of_users = 1 #change here for concurrent users
253
+ duration_minutes = 2
254
+
255
+ def run_concurrent_tests_cont(number_of_users, duration_minutes):
256
+ start_time = time.time()
257
+ end_time = start_time + duration_minutes * 60
258
+
259
+ results = []
260
+
261
+ with ThreadPoolExecutor(max_workers=number_of_users) as executor:
262
+ future_to_start_time = {}
263
+
264
+ while time.time() < end_time:
265
+ # Submit new tasks continuously
266
+ percent_bifer = random.randint(1,10)
267
+ if len(future_to_start_time) < number_of_users:
268
+ future = executor.submit(run_single_test, requests_data)
269
+ future_to_start_time[future] = time.time()
270
+
271
+ # Process completed tasks and replace them
272
+ done_futures = [f for f in future_to_start_time if f.done()]
273
+ for future in done_futures:
274
+ response_time = future.result()
275
+ results.append(response_time)
276
+ del future_to_start_time[future]
277
+
278
+ # Wait for any remaining tasks to finish
279
+ for future in future_to_start_time:
280
+ results.append(future.result())
281
+
282
+
283
+ p25 = np.percentile(results, 25)
284
+ p50 = np.percentile(results, 50)
285
+ p90 = np.percentile(results, 90)
286
+ p99 = np.percentile(results, 99)
287
+ avg = sum(results) / len(results)
288
+
289
+ with open(f"result_dump_{number_of_users}_{duration_minutes}.json", "w") as f:
290
+ f.write(
291
+ json.dumps(resp_list, indent=4)
292
+ )
293
+
294
+ return p25, p50, p90, p99, avg
295
+
296
+ if run_multiple_tests:
297
+ p25_result , p50_results, p90_resutls, p99_results, avg = run_concurrent_tests_cont(number_of_users,duration_minutes)
298
+ load_lora_time = warmup_and_load_lora(requests_data)
299
+
300
+ print(f"25th Percentile: {p25_result:.3f} seconds")
301
+ print(f"50th Percentile: {p50_results:.3f} seconds")
302
+ print(f"90th Percentile: {p90_resutls:.3f} seconds")
303
+ print(f"99th Percentile: {p99_results:.3f} seconds")
304
+ print(f"Average Response Time: {avg:.3f} seconds")
305
+
306
+ with open(f"test_results.json", "w") as f:
307
+ f.write(
308
+ json.dumps({
309
+ "p25": p25_result,
310
+ "p50": p50_results,
311
+ "p90": p90_resutls,
312
+ "p99": p99_results,
313
+ "avg": avg,
314
+ "sCode_breakup": scode_breakup
315
+ }, indent=4)
316
+ )
317
+ else:
318
+ payload = {
319
+ "prompt": "A girl in city, 25 years old, cool, futuristic <lora:https://huggingface.co/XLabs-AI/flux-lora-collection/resolve/main/art_lora.safetensors:0.5>",
320
+ "negative_prompt": "blurry, low quality, distorted",
321
+ "height": 1024,
322
+ "width": 1024,
323
+ "num_images_per_prompt": 1,
324
+ "num_inference_steps": 20,
325
+ "seed": 42424243,
326
+ "guidance_scale": 7.0,
327
+ "model_type": "txt2img"
328
+ }
329
+ result = process_and_send_request(payload)
330
+ print(result)