Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- sglang/benchmark/bench_in_batch_prefix/bench_in_batch_prefix.py +130 -0
- sglang/benchmark/benchmark_vllm_060/README.md +89 -0
- sglang/benchmark/blog_v0_2/README.md +164 -0
- sglang/benchmark/blog_v0_2/config.md +100 -0
- sglang/benchmark/deepseek_v3/README.md +123 -0
- sglang/benchmark/generative_agents/README.md +38 -0
- sglang/benchmark/generative_agents/agent_functions.py +300 -0
- sglang/benchmark/generative_agents/bench_other.py +80 -0
- sglang/benchmark/generative_agents/bench_sglang.py +74 -0
- sglang/benchmark/hellaswag/bench_sglang.py +106 -0
- sglang/benchmark/json_decode_regex/README.md +60 -0
- sglang/benchmark/json_decode_regex/bench_other.py +98 -0
- sglang/benchmark/json_decode_regex/bench_sglang.py +100 -0
- sglang/benchmark/json_decode_regex/build_dataset.py +58 -0
- sglang/benchmark/json_jump_forward/README.md +88 -0
- sglang/benchmark/json_jump_forward/bench_other.py +288 -0
- sglang/benchmark/json_jump_forward/bench_sglang.py +143 -0
- sglang/benchmark/json_jump_forward/build_dataset.py +58 -0
- sglang/benchmark/json_jump_forward/dataset.txt +50 -0
- sglang/benchmark/json_schema/README.md +15 -0
- sglang/benchmark/json_schema/bench_sglang.py +146 -0
- sglang/benchmark/kernels/decoding_attention_triton/triton_flashinfer_cudnn.py +405 -0
- sglang/benchmark/kernels/fused_moe_triton/README.md +49 -0
- sglang/benchmark/kernels/rmsnorm/benchmark_rmsnorm.py +231 -0
- sglang/benchmark/kernels/scheduler_batch/benchmark_write_req_to_token_pool_triton.py +345 -0
- sglang/benchmark/line_retrieval/README.md +37 -0
- sglang/benchmark/line_retrieval/bench_sglang.py +149 -0
- sglang/benchmark/line_retrieval/gen_data.py +139 -0
- sglang/benchmark/llava_bench/README.md +61 -0
- sglang/benchmark/llava_bench/bench_hf_llava_bench.sh +9 -0
- sglang/benchmark/llava_bench/bench_hf_mme.sh +9 -0
- sglang/benchmark/llava_bench/bench_sglang.py +96 -0
- sglang/benchmark/llava_bench/bench_sglang_mme.sh +2 -0
- sglang/benchmark/llava_bench/download_images.py +20 -0
- sglang/benchmark/llava_bench/questions.jsonl +60 -0
- sglang/benchmark/llm_judge/README.md +33 -0
- sglang/benchmark/llm_judge/articles.jsonl +0 -0
- sglang/benchmark/llm_judge/bench_other.py +151 -0
- sglang/benchmark/llm_judge/bench_sglang.py +97 -0
- sglang/benchmark/long_json_decode/README.md +33 -0
- sglang/benchmark/long_json_decode/bench_other.py +89 -0
- sglang/benchmark/long_json_decode/bench_sglang.py +81 -0
- sglang/benchmark/long_json_decode/build_dataset.py +27 -0
- sglang/benchmark/mmlu/bench_other.py +173 -0
- sglang/benchmark/mmlu/bench_sglang.py +174 -0
- sglang/benchmark/mmlu/download_data.sh +2 -0
- sglang/benchmark/multi_chain_reasoning/bench_other.py +186 -0
- sglang/benchmark/multi_document_qa/README.md +47 -0
- sglang/benchmark/multi_document_qa/bench_other.py +114 -0
- sglang/benchmark/multi_document_qa/bench_sglang.py +93 -0
sglang/benchmark/bench_in_batch_prefix/bench_in_batch_prefix.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Benchmark with lots of common prefixes. Used to benchmark prefix caching performance.
|
2 |
+
#
|
3 |
+
# Launch a server:
|
4 |
+
# python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --log-level-http warning
|
5 |
+
|
6 |
+
import random
|
7 |
+
import string
|
8 |
+
import time
|
9 |
+
|
10 |
+
from tqdm import tqdm
|
11 |
+
from transformers import AutoTokenizer
|
12 |
+
|
13 |
+
import sglang as sgl
|
14 |
+
from sglang import set_default_backend
|
15 |
+
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
|
16 |
+
|
17 |
+
|
18 |
+
def generate_random_string(token_length: int) -> str:
|
19 |
+
random_string = "".join(
|
20 |
+
random.choices(string.ascii_letters + string.digits, k=token_length * 100)
|
21 |
+
)
|
22 |
+
tokenized_output = tokenizer.encode(random_string, add_special_tokens=False)[
|
23 |
+
:token_length
|
24 |
+
]
|
25 |
+
|
26 |
+
if len(tokenized_output) < token_length:
|
27 |
+
tokenized_output = tokenized_output + [tokenizer.pad_token_id] * (
|
28 |
+
token_length - len(tokenized_output)
|
29 |
+
)
|
30 |
+
|
31 |
+
decoded_string = tokenizer.decode(tokenized_output, skip_special_tokens=False)
|
32 |
+
return decoded_string
|
33 |
+
|
34 |
+
|
35 |
+
def generate_unique_prefix(base_text, index):
|
36 |
+
return str(index) + base_text[len(str(index)) :]
|
37 |
+
|
38 |
+
|
39 |
+
@sgl.function
|
40 |
+
def text_qa(s, question, gen_len):
|
41 |
+
s += "Q: " + question + "\n"
|
42 |
+
s += "A:" + sgl.gen("answer", stop="\n", temperature=0, max_tokens=gen_len)
|
43 |
+
|
44 |
+
|
45 |
+
def prepare_prompts(num_prefix, num_samples_per_prefix, prefix_length, suffix_length):
|
46 |
+
base_prefix = generate_random_string(prefix_length)
|
47 |
+
|
48 |
+
tot_input_len = 0
|
49 |
+
all_prompts = []
|
50 |
+
for i in tqdm(range(num_prefix), desc="prepare prompts"):
|
51 |
+
unique_prefix = generate_unique_prefix(base_prefix, i)
|
52 |
+
prompt_list = []
|
53 |
+
for j in range(num_samples_per_prefix):
|
54 |
+
suffix = generate_random_string(suffix_length)
|
55 |
+
prompt = unique_prefix + suffix
|
56 |
+
prompt_list.append(prompt)
|
57 |
+
tot_input_len += len(tokenizer.encode(prompt))
|
58 |
+
all_prompts.append(prompt_list)
|
59 |
+
return all_prompts, tot_input_len
|
60 |
+
|
61 |
+
|
62 |
+
def test_batch_by_batch(all_prompts, gen_len):
|
63 |
+
backend.flush_cache()
|
64 |
+
|
65 |
+
tot_time = 0
|
66 |
+
for i in range(len(all_prompts)):
|
67 |
+
tic = time.time()
|
68 |
+
text_qa.run_batch(
|
69 |
+
list(zip(all_prompts[i], [gen_len] * len(all_prompts[i]))),
|
70 |
+
)
|
71 |
+
tot_time += time.time() - tic
|
72 |
+
|
73 |
+
return tot_time
|
74 |
+
|
75 |
+
|
76 |
+
def test_batch_by_batch_with_hint(all_prompts, gen_len):
|
77 |
+
backend.flush_cache()
|
78 |
+
|
79 |
+
tot_time = 0
|
80 |
+
for i in range(len(all_prompts)):
|
81 |
+
tic = time.time()
|
82 |
+
# Send a hint to cache the prefix
|
83 |
+
text_qa.run_batch(list(zip(all_prompts[i][:1], [gen_len])))
|
84 |
+
# Send the batch
|
85 |
+
text_qa.run_batch(list(zip(all_prompts[i], [gen_len] * len(all_prompts[i]))))
|
86 |
+
|
87 |
+
tot_time += time.time() - tic
|
88 |
+
|
89 |
+
return tot_time
|
90 |
+
|
91 |
+
|
92 |
+
def test_send_all(all_prompts, gen_len):
|
93 |
+
backend.flush_cache()
|
94 |
+
|
95 |
+
all_prompts = [x for prompt_list in all_prompts for x in prompt_list]
|
96 |
+
|
97 |
+
tic = time.time()
|
98 |
+
text_qa.run_batch(
|
99 |
+
list(zip(all_prompts, [gen_len] * len(all_prompts))),
|
100 |
+
)
|
101 |
+
tot_time = time.time() - tic
|
102 |
+
|
103 |
+
return tot_time
|
104 |
+
|
105 |
+
|
106 |
+
if __name__ == "__main__":
|
107 |
+
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
108 |
+
backend = RuntimeEndpoint("http://127.0.0.1:30000")
|
109 |
+
set_default_backend(backend)
|
110 |
+
|
111 |
+
random.seed(0)
|
112 |
+
num_prefix = 10
|
113 |
+
num_samples_per_prefix = 32
|
114 |
+
prefix_length = 1024
|
115 |
+
suffix_length = 128
|
116 |
+
gen_len = 1
|
117 |
+
all_prompts, tot_input_len = prepare_prompts(
|
118 |
+
num_prefix, num_samples_per_prefix, prefix_length, suffix_length
|
119 |
+
)
|
120 |
+
|
121 |
+
print(f"Total input token length: {tot_input_len}\n")
|
122 |
+
|
123 |
+
cost = test_batch_by_batch(all_prompts, gen_len)
|
124 |
+
print(f"Latency of test_batch_by_batch : {cost:.4f} s\n")
|
125 |
+
|
126 |
+
cost = test_batch_by_batch_with_hint(all_prompts, gen_len)
|
127 |
+
print(f"Latency of test_batch_by_batch_with_hint: {cost:.4f} s\n")
|
128 |
+
|
129 |
+
cost = test_send_all(all_prompts, gen_len)
|
130 |
+
print(f"Latency of test_send_all : {cost:.4f} s\n")
|
sglang/benchmark/benchmark_vllm_060/README.md
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## How to reproduce the benchmark results for SGLang v0.3.0 compared to vLLM v0.6.0
|
2 |
+
|
3 |
+
In short, with multi step enabled, in online scenarios that we benchmarked, the Median TTFT of vLLM is **3 times** that of SGLang, and the Median ITL is **10 times** that of SGLang. Lower Median TTFT and ITL are better. vLLM's multi-step optimization did not improve throughput while ensuring lower Median TTFT and ITL. Also, under maximum throughput benchmark, if vLLM does not set gpu util to 0.95 separately and uses the default configuration instead, its maximum throughput is **lower** than that of SGLang.
|
4 |
+
|
5 |
+
## Online benchmark results
|
6 |
+
|
7 |
+
### Llama 3.1 8B Instruct 1 x A100 80G
|
8 |
+
|
9 |
+
| RPS | Num prompts | Engine | Median E2E Latency | Median TTFT | Median TPOT | Median ITL |
|
10 |
+
|------|-------------|--------|--------------------|-------------|-------------|------------|
|
11 |
+
| 4 | 1200 | SGLang | 1564.17 | **31.98** | 13.17 | **11.93** |
|
12 |
+
| 4 | 1200 | vLLM | 1691.97 | **100.48** | 14.14 | **129.32** |
|
13 |
+
| 8 | 2400 | SGLang | 2175.02 | **35.68** | 17.85 | **14.41** |
|
14 |
+
| 8 | 2400 | vLLM | 2137.16 | **120.39** | 17.09 | **158.63** |
|
15 |
+
|
16 |
+
### Llama 3.1 70B Insruct 4 x H100 80G
|
17 |
+
|
18 |
+
| RPS | Num Prompts | Engine | Median E2E Latency | Median TTFT | Median TPOT | Median ITL |
|
19 |
+
|------|-------------|--------|--------------------|-------------|-------------|------------|
|
20 |
+
| 4 | 1200 | SGLang | 3005.24 | **53.94** | 25.03 | **21.67** |
|
21 |
+
| 4 | 1200 | vLLM | 2915.60 | **179.15** | 23.58 | **231.23** |
|
22 |
+
| 8 | 2400 | SGLang | 4064.98 | **58.11** | 33.07 | **24.45** |
|
23 |
+
| 8 | 2400 | vLLM | 3752.38 | **207.12** | 29.15 | **275.32** |
|
24 |
+
|
25 |
+
## Offline benchmark results
|
26 |
+
|
27 |
+
### Llama 3.1 8B Instruct 1 x A100 80G
|
28 |
+
|
29 |
+
| RPS | Num Prompts | Engine | Request throughput | Output token throughput |
|
30 |
+
|------|-------------|--------|--------------------|-------------------------|
|
31 |
+
| inf | 5000 | SGLang | 22.03 | **4281.51** |
|
32 |
+
| inf | 5000 | vLLM | 21.27 | **4132.37** |
|
33 |
+
|
34 |
+
### Llama 3.1 70B Insruct 4 x H100 80G
|
35 |
+
|
36 |
+
| RPS | Num Prompts | Engine | Request throughput | Output token throughput |
|
37 |
+
|------|-------------|--------|--------------------|-------------------------|
|
38 |
+
| inf | 5000 | SGLang | 19.84 | **3856.01** |
|
39 |
+
| inf | 5000 | vLLM | 19.04 | **3700.64** |
|
40 |
+
|
41 |
+
## Installation
|
42 |
+
|
43 |
+
```bash
|
44 |
+
# install sglang v0.3.0
|
45 |
+
pip install --upgrade pip
|
46 |
+
pip install "sglang[all]"==0.3.0
|
47 |
+
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/
|
48 |
+
|
49 |
+
# install vllm v0.6.0
|
50 |
+
pip install vllm==0.6.0
|
51 |
+
```
|
52 |
+
|
53 |
+
## Notes
|
54 |
+
|
55 |
+
We referred to the reproduction method in https://github.com/vllm-project/vllm/issues/8176, and added the `--num-scheduler-steps 10` parameter when starting the vLLM server. The `gpu_memory_utilization` of vLLM is by default 0.9 at both TP 1 and TP 4, while SGLang's `mem_frac` is 0.88 at TP 1 and 0.85 at TP 4, so we manually set it to 0.88 at TP 4.
|
56 |
+
|
57 |
+
## Online benchmarks
|
58 |
+
|
59 |
+
```bash
|
60 |
+
# Llama 3.1 8B Instruct on 1 x A100
|
61 |
+
python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --enable-torch-compile --disable-radix-cache
|
62 |
+
python -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-3.1-8B-Instruct --disable-log-requests --num-scheduler-steps 10 --max_model_len 4096
|
63 |
+
|
64 |
+
# Llama 3.1 70B Instruct on 4 x H100
|
65 |
+
python -m sglang.launch_server --model-path meta-llama/Llama-3.1-70B-Instruct --disable-radix-cache --tp 4
|
66 |
+
python -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-3.1-70B-Instruct --disable-log-requests --num-scheduler-steps 10 --tensor 4 --max_model_len 4096
|
67 |
+
|
68 |
+
# bench serving
|
69 |
+
python3 -m sglang.bench_serving --backend sglang --dataset-name sharegpt --num-prompts 1200 --request-rate 4
|
70 |
+
python3 -m sglang.bench_serving --backend sglang --dataset-name sharegpt --num-prompts 2400 --request-rate 8
|
71 |
+
python3 -m sglang.bench_serving --backend vllm --dataset-name sharegpt --num-prompts 1200 --request-rate 4
|
72 |
+
python3 -m sglang.bench_serving --backend vllm --dataset-name sharegpt --num-prompts 2400 --request-rate 8
|
73 |
+
```
|
74 |
+
|
75 |
+
## Offline benchmarks
|
76 |
+
|
77 |
+
```bash
|
78 |
+
# Llama 3.1 8B Instruct on 1 x A100
|
79 |
+
python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --enable-torch-compile --disable-radix-cache
|
80 |
+
python -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-3.1-8B-Instruct --disable-log-requests --num-scheduler-steps 10 --max_model_len 4096
|
81 |
+
|
82 |
+
# Llama 3.1 70B Instruct on 4 x H100
|
83 |
+
python -m sglang.launch_server --model-path meta-llama/Llama-3.1-70B-Instruct --disable-radix-cache --tp 4 --mem-frac 0.88
|
84 |
+
python -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-3.1-70B-Instruct --disable-log-requests --num-scheduler-steps 10 --tensor 4 --max_model_len 4096
|
85 |
+
|
86 |
+
# bench serving
|
87 |
+
python3 -m sglang.bench_serving --backend sglang --dataset-name sharegpt --num-prompts 5000
|
88 |
+
python3 -m sglang.bench_serving --backend vllm --dataset-name sharegpt --num-prompts 5000
|
89 |
+
```
|
sglang/benchmark/blog_v0_2/README.md
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# How to reproduce the benchmark results of SGLang
|
2 |
+
|
3 |
+
## Prerequisite
|
4 |
+
|
5 |
+
### Install the latest SGLang
|
6 |
+
|
7 |
+
```bash
|
8 |
+
git clone https://github.com/sgl-project/sglang.git
|
9 |
+
cd sglang
|
10 |
+
git checkout v0.2.7
|
11 |
+
|
12 |
+
pip install --upgrade pip
|
13 |
+
pip install -e "python[all]"
|
14 |
+
|
15 |
+
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.3/
|
16 |
+
```
|
17 |
+
|
18 |
+
### Set up ulimit and HF_TOKEN
|
19 |
+
|
20 |
+
```bash
|
21 |
+
ulimit -n 65535
|
22 |
+
# Change the token to a real and usable one, with access permissions for the Llama 3 models.
|
23 |
+
export HF_TOKEN=hf_token
|
24 |
+
```
|
25 |
+
|
26 |
+
### Launch the server
|
27 |
+
|
28 |
+
```bash
|
29 |
+
# Meta-Llama-3.1-8B-Instruct
|
30 |
+
python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --enable-torch-compile --disable-radix-cache
|
31 |
+
|
32 |
+
# Meta-Llama-3.1-70B-Instruct
|
33 |
+
python -m sglang.launch_server --model-path meta-llama/Llama-3.1-70B-Instruct --disable-radix-cache --tp 8
|
34 |
+
|
35 |
+
# Meta-Llama-3-70B-Instruct-FP8
|
36 |
+
python -m sglang.launch_server --model-path neuralmagic/Meta-Llama-3-70B-Instruct-FP8 --disable-radix-cache --tp 8
|
37 |
+
```
|
38 |
+
|
39 |
+
## Benchmark
|
40 |
+
|
41 |
+
### Hardware Requirements
|
42 |
+
|
43 |
+
- 8B models: Single NVIDIA A100 80GB GPU
|
44 |
+
- 70B models: 8 x NVIDIA A100 80GB GPUs with Tensor Parallelism (TP) 8
|
45 |
+
- 70B FP8 models: 8 x NVIDIA H100 GPUs with Tensor Parallelism (TP) 8
|
46 |
+
|
47 |
+
Please ensure you have the appropriate hardware before running the benchmarks.
|
48 |
+
|
49 |
+
#### Offline benchmark
|
50 |
+
|
51 |
+
```bash
|
52 |
+
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 4000 --random-input 1024 --random-output 1024 --output-file offline.jsonl
|
53 |
+
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 5000 --random-input 1024 --random-output 512 --output-file offline.jsonl
|
54 |
+
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 1000 --random-input 4096 --random-output 2048 --output-file offline.jsonl
|
55 |
+
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 2000 --random-input 4096 --random-output 1024 --output-file offline.jsonl
|
56 |
+
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 6000 --random-input 256 --random-output 512 --output-file offline.jsonl
|
57 |
+
python3 -m sglang.bench_serving --backend sglang --dataset-name sharegpt --num-prompts 3000 --output-file offline.jsonl
|
58 |
+
cat offline.jsonl | cut -d':' -f12 | cut -d',' -f1
|
59 |
+
```
|
60 |
+
|
61 |
+
#### Online benchmark
|
62 |
+
|
63 |
+
```bash
|
64 |
+
python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 300 --request-rate 1 --output-file online.jsonl
|
65 |
+
python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 600 --request-rate 2 --output-file online.jsonl
|
66 |
+
python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 1200 --request-rate 4 --output-file online.jsonl
|
67 |
+
python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 2400 --request-rate 8 --output-file online.jsonl
|
68 |
+
python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 3200 --request-rate 16 --output-file online.jsonl
|
69 |
+
cat online.jsonl | cut -d':' -f9 | cut -d',' -f1
|
70 |
+
```
|
71 |
+
|
72 |
+
## Other
|
73 |
+
|
74 |
+
We tried using vLLM 0.5.3.post1, but it often crashes under high loads, and it seems to have similar or worse performance compared to vLLM 0.5.2 from our partial benchmarking, so we are using the older version, vLLM 0.5.2.
|
75 |
+
|
76 |
+
Preparation for TensorRT LLM can refer to https://github.com/sgl-project/tensorrt-demo. Specifically, we used a batch size of 512, a max input length of 8192, and a max number of tokens of 8192. The instance count for preprocessing and postprocessing in Triton Server is 16.
|
77 |
+
|
78 |
+
```bash
|
79 |
+
# vLLM
|
80 |
+
pip install vllm==0.5.2
|
81 |
+
pip install jsonschema==4.21.1
|
82 |
+
|
83 |
+
# Meta-Llama-3-8B-Instruct
|
84 |
+
python -m vllm.entrypoints.openai.api_server --model meta-llama/Meta-Llama-3-8B-Instruct --disable-log-requests
|
85 |
+
|
86 |
+
# meta-llama/Meta-Llama-3-70B-Instruct
|
87 |
+
python -m vllm.entrypoints.openai.api_server --model meta-llama/Meta-Llama-3-70B-Instruct --disable-log-requests --tensor 8
|
88 |
+
|
89 |
+
# neuralmagic/Meta-Llama-3-70B-Instruct-FP8
|
90 |
+
python -m vllm.entrypoints.openai.api_server --model neuralmagic/Meta-Llama-3-70B-Instruct-FP8 --disable-log-requests --tensor 8
|
91 |
+
```
|
92 |
+
|
93 |
+
```bash
|
94 |
+
wget https://raw.githubusercontent.com/sgl-project/sglang/main/python/sglang/bench_serving.py
|
95 |
+
```
|
96 |
+
|
97 |
+
```bash
|
98 |
+
# vLLM Offline
|
99 |
+
|
100 |
+
python3 bench_serving.py --backend vllm --dataset-name random --num-prompts 4000 --random-input 1024 --random-output 1024 --output-file offline_vllm.jsonl
|
101 |
+
python3 bench_serving.py --backend vllm --dataset-name random --num-prompts 5000 --random-input 1024 --random-output 512 --output-file offline_vllm.jsonl
|
102 |
+
python3 bench_serving.py --backend vllm --dataset-name random --num-prompts 1000 --random-input 4096 --random-output 2048 --output-file offline_vllm.jsonl
|
103 |
+
python3 bench_serving.py --backend vllm --dataset-name random --num-prompts 2000 --random-input 4096 --random-output 1024 --output-file offline_vllm.jsonl
|
104 |
+
python3 bench_serving.py --backend vllm --dataset-name random --num-prompts 6000 --random-input 256 --random-output 512 --output-file offline_vllm.jsonl
|
105 |
+
python3 bench_serving.py --backend vllm --dataset-name sharegpt --num-prompts 3000 --output-file offline_vllm.jsonl
|
106 |
+
cat offline_vllm.jsonl | cut -d':' -f12 | cut -d',' -f1
|
107 |
+
```
|
108 |
+
|
109 |
+
```bash
|
110 |
+
# vLLM Online
|
111 |
+
|
112 |
+
python3 bench_serving.py --backend vllm --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 300 --request-rate 1 --output-file online_vllm.jsonl
|
113 |
+
python3 bench_serving.py --backend vllm --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 600 --request-rate 2 --output-file online_vllm.jsonl
|
114 |
+
python3 bench_serving.py --backend vllm --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 1200 --request-rate 4 --output-file online_vllm.jsonl
|
115 |
+
python3 bench_serving.py --backend vllm --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 2400 --request-rate 8 --output-file online_vllm.jsonl
|
116 |
+
python3 bench_serving.py --backend vllm --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 3200 --request-rate 16 --output-file online_vllm.jsonl
|
117 |
+
cat online_vllm.jsonl | cut -d':' -f9 | cut -d',' -f1
|
118 |
+
```
|
119 |
+
|
120 |
+
```bash
|
121 |
+
# TensorRT LLM Offline 8B
|
122 |
+
|
123 |
+
python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-8B-Instruct --dataset-name random --num-prompts 4000 --random-input 1024 --random-output 1024 --output-file offline_trt_8b.jsonl
|
124 |
+
python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-8B-Instruct --dataset-name random --num-prompts 5000 --random-input 1024 --random-output 512 --output-file offline_trt_8b.jsonl
|
125 |
+
python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-8B-Instruct --dataset-name random --num-prompts 1000 --random-input 4096 --random-output 2048 --output-file offline_trt_8b.jsonl
|
126 |
+
python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-8B-Instruct --dataset-name random --num-prompts 2000 --random-input 4096 --random-output 1024 --output-file offline_trt_8b.jsonl
|
127 |
+
python3 bench_serving.py --backend trt --dataset-name random --num-prompts 6000 --random-input 256 --random-output 512 --output-file offline_trt_8b.jsonl --model meta-llama/Meta-Llama-3-8B-Instruct
|
128 |
+
python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-8B-Instruct --dataset-name sharegpt --num-prompts 3000 --output-file offline_trt_8b.jsonl
|
129 |
+
cat offline_trt_8b.jsonl | cut -d':' -f12 | cut -d',' -f1
|
130 |
+
```
|
131 |
+
|
132 |
+
```bash
|
133 |
+
# TensorRT LLM Online 8B
|
134 |
+
|
135 |
+
python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-8B-Instruct --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 300 --request-rate 1 --output-file online_trt_8b.jsonl
|
136 |
+
python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-8B-Instruct --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 600 --request-rate 2 --output-file online_trt_8b.jsonl
|
137 |
+
python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-8B-Instruct --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 1200 --request-rate 4 --output-file online_trt_8b.jsonl
|
138 |
+
python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-8B-Instruct --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 2400 --request-rate 8 --output-file online_trt_8b.jsonl
|
139 |
+
python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-8B-Instruct --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 3200 --request-rate 16 --output-file online_trt_8b.jsonl
|
140 |
+
cat online_trt_8b.jsonl | cut -d':' -f9 | cut -d',' -f1
|
141 |
+
```
|
142 |
+
|
143 |
+
```bash
|
144 |
+
# TensorRT LLM Offline 70B
|
145 |
+
|
146 |
+
python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-70B-Instruct --dataset-name random --num-prompts 4000 --random-input 1024 --random-output 1024 --output-file offline_trt_70b.jsonl
|
147 |
+
python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-70B-Instruct --dataset-name random --num-prompts 5000 --random-input 1024 --random-output 512 --output-file offline_trt_70b.jsonl
|
148 |
+
python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-70B-Instruct --dataset-name random --num-prompts 1000 --random-input 4096 --random-output 2048 --output-file offline_trt_70b.jsonl
|
149 |
+
python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-70B-Instruct --dataset-name random --num-prompts 2000 --random-input 4096 --random-output 1024 --output-file offline_trt_70b.jsonl
|
150 |
+
python3 bench_serving.py --backend trt --dataset-name random --num-prompts 6000 --random-input 256 --random-output 512 --output-file offline_trt_70b.jsonl --model meta-llama/Meta-Llama-3-70B-Instruct
|
151 |
+
python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-70B-Instruct --dataset-name sharegpt --num-prompts 3000 --output-file offline_trt_70b.jsonl
|
152 |
+
cat offline_trt_70b.jsonl | cut -d':' -f12 | cut -d',' -f1
|
153 |
+
```
|
154 |
+
|
155 |
+
```bash
|
156 |
+
# TensorRT LLM Online 70B
|
157 |
+
|
158 |
+
python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-70B-Instruct --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 300 --request-rate 1 --output-file online_trt_70b.jsonl
|
159 |
+
python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-70B-Instruct --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 600 --request-rate 2 --output-file online_trt_70b.jsonl
|
160 |
+
python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-70B-Instruct --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 1200 --request-rate 4 --output-file online_trt_70b.jsonl
|
161 |
+
python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-70B-Instruct --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 2400 --request-rate 8 --output-file online_trt_70b.jsonl
|
162 |
+
python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-70B-Instruct --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 3200 --request-rate 16 --output-file online_trt_70b.jsonl
|
163 |
+
cat online_trt_70b.jsonl | cut -d':' -f9 | cut -d',' -f1
|
164 |
+
```
|
sglang/benchmark/blog_v0_2/config.md
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
### used for TensorRT LLM
|
2 |
+
|
3 |
+
```
|
4 |
+
{
|
5 |
+
"architecture": "LlamaForCausalLM",
|
6 |
+
"dtype": "float16",
|
7 |
+
"logits_dtype": "float32",
|
8 |
+
"vocab_size": 128256,
|
9 |
+
"max_position_embeddings": 8192,
|
10 |
+
"hidden_size": 16384,
|
11 |
+
"num_hidden_layers": 126,
|
12 |
+
"num_attention_heads": 128,
|
13 |
+
"num_key_value_heads": 16,
|
14 |
+
"head_size": 128,
|
15 |
+
"qk_layernorm": false,
|
16 |
+
"hidden_act": "silu",
|
17 |
+
"intermediate_size": 53248,
|
18 |
+
"norm_epsilon": 1e-05,
|
19 |
+
"position_embedding_type": "rope_gpt_neox",
|
20 |
+
"use_parallel_embedding": false,
|
21 |
+
"embedding_sharding_dim": 0,
|
22 |
+
"share_embedding_table": false,
|
23 |
+
"mapping": {
|
24 |
+
"world_size": 8,
|
25 |
+
"tp_size": 8,
|
26 |
+
"pp_size": 1,
|
27 |
+
"gpus_per_node": 8
|
28 |
+
},
|
29 |
+
"quantization": {
|
30 |
+
"quant_algo": "FP8",
|
31 |
+
"kv_cache_quant_algo": null,
|
32 |
+
"group_size": 128,
|
33 |
+
"smoothquant_val": null,
|
34 |
+
"has_zero_point": false,
|
35 |
+
"pre_quant_scale": false,
|
36 |
+
"exclude_modules": [
|
37 |
+
"lm_head"
|
38 |
+
]
|
39 |
+
},
|
40 |
+
"kv_dtype": "float16",
|
41 |
+
"rotary_scaling": null,
|
42 |
+
"residual_mlp": false,
|
43 |
+
"moe_normalization_mode": null,
|
44 |
+
"rotary_base": 500000.0,
|
45 |
+
"moe_num_experts": 0,
|
46 |
+
"moe_top_k": 0,
|
47 |
+
"moe_tp_mode": 2,
|
48 |
+
"attn_bias": false,
|
49 |
+
"disable_weight_only_quant_plugin": false,
|
50 |
+
"mlp_bias": false
|
51 |
+
}
|
52 |
+
```
|
53 |
+
|
54 |
+
### used for vLLM and SGLang
|
55 |
+
|
56 |
+
```
|
57 |
+
{
|
58 |
+
"_name_or_path": "dummy_fp8",
|
59 |
+
"architectures": [
|
60 |
+
"LlamaForCausalLM"
|
61 |
+
],
|
62 |
+
"attention_bias": false,
|
63 |
+
"attention_dropout": 0.0,
|
64 |
+
"bos_token_id": 128000,
|
65 |
+
"eos_token_id": 128009,
|
66 |
+
"hidden_act": "silu",
|
67 |
+
"hidden_size": 16384,
|
68 |
+
"initializer_range": 0.02,
|
69 |
+
"intermediate_size": 53248,
|
70 |
+
"mlp_bias": false,
|
71 |
+
"model_type": "llama",
|
72 |
+
"num_attention_heads": 128,
|
73 |
+
"num_hidden_layers": 126,
|
74 |
+
"num_key_value_heads": 8,
|
75 |
+
"pretraining_tp": 1,
|
76 |
+
"quantization_config": {
|
77 |
+
"activation_scheme": "static",
|
78 |
+
"ignored_layers": [
|
79 |
+
"lm_head"
|
80 |
+
],
|
81 |
+
"quant_method": "fp8"
|
82 |
+
},
|
83 |
+
"rope_scaling": {
|
84 |
+
"factor": 8.0,
|
85 |
+
"low_freq_factor": 1.0,
|
86 |
+
"high_freq_factor": 4.0,
|
87 |
+
"original_max_position_embeddings": 8192,
|
88 |
+
"rope_type": "llama3"
|
89 |
+
},
|
90 |
+
"max_position_embeddings": 131072,
|
91 |
+
"rms_norm_eps": 1e-05,
|
92 |
+
"rope_scaling": null,
|
93 |
+
"rope_theta": 500000.0,
|
94 |
+
"tie_word_embeddings": false,
|
95 |
+
"torch_dtype": "bfloat16",
|
96 |
+
"transformers_version": "4.41.1",
|
97 |
+
"use_cache": true,
|
98 |
+
"vocab_size": 128256
|
99 |
+
}
|
100 |
+
```
|
sglang/benchmark/deepseek_v3/README.md
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# DeepSeek V3 Support
|
2 |
+
|
3 |
+
The SGLang and DeepSeek teams collaborated to get DeepSeek V3 FP8 running on NVIDIA and AMD GPUs **from day one**. SGLang also supports [MLA optimization](https://lmsys.org/blog/2024-09-04-sglang-v0-3/#deepseek-multi-head-latent-attention-mla-throughput-optimizations) and [DP attention](https://lmsys.org/blog/2024-12-04-sglang-v0-4/#data-parallelism-attention-for-deepseek-models), making SGLang one of the best open-source LLM engines for running DeepSeek models. SGLang is the inference engine recommended by the official [DeepSeek team](https://github.com/deepseek-ai/DeepSeek-V3/tree/main?tab=readme-ov-file#62-inference-with-sglang-recommended).
|
4 |
+
|
5 |
+
Special thanks to Meituan's Search & Recommend Platform Team and Baseten's Model Performance Team for implementing the model, and DataCrunch for providing GPU resources.
|
6 |
+
|
7 |
+
## Hardware Recommendation
|
8 |
+
- 8 x NVIDIA H200 GPUs
|
9 |
+
|
10 |
+
If you do not have GPUs with large enough memory, please try multi-node tensor parallelism. There is an example serving with [2 H20 nodes](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3#example-serving-with-2-h208) below.
|
11 |
+
|
12 |
+
## Installation & Launch
|
13 |
+
|
14 |
+
If you encounter errors when starting the server, ensure the weights have finished downloading. It's recommended to download them beforehand or restart multiple times until all weights are downloaded.
|
15 |
+
|
16 |
+
### Using Docker (Recommended)
|
17 |
+
```bash
|
18 |
+
# Pull latest image
|
19 |
+
# https://hub.docker.com/r/lmsysorg/sglang/tags
|
20 |
+
docker pull lmsysorg/sglang:latest
|
21 |
+
|
22 |
+
# Launch
|
23 |
+
docker run --gpus all --shm-size 32g -p 30000:30000 -v ~/.cache/huggingface:/root/.cache/huggingface --ipc=host lmsysorg/sglang:latest \
|
24 |
+
python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code --port 30000
|
25 |
+
```
|
26 |
+
|
27 |
+
For high QPS scenarios, add the `--enable-dp-attention` argument to boost throughput.
|
28 |
+
|
29 |
+
### Using pip
|
30 |
+
```bash
|
31 |
+
# Installation
|
32 |
+
pip install "sglang[all]>=0.4.1.post3" --find-links https://flashinfer.ai/whl/cu124/torch2.4/flashinfer
|
33 |
+
|
34 |
+
# Launch
|
35 |
+
python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code
|
36 |
+
```
|
37 |
+
|
38 |
+
For high QPS scenarios, add the `--enable-dp-attention` argument to boost throughput.
|
39 |
+
|
40 |
+
### Example with OpenAI API
|
41 |
+
|
42 |
+
```python3
|
43 |
+
import openai
|
44 |
+
client = openai.Client(
|
45 |
+
base_url="http://127.0.0.1:30000/v1", api_key="EMPTY")
|
46 |
+
|
47 |
+
# Chat completion
|
48 |
+
response = client.chat.completions.create(
|
49 |
+
model="default",
|
50 |
+
messages=[
|
51 |
+
{"role": "system", "content": "You are a helpful AI assistant"},
|
52 |
+
{"role": "user", "content": "List 3 countries and their capitals."},
|
53 |
+
],
|
54 |
+
temperature=0,
|
55 |
+
max_tokens=64,
|
56 |
+
)
|
57 |
+
print(response)
|
58 |
+
```
|
59 |
+
### Example serving with 2 H20*8
|
60 |
+
For example, there are two H20 nodes, each with 8 GPUs. The first node's IP is `10.0.0.1`, and the second node's IP is `10.0.0.2`.
|
61 |
+
|
62 |
+
```bash
|
63 |
+
# node 1
|
64 |
+
python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3 --tp 16 --dist-init-addr 10.0.0.1:5000 --nnodes 2 --node-rank 0 --trust-remote-code
|
65 |
+
|
66 |
+
# node 2
|
67 |
+
python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3 --tp 16 --dist-init-addr 10.0.0.1:5000 --nnodes 2 --node-rank 1 --trust-remote-code
|
68 |
+
```
|
69 |
+
|
70 |
+
If you have two H100 nodes, the usage is similar to the aforementioned H20.
|
71 |
+
|
72 |
+
### Example serving with Docker two H200*8 nodes
|
73 |
+
There are two H200 nodes, each with 8 GPUs. The first node's IP is `192.168.114.10`, and the second node's IP is `192.168.114.11`. Configure the endpoint to expose it to another Docker container using `--host 0.0.0.0` and `--port 40000`, and set up communications with `--dist-init-addr 192.168.114.10:20000`.
|
74 |
+
A single H200 with 8 devices can run DeepSeek V3, the dual H200 setup is just to demonstrate multi-node usage.
|
75 |
+
|
76 |
+
```bash
|
77 |
+
# node 1
|
78 |
+
docker run --gpus all \
|
79 |
+
--shm-size 32g \
|
80 |
+
--network=host \
|
81 |
+
-v ~/.cache/huggingface:/root/.cache/huggingface \
|
82 |
+
--name sglang_multinode1 \
|
83 |
+
-it \
|
84 |
+
--rm \
|
85 |
+
--env "HF_TOKEN=$HF_TOKEN" \
|
86 |
+
--ipc=host \
|
87 |
+
lmsysorg/sglang:latest \
|
88 |
+
python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3 --tp 16 --dist-init-addr 192.168.114.10:20000 --nnodes 2 --node-rank 0 --trust-remote-code --host 0.0.0.0 --port 40000
|
89 |
+
```
|
90 |
+
|
91 |
+
```bash
|
92 |
+
# node 2
|
93 |
+
docker run --gpus all \
|
94 |
+
--shm-size 32g \
|
95 |
+
--network=host \
|
96 |
+
-v ~/.cache/huggingface:/root/.cache/huggingface \
|
97 |
+
--name sglang_multinode2 \
|
98 |
+
-it \
|
99 |
+
--rm \
|
100 |
+
--env "HF_TOKEN=$HF_TOKEN" \
|
101 |
+
--ipc=host \
|
102 |
+
lmsysorg/sglang:latest \
|
103 |
+
python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3 --tp 16 --dist-init-addr 192.168.114.10:20000 --nnodes 2 --node-rank 1 --trust-remote-code --host 0.0.0.0 --port 40000
|
104 |
+
```
|
105 |
+
|
106 |
+
To ensure functionality, we include a test from a client Docker container.
|
107 |
+
```bash
|
108 |
+
docker run --gpus all \
|
109 |
+
--shm-size 32g \
|
110 |
+
--network=host \
|
111 |
+
-v ~/.cache/huggingface:/root/.cache/huggingface \
|
112 |
+
--name sglang_multinode_client \
|
113 |
+
-it \
|
114 |
+
--rm \
|
115 |
+
--env "HF_TOKEN=$HF_TOKEN" \
|
116 |
+
--ipc=host \
|
117 |
+
lmsysorg/sglang:latest \
|
118 |
+
python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 1 --random-output 512 --random-range-ratio 1 --num-prompts 1 --host 0.0.0.0 --port 40000 --output-file "deepseekv3_multinode.jsonl"
|
119 |
+
```
|
120 |
+
|
121 |
+
## DeepSeek V3 Optimization Plan
|
122 |
+
|
123 |
+
https://github.com/sgl-project/sglang/issues/2591
|
sglang/benchmark/generative_agents/README.md
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Download the dataset
|
2 |
+
|
3 |
+
```
|
4 |
+
wget -O agent_calls.jsonl https://drive.google.com/uc?export=download&id=19qLpD45e9JGTKF2cUjJJegwzSUEZEKht
|
5 |
+
```
|
6 |
+
|
7 |
+
## Run benchmark
|
8 |
+
|
9 |
+
Ensure that this benchmark is run in a serial manner (using --parallel 1) to preserve any potential dependencies between requests.
|
10 |
+
|
11 |
+
### Benchmark sglang
|
12 |
+
```
|
13 |
+
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
|
14 |
+
```
|
15 |
+
|
16 |
+
```
|
17 |
+
python3 bench_sglang.py --num-events 1000 --parallel 1
|
18 |
+
```
|
19 |
+
|
20 |
+
### Benchmark vllm
|
21 |
+
```
|
22 |
+
python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
|
23 |
+
```
|
24 |
+
|
25 |
+
```
|
26 |
+
python3 bench_other.py --num-events 1000 --backend vllm --parallel 1
|
27 |
+
```
|
28 |
+
|
29 |
+
### Benchmark guidance
|
30 |
+
```
|
31 |
+
python3 bench_other.py --num-events 1000 --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf
|
32 |
+
```
|
33 |
+
|
34 |
+
### Benchmark lmql
|
35 |
+
|
36 |
+
```
|
37 |
+
python3 bench_other.py --num-events 1000 --backend lmql --parallel 1
|
38 |
+
```
|
sglang/benchmark/generative_agents/agent_functions.py
ADDED
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sglang as sgl
|
2 |
+
|
3 |
+
# here are the top five agent functions contributing ~70% LLM calls
|
4 |
+
# reference: https://github.com/joonspk-research/generative_agents/
|
5 |
+
|
6 |
+
|
7 |
+
@sgl.function
|
8 |
+
def poignancy_event(s, persona_name, persona_iss, event):
|
9 |
+
s += "Here is a brief description of " + persona_name + ".\n"
|
10 |
+
s += persona_iss + "\n"
|
11 |
+
s += "On the scale of 1 to 10, where 1 is purely mundane (e.g., brushing teeth, making bed) and 10 is extremely poignant (e.g., a break up, college acceptance), rate the likely poignancy of the following event for"
|
12 |
+
s += persona_name + ".\n\n"
|
13 |
+
s += "Event: " + event
|
14 |
+
s += "Rate (return a number between 1 to 10):"
|
15 |
+
s += sgl.gen(name="Rate", max_tokens=2)
|
16 |
+
|
17 |
+
|
18 |
+
def poignancy_event_prompt(persona_name, persona_iss, event):
|
19 |
+
# return prompt and max_tokens
|
20 |
+
s = ""
|
21 |
+
s += "Here is a brief description of " + persona_name + ".\n"
|
22 |
+
s += persona_iss + "\n"
|
23 |
+
s += "On the scale of 1 to 10, where 1 is purely mundane (e.g., brushing teeth, making bed) and 10 is extremely poignant (e.g., a break up, college acceptance), rate the likely poignancy of the following event for"
|
24 |
+
s += persona_name + ".\n\n"
|
25 |
+
s += "Event: " + event
|
26 |
+
s += "Rate (return a number between 1 to 10):"
|
27 |
+
return {"prompt": s, "max_tokens": 2, "stop": None}
|
28 |
+
|
29 |
+
|
30 |
+
@sgl.function
|
31 |
+
def generate_event_triple(s, persona_name, action):
|
32 |
+
s += """Task: Turn the input into (subject, predicate, object).
|
33 |
+
Input: Sam Johnson is eating breakfast.
|
34 |
+
Output: (Dolores Murphy, eat, breakfast)
|
35 |
+
---
|
36 |
+
Input: Joon Park is brewing coffee.
|
37 |
+
Output: (Joon Park, brew, coffee)
|
38 |
+
---
|
39 |
+
Input: Jane Cook is sleeping.
|
40 |
+
Output: (Jane Cook, is, sleep)
|
41 |
+
---
|
42 |
+
Input: Michael Bernstein is writing email on a computer.
|
43 |
+
Output: (Michael Bernstein, write, email)
|
44 |
+
---
|
45 |
+
Input: Percy Liang is teaching students in a classroom.
|
46 |
+
Output: (Percy Liang, teach, students)
|
47 |
+
---
|
48 |
+
Input: Merrie Morris is running on a treadmill.
|
49 |
+
Output: (Merrie Morris, run, treadmill)
|
50 |
+
---"""
|
51 |
+
s += persona_name + "is" + action + ".\n"
|
52 |
+
s += "(" + persona_name + ","
|
53 |
+
s += sgl.gen(name="Triple", max_tokens=20, stop=")")
|
54 |
+
|
55 |
+
|
56 |
+
def generate_event_triple_prompt(persona_name, action):
|
57 |
+
s = ""
|
58 |
+
s += """Task: Turn the input into (subject, predicate, object).
|
59 |
+
Input: Sam Johnson is eating breakfast.
|
60 |
+
Output: (Dolores Murphy, eat, breakfast)
|
61 |
+
---
|
62 |
+
Input: Joon Park is brewing coffee.
|
63 |
+
Output: (Joon Park, brew, coffee)
|
64 |
+
---
|
65 |
+
Input: Jane Cook is sleeping.
|
66 |
+
Output: (Jane Cook, is, sleep)
|
67 |
+
---
|
68 |
+
Input: Michael Bernstein is writing email on a computer.
|
69 |
+
Output: (Michael Bernstein, write, email)
|
70 |
+
---
|
71 |
+
Input: Percy Liang is teaching students in a classroom.
|
72 |
+
Output: (Percy Liang, teach, students)
|
73 |
+
---
|
74 |
+
Input: Merrie Morris is running on a treadmill.
|
75 |
+
Output: (Merrie Morris, run, treadmill)
|
76 |
+
---"""
|
77 |
+
s += persona_name + "is" + action + ".\n"
|
78 |
+
s += "(" + persona_name + ","
|
79 |
+
return {"prompt": s, "max_tokens": 20, "stop": ")"}
|
80 |
+
|
81 |
+
|
82 |
+
@sgl.function
|
83 |
+
def generate_pronunciatio(s, action):
|
84 |
+
s += "Convert an action description to an emoji (important: use two or less emojis).\n"
|
85 |
+
s += "Action description: " + action + ".\n"
|
86 |
+
s += "Emoji:" + sgl.gen(name="Emoji", max_tokens=6)
|
87 |
+
|
88 |
+
|
89 |
+
def generate_pronunciatio_prompt(action):
|
90 |
+
s = ""
|
91 |
+
s += "Convert an action description to an emoji (important: use two or less emojis).\n"
|
92 |
+
s += "Action description: " + action + ".\n"
|
93 |
+
s += "Emoji:"
|
94 |
+
return {"prompt": s, "max_tokens": 6, "stop": None}
|
95 |
+
|
96 |
+
|
97 |
+
@sgl.function
|
98 |
+
def action_location_sector(
|
99 |
+
s,
|
100 |
+
persona_name,
|
101 |
+
living_sector,
|
102 |
+
living_sector_areas,
|
103 |
+
current_sector,
|
104 |
+
current_sector_areas,
|
105 |
+
daily_plan,
|
106 |
+
sector_options,
|
107 |
+
current_action,
|
108 |
+
next_action,
|
109 |
+
):
|
110 |
+
s += """Task -- choose an appropriate area from the area options for a task at hand.
|
111 |
+
Sam Kim lives in {Sam Kim's house} that has Sam Kim's room, bathroom, kitchen.
|
112 |
+
Sam Kim is currently in {Sam Kim's house} that has Sam Kim's room, bathroom, kitchen.
|
113 |
+
Area options: {Sam Kim's house, The Rose and Crown Pub, Hobbs Cafe, Oak Hill College, Johnson Park, Harvey Oak Supply Store, The Willows Market and Pharmacy}.
|
114 |
+
* Stay in the current area if the activity can be done there. Only go out if the activity needs to take place in another place.
|
115 |
+
* Must be one of the "Area options," verbatim.
|
116 |
+
For taking a walk, Sam Kim should go to the following area: {Johnson Park}
|
117 |
+
---
|
118 |
+
Jane Anderson lives in {Oak Hill College Student Dormatory} that has Jane Anderson's room.
|
119 |
+
Jane Anderson is currently in {Oak Hill College} that has a classroom, library
|
120 |
+
Area options: {Oak Hill College Student Dormatory, The Rose and Crown Pub, Hobbs Cafe, Oak Hill College, Johnson Park, Harvey Oak Supply Store, The Willows Market and Pharmacy}.
|
121 |
+
* Stay in the current area if the activity can be done there. Only go out if the activity needs to take place in another place.
|
122 |
+
* Must be one of the "Area options," verbatim.
|
123 |
+
For eating dinner, Jane Anderson should go to the following area: {Hobbs Cafe}
|
124 |
+
---"""
|
125 |
+
s += (
|
126 |
+
persona_name
|
127 |
+
+ " lives in "
|
128 |
+
+ living_sector
|
129 |
+
+ " that has "
|
130 |
+
+ living_sector_areas
|
131 |
+
+ ".\n"
|
132 |
+
)
|
133 |
+
s += (
|
134 |
+
persona_name
|
135 |
+
+ " is currently in "
|
136 |
+
+ current_sector
|
137 |
+
+ " that has "
|
138 |
+
+ current_sector_areas
|
139 |
+
+ ".\n"
|
140 |
+
)
|
141 |
+
s += daily_plan + ".\n"
|
142 |
+
s += "Area options: " + sector_options + ".\n"
|
143 |
+
s += """* Stay in the current area if the activity can be done there. Only go out if the activity needs to take place in another place.
|
144 |
+
* Must be one of the "Area options," verbatim.\n"""
|
145 |
+
s += (
|
146 |
+
persona_name
|
147 |
+
+ " is "
|
148 |
+
+ current_action
|
149 |
+
+ ". For "
|
150 |
+
+ next_action
|
151 |
+
+ ", "
|
152 |
+
+ persona_name
|
153 |
+
+ " should go to the following area: {"
|
154 |
+
)
|
155 |
+
s += sgl.gen(name="Location", max_tokens=10, stop="}")
|
156 |
+
|
157 |
+
|
158 |
+
def action_location_sector_prompt(
|
159 |
+
persona_name,
|
160 |
+
living_sector,
|
161 |
+
living_sector_areas,
|
162 |
+
current_sector,
|
163 |
+
current_sector_areas,
|
164 |
+
daily_plan,
|
165 |
+
sector_options,
|
166 |
+
current_action,
|
167 |
+
next_action,
|
168 |
+
):
|
169 |
+
s = ""
|
170 |
+
s += """Task -- choose an appropriate area from the area options for a task at hand.
|
171 |
+
Sam Kim lives in {Sam Kim's house} that has Sam Kim's room, bathroom, kitchen.
|
172 |
+
Sam Kim is currently in {Sam Kim's house} that has Sam Kim's room, bathroom, kitchen.
|
173 |
+
Area options: {Sam Kim's house, The Rose and Crown Pub, Hobbs Cafe, Oak Hill College, Johnson Park, Harvey Oak Supply Store, The Willows Market and Pharmacy}.
|
174 |
+
* Stay in the current area if the activity can be done there. Only go out if the activity needs to take place in another place.
|
175 |
+
* Must be one of the "Area options," verbatim.
|
176 |
+
For taking a walk, Sam Kim should go to the following area: {Johnson Park}
|
177 |
+
---
|
178 |
+
Jane Anderson lives in {Oak Hill College Student Dormatory} that has Jane Anderson's room.
|
179 |
+
Jane Anderson is currently in {Oak Hill College} that has a classroom, library
|
180 |
+
Area options: {Oak Hill College Student Dormatory, The Rose and Crown Pub, Hobbs Cafe, Oak Hill College, Johnson Park, Harvey Oak Supply Store, The Willows Market and Pharmacy}.
|
181 |
+
* Stay in the current area if the activity can be done there. Only go out if the activity needs to take place in another place.
|
182 |
+
* Must be one of the "Area options," verbatim.
|
183 |
+
For eating dinner, Jane Anderson should go to the following area: {Hobbs Cafe}
|
184 |
+
---"""
|
185 |
+
s += (
|
186 |
+
persona_name
|
187 |
+
+ " lives in "
|
188 |
+
+ living_sector
|
189 |
+
+ " that has "
|
190 |
+
+ living_sector_areas
|
191 |
+
+ ".\n"
|
192 |
+
)
|
193 |
+
s += (
|
194 |
+
persona_name
|
195 |
+
+ " is currently in "
|
196 |
+
+ current_sector
|
197 |
+
+ " that has "
|
198 |
+
+ current_sector_areas
|
199 |
+
+ ".\n"
|
200 |
+
)
|
201 |
+
s += daily_plan + ".\n"
|
202 |
+
s += "Area options: " + sector_options + ".\n"
|
203 |
+
s += """* Stay in the current area if the activity can be done there. Only go out if the activity needs to take place in another place.
|
204 |
+
* Must be one of the "Area options," verbatim.\n"""
|
205 |
+
s += (
|
206 |
+
persona_name
|
207 |
+
+ " is "
|
208 |
+
+ current_action
|
209 |
+
+ ". For "
|
210 |
+
+ next_action
|
211 |
+
+ ", "
|
212 |
+
+ persona_name
|
213 |
+
+ " should go to the following area: {"
|
214 |
+
)
|
215 |
+
return {"prompt": s, "max_tokens": 10, "stop": "}"}
|
216 |
+
|
217 |
+
|
218 |
+
@sgl.function
|
219 |
+
def action_location_object(
|
220 |
+
s, persona_name, target_sector, target_sector_areas, current_action, next_action
|
221 |
+
):
|
222 |
+
s += """
|
223 |
+
Jane Anderson is in kitchen in Jane Anderson's house.
|
224 |
+
Jane Anderson is going to Jane Anderson's house that has the following areas: {kitchen, bedroom, bathroom}
|
225 |
+
Stay in the current area if the activity can be done there. Never go into other people's rooms unless necessary.
|
226 |
+
For cooking, Jane Anderson should go to the following area in Jane Anderson's house:
|
227 |
+
Answer: {kitchen}
|
228 |
+
---
|
229 |
+
Tom Watson is in common room in Tom Watson's apartment.
|
230 |
+
Tom Watson is going to Hobbs Cafe that has the following areas: {cafe}
|
231 |
+
Stay in the current area if the activity can be done there. Never go into other people's rooms unless necessary.
|
232 |
+
For getting coffee, Tom Watson should go to the following area in Hobbs Cafe:
|
233 |
+
Answer: {cafe}
|
234 |
+
---"""
|
235 |
+
s += (
|
236 |
+
persona_name
|
237 |
+
+ " is going to "
|
238 |
+
+ target_sector
|
239 |
+
+ " that has the following areas: {"
|
240 |
+
+ target_sector_areas
|
241 |
+
+ "}\n"
|
242 |
+
)
|
243 |
+
s += """* Stay in the current area if the activity can be done there.
|
244 |
+
* NEVER go into other people's rooms unless necessary."""
|
245 |
+
s += (
|
246 |
+
persona_name
|
247 |
+
+ " is "
|
248 |
+
+ current_action
|
249 |
+
+ ". For "
|
250 |
+
+ next_action
|
251 |
+
+ ", "
|
252 |
+
+ persona_name
|
253 |
+
+ "should go to the following area in "
|
254 |
+
+ target_sector
|
255 |
+
)
|
256 |
+
s += " (MUST pick one of {" + target_sector_areas + "}):\n"
|
257 |
+
s += "Answer: {" + sgl.gen(name="Area", max_tokens=5, stop="}")
|
258 |
+
|
259 |
+
|
260 |
+
def action_location_object_prompt(
|
261 |
+
persona_name, target_sector, target_sector_areas, current_action, next_action
|
262 |
+
):
|
263 |
+
s = ""
|
264 |
+
s += """
|
265 |
+
Jane Anderson is in kitchen in Jane Anderson's house.
|
266 |
+
Jane Anderson is going to Jane Anderson's house that has the following areas: {kitchen, bedroom, bathroom}
|
267 |
+
Stay in the current area if the activity can be done there. Never go into other people's rooms unless necessary.
|
268 |
+
For cooking, Jane Anderson should go to the following area in Jane Anderson's house:
|
269 |
+
Answer: {kitchen}
|
270 |
+
---
|
271 |
+
Tom Watson is in common room in Tom Watson's apartment.
|
272 |
+
Tom Watson is going to Hobbs Cafe that has the following areas: {cafe}
|
273 |
+
Stay in the current area if the activity can be done there. Never go into other people's rooms unless necessary.
|
274 |
+
For getting coffee, Tom Watson should go to the following area in Hobbs Cafe:
|
275 |
+
Answer: {cafe}
|
276 |
+
---"""
|
277 |
+
s += (
|
278 |
+
persona_name
|
279 |
+
+ " is going to "
|
280 |
+
+ target_sector
|
281 |
+
+ " that has the following areas: {"
|
282 |
+
+ target_sector_areas
|
283 |
+
+ "}\n"
|
284 |
+
)
|
285 |
+
s += """* Stay in the current area if the activity can be done there.
|
286 |
+
* NEVER go into other people's rooms unless necessary."""
|
287 |
+
s += (
|
288 |
+
persona_name
|
289 |
+
+ " is "
|
290 |
+
+ current_action
|
291 |
+
+ ". For "
|
292 |
+
+ next_action
|
293 |
+
+ ", "
|
294 |
+
+ persona_name
|
295 |
+
+ "should go to the following area in "
|
296 |
+
+ target_sector
|
297 |
+
)
|
298 |
+
s += " (MUST pick one of {" + target_sector_areas + "}):\n"
|
299 |
+
s += "Answer: {"
|
300 |
+
return {"prompt": s, "max_tokens": 5, "stop": "}"}
|
sglang/benchmark/generative_agents/bench_other.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import time
|
4 |
+
|
5 |
+
from agent_functions import (
|
6 |
+
action_location_object_prompt,
|
7 |
+
action_location_sector_prompt,
|
8 |
+
generate_event_triple_prompt,
|
9 |
+
generate_pronunciatio_prompt,
|
10 |
+
poignancy_event_prompt,
|
11 |
+
)
|
12 |
+
from tqdm import tqdm
|
13 |
+
|
14 |
+
from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
|
15 |
+
from sglang.utils import dump_state_text, read_jsonl
|
16 |
+
|
17 |
+
|
18 |
+
def main(args):
|
19 |
+
lines = read_jsonl(args.data_path)[: args.num_events]
|
20 |
+
mapping = {
|
21 |
+
"poignancy_event": poignancy_event_prompt,
|
22 |
+
"generate_event_triple": generate_event_triple_prompt,
|
23 |
+
"generate_pronunciatio": generate_pronunciatio_prompt,
|
24 |
+
"action_location_sector": action_location_sector_prompt,
|
25 |
+
"action_location_object": action_location_object_prompt,
|
26 |
+
}
|
27 |
+
|
28 |
+
arguments = [mapping[k](**v) for l in lines for k, v in l.items()]
|
29 |
+
states = []
|
30 |
+
|
31 |
+
# Select backend
|
32 |
+
call_generate = get_call_generate(args)
|
33 |
+
|
34 |
+
def get_one_answer(arg):
|
35 |
+
answer = call_generate(**arg, temperature=0)
|
36 |
+
states.append(answer)
|
37 |
+
|
38 |
+
async def get_one_answer_async(arg):
|
39 |
+
answer = await call_generate(**arg, temperature=0)
|
40 |
+
states.append(answer)
|
41 |
+
|
42 |
+
tic = time.time()
|
43 |
+
# we always sequentially execute agent calls to maintain its dependency
|
44 |
+
if args.backend != "lmql":
|
45 |
+
for arg in tqdm(arguments):
|
46 |
+
get_one_answer(arg)
|
47 |
+
else:
|
48 |
+
import asyncio
|
49 |
+
|
50 |
+
loop = asyncio.get_event_loop()
|
51 |
+
for arg in tqdm(arguments):
|
52 |
+
loop.run_until_complete(get_one_answer_async(arg))
|
53 |
+
latency = time.time() - tic
|
54 |
+
|
55 |
+
print(f"Latency: {latency:.3f}")
|
56 |
+
|
57 |
+
# Write results
|
58 |
+
dump_state_text(f"tmp_output_{args.backend}.txt", states)
|
59 |
+
|
60 |
+
with open(args.result_file, "a") as fout:
|
61 |
+
value = {
|
62 |
+
"task": "Generative Agents",
|
63 |
+
"backend": args.backend,
|
64 |
+
"num_gpus": 1,
|
65 |
+
"latency": round(latency, 3),
|
66 |
+
# to pack weighted functions as a single agent
|
67 |
+
"num_requests": len(arguments) / len(mapping),
|
68 |
+
"other": {
|
69 |
+
"parallel": args.parallel,
|
70 |
+
},
|
71 |
+
}
|
72 |
+
fout.write(json.dumps(value) + "\n")
|
73 |
+
|
74 |
+
|
75 |
+
if __name__ == "__main__":
|
76 |
+
parser = argparse.ArgumentParser()
|
77 |
+
parser.add_argument("--data-path", type=str, default="agent_calls.jsonl")
|
78 |
+
parser.add_argument("--num-events", type=int, default=10)
|
79 |
+
args = add_common_other_args_and_parse(parser)
|
80 |
+
main(args)
|
sglang/benchmark/generative_agents/bench_sglang.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import time
|
4 |
+
|
5 |
+
from agent_functions import (
|
6 |
+
action_location_object,
|
7 |
+
action_location_sector,
|
8 |
+
generate_event_triple,
|
9 |
+
generate_pronunciatio,
|
10 |
+
poignancy_event,
|
11 |
+
)
|
12 |
+
|
13 |
+
import sglang as sgl
|
14 |
+
from sglang.test.test_utils import (
|
15 |
+
add_common_sglang_args_and_parse,
|
16 |
+
select_sglang_backend,
|
17 |
+
)
|
18 |
+
from sglang.utils import dump_state_text, read_jsonl
|
19 |
+
|
20 |
+
|
21 |
+
def main(args):
|
22 |
+
lines = read_jsonl(args.data_path)[: args.num_events]
|
23 |
+
mapping = {
|
24 |
+
"poignancy_event": poignancy_event,
|
25 |
+
"generate_event_triple": generate_event_triple,
|
26 |
+
"generate_pronunciatio": generate_pronunciatio,
|
27 |
+
"action_location_sector": action_location_sector,
|
28 |
+
"action_location_object": action_location_object,
|
29 |
+
}
|
30 |
+
arguments = [{mapping[k]: v for k, v in l.items()} for l in lines]
|
31 |
+
|
32 |
+
# Select backend
|
33 |
+
backend = select_sglang_backend(args)
|
34 |
+
sgl.set_default_backend(backend)
|
35 |
+
|
36 |
+
states = []
|
37 |
+
# Run requests
|
38 |
+
tic = time.time()
|
39 |
+
for a in arguments:
|
40 |
+
# only a single key in the dict
|
41 |
+
for func, arg in a.items():
|
42 |
+
result = func.run(**arg)
|
43 |
+
result.sync()
|
44 |
+
states.append(result)
|
45 |
+
latency = time.time() - tic
|
46 |
+
|
47 |
+
# Compute accuracy
|
48 |
+
print(f"Latency: {latency:.3f}")
|
49 |
+
|
50 |
+
# Write results
|
51 |
+
dump_state_text(f"tmp_output_{args.backend}.txt", states)
|
52 |
+
|
53 |
+
with open(args.result_file, "a") as fout:
|
54 |
+
value = {
|
55 |
+
"task": "Generative Agents",
|
56 |
+
"backend": args.backend,
|
57 |
+
"num_gpus": 1,
|
58 |
+
"latency": round(latency, 3),
|
59 |
+
# to pack weighted functions as a single agent
|
60 |
+
"num_requests": len(arguments) / len(mapping),
|
61 |
+
"other": {
|
62 |
+
"num_events": args.num_events,
|
63 |
+
"parallel": args.parallel,
|
64 |
+
},
|
65 |
+
}
|
66 |
+
fout.write(json.dumps(value) + "\n")
|
67 |
+
|
68 |
+
|
69 |
+
if __name__ == "__main__":
|
70 |
+
parser = argparse.ArgumentParser()
|
71 |
+
parser.add_argument("--data-path", type=str, default="agent_calls.jsonl")
|
72 |
+
parser.add_argument("--num-events", type=int, default=10)
|
73 |
+
args = add_common_sglang_args_and_parse(parser)
|
74 |
+
main(args)
|
sglang/benchmark/hellaswag/bench_sglang.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import time
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from sglang.api import set_default_backend
|
8 |
+
from sglang.test.test_utils import (
|
9 |
+
add_common_sglang_args_and_parse,
|
10 |
+
select_sglang_backend,
|
11 |
+
)
|
12 |
+
from sglang.utils import download_and_cache_file, read_jsonl
|
13 |
+
|
14 |
+
|
15 |
+
def get_one_example(lines, i, include_answer):
|
16 |
+
ret = lines[i]["activity_label"] + ": " + lines[i]["ctx"] + " "
|
17 |
+
if include_answer:
|
18 |
+
ret += lines[i]["endings"][lines[i]["label"]]
|
19 |
+
return ret
|
20 |
+
|
21 |
+
|
22 |
+
def get_few_shot_examples(lines, k):
|
23 |
+
ret = ""
|
24 |
+
for i in range(k):
|
25 |
+
ret += get_one_example(lines, i, True) + "\n\n"
|
26 |
+
return ret
|
27 |
+
|
28 |
+
|
29 |
+
def main(args):
|
30 |
+
# Select backend
|
31 |
+
set_default_backend(select_sglang_backend(args))
|
32 |
+
|
33 |
+
# Read data
|
34 |
+
url = "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl"
|
35 |
+
filename = download_and_cache_file(url)
|
36 |
+
lines = list(read_jsonl(filename))
|
37 |
+
|
38 |
+
# Construct prompts
|
39 |
+
num_questions = args.num_questions
|
40 |
+
num_shots = args.num_shots
|
41 |
+
few_shot_examples = get_few_shot_examples(lines, num_shots)
|
42 |
+
|
43 |
+
questions = []
|
44 |
+
choices = []
|
45 |
+
labels = []
|
46 |
+
for i in range(len(lines[:num_questions])):
|
47 |
+
questions.append(get_one_example(lines, i, False))
|
48 |
+
choices.append(lines[i]["endings"])
|
49 |
+
labels.append(lines[i]["label"])
|
50 |
+
arguments = [{"question": q, "choices": c} for q, c in zip(questions, choices)]
|
51 |
+
|
52 |
+
#####################################
|
53 |
+
######### SGL Program Begin #########
|
54 |
+
#####################################
|
55 |
+
|
56 |
+
import sglang as sgl
|
57 |
+
|
58 |
+
@sgl.function
|
59 |
+
def few_shot_hellaswag(s, question, choices):
|
60 |
+
s += few_shot_examples + question
|
61 |
+
s += sgl.select("answer", choices=choices)
|
62 |
+
|
63 |
+
#####################################
|
64 |
+
########## SGL Program End ##########
|
65 |
+
#####################################
|
66 |
+
|
67 |
+
# Run requests
|
68 |
+
tic = time.time()
|
69 |
+
rets = few_shot_hellaswag.run_batch(
|
70 |
+
arguments,
|
71 |
+
temperature=0,
|
72 |
+
num_threads=args.parallel,
|
73 |
+
progress_bar=True,
|
74 |
+
)
|
75 |
+
preds = [choices[i].index(rets[i]["answer"]) for i in range(len(rets))]
|
76 |
+
latency = time.time() - tic
|
77 |
+
|
78 |
+
# Compute accuracy
|
79 |
+
acc = np.mean(np.array(preds) == np.array(labels))
|
80 |
+
print(f"Latency: {latency:.3f}")
|
81 |
+
print(f"Accuracy: {acc:.3f}")
|
82 |
+
|
83 |
+
# Write results
|
84 |
+
with open(args.result_file, "a") as fout:
|
85 |
+
value = {
|
86 |
+
"task": "hellaswag",
|
87 |
+
"backend": args.backend,
|
88 |
+
"num_gpus": 1,
|
89 |
+
"latency": round(latency, 3),
|
90 |
+
"accuracy": round(acc, 3),
|
91 |
+
"num_requests": args.num_questions,
|
92 |
+
"other": {
|
93 |
+
"num_questions": args.num_questions,
|
94 |
+
"parallel": args.parallel,
|
95 |
+
},
|
96 |
+
}
|
97 |
+
fout.write(json.dumps(value) + "\n")
|
98 |
+
|
99 |
+
|
100 |
+
if __name__ == "__main__":
|
101 |
+
parser = argparse.ArgumentParser()
|
102 |
+
parser.add_argument("--num-shots", type=int, default=20)
|
103 |
+
parser.add_argument("--data-path", type=str, default="hellaswag_val.jsonl")
|
104 |
+
parser.add_argument("--num-questions", type=int, default=200)
|
105 |
+
args = add_common_sglang_args_and_parse(parser)
|
106 |
+
main(args)
|
sglang/benchmark/json_decode_regex/README.md
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Run benchmark
|
2 |
+
|
3 |
+
### Build dataset
|
4 |
+
```
|
5 |
+
pip install wikipedia
|
6 |
+
python3 build_dataset.py
|
7 |
+
```
|
8 |
+
|
9 |
+
### Dependencies
|
10 |
+
|
11 |
+
```
|
12 |
+
llama_cpp_python 0.2.19
|
13 |
+
guidance 0.1.10
|
14 |
+
vllm 0.2.5
|
15 |
+
outlines 0.0.22
|
16 |
+
```
|
17 |
+
|
18 |
+
### Benchmark sglang
|
19 |
+
|
20 |
+
Run Llama-7B
|
21 |
+
|
22 |
+
```
|
23 |
+
python3 -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
|
24 |
+
```
|
25 |
+
|
26 |
+
Run Mixtral-8x7B
|
27 |
+
|
28 |
+
```
|
29 |
+
python3 -m sglang.launch_server --model-path mistralai/Mixtral-8x7B-Instruct-v0.1 --port 30000 --tp-size 8
|
30 |
+
```
|
31 |
+
|
32 |
+
Benchmark
|
33 |
+
|
34 |
+
```
|
35 |
+
python3 bench_sglang.py --num-questions 10
|
36 |
+
```
|
37 |
+
|
38 |
+
|
39 |
+
### Benchmark Outlines + vLLM
|
40 |
+
|
41 |
+
Run Llama-7B
|
42 |
+
|
43 |
+
```
|
44 |
+
python3 -m outlines.serve.serve --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
|
45 |
+
```
|
46 |
+
|
47 |
+
Benchmark
|
48 |
+
|
49 |
+
```
|
50 |
+
python3 bench_other.py --backend outlines --num-questions 10
|
51 |
+
```
|
52 |
+
|
53 |
+
|
54 |
+
### Benchmark guidance
|
55 |
+
|
56 |
+
Run Llama-7B and benchmark
|
57 |
+
|
58 |
+
```
|
59 |
+
python3 bench_other.py --backend guidance --num-questions 10 --parallel 1 --n-ctx 4096 --model-path path/to/gguf
|
60 |
+
```
|
sglang/benchmark/json_decode_regex/bench_other.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import time
|
4 |
+
from concurrent.futures import ThreadPoolExecutor
|
5 |
+
from functools import partial
|
6 |
+
|
7 |
+
from tqdm import tqdm
|
8 |
+
|
9 |
+
from sglang.lang.ir import REGEX_FLOAT, REGEX_INT, REGEX_STR
|
10 |
+
from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
|
11 |
+
from sglang.utils import dump_state_text, read_jsonl
|
12 |
+
|
13 |
+
REGEX_LIST = r"\[(" + REGEX_STR + ", )*" + REGEX_STR + r"\]"
|
14 |
+
|
15 |
+
|
16 |
+
# fmt: off
|
17 |
+
def json_decode(document, generate):
|
18 |
+
s = "Please extract the information of a city from the following wikipedia page.\n"
|
19 |
+
s += "Page begin.\n" + document + "Page end.\n"
|
20 |
+
s += "Here is the name, country, and symbol of the city in JSON format.\n"
|
21 |
+
s += "{\n"
|
22 |
+
s += ' "name": '
|
23 |
+
s += generate(s, max_tokens=8, regex=REGEX_STR + ",") + "\n"
|
24 |
+
s += ' "country": '
|
25 |
+
s += generate(s, max_tokens=8, regex=REGEX_STR + ",") + "\n"
|
26 |
+
s += ' "latitude": '
|
27 |
+
s += generate(s, max_tokens=8, regex=REGEX_FLOAT + ",") + "\n"
|
28 |
+
s += ' "population": '
|
29 |
+
s += generate(s, max_tokens=8, regex=REGEX_INT + ",") + "\n"
|
30 |
+
s += ' "top 3 landmarks": '
|
31 |
+
s += generate(s, max_tokens=24, regex=REGEX_LIST) + "\n"
|
32 |
+
s += "}\n"
|
33 |
+
|
34 |
+
return s
|
35 |
+
# fmt: on
|
36 |
+
|
37 |
+
|
38 |
+
def main(args):
|
39 |
+
lines = read_jsonl(args.data_path)
|
40 |
+
arguments = []
|
41 |
+
for i in range(len(lines[: args.num_questions])):
|
42 |
+
arguments.append(
|
43 |
+
{
|
44 |
+
"document": lines[i]["document"],
|
45 |
+
}
|
46 |
+
)
|
47 |
+
states = [None] * len(arguments)
|
48 |
+
|
49 |
+
# Select backend
|
50 |
+
call_generate = partial(get_call_generate(args), temperature=0)
|
51 |
+
|
52 |
+
# Run requests
|
53 |
+
def get_one_answer(i):
|
54 |
+
states[i] = json_decode(generate=call_generate, **arguments[i])
|
55 |
+
|
56 |
+
tic = time.time()
|
57 |
+
if args.parallel == 1:
|
58 |
+
for i in tqdm(range(len(arguments))):
|
59 |
+
get_one_answer(i)
|
60 |
+
else:
|
61 |
+
with ThreadPoolExecutor(args.parallel) as executor:
|
62 |
+
rets = list(
|
63 |
+
tqdm(
|
64 |
+
executor.map(get_one_answer, list(range(len(arguments)))),
|
65 |
+
total=len(arguments),
|
66 |
+
)
|
67 |
+
)
|
68 |
+
for _ in rets:
|
69 |
+
pass
|
70 |
+
|
71 |
+
latency = time.time() - tic
|
72 |
+
|
73 |
+
# Compute accuracy
|
74 |
+
print(f"Latency: {latency:.3f}")
|
75 |
+
|
76 |
+
# Write results
|
77 |
+
dump_state_text(f"tmp_output_{args.backend}.txt", states)
|
78 |
+
|
79 |
+
with open(args.result_file, "a") as fout:
|
80 |
+
value = {
|
81 |
+
"task": "json_decode_regex",
|
82 |
+
"backend": args.backend,
|
83 |
+
"num_gpus": 1,
|
84 |
+
"latency": round(latency, 3),
|
85 |
+
"num_requests": args.num_questions,
|
86 |
+
"other": {
|
87 |
+
"parallel": args.parallel,
|
88 |
+
},
|
89 |
+
}
|
90 |
+
fout.write(json.dumps(value) + "\n")
|
91 |
+
|
92 |
+
|
93 |
+
if __name__ == "__main__":
|
94 |
+
parser = argparse.ArgumentParser()
|
95 |
+
parser.add_argument("--data-path", type=str, default="questions.jsonl")
|
96 |
+
parser.add_argument("--num-questions", type=int, default=20)
|
97 |
+
args = add_common_other_args_and_parse(parser)
|
98 |
+
main(args)
|
sglang/benchmark/json_decode_regex/bench_sglang.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import time
|
4 |
+
|
5 |
+
import sglang as sgl
|
6 |
+
from sglang.lang.ir import REGEX_FLOAT, REGEX_INT, REGEX_STR
|
7 |
+
from sglang.test.test_utils import (
|
8 |
+
add_common_sglang_args_and_parse,
|
9 |
+
select_sglang_backend,
|
10 |
+
)
|
11 |
+
from sglang.utils import dump_state_text, read_jsonl
|
12 |
+
|
13 |
+
REGEX_LIST = r"\[(" + REGEX_STR + ", )*" + REGEX_STR + r"\]"
|
14 |
+
|
15 |
+
# fmt: off
|
16 |
+
@sgl.function
|
17 |
+
def json_warm_up(s):
|
18 |
+
s += "The information about Hogwarts is in the following JSON format.\n"
|
19 |
+
with s.var_scope("json_output"):
|
20 |
+
s += "{\n"
|
21 |
+
s += ' "name": ' + sgl.gen("name", max_tokens=8, regex=REGEX_STR + ",") + "\n"
|
22 |
+
s += ' "country": ' + sgl.gen("country", max_tokens=8, regex=REGEX_STR + ",") + "\n"
|
23 |
+
s += ' "latitude": ' + sgl.gen("latitude", max_tokens=8, regex=REGEX_FLOAT + ",") + "\n"
|
24 |
+
s += ' "population": ' + sgl.gen("population", max_tokens=8, regex=REGEX_INT + ",") + "\n"
|
25 |
+
s += ' "top 3 landmarks": ' + sgl.gen( "landmarks", max_tokens=24, regex=REGEX_LIST) + "\n"
|
26 |
+
s += "}\n"
|
27 |
+
print(f'The warmp up json result is:\n{s["json_output"]}')
|
28 |
+
# fmt: on
|
29 |
+
|
30 |
+
# fmt: off
|
31 |
+
@sgl.function
|
32 |
+
def json_decode(s, document):
|
33 |
+
s += "Please extract the information of a city from the following wikipedia page.\n"
|
34 |
+
s += "Page begin.\n" + document + "Page end.\n"
|
35 |
+
s += "Here is the name, country, and symbol of the city in JSON format.\n"
|
36 |
+
with s.var_scope("json_output"):
|
37 |
+
s += "{\n"
|
38 |
+
s += ' "name": ' + sgl.gen("name", max_tokens=8, regex=REGEX_STR + ",") + "\n"
|
39 |
+
s += ' "country": ' + sgl.gen("country", max_tokens=8, regex=REGEX_STR + ",") + "\n"
|
40 |
+
s += ' "latitude": ' + sgl.gen("latitude", max_tokens=8, regex=REGEX_FLOAT + ",") + "\n"
|
41 |
+
s += ' "population": ' + sgl.gen("population", max_tokens=8, regex=REGEX_INT + ",") + "\n"
|
42 |
+
s += ' "top 3 landmarks": ' + sgl.gen( "landmarks", max_tokens=24, regex=REGEX_LIST) + "\n"
|
43 |
+
s += "}\n"
|
44 |
+
# fmt: on
|
45 |
+
|
46 |
+
|
47 |
+
def main(args):
|
48 |
+
lines = read_jsonl(args.data_path)
|
49 |
+
arguments = []
|
50 |
+
for i in range(len(lines[: args.num_questions])):
|
51 |
+
arguments.append(
|
52 |
+
{
|
53 |
+
"document": lines[i]["document"],
|
54 |
+
}
|
55 |
+
)
|
56 |
+
|
57 |
+
# Select backend
|
58 |
+
backend = select_sglang_backend(args)
|
59 |
+
sgl.set_default_backend(backend)
|
60 |
+
|
61 |
+
# Warm up
|
62 |
+
json_warm_up.run().sync()
|
63 |
+
|
64 |
+
# Run requests
|
65 |
+
tic = time.time()
|
66 |
+
states = json_decode.run_batch(
|
67 |
+
arguments, temperature=0, num_threads=args.parallel, progress_bar=True
|
68 |
+
)
|
69 |
+
latency = time.time() - tic
|
70 |
+
|
71 |
+
# Compute accuracy
|
72 |
+
print(f"Latency: {latency:.3f}")
|
73 |
+
|
74 |
+
# Write results
|
75 |
+
dump_state_text(f"tmp_output_{args.backend}.txt", states)
|
76 |
+
|
77 |
+
with open(f"tmp_{args.backend}_json_results.txt", "w") as fout:
|
78 |
+
for state in states:
|
79 |
+
fout.write(state["json_output"] + "\n")
|
80 |
+
|
81 |
+
with open(args.result_file, "a") as fout:
|
82 |
+
value = {
|
83 |
+
"task": "json_decode_regex",
|
84 |
+
"backend": args.backend,
|
85 |
+
"num_gpus": 1,
|
86 |
+
"latency": round(latency, 3),
|
87 |
+
"num_requests": args.num_questions,
|
88 |
+
"other": {
|
89 |
+
"parallel": args.parallel,
|
90 |
+
},
|
91 |
+
}
|
92 |
+
fout.write(json.dumps(value) + "\n")
|
93 |
+
|
94 |
+
|
95 |
+
if __name__ == "__main__":
|
96 |
+
parser = argparse.ArgumentParser()
|
97 |
+
parser.add_argument("--data-path", type=str, default="questions.jsonl")
|
98 |
+
parser.add_argument("--num-questions", type=int, default=20)
|
99 |
+
args = add_common_sglang_args_and_parse(parser)
|
100 |
+
main(args)
|
sglang/benchmark/json_decode_regex/build_dataset.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
|
3 |
+
import transformers
|
4 |
+
import wikipedia
|
5 |
+
|
6 |
+
model_path = "meta-llama/Llama-2-7b-chat-hf"
|
7 |
+
t = transformers.AutoTokenizer.from_pretrained(model_path)
|
8 |
+
city_names = [
|
9 |
+
"los angles",
|
10 |
+
"london",
|
11 |
+
"tokyo",
|
12 |
+
"beijing",
|
13 |
+
"singapore",
|
14 |
+
"paris",
|
15 |
+
"dubai",
|
16 |
+
"sydney",
|
17 |
+
"moscow",
|
18 |
+
"rome",
|
19 |
+
"toronto",
|
20 |
+
"rio de janeiro",
|
21 |
+
"istanbul",
|
22 |
+
"berlin",
|
23 |
+
"auckland",
|
24 |
+
"buenos aires",
|
25 |
+
"mexico city",
|
26 |
+
"mumbai",
|
27 |
+
"seoul",
|
28 |
+
"bangkok",
|
29 |
+
"cairo",
|
30 |
+
"athens",
|
31 |
+
"jerusalem",
|
32 |
+
]
|
33 |
+
|
34 |
+
|
35 |
+
def get_content(city_name):
|
36 |
+
content = str(wikipedia.page(city_name).content)
|
37 |
+
content = content.replace("\n\n", "\n")
|
38 |
+
|
39 |
+
tokens = t.encode(content)
|
40 |
+
|
41 |
+
expected_tokens = 3000
|
42 |
+
truncate_len = int((expected_tokens / len(tokens)) * len(content))
|
43 |
+
truncate_content = content[:truncate_len]
|
44 |
+
truncate_tokens = t.encode(truncate_content)
|
45 |
+
|
46 |
+
# Count token
|
47 |
+
print(
|
48 |
+
f"city_name: {city_name}, #tokens: {len(tokens)}, #truncate tokens: {len(truncate_tokens)}"
|
49 |
+
)
|
50 |
+
|
51 |
+
return truncate_content
|
52 |
+
|
53 |
+
|
54 |
+
if __name__ == "__main__":
|
55 |
+
with open("questions.jsonl", "w") as fout:
|
56 |
+
for city_name in city_names:
|
57 |
+
truncate_content = get_content(city_name)
|
58 |
+
fout.write(json.dumps({"document": truncate_content}) + "\n")
|
sglang/benchmark/json_jump_forward/README.md
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Run benchmark
|
2 |
+
|
3 |
+
### Dependencies
|
4 |
+
|
5 |
+
```
|
6 |
+
llama_cpp_python 0.2.38
|
7 |
+
guidance 0.1.10
|
8 |
+
vllm 0.2.7
|
9 |
+
outlines 0.0.25
|
10 |
+
```
|
11 |
+
|
12 |
+
### Build dataset
|
13 |
+
|
14 |
+
When benchmarking long document information retrieval, run the following command to build the dataset:
|
15 |
+
|
16 |
+
```bash
|
17 |
+
pip install wikipedia
|
18 |
+
python3 build_dataset.py
|
19 |
+
```
|
20 |
+
|
21 |
+
### Benchmark sglang
|
22 |
+
|
23 |
+
Run Llama-7B
|
24 |
+
|
25 |
+
```bash
|
26 |
+
python3 -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
|
27 |
+
```
|
28 |
+
|
29 |
+
Benchmark Character Generation
|
30 |
+
|
31 |
+
```bash
|
32 |
+
python3 bench_sglang.py --mode character
|
33 |
+
```
|
34 |
+
|
35 |
+
Benchmark City Information Retrieval
|
36 |
+
|
37 |
+
```bash
|
38 |
+
python3 bench_sglang.py --mode city
|
39 |
+
```
|
40 |
+
|
41 |
+
|
42 |
+
### Benchmark Outlines + vLLM
|
43 |
+
|
44 |
+
Run Llama-7B
|
45 |
+
|
46 |
+
```bash
|
47 |
+
python3 -m outlines.serve.serve --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
|
48 |
+
```
|
49 |
+
|
50 |
+
Benchmark Character Generation
|
51 |
+
|
52 |
+
```bash
|
53 |
+
python3 bench_other.py --mode character --backend outlines
|
54 |
+
```
|
55 |
+
|
56 |
+
Benchmark City Information Retrieval
|
57 |
+
|
58 |
+
```bash
|
59 |
+
python3 bench_other.py --mode city --backend outlines
|
60 |
+
```
|
61 |
+
|
62 |
+
### Benchmark guidance
|
63 |
+
|
64 |
+
Run Llama-7B and benchmark character generation
|
65 |
+
|
66 |
+
```bash
|
67 |
+
python3 bench_other.py --mode character --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf
|
68 |
+
```
|
69 |
+
|
70 |
+
Run Llama-7B and benchmark city information retrieval
|
71 |
+
|
72 |
+
```bash
|
73 |
+
python3 bench_other.py --mode city --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf
|
74 |
+
```
|
75 |
+
|
76 |
+
### Benchmark lmql
|
77 |
+
|
78 |
+
Run Llama-7B and benchmark character generation
|
79 |
+
|
80 |
+
```
|
81 |
+
python3 bench_other.py --mode character --backend lmql --parallel 1
|
82 |
+
```
|
83 |
+
|
84 |
+
Run Llama-7B and benchmark city information retrieval
|
85 |
+
|
86 |
+
```
|
87 |
+
python3 bench_other.py --mode city --backend lmql --parallel 1
|
88 |
+
```
|
sglang/benchmark/json_jump_forward/bench_other.py
ADDED
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import time
|
4 |
+
from concurrent.futures import ThreadPoolExecutor
|
5 |
+
from functools import partial
|
6 |
+
|
7 |
+
import guidance
|
8 |
+
from tqdm import tqdm
|
9 |
+
|
10 |
+
from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
|
11 |
+
from sglang.utils import dump_state_text, read_jsonl
|
12 |
+
|
13 |
+
# there are some FSM bugs with json regex converted from pydantic model
|
14 |
+
# here use a string regex instead
|
15 |
+
# regex_string = build_regex_from_object(HarryPoterRole)
|
16 |
+
character_regex = (
|
17 |
+
r"""\{\n"""
|
18 |
+
+ r""" "name": "[\w\d\s]{1,16}",\n"""
|
19 |
+
+ r""" "house": "(Gryffindor|Slytherin|Ravenclaw|Hufflepuff)",\n"""
|
20 |
+
+ r""" "blood status": "(Pure-blood|Half-blood|Muggle-born)",\n"""
|
21 |
+
+ r""" "occupation": "(student|teacher|auror|ministry of magic|death eater|order of the phoenix)",\n"""
|
22 |
+
+ r""" "wand": \{\n"""
|
23 |
+
+ r""" "wood": "[\w\d\s]{1,16}",\n"""
|
24 |
+
+ r""" "core": "[\w\d\s]{1,16}",\n"""
|
25 |
+
+ r""" "length": [0-9]{1,2}\.[0-9]{0,2}\n"""
|
26 |
+
+ r""" \},\n"""
|
27 |
+
+ r""" "alive": "(Alive|Deceased)",\n"""
|
28 |
+
+ r""" "patronus": "[\w\d\s]{1,16}",\n"""
|
29 |
+
+ r""" "bogart": "[\w\d\s]{1,16}"\n"""
|
30 |
+
+ r"""\}"""
|
31 |
+
)
|
32 |
+
|
33 |
+
city_regex = (
|
34 |
+
r"""\{\n"""
|
35 |
+
+ r""" "name": "[\w\d\s]{1,16}",\n"""
|
36 |
+
+ r""" "country": "[\w\d\s]{1,16}",\n"""
|
37 |
+
+ r""" "latitude": [-+]?[0-9]*\.?[0-9]{0,2},\n"""
|
38 |
+
+ r""" "population": [-+]?[0-9]{1,9},\n"""
|
39 |
+
+ r""" "top 3 landmarks": \["[\w\d\s]{1,16}", "[\w\d\s]{1,16}", "[\w\d\s]{1,16}"\]\n"""
|
40 |
+
+ r"""\}"""
|
41 |
+
)
|
42 |
+
|
43 |
+
# fmt: off
|
44 |
+
def character_gen(name, generate):
|
45 |
+
s = name + " is a character in Harry Potter. Please fill in the following information about this character.\n"
|
46 |
+
s += generate(s, max_tokens=256, regex=character_regex)
|
47 |
+
return s
|
48 |
+
# fmt: on
|
49 |
+
|
50 |
+
# fmt: off
|
51 |
+
def city_gen(document, generate):
|
52 |
+
s = "Please extract the information of a city from the following wikipedia page.\n"
|
53 |
+
s += "Page begin.\n" + document + "Page end.\n"
|
54 |
+
s += "Here is the name, country, and symbol of the city in JSON format.\n"
|
55 |
+
s += generate(s, max_tokens=256, regex=city_regex)
|
56 |
+
return s
|
57 |
+
# fmt: on
|
58 |
+
|
59 |
+
|
60 |
+
@guidance
|
61 |
+
def character_maker(lm, name):
|
62 |
+
regex_str_no_quote = r"[\w\d\s]+"
|
63 |
+
regex_float = r"[0-9]+\.[0-9]+"
|
64 |
+
lm += f"""\
|
65 |
+
{name} is a character in Harry Potter. Please fill in the following information about this character.
|
66 |
+
{{
|
67 |
+
"name": "{guidance.gen("name", max_tokens=16, regex=regex_str_no_quote)}",
|
68 |
+
"house": "{guidance.select(options=['Gryffindor', 'Slytherin', 'Ravenclaw', 'Hufflepuff'], name='house')}",
|
69 |
+
"blood status": "{guidance.select(options=['Pure-blood', 'Half-blood', 'Muggle-born'], name='blood status')}",
|
70 |
+
"occupation": "{guidance.select(options=['student', 'teacher', 'auror', 'ministry of magic', 'death eater', 'order of the phoenix'], name='occupation')}",
|
71 |
+
"wand": {{
|
72 |
+
"wood": "{guidance.gen("wood", max_tokens=16, regex=regex_str_no_quote)}",
|
73 |
+
"core": "{guidance.gen('core', max_tokens=16, regex=regex_str_no_quote)}",
|
74 |
+
"length": {guidance.gen('length', max_tokens=10, regex=regex_float)}
|
75 |
+
}},
|
76 |
+
"alive": "{guidance.select(options=['Alive', 'Deceased'], name='alive')}",
|
77 |
+
"patronus": "{guidance.gen('patronus', max_tokens=16, regex=regex_str_no_quote)}",
|
78 |
+
"bogart": "{guidance.gen('bogart', max_tokens=16, regex=regex_str_no_quote)}"
|
79 |
+
}}
|
80 |
+
"""
|
81 |
+
|
82 |
+
return lm
|
83 |
+
|
84 |
+
|
85 |
+
async def call_generate_lmql(
|
86 |
+
prompt, temperature, max_tokens, regex, max_len=4096, model=None, **kwargs
|
87 |
+
):
|
88 |
+
assert model is not None
|
89 |
+
import lmql
|
90 |
+
|
91 |
+
@lmql.query(model=model)
|
92 |
+
async def program(question, max_tokens, regex):
|
93 |
+
'''lmql
|
94 |
+
"""{question}[ANSWER]""" where len(TOKENS(ANSWER)) < max_tokens and REGEX(ANSWER, regex)
|
95 |
+
return ANSWER
|
96 |
+
'''
|
97 |
+
|
98 |
+
return await program(
|
99 |
+
question=prompt,
|
100 |
+
temperature=temperature,
|
101 |
+
max_tokens=max_tokens,
|
102 |
+
max_len=max_len,
|
103 |
+
regex=regex,
|
104 |
+
**kwargs,
|
105 |
+
)
|
106 |
+
|
107 |
+
|
108 |
+
@guidance
|
109 |
+
def city_maker(lm, document):
|
110 |
+
regex_str_no_quote = r"[\w\d\s]+"
|
111 |
+
regex_float = r"[0-9]+\.[0-9]+"
|
112 |
+
lm += f"""\
|
113 |
+
Please extract the information of a city from the following wikipedia page.
|
114 |
+
Page begin.
|
115 |
+
{document}
|
116 |
+
Page end.
|
117 |
+
Here is the name, country, and symbol of the city in JSON format.
|
118 |
+
{{
|
119 |
+
"name": "{guidance.gen("name", max_tokens=16, regex=regex_str_no_quote)}",
|
120 |
+
"country": "{guidance.gen("country", max_tokens=16, regex=regex_str_no_quote)}",
|
121 |
+
"latitude": {guidance.gen("latitude", max_tokens=10, regex=regex_float)},
|
122 |
+
"population": {guidance.gen("population", max_tokens=10, regex=r"[0-9]+")},
|
123 |
+
"top 3 landmarks": [
|
124 |
+
"{guidance.gen("landmark1", max_tokens=16, regex=regex_str_no_quote)}", "{guidance.gen("landmark2", max_tokens=16, regex=regex_str_no_quote)}", "{guidance.gen("landmark3", max_tokens=16, regex=regex_str_no_quote)}"
|
125 |
+
]
|
126 |
+
}}
|
127 |
+
"""
|
128 |
+
|
129 |
+
return lm
|
130 |
+
|
131 |
+
|
132 |
+
def bench_character(args):
|
133 |
+
arguments = []
|
134 |
+
with open(args.data_path, "r") as f:
|
135 |
+
for line in f:
|
136 |
+
arguments.append({"name": line.strip()})
|
137 |
+
arguments = arguments[: args.num_jsons]
|
138 |
+
|
139 |
+
states = [None] * len(arguments)
|
140 |
+
|
141 |
+
# Select backend
|
142 |
+
if args.backend == "outlines":
|
143 |
+
call_generate = partial(get_call_generate(args), temperature=0)
|
144 |
+
|
145 |
+
def get_one_answer(i):
|
146 |
+
states[i] = character_gen(**arguments[i], generate=call_generate)
|
147 |
+
|
148 |
+
elif args.backend == "guidance":
|
149 |
+
model = guidance.models.LlamaCpp(
|
150 |
+
args.model_path,
|
151 |
+
n_gpu_layers=-1,
|
152 |
+
n_ctx=args.n_ctx,
|
153 |
+
)
|
154 |
+
|
155 |
+
def get_one_answer(i):
|
156 |
+
lm = model + character_maker(**arguments[i])
|
157 |
+
states[i] = lm
|
158 |
+
|
159 |
+
elif args.backend == "lmql":
|
160 |
+
import asyncio
|
161 |
+
|
162 |
+
import lmql
|
163 |
+
|
164 |
+
model = lmql.model(args.model_path, endpoint=f"{args.host}:{args.port}")
|
165 |
+
call_generate = partial(
|
166 |
+
call_generate_lmql,
|
167 |
+
model=model,
|
168 |
+
max_tokens=256,
|
169 |
+
regex=character_regex,
|
170 |
+
)
|
171 |
+
|
172 |
+
async def get_one_answer_async(i):
|
173 |
+
states[i] = await call_generate(prompt=arguments[i]["name"], temperature=0)
|
174 |
+
|
175 |
+
else:
|
176 |
+
raise ValueError(f"Invalid backend: {args.backend}")
|
177 |
+
|
178 |
+
tic = time.time()
|
179 |
+
|
180 |
+
if args.backend != "lmql":
|
181 |
+
if args.parallel == 1:
|
182 |
+
for i in tqdm(range(len(arguments))):
|
183 |
+
get_one_answer(i)
|
184 |
+
else:
|
185 |
+
with ThreadPoolExecutor(args.parallel) as executor:
|
186 |
+
rets = list(
|
187 |
+
tqdm(
|
188 |
+
executor.map(get_one_answer, list(range(len(arguments)))),
|
189 |
+
total=len(arguments),
|
190 |
+
)
|
191 |
+
)
|
192 |
+
for _ in rets:
|
193 |
+
pass
|
194 |
+
else:
|
195 |
+
batches = []
|
196 |
+
for i in range(0, len(arguments), args.parallel):
|
197 |
+
batches.append(list(range(i, min(i + args.parallel, len(arguments)))))
|
198 |
+
loop = asyncio.get_event_loop()
|
199 |
+
|
200 |
+
for bt in tqdm(batches):
|
201 |
+
loop.run_until_complete(
|
202 |
+
asyncio.gather(*[get_one_answer_async(i) for i in bt])
|
203 |
+
)
|
204 |
+
|
205 |
+
latency = time.time() - tic
|
206 |
+
|
207 |
+
return states, latency
|
208 |
+
|
209 |
+
|
210 |
+
def bench_city_doc(args):
|
211 |
+
arguments = []
|
212 |
+
for line in read_jsonl(args.data_path):
|
213 |
+
arguments.append({"document": line["document"]})
|
214 |
+
arguments = arguments[: args.num_jsons]
|
215 |
+
|
216 |
+
states = [None] * len(arguments)
|
217 |
+
|
218 |
+
# Select backend
|
219 |
+
if args.backend == "outlines":
|
220 |
+
call_generate = partial(get_call_generate(args), temperature=0)
|
221 |
+
|
222 |
+
def get_one_answer(i):
|
223 |
+
states[i] = city_gen(**arguments[i], generate=call_generate)
|
224 |
+
|
225 |
+
elif args.backend == "guidance":
|
226 |
+
model = guidance.models.LlamaCpp(
|
227 |
+
args.model_path,
|
228 |
+
n_gpu_layers=-1,
|
229 |
+
n_ctx=args.n_ctx,
|
230 |
+
)
|
231 |
+
|
232 |
+
def get_one_answer(i):
|
233 |
+
lm = model + city_maker(**arguments[i])
|
234 |
+
states[i] = lm
|
235 |
+
|
236 |
+
else:
|
237 |
+
raise ValueError(f"Invalid backend: {args.backend}")
|
238 |
+
|
239 |
+
tic = time.time()
|
240 |
+
if args.parallel == 1:
|
241 |
+
for i in tqdm(range(len(arguments))):
|
242 |
+
get_one_answer(i)
|
243 |
+
else:
|
244 |
+
with ThreadPoolExecutor(args.parallel) as executor:
|
245 |
+
rets = executor.map(get_one_answer, list(range(len(arguments))))
|
246 |
+
for _ in rets:
|
247 |
+
pass
|
248 |
+
|
249 |
+
latency = time.time() - tic
|
250 |
+
|
251 |
+
return states, latency
|
252 |
+
|
253 |
+
|
254 |
+
def main(args):
|
255 |
+
if args.mode == "character":
|
256 |
+
args.data_path = "dataset.txt"
|
257 |
+
states, latency = bench_character(args)
|
258 |
+
elif args.mode == "city":
|
259 |
+
args.data_path = "questions.jsonl"
|
260 |
+
states, latency = bench_city_doc(args)
|
261 |
+
|
262 |
+
# Compute accuracy
|
263 |
+
print(f"Latency: {latency:.3f}")
|
264 |
+
|
265 |
+
# Write results
|
266 |
+
dump_state_text(f"tmp_output_{args.backend}_{args.mode}.txt", states)
|
267 |
+
|
268 |
+
with open(args.result_file, "a") as fout:
|
269 |
+
value = {
|
270 |
+
"task": "json_jump_forward",
|
271 |
+
"backend": args.backend,
|
272 |
+
"latency": round(latency, 3),
|
273 |
+
"num_jsons": args.num_jsons,
|
274 |
+
"mode": args.mode,
|
275 |
+
"parallel": args.parallel,
|
276 |
+
}
|
277 |
+
fout.write(json.dumps(value) + "\n")
|
278 |
+
|
279 |
+
|
280 |
+
if __name__ == "__main__":
|
281 |
+
parser = argparse.ArgumentParser()
|
282 |
+
parser.add_argument("--data-path", type=str)
|
283 |
+
parser.add_argument("--num-jsons", type=int, default=50)
|
284 |
+
parser.add_argument(
|
285 |
+
"--mode", type=str, default="character", choices=["character", "city"]
|
286 |
+
)
|
287 |
+
args = add_common_other_args_and_parse(parser)
|
288 |
+
main(args)
|
sglang/benchmark/json_jump_forward/bench_sglang.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import time
|
4 |
+
|
5 |
+
import sglang as sgl
|
6 |
+
from sglang.test.test_utils import (
|
7 |
+
add_common_sglang_args_and_parse,
|
8 |
+
select_sglang_backend,
|
9 |
+
)
|
10 |
+
from sglang.utils import dump_state_text, read_jsonl
|
11 |
+
|
12 |
+
# there are some FSM bugs with json regex converted from pydantic model
|
13 |
+
# here use a string regex instead
|
14 |
+
# regex_string = build_regex_from_object(HarryPoterRole)
|
15 |
+
character_regex = (
|
16 |
+
r"""\{\n"""
|
17 |
+
+ r""" "name": "[\w\d\s]{1,16}",\n"""
|
18 |
+
+ r""" "house": "(Gryffindor|Slytherin|Ravenclaw|Hufflepuff)",\n"""
|
19 |
+
+ r""" "blood status": "(Pure-blood|Half-blood|Muggle-born)",\n"""
|
20 |
+
+ r""" "occupation": "(student|teacher|auror|ministry of magic|death eater|order of the phoenix)",\n"""
|
21 |
+
+ r""" "wand": \{\n"""
|
22 |
+
+ r""" "wood": "[\w\d\s]{1,16}",\n"""
|
23 |
+
+ r""" "core": "[\w\d\s]{1,16}",\n"""
|
24 |
+
+ r""" "length": [0-9]{1,2}\.[0-9]{0,2}\n"""
|
25 |
+
+ r""" \},\n"""
|
26 |
+
+ r""" "alive": "(Alive|Deceased)",\n"""
|
27 |
+
+ r""" "patronus": "[\w\d\s]{1,16}",\n"""
|
28 |
+
+ r""" "bogart": "[\w\d\s]{1,16}"\n"""
|
29 |
+
+ r"""\}"""
|
30 |
+
)
|
31 |
+
|
32 |
+
city_regex = (
|
33 |
+
r"""\{\n"""
|
34 |
+
+ r""" "name": "[\w\d\s]{1,16}",\n"""
|
35 |
+
+ r""" "country": "[\w\d\s]{1,16}",\n"""
|
36 |
+
+ r""" "latitude": [-+]?[0-9]*\.?[0-9]{0,2},\n"""
|
37 |
+
+ r""" "population": [-+]?[0-9]{1,9},\n"""
|
38 |
+
+ r""" "top 3 landmarks": \["[\w\d\s]{1,16}", "[\w\d\s]{1,16}", "[\w\d\s]{1,16}"\]\n"""
|
39 |
+
+ r"""\}"""
|
40 |
+
)
|
41 |
+
|
42 |
+
# fmt: off
|
43 |
+
@sgl.function
|
44 |
+
def character_gen(s, name):
|
45 |
+
s += name + " is a character in Harry Potter. Please fill in the following information about this character.\n"
|
46 |
+
s += sgl.gen("json_output", max_tokens=256, regex=character_regex)
|
47 |
+
# fmt: on
|
48 |
+
|
49 |
+
# fmt: off
|
50 |
+
@sgl.function
|
51 |
+
def city_gen(s, document):
|
52 |
+
s += "Please extract the information of a city from the following wikipedia page.\n"
|
53 |
+
s += "Page begin.\n" + document + "Page end.\n"
|
54 |
+
s += "Here is the name, country, and symbol of the city in JSON format.\n"
|
55 |
+
s += sgl.gen("json_output",max_tokens=256, regex=city_regex)
|
56 |
+
# fmt: on
|
57 |
+
|
58 |
+
|
59 |
+
def bench_city_doc(args):
|
60 |
+
arguments = []
|
61 |
+
for line in read_jsonl(args.data_path):
|
62 |
+
arguments.append({"document": line["document"]})
|
63 |
+
arguments = arguments[: args.num_jsons]
|
64 |
+
|
65 |
+
# Select backend
|
66 |
+
backend = select_sglang_backend(args)
|
67 |
+
sgl.set_default_backend(backend)
|
68 |
+
|
69 |
+
# Run requests
|
70 |
+
tic = time.time()
|
71 |
+
states = city_gen.run_batch(
|
72 |
+
arguments,
|
73 |
+
temperature=0,
|
74 |
+
num_threads=args.parallel,
|
75 |
+
progress_bar=True,
|
76 |
+
)
|
77 |
+
latency = time.time() - tic
|
78 |
+
|
79 |
+
return states, latency
|
80 |
+
|
81 |
+
|
82 |
+
def bench_character(args):
|
83 |
+
arguments = []
|
84 |
+
with open(args.data_path, "r") as f:
|
85 |
+
for line in f:
|
86 |
+
arguments.append({"name": line.strip()})
|
87 |
+
arguments = arguments[: args.num_jsons]
|
88 |
+
|
89 |
+
# Select backend
|
90 |
+
backend = select_sglang_backend(args)
|
91 |
+
sgl.set_default_backend(backend)
|
92 |
+
|
93 |
+
# Run requests
|
94 |
+
tic = time.time()
|
95 |
+
states = character_gen.run_batch(
|
96 |
+
arguments,
|
97 |
+
temperature=0,
|
98 |
+
num_threads=args.parallel,
|
99 |
+
progress_bar=True,
|
100 |
+
)
|
101 |
+
latency = time.time() - tic
|
102 |
+
|
103 |
+
return states, latency
|
104 |
+
|
105 |
+
|
106 |
+
def main(args):
|
107 |
+
if args.mode == "character":
|
108 |
+
args.data_path = "dataset.txt"
|
109 |
+
states, latency = bench_character(args)
|
110 |
+
elif args.mode == "city":
|
111 |
+
args.data_path = "questions.jsonl"
|
112 |
+
states, latency = bench_city_doc(args)
|
113 |
+
|
114 |
+
# Compute accuracy
|
115 |
+
print(f"Latency: {latency:.3f}")
|
116 |
+
|
117 |
+
# Write results
|
118 |
+
dump_state_text(f"tmp_output_{args.backend}_{args.mode}.txt", states)
|
119 |
+
with open(f"{args.backend}_{args.mode}.json", "w") as fout:
|
120 |
+
for state in states:
|
121 |
+
fout.write(state["json_output"] + "\n")
|
122 |
+
|
123 |
+
with open(args.result_file, "a") as fout:
|
124 |
+
value = {
|
125 |
+
"task": "json_jump_forward",
|
126 |
+
"backend": args.backend,
|
127 |
+
"latency": round(latency, 3),
|
128 |
+
"num_jsons": args.num_jsons,
|
129 |
+
"mode": args.mode,
|
130 |
+
"parallel": args.parallel,
|
131 |
+
}
|
132 |
+
fout.write(json.dumps(value) + "\n")
|
133 |
+
|
134 |
+
|
135 |
+
if __name__ == "__main__":
|
136 |
+
parser = argparse.ArgumentParser()
|
137 |
+
parser.add_argument("--data-path", type=str)
|
138 |
+
parser.add_argument("--num-jsons", type=int, default=50)
|
139 |
+
parser.add_argument(
|
140 |
+
"--mode", type=str, default="character", choices=["character", "city"]
|
141 |
+
)
|
142 |
+
args = add_common_sglang_args_and_parse(parser)
|
143 |
+
main(args)
|
sglang/benchmark/json_jump_forward/build_dataset.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
|
3 |
+
import transformers
|
4 |
+
import wikipedia
|
5 |
+
|
6 |
+
model_path = "meta-llama/Llama-2-7b-chat-hf"
|
7 |
+
t = transformers.AutoTokenizer.from_pretrained(model_path)
|
8 |
+
city_names = [
|
9 |
+
"los angles",
|
10 |
+
"london",
|
11 |
+
"tokyo",
|
12 |
+
"beijing",
|
13 |
+
"singapore",
|
14 |
+
"paris",
|
15 |
+
"dubai",
|
16 |
+
"sydney",
|
17 |
+
"moscow",
|
18 |
+
"rome",
|
19 |
+
"toronto",
|
20 |
+
"rio de janeiro",
|
21 |
+
"istanbul",
|
22 |
+
"berlin",
|
23 |
+
"auckland",
|
24 |
+
"buenos aires",
|
25 |
+
"mexico city",
|
26 |
+
"mumbai",
|
27 |
+
"seoul",
|
28 |
+
"bangkok",
|
29 |
+
"cairo",
|
30 |
+
"athens",
|
31 |
+
"jerusalem",
|
32 |
+
]
|
33 |
+
|
34 |
+
|
35 |
+
def get_content(city_name):
|
36 |
+
content = str(wikipedia.page(city_name).content)
|
37 |
+
content = content.replace("\n\n", "\n")
|
38 |
+
|
39 |
+
tokens = t.encode(content)
|
40 |
+
|
41 |
+
expected_tokens = 3000
|
42 |
+
truncate_len = int((expected_tokens / len(tokens)) * len(content))
|
43 |
+
truncate_content = content[:truncate_len]
|
44 |
+
truncate_tokens = t.encode(truncate_content)
|
45 |
+
|
46 |
+
# Count token
|
47 |
+
print(
|
48 |
+
f"city_name: {city_name}, #tokens: {len(tokens)}, #truncate tokens: {len(truncate_tokens)}"
|
49 |
+
)
|
50 |
+
|
51 |
+
return truncate_content
|
52 |
+
|
53 |
+
|
54 |
+
if __name__ == "__main__":
|
55 |
+
with open("questions.jsonl", "w") as fout:
|
56 |
+
for city_name in city_names:
|
57 |
+
truncate_content = get_content(city_name)
|
58 |
+
fout.write(json.dumps({"document": truncate_content}) + "\n")
|
sglang/benchmark/json_jump_forward/dataset.txt
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Harry Potter
|
2 |
+
Hermione Granger
|
3 |
+
Ron Weasley
|
4 |
+
Albus Dumbledore
|
5 |
+
Severus Snape
|
6 |
+
Rubeus Hagrid
|
7 |
+
Draco Malfoy
|
8 |
+
Ginny Weasley
|
9 |
+
Fred Weasley
|
10 |
+
George Weasley
|
11 |
+
Percy Weasley
|
12 |
+
Sirius Black
|
13 |
+
Remus Lupin
|
14 |
+
Neville Longbottom
|
15 |
+
Luna Lovegood
|
16 |
+
Cedric Diggory
|
17 |
+
Cho Chang
|
18 |
+
Lord Voldemort
|
19 |
+
Minerva McGonagall
|
20 |
+
Filius Flitwick
|
21 |
+
Dolores Umbridge
|
22 |
+
Bellatrix Lestrange
|
23 |
+
Lucius Malfoy
|
24 |
+
Molly Weasley
|
25 |
+
Arthur Weasley
|
26 |
+
Nymphadora Tonks
|
27 |
+
Dobby
|
28 |
+
Moaning Myrtle
|
29 |
+
Peter Pettigrew
|
30 |
+
Alastor 'Mad-Eye' Moody
|
31 |
+
Horace Slughorn
|
32 |
+
Vernon Dursley
|
33 |
+
Petunia Dursley
|
34 |
+
Dudley Dursley
|
35 |
+
Argus Filch
|
36 |
+
Sybill Trelawney
|
37 |
+
Gilderoy Lockhart
|
38 |
+
Fleur Delacour
|
39 |
+
Viktor Krum
|
40 |
+
Bill Weasley
|
41 |
+
Oliver Wood
|
42 |
+
Cornelius Fudge
|
43 |
+
Barty Crouch Sr.
|
44 |
+
Barty Crouch Jr.
|
45 |
+
Kingsley Shacklebolt
|
46 |
+
Quirinus Quirrell
|
47 |
+
Nearly Headless Nick
|
48 |
+
Aunt Marge
|
49 |
+
Griphook
|
50 |
+
Ludo Bagman
|
sglang/benchmark/json_schema/README.md
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Run benchmark
|
2 |
+
|
3 |
+
### Benchmark sglang
|
4 |
+
|
5 |
+
Run Llama-8b
|
6 |
+
|
7 |
+
```bash
|
8 |
+
python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --port 30000
|
9 |
+
```
|
10 |
+
|
11 |
+
Benchmark
|
12 |
+
|
13 |
+
```bash
|
14 |
+
python3 bench_sglang.py
|
15 |
+
```
|
sglang/benchmark/json_schema/bench_sglang.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import time
|
4 |
+
from typing import List, Tuple
|
5 |
+
|
6 |
+
import jsonschema
|
7 |
+
from datasets import load_dataset
|
8 |
+
|
9 |
+
import sglang as sgl
|
10 |
+
from sglang.global_config import global_config
|
11 |
+
from sglang.srt.hf_transformers_utils import get_tokenizer
|
12 |
+
from sglang.test.test_utils import (
|
13 |
+
add_common_sglang_args_and_parse,
|
14 |
+
select_sglang_backend,
|
15 |
+
)
|
16 |
+
from sglang.utils import dump_state_text
|
17 |
+
|
18 |
+
|
19 |
+
@sgl.function
|
20 |
+
def schema_gen(s, message: Tuple[str, str], json_schema: str):
|
21 |
+
system, user = message
|
22 |
+
s += sgl.system(system)
|
23 |
+
s += sgl.user(user)
|
24 |
+
s += sgl.assistant(
|
25 |
+
sgl.gen("json_output", temperature=0, max_tokens=256, json_schema=json_schema)
|
26 |
+
)
|
27 |
+
|
28 |
+
|
29 |
+
def contains_formats(schema, formats: List[str]):
|
30 |
+
if isinstance(schema, dict):
|
31 |
+
if schema.get("format", None) in formats:
|
32 |
+
return True
|
33 |
+
for value in schema.values():
|
34 |
+
if contains_formats(value, formats):
|
35 |
+
return True
|
36 |
+
elif isinstance(schema, list):
|
37 |
+
for item in schema:
|
38 |
+
if contains_formats(item, formats):
|
39 |
+
return True
|
40 |
+
return False
|
41 |
+
|
42 |
+
|
43 |
+
def convert_dataset(path: str):
|
44 |
+
raw_dataset = load_dataset(path)
|
45 |
+
dataset = []
|
46 |
+
for data in raw_dataset["train"]:
|
47 |
+
messages = data["prompt"]
|
48 |
+
schema = data["schema"]
|
49 |
+
obj = json.loads(schema)
|
50 |
+
|
51 |
+
# skip some corrupted examples
|
52 |
+
if obj.get("type", None) is None:
|
53 |
+
continue
|
54 |
+
|
55 |
+
# skip schema with format "email"
|
56 |
+
# which is not supported by outlines for now
|
57 |
+
if contains_formats(obj, ["email"]):
|
58 |
+
continue
|
59 |
+
|
60 |
+
system = messages[0]
|
61 |
+
user = messages[1]
|
62 |
+
assert system["role"] == "system", "invalid role"
|
63 |
+
assert user["role"] == "user", "invalid role"
|
64 |
+
assert len(messages) == 2, "invalid message length"
|
65 |
+
message = json.dumps(system["content"]), json.dumps(user["content"])
|
66 |
+
dataset.append(
|
67 |
+
{
|
68 |
+
"message": message,
|
69 |
+
"json_schema": schema,
|
70 |
+
}
|
71 |
+
)
|
72 |
+
|
73 |
+
return dataset
|
74 |
+
|
75 |
+
|
76 |
+
def bench_schema(args):
|
77 |
+
arguments = convert_dataset(args.data_path)
|
78 |
+
|
79 |
+
if args.num_jsons < 0 or args.num_jsons > len(arguments):
|
80 |
+
args.num_jsons = len(arguments)
|
81 |
+
arguments = arguments[: args.num_jsons]
|
82 |
+
|
83 |
+
# Select backend
|
84 |
+
backend = select_sglang_backend(args)
|
85 |
+
sgl.set_default_backend(backend)
|
86 |
+
|
87 |
+
# Run requests
|
88 |
+
tic = time.time()
|
89 |
+
states = schema_gen.run_batch(
|
90 |
+
arguments,
|
91 |
+
temperature=0,
|
92 |
+
num_threads=args.parallel,
|
93 |
+
progress_bar=True,
|
94 |
+
)
|
95 |
+
latency = time.time() - tic
|
96 |
+
|
97 |
+
# Check if the outputs are valid
|
98 |
+
indexs = []
|
99 |
+
for i, state in enumerate(states):
|
100 |
+
try:
|
101 |
+
schema = json.loads(arguments[i]["json_schema"])
|
102 |
+
obj = json.loads(state["json_output"])
|
103 |
+
assert jsonschema.validate(obj, schema) is None
|
104 |
+
except Exception as e:
|
105 |
+
print(e)
|
106 |
+
indexs.append(i)
|
107 |
+
|
108 |
+
return states, latency
|
109 |
+
|
110 |
+
|
111 |
+
def main(args):
|
112 |
+
states, latency = bench_schema(args)
|
113 |
+
|
114 |
+
# Compute accuracy
|
115 |
+
tokenizer = get_tokenizer(
|
116 |
+
global_config.default_backend.get_server_info()["tokenizer_path"]
|
117 |
+
)
|
118 |
+
output_jsons = [state["json_output"] for state in states]
|
119 |
+
num_output_tokens = sum(len(tokenizer.encode(x)) for x in output_jsons)
|
120 |
+
print(f"Latency: {latency:.3f}")
|
121 |
+
print(f"Output throughput: {num_output_tokens / latency:.3f} token/s")
|
122 |
+
print(f"#output tokens: {num_output_tokens}")
|
123 |
+
|
124 |
+
# Write results
|
125 |
+
dump_state_text(f"tmp_output_{args.backend}.txt", states)
|
126 |
+
with open(f"{args.backend}.jsonl", "w") as fout:
|
127 |
+
for state in states:
|
128 |
+
fout.write(state["json_output"] + "\n")
|
129 |
+
|
130 |
+
with open(args.result_file, "a") as fout:
|
131 |
+
value = {
|
132 |
+
"task": "json_schema",
|
133 |
+
"backend": args.backend,
|
134 |
+
"latency": round(latency, 3),
|
135 |
+
"num_jsons": args.num_jsons,
|
136 |
+
"parallel": args.parallel,
|
137 |
+
}
|
138 |
+
fout.write(json.dumps(value) + "\n")
|
139 |
+
|
140 |
+
|
141 |
+
if __name__ == "__main__":
|
142 |
+
parser = argparse.ArgumentParser()
|
143 |
+
parser.add_argument("--data-path", type=str, default="NousResearch/json-mode-eval")
|
144 |
+
parser.add_argument("--num-jsons", type=int, default=-1)
|
145 |
+
args = add_common_sglang_args_and_parse(parser)
|
146 |
+
main(args)
|
sglang/benchmark/kernels/decoding_attention_triton/triton_flashinfer_cudnn.py
ADDED
@@ -0,0 +1,405 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import itertools
|
2 |
+
import math
|
3 |
+
|
4 |
+
import cudnn
|
5 |
+
import torch
|
6 |
+
import torch.utils.benchmark as benchmark
|
7 |
+
import triton
|
8 |
+
import triton.language as tl
|
9 |
+
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
|
10 |
+
|
11 |
+
from sglang.srt.layers.attention.triton_ops.decode_attention import decode_attention_fwd
|
12 |
+
from sglang.srt.utils import should_use_tensor_core
|
13 |
+
|
14 |
+
|
15 |
+
def benchmark_forward(
|
16 |
+
fn,
|
17 |
+
*inputs,
|
18 |
+
repeats=10,
|
19 |
+
amp=False,
|
20 |
+
amp_dtype=torch.float16,
|
21 |
+
**kwinputs,
|
22 |
+
):
|
23 |
+
def amp_wrapper(*inputs, **kwinputs):
|
24 |
+
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
|
25 |
+
fn(*inputs, **kwinputs)
|
26 |
+
|
27 |
+
t = benchmark.Timer(
|
28 |
+
stmt="fn_amp(*inputs, **kwinputs)",
|
29 |
+
globals={"fn_amp": amp_wrapper, "inputs": inputs, "kwinputs": kwinputs},
|
30 |
+
num_threads=torch.get_num_threads(),
|
31 |
+
)
|
32 |
+
m = t.timeit(repeats)
|
33 |
+
return t, m
|
34 |
+
|
35 |
+
|
36 |
+
def time_fwd(func, *args, **kwargs):
|
37 |
+
time_f = benchmark_forward(func, *args, **kwargs)
|
38 |
+
return time_f[1].mean * 1e6
|
39 |
+
|
40 |
+
|
41 |
+
def decode_attention_sglang(
|
42 |
+
q,
|
43 |
+
kv_data,
|
44 |
+
batch_size,
|
45 |
+
kv_len,
|
46 |
+
head_num_q,
|
47 |
+
head_num_kv,
|
48 |
+
head_dim,
|
49 |
+
num_kv_splits,
|
50 |
+
warmup=10,
|
51 |
+
):
|
52 |
+
|
53 |
+
k_buffer = kv_data[0].view(-1, head_num_kv, head_dim)
|
54 |
+
v_buffer = kv_data[1].view(-1, head_num_kv, head_dim)
|
55 |
+
o = torch.empty_like(q)
|
56 |
+
total_tokens = batch_size * kv_len
|
57 |
+
req_to_token = torch.arange(0, total_tokens).to(0).int().view(batch_size, kv_len)
|
58 |
+
b_req_idx = torch.arange(0, batch_size).to(0).int()
|
59 |
+
b_seq_len = torch.full((batch_size,), kv_len, dtype=torch.int32, device="cuda")
|
60 |
+
max_len_in_batch = kv_len
|
61 |
+
sm_scale = 1.0 / (head_dim**0.5)
|
62 |
+
|
63 |
+
attn_logits = torch.empty(
|
64 |
+
(batch_size, head_num_q, num_kv_splits, head_dim + 1),
|
65 |
+
dtype=torch.float32,
|
66 |
+
device="cuda",
|
67 |
+
)
|
68 |
+
|
69 |
+
for _ in range(warmup):
|
70 |
+
decode_attention_fwd(
|
71 |
+
q,
|
72 |
+
k_buffer,
|
73 |
+
v_buffer,
|
74 |
+
o,
|
75 |
+
req_to_token,
|
76 |
+
b_req_idx,
|
77 |
+
b_seq_len,
|
78 |
+
attn_logits,
|
79 |
+
num_kv_splits,
|
80 |
+
sm_scale,
|
81 |
+
)
|
82 |
+
|
83 |
+
f = time_fwd(
|
84 |
+
decode_attention_fwd,
|
85 |
+
q,
|
86 |
+
k_buffer,
|
87 |
+
v_buffer,
|
88 |
+
o,
|
89 |
+
req_to_token,
|
90 |
+
b_req_idx,
|
91 |
+
b_seq_len,
|
92 |
+
attn_logits,
|
93 |
+
num_kv_splits,
|
94 |
+
sm_scale,
|
95 |
+
)
|
96 |
+
|
97 |
+
return f, o
|
98 |
+
|
99 |
+
|
100 |
+
def decode_attention_flashinfer(dtype, head_num_q, head_num_kv):
|
101 |
+
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device="cuda")
|
102 |
+
use_tensor_cores = should_use_tensor_core(
|
103 |
+
kv_cache_dtype=dtype,
|
104 |
+
num_attention_heads=head_num_q,
|
105 |
+
num_kv_heads=head_num_kv,
|
106 |
+
)
|
107 |
+
flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
108 |
+
workspace_buffer, "NHD", use_tensor_cores=use_tensor_cores
|
109 |
+
)
|
110 |
+
|
111 |
+
class FlashinferAttention(torch.autograd.Function):
|
112 |
+
@staticmethod
|
113 |
+
def forward(
|
114 |
+
ctx,
|
115 |
+
q,
|
116 |
+
kv_data,
|
117 |
+
batch_size,
|
118 |
+
kv_len,
|
119 |
+
head_num_q,
|
120 |
+
head_num_kv,
|
121 |
+
head_dim,
|
122 |
+
dtype,
|
123 |
+
warmup=10,
|
124 |
+
):
|
125 |
+
total_tokens = batch_size * kv_len
|
126 |
+
kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * kv_len
|
127 |
+
kv_indices = torch.arange(0, total_tokens).to(0).int()
|
128 |
+
kv_last_page_len = torch.full(
|
129 |
+
(batch_size,), 1, dtype=torch.int32, device="cuda"
|
130 |
+
)
|
131 |
+
|
132 |
+
flashinfer_decode_wrapper.end_forward()
|
133 |
+
flashinfer_decode_wrapper.begin_forward(
|
134 |
+
kv_indptr,
|
135 |
+
kv_indices,
|
136 |
+
kv_last_page_len,
|
137 |
+
head_num_q,
|
138 |
+
head_num_kv,
|
139 |
+
head_dim,
|
140 |
+
1,
|
141 |
+
pos_encoding_mode="NONE",
|
142 |
+
data_type=dtype,
|
143 |
+
)
|
144 |
+
|
145 |
+
for _ in range(warmup):
|
146 |
+
o = flashinfer_decode_wrapper.forward(
|
147 |
+
q.contiguous().view(-1, head_num_q, head_dim), kv_data
|
148 |
+
)
|
149 |
+
|
150 |
+
f = time_fwd(
|
151 |
+
flashinfer_decode_wrapper.forward,
|
152 |
+
q.contiguous().view(-1, head_num_q, head_dim),
|
153 |
+
kv_data,
|
154 |
+
)
|
155 |
+
|
156 |
+
return f, o
|
157 |
+
|
158 |
+
return FlashinferAttention
|
159 |
+
|
160 |
+
|
161 |
+
def convert_to_cudnn_type(torch_type):
|
162 |
+
if torch_type == torch.float16:
|
163 |
+
return cudnn.data_type.HALF
|
164 |
+
elif torch_type == torch.bfloat16:
|
165 |
+
return cudnn.data_type.BFLOAT16
|
166 |
+
elif torch_type == torch.float32:
|
167 |
+
return cudnn.data_type.FLOAT
|
168 |
+
elif torch_type == torch.int32:
|
169 |
+
return cudnn.data_type.INT32
|
170 |
+
elif torch_type == torch.int64:
|
171 |
+
return cudnn.data_type.INT64
|
172 |
+
else:
|
173 |
+
raise ValueError("Unsupported tensor data type.")
|
174 |
+
|
175 |
+
|
176 |
+
def decode_attention_cudnn(
|
177 |
+
q, kv_data, batch_size, kv_len, head_num_q, head_num_kv, head_dim, dtype, warmup=10
|
178 |
+
):
|
179 |
+
# Prepare data: continuous q,k,v
|
180 |
+
dims_q = (batch_size, head_num_q, 1, head_dim)
|
181 |
+
strides_q = (head_num_q * head_dim, head_dim, head_num_q * head_dim, 1)
|
182 |
+
q_gpu = q.as_strided(dims_q, strides_q)
|
183 |
+
o_gpu = (
|
184 |
+
torch.empty(batch_size * head_num_q * head_dim)
|
185 |
+
.half()
|
186 |
+
.cuda()
|
187 |
+
.as_strided(dims_q, strides_q)
|
188 |
+
)
|
189 |
+
|
190 |
+
dims_kv = (batch_size, head_num_kv, kv_len, head_dim)
|
191 |
+
strides_kv = (
|
192 |
+
kv_len * head_num_kv * head_dim,
|
193 |
+
head_dim,
|
194 |
+
head_num_kv * head_dim,
|
195 |
+
1,
|
196 |
+
)
|
197 |
+
k_gpu = kv_data[0].as_strided(dims_kv, strides_kv)
|
198 |
+
v_gpu = kv_data[1].as_strided(dims_kv, strides_kv)
|
199 |
+
|
200 |
+
seq_len_q_gpu = torch.full((batch_size, 1, 1, 1), 1, device="cuda")
|
201 |
+
seq_len_kv_gpu = torch.full((batch_size, 1, 1, 1), kv_len, device="cuda")
|
202 |
+
attn_scale = 1.0 / (head_dim**0.5)
|
203 |
+
|
204 |
+
# Prepare data: paged k,v
|
205 |
+
block_size = 1
|
206 |
+
blocks_per_batch = math.ceil(kv_len / block_size)
|
207 |
+
# [num_blocks, head_num_kv, block_size, head_dim], num_blocks = batch_size * blocks_per_batch
|
208 |
+
container_k_gpu = torch.cat(k_gpu.chunk(blocks_per_batch, dim=2), dim=0)
|
209 |
+
container_v_gpu = torch.cat(v_gpu.chunk(blocks_per_batch, dim=2), dim=0)
|
210 |
+
page_table_k_gpu = (
|
211 |
+
torch.linspace(
|
212 |
+
0,
|
213 |
+
batch_size * blocks_per_batch - 1,
|
214 |
+
batch_size * blocks_per_batch,
|
215 |
+
device="cuda",
|
216 |
+
dtype=torch.int32,
|
217 |
+
)
|
218 |
+
.reshape(blocks_per_batch, 1, batch_size, 1)
|
219 |
+
.transpose(0, 2)
|
220 |
+
)
|
221 |
+
page_table_v_gpu = page_table_k_gpu.clone()
|
222 |
+
|
223 |
+
graph = cudnn.pygraph(
|
224 |
+
io_data_type=convert_to_cudnn_type(dtype),
|
225 |
+
intermediate_data_type=cudnn.data_type.FLOAT,
|
226 |
+
compute_data_type=cudnn.data_type.FLOAT,
|
227 |
+
)
|
228 |
+
|
229 |
+
q = graph.tensor_like(q_gpu)
|
230 |
+
container_k = graph.tensor_like(container_k_gpu)
|
231 |
+
container_v = graph.tensor_like(container_v_gpu)
|
232 |
+
page_table_k = graph.tensor_like(page_table_k_gpu)
|
233 |
+
page_table_v = graph.tensor_like(page_table_v_gpu)
|
234 |
+
|
235 |
+
seq_len_q = graph.tensor_like(seq_len_q_gpu)
|
236 |
+
seq_len_kv = graph.tensor_like(seq_len_kv_gpu)
|
237 |
+
|
238 |
+
o, _ = graph.sdpa(
|
239 |
+
name="sdpa",
|
240 |
+
q=q,
|
241 |
+
k=container_k, # Container K: non contiguous container with K blocks
|
242 |
+
v=container_v, # Container V: non contiguous container with V blocks
|
243 |
+
is_inference=True,
|
244 |
+
attn_scale=attn_scale,
|
245 |
+
use_causal_mask=False,
|
246 |
+
use_padding_mask=True,
|
247 |
+
seq_len_q=seq_len_q,
|
248 |
+
seq_len_kv=seq_len_kv,
|
249 |
+
paged_attention_k_table=page_table_k, # Page Table K: Tensor containing offsets to the container with K blocks
|
250 |
+
paged_attention_v_table=page_table_v, # Page Table V: Tensor containing offsets to the container with V blocks
|
251 |
+
paged_attention_max_seq_len_kv=kv_len, # The maximum sequence length for K caches (this is optional, but recommended)
|
252 |
+
)
|
253 |
+
|
254 |
+
o.set_output(True).set_dim(dims_q).set_stride(strides_q)
|
255 |
+
|
256 |
+
graph.validate()
|
257 |
+
graph.build_operation_graph()
|
258 |
+
graph.create_execution_plans([cudnn.heur_mode.A])
|
259 |
+
graph.check_support()
|
260 |
+
graph.build_plans()
|
261 |
+
|
262 |
+
workspace = torch.empty(
|
263 |
+
graph.get_workspace_size(), device="cuda", dtype=torch.uint8
|
264 |
+
)
|
265 |
+
|
266 |
+
variant_pack = {
|
267 |
+
q: q_gpu,
|
268 |
+
container_k: container_k_gpu,
|
269 |
+
container_v: container_v_gpu,
|
270 |
+
page_table_k: page_table_k_gpu,
|
271 |
+
page_table_v: page_table_v_gpu,
|
272 |
+
seq_len_q: seq_len_q_gpu,
|
273 |
+
seq_len_kv: seq_len_kv_gpu,
|
274 |
+
o: o_gpu,
|
275 |
+
}
|
276 |
+
|
277 |
+
for _ in range(warmup):
|
278 |
+
graph.execute(variant_pack, workspace)
|
279 |
+
|
280 |
+
f = time_fwd(
|
281 |
+
graph.execute,
|
282 |
+
variant_pack,
|
283 |
+
workspace,
|
284 |
+
)
|
285 |
+
|
286 |
+
return f, o_gpu.squeeze(dim=2)
|
287 |
+
|
288 |
+
|
289 |
+
def calculate_diff():
|
290 |
+
|
291 |
+
dtype = torch.float16
|
292 |
+
batch_size = 64
|
293 |
+
kv_len = 4096
|
294 |
+
head_num_q = 64
|
295 |
+
head_num_kv = 8
|
296 |
+
head_dim = 128
|
297 |
+
|
298 |
+
q = torch.randn(batch_size, head_num_q, head_dim, dtype=dtype, device="cuda")
|
299 |
+
kv_data = (
|
300 |
+
torch.randn(
|
301 |
+
batch_size * kv_len, head_num_kv, head_dim, dtype=dtype, device="cuda"
|
302 |
+
),
|
303 |
+
torch.randn(
|
304 |
+
batch_size * kv_len, head_num_kv, head_dim, dtype=dtype, device="cuda"
|
305 |
+
),
|
306 |
+
)
|
307 |
+
|
308 |
+
_, output_sglang = decode_attention_sglang(
|
309 |
+
q,
|
310 |
+
kv_data,
|
311 |
+
batch_size,
|
312 |
+
kv_len,
|
313 |
+
head_num_q,
|
314 |
+
head_num_kv,
|
315 |
+
head_dim,
|
316 |
+
num_kv_splits=8,
|
317 |
+
)
|
318 |
+
|
319 |
+
attn_flashinfer = decode_attention_flashinfer(dtype, head_num_q, head_num_kv).apply
|
320 |
+
_, output_flashinfer = attn_flashinfer(
|
321 |
+
q, kv_data, batch_size, kv_len, head_num_q, head_num_kv, head_dim, dtype
|
322 |
+
)
|
323 |
+
|
324 |
+
_, output_cudnn = decode_attention_cudnn(
|
325 |
+
q, kv_data, batch_size, kv_len, head_num_q, head_num_kv, head_dim, dtype
|
326 |
+
)
|
327 |
+
|
328 |
+
print(f"SGLang output={output_sglang}")
|
329 |
+
print(f"FlashInfer output={output_flashinfer}")
|
330 |
+
print(f"cuDNN output={output_cudnn}")
|
331 |
+
if torch.allclose(output_sglang, output_flashinfer, atol=1e-2, rtol=1e-2):
|
332 |
+
print("✅ SGLang[Triton] and FlashInfer match")
|
333 |
+
else:
|
334 |
+
print("❌ SGLang[Triton] and FlashInfer differ")
|
335 |
+
|
336 |
+
if torch.allclose(output_sglang, output_cudnn, atol=1e-2, rtol=1e-2):
|
337 |
+
print("✅ SGLang[Triton] and cuDNN match")
|
338 |
+
else:
|
339 |
+
print("❌ SGLang[Triton] and cuDNN differ")
|
340 |
+
|
341 |
+
|
342 |
+
if __name__ == "__main__":
|
343 |
+
calculate_diff()
|
344 |
+
|
345 |
+
head_dim = 128
|
346 |
+
dtype = torch.float16
|
347 |
+
batch_size_range = [2**i for i in range(0, 8, 2)]
|
348 |
+
kv_len_range = [2**i for i in range(6, 13, 1)]
|
349 |
+
configs = list(itertools.product(batch_size_range, kv_len_range))
|
350 |
+
|
351 |
+
for head_num_q, head_num_kv in [[32, 32], [64, 8], [40, 8]]:
|
352 |
+
attn_flashinfer = decode_attention_flashinfer(
|
353 |
+
dtype, head_num_q, head_num_kv
|
354 |
+
).apply
|
355 |
+
for batch_size, kv_len in configs:
|
356 |
+
q = torch.randn(
|
357 |
+
batch_size, head_num_q, head_dim, dtype=dtype, device="cuda"
|
358 |
+
)
|
359 |
+
kv_data = (
|
360 |
+
torch.randn(
|
361 |
+
batch_size * kv_len,
|
362 |
+
head_num_kv,
|
363 |
+
head_dim,
|
364 |
+
dtype=dtype,
|
365 |
+
device="cuda",
|
366 |
+
),
|
367 |
+
torch.randn(
|
368 |
+
batch_size * kv_len,
|
369 |
+
head_num_kv,
|
370 |
+
head_dim,
|
371 |
+
dtype=dtype,
|
372 |
+
device="cuda",
|
373 |
+
),
|
374 |
+
)
|
375 |
+
us_cudnn, output_cudnn = decode_attention_cudnn(
|
376 |
+
q, kv_data, batch_size, kv_len, head_num_q, head_num_kv, head_dim, dtype
|
377 |
+
)
|
378 |
+
us_sglang, output_sglang = decode_attention_sglang(
|
379 |
+
q,
|
380 |
+
kv_data,
|
381 |
+
batch_size,
|
382 |
+
kv_len,
|
383 |
+
head_num_q,
|
384 |
+
head_num_kv,
|
385 |
+
head_dim,
|
386 |
+
num_kv_splits=8,
|
387 |
+
)
|
388 |
+
us_flashinfer, _ = attn_flashinfer(
|
389 |
+
q, kv_data, batch_size, kv_len, head_num_q, head_num_kv, head_dim, dtype
|
390 |
+
)
|
391 |
+
print(
|
392 |
+
head_num_q,
|
393 |
+
" ",
|
394 |
+
head_num_kv,
|
395 |
+
" ",
|
396 |
+
batch_size,
|
397 |
+
" ",
|
398 |
+
kv_len,
|
399 |
+
" ",
|
400 |
+
us_cudnn,
|
401 |
+
" ",
|
402 |
+
us_sglang,
|
403 |
+
" ",
|
404 |
+
us_flashinfer,
|
405 |
+
)
|
sglang/benchmark/kernels/fused_moe_triton/README.md
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Benchmark Kernels
|
2 |
+
|
3 |
+
This directory contains benchmarking tools for MoE (Mixture of Experts) kernels.
|
4 |
+
|
5 |
+
### Tuning Tool
|
6 |
+
|
7 |
+
- `tuning_fused_moe_triton.py`: A tool for tuning the `fused_moe_triton` kernel. Adapted from [vllm's benchmark_moe.py](https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_moe.py), with added support for various model architectures.
|
8 |
+
|
9 |
+
Example usage:
|
10 |
+
```bash
|
11 |
+
# Tune Qwen2-57B with FP8 and TP=4
|
12 |
+
python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \
|
13 |
+
--model Qwen/Qwen2-57B-A14B-Instruct \
|
14 |
+
--tp-size 4 \
|
15 |
+
--dtype fp8_w8a8 \
|
16 |
+
--tune
|
17 |
+
|
18 |
+
# Tune Mixtral-8x7B with default settings
|
19 |
+
python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \
|
20 |
+
--model mistralai/Mixtral-8x7B-Instruct-v0.1 \
|
21 |
+
--tune
|
22 |
+
```
|
23 |
+
|
24 |
+
After tuning, a configuration file (e.g., `E=64,N=640,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json`) will be generated in the current directory. You can move this file to `sglang/srt/layers/fused_moe_triton/configs/` to use it in `sglang`.
|
25 |
+
|
26 |
+
### Performance Comparison Tool
|
27 |
+
|
28 |
+
- `benchmark_vllm_vs_sglang_fused_moe_triton.py`: A tool for comparing the performance of fused MoE kernels between vllm and sglang implementations. Supports various model architectures and data types.
|
29 |
+
|
30 |
+
Example usage:
|
31 |
+
```bash
|
32 |
+
# Compare with default settings (Mixtral model)
|
33 |
+
python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py
|
34 |
+
|
35 |
+
# Compare with FP8 mode for Qwen2-57B
|
36 |
+
python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py \
|
37 |
+
--model Qwen/Qwen2-57B-A14B-Instruct \
|
38 |
+
--use-fp8
|
39 |
+
|
40 |
+
# Compare with custom TP size
|
41 |
+
python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py \
|
42 |
+
--tp-size 4
|
43 |
+
```
|
44 |
+
|
45 |
+
The benchmark results will be saved as plots and data files in the specified output directory (default: `./configs/benchmark_ops/vllm_sglang_fused_moe/`).
|
46 |
+
|
47 |
+
- `benchmark_torch_compile_fused_moe.py`: A tool for benchmarking the performance of the fused MoE kernel with `torch.compile` and original fused MoE kernel.
|
48 |
+
|
49 |
+
Usage is the same as `benchmark_vllm_vs_sglang_fused_moe_triton.py`.
|
sglang/benchmark/kernels/rmsnorm/benchmark_rmsnorm.py
ADDED
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import itertools
|
2 |
+
from typing import Optional, Tuple, Union
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import triton
|
6 |
+
import triton.language as tl
|
7 |
+
from flashinfer.norm import fused_add_rmsnorm, rmsnorm
|
8 |
+
from torch import nn
|
9 |
+
from vllm import _custom_ops as vllm_ops
|
10 |
+
|
11 |
+
|
12 |
+
class HuggingFaceRMSNorm(nn.Module):
|
13 |
+
def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
|
14 |
+
super().__init__()
|
15 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
16 |
+
self.variance_epsilon = eps
|
17 |
+
|
18 |
+
def forward(
|
19 |
+
self,
|
20 |
+
x: torch.Tensor,
|
21 |
+
residual: Optional[torch.Tensor] = None,
|
22 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
23 |
+
orig_dtype = x.dtype
|
24 |
+
x = x.to(torch.float32)
|
25 |
+
if residual is not None:
|
26 |
+
x = x + residual.to(torch.float32)
|
27 |
+
residual = x.to(orig_dtype)
|
28 |
+
|
29 |
+
variance = x.pow(2).mean(dim=-1, keepdim=True)
|
30 |
+
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
31 |
+
x = x.to(orig_dtype) * self.weight
|
32 |
+
if residual is None:
|
33 |
+
return x
|
34 |
+
else:
|
35 |
+
return x, residual
|
36 |
+
|
37 |
+
|
38 |
+
def rmsnorm_naive(
|
39 |
+
x: torch.Tensor,
|
40 |
+
weight: torch.Tensor,
|
41 |
+
residual: Optional[torch.Tensor] = None,
|
42 |
+
eps: float = 1e-6,
|
43 |
+
):
|
44 |
+
naive_norm = HuggingFaceRMSNorm(x.shape[-1], eps=eps)
|
45 |
+
naive_norm.weight = nn.Parameter(weight)
|
46 |
+
naive_norm = naive_norm.to(x.device)
|
47 |
+
|
48 |
+
orig_shape = x.shape
|
49 |
+
x = x.view(-1, x.shape[-1])
|
50 |
+
if residual is not None:
|
51 |
+
residual = residual.view(-1, residual.shape[-1])
|
52 |
+
|
53 |
+
output = naive_norm(x, residual)
|
54 |
+
|
55 |
+
if isinstance(output, tuple):
|
56 |
+
output = (output[0].view(orig_shape), output[1].view(orig_shape))
|
57 |
+
else:
|
58 |
+
output = output.view(orig_shape)
|
59 |
+
return output
|
60 |
+
|
61 |
+
|
62 |
+
def rmsnorm_flashinfer(
|
63 |
+
x: torch.Tensor,
|
64 |
+
weight: torch.Tensor,
|
65 |
+
residual: Optional[torch.Tensor] = None,
|
66 |
+
eps: float = 1e-6,
|
67 |
+
):
|
68 |
+
orig_shape = x.shape
|
69 |
+
x = x.view(-1, x.shape[-1])
|
70 |
+
if residual is not None:
|
71 |
+
residual = residual.view(-1, residual.shape[-1])
|
72 |
+
|
73 |
+
if residual is not None:
|
74 |
+
fused_add_rmsnorm(x, residual, weight, eps)
|
75 |
+
output = (x, residual)
|
76 |
+
else:
|
77 |
+
output = rmsnorm(x, weight, eps)
|
78 |
+
|
79 |
+
if isinstance(output, tuple):
|
80 |
+
output = (output[0].view(orig_shape), output[1].view(orig_shape))
|
81 |
+
else:
|
82 |
+
output = output.view(orig_shape)
|
83 |
+
return output
|
84 |
+
|
85 |
+
|
86 |
+
def rmsnorm_vllm(
|
87 |
+
x: torch.Tensor,
|
88 |
+
weight: torch.Tensor,
|
89 |
+
residual: Optional[torch.Tensor] = None,
|
90 |
+
eps: float = 1e-6,
|
91 |
+
):
|
92 |
+
orig_shape = x.shape
|
93 |
+
x = x.view(-1, x.shape[-1])
|
94 |
+
if residual is not None:
|
95 |
+
residual = residual.view(-1, residual.shape[-1])
|
96 |
+
|
97 |
+
if residual is not None:
|
98 |
+
vllm_ops.fused_add_rms_norm(x, residual, weight, eps)
|
99 |
+
output = (x, residual)
|
100 |
+
else:
|
101 |
+
out = torch.empty_like(x)
|
102 |
+
vllm_ops.rms_norm(out, x, weight, eps)
|
103 |
+
output = out
|
104 |
+
|
105 |
+
if isinstance(output, tuple):
|
106 |
+
output = (output[0].view(orig_shape), output[1].view(orig_shape))
|
107 |
+
else:
|
108 |
+
output = output.view(orig_shape)
|
109 |
+
return output
|
110 |
+
|
111 |
+
|
112 |
+
def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True):
|
113 |
+
dtype = torch.bfloat16
|
114 |
+
x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device="cuda")
|
115 |
+
weight = torch.ones(hidden_size, dtype=dtype, device="cuda")
|
116 |
+
residual = torch.randn_like(x) if use_residual else None
|
117 |
+
|
118 |
+
output_naive = rmsnorm_naive(
|
119 |
+
x.clone(), weight, residual.clone() if residual is not None else None
|
120 |
+
)
|
121 |
+
output_flashinfer = rmsnorm_flashinfer(
|
122 |
+
x.clone(), weight, residual.clone() if residual is not None else None
|
123 |
+
)
|
124 |
+
output_vllm = rmsnorm_vllm(
|
125 |
+
x.clone(), weight, residual.clone() if residual is not None else None
|
126 |
+
)
|
127 |
+
|
128 |
+
if use_residual:
|
129 |
+
output_naive = output_naive[0]
|
130 |
+
output_flashinfer = output_flashinfer[0]
|
131 |
+
output_vllm = output_vllm[0]
|
132 |
+
|
133 |
+
print(f"Naive output={output_naive}")
|
134 |
+
print(f"FlashInfer output={output_flashinfer}")
|
135 |
+
print(f"VLLM output={output_vllm}")
|
136 |
+
|
137 |
+
if torch.allclose(
|
138 |
+
output_naive, output_flashinfer, atol=1e-2, rtol=1e-2
|
139 |
+
) and torch.allclose(output_naive, output_vllm, atol=1e-2, rtol=1e-2):
|
140 |
+
print("✅ All implementations match")
|
141 |
+
else:
|
142 |
+
print("❌ Implementations differ")
|
143 |
+
|
144 |
+
|
145 |
+
batch_size_range = [2**i for i in range(0, 7, 2)]
|
146 |
+
seq_length_range = [2**i for i in range(6, 11, 1)]
|
147 |
+
head_num_range = [32, 48]
|
148 |
+
configs = list(itertools.product(head_num_range, batch_size_range, seq_length_range))
|
149 |
+
|
150 |
+
|
151 |
+
def get_benchmark(use_residual):
|
152 |
+
@triton.testing.perf_report(
|
153 |
+
triton.testing.Benchmark(
|
154 |
+
x_names=["head_num", "batch_size", "seq_len"],
|
155 |
+
x_vals=[list(_) for _ in configs],
|
156 |
+
line_arg="provider",
|
157 |
+
line_vals=["huggingface", "flashinfer", "vllm"],
|
158 |
+
line_names=["HuggingFace", "FlashInfer", "vLLM"],
|
159 |
+
styles=[("blue", "-"), ("green", "-"), ("red", "-")],
|
160 |
+
ylabel="us",
|
161 |
+
plot_name=f"rmsnorm-performance-{'with' if use_residual else 'without'}-residual",
|
162 |
+
args={},
|
163 |
+
)
|
164 |
+
)
|
165 |
+
def benchmark(head_num, batch_size, seq_len, provider):
|
166 |
+
dtype = torch.bfloat16
|
167 |
+
hidden_size = head_num * 128 # assuming head_dim = 128
|
168 |
+
|
169 |
+
x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device="cuda")
|
170 |
+
weight = torch.ones(hidden_size, dtype=dtype, device="cuda")
|
171 |
+
residual = torch.randn_like(x) if use_residual else None
|
172 |
+
|
173 |
+
quantiles = [0.5, 0.2, 0.8]
|
174 |
+
|
175 |
+
if provider == "huggingface":
|
176 |
+
ms, min_ms, max_ms = triton.testing.do_bench(
|
177 |
+
lambda: rmsnorm_naive(
|
178 |
+
x.clone(),
|
179 |
+
weight,
|
180 |
+
residual.clone() if residual is not None else None,
|
181 |
+
),
|
182 |
+
quantiles=quantiles,
|
183 |
+
)
|
184 |
+
elif provider == "flashinfer":
|
185 |
+
ms, min_ms, max_ms = triton.testing.do_bench(
|
186 |
+
lambda: rmsnorm_flashinfer(
|
187 |
+
x.clone(),
|
188 |
+
weight,
|
189 |
+
residual.clone() if residual is not None else None,
|
190 |
+
),
|
191 |
+
quantiles=quantiles,
|
192 |
+
)
|
193 |
+
else:
|
194 |
+
ms, min_ms, max_ms = triton.testing.do_bench(
|
195 |
+
lambda: rmsnorm_vllm(
|
196 |
+
x.clone(),
|
197 |
+
weight,
|
198 |
+
residual.clone() if residual is not None else None,
|
199 |
+
),
|
200 |
+
quantiles=quantiles,
|
201 |
+
)
|
202 |
+
|
203 |
+
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
204 |
+
|
205 |
+
return benchmark
|
206 |
+
|
207 |
+
|
208 |
+
if __name__ == "__main__":
|
209 |
+
import argparse
|
210 |
+
|
211 |
+
parser = argparse.ArgumentParser()
|
212 |
+
parser.add_argument(
|
213 |
+
"--use_residual", action="store_true", help="Whether to use residual connection"
|
214 |
+
)
|
215 |
+
parser.add_argument(
|
216 |
+
"--save_path",
|
217 |
+
type=str,
|
218 |
+
default="./configs/benchmark_ops/rmsnorm/",
|
219 |
+
help="Path to save rmsnorm benchmark results",
|
220 |
+
)
|
221 |
+
args = parser.parse_args()
|
222 |
+
|
223 |
+
# Run correctness test
|
224 |
+
calculate_diff(
|
225 |
+
batch_size=4, seq_len=128, hidden_size=4096, use_residual=args.use_residual
|
226 |
+
)
|
227 |
+
|
228 |
+
# Get the benchmark function with proper use_residual setting
|
229 |
+
benchmark = get_benchmark(args.use_residual)
|
230 |
+
# Run performance benchmark
|
231 |
+
benchmark.run(print_data=True, save_path=args.save_path)
|
sglang/benchmark/kernels/scheduler_batch/benchmark_write_req_to_token_pool_triton.py
ADDED
@@ -0,0 +1,345 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import itertools
|
2 |
+
import os
|
3 |
+
from typing import List
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import pytest
|
7 |
+
import torch
|
8 |
+
import triton
|
9 |
+
import triton.language as tl
|
10 |
+
|
11 |
+
|
12 |
+
@triton.jit
|
13 |
+
def write_req_to_token_pool_triton(
|
14 |
+
req_to_token_ptr, # [max_batch, max_context_len]
|
15 |
+
req_pool_indices,
|
16 |
+
pre_lens,
|
17 |
+
seq_lens,
|
18 |
+
extend_lens,
|
19 |
+
out_cache_loc,
|
20 |
+
req_to_token_ptr_stride: tl.constexpr,
|
21 |
+
):
|
22 |
+
BLOCK_SIZE: tl.constexpr = 512
|
23 |
+
pid = tl.program_id(0)
|
24 |
+
|
25 |
+
req_pool_index = tl.load(req_pool_indices + pid)
|
26 |
+
pre_len = tl.load(pre_lens + pid)
|
27 |
+
seq_len = tl.load(seq_lens + pid)
|
28 |
+
|
29 |
+
# TODO: optimize this?
|
30 |
+
cumsum_start = 0
|
31 |
+
for i in range(pid):
|
32 |
+
cumsum_start += tl.load(extend_lens + i)
|
33 |
+
|
34 |
+
num_loop = tl.cdiv(seq_len - pre_len, BLOCK_SIZE)
|
35 |
+
for i in range(num_loop):
|
36 |
+
offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
|
37 |
+
mask = offset < (seq_len - pre_len)
|
38 |
+
value = tl.load(out_cache_loc + cumsum_start + offset, mask=mask)
|
39 |
+
tl.store(
|
40 |
+
req_to_token_ptr
|
41 |
+
+ req_pool_index * req_to_token_ptr_stride
|
42 |
+
+ offset
|
43 |
+
+ pre_len,
|
44 |
+
value,
|
45 |
+
mask=mask,
|
46 |
+
)
|
47 |
+
|
48 |
+
|
49 |
+
@triton.jit
|
50 |
+
def write_req_to_token_pool_triton_optimize(
|
51 |
+
req_to_token_ptr, # [max_batch, max_context_len]
|
52 |
+
req_pool_indices,
|
53 |
+
pre_lens,
|
54 |
+
seq_lens,
|
55 |
+
extend_lens,
|
56 |
+
out_cache_loc,
|
57 |
+
req_to_token_ptr_stride: tl.constexpr,
|
58 |
+
BLOCK_SIZE: tl.constexpr,
|
59 |
+
):
|
60 |
+
pid_batch = tl.program_id(0)
|
61 |
+
pid_token = tl.program_id(1)
|
62 |
+
|
63 |
+
req_pool_index = tl.load(req_pool_indices + pid_batch)
|
64 |
+
pre_len = tl.load(pre_lens + pid_batch)
|
65 |
+
seq_len = tl.load(seq_lens + pid_batch)
|
66 |
+
extend_len = seq_len - pre_len
|
67 |
+
|
68 |
+
cumsum_start = 0
|
69 |
+
for i in range(pid_batch):
|
70 |
+
cumsum_start += tl.load(extend_lens + i)
|
71 |
+
|
72 |
+
token_start = pid_token * BLOCK_SIZE
|
73 |
+
|
74 |
+
offset = tl.arange(0, BLOCK_SIZE)
|
75 |
+
actual_offset = token_start + offset
|
76 |
+
mask = actual_offset < extend_len
|
77 |
+
|
78 |
+
src_ptr = out_cache_loc + cumsum_start + actual_offset
|
79 |
+
src_ptr = tl.max_contiguous(tl.multiple_of(src_ptr, BLOCK_SIZE), BLOCK_SIZE)
|
80 |
+
value = tl.load(src_ptr, mask=mask)
|
81 |
+
dst_ptr = (
|
82 |
+
req_to_token_ptr
|
83 |
+
+ req_pool_index * req_to_token_ptr_stride
|
84 |
+
+ actual_offset
|
85 |
+
+ pre_len
|
86 |
+
)
|
87 |
+
dst_ptr = tl.max_contiguous(tl.multiple_of(dst_ptr, BLOCK_SIZE), BLOCK_SIZE)
|
88 |
+
|
89 |
+
tl.store(dst_ptr, value, mask=mask)
|
90 |
+
|
91 |
+
|
92 |
+
def write_req_to_token_pool_reference(
|
93 |
+
req_to_token: torch.Tensor,
|
94 |
+
req_pool_indices: torch.Tensor,
|
95 |
+
pre_lens: torch.Tensor,
|
96 |
+
seq_lens: torch.Tensor,
|
97 |
+
extend_lens: torch.Tensor,
|
98 |
+
out_cache_loc: torch.Tensor,
|
99 |
+
) -> None:
|
100 |
+
"""Reference implementation using PyTorch"""
|
101 |
+
for i in range(len(req_pool_indices)):
|
102 |
+
req_pool_idx = req_pool_indices[i].item()
|
103 |
+
pre_len = pre_lens[i].item()
|
104 |
+
seq_len = seq_lens[i].item()
|
105 |
+
extend_len = extend_lens[i].item()
|
106 |
+
|
107 |
+
cumsum_start = sum(extend_lens[:i].tolist())
|
108 |
+
|
109 |
+
# Copy values from out_cache_loc to req_to_token
|
110 |
+
req_to_token[req_pool_idx, pre_len:seq_len] = out_cache_loc[
|
111 |
+
cumsum_start : cumsum_start + extend_len
|
112 |
+
]
|
113 |
+
|
114 |
+
|
115 |
+
def test_write_req_to_token_pool():
|
116 |
+
max_batch = 4097
|
117 |
+
max_context_len = 6148
|
118 |
+
batch_size = 1
|
119 |
+
extend_len = 14
|
120 |
+
|
121 |
+
# Initialize input tensors
|
122 |
+
req_to_token = torch.zeros(
|
123 |
+
(max_batch, max_context_len), dtype=torch.int32, device="cuda"
|
124 |
+
)
|
125 |
+
req_pool_indices = torch.tensor([42], dtype=torch.int32, device="cuda")
|
126 |
+
pre_lens = torch.tensor([8], dtype=torch.int32, device="cuda")
|
127 |
+
seq_lens = torch.tensor([22], dtype=torch.int32, device="cuda")
|
128 |
+
extend_lens = torch.tensor([extend_len], dtype=torch.int32, device="cuda")
|
129 |
+
out_cache_loc = torch.arange(extend_len, dtype=torch.int32, device="cuda")
|
130 |
+
|
131 |
+
# Create copies for reference implementation
|
132 |
+
req_to_token_ref = req_to_token.clone()
|
133 |
+
req_to_token_opt = req_to_token.clone()
|
134 |
+
|
135 |
+
# Run original triton kernel
|
136 |
+
write_req_to_token_pool_triton[(batch_size,)](
|
137 |
+
req_to_token,
|
138 |
+
req_pool_indices,
|
139 |
+
pre_lens,
|
140 |
+
seq_lens,
|
141 |
+
extend_lens,
|
142 |
+
out_cache_loc,
|
143 |
+
max_context_len,
|
144 |
+
)
|
145 |
+
|
146 |
+
# Run optimized triton kernel
|
147 |
+
def grid(batch_size, extend_len):
|
148 |
+
num_token_blocks = triton.cdiv(extend_len, 512)
|
149 |
+
return (batch_size, num_token_blocks)
|
150 |
+
|
151 |
+
write_req_to_token_pool_triton_optimize[grid(batch_size, extend_len)](
|
152 |
+
req_to_token_opt,
|
153 |
+
req_pool_indices,
|
154 |
+
pre_lens,
|
155 |
+
seq_lens,
|
156 |
+
extend_lens,
|
157 |
+
out_cache_loc,
|
158 |
+
max_context_len,
|
159 |
+
BLOCK_SIZE=512,
|
160 |
+
)
|
161 |
+
|
162 |
+
# Run reference implementation
|
163 |
+
write_req_to_token_pool_reference(
|
164 |
+
req_to_token_ref,
|
165 |
+
req_pool_indices,
|
166 |
+
pre_lens,
|
167 |
+
seq_lens,
|
168 |
+
extend_lens,
|
169 |
+
out_cache_loc,
|
170 |
+
)
|
171 |
+
|
172 |
+
# Compare results
|
173 |
+
torch.testing.assert_close(req_to_token, req_to_token_ref)
|
174 |
+
torch.testing.assert_close(req_to_token_opt, req_to_token_ref)
|
175 |
+
|
176 |
+
# Test case 2: batch size > 1
|
177 |
+
batch_size = 3
|
178 |
+
extend_lens_list = [14, 20, 30]
|
179 |
+
total_extend_len = sum(extend_lens_list)
|
180 |
+
|
181 |
+
req_to_token = torch.zeros(
|
182 |
+
(max_batch, max_context_len), dtype=torch.int32, device="cuda"
|
183 |
+
)
|
184 |
+
req_pool_indices = torch.tensor([42, 100, 200], dtype=torch.int32, device="cuda")
|
185 |
+
pre_lens = torch.tensor([8, 10, 15], dtype=torch.int32, device="cuda")
|
186 |
+
seq_lens = torch.tensor([22, 30, 45], dtype=torch.int32, device="cuda")
|
187 |
+
extend_lens = torch.tensor(extend_lens_list, dtype=torch.int32, device="cuda")
|
188 |
+
out_cache_loc = torch.arange(total_extend_len, dtype=torch.int32, device="cuda")
|
189 |
+
|
190 |
+
req_to_token_ref = req_to_token.clone()
|
191 |
+
req_to_token_opt = req_to_token.clone()
|
192 |
+
|
193 |
+
# Run original triton kernel
|
194 |
+
write_req_to_token_pool_triton[(batch_size,)](
|
195 |
+
req_to_token,
|
196 |
+
req_pool_indices,
|
197 |
+
pre_lens,
|
198 |
+
seq_lens,
|
199 |
+
extend_lens,
|
200 |
+
out_cache_loc,
|
201 |
+
max_context_len,
|
202 |
+
)
|
203 |
+
|
204 |
+
# Run optimized triton kernel
|
205 |
+
max_extend_len = max(extend_lens_list)
|
206 |
+
write_req_to_token_pool_triton_optimize[grid(batch_size, max_extend_len)](
|
207 |
+
req_to_token_opt,
|
208 |
+
req_pool_indices,
|
209 |
+
pre_lens,
|
210 |
+
seq_lens,
|
211 |
+
extend_lens,
|
212 |
+
out_cache_loc,
|
213 |
+
max_context_len,
|
214 |
+
BLOCK_SIZE=512,
|
215 |
+
)
|
216 |
+
|
217 |
+
# Run reference implementation
|
218 |
+
write_req_to_token_pool_reference(
|
219 |
+
req_to_token_ref,
|
220 |
+
req_pool_indices,
|
221 |
+
pre_lens,
|
222 |
+
seq_lens,
|
223 |
+
extend_lens,
|
224 |
+
out_cache_loc,
|
225 |
+
)
|
226 |
+
|
227 |
+
# Compare results
|
228 |
+
torch.testing.assert_close(req_to_token, req_to_token_ref)
|
229 |
+
torch.testing.assert_close(req_to_token_opt, req_to_token_ref)
|
230 |
+
|
231 |
+
|
232 |
+
def get_benchmark():
|
233 |
+
batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128]
|
234 |
+
extend_lens = [32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]
|
235 |
+
configs = list(itertools.product(batch_sizes, extend_lens))
|
236 |
+
|
237 |
+
@triton.testing.perf_report(
|
238 |
+
triton.testing.Benchmark(
|
239 |
+
x_names=["batch_size", "extend_len"],
|
240 |
+
x_vals=configs,
|
241 |
+
line_arg="provider",
|
242 |
+
line_vals=["reference", "triton", "triton_optimize"],
|
243 |
+
line_names=["PyTorch", "Triton", "Triton Optimized"],
|
244 |
+
styles=[("blue", "-"), ("green", "-"), ("red", "-")],
|
245 |
+
ylabel="us",
|
246 |
+
plot_name="write-req-to-token-pool-performance",
|
247 |
+
args={},
|
248 |
+
)
|
249 |
+
)
|
250 |
+
def benchmark(batch_size, extend_len, provider):
|
251 |
+
max_batch = 256
|
252 |
+
max_context_len = 16384
|
253 |
+
|
254 |
+
extend_lens_list = [extend_len] * batch_size
|
255 |
+
total_extend_len = sum(extend_lens_list)
|
256 |
+
|
257 |
+
req_to_token = torch.zeros(
|
258 |
+
(max_batch, max_context_len), dtype=torch.int32, device="cuda"
|
259 |
+
)
|
260 |
+
req_pool_indices = torch.arange(batch_size, dtype=torch.int32, device="cuda")
|
261 |
+
pre_lens = torch.ones(batch_size, dtype=torch.int32, device="cuda") * 8
|
262 |
+
seq_lens = pre_lens + extend_len
|
263 |
+
extend_lens = torch.tensor(extend_lens_list, dtype=torch.int32, device="cuda")
|
264 |
+
out_cache_loc = torch.arange(total_extend_len, dtype=torch.int32, device="cuda")
|
265 |
+
|
266 |
+
quantiles = [0.5, 0.2, 0.8]
|
267 |
+
|
268 |
+
if provider == "reference":
|
269 |
+
ms, min_ms, max_ms = triton.testing.do_bench(
|
270 |
+
lambda: write_req_to_token_pool_reference(
|
271 |
+
req_to_token.clone(),
|
272 |
+
req_pool_indices,
|
273 |
+
pre_lens,
|
274 |
+
seq_lens,
|
275 |
+
extend_lens,
|
276 |
+
out_cache_loc,
|
277 |
+
),
|
278 |
+
quantiles=quantiles,
|
279 |
+
)
|
280 |
+
elif provider == "triton":
|
281 |
+
ms, min_ms, max_ms = triton.testing.do_bench(
|
282 |
+
lambda: write_req_to_token_pool_triton[(batch_size,)](
|
283 |
+
req_to_token.clone(),
|
284 |
+
req_pool_indices,
|
285 |
+
pre_lens,
|
286 |
+
seq_lens,
|
287 |
+
extend_lens,
|
288 |
+
out_cache_loc,
|
289 |
+
max_context_len,
|
290 |
+
),
|
291 |
+
quantiles=quantiles,
|
292 |
+
)
|
293 |
+
else:
|
294 |
+
|
295 |
+
def run_optimized():
|
296 |
+
block_size = 128 if extend_len <= 1024 else 512
|
297 |
+
grid_config = (batch_size, triton.cdiv(extend_len, block_size))
|
298 |
+
write_req_to_token_pool_triton_optimize[grid_config](
|
299 |
+
req_to_token.clone(),
|
300 |
+
req_pool_indices,
|
301 |
+
pre_lens,
|
302 |
+
seq_lens,
|
303 |
+
extend_lens,
|
304 |
+
out_cache_loc,
|
305 |
+
max_context_len,
|
306 |
+
BLOCK_SIZE=block_size,
|
307 |
+
)
|
308 |
+
|
309 |
+
ms, min_ms, max_ms = triton.testing.do_bench(
|
310 |
+
run_optimized, quantiles=quantiles
|
311 |
+
)
|
312 |
+
|
313 |
+
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
314 |
+
|
315 |
+
return benchmark
|
316 |
+
|
317 |
+
|
318 |
+
def run_benchmark(save_path: str = "./configs/benchmark_ops/write_req_to_token_pool/"):
|
319 |
+
"""Run benchmark and save results"""
|
320 |
+
|
321 |
+
# Ensure save path exists
|
322 |
+
os.makedirs(save_path, exist_ok=True)
|
323 |
+
|
324 |
+
# Run correctness test
|
325 |
+
test_write_req_to_token_pool()
|
326 |
+
print("Correctness test passed!")
|
327 |
+
|
328 |
+
# Run performance test
|
329 |
+
benchmark = get_benchmark()
|
330 |
+
benchmark.run(print_data=True, save_path=save_path)
|
331 |
+
|
332 |
+
|
333 |
+
if __name__ == "__main__":
|
334 |
+
import argparse
|
335 |
+
|
336 |
+
parser = argparse.ArgumentParser()
|
337 |
+
parser.add_argument(
|
338 |
+
"--save_path",
|
339 |
+
type=str,
|
340 |
+
default="./configs/benchmark_ops/write_req_to_token_pool/",
|
341 |
+
help="Path to save benchmark results",
|
342 |
+
)
|
343 |
+
args = parser.parse_args()
|
344 |
+
|
345 |
+
run_benchmark(args.save_path)
|
sglang/benchmark/line_retrieval/README.md
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Download data
|
2 |
+
|
3 |
+
```
|
4 |
+
wget https://raw.githubusercontent.com/merrymercy/merrymercy.github.io/master/files/random_words.json
|
5 |
+
python3 gen_data.py --number 1000
|
6 |
+
```
|
7 |
+
|
8 |
+
## Run benchmark
|
9 |
+
|
10 |
+
### Benchmark sglang
|
11 |
+
```
|
12 |
+
python3 -m sglang.launch_server --model-path codellama/CodeLlama-7b-hf --port 30000
|
13 |
+
```
|
14 |
+
|
15 |
+
```
|
16 |
+
python3 bench_sglang.py --src-index 600 --num-q 50 --parallel 1
|
17 |
+
```
|
18 |
+
|
19 |
+
|
20 |
+
###
|
21 |
+
|
22 |
+
```
|
23 |
+
# original
|
24 |
+
Accuracy: 0.940, latency: 332.83 s
|
25 |
+
|
26 |
+
# parallel encoding (no_adjust, offset = 1000)
|
27 |
+
Accuracy: 0.760, latency: 238.46 s
|
28 |
+
|
29 |
+
# parallel encoding (no_adjust, offset = 3000)
|
30 |
+
Accuracy: 0.760, latency: 238.46 s
|
31 |
+
|
32 |
+
# parallel encoding (no_adjust, offset = 0)
|
33 |
+
Accuracy: 0.520, latency: 238.46 s
|
34 |
+
|
35 |
+
# parallel encoding (adjust_cache)
|
36 |
+
Accuracy: 0.460, latency: 257.66 s
|
37 |
+
```
|
sglang/benchmark/line_retrieval/bench_sglang.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import re
|
4 |
+
import time
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
import sglang as sgl
|
9 |
+
from sglang.test.test_utils import (
|
10 |
+
add_common_sglang_args_and_parse,
|
11 |
+
select_sglang_backend,
|
12 |
+
)
|
13 |
+
from sglang.utils import dump_state_text
|
14 |
+
|
15 |
+
|
16 |
+
@sgl.function
|
17 |
+
def line_retrieval(s, prefix, suffix, body_0, body_1, body_2, body_3):
|
18 |
+
s += prefix + "\n"
|
19 |
+
|
20 |
+
contexts = [body_0, body_1, body_2, body_3]
|
21 |
+
position_ids_offset = [i * 1000 for i in range(len(contexts))]
|
22 |
+
forks = s.fork(len(contexts), position_ids_offset)
|
23 |
+
forks += lambda i: contexts[i] + "\n"
|
24 |
+
forks.join(mode="concate_and_append")
|
25 |
+
|
26 |
+
s += "\n" + suffix
|
27 |
+
s += sgl.gen("answer", max_tokens=16)
|
28 |
+
|
29 |
+
|
30 |
+
def eval_model(args, line_obj, num_hoops, src_indices, dst_percents):
|
31 |
+
arguments = []
|
32 |
+
labels = []
|
33 |
+
sum_src_indices = []
|
34 |
+
sum_dst_indices = []
|
35 |
+
|
36 |
+
for i in range(len(src_indices)):
|
37 |
+
for j in range(len(dst_percents)):
|
38 |
+
src_index = src_indices[i]
|
39 |
+
dst_percent = dst_percents[j]
|
40 |
+
|
41 |
+
query_indices = line_obj["group_by_num_hoops"][str(num_hoops)]
|
42 |
+
query_indices = [
|
43 |
+
q
|
44 |
+
for q in query_indices
|
45 |
+
if all(l <= src_index for l in line_obj["links"][q]) and q < src_index
|
46 |
+
]
|
47 |
+
dst_index = query_indices[
|
48 |
+
min(int(len(query_indices) * dst_percent), len(query_indices) - 1)
|
49 |
+
]
|
50 |
+
label = line_obj["values"][dst_index]
|
51 |
+
|
52 |
+
body = line_obj["lines"][: src_index + 1]
|
53 |
+
suffix = line_obj["suffix"].replace("???", line_obj["indices"][dst_index])
|
54 |
+
body_part_len = len(body) // 4
|
55 |
+
|
56 |
+
arguments.append(
|
57 |
+
{
|
58 |
+
"prefix": line_obj["prefix"],
|
59 |
+
"body_0": "\n".join(body[:body_part_len]),
|
60 |
+
"body_1": "\n".join(body[body_part_len : 2 * body_part_len]),
|
61 |
+
"body_2": "\n".join(body[2 * body_part_len : 3 * body_part_len]),
|
62 |
+
"body_3": "\n".join(body[3 * body_part_len :]),
|
63 |
+
"suffix": suffix,
|
64 |
+
}
|
65 |
+
)
|
66 |
+
labels.append(label)
|
67 |
+
sum_src_indices.append(src_index)
|
68 |
+
sum_dst_indices.append(dst_index)
|
69 |
+
|
70 |
+
# Select backend
|
71 |
+
backend = select_sglang_backend(args)
|
72 |
+
|
73 |
+
tic = time.time()
|
74 |
+
states = line_retrieval.run_batch(
|
75 |
+
arguments,
|
76 |
+
temperature=0,
|
77 |
+
backend=backend,
|
78 |
+
num_threads=args.parallel,
|
79 |
+
progress_bar=True,
|
80 |
+
)
|
81 |
+
latency = time.time() - tic
|
82 |
+
|
83 |
+
corrects = []
|
84 |
+
for i in range(len(arguments)):
|
85 |
+
output = states[i]["answer"]
|
86 |
+
prompt_len = states[i].get_meta_info("answer").get("prompt_length", -1)
|
87 |
+
label = labels[i]
|
88 |
+
|
89 |
+
# Try all numbers
|
90 |
+
findall = re.findall("\d+", output)
|
91 |
+
if not findall:
|
92 |
+
response_number = output
|
93 |
+
else:
|
94 |
+
for response_number in findall:
|
95 |
+
if response_number == label:
|
96 |
+
break
|
97 |
+
|
98 |
+
correct = response_number == label
|
99 |
+
corrects.append(correct)
|
100 |
+
|
101 |
+
# Log results
|
102 |
+
summary = (
|
103 |
+
f"Line index: {sum_src_indices[i]} -> {sum_dst_indices[i]}, "
|
104 |
+
f"Prompt len: {prompt_len}, "
|
105 |
+
f"Correct: {correct}, "
|
106 |
+
f"Label: {label}, Predicted: {response_number}, "
|
107 |
+
)
|
108 |
+
print(summary)
|
109 |
+
|
110 |
+
accuracy = np.mean(corrects)
|
111 |
+
print(f"Accuracy: {accuracy:.3f}, latency: {latency:.2f} s")
|
112 |
+
|
113 |
+
# Write results
|
114 |
+
dump_state_text(f"tmp_output_{args.backend}.txt", states)
|
115 |
+
|
116 |
+
with open(args.result_file, "a") as fout:
|
117 |
+
value = {
|
118 |
+
"task": "line_retrieval",
|
119 |
+
"backend": args.backend,
|
120 |
+
"num_gpus": 1,
|
121 |
+
"latency": round(latency, 3),
|
122 |
+
"num_requests": len(arguments),
|
123 |
+
"other": {
|
124 |
+
"num_questions": len(arguments),
|
125 |
+
"parallel": args.parallel,
|
126 |
+
},
|
127 |
+
}
|
128 |
+
fout.write(json.dumps(value) + "\n")
|
129 |
+
|
130 |
+
|
131 |
+
def main(args):
|
132 |
+
line_obj = json.load(open(args.data_path, "r"))
|
133 |
+
|
134 |
+
num_hoops = args.num_hoops
|
135 |
+
for src_index in args.src_index:
|
136 |
+
src_indices = [src_index]
|
137 |
+
num_queries = args.num_queries_per_src
|
138 |
+
dst_percents = [i * (1 / (num_queries)) for i in range(num_queries)]
|
139 |
+
eval_model(args, line_obj, num_hoops, src_indices, dst_percents)
|
140 |
+
|
141 |
+
|
142 |
+
if __name__ == "__main__":
|
143 |
+
parser = argparse.ArgumentParser()
|
144 |
+
parser.add_argument("--data-path", type=str, default="lines_1000_0.0.json")
|
145 |
+
parser.add_argument("--src-index", type=int, nargs="+", default=[100])
|
146 |
+
parser.add_argument("--num-queries-per-src", type=int, default=10)
|
147 |
+
parser.add_argument("--num-hoops", type=int, default=1)
|
148 |
+
args = add_common_sglang_args_and_parse(parser)
|
149 |
+
main(args)
|
sglang/benchmark/line_retrieval/gen_data.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Generate line data for line retrieval task.
|
3 |
+
|
4 |
+
Usage:
|
5 |
+
python3 gen_data.py --number 1000
|
6 |
+
"""
|
7 |
+
|
8 |
+
import argparse
|
9 |
+
import json
|
10 |
+
from collections import defaultdict
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
from tqdm import tqdm
|
14 |
+
|
15 |
+
|
16 |
+
def generate_lines(random_words, num_lines, redirect_ratio):
|
17 |
+
prefix = "Here is a list of lines, each with its corresponding REGISTER_CONTENT value. Please memorize them. Be prepared to provide the REGISTER_CONTENT value for a specific line index when I ask."
|
18 |
+
suffix = "The list has ended. Please give the final REGISTER_CONTENT value for a specific line after resovling the redirections and references. For example, the REGISTER_CONTENT of Line __idx0__ is __val0__. The REGISTER_CONTENT of Line __idx1__ is __val1__. The REGISTER_CONTENT of Line __idx2__ is __val2__. The REGISTER_CONTENT of Line ??? is"
|
19 |
+
|
20 |
+
# Raw lines
|
21 |
+
visited_indices = set([None])
|
22 |
+
visited_values = set([None])
|
23 |
+
|
24 |
+
lines = []
|
25 |
+
redirects = []
|
26 |
+
indices = []
|
27 |
+
values = []
|
28 |
+
for i in tqdm(range(num_lines)):
|
29 |
+
line_index = None
|
30 |
+
while line_index in visited_indices:
|
31 |
+
line_index = "-".join(np.random.choice(random_words, size=(2,)))
|
32 |
+
visited_indices.add(line_index)
|
33 |
+
|
34 |
+
line_value = np.random.randint(low=0, high=999999)
|
35 |
+
line_value = f"{line_value:06}"
|
36 |
+
|
37 |
+
line = f"Line {line_index}: The REGISTER_CONTENT is {line_value}."
|
38 |
+
lines.append(line)
|
39 |
+
redirects.append(None)
|
40 |
+
indices.append(line_index)
|
41 |
+
values.append(line_value)
|
42 |
+
|
43 |
+
# Add redirect
|
44 |
+
if redirect_ratio > 0:
|
45 |
+
num_redirect_lines = int(len(lines) * redirect_ratio)
|
46 |
+
redirect_indices = np.random.choice(
|
47 |
+
np.arange(len(lines)), size=(num_redirect_lines,), replace=False
|
48 |
+
)
|
49 |
+
for i in redirect_indices:
|
50 |
+
target_idx = np.random.choice(min(i * 2 + 100, num_lines))
|
51 |
+
lines[i] = (
|
52 |
+
f"Line {indices[i]}: The REGISTER_CONTENT is the same as Line {indices[target_idx]}."
|
53 |
+
)
|
54 |
+
redirects[i] = target_idx
|
55 |
+
|
56 |
+
# Build links and find sources
|
57 |
+
links = [[] for _ in range(num_lines)]
|
58 |
+
contains_ring = set()
|
59 |
+
for i in range(num_lines):
|
60 |
+
if redirects[i] is None:
|
61 |
+
continue
|
62 |
+
|
63 |
+
tmp_link = []
|
64 |
+
cur = i
|
65 |
+
visited = set()
|
66 |
+
while redirects[cur] is not None:
|
67 |
+
visited.add(cur)
|
68 |
+
tmp_link.append(redirects[cur])
|
69 |
+
cur = redirects[cur]
|
70 |
+
|
71 |
+
if cur in visited:
|
72 |
+
contains_ring.add(i)
|
73 |
+
tmp_link = None
|
74 |
+
break
|
75 |
+
values[i] = values[cur]
|
76 |
+
links[i] = tmp_link
|
77 |
+
|
78 |
+
# Group by num_links
|
79 |
+
group_by_num_hoops = defaultdict(list)
|
80 |
+
for i in range(num_lines):
|
81 |
+
if i in contains_ring:
|
82 |
+
continue
|
83 |
+
group_by_num_hoops[len(links[i]) + 1].append(i)
|
84 |
+
|
85 |
+
keys = sorted(list(group_by_num_hoops.keys()))
|
86 |
+
for num_links in keys:
|
87 |
+
print(f"#links: {num_links}, #lines: {len(group_by_num_hoops[num_links])}")
|
88 |
+
|
89 |
+
# Append few-shot examples
|
90 |
+
hoop1_candidates = list(group_by_num_hoops[1])
|
91 |
+
hoop1_candidate_keys = {c: max([c] + links[c]) for c in hoop1_candidates}
|
92 |
+
hoop1_candidates.sort(key=lambda c: hoop1_candidate_keys[c])
|
93 |
+
hoop2_candidates = list(group_by_num_hoops[2])
|
94 |
+
hoop2_candidate_keys = {c: max([c] + links[c]) for c in hoop2_candidates}
|
95 |
+
hoop2_candidates.sort(key=lambda c: hoop2_candidate_keys[c])
|
96 |
+
|
97 |
+
i = hoop1_candidates[5]
|
98 |
+
suffix = suffix.replace("__idx0__", indices[i]).replace("__val0__", values[i])
|
99 |
+
if len(hoop2_candidates):
|
100 |
+
i = hoop2_candidates[0]
|
101 |
+
suffix = suffix.replace("__idx1__", indices[i]).replace("__val1__", values[i])
|
102 |
+
i = hoop2_candidates[1]
|
103 |
+
suffix = suffix.replace("__idx2__", indices[i]).replace("__val2__", values[i])
|
104 |
+
else:
|
105 |
+
i = hoop1_candidates[1]
|
106 |
+
suffix = suffix.replace("__idx1__", indices[i]).replace("__val1__", values[i])
|
107 |
+
i = hoop1_candidates[10]
|
108 |
+
suffix = suffix.replace("__idx2__", indices[i]).replace("__val2__", values[i])
|
109 |
+
|
110 |
+
obj = {
|
111 |
+
"prefix": prefix,
|
112 |
+
"suffix": suffix,
|
113 |
+
"lines": lines,
|
114 |
+
"indices": indices,
|
115 |
+
"values": values,
|
116 |
+
"links": links,
|
117 |
+
"group_by_num_hoops": group_by_num_hoops,
|
118 |
+
"contains_ring": sorted(list(contains_ring)),
|
119 |
+
}
|
120 |
+
return obj
|
121 |
+
|
122 |
+
|
123 |
+
if __name__ == "__main__":
|
124 |
+
parser = argparse.ArgumentParser()
|
125 |
+
parser.add_argument("--number", type=int)
|
126 |
+
parser.add_argument("--redirect-ratio", type=float, default=0.0)
|
127 |
+
args = parser.parse_args()
|
128 |
+
|
129 |
+
num_lines = args.number
|
130 |
+
|
131 |
+
random_words_filename = "random_words.json"
|
132 |
+
random_words = json.load(open(random_words_filename, "r"))
|
133 |
+
|
134 |
+
np.random.seed(42)
|
135 |
+
obj = generate_lines(random_words, num_lines, args.redirect_ratio)
|
136 |
+
|
137 |
+
fout = f"lines_{num_lines}_{args.redirect_ratio:.1f}.json"
|
138 |
+
with open(fout, "w") as fout:
|
139 |
+
json.dump(obj, fout, indent=2)
|
sglang/benchmark/llava_bench/README.md
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Download benchmark images
|
2 |
+
|
3 |
+
```
|
4 |
+
python3 download_images.py
|
5 |
+
```
|
6 |
+
|
7 |
+
image benchmark source: https://huggingface.co/datasets/liuhaotian/llava-bench-in-the-wild
|
8 |
+
|
9 |
+
### Other Dependency
|
10 |
+
```
|
11 |
+
pip3 install "sglang[all]"
|
12 |
+
pip3 install "torch>=2.1.2" "transformers>=4.36" pillow
|
13 |
+
```
|
14 |
+
|
15 |
+
## Run benchmark
|
16 |
+
|
17 |
+
### Benchmark sglang
|
18 |
+
Launch a server
|
19 |
+
```
|
20 |
+
python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.6-vicuna-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --port 30000
|
21 |
+
```
|
22 |
+
|
23 |
+
Run benchmark
|
24 |
+
```
|
25 |
+
# Run with local models
|
26 |
+
python3 bench_sglang.py --num-questions 60
|
27 |
+
|
28 |
+
# Run with OpenAI models
|
29 |
+
python3 bench_sglang.py --num-questions 60 --backend gpt-4-vision-preview
|
30 |
+
```
|
31 |
+
|
32 |
+
### Bench LLaVA original code
|
33 |
+
```
|
34 |
+
git clone [email protected]:haotian-liu/LLaVA.git
|
35 |
+
cd LLaVA
|
36 |
+
git reset --hard 9a26bd1435b4ac42c282757f2c16d34226575e96
|
37 |
+
pip3 install -e .
|
38 |
+
|
39 |
+
cd ~/sglang/benchmark/llava_bench
|
40 |
+
CUDA_VISIBLE_DEVICES=0 bash bench_hf_llava_bench.sh
|
41 |
+
```
|
42 |
+
|
43 |
+
|
44 |
+
### Benchmark llama.cpp
|
45 |
+
|
46 |
+
```
|
47 |
+
# Install
|
48 |
+
CMAKE_ARGS="-DLLAMA_CUBLAS=on" pip install llama-cpp-python
|
49 |
+
pip install sse_starlette starlette_context pydantic_settings
|
50 |
+
|
51 |
+
# Download weights
|
52 |
+
mkdir -p ~/model_weights/llava-v1.5-7b/
|
53 |
+
wget https://huggingface.co/mys/ggml_llava-v1.5-7b/resolve/main/ggml-model-f16.gguf -O ~/model_weights/llava-v1.5-7b/ggml-model-f16.gguf
|
54 |
+
wget https://huggingface.co/mys/ggml_llava-v1.5-7b/resolve/main/mmproj-model-f16.gguf -O ~/model_weights/llava-v1.5-7b/mmproj-model-f16.gguf
|
55 |
+
```
|
56 |
+
|
57 |
+
```
|
58 |
+
python3 -m llama_cpp.server --model ~/model_weights/llava-v1.5-7b/ggml-model-f16.gguf --clip_model_path ~/model_weights/llava-v1.5-7b/mmproj-model-f16.gguf --chat_format llava-1-5 --port 23000
|
59 |
+
|
60 |
+
OPENAI_BASE_URL=http://localhost:23000/v1 python3 bench_sglang.py --backend gpt-4-vision-preview --num-q 1
|
61 |
+
```
|
sglang/benchmark/llava_bench/bench_hf_llava_bench.sh
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
python -m llava.eval.model_vqa \
|
4 |
+
--model-path liuhaotian/llava-v1.5-7b \
|
5 |
+
--question-file ./questions.jsonl \
|
6 |
+
--image-folder ./images \
|
7 |
+
--answers-file ./answers_hf.jsonl \
|
8 |
+
--temperature 0 \
|
9 |
+
--conv-mode vicuna_v1
|
sglang/benchmark/llava_bench/bench_hf_mme.sh
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
python -m llava.eval.model_vqa_loader \
|
4 |
+
--model-path liuhaotian/llava-v1.5-7b \
|
5 |
+
--question-file ./mme_pack/llava_mme_bench_replace.jsonl \
|
6 |
+
--image-folder ./mme_pack/MME_Benchmark_release_version \
|
7 |
+
--answers-file ./answers_hf_mme.jsonl \
|
8 |
+
--temperature 0 \
|
9 |
+
--conv-mode vicuna_v1
|
sglang/benchmark/llava_bench/bench_sglang.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import time
|
5 |
+
|
6 |
+
import tqdm
|
7 |
+
|
8 |
+
import sglang as sgl
|
9 |
+
from sglang.test.test_utils import (
|
10 |
+
add_common_sglang_args_and_parse,
|
11 |
+
select_sglang_backend,
|
12 |
+
)
|
13 |
+
from sglang.utils import dump_state_text, read_jsonl
|
14 |
+
|
15 |
+
|
16 |
+
@sgl.function
|
17 |
+
def image_qa(s, image_file, question):
|
18 |
+
s += sgl.user(sgl.image(image_file) + question)
|
19 |
+
s += sgl.assistant(sgl.gen("answer", max_tokens=args.max_tokens))
|
20 |
+
|
21 |
+
|
22 |
+
def main(args):
|
23 |
+
lines = list(read_jsonl(args.question_file))[: args.num_questions]
|
24 |
+
arguments = [
|
25 |
+
{
|
26 |
+
"image_file": os.path.abspath(args.image_folder + "/" + l["image"]),
|
27 |
+
"question": l["text"],
|
28 |
+
}
|
29 |
+
for l in lines
|
30 |
+
]
|
31 |
+
# arguments = [
|
32 |
+
# {"image_file":
|
33 |
+
# Image.open(os.path.abspath(args.image_folder + "/" + l["image"])),
|
34 |
+
# "question": l["text"]} for l in lines
|
35 |
+
# ]
|
36 |
+
|
37 |
+
states = [None] * len(lines)
|
38 |
+
|
39 |
+
# Select backend
|
40 |
+
backend = select_sglang_backend(args)
|
41 |
+
sgl.set_default_backend(backend)
|
42 |
+
|
43 |
+
# Run requests
|
44 |
+
tic = time.time()
|
45 |
+
if args.parallel == 1:
|
46 |
+
for i in tqdm.tqdm(range(len(lines))):
|
47 |
+
image_file = arguments[i]["image_file"]
|
48 |
+
question = arguments[i]["question"]
|
49 |
+
ret = image_qa.run(image_file=image_file, question=question, temperature=0)
|
50 |
+
states[i] = ret
|
51 |
+
else:
|
52 |
+
states = image_qa.run_batch(
|
53 |
+
arguments, temperature=0, num_threads=args.parallel, progress_bar=True
|
54 |
+
)
|
55 |
+
latency = time.time() - tic
|
56 |
+
|
57 |
+
print(f"Latency: {latency:.3f}")
|
58 |
+
|
59 |
+
# Write results
|
60 |
+
dump_state_text(f"tmp_output_{args.backend}.txt", states)
|
61 |
+
|
62 |
+
print(f"Write output to {args.answer_file}")
|
63 |
+
with open(args.answer_file, "w") as fout:
|
64 |
+
for i in range(len(lines)):
|
65 |
+
value = {
|
66 |
+
"question_id": lines[i]["question_id"],
|
67 |
+
"prompt": lines[i]["text"],
|
68 |
+
"text": states[i]["answer"].strip(),
|
69 |
+
"model_id": backend.model_info["model_path"],
|
70 |
+
"answer_id": i,
|
71 |
+
"metadata": {},
|
72 |
+
}
|
73 |
+
fout.write(json.dumps(value) + "\n")
|
74 |
+
|
75 |
+
with open(args.result_file, "a") as fout:
|
76 |
+
value = {
|
77 |
+
"task": "llava_bench",
|
78 |
+
"backend": args.backend,
|
79 |
+
"num_gpus": 1,
|
80 |
+
"latency": round(latency, 3),
|
81 |
+
"num_requests": len(lines),
|
82 |
+
"parallel": args.parallel,
|
83 |
+
}
|
84 |
+
fout.write(json.dumps(value) + "\n")
|
85 |
+
|
86 |
+
|
87 |
+
if __name__ == "__main__":
|
88 |
+
parser = argparse.ArgumentParser()
|
89 |
+
parser.add_argument("--question-file", type=str, default="questions.jsonl")
|
90 |
+
parser.add_argument("--answer-file", type=str, default="answers.jsonl")
|
91 |
+
parser.add_argument("--image-folder", type=str, default="./images")
|
92 |
+
parser.add_argument("--temperature", type=float, default=0.0)
|
93 |
+
parser.add_argument("--num-questions", type=int, default=None)
|
94 |
+
parser.add_argument("--max-tokens", type=int, default=768)
|
95 |
+
args = add_common_sglang_args_and_parse(parser)
|
96 |
+
main(args)
|
sglang/benchmark/llava_bench/bench_sglang_mme.sh
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
MME_FOLDER=./mme_pack
|
2 |
+
python3 bench_sglang.py --num-questions 5000 --question-file $MME_FOLDER/llava_mme_bench_replace.jsonl --answer-file answer_mme.jsonl --image-folder $MME_FOLDER/MME_Benchmark_release_version --max-tokens 4
|
sglang/benchmark/llava_bench/download_images.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
# Create the 'images' directory if it doesn't exist
|
4 |
+
if not os.path.exists("images"):
|
5 |
+
os.makedirs("images")
|
6 |
+
|
7 |
+
# Base URL
|
8 |
+
base_url = "https://huggingface.co/datasets/liuhaotian/llava-bench-in-the-wild/resolve/main/images/"
|
9 |
+
|
10 |
+
# Loop through image numbers
|
11 |
+
for i in range(1, 25):
|
12 |
+
# Format the image number with leading zeros
|
13 |
+
image_number = str(i).zfill(3)
|
14 |
+
image_url = base_url + image_number + ".jpg"
|
15 |
+
image_path = "images/" + image_number + ".jpg"
|
16 |
+
|
17 |
+
# Download the image using wget
|
18 |
+
os.system(f"wget -O {image_path} {image_url}")
|
19 |
+
|
20 |
+
print("Download complete.")
|
sglang/benchmark/llava_bench/questions.jsonl
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{"image": "001.jpg", "text": "What is the name of this famous sight in the photo?", "category": "conv", "question_id": 0}
|
2 |
+
{"image": "001.jpg", "text": "Describe this photo in detail.", "category": "detail", "question_id": 1}
|
3 |
+
{"image": "001.jpg", "text": "What are the possible reasons of the formation of this sight?", "category": "complex", "question_id": 2}
|
4 |
+
{"image": "001.jpg", "text": "Compose an engaging travel blog post about a recent trip to this place, highlighting cultural experiences and must-see attractions, including both the attraction seen in the photo and other must-see attractions as well.", "category": "complex", "question_id": 3}
|
5 |
+
{"image": "002.jpg", "text": "What type of fruit is this?", "category": "conv", "question_id": 4}
|
6 |
+
{"image": "002.jpg", "text": "How many uncut fruits are in the image?", "category": "conv", "question_id": 5}
|
7 |
+
{"image": "002.jpg", "text": "Describe this photo in detail.", "category": "detail", "question_id": 6}
|
8 |
+
{"image": "002.jpg", "text": "Imagine the fragrance of the fruits in the image. How would you describe this to someone who has never had this fruit before?", "category": "complex", "question_id": 7}
|
9 |
+
{"image": "003.jpg", "text": "Describe this photo in detail.", "category": "detail", "question_id": 8}
|
10 |
+
{"image": "003.jpg", "text": "What might be the intended effect of this painting?", "category": "complex", "question_id": 9}
|
11 |
+
{"image": "003.jpg", "text": "Discuss how this creative twist on a classic work of art might be interpreted differently by various audiences.", "category": "complex", "question_id": 10}
|
12 |
+
{"image": "004.jpg", "text": "What is the name of the man in the photo?", "category": "conv", "question_id": 11}
|
13 |
+
{"image": "004.jpg", "text": "Which iconic movie scene is being parodied in the meme?", "category": "conv", "question_id": 12}
|
14 |
+
{"image": "004.jpg", "text": "How does this meme reflect or comment on Elon Musk's public image, personality, or actions?", "category": "complex", "question_id": 13}
|
15 |
+
{"image": "005.jpg", "text": "Please explain the meme in detail.", "category": "detail", "question_id": 14}
|
16 |
+
{"image": "005.jpg", "text": "In what other ways might someone express the same sentiment that this meme is expressing?", "category": "complex", "question_id": 15}
|
17 |
+
{"image": "006.jpg", "text": "Do you know who paint this?", "category": "conv", "question_id": 16}
|
18 |
+
{"image": "006.jpg", "text": "Describe this painting in detail.", "category": "detail", "question_id": 17}
|
19 |
+
{"image": "006.jpg", "text": "Discuss the historical impact and the significance of this painting in the art world.", "category": "complex", "question_id": 18}
|
20 |
+
{"image": "007.jpg", "text": "Describe this photo in detail.", "category": "detail", "question_id": 19}
|
21 |
+
{"image": "007.jpg", "text": "What's the best weather, season, time of the day of visiting this place? Is the time when this photo was taken a good time to visit this place?", "category": "complex", "question_id": 20}
|
22 |
+
{"image": "008.jpg", "text": "What is the name of the character in the image?", "category": "conv", "question_id": 21}
|
23 |
+
{"image": "008.jpg", "text": "What's the personality of this character? Explain what elements or aspects of the character's design may have contributed to its popularity.", "category": "complex", "question_id": 22}
|
24 |
+
{"image": "009.jpg", "text": "What are the things I should be cautious about when I visit here?", "category": "complex", "question_id": 23}
|
25 |
+
{"image": "009.jpg", "text": "If you were a photographer looking to capture this location's essence, what time of day and weather conditions would you choose? Describe the reasons behind your choice.", "category": "complex", "question_id": 24}
|
26 |
+
{"image": "010.jpg", "text": "Describe this photo in detail.", "category": "detail", "question_id": 25}
|
27 |
+
{"image": "010.jpg", "text": "What is unusual about this image?", "category": "complex", "question_id": 26}
|
28 |
+
{"image": "011.jpg", "text": "What fruit is in the left part of the fridge?", "category": "conv", "question_id": 27}
|
29 |
+
{"image": "011.jpg", "text": "What is the brand of the yogurt flavored with blueberry?", "category": "conv", "question_id": 28}
|
30 |
+
{"image": "011.jpg", "text": "Is there any strawberry-flavored yogurt in the fridge?", "category": "conv", "question_id": 29}
|
31 |
+
{"image": "011.jpg", "text": "Describe this photo in detail.", "category": "detail", "question_id": 30}
|
32 |
+
{"image": "011.jpg", "text": "What are the meals that I can cook with these?", "category": "complex", "question_id": 31}
|
33 |
+
{"image": "012.jpg", "text": "How many coffee mugs are in the set?", "category": "conv", "question_id": 32}
|
34 |
+
{"image": "012.jpg", "text": "Write an attractive product description for this.", "category": "complex", "question_id": 33}
|
35 |
+
{"image": "013.jpg", "text": "Show the detailed recipe for this dish.", "category": "complex", "question_id": 34}
|
36 |
+
{"image": "014.jpg", "text": "Can you explain this meme in detail?", "category": "complex", "question_id": 35}
|
37 |
+
{"image": "015.jpg", "text": "What are the two machine learning concepts mentioned in the meme?", "category": "conv", "question_id": 36}
|
38 |
+
{"image": "015.jpg", "text": "Give a detailed description of this meme.", "category": "detail", "question_id": 37}
|
39 |
+
{"image": "015.jpg", "text": "Can you explain why this is funny. Think about it step-by-step.", "category": "complex", "question_id": 38}
|
40 |
+
{"image": "016.jpg", "text": "Give a detailed description of this image. Describe it panel by panel.", "category": "detail", "question_id": 39}
|
41 |
+
{"image": "016.jpg", "text": "What is funny about this image? Describe it panel by panel.", "category": "complex", "question_id": 40}
|
42 |
+
{"image": "017.jpg", "text": "What material appears to make up the creature?", "category": "conv", "question_id": 41}
|
43 |
+
{"image": "017.jpg", "text": "This is the logo of LLaVA, Large Language and Vision Assistant, based on the LLaMA architecture. Please explain this logo in detail, and how do you think of its design.", "category": "complex", "question_id": 42}
|
44 |
+
{"image": "018.jpg", "text": "What are the animals in the painting and what are they doing?", "category": "conv", "question_id": 43}
|
45 |
+
{"image": "018.jpg", "text": "Write a fairy tale based on this painting.", "category": "complex", "question_id": 44}
|
46 |
+
{"image": "019.jpg", "text": "Describe this sketch in detail.", "category": "detail", "question_id": 45}
|
47 |
+
{"image": "019.jpg", "text": "Write brief HTML/JS to turn this mock-up into a colorful website, where the jokes are replaced by two real jokes.", "category": "complex", "question_id": 46}
|
48 |
+
{"image": "020.jpg", "text": "Describe this sketch in detail.", "category": "detail", "question_id": 47}
|
49 |
+
{"image": "020.jpg", "text": "Write brief HTML/JS to turn this mock-up into a colorful and interactive website, where the joke is replaced by a real joke.", "category": "complex", "question_id": 48}
|
50 |
+
{"image": "021.jpg", "text": "What's the ending of this movie?", "category": "conv", "question_id": 49}
|
51 |
+
{"image": "021.jpg", "text": "What is the significance of this scene in the context of the movie?", "category": "complex", "question_id": 50}
|
52 |
+
{"image": "022.jpg", "text": "What's the name of the restaurant serving these dishes?", "category": "conv", "question_id": 51}
|
53 |
+
{"image": "022.jpg", "text": "Describe this photo in detail.", "category": "detail", "question_id": 52}
|
54 |
+
{"image": "022.jpg", "text": "If someone were to recommend a new flavor or topping to the dish, describe the reason for this change and how it might alter the overall taste.", "category": "complex", "question_id": 53}
|
55 |
+
{"image": "023.jpg", "text": "What brand is featured in this advertisement?", "category": "conv", "question_id": 54}
|
56 |
+
{"image": "023.jpg", "text": "Describe this photo in detail.", "category": "detail", "question_id": 55}
|
57 |
+
{"image": "023.jpg", "text": "Show me a detailed recipe for cooking this at home.", "category": "complex", "question_id": 56}
|
58 |
+
{"image": "024.jpg", "text": "Describe this photo in detail.", "category": "detail", "question_id": 57}
|
59 |
+
{"image": "024.jpg", "text": "What is the problem this city might be facing? What are some possible solutions?", "category": "complex", "question_id": 58}
|
60 |
+
{"image": "024.jpg", "text": "Explain all the cues that indicate the current traffic conditions.", "category": "complex", "question_id": 59}
|
sglang/benchmark/llm_judge/README.md
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Run benchmark
|
2 |
+
|
3 |
+
### Benchmark sglang
|
4 |
+
```
|
5 |
+
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
|
6 |
+
```
|
7 |
+
|
8 |
+
```
|
9 |
+
python3 bench_sglang.py --num-questions 25 --parallel 8
|
10 |
+
python3 bench_sglang.py --num-questions 16 --parallel 1
|
11 |
+
```
|
12 |
+
|
13 |
+
|
14 |
+
### Benchmark vllm
|
15 |
+
```
|
16 |
+
python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
|
17 |
+
```
|
18 |
+
|
19 |
+
```
|
20 |
+
python3 bench_other.py --backend vllm --num-questions 25
|
21 |
+
```
|
22 |
+
|
23 |
+
|
24 |
+
### Benchmark guidance
|
25 |
+
```
|
26 |
+
python3 bench_other.py --backend guidance --num-questions 25 --parallel 1 --n-ctx 4096 --model-path path/to/gguf
|
27 |
+
```
|
28 |
+
|
29 |
+
### Benchmark lmql
|
30 |
+
|
31 |
+
```
|
32 |
+
python3 bench_other.py --backend lmql --num-questions 25 --parallel 1
|
33 |
+
```
|
sglang/benchmark/llm_judge/articles.jsonl
ADDED
The diff for this file is too large to render.
See raw diff
|
|
sglang/benchmark/llm_judge/bench_other.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import time
|
4 |
+
from concurrent.futures import ThreadPoolExecutor
|
5 |
+
from functools import partial
|
6 |
+
|
7 |
+
from tqdm import tqdm
|
8 |
+
|
9 |
+
from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
|
10 |
+
from sglang.utils import dump_state_text, read_jsonl
|
11 |
+
|
12 |
+
system_prompt = "Please serve as an impartial judge and rigorously evaluate the quality of the following article. Apply the most stringent standards possible, showing no leniency."
|
13 |
+
|
14 |
+
dimension_prompts = [
|
15 |
+
"Content: This refers to the essences of the essay. The substance should be well researched, accurate, relevant to the topic and should show a thorough understanding of the subject. The essay should also reflect a clear goal or purpose.",
|
16 |
+
"Organization and Structure: An essay needs to be properly structured with a clear introduction, body, and conclusion. The essay should flow naturally, with one paragraph leading seamlessly into the next.",
|
17 |
+
"Argument and Analysis: The argument made in the essay should be logical, coherent and clearly articulated. Each point made should be backed up by solid evidence and thorough analysis.",
|
18 |
+
"Clarity and Precision: The essay should be written in a clear and concise manner. The points made should be easily understood by the reader. The language used should also be precise and unambiguous.",
|
19 |
+
"Grammar and Punctuation: Proper use of grammar and punctuation is vital in an academic essay. Errors in grammar and punctuation not only distract the reader but can also negatively impact the meaning and interpretation of the content.",
|
20 |
+
"Referencing and Citation: An essay should contain proper citations and references for all sources used. This not only prevents accusations of plagiarism but also gives credit to the authors of the works that have contributed to the essay. The citation should adhere to a specific format as required by the academic institution or specified by the professor.",
|
21 |
+
]
|
22 |
+
|
23 |
+
|
24 |
+
def multi_dimension_judge(article, generate):
|
25 |
+
s = system_prompt
|
26 |
+
s += "\n```\n" + article + "\n```\n\n"
|
27 |
+
|
28 |
+
judges = []
|
29 |
+
for i in range(len(dimension_prompts)):
|
30 |
+
comp = generate(
|
31 |
+
s
|
32 |
+
+ "USER: Please judge the quality based on the following metric. "
|
33 |
+
+ dimension_prompts[i]
|
34 |
+
+ " Please provide a single-paragraph judgement. "
|
35 |
+
+ "Focus on the provided metric and do not say other things. "
|
36 |
+
'End your judgement paragraph with the word "END"\nJUDGE:',
|
37 |
+
max_tokens=256,
|
38 |
+
stop="END",
|
39 |
+
)
|
40 |
+
judges.append(comp)
|
41 |
+
|
42 |
+
s += "I will judge the quality based on the following metrics.\n"
|
43 |
+
for i in range(len(dimension_prompts)):
|
44 |
+
s += dimension_prompts[i].split(":")[0] + ": " + judges[i].strip() + "\n"
|
45 |
+
|
46 |
+
s += "In summary, on a scale of 1 to 10, I would give the article a score of"
|
47 |
+
s += generate(s, max_tokens=2, stop=None)
|
48 |
+
|
49 |
+
return s
|
50 |
+
|
51 |
+
|
52 |
+
async def multi_dimension_judge_async(article, generate):
|
53 |
+
s = system_prompt
|
54 |
+
s += "\n```\n" + article + "\n```\n\n"
|
55 |
+
|
56 |
+
judges = []
|
57 |
+
for i in range(len(dimension_prompts)):
|
58 |
+
comp = await generate(
|
59 |
+
s
|
60 |
+
+ "USER: Please judge the quality based on the following metric. "
|
61 |
+
+ dimension_prompts[i]
|
62 |
+
+ " Please provide a single-paragraph judgement. "
|
63 |
+
+ "Focus on the provided metric and do not say other things. "
|
64 |
+
'End your judgement paragraph with the word "END"\nJUDGE:',
|
65 |
+
max_tokens=256,
|
66 |
+
stop="END",
|
67 |
+
)
|
68 |
+
judges.append(comp)
|
69 |
+
|
70 |
+
s += "I will judge the quality based on the following metrics.\n"
|
71 |
+
for i in range(len(dimension_prompts)):
|
72 |
+
s += dimension_prompts[i].split(":")[0] + ": " + judges[i].strip() + "\n"
|
73 |
+
|
74 |
+
s += "In summary, on a scale of 1 to 10, I would give the article a score of"
|
75 |
+
s += await generate(s, max_tokens=2, stop=None)
|
76 |
+
|
77 |
+
return s
|
78 |
+
|
79 |
+
|
80 |
+
def main(args):
|
81 |
+
lines = read_jsonl(args.data_path)[: args.num_questions]
|
82 |
+
states = [None] * len(lines)
|
83 |
+
|
84 |
+
# Select backend
|
85 |
+
call_generate = partial(get_call_generate(args), temperature=0)
|
86 |
+
|
87 |
+
# Run requests
|
88 |
+
tic = time.time()
|
89 |
+
|
90 |
+
if args.backend != "lmql":
|
91 |
+
|
92 |
+
def get_one_answer(i):
|
93 |
+
states[i] = multi_dimension_judge(lines[i], call_generate)
|
94 |
+
|
95 |
+
if args.parallel == 1:
|
96 |
+
for i in tqdm(range(len(lines))):
|
97 |
+
get_one_answer(i)
|
98 |
+
else:
|
99 |
+
with ThreadPoolExecutor(args.parallel) as executor:
|
100 |
+
list(
|
101 |
+
tqdm(
|
102 |
+
executor.map(get_one_answer, list(range(len(lines)))),
|
103 |
+
total=len(lines),
|
104 |
+
)
|
105 |
+
)
|
106 |
+
|
107 |
+
else:
|
108 |
+
import asyncio
|
109 |
+
|
110 |
+
async def get_one_answer_async(i):
|
111 |
+
states[i] = await multi_dimension_judge_async(lines[i], call_generate)
|
112 |
+
|
113 |
+
batches = []
|
114 |
+
for i in range(0, len(lines), args.parallel):
|
115 |
+
batches.append(list(range(i, min(i + args.parallel, len(lines)))))
|
116 |
+
|
117 |
+
loop = asyncio.get_event_loop()
|
118 |
+
for bt in tqdm(batches):
|
119 |
+
loop.run_until_complete(
|
120 |
+
asyncio.gather(*[get_one_answer_async(i) for i in bt])
|
121 |
+
)
|
122 |
+
|
123 |
+
latency = time.time() - tic
|
124 |
+
|
125 |
+
# Compute accuracy
|
126 |
+
print(f"Latency: {latency:.3f}")
|
127 |
+
|
128 |
+
# Write results
|
129 |
+
dump_state_text(f"tmp_output_{args.backend}.txt", states)
|
130 |
+
|
131 |
+
with open(args.result_file, "a") as fout:
|
132 |
+
value = {
|
133 |
+
"task": "llm_judge",
|
134 |
+
"backend": args.backend,
|
135 |
+
"num_gpus": 1,
|
136 |
+
"latency": round(latency, 3),
|
137 |
+
"num_requests": args.num_questions,
|
138 |
+
"other": {
|
139 |
+
"num_questions": args.num_questions,
|
140 |
+
"parallel": args.parallel,
|
141 |
+
},
|
142 |
+
}
|
143 |
+
fout.write(json.dumps(value) + "\n")
|
144 |
+
|
145 |
+
|
146 |
+
if __name__ == "__main__":
|
147 |
+
parser = argparse.ArgumentParser()
|
148 |
+
parser.add_argument("--data-path", type=str, default="articles.jsonl")
|
149 |
+
parser.add_argument("--num-questions", type=int, default=20)
|
150 |
+
args = add_common_other_args_and_parse(parser)
|
151 |
+
main(args)
|
sglang/benchmark/llm_judge/bench_sglang.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import time
|
4 |
+
|
5 |
+
import sglang as sgl
|
6 |
+
from sglang.test.test_utils import (
|
7 |
+
add_common_sglang_args_and_parse,
|
8 |
+
select_sglang_backend,
|
9 |
+
)
|
10 |
+
from sglang.utils import dump_state_text, read_jsonl
|
11 |
+
|
12 |
+
system_prompt = "Please serve as an impartial judge and rigorously evaluate the quality of the following article. Apply the most stringent standards possible, showing no leniency."
|
13 |
+
|
14 |
+
dimension_prompts = [
|
15 |
+
"Content: This refers to the essences of the essay. The substance should be well researched, accurate, relevant to the topic and should show a thorough understanding of the subject. The essay should also reflect a clear goal or purpose.",
|
16 |
+
"Organization and Structure: An essay needs to be properly structured with a clear introduction, body, and conclusion. The essay should flow naturally, with one paragraph leading seamlessly into the next.",
|
17 |
+
"Argument and Analysis: The argument made in the essay should be logical, coherent and clearly articulated. Each point made should be backed up by solid evidence and thorough analysis.",
|
18 |
+
"Clarity and Precision: The essay should be written in a clear and concise manner. The points made should be easily understood by the reader. The language used should also be precise and unambiguous.",
|
19 |
+
"Grammar and Punctuation: Proper use of grammar and punctuation is vital in an academic essay. Errors in grammar and punctuation not only distract the reader but can also negatively impact the meaning and interpretation of the content.",
|
20 |
+
"Referencing and Citation: An essay should contain proper citations and references for all sources used. This not only prevents accusations of plagiarism but also gives credit to the authors of the works that have contributed to the essay. The citation should adhere to a specific format as required by the academic institution or specified by the professor.",
|
21 |
+
]
|
22 |
+
|
23 |
+
|
24 |
+
@sgl.function
|
25 |
+
def multi_dimension_judge(s, article):
|
26 |
+
s += system_prompt
|
27 |
+
s += "\n```\n" + article + "\n```\n\n"
|
28 |
+
|
29 |
+
forks = s.fork(len(dimension_prompts))
|
30 |
+
for i in range(len(dimension_prompts)):
|
31 |
+
forks[i] += (
|
32 |
+
"USER: Please judge the quality based on the following metric. "
|
33 |
+
+ dimension_prompts[i]
|
34 |
+
+ " Please provide a single-paragraph judgement. "
|
35 |
+
+ "Focus on the provided metric and do not say other things. "
|
36 |
+
'End your judgement paragraph with the word "END"\nJUDGE:'
|
37 |
+
)
|
38 |
+
forks[i] += sgl.gen("judgement", max_tokens=256, stop="END")
|
39 |
+
forks.join()
|
40 |
+
|
41 |
+
s += "I will judge the quality based on the following metrics.\n"
|
42 |
+
for i in range(len(dimension_prompts)):
|
43 |
+
s += (
|
44 |
+
dimension_prompts[i].split(":")[0]
|
45 |
+
+ ": "
|
46 |
+
+ forks[i]["judgement"].strip()
|
47 |
+
+ "\n"
|
48 |
+
)
|
49 |
+
|
50 |
+
s += "In summary, on a scale of 1 to 10, I would give the article a score of"
|
51 |
+
s += sgl.gen("score", max_tokens=2)
|
52 |
+
|
53 |
+
|
54 |
+
def main(args):
|
55 |
+
lines = read_jsonl(args.data_path)[: args.num_questions]
|
56 |
+
arguments = [{"article": l} for l in lines]
|
57 |
+
|
58 |
+
# Select backend
|
59 |
+
backend = select_sglang_backend(args)
|
60 |
+
|
61 |
+
# Run requests
|
62 |
+
tic = time.time()
|
63 |
+
states = multi_dimension_judge.run_batch(
|
64 |
+
arguments,
|
65 |
+
temperature=0,
|
66 |
+
backend=backend,
|
67 |
+
num_threads=args.parallel,
|
68 |
+
progress_bar=True,
|
69 |
+
)
|
70 |
+
latency = time.time() - tic
|
71 |
+
|
72 |
+
print(f"Latency: {latency:.3f}")
|
73 |
+
|
74 |
+
# Write results
|
75 |
+
dump_state_text(f"tmp_output_{args.backend}.txt", states)
|
76 |
+
|
77 |
+
with open(args.result_file, "a") as fout:
|
78 |
+
value = {
|
79 |
+
"task": "llm_judge",
|
80 |
+
"backend": args.backend,
|
81 |
+
"num_gpus": 1,
|
82 |
+
"latency": round(latency, 3),
|
83 |
+
"num_requests": args.num_questions,
|
84 |
+
"other": {
|
85 |
+
"num_questions": args.num_questions,
|
86 |
+
"parallel": args.parallel,
|
87 |
+
},
|
88 |
+
}
|
89 |
+
fout.write(json.dumps(value) + "\n")
|
90 |
+
|
91 |
+
|
92 |
+
if __name__ == "__main__":
|
93 |
+
parser = argparse.ArgumentParser()
|
94 |
+
parser.add_argument("--data-path", type=str, default="articles.jsonl")
|
95 |
+
parser.add_argument("--num-questions", type=int, default=20)
|
96 |
+
args = add_common_sglang_args_and_parse(parser)
|
97 |
+
main(args)
|
sglang/benchmark/long_json_decode/README.md
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Run benchmark
|
2 |
+
|
3 |
+
### Benchmark sglang
|
4 |
+
```
|
5 |
+
python3 -m sglang.launch_server --model-path codellama/CodeLlama-7b-instruct-hf --port 30000
|
6 |
+
```
|
7 |
+
|
8 |
+
```
|
9 |
+
python3 bench_sglang.py --num-questions 5 --parallel 1
|
10 |
+
```
|
11 |
+
|
12 |
+
|
13 |
+
### Benchmark vllm
|
14 |
+
```
|
15 |
+
python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model codellama/CodeLlama-7b-instruct-hf --disable-log-requests --port 21000 --gpu 0.97
|
16 |
+
```
|
17 |
+
|
18 |
+
```
|
19 |
+
python3 bench_other.py --backend vllm --num-questions 5
|
20 |
+
```
|
21 |
+
|
22 |
+
|
23 |
+
### Benchmark guidance
|
24 |
+
```
|
25 |
+
python3 bench_other.py --backend guidance --num-questions 5 --parallel 1 --n-ctx 11000 --model-path path/to/code-llama/gguf
|
26 |
+
```
|
27 |
+
|
28 |
+
|
29 |
+
### Build dataset
|
30 |
+
```
|
31 |
+
pip install wikipedia
|
32 |
+
python3 build_dataset.py
|
33 |
+
```
|
sglang/benchmark/long_json_decode/bench_other.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import time
|
4 |
+
from concurrent.futures import ThreadPoolExecutor
|
5 |
+
from functools import partial
|
6 |
+
|
7 |
+
from tqdm import tqdm
|
8 |
+
|
9 |
+
from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
|
10 |
+
from sglang.utils import dump_state_text, read_jsonl
|
11 |
+
|
12 |
+
|
13 |
+
def json_decode(document, generate):
|
14 |
+
s = "Please extract the information of a city from the following wikipedia page.\n"
|
15 |
+
s += "Page begin.\n" + document + "Page end.\n"
|
16 |
+
s += "Here is the name, country, and symbol of the city in JSON format.\n"
|
17 |
+
s += "{\n"
|
18 |
+
s += ' "name": "'
|
19 |
+
s += generate(s, max_tokens=8, stop='"') + '",\n'
|
20 |
+
s += ' "country": "'
|
21 |
+
s += generate(s, max_tokens=8, stop='"') + '",\n'
|
22 |
+
s += ' "air port code": "'
|
23 |
+
s += generate(s, max_tokens=8, stop='"') + '",\n'
|
24 |
+
s += ' "top 3 landmarks": "'
|
25 |
+
s += generate(s, max_tokens=24, stop='"') + '",\n'
|
26 |
+
s += "}\n"
|
27 |
+
return s
|
28 |
+
|
29 |
+
|
30 |
+
def main(args):
|
31 |
+
lines = read_jsonl(args.data_path)
|
32 |
+
arguments = []
|
33 |
+
for i in range(len(lines[: args.num_questions])):
|
34 |
+
arguments.append(
|
35 |
+
{
|
36 |
+
"document": lines[i]["document"],
|
37 |
+
}
|
38 |
+
)
|
39 |
+
states = [None] * len(arguments)
|
40 |
+
|
41 |
+
# Select backend
|
42 |
+
call_generate = partial(get_call_generate(args), temperature=0)
|
43 |
+
|
44 |
+
# Run requests
|
45 |
+
def get_one_answer(i):
|
46 |
+
states[i] = json_decode(generate=call_generate, **arguments[i])
|
47 |
+
|
48 |
+
tic = time.time()
|
49 |
+
if args.parallel == 1:
|
50 |
+
for i in tqdm(range(len(arguments))):
|
51 |
+
get_one_answer(i)
|
52 |
+
else:
|
53 |
+
with ThreadPoolExecutor(args.parallel) as executor:
|
54 |
+
list(
|
55 |
+
tqdm(
|
56 |
+
executor.map(get_one_answer, list(range(len(arguments)))),
|
57 |
+
total=len(arguments),
|
58 |
+
)
|
59 |
+
)
|
60 |
+
|
61 |
+
latency = time.time() - tic
|
62 |
+
|
63 |
+
# Compute accuracy
|
64 |
+
print(f"Latency: {latency:.3f}")
|
65 |
+
|
66 |
+
# Write results
|
67 |
+
dump_state_text(f"tmp_output_{args.backend}.txt", states)
|
68 |
+
|
69 |
+
with open(args.result_file, "a") as fout:
|
70 |
+
value = {
|
71 |
+
"task": "long_json_decode",
|
72 |
+
"backend": args.backend,
|
73 |
+
"num_gpus": 1,
|
74 |
+
"latency": round(latency, 3),
|
75 |
+
"num_requests": args.num_questions,
|
76 |
+
"other": {
|
77 |
+
"num_questions": args.num_questions,
|
78 |
+
"parallel": args.parallel,
|
79 |
+
},
|
80 |
+
}
|
81 |
+
fout.write(json.dumps(value) + "\n")
|
82 |
+
|
83 |
+
|
84 |
+
if __name__ == "__main__":
|
85 |
+
parser = argparse.ArgumentParser()
|
86 |
+
parser.add_argument("--data-path", type=str, default="questions.jsonl")
|
87 |
+
parser.add_argument("--num-questions", type=int, default=100)
|
88 |
+
args = add_common_other_args_and_parse(parser)
|
89 |
+
main(args)
|
sglang/benchmark/long_json_decode/bench_sglang.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import time
|
4 |
+
|
5 |
+
import sglang as sgl
|
6 |
+
from sglang.test.test_utils import (
|
7 |
+
add_common_sglang_args_and_parse,
|
8 |
+
select_sglang_backend,
|
9 |
+
)
|
10 |
+
from sglang.utils import dump_state_text, read_jsonl
|
11 |
+
|
12 |
+
|
13 |
+
@sgl.function
|
14 |
+
def json_decode(s, document):
|
15 |
+
s += "Please extract the information of a city from the following wikipedia page.\n"
|
16 |
+
s += "Page begin.\n" + document + "Page end.\n"
|
17 |
+
s += "Here is the name, country, and symbol of the city in JSON format.\n"
|
18 |
+
s += "{\n"
|
19 |
+
s += ' "name": "' + sgl.gen("name", max_tokens=8, stop='"') + '",\n'
|
20 |
+
s += ' "country": "' + sgl.gen("country", max_tokens=8, stop='"') + '",\n'
|
21 |
+
s += (
|
22 |
+
' "air port code": "'
|
23 |
+
+ sgl.gen("air port code", max_tokens=8, stop='"')
|
24 |
+
+ '",\n'
|
25 |
+
)
|
26 |
+
s += (
|
27 |
+
' "top 3 landmarks": "'
|
28 |
+
+ sgl.gen("landmarks", max_tokens=24, stop='"')
|
29 |
+
+ '",\n'
|
30 |
+
)
|
31 |
+
s += "}\n"
|
32 |
+
|
33 |
+
|
34 |
+
def main(args):
|
35 |
+
lines = read_jsonl(args.data_path)
|
36 |
+
arguments = []
|
37 |
+
for i in range(len(lines[: args.num_questions])):
|
38 |
+
arguments.append(
|
39 |
+
{
|
40 |
+
"document": lines[i]["document"],
|
41 |
+
}
|
42 |
+
)
|
43 |
+
|
44 |
+
# Select backend
|
45 |
+
backend = select_sglang_backend(args)
|
46 |
+
sgl.set_default_backend(backend)
|
47 |
+
|
48 |
+
# Run requests
|
49 |
+
tic = time.time()
|
50 |
+
states = json_decode.run_batch(
|
51 |
+
arguments, temperature=0, num_threads=args.parallel, progress_bar=True
|
52 |
+
)
|
53 |
+
latency = time.time() - tic
|
54 |
+
|
55 |
+
# Compute accuracy
|
56 |
+
print(f"Latency: {latency:.3f}")
|
57 |
+
|
58 |
+
# Write results
|
59 |
+
dump_state_text(f"tmp_output_{args.backend}.txt", states)
|
60 |
+
|
61 |
+
with open(args.result_file, "a") as fout:
|
62 |
+
value = {
|
63 |
+
"task": "long_json_decode",
|
64 |
+
"backend": args.backend,
|
65 |
+
"num_gpus": 1,
|
66 |
+
"latency": round(latency, 3),
|
67 |
+
"num_requests": args.num_questions,
|
68 |
+
"other": {
|
69 |
+
"num_questions": args.num_questions,
|
70 |
+
"parallel": args.parallel,
|
71 |
+
},
|
72 |
+
}
|
73 |
+
fout.write(json.dumps(value) + "\n")
|
74 |
+
|
75 |
+
|
76 |
+
if __name__ == "__main__":
|
77 |
+
parser = argparse.ArgumentParser()
|
78 |
+
parser.add_argument("--data-path", type=str, default="questions.jsonl")
|
79 |
+
parser.add_argument("--num-questions", type=int, default=10)
|
80 |
+
args = add_common_sglang_args_and_parse(parser)
|
81 |
+
main(args)
|
sglang/benchmark/long_json_decode/build_dataset.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
|
3 |
+
import transformers
|
4 |
+
import wikipedia
|
5 |
+
|
6 |
+
name = "meta-llama/Llama-2-7b-chat-hf"
|
7 |
+
t = transformers.AutoTokenizer.from_pretrained(name)
|
8 |
+
city_names = ["los angles", "london", "tokyo", "beijing", "singapore"]
|
9 |
+
|
10 |
+
|
11 |
+
for city_name in city_names:
|
12 |
+
content = str(wikipedia.page(city_name).content)
|
13 |
+
content = content.replace("\n\n", "\n")
|
14 |
+
|
15 |
+
tokens = t.encode(content)
|
16 |
+
|
17 |
+
truncate_len = int((10000 / len(tokens)) * len(content))
|
18 |
+
truncate_content = content[:truncate_len]
|
19 |
+
truncate_tokens = t.encode(truncate_content)
|
20 |
+
|
21 |
+
# Count token
|
22 |
+
print(
|
23 |
+
f"city_name: {city_name}, #tokens: {len(tokens)}, #truncate tokens: {len(truncate_tokens)}"
|
24 |
+
)
|
25 |
+
|
26 |
+
with open("questions.jsonl", "a") as fout:
|
27 |
+
fout.write(json.dumps({"document": truncate_content}) + "\n")
|
sglang/benchmark/mmlu/bench_other.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import asyncio
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
import time
|
6 |
+
from concurrent.futures import ThreadPoolExecutor
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import pandas as pd
|
10 |
+
import tiktoken
|
11 |
+
from tqdm import tqdm
|
12 |
+
|
13 |
+
from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
|
14 |
+
|
15 |
+
choices = ["A", "B", "C", "D"]
|
16 |
+
|
17 |
+
tokenizer = tiktoken.encoding_for_model("gpt-3.5-turbo")
|
18 |
+
|
19 |
+
|
20 |
+
def format_subject(subject):
|
21 |
+
l = subject.split("_")
|
22 |
+
s = ""
|
23 |
+
for entry in l:
|
24 |
+
s += " " + entry
|
25 |
+
return s
|
26 |
+
|
27 |
+
|
28 |
+
def format_example(df, idx, include_answer=True):
|
29 |
+
prompt = df.iloc[idx, 0]
|
30 |
+
k = df.shape[1] - 2
|
31 |
+
for j in range(k):
|
32 |
+
prompt += "\n{}. {}".format(choices[j], df.iloc[idx, j + 1])
|
33 |
+
prompt += "\nAnswer:"
|
34 |
+
if include_answer:
|
35 |
+
prompt += " {}\n\n".format(df.iloc[idx, k + 1])
|
36 |
+
return prompt
|
37 |
+
|
38 |
+
|
39 |
+
def gen_prompt(train_df, subject, k=-1):
|
40 |
+
prompt = "The following are multiple choice questions (with answers) about{}.\n\n".format(
|
41 |
+
format_subject(subject)
|
42 |
+
)
|
43 |
+
if k == -1:
|
44 |
+
k = train_df.shape[0]
|
45 |
+
for i in range(k):
|
46 |
+
prompt += format_example(train_df, i)
|
47 |
+
return prompt
|
48 |
+
|
49 |
+
|
50 |
+
def evaluate(args, subject, dev_df, test_df, call_generate):
|
51 |
+
prompts = []
|
52 |
+
labels = []
|
53 |
+
|
54 |
+
# Construct prompts
|
55 |
+
k = args.ntrain
|
56 |
+
train_prompt = gen_prompt(dev_df, subject, k)
|
57 |
+
while len(tokenizer.encode(train_prompt)) > 1536:
|
58 |
+
k -= 1
|
59 |
+
train_prompt = gen_prompt(dev_df, subject, k)
|
60 |
+
|
61 |
+
for i in range(test_df.shape[0]):
|
62 |
+
prompt_end = format_example(test_df, i, include_answer=False)
|
63 |
+
prompt = train_prompt + prompt_end
|
64 |
+
prompts.append(prompt)
|
65 |
+
|
66 |
+
label = test_df.iloc[i, test_df.shape[1] - 1]
|
67 |
+
labels.append(label)
|
68 |
+
|
69 |
+
preds = [None] * len(prompts)
|
70 |
+
max_tokens = 1
|
71 |
+
|
72 |
+
# Run requests
|
73 |
+
if args.backend != "lmql":
|
74 |
+
# Use thread pool
|
75 |
+
def get_one_answer(i):
|
76 |
+
pred = call_generate(prompts[i], temperature=0, max_tokens=max_tokens)
|
77 |
+
preds[i] = pred.strip()[0]
|
78 |
+
|
79 |
+
tic = time.time()
|
80 |
+
if args.parallel == 1:
|
81 |
+
for i in range(len(prompts)):
|
82 |
+
get_one_answer(i)
|
83 |
+
else:
|
84 |
+
with ThreadPoolExecutor(args.parallel) as executor:
|
85 |
+
executor.map(get_one_answer, list(range(len(prompts))))
|
86 |
+
else:
|
87 |
+
# Use asyncio
|
88 |
+
async def batched_call(batch_size):
|
89 |
+
for i in range(0, len(prompts), batch_size):
|
90 |
+
tasks = []
|
91 |
+
for p in prompts[i : i + batch_size]:
|
92 |
+
tasks.append(call_generate(p, temperature=0, max_tokens=max_tokens))
|
93 |
+
rets = await asyncio.gather(*tasks)
|
94 |
+
for j in range(len(rets)):
|
95 |
+
preds[i + j] = rets[j].strip()[0]
|
96 |
+
|
97 |
+
tic = time.time()
|
98 |
+
asyncio.run(batched_call(batch_size=args.parallel))
|
99 |
+
latency = time.time() - tic
|
100 |
+
|
101 |
+
# Compute accuracy
|
102 |
+
cors = [pred == label for pred, label in zip(preds, labels)]
|
103 |
+
acc = np.mean(cors)
|
104 |
+
cors = np.array(cors)
|
105 |
+
|
106 |
+
print(
|
107 |
+
"Average accuracy {:.3f}, latency {:.2f}, #q: {} - {}".format(
|
108 |
+
acc, latency, len(prompts), subject
|
109 |
+
)
|
110 |
+
)
|
111 |
+
|
112 |
+
return cors, acc, latency
|
113 |
+
|
114 |
+
|
115 |
+
def main(args):
|
116 |
+
subjects = sorted(
|
117 |
+
[
|
118 |
+
f.split("_test.csv")[0]
|
119 |
+
for f in os.listdir(os.path.join(args.data_dir, "test"))
|
120 |
+
if "_test.csv" in f
|
121 |
+
]
|
122 |
+
)
|
123 |
+
|
124 |
+
all_cors = []
|
125 |
+
all_latencies = []
|
126 |
+
num_requests = 0
|
127 |
+
|
128 |
+
# Select backend
|
129 |
+
call_generate = get_call_generate(args)
|
130 |
+
|
131 |
+
for subject in tqdm(subjects[: args.nsub]):
|
132 |
+
dev_df = pd.read_csv(
|
133 |
+
os.path.join(args.data_dir, "dev", subject + "_dev.csv"), header=None
|
134 |
+
)[: args.ntrain]
|
135 |
+
test_df = pd.read_csv(
|
136 |
+
os.path.join(args.data_dir, "test", subject + "_test.csv"), header=None
|
137 |
+
)
|
138 |
+
|
139 |
+
cors, acc, latency = evaluate(args, subject, dev_df, test_df, call_generate)
|
140 |
+
all_cors.append(cors)
|
141 |
+
all_latencies.append(latency)
|
142 |
+
num_requests += len(test_df)
|
143 |
+
|
144 |
+
total_latency = np.sum(all_latencies)
|
145 |
+
print("Total latency: {:.3f}".format(total_latency))
|
146 |
+
|
147 |
+
weighted_acc = np.mean(np.concatenate(all_cors))
|
148 |
+
print("Average accuracy: {:.3f}".format(weighted_acc))
|
149 |
+
|
150 |
+
# Write results
|
151 |
+
with open(args.result_file, "a") as fout:
|
152 |
+
value = {
|
153 |
+
"task": "mmlu",
|
154 |
+
"backend": args.backend,
|
155 |
+
"num_gpus": 1,
|
156 |
+
"latency": round(total_latency, 3),
|
157 |
+
"accuracy": round(weighted_acc, 3),
|
158 |
+
"num_requests": num_requests,
|
159 |
+
"other": {
|
160 |
+
"nsub": args.nsub,
|
161 |
+
"parallel": args.parallel,
|
162 |
+
},
|
163 |
+
}
|
164 |
+
fout.write(json.dumps(value) + "\n")
|
165 |
+
|
166 |
+
|
167 |
+
if __name__ == "__main__":
|
168 |
+
parser = argparse.ArgumentParser()
|
169 |
+
parser.add_argument("--ntrain", type=int, default=5)
|
170 |
+
parser.add_argument("--data_dir", type=str, default="data")
|
171 |
+
parser.add_argument("--nsub", type=int, default=60)
|
172 |
+
args = add_common_other_args_and_parse(parser)
|
173 |
+
main(args)
|
sglang/benchmark/mmlu/bench_sglang.py
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import time
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import pandas as pd
|
8 |
+
import tiktoken
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
from sglang.test.test_utils import (
|
12 |
+
add_common_sglang_args_and_parse,
|
13 |
+
select_sglang_backend,
|
14 |
+
)
|
15 |
+
|
16 |
+
choices = ["A", "B", "C", "D"]
|
17 |
+
|
18 |
+
tokenizer = tiktoken.encoding_for_model("gpt-3.5-turbo")
|
19 |
+
|
20 |
+
|
21 |
+
def format_subject(subject):
|
22 |
+
l = subject.split("_")
|
23 |
+
s = ""
|
24 |
+
for entry in l:
|
25 |
+
s += " " + entry
|
26 |
+
return s
|
27 |
+
|
28 |
+
|
29 |
+
def format_example(df, idx, include_answer=True):
|
30 |
+
prompt = df.iloc[idx, 0]
|
31 |
+
k = df.shape[1] - 2
|
32 |
+
for j in range(k):
|
33 |
+
prompt += "\n{}. {}".format(choices[j], df.iloc[idx, j + 1])
|
34 |
+
prompt += "\nAnswer:"
|
35 |
+
if include_answer:
|
36 |
+
prompt += " {}\n\n".format(df.iloc[idx, k + 1])
|
37 |
+
return prompt
|
38 |
+
|
39 |
+
|
40 |
+
def gen_prompt(train_df, subject, k=-1):
|
41 |
+
prompt = "The following are multiple choice questions (with answers) about{}.\n\n".format(
|
42 |
+
format_subject(subject)
|
43 |
+
)
|
44 |
+
if k == -1:
|
45 |
+
k = train_df.shape[0]
|
46 |
+
for i in range(k):
|
47 |
+
prompt += format_example(train_df, i)
|
48 |
+
return prompt
|
49 |
+
|
50 |
+
|
51 |
+
def main(args):
|
52 |
+
subjects = sorted(
|
53 |
+
[
|
54 |
+
f.split("_test.csv")[0]
|
55 |
+
for f in os.listdir(os.path.join(args.data_dir, "test"))
|
56 |
+
if "_test.csv" in f
|
57 |
+
]
|
58 |
+
)
|
59 |
+
|
60 |
+
# Build prompts
|
61 |
+
arguments = []
|
62 |
+
labels = []
|
63 |
+
num_questions = []
|
64 |
+
|
65 |
+
for subject in subjects[: args.nsub]:
|
66 |
+
dev_df = pd.read_csv(
|
67 |
+
os.path.join(args.data_dir, "dev", subject + "_dev.csv"), header=None
|
68 |
+
)[: args.ntrain]
|
69 |
+
test_df = pd.read_csv(
|
70 |
+
os.path.join(args.data_dir, "test", subject + "_test.csv"), header=None
|
71 |
+
)
|
72 |
+
num_questions.append(test_df.shape[0])
|
73 |
+
|
74 |
+
k = args.ntrain
|
75 |
+
few_shot_examples = gen_prompt(dev_df, subject, k)
|
76 |
+
while len(tokenizer.encode(few_shot_examples)) > 1536:
|
77 |
+
k -= 1
|
78 |
+
few_shot_examples = gen_prompt(dev_df, subject, k)
|
79 |
+
|
80 |
+
for i in range(test_df.shape[0]):
|
81 |
+
prompt_end = format_example(test_df, i, include_answer=False)
|
82 |
+
|
83 |
+
arguments.append(
|
84 |
+
{
|
85 |
+
"examples": few_shot_examples,
|
86 |
+
"question": prompt_end,
|
87 |
+
}
|
88 |
+
)
|
89 |
+
|
90 |
+
label = test_df.iloc[i, test_df.shape[1] - 1]
|
91 |
+
labels.append(label)
|
92 |
+
|
93 |
+
#####################################
|
94 |
+
######### SGL Program Begin #########
|
95 |
+
#####################################
|
96 |
+
|
97 |
+
import sglang as sgl
|
98 |
+
|
99 |
+
if args.backend.startswith("gpt-"):
|
100 |
+
|
101 |
+
@sgl.function
|
102 |
+
def few_shot_mmlu(s, examples, question):
|
103 |
+
s += sgl.user(examples + question)
|
104 |
+
s += sgl.assistant(sgl.gen("answer"))
|
105 |
+
|
106 |
+
else:
|
107 |
+
|
108 |
+
@sgl.function
|
109 |
+
def few_shot_mmlu(s, examples, question):
|
110 |
+
s += examples + question + sgl.gen("answer")
|
111 |
+
|
112 |
+
#####################################
|
113 |
+
########## SGL Program End ##########
|
114 |
+
#####################################
|
115 |
+
|
116 |
+
# Select backend
|
117 |
+
backend = select_sglang_backend(args)
|
118 |
+
|
119 |
+
# Run
|
120 |
+
tic = time.time()
|
121 |
+
states = few_shot_mmlu.run_batch(
|
122 |
+
arguments,
|
123 |
+
temperature=0,
|
124 |
+
max_new_tokens=1,
|
125 |
+
backend=backend,
|
126 |
+
num_threads=args.parallel,
|
127 |
+
progress_bar=True,
|
128 |
+
)
|
129 |
+
preds = [
|
130 |
+
s["answer"].strip()[0] if len(s["answer"].strip()) > 0 else "" for s in states
|
131 |
+
]
|
132 |
+
latency = time.time() - tic
|
133 |
+
|
134 |
+
# Compute accuracy
|
135 |
+
cors = [pred == label for pred, label in zip(preds, labels)]
|
136 |
+
|
137 |
+
pt = 0
|
138 |
+
for subject, num_qs in zip(subjects[: args.nsub], num_questions):
|
139 |
+
print(
|
140 |
+
f"subject: {subject}, #q:{num_qs}, acc: {np.mean(cors[pt: pt + num_qs]):.3f}"
|
141 |
+
)
|
142 |
+
pt += num_qs
|
143 |
+
assert pt == len(cors)
|
144 |
+
weighted_acc = np.mean(cors)
|
145 |
+
|
146 |
+
# Print results
|
147 |
+
print("Total latency: {:.3f}".format(latency))
|
148 |
+
print("Average accuracy: {:.3f}".format(weighted_acc))
|
149 |
+
|
150 |
+
# Write results
|
151 |
+
with open(args.result_file, "a") as fout:
|
152 |
+
value = {
|
153 |
+
"task": "mmlu",
|
154 |
+
"backend": args.backend,
|
155 |
+
"num_gpus": 1,
|
156 |
+
"latency": round(latency, 3),
|
157 |
+
"accuracy": round(weighted_acc, 3),
|
158 |
+
"num_requests": len(arguments),
|
159 |
+
"other": {
|
160 |
+
"nsub": args.nsub,
|
161 |
+
"parallel": args.parallel,
|
162 |
+
},
|
163 |
+
}
|
164 |
+
fout.write(json.dumps(value) + "\n")
|
165 |
+
|
166 |
+
|
167 |
+
if __name__ == "__main__":
|
168 |
+
parser = argparse.ArgumentParser()
|
169 |
+
parser.add_argument("--ntrain", "-k", type=int, default=5)
|
170 |
+
parser.add_argument("--data_dir", "-d", type=str, default="data")
|
171 |
+
parser.add_argument("--save_dir", "-s", type=str, default="results")
|
172 |
+
parser.add_argument("--nsub", type=int, default=60)
|
173 |
+
args = add_common_sglang_args_and_parse(parser)
|
174 |
+
main(args)
|
sglang/benchmark/mmlu/download_data.sh
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
wget https://people.eecs.berkeley.edu/~hendrycks/data.tar
|
2 |
+
tar xf data.tar
|
sglang/benchmark/multi_chain_reasoning/bench_other.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import ast
|
3 |
+
import asyncio
|
4 |
+
import json
|
5 |
+
import re
|
6 |
+
import time
|
7 |
+
from concurrent.futures import ThreadPoolExecutor
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
from tqdm import tqdm
|
11 |
+
|
12 |
+
from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
|
13 |
+
from sglang.utils import dump_state_text, read_jsonl
|
14 |
+
|
15 |
+
INVALID = -9999999
|
16 |
+
|
17 |
+
|
18 |
+
def get_answer_value(answer_str):
|
19 |
+
answer_str = answer_str.replace(",", "")
|
20 |
+
numbers = re.findall(r"\d+", answer_str)
|
21 |
+
if len(numbers) < 1:
|
22 |
+
return INVALID
|
23 |
+
try:
|
24 |
+
return ast.literal_eval(numbers[-1])
|
25 |
+
except SyntaxError:
|
26 |
+
return INVALID
|
27 |
+
|
28 |
+
|
29 |
+
prompt_lib = [
|
30 |
+
"Let us think step by step.",
|
31 |
+
"Approach this methodically. Let's dissect the problem into smaller, more manageable parts.",
|
32 |
+
"It's important to proceed step by step, ensuring accuracy at each stage.",
|
33 |
+
"Take a deep breath and break this down.",
|
34 |
+
"A little bit of arithmetic and a logical approach will help us quickly arrive at the solution to this problem.",
|
35 |
+
"I am extremely good at math.",
|
36 |
+
]
|
37 |
+
|
38 |
+
|
39 |
+
def multi_chain_gsm8k(question, num_chains, call_generate):
|
40 |
+
s = "Question: " + question + "\n"
|
41 |
+
# s += call_generate(s + "Answer: " + prompt_lib[0], max_tokens=256,
|
42 |
+
# stop="Question", temperature=0)
|
43 |
+
# return s
|
44 |
+
|
45 |
+
comps = []
|
46 |
+
for i in range(num_chains):
|
47 |
+
comps.append(
|
48 |
+
call_generate(
|
49 |
+
s + "Answer: " + prompt_lib[i % num_chains],
|
50 |
+
max_tokens=256,
|
51 |
+
temperature=0.3,
|
52 |
+
stop="Question",
|
53 |
+
)
|
54 |
+
)
|
55 |
+
|
56 |
+
s += "Answer: To answer this question, here are some possible solutions. "
|
57 |
+
s += "After considering all of them, I will do a majority vote.\n\n"
|
58 |
+
for i in range(num_chains):
|
59 |
+
s += f"Solution {i+1}: " + comps[i].strip() + "\n\n"
|
60 |
+
s += "\nBy considering the above solutions and doing a majority vote, I think the final answer (a single integer number) is "
|
61 |
+
s += call_generate(s, max_tokens=16, temperature=0, stop=None)
|
62 |
+
return s
|
63 |
+
|
64 |
+
|
65 |
+
async def multi_chain_gsm8k_async(question, num_chains, call_generate):
|
66 |
+
s = "Question: " + question + "\n"
|
67 |
+
# s += call_generate(s + "Answer: " + prompt_lib[0], max_tokens=256,
|
68 |
+
# stop="Question", temperature=0)
|
69 |
+
# return s
|
70 |
+
|
71 |
+
comps = []
|
72 |
+
for i in range(num_chains):
|
73 |
+
comps.append(
|
74 |
+
await call_generate(
|
75 |
+
s + "Answer: " + prompt_lib[i % num_chains],
|
76 |
+
max_tokens=256,
|
77 |
+
temperature=0.3,
|
78 |
+
stop="Question",
|
79 |
+
)
|
80 |
+
)
|
81 |
+
|
82 |
+
s += "Answer: To answer this question, here are some possible solutions. "
|
83 |
+
s += "After considering all of them, I will do a majority vote.\n\n"
|
84 |
+
for i in range(num_chains):
|
85 |
+
s += f"Solution {i+1}: " + comps[i].strip() + "\n\n"
|
86 |
+
s += "\nBy considering the above solutions and doing a majority vote, I think the final answer (a single integer number) is "
|
87 |
+
s += await call_generate(s, max_tokens=16, temperature=0, stop=None)
|
88 |
+
return s
|
89 |
+
|
90 |
+
|
91 |
+
def main(args):
|
92 |
+
lines = read_jsonl(args.data_path)
|
93 |
+
|
94 |
+
# Construct prompts
|
95 |
+
k = args.num_shot
|
96 |
+
|
97 |
+
questions = []
|
98 |
+
labels = []
|
99 |
+
for i in range(len(lines[: args.num_questions])):
|
100 |
+
questions.append(lines[i]["question"])
|
101 |
+
labels.append(get_answer_value(lines[i]["answer"]))
|
102 |
+
assert all(l != INVALID for l in labels)
|
103 |
+
|
104 |
+
states = [None] * len(labels)
|
105 |
+
|
106 |
+
# Select backend
|
107 |
+
call_generate = get_call_generate(args)
|
108 |
+
|
109 |
+
# Run requests
|
110 |
+
if args.backend != "lmql":
|
111 |
+
# Use thread pool
|
112 |
+
def get_one_answer(i):
|
113 |
+
answer = multi_chain_gsm8k(questions[i], args.num_chains, call_generate)
|
114 |
+
states[i] = answer
|
115 |
+
|
116 |
+
tic = time.time()
|
117 |
+
if args.parallel == 1:
|
118 |
+
for i in tqdm(range(len(questions))):
|
119 |
+
get_one_answer(i)
|
120 |
+
else:
|
121 |
+
with ThreadPoolExecutor(args.parallel) as executor:
|
122 |
+
list(
|
123 |
+
tqdm(
|
124 |
+
executor.map(get_one_answer, list(range(len(questions)))),
|
125 |
+
total=len(questions),
|
126 |
+
)
|
127 |
+
)
|
128 |
+
|
129 |
+
else:
|
130 |
+
# Use asyncio
|
131 |
+
async def get_one_answer_asyncio(i):
|
132 |
+
answer = await multi_chain_gsm8k_async(
|
133 |
+
questions[i], args.num_chains, call_generate
|
134 |
+
)
|
135 |
+
states[i] = answer
|
136 |
+
|
137 |
+
tic = time.time()
|
138 |
+
loop = asyncio.get_event_loop()
|
139 |
+
batches = [
|
140 |
+
list(range(i, min(i + args.parallel, len(questions))))
|
141 |
+
for i in range(0, len(questions), args.parallel)
|
142 |
+
]
|
143 |
+
for bt in tqdm(batches):
|
144 |
+
tasks = [get_one_answer_asyncio(k) for k in bt]
|
145 |
+
loop.run_until_complete(asyncio.gather(*tasks))
|
146 |
+
|
147 |
+
latency = time.time() - tic
|
148 |
+
|
149 |
+
preds = []
|
150 |
+
for i in range(len(states)):
|
151 |
+
preds.append(get_answer_value(states[i]))
|
152 |
+
|
153 |
+
# Compute accuracy
|
154 |
+
acc = np.mean(np.array(preds) == np.array(labels))
|
155 |
+
invalid = np.mean(np.array(preds) == INVALID)
|
156 |
+
print(f"Latency: {latency:.3f}")
|
157 |
+
print(f"Invalid: {invalid:.3f}")
|
158 |
+
print(f"Accuracy: {acc:.3f}")
|
159 |
+
|
160 |
+
# Write results
|
161 |
+
dump_state_text(f"tmp_output_{args.backend}.txt", states)
|
162 |
+
|
163 |
+
with open(args.result_file, "a") as fout:
|
164 |
+
value = {
|
165 |
+
"task": "multi_chain_gsm8k",
|
166 |
+
"backend": args.backend,
|
167 |
+
"num_gpus": 1,
|
168 |
+
"latency": round(latency, 3),
|
169 |
+
"accuracy": round(acc, 3),
|
170 |
+
"num_requests": args.num_questions,
|
171 |
+
"other": {
|
172 |
+
"num_questions": args.num_questions,
|
173 |
+
"parallel": args.parallel,
|
174 |
+
},
|
175 |
+
}
|
176 |
+
fout.write(json.dumps(value) + "\n")
|
177 |
+
|
178 |
+
|
179 |
+
if __name__ == "__main__":
|
180 |
+
parser = argparse.ArgumentParser()
|
181 |
+
parser.add_argument("--num-shot", type=int, default=0)
|
182 |
+
parser.add_argument("--num-chains", type=int, default=5)
|
183 |
+
parser.add_argument("--data-path", type=str, default="test.jsonl")
|
184 |
+
parser.add_argument("--num-questions", type=int, default=50)
|
185 |
+
args = add_common_other_args_and_parse(parser)
|
186 |
+
main(args)
|
sglang/benchmark/multi_document_qa/README.md
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Run benchmark
|
2 |
+
|
3 |
+
### Benchmark sglang
|
4 |
+
```
|
5 |
+
python3 -m sglang.launch_server --model-path codellama/CodeLlama-7b-instruct-hf --port 30000
|
6 |
+
```
|
7 |
+
|
8 |
+
```
|
9 |
+
python3 bench_sglang.py --num-questions 10 --parallel 1
|
10 |
+
```
|
11 |
+
|
12 |
+
|
13 |
+
### Benchmark vllm
|
14 |
+
```
|
15 |
+
python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model codellama/CodeLlama-7b-instruct-hf --disable-log-requests --port 21000 --gpu 0.97
|
16 |
+
```
|
17 |
+
|
18 |
+
```
|
19 |
+
python3 bench_other.py --backend vllm --num-questions 64
|
20 |
+
```
|
21 |
+
|
22 |
+
|
23 |
+
### Benchmark guidance
|
24 |
+
```
|
25 |
+
python3 bench_other.py --backend guidance --num-questions 32 --parallel 1 --n-ctx 11000 --model-path path/to/code-llama/gguf
|
26 |
+
```
|
27 |
+
|
28 |
+
|
29 |
+
|
30 |
+
### Build dataset
|
31 |
+
|
32 |
+
```
|
33 |
+
pip install PyPDF2
|
34 |
+
python3 build_dataset.py
|
35 |
+
```
|
36 |
+
|
37 |
+
```python
|
38 |
+
import PyPDF2
|
39 |
+
|
40 |
+
with open('llama2.pdf', 'rb') as file:
|
41 |
+
reader = PyPDF2.PdfReader(file)
|
42 |
+
text = ''
|
43 |
+
for page_num in range(len(reader.pages)):
|
44 |
+
text += reader.pages[page_num].extract_text()
|
45 |
+
with open('output.txt', 'w') as text_file:
|
46 |
+
text_file.write(text)
|
47 |
+
```
|
sglang/benchmark/multi_document_qa/bench_other.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import time
|
4 |
+
from concurrent.futures import ThreadPoolExecutor
|
5 |
+
from functools import partial
|
6 |
+
|
7 |
+
from tqdm import tqdm
|
8 |
+
|
9 |
+
from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
|
10 |
+
from sglang.utils import dump_state_text, read_jsonl
|
11 |
+
|
12 |
+
USER_PREFIX = "[INST] "
|
13 |
+
USER_SUFFIX = " [/INST]"
|
14 |
+
ASSISTANT_PREFIX = ""
|
15 |
+
ASSISTANT_SUFFIX = " </s><s>"
|
16 |
+
|
17 |
+
|
18 |
+
def multi_document_qa(docs, question, generate):
|
19 |
+
s = USER_PREFIX
|
20 |
+
s += "Pleaes answer a question according to given documents.\n"
|
21 |
+
s += "Question:" + question + "Documents begin.\n"
|
22 |
+
|
23 |
+
s += "".join(docs)
|
24 |
+
|
25 |
+
s += "\nDocuments end."
|
26 |
+
s += (
|
27 |
+
"\n\nBased on the above documents, please answer this question:\n"
|
28 |
+
+ question
|
29 |
+
+ "\nAnswer in three words or fewer."
|
30 |
+
)
|
31 |
+
s += USER_SUFFIX
|
32 |
+
s += ASSISTANT_PREFIX
|
33 |
+
answer = generate(s, max_tokens=16, stop=None)
|
34 |
+
return answer
|
35 |
+
|
36 |
+
|
37 |
+
def main(args):
|
38 |
+
lines = read_jsonl(args.data_path)
|
39 |
+
l = lines[0]
|
40 |
+
arguments = []
|
41 |
+
labels = []
|
42 |
+
|
43 |
+
num_docs = 10
|
44 |
+
if args.backend == "guidance":
|
45 |
+
num_docs = 7 # due to OOM
|
46 |
+
|
47 |
+
for i in range(len(l["questions"][: args.num_questions])):
|
48 |
+
arguments.append(
|
49 |
+
{
|
50 |
+
"docs": l["documents"][:num_docs],
|
51 |
+
"question": l["questions"][i],
|
52 |
+
}
|
53 |
+
)
|
54 |
+
labels.append(l["answers"][i])
|
55 |
+
states = [None] * len(arguments)
|
56 |
+
|
57 |
+
# Select backend
|
58 |
+
call_generate = partial(get_call_generate(args), temperature=0)
|
59 |
+
|
60 |
+
# Run requests
|
61 |
+
def get_one_answer(i):
|
62 |
+
states[i] = multi_document_qa(generate=call_generate, **arguments[i])
|
63 |
+
|
64 |
+
tic = time.time()
|
65 |
+
if args.parallel == 1:
|
66 |
+
for i in tqdm(range(len(labels))):
|
67 |
+
get_one_answer(i)
|
68 |
+
else:
|
69 |
+
with ThreadPoolExecutor(args.parallel) as executor:
|
70 |
+
list(
|
71 |
+
tqdm(
|
72 |
+
executor.map(get_one_answer, list(range(len(labels)))),
|
73 |
+
total=len(labels),
|
74 |
+
)
|
75 |
+
)
|
76 |
+
|
77 |
+
latency = time.time() - tic
|
78 |
+
|
79 |
+
# Compute accuracy
|
80 |
+
print(states)
|
81 |
+
correct = 0
|
82 |
+
for s, label in zip(states, labels):
|
83 |
+
answer = s.lower()
|
84 |
+
if all(x in answer for x in label.lower().split(" ")):
|
85 |
+
correct += 1
|
86 |
+
accuracy = correct / len(labels)
|
87 |
+
print(f"Accuracy: {accuracy:.3f}")
|
88 |
+
print(f"Latency: {latency:.3f}")
|
89 |
+
|
90 |
+
# Write results
|
91 |
+
dump_state_text(f"tmp_output_{args.backend}.txt", states)
|
92 |
+
|
93 |
+
with open(args.result_file, "a") as fout:
|
94 |
+
value = {
|
95 |
+
"task": "multi_document_qa",
|
96 |
+
"backend": args.backend,
|
97 |
+
"num_gpus": 1,
|
98 |
+
"latency": round(latency, 3),
|
99 |
+
"num_requests": args.num_questions,
|
100 |
+
"accuracy": accuracy,
|
101 |
+
"other": {
|
102 |
+
"num_questions": args.num_questions,
|
103 |
+
"parallel": args.parallel,
|
104 |
+
},
|
105 |
+
}
|
106 |
+
fout.write(json.dumps(value) + "\n")
|
107 |
+
|
108 |
+
|
109 |
+
if __name__ == "__main__":
|
110 |
+
parser = argparse.ArgumentParser()
|
111 |
+
parser.add_argument("--data-path", type=str, default="questions.jsonl")
|
112 |
+
parser.add_argument("--num-questions", type=int, default=100)
|
113 |
+
args = add_common_other_args_and_parse(parser)
|
114 |
+
main(args)
|
sglang/benchmark/multi_document_qa/bench_sglang.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import time
|
4 |
+
|
5 |
+
import sglang as sgl
|
6 |
+
from sglang.test.test_utils import (
|
7 |
+
add_common_sglang_args_and_parse,
|
8 |
+
select_sglang_backend,
|
9 |
+
)
|
10 |
+
from sglang.utils import dump_state_text, read_jsonl
|
11 |
+
|
12 |
+
|
13 |
+
@sgl.function
|
14 |
+
def multi_document_qa(s, docs, question):
|
15 |
+
s += sgl.user_begin()
|
16 |
+
s += "Pleaes answer a question according to given documents.\n"
|
17 |
+
s += "Question:" + question + "Documents begin.\n"
|
18 |
+
|
19 |
+
forks = s.fork(len(docs))
|
20 |
+
forks += lambda i: docs[i]
|
21 |
+
forks.join("concate_and_append")
|
22 |
+
|
23 |
+
s += "\nDocuments end."
|
24 |
+
s += (
|
25 |
+
"\n\nBased on the above documents, please answer this question:\n"
|
26 |
+
+ question
|
27 |
+
+ "\nAnswer in three words or fewer."
|
28 |
+
)
|
29 |
+
s += sgl.user_end()
|
30 |
+
s += sgl.assistant(sgl.gen("answer", max_tokens=16))
|
31 |
+
|
32 |
+
|
33 |
+
def main(args):
|
34 |
+
lines = read_jsonl(args.data_path)
|
35 |
+
l = lines[0]
|
36 |
+
arguments = []
|
37 |
+
labels = []
|
38 |
+
for i in range(len(l["questions"][: args.num_questions])):
|
39 |
+
arguments.append(
|
40 |
+
{
|
41 |
+
"docs": l["documents"][:10],
|
42 |
+
"question": l["questions"][i],
|
43 |
+
}
|
44 |
+
)
|
45 |
+
labels.append(l["answers"][i])
|
46 |
+
|
47 |
+
# Select backend
|
48 |
+
backend = select_sglang_backend(args)
|
49 |
+
sgl.set_default_backend(backend)
|
50 |
+
|
51 |
+
# Run requests
|
52 |
+
tic = time.time()
|
53 |
+
states = multi_document_qa.run_batch(
|
54 |
+
arguments, temperature=0, num_threads=args.parallel, progress_bar=True
|
55 |
+
)
|
56 |
+
latency = time.time() - tic
|
57 |
+
|
58 |
+
# Compute accuracy
|
59 |
+
print([s["answer"] for s in states])
|
60 |
+
correct = 0
|
61 |
+
for s, label in zip(states, labels):
|
62 |
+
answer = s["answer"].lower()
|
63 |
+
if all(x in answer for x in label.lower().split(" ")):
|
64 |
+
correct += 1
|
65 |
+
accuracy = correct / len(labels)
|
66 |
+
print(f"Accuracy: {accuracy:.3f}")
|
67 |
+
print(f"Latency: {latency:.3f}")
|
68 |
+
|
69 |
+
# Write results
|
70 |
+
dump_state_text(f"tmp_output_{args.backend}.txt", states)
|
71 |
+
|
72 |
+
with open(args.result_file, "a") as fout:
|
73 |
+
value = {
|
74 |
+
"task": "multi_document_qa",
|
75 |
+
"backend": args.backend,
|
76 |
+
"num_gpus": 1,
|
77 |
+
"latency": round(latency, 3),
|
78 |
+
"num_requests": args.num_questions,
|
79 |
+
"accuracy": accuracy,
|
80 |
+
"other": {
|
81 |
+
"num_questions": args.num_questions,
|
82 |
+
"parallel": args.parallel,
|
83 |
+
},
|
84 |
+
}
|
85 |
+
fout.write(json.dumps(value) + "\n")
|
86 |
+
|
87 |
+
|
88 |
+
if __name__ == "__main__":
|
89 |
+
parser = argparse.ArgumentParser()
|
90 |
+
parser.add_argument("--data-path", type=str, default="questions.jsonl")
|
91 |
+
parser.add_argument("--num-questions", type=int, default=100)
|
92 |
+
args = add_common_sglang_args_and_parse(parser)
|
93 |
+
main(args)
|