tuandunghcmut commited on
Commit
6c1556e
·
verified ·
1 Parent(s): 127dcad

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. sglang/benchmark/bench_in_batch_prefix/bench_in_batch_prefix.py +130 -0
  2. sglang/benchmark/benchmark_vllm_060/README.md +89 -0
  3. sglang/benchmark/blog_v0_2/README.md +164 -0
  4. sglang/benchmark/blog_v0_2/config.md +100 -0
  5. sglang/benchmark/deepseek_v3/README.md +123 -0
  6. sglang/benchmark/generative_agents/README.md +38 -0
  7. sglang/benchmark/generative_agents/agent_functions.py +300 -0
  8. sglang/benchmark/generative_agents/bench_other.py +80 -0
  9. sglang/benchmark/generative_agents/bench_sglang.py +74 -0
  10. sglang/benchmark/hellaswag/bench_sglang.py +106 -0
  11. sglang/benchmark/json_decode_regex/README.md +60 -0
  12. sglang/benchmark/json_decode_regex/bench_other.py +98 -0
  13. sglang/benchmark/json_decode_regex/bench_sglang.py +100 -0
  14. sglang/benchmark/json_decode_regex/build_dataset.py +58 -0
  15. sglang/benchmark/json_jump_forward/README.md +88 -0
  16. sglang/benchmark/json_jump_forward/bench_other.py +288 -0
  17. sglang/benchmark/json_jump_forward/bench_sglang.py +143 -0
  18. sglang/benchmark/json_jump_forward/build_dataset.py +58 -0
  19. sglang/benchmark/json_jump_forward/dataset.txt +50 -0
  20. sglang/benchmark/json_schema/README.md +15 -0
  21. sglang/benchmark/json_schema/bench_sglang.py +146 -0
  22. sglang/benchmark/kernels/decoding_attention_triton/triton_flashinfer_cudnn.py +405 -0
  23. sglang/benchmark/kernels/fused_moe_triton/README.md +49 -0
  24. sglang/benchmark/kernels/rmsnorm/benchmark_rmsnorm.py +231 -0
  25. sglang/benchmark/kernels/scheduler_batch/benchmark_write_req_to_token_pool_triton.py +345 -0
  26. sglang/benchmark/line_retrieval/README.md +37 -0
  27. sglang/benchmark/line_retrieval/bench_sglang.py +149 -0
  28. sglang/benchmark/line_retrieval/gen_data.py +139 -0
  29. sglang/benchmark/llava_bench/README.md +61 -0
  30. sglang/benchmark/llava_bench/bench_hf_llava_bench.sh +9 -0
  31. sglang/benchmark/llava_bench/bench_hf_mme.sh +9 -0
  32. sglang/benchmark/llava_bench/bench_sglang.py +96 -0
  33. sglang/benchmark/llava_bench/bench_sglang_mme.sh +2 -0
  34. sglang/benchmark/llava_bench/download_images.py +20 -0
  35. sglang/benchmark/llava_bench/questions.jsonl +60 -0
  36. sglang/benchmark/llm_judge/README.md +33 -0
  37. sglang/benchmark/llm_judge/articles.jsonl +0 -0
  38. sglang/benchmark/llm_judge/bench_other.py +151 -0
  39. sglang/benchmark/llm_judge/bench_sglang.py +97 -0
  40. sglang/benchmark/long_json_decode/README.md +33 -0
  41. sglang/benchmark/long_json_decode/bench_other.py +89 -0
  42. sglang/benchmark/long_json_decode/bench_sglang.py +81 -0
  43. sglang/benchmark/long_json_decode/build_dataset.py +27 -0
  44. sglang/benchmark/mmlu/bench_other.py +173 -0
  45. sglang/benchmark/mmlu/bench_sglang.py +174 -0
  46. sglang/benchmark/mmlu/download_data.sh +2 -0
  47. sglang/benchmark/multi_chain_reasoning/bench_other.py +186 -0
  48. sglang/benchmark/multi_document_qa/README.md +47 -0
  49. sglang/benchmark/multi_document_qa/bench_other.py +114 -0
  50. sglang/benchmark/multi_document_qa/bench_sglang.py +93 -0
sglang/benchmark/bench_in_batch_prefix/bench_in_batch_prefix.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Benchmark with lots of common prefixes. Used to benchmark prefix caching performance.
2
+ #
3
+ # Launch a server:
4
+ # python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --log-level-http warning
5
+
6
+ import random
7
+ import string
8
+ import time
9
+
10
+ from tqdm import tqdm
11
+ from transformers import AutoTokenizer
12
+
13
+ import sglang as sgl
14
+ from sglang import set_default_backend
15
+ from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
16
+
17
+
18
+ def generate_random_string(token_length: int) -> str:
19
+ random_string = "".join(
20
+ random.choices(string.ascii_letters + string.digits, k=token_length * 100)
21
+ )
22
+ tokenized_output = tokenizer.encode(random_string, add_special_tokens=False)[
23
+ :token_length
24
+ ]
25
+
26
+ if len(tokenized_output) < token_length:
27
+ tokenized_output = tokenized_output + [tokenizer.pad_token_id] * (
28
+ token_length - len(tokenized_output)
29
+ )
30
+
31
+ decoded_string = tokenizer.decode(tokenized_output, skip_special_tokens=False)
32
+ return decoded_string
33
+
34
+
35
+ def generate_unique_prefix(base_text, index):
36
+ return str(index) + base_text[len(str(index)) :]
37
+
38
+
39
+ @sgl.function
40
+ def text_qa(s, question, gen_len):
41
+ s += "Q: " + question + "\n"
42
+ s += "A:" + sgl.gen("answer", stop="\n", temperature=0, max_tokens=gen_len)
43
+
44
+
45
+ def prepare_prompts(num_prefix, num_samples_per_prefix, prefix_length, suffix_length):
46
+ base_prefix = generate_random_string(prefix_length)
47
+
48
+ tot_input_len = 0
49
+ all_prompts = []
50
+ for i in tqdm(range(num_prefix), desc="prepare prompts"):
51
+ unique_prefix = generate_unique_prefix(base_prefix, i)
52
+ prompt_list = []
53
+ for j in range(num_samples_per_prefix):
54
+ suffix = generate_random_string(suffix_length)
55
+ prompt = unique_prefix + suffix
56
+ prompt_list.append(prompt)
57
+ tot_input_len += len(tokenizer.encode(prompt))
58
+ all_prompts.append(prompt_list)
59
+ return all_prompts, tot_input_len
60
+
61
+
62
+ def test_batch_by_batch(all_prompts, gen_len):
63
+ backend.flush_cache()
64
+
65
+ tot_time = 0
66
+ for i in range(len(all_prompts)):
67
+ tic = time.time()
68
+ text_qa.run_batch(
69
+ list(zip(all_prompts[i], [gen_len] * len(all_prompts[i]))),
70
+ )
71
+ tot_time += time.time() - tic
72
+
73
+ return tot_time
74
+
75
+
76
+ def test_batch_by_batch_with_hint(all_prompts, gen_len):
77
+ backend.flush_cache()
78
+
79
+ tot_time = 0
80
+ for i in range(len(all_prompts)):
81
+ tic = time.time()
82
+ # Send a hint to cache the prefix
83
+ text_qa.run_batch(list(zip(all_prompts[i][:1], [gen_len])))
84
+ # Send the batch
85
+ text_qa.run_batch(list(zip(all_prompts[i], [gen_len] * len(all_prompts[i]))))
86
+
87
+ tot_time += time.time() - tic
88
+
89
+ return tot_time
90
+
91
+
92
+ def test_send_all(all_prompts, gen_len):
93
+ backend.flush_cache()
94
+
95
+ all_prompts = [x for prompt_list in all_prompts for x in prompt_list]
96
+
97
+ tic = time.time()
98
+ text_qa.run_batch(
99
+ list(zip(all_prompts, [gen_len] * len(all_prompts))),
100
+ )
101
+ tot_time = time.time() - tic
102
+
103
+ return tot_time
104
+
105
+
106
+ if __name__ == "__main__":
107
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
108
+ backend = RuntimeEndpoint("http://127.0.0.1:30000")
109
+ set_default_backend(backend)
110
+
111
+ random.seed(0)
112
+ num_prefix = 10
113
+ num_samples_per_prefix = 32
114
+ prefix_length = 1024
115
+ suffix_length = 128
116
+ gen_len = 1
117
+ all_prompts, tot_input_len = prepare_prompts(
118
+ num_prefix, num_samples_per_prefix, prefix_length, suffix_length
119
+ )
120
+
121
+ print(f"Total input token length: {tot_input_len}\n")
122
+
123
+ cost = test_batch_by_batch(all_prompts, gen_len)
124
+ print(f"Latency of test_batch_by_batch : {cost:.4f} s\n")
125
+
126
+ cost = test_batch_by_batch_with_hint(all_prompts, gen_len)
127
+ print(f"Latency of test_batch_by_batch_with_hint: {cost:.4f} s\n")
128
+
129
+ cost = test_send_all(all_prompts, gen_len)
130
+ print(f"Latency of test_send_all : {cost:.4f} s\n")
sglang/benchmark/benchmark_vllm_060/README.md ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## How to reproduce the benchmark results for SGLang v0.3.0 compared to vLLM v0.6.0
2
+
3
+ In short, with multi step enabled, in online scenarios that we benchmarked, the Median TTFT of vLLM is **3 times** that of SGLang, and the Median ITL is **10 times** that of SGLang. Lower Median TTFT and ITL are better. vLLM's multi-step optimization did not improve throughput while ensuring lower Median TTFT and ITL. Also, under maximum throughput benchmark, if vLLM does not set gpu util to 0.95 separately and uses the default configuration instead, its maximum throughput is **lower** than that of SGLang.
4
+
5
+ ## Online benchmark results
6
+
7
+ ### Llama 3.1 8B Instruct 1 x A100 80G
8
+
9
+ | RPS | Num prompts | Engine | Median E2E Latency | Median TTFT | Median TPOT | Median ITL |
10
+ |------|-------------|--------|--------------------|-------------|-------------|------------|
11
+ | 4 | 1200 | SGLang | 1564.17 | **31.98** | 13.17 | **11.93** |
12
+ | 4 | 1200 | vLLM | 1691.97 | **100.48** | 14.14 | **129.32** |
13
+ | 8 | 2400 | SGLang | 2175.02 | **35.68** | 17.85 | **14.41** |
14
+ | 8 | 2400 | vLLM | 2137.16 | **120.39** | 17.09 | **158.63** |
15
+
16
+ ### Llama 3.1 70B Insruct 4 x H100 80G
17
+
18
+ | RPS | Num Prompts | Engine | Median E2E Latency | Median TTFT | Median TPOT | Median ITL |
19
+ |------|-------------|--------|--------------------|-------------|-------------|------------|
20
+ | 4 | 1200 | SGLang | 3005.24 | **53.94** | 25.03 | **21.67** |
21
+ | 4 | 1200 | vLLM | 2915.60 | **179.15** | 23.58 | **231.23** |
22
+ | 8 | 2400 | SGLang | 4064.98 | **58.11** | 33.07 | **24.45** |
23
+ | 8 | 2400 | vLLM | 3752.38 | **207.12** | 29.15 | **275.32** |
24
+
25
+ ## Offline benchmark results
26
+
27
+ ### Llama 3.1 8B Instruct 1 x A100 80G
28
+
29
+ | RPS | Num Prompts | Engine | Request throughput | Output token throughput |
30
+ |------|-------------|--------|--------------------|-------------------------|
31
+ | inf | 5000 | SGLang | 22.03 | **4281.51** |
32
+ | inf | 5000 | vLLM | 21.27 | **4132.37** |
33
+
34
+ ### Llama 3.1 70B Insruct 4 x H100 80G
35
+
36
+ | RPS | Num Prompts | Engine | Request throughput | Output token throughput |
37
+ |------|-------------|--------|--------------------|-------------------------|
38
+ | inf | 5000 | SGLang | 19.84 | **3856.01** |
39
+ | inf | 5000 | vLLM | 19.04 | **3700.64** |
40
+
41
+ ## Installation
42
+
43
+ ```bash
44
+ # install sglang v0.3.0
45
+ pip install --upgrade pip
46
+ pip install "sglang[all]"==0.3.0
47
+ pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/
48
+
49
+ # install vllm v0.6.0
50
+ pip install vllm==0.6.0
51
+ ```
52
+
53
+ ## Notes
54
+
55
+ We referred to the reproduction method in https://github.com/vllm-project/vllm/issues/8176, and added the `--num-scheduler-steps 10` parameter when starting the vLLM server. The `gpu_memory_utilization` of vLLM is by default 0.9 at both TP 1 and TP 4, while SGLang's `mem_frac` is 0.88 at TP 1 and 0.85 at TP 4, so we manually set it to 0.88 at TP 4.
56
+
57
+ ## Online benchmarks
58
+
59
+ ```bash
60
+ # Llama 3.1 8B Instruct on 1 x A100
61
+ python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --enable-torch-compile --disable-radix-cache
62
+ python -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-3.1-8B-Instruct --disable-log-requests --num-scheduler-steps 10 --max_model_len 4096
63
+
64
+ # Llama 3.1 70B Instruct on 4 x H100
65
+ python -m sglang.launch_server --model-path meta-llama/Llama-3.1-70B-Instruct --disable-radix-cache --tp 4
66
+ python -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-3.1-70B-Instruct --disable-log-requests --num-scheduler-steps 10 --tensor 4 --max_model_len 4096
67
+
68
+ # bench serving
69
+ python3 -m sglang.bench_serving --backend sglang --dataset-name sharegpt --num-prompts 1200 --request-rate 4
70
+ python3 -m sglang.bench_serving --backend sglang --dataset-name sharegpt --num-prompts 2400 --request-rate 8
71
+ python3 -m sglang.bench_serving --backend vllm --dataset-name sharegpt --num-prompts 1200 --request-rate 4
72
+ python3 -m sglang.bench_serving --backend vllm --dataset-name sharegpt --num-prompts 2400 --request-rate 8
73
+ ```
74
+
75
+ ## Offline benchmarks
76
+
77
+ ```bash
78
+ # Llama 3.1 8B Instruct on 1 x A100
79
+ python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --enable-torch-compile --disable-radix-cache
80
+ python -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-3.1-8B-Instruct --disable-log-requests --num-scheduler-steps 10 --max_model_len 4096
81
+
82
+ # Llama 3.1 70B Instruct on 4 x H100
83
+ python -m sglang.launch_server --model-path meta-llama/Llama-3.1-70B-Instruct --disable-radix-cache --tp 4 --mem-frac 0.88
84
+ python -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-3.1-70B-Instruct --disable-log-requests --num-scheduler-steps 10 --tensor 4 --max_model_len 4096
85
+
86
+ # bench serving
87
+ python3 -m sglang.bench_serving --backend sglang --dataset-name sharegpt --num-prompts 5000
88
+ python3 -m sglang.bench_serving --backend vllm --dataset-name sharegpt --num-prompts 5000
89
+ ```
sglang/benchmark/blog_v0_2/README.md ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # How to reproduce the benchmark results of SGLang
2
+
3
+ ## Prerequisite
4
+
5
+ ### Install the latest SGLang
6
+
7
+ ```bash
8
+ git clone https://github.com/sgl-project/sglang.git
9
+ cd sglang
10
+ git checkout v0.2.7
11
+
12
+ pip install --upgrade pip
13
+ pip install -e "python[all]"
14
+
15
+ pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.3/
16
+ ```
17
+
18
+ ### Set up ulimit and HF_TOKEN
19
+
20
+ ```bash
21
+ ulimit -n 65535
22
+ # Change the token to a real and usable one, with access permissions for the Llama 3 models.
23
+ export HF_TOKEN=hf_token
24
+ ```
25
+
26
+ ### Launch the server
27
+
28
+ ```bash
29
+ # Meta-Llama-3.1-8B-Instruct
30
+ python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --enable-torch-compile --disable-radix-cache
31
+
32
+ # Meta-Llama-3.1-70B-Instruct
33
+ python -m sglang.launch_server --model-path meta-llama/Llama-3.1-70B-Instruct --disable-radix-cache --tp 8
34
+
35
+ # Meta-Llama-3-70B-Instruct-FP8
36
+ python -m sglang.launch_server --model-path neuralmagic/Meta-Llama-3-70B-Instruct-FP8 --disable-radix-cache --tp 8
37
+ ```
38
+
39
+ ## Benchmark
40
+
41
+ ### Hardware Requirements
42
+
43
+ - 8B models: Single NVIDIA A100 80GB GPU
44
+ - 70B models: 8 x NVIDIA A100 80GB GPUs with Tensor Parallelism (TP) 8
45
+ - 70B FP8 models: 8 x NVIDIA H100 GPUs with Tensor Parallelism (TP) 8
46
+
47
+ Please ensure you have the appropriate hardware before running the benchmarks.
48
+
49
+ #### Offline benchmark
50
+
51
+ ```bash
52
+ python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 4000 --random-input 1024 --random-output 1024 --output-file offline.jsonl
53
+ python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 5000 --random-input 1024 --random-output 512 --output-file offline.jsonl
54
+ python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 1000 --random-input 4096 --random-output 2048 --output-file offline.jsonl
55
+ python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 2000 --random-input 4096 --random-output 1024 --output-file offline.jsonl
56
+ python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 6000 --random-input 256 --random-output 512 --output-file offline.jsonl
57
+ python3 -m sglang.bench_serving --backend sglang --dataset-name sharegpt --num-prompts 3000 --output-file offline.jsonl
58
+ cat offline.jsonl | cut -d':' -f12 | cut -d',' -f1
59
+ ```
60
+
61
+ #### Online benchmark
62
+
63
+ ```bash
64
+ python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 300 --request-rate 1 --output-file online.jsonl
65
+ python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 600 --request-rate 2 --output-file online.jsonl
66
+ python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 1200 --request-rate 4 --output-file online.jsonl
67
+ python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 2400 --request-rate 8 --output-file online.jsonl
68
+ python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 3200 --request-rate 16 --output-file online.jsonl
69
+ cat online.jsonl | cut -d':' -f9 | cut -d',' -f1
70
+ ```
71
+
72
+ ## Other
73
+
74
+ We tried using vLLM 0.5.3.post1, but it often crashes under high loads, and it seems to have similar or worse performance compared to vLLM 0.5.2 from our partial benchmarking, so we are using the older version, vLLM 0.5.2.
75
+
76
+ Preparation for TensorRT LLM can refer to https://github.com/sgl-project/tensorrt-demo. Specifically, we used a batch size of 512, a max input length of 8192, and a max number of tokens of 8192. The instance count for preprocessing and postprocessing in Triton Server is 16.
77
+
78
+ ```bash
79
+ # vLLM
80
+ pip install vllm==0.5.2
81
+ pip install jsonschema==4.21.1
82
+
83
+ # Meta-Llama-3-8B-Instruct
84
+ python -m vllm.entrypoints.openai.api_server --model meta-llama/Meta-Llama-3-8B-Instruct --disable-log-requests
85
+
86
+ # meta-llama/Meta-Llama-3-70B-Instruct
87
+ python -m vllm.entrypoints.openai.api_server --model meta-llama/Meta-Llama-3-70B-Instruct --disable-log-requests --tensor 8
88
+
89
+ # neuralmagic/Meta-Llama-3-70B-Instruct-FP8
90
+ python -m vllm.entrypoints.openai.api_server --model neuralmagic/Meta-Llama-3-70B-Instruct-FP8 --disable-log-requests --tensor 8
91
+ ```
92
+
93
+ ```bash
94
+ wget https://raw.githubusercontent.com/sgl-project/sglang/main/python/sglang/bench_serving.py
95
+ ```
96
+
97
+ ```bash
98
+ # vLLM Offline
99
+
100
+ python3 bench_serving.py --backend vllm --dataset-name random --num-prompts 4000 --random-input 1024 --random-output 1024 --output-file offline_vllm.jsonl
101
+ python3 bench_serving.py --backend vllm --dataset-name random --num-prompts 5000 --random-input 1024 --random-output 512 --output-file offline_vllm.jsonl
102
+ python3 bench_serving.py --backend vllm --dataset-name random --num-prompts 1000 --random-input 4096 --random-output 2048 --output-file offline_vllm.jsonl
103
+ python3 bench_serving.py --backend vllm --dataset-name random --num-prompts 2000 --random-input 4096 --random-output 1024 --output-file offline_vllm.jsonl
104
+ python3 bench_serving.py --backend vllm --dataset-name random --num-prompts 6000 --random-input 256 --random-output 512 --output-file offline_vllm.jsonl
105
+ python3 bench_serving.py --backend vllm --dataset-name sharegpt --num-prompts 3000 --output-file offline_vllm.jsonl
106
+ cat offline_vllm.jsonl | cut -d':' -f12 | cut -d',' -f1
107
+ ```
108
+
109
+ ```bash
110
+ # vLLM Online
111
+
112
+ python3 bench_serving.py --backend vllm --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 300 --request-rate 1 --output-file online_vllm.jsonl
113
+ python3 bench_serving.py --backend vllm --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 600 --request-rate 2 --output-file online_vllm.jsonl
114
+ python3 bench_serving.py --backend vllm --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 1200 --request-rate 4 --output-file online_vllm.jsonl
115
+ python3 bench_serving.py --backend vllm --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 2400 --request-rate 8 --output-file online_vllm.jsonl
116
+ python3 bench_serving.py --backend vllm --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 3200 --request-rate 16 --output-file online_vllm.jsonl
117
+ cat online_vllm.jsonl | cut -d':' -f9 | cut -d',' -f1
118
+ ```
119
+
120
+ ```bash
121
+ # TensorRT LLM Offline 8B
122
+
123
+ python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-8B-Instruct --dataset-name random --num-prompts 4000 --random-input 1024 --random-output 1024 --output-file offline_trt_8b.jsonl
124
+ python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-8B-Instruct --dataset-name random --num-prompts 5000 --random-input 1024 --random-output 512 --output-file offline_trt_8b.jsonl
125
+ python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-8B-Instruct --dataset-name random --num-prompts 1000 --random-input 4096 --random-output 2048 --output-file offline_trt_8b.jsonl
126
+ python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-8B-Instruct --dataset-name random --num-prompts 2000 --random-input 4096 --random-output 1024 --output-file offline_trt_8b.jsonl
127
+ python3 bench_serving.py --backend trt --dataset-name random --num-prompts 6000 --random-input 256 --random-output 512 --output-file offline_trt_8b.jsonl --model meta-llama/Meta-Llama-3-8B-Instruct
128
+ python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-8B-Instruct --dataset-name sharegpt --num-prompts 3000 --output-file offline_trt_8b.jsonl
129
+ cat offline_trt_8b.jsonl | cut -d':' -f12 | cut -d',' -f1
130
+ ```
131
+
132
+ ```bash
133
+ # TensorRT LLM Online 8B
134
+
135
+ python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-8B-Instruct --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 300 --request-rate 1 --output-file online_trt_8b.jsonl
136
+ python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-8B-Instruct --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 600 --request-rate 2 --output-file online_trt_8b.jsonl
137
+ python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-8B-Instruct --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 1200 --request-rate 4 --output-file online_trt_8b.jsonl
138
+ python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-8B-Instruct --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 2400 --request-rate 8 --output-file online_trt_8b.jsonl
139
+ python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-8B-Instruct --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 3200 --request-rate 16 --output-file online_trt_8b.jsonl
140
+ cat online_trt_8b.jsonl | cut -d':' -f9 | cut -d',' -f1
141
+ ```
142
+
143
+ ```bash
144
+ # TensorRT LLM Offline 70B
145
+
146
+ python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-70B-Instruct --dataset-name random --num-prompts 4000 --random-input 1024 --random-output 1024 --output-file offline_trt_70b.jsonl
147
+ python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-70B-Instruct --dataset-name random --num-prompts 5000 --random-input 1024 --random-output 512 --output-file offline_trt_70b.jsonl
148
+ python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-70B-Instruct --dataset-name random --num-prompts 1000 --random-input 4096 --random-output 2048 --output-file offline_trt_70b.jsonl
149
+ python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-70B-Instruct --dataset-name random --num-prompts 2000 --random-input 4096 --random-output 1024 --output-file offline_trt_70b.jsonl
150
+ python3 bench_serving.py --backend trt --dataset-name random --num-prompts 6000 --random-input 256 --random-output 512 --output-file offline_trt_70b.jsonl --model meta-llama/Meta-Llama-3-70B-Instruct
151
+ python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-70B-Instruct --dataset-name sharegpt --num-prompts 3000 --output-file offline_trt_70b.jsonl
152
+ cat offline_trt_70b.jsonl | cut -d':' -f12 | cut -d',' -f1
153
+ ```
154
+
155
+ ```bash
156
+ # TensorRT LLM Online 70B
157
+
158
+ python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-70B-Instruct --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 300 --request-rate 1 --output-file online_trt_70b.jsonl
159
+ python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-70B-Instruct --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 600 --request-rate 2 --output-file online_trt_70b.jsonl
160
+ python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-70B-Instruct --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 1200 --request-rate 4 --output-file online_trt_70b.jsonl
161
+ python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-70B-Instruct --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 2400 --request-rate 8 --output-file online_trt_70b.jsonl
162
+ python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-70B-Instruct --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 3200 --request-rate 16 --output-file online_trt_70b.jsonl
163
+ cat online_trt_70b.jsonl | cut -d':' -f9 | cut -d',' -f1
164
+ ```
sglang/benchmark/blog_v0_2/config.md ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### used for TensorRT LLM
2
+
3
+ ```
4
+ {
5
+ "architecture": "LlamaForCausalLM",
6
+ "dtype": "float16",
7
+ "logits_dtype": "float32",
8
+ "vocab_size": 128256,
9
+ "max_position_embeddings": 8192,
10
+ "hidden_size": 16384,
11
+ "num_hidden_layers": 126,
12
+ "num_attention_heads": 128,
13
+ "num_key_value_heads": 16,
14
+ "head_size": 128,
15
+ "qk_layernorm": false,
16
+ "hidden_act": "silu",
17
+ "intermediate_size": 53248,
18
+ "norm_epsilon": 1e-05,
19
+ "position_embedding_type": "rope_gpt_neox",
20
+ "use_parallel_embedding": false,
21
+ "embedding_sharding_dim": 0,
22
+ "share_embedding_table": false,
23
+ "mapping": {
24
+ "world_size": 8,
25
+ "tp_size": 8,
26
+ "pp_size": 1,
27
+ "gpus_per_node": 8
28
+ },
29
+ "quantization": {
30
+ "quant_algo": "FP8",
31
+ "kv_cache_quant_algo": null,
32
+ "group_size": 128,
33
+ "smoothquant_val": null,
34
+ "has_zero_point": false,
35
+ "pre_quant_scale": false,
36
+ "exclude_modules": [
37
+ "lm_head"
38
+ ]
39
+ },
40
+ "kv_dtype": "float16",
41
+ "rotary_scaling": null,
42
+ "residual_mlp": false,
43
+ "moe_normalization_mode": null,
44
+ "rotary_base": 500000.0,
45
+ "moe_num_experts": 0,
46
+ "moe_top_k": 0,
47
+ "moe_tp_mode": 2,
48
+ "attn_bias": false,
49
+ "disable_weight_only_quant_plugin": false,
50
+ "mlp_bias": false
51
+ }
52
+ ```
53
+
54
+ ### used for vLLM and SGLang
55
+
56
+ ```
57
+ {
58
+ "_name_or_path": "dummy_fp8",
59
+ "architectures": [
60
+ "LlamaForCausalLM"
61
+ ],
62
+ "attention_bias": false,
63
+ "attention_dropout": 0.0,
64
+ "bos_token_id": 128000,
65
+ "eos_token_id": 128009,
66
+ "hidden_act": "silu",
67
+ "hidden_size": 16384,
68
+ "initializer_range": 0.02,
69
+ "intermediate_size": 53248,
70
+ "mlp_bias": false,
71
+ "model_type": "llama",
72
+ "num_attention_heads": 128,
73
+ "num_hidden_layers": 126,
74
+ "num_key_value_heads": 8,
75
+ "pretraining_tp": 1,
76
+ "quantization_config": {
77
+ "activation_scheme": "static",
78
+ "ignored_layers": [
79
+ "lm_head"
80
+ ],
81
+ "quant_method": "fp8"
82
+ },
83
+ "rope_scaling": {
84
+ "factor": 8.0,
85
+ "low_freq_factor": 1.0,
86
+ "high_freq_factor": 4.0,
87
+ "original_max_position_embeddings": 8192,
88
+ "rope_type": "llama3"
89
+ },
90
+ "max_position_embeddings": 131072,
91
+ "rms_norm_eps": 1e-05,
92
+ "rope_scaling": null,
93
+ "rope_theta": 500000.0,
94
+ "tie_word_embeddings": false,
95
+ "torch_dtype": "bfloat16",
96
+ "transformers_version": "4.41.1",
97
+ "use_cache": true,
98
+ "vocab_size": 128256
99
+ }
100
+ ```
sglang/benchmark/deepseek_v3/README.md ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DeepSeek V3 Support
2
+
3
+ The SGLang and DeepSeek teams collaborated to get DeepSeek V3 FP8 running on NVIDIA and AMD GPUs **from day one**. SGLang also supports [MLA optimization](https://lmsys.org/blog/2024-09-04-sglang-v0-3/#deepseek-multi-head-latent-attention-mla-throughput-optimizations) and [DP attention](https://lmsys.org/blog/2024-12-04-sglang-v0-4/#data-parallelism-attention-for-deepseek-models), making SGLang one of the best open-source LLM engines for running DeepSeek models. SGLang is the inference engine recommended by the official [DeepSeek team](https://github.com/deepseek-ai/DeepSeek-V3/tree/main?tab=readme-ov-file#62-inference-with-sglang-recommended).
4
+
5
+ Special thanks to Meituan's Search & Recommend Platform Team and Baseten's Model Performance Team for implementing the model, and DataCrunch for providing GPU resources.
6
+
7
+ ## Hardware Recommendation
8
+ - 8 x NVIDIA H200 GPUs
9
+
10
+ If you do not have GPUs with large enough memory, please try multi-node tensor parallelism. There is an example serving with [2 H20 nodes](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3#example-serving-with-2-h208) below.
11
+
12
+ ## Installation & Launch
13
+
14
+ If you encounter errors when starting the server, ensure the weights have finished downloading. It's recommended to download them beforehand or restart multiple times until all weights are downloaded.
15
+
16
+ ### Using Docker (Recommended)
17
+ ```bash
18
+ # Pull latest image
19
+ # https://hub.docker.com/r/lmsysorg/sglang/tags
20
+ docker pull lmsysorg/sglang:latest
21
+
22
+ # Launch
23
+ docker run --gpus all --shm-size 32g -p 30000:30000 -v ~/.cache/huggingface:/root/.cache/huggingface --ipc=host lmsysorg/sglang:latest \
24
+ python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code --port 30000
25
+ ```
26
+
27
+ For high QPS scenarios, add the `--enable-dp-attention` argument to boost throughput.
28
+
29
+ ### Using pip
30
+ ```bash
31
+ # Installation
32
+ pip install "sglang[all]>=0.4.1.post3" --find-links https://flashinfer.ai/whl/cu124/torch2.4/flashinfer
33
+
34
+ # Launch
35
+ python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code
36
+ ```
37
+
38
+ For high QPS scenarios, add the `--enable-dp-attention` argument to boost throughput.
39
+
40
+ ### Example with OpenAI API
41
+
42
+ ```python3
43
+ import openai
44
+ client = openai.Client(
45
+ base_url="http://127.0.0.1:30000/v1", api_key="EMPTY")
46
+
47
+ # Chat completion
48
+ response = client.chat.completions.create(
49
+ model="default",
50
+ messages=[
51
+ {"role": "system", "content": "You are a helpful AI assistant"},
52
+ {"role": "user", "content": "List 3 countries and their capitals."},
53
+ ],
54
+ temperature=0,
55
+ max_tokens=64,
56
+ )
57
+ print(response)
58
+ ```
59
+ ### Example serving with 2 H20*8
60
+ For example, there are two H20 nodes, each with 8 GPUs. The first node's IP is `10.0.0.1`, and the second node's IP is `10.0.0.2`.
61
+
62
+ ```bash
63
+ # node 1
64
+ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3 --tp 16 --dist-init-addr 10.0.0.1:5000 --nnodes 2 --node-rank 0 --trust-remote-code
65
+
66
+ # node 2
67
+ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3 --tp 16 --dist-init-addr 10.0.0.1:5000 --nnodes 2 --node-rank 1 --trust-remote-code
68
+ ```
69
+
70
+ If you have two H100 nodes, the usage is similar to the aforementioned H20.
71
+
72
+ ### Example serving with Docker two H200*8 nodes
73
+ There are two H200 nodes, each with 8 GPUs. The first node's IP is `192.168.114.10`, and the second node's IP is `192.168.114.11`. Configure the endpoint to expose it to another Docker container using `--host 0.0.0.0` and `--port 40000`, and set up communications with `--dist-init-addr 192.168.114.10:20000`.
74
+ A single H200 with 8 devices can run DeepSeek V3, the dual H200 setup is just to demonstrate multi-node usage.
75
+
76
+ ```bash
77
+ # node 1
78
+ docker run --gpus all \
79
+ --shm-size 32g \
80
+ --network=host \
81
+ -v ~/.cache/huggingface:/root/.cache/huggingface \
82
+ --name sglang_multinode1 \
83
+ -it \
84
+ --rm \
85
+ --env "HF_TOKEN=$HF_TOKEN" \
86
+ --ipc=host \
87
+ lmsysorg/sglang:latest \
88
+ python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3 --tp 16 --dist-init-addr 192.168.114.10:20000 --nnodes 2 --node-rank 0 --trust-remote-code --host 0.0.0.0 --port 40000
89
+ ```
90
+
91
+ ```bash
92
+ # node 2
93
+ docker run --gpus all \
94
+ --shm-size 32g \
95
+ --network=host \
96
+ -v ~/.cache/huggingface:/root/.cache/huggingface \
97
+ --name sglang_multinode2 \
98
+ -it \
99
+ --rm \
100
+ --env "HF_TOKEN=$HF_TOKEN" \
101
+ --ipc=host \
102
+ lmsysorg/sglang:latest \
103
+ python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3 --tp 16 --dist-init-addr 192.168.114.10:20000 --nnodes 2 --node-rank 1 --trust-remote-code --host 0.0.0.0 --port 40000
104
+ ```
105
+
106
+ To ensure functionality, we include a test from a client Docker container.
107
+ ```bash
108
+ docker run --gpus all \
109
+ --shm-size 32g \
110
+ --network=host \
111
+ -v ~/.cache/huggingface:/root/.cache/huggingface \
112
+ --name sglang_multinode_client \
113
+ -it \
114
+ --rm \
115
+ --env "HF_TOKEN=$HF_TOKEN" \
116
+ --ipc=host \
117
+ lmsysorg/sglang:latest \
118
+ python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 1 --random-output 512 --random-range-ratio 1 --num-prompts 1 --host 0.0.0.0 --port 40000 --output-file "deepseekv3_multinode.jsonl"
119
+ ```
120
+
121
+ ## DeepSeek V3 Optimization Plan
122
+
123
+ https://github.com/sgl-project/sglang/issues/2591
sglang/benchmark/generative_agents/README.md ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Download the dataset
2
+
3
+ ```
4
+ wget -O agent_calls.jsonl https://drive.google.com/uc?export=download&id=19qLpD45e9JGTKF2cUjJJegwzSUEZEKht
5
+ ```
6
+
7
+ ## Run benchmark
8
+
9
+ Ensure that this benchmark is run in a serial manner (using --parallel 1) to preserve any potential dependencies between requests.
10
+
11
+ ### Benchmark sglang
12
+ ```
13
+ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
14
+ ```
15
+
16
+ ```
17
+ python3 bench_sglang.py --num-events 1000 --parallel 1
18
+ ```
19
+
20
+ ### Benchmark vllm
21
+ ```
22
+ python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
23
+ ```
24
+
25
+ ```
26
+ python3 bench_other.py --num-events 1000 --backend vllm --parallel 1
27
+ ```
28
+
29
+ ### Benchmark guidance
30
+ ```
31
+ python3 bench_other.py --num-events 1000 --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf
32
+ ```
33
+
34
+ ### Benchmark lmql
35
+
36
+ ```
37
+ python3 bench_other.py --num-events 1000 --backend lmql --parallel 1
38
+ ```
sglang/benchmark/generative_agents/agent_functions.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sglang as sgl
2
+
3
+ # here are the top five agent functions contributing ~70% LLM calls
4
+ # reference: https://github.com/joonspk-research/generative_agents/
5
+
6
+
7
+ @sgl.function
8
+ def poignancy_event(s, persona_name, persona_iss, event):
9
+ s += "Here is a brief description of " + persona_name + ".\n"
10
+ s += persona_iss + "\n"
11
+ s += "On the scale of 1 to 10, where 1 is purely mundane (e.g., brushing teeth, making bed) and 10 is extremely poignant (e.g., a break up, college acceptance), rate the likely poignancy of the following event for"
12
+ s += persona_name + ".\n\n"
13
+ s += "Event: " + event
14
+ s += "Rate (return a number between 1 to 10):"
15
+ s += sgl.gen(name="Rate", max_tokens=2)
16
+
17
+
18
+ def poignancy_event_prompt(persona_name, persona_iss, event):
19
+ # return prompt and max_tokens
20
+ s = ""
21
+ s += "Here is a brief description of " + persona_name + ".\n"
22
+ s += persona_iss + "\n"
23
+ s += "On the scale of 1 to 10, where 1 is purely mundane (e.g., brushing teeth, making bed) and 10 is extremely poignant (e.g., a break up, college acceptance), rate the likely poignancy of the following event for"
24
+ s += persona_name + ".\n\n"
25
+ s += "Event: " + event
26
+ s += "Rate (return a number between 1 to 10):"
27
+ return {"prompt": s, "max_tokens": 2, "stop": None}
28
+
29
+
30
+ @sgl.function
31
+ def generate_event_triple(s, persona_name, action):
32
+ s += """Task: Turn the input into (subject, predicate, object).
33
+ Input: Sam Johnson is eating breakfast.
34
+ Output: (Dolores Murphy, eat, breakfast)
35
+ ---
36
+ Input: Joon Park is brewing coffee.
37
+ Output: (Joon Park, brew, coffee)
38
+ ---
39
+ Input: Jane Cook is sleeping.
40
+ Output: (Jane Cook, is, sleep)
41
+ ---
42
+ Input: Michael Bernstein is writing email on a computer.
43
+ Output: (Michael Bernstein, write, email)
44
+ ---
45
+ Input: Percy Liang is teaching students in a classroom.
46
+ Output: (Percy Liang, teach, students)
47
+ ---
48
+ Input: Merrie Morris is running on a treadmill.
49
+ Output: (Merrie Morris, run, treadmill)
50
+ ---"""
51
+ s += persona_name + "is" + action + ".\n"
52
+ s += "(" + persona_name + ","
53
+ s += sgl.gen(name="Triple", max_tokens=20, stop=")")
54
+
55
+
56
+ def generate_event_triple_prompt(persona_name, action):
57
+ s = ""
58
+ s += """Task: Turn the input into (subject, predicate, object).
59
+ Input: Sam Johnson is eating breakfast.
60
+ Output: (Dolores Murphy, eat, breakfast)
61
+ ---
62
+ Input: Joon Park is brewing coffee.
63
+ Output: (Joon Park, brew, coffee)
64
+ ---
65
+ Input: Jane Cook is sleeping.
66
+ Output: (Jane Cook, is, sleep)
67
+ ---
68
+ Input: Michael Bernstein is writing email on a computer.
69
+ Output: (Michael Bernstein, write, email)
70
+ ---
71
+ Input: Percy Liang is teaching students in a classroom.
72
+ Output: (Percy Liang, teach, students)
73
+ ---
74
+ Input: Merrie Morris is running on a treadmill.
75
+ Output: (Merrie Morris, run, treadmill)
76
+ ---"""
77
+ s += persona_name + "is" + action + ".\n"
78
+ s += "(" + persona_name + ","
79
+ return {"prompt": s, "max_tokens": 20, "stop": ")"}
80
+
81
+
82
+ @sgl.function
83
+ def generate_pronunciatio(s, action):
84
+ s += "Convert an action description to an emoji (important: use two or less emojis).\n"
85
+ s += "Action description: " + action + ".\n"
86
+ s += "Emoji:" + sgl.gen(name="Emoji", max_tokens=6)
87
+
88
+
89
+ def generate_pronunciatio_prompt(action):
90
+ s = ""
91
+ s += "Convert an action description to an emoji (important: use two or less emojis).\n"
92
+ s += "Action description: " + action + ".\n"
93
+ s += "Emoji:"
94
+ return {"prompt": s, "max_tokens": 6, "stop": None}
95
+
96
+
97
+ @sgl.function
98
+ def action_location_sector(
99
+ s,
100
+ persona_name,
101
+ living_sector,
102
+ living_sector_areas,
103
+ current_sector,
104
+ current_sector_areas,
105
+ daily_plan,
106
+ sector_options,
107
+ current_action,
108
+ next_action,
109
+ ):
110
+ s += """Task -- choose an appropriate area from the area options for a task at hand.
111
+ Sam Kim lives in {Sam Kim's house} that has Sam Kim's room, bathroom, kitchen.
112
+ Sam Kim is currently in {Sam Kim's house} that has Sam Kim's room, bathroom, kitchen.
113
+ Area options: {Sam Kim's house, The Rose and Crown Pub, Hobbs Cafe, Oak Hill College, Johnson Park, Harvey Oak Supply Store, The Willows Market and Pharmacy}.
114
+ * Stay in the current area if the activity can be done there. Only go out if the activity needs to take place in another place.
115
+ * Must be one of the "Area options," verbatim.
116
+ For taking a walk, Sam Kim should go to the following area: {Johnson Park}
117
+ ---
118
+ Jane Anderson lives in {Oak Hill College Student Dormatory} that has Jane Anderson's room.
119
+ Jane Anderson is currently in {Oak Hill College} that has a classroom, library
120
+ Area options: {Oak Hill College Student Dormatory, The Rose and Crown Pub, Hobbs Cafe, Oak Hill College, Johnson Park, Harvey Oak Supply Store, The Willows Market and Pharmacy}.
121
+ * Stay in the current area if the activity can be done there. Only go out if the activity needs to take place in another place.
122
+ * Must be one of the "Area options," verbatim.
123
+ For eating dinner, Jane Anderson should go to the following area: {Hobbs Cafe}
124
+ ---"""
125
+ s += (
126
+ persona_name
127
+ + " lives in "
128
+ + living_sector
129
+ + " that has "
130
+ + living_sector_areas
131
+ + ".\n"
132
+ )
133
+ s += (
134
+ persona_name
135
+ + " is currently in "
136
+ + current_sector
137
+ + " that has "
138
+ + current_sector_areas
139
+ + ".\n"
140
+ )
141
+ s += daily_plan + ".\n"
142
+ s += "Area options: " + sector_options + ".\n"
143
+ s += """* Stay in the current area if the activity can be done there. Only go out if the activity needs to take place in another place.
144
+ * Must be one of the "Area options," verbatim.\n"""
145
+ s += (
146
+ persona_name
147
+ + " is "
148
+ + current_action
149
+ + ". For "
150
+ + next_action
151
+ + ", "
152
+ + persona_name
153
+ + " should go to the following area: {"
154
+ )
155
+ s += sgl.gen(name="Location", max_tokens=10, stop="}")
156
+
157
+
158
+ def action_location_sector_prompt(
159
+ persona_name,
160
+ living_sector,
161
+ living_sector_areas,
162
+ current_sector,
163
+ current_sector_areas,
164
+ daily_plan,
165
+ sector_options,
166
+ current_action,
167
+ next_action,
168
+ ):
169
+ s = ""
170
+ s += """Task -- choose an appropriate area from the area options for a task at hand.
171
+ Sam Kim lives in {Sam Kim's house} that has Sam Kim's room, bathroom, kitchen.
172
+ Sam Kim is currently in {Sam Kim's house} that has Sam Kim's room, bathroom, kitchen.
173
+ Area options: {Sam Kim's house, The Rose and Crown Pub, Hobbs Cafe, Oak Hill College, Johnson Park, Harvey Oak Supply Store, The Willows Market and Pharmacy}.
174
+ * Stay in the current area if the activity can be done there. Only go out if the activity needs to take place in another place.
175
+ * Must be one of the "Area options," verbatim.
176
+ For taking a walk, Sam Kim should go to the following area: {Johnson Park}
177
+ ---
178
+ Jane Anderson lives in {Oak Hill College Student Dormatory} that has Jane Anderson's room.
179
+ Jane Anderson is currently in {Oak Hill College} that has a classroom, library
180
+ Area options: {Oak Hill College Student Dormatory, The Rose and Crown Pub, Hobbs Cafe, Oak Hill College, Johnson Park, Harvey Oak Supply Store, The Willows Market and Pharmacy}.
181
+ * Stay in the current area if the activity can be done there. Only go out if the activity needs to take place in another place.
182
+ * Must be one of the "Area options," verbatim.
183
+ For eating dinner, Jane Anderson should go to the following area: {Hobbs Cafe}
184
+ ---"""
185
+ s += (
186
+ persona_name
187
+ + " lives in "
188
+ + living_sector
189
+ + " that has "
190
+ + living_sector_areas
191
+ + ".\n"
192
+ )
193
+ s += (
194
+ persona_name
195
+ + " is currently in "
196
+ + current_sector
197
+ + " that has "
198
+ + current_sector_areas
199
+ + ".\n"
200
+ )
201
+ s += daily_plan + ".\n"
202
+ s += "Area options: " + sector_options + ".\n"
203
+ s += """* Stay in the current area if the activity can be done there. Only go out if the activity needs to take place in another place.
204
+ * Must be one of the "Area options," verbatim.\n"""
205
+ s += (
206
+ persona_name
207
+ + " is "
208
+ + current_action
209
+ + ". For "
210
+ + next_action
211
+ + ", "
212
+ + persona_name
213
+ + " should go to the following area: {"
214
+ )
215
+ return {"prompt": s, "max_tokens": 10, "stop": "}"}
216
+
217
+
218
+ @sgl.function
219
+ def action_location_object(
220
+ s, persona_name, target_sector, target_sector_areas, current_action, next_action
221
+ ):
222
+ s += """
223
+ Jane Anderson is in kitchen in Jane Anderson's house.
224
+ Jane Anderson is going to Jane Anderson's house that has the following areas: {kitchen, bedroom, bathroom}
225
+ Stay in the current area if the activity can be done there. Never go into other people's rooms unless necessary.
226
+ For cooking, Jane Anderson should go to the following area in Jane Anderson's house:
227
+ Answer: {kitchen}
228
+ ---
229
+ Tom Watson is in common room in Tom Watson's apartment.
230
+ Tom Watson is going to Hobbs Cafe that has the following areas: {cafe}
231
+ Stay in the current area if the activity can be done there. Never go into other people's rooms unless necessary.
232
+ For getting coffee, Tom Watson should go to the following area in Hobbs Cafe:
233
+ Answer: {cafe}
234
+ ---"""
235
+ s += (
236
+ persona_name
237
+ + " is going to "
238
+ + target_sector
239
+ + " that has the following areas: {"
240
+ + target_sector_areas
241
+ + "}\n"
242
+ )
243
+ s += """* Stay in the current area if the activity can be done there.
244
+ * NEVER go into other people's rooms unless necessary."""
245
+ s += (
246
+ persona_name
247
+ + " is "
248
+ + current_action
249
+ + ". For "
250
+ + next_action
251
+ + ", "
252
+ + persona_name
253
+ + "should go to the following area in "
254
+ + target_sector
255
+ )
256
+ s += " (MUST pick one of {" + target_sector_areas + "}):\n"
257
+ s += "Answer: {" + sgl.gen(name="Area", max_tokens=5, stop="}")
258
+
259
+
260
+ def action_location_object_prompt(
261
+ persona_name, target_sector, target_sector_areas, current_action, next_action
262
+ ):
263
+ s = ""
264
+ s += """
265
+ Jane Anderson is in kitchen in Jane Anderson's house.
266
+ Jane Anderson is going to Jane Anderson's house that has the following areas: {kitchen, bedroom, bathroom}
267
+ Stay in the current area if the activity can be done there. Never go into other people's rooms unless necessary.
268
+ For cooking, Jane Anderson should go to the following area in Jane Anderson's house:
269
+ Answer: {kitchen}
270
+ ---
271
+ Tom Watson is in common room in Tom Watson's apartment.
272
+ Tom Watson is going to Hobbs Cafe that has the following areas: {cafe}
273
+ Stay in the current area if the activity can be done there. Never go into other people's rooms unless necessary.
274
+ For getting coffee, Tom Watson should go to the following area in Hobbs Cafe:
275
+ Answer: {cafe}
276
+ ---"""
277
+ s += (
278
+ persona_name
279
+ + " is going to "
280
+ + target_sector
281
+ + " that has the following areas: {"
282
+ + target_sector_areas
283
+ + "}\n"
284
+ )
285
+ s += """* Stay in the current area if the activity can be done there.
286
+ * NEVER go into other people's rooms unless necessary."""
287
+ s += (
288
+ persona_name
289
+ + " is "
290
+ + current_action
291
+ + ". For "
292
+ + next_action
293
+ + ", "
294
+ + persona_name
295
+ + "should go to the following area in "
296
+ + target_sector
297
+ )
298
+ s += " (MUST pick one of {" + target_sector_areas + "}):\n"
299
+ s += "Answer: {"
300
+ return {"prompt": s, "max_tokens": 5, "stop": "}"}
sglang/benchmark/generative_agents/bench_other.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import time
4
+
5
+ from agent_functions import (
6
+ action_location_object_prompt,
7
+ action_location_sector_prompt,
8
+ generate_event_triple_prompt,
9
+ generate_pronunciatio_prompt,
10
+ poignancy_event_prompt,
11
+ )
12
+ from tqdm import tqdm
13
+
14
+ from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
15
+ from sglang.utils import dump_state_text, read_jsonl
16
+
17
+
18
+ def main(args):
19
+ lines = read_jsonl(args.data_path)[: args.num_events]
20
+ mapping = {
21
+ "poignancy_event": poignancy_event_prompt,
22
+ "generate_event_triple": generate_event_triple_prompt,
23
+ "generate_pronunciatio": generate_pronunciatio_prompt,
24
+ "action_location_sector": action_location_sector_prompt,
25
+ "action_location_object": action_location_object_prompt,
26
+ }
27
+
28
+ arguments = [mapping[k](**v) for l in lines for k, v in l.items()]
29
+ states = []
30
+
31
+ # Select backend
32
+ call_generate = get_call_generate(args)
33
+
34
+ def get_one_answer(arg):
35
+ answer = call_generate(**arg, temperature=0)
36
+ states.append(answer)
37
+
38
+ async def get_one_answer_async(arg):
39
+ answer = await call_generate(**arg, temperature=0)
40
+ states.append(answer)
41
+
42
+ tic = time.time()
43
+ # we always sequentially execute agent calls to maintain its dependency
44
+ if args.backend != "lmql":
45
+ for arg in tqdm(arguments):
46
+ get_one_answer(arg)
47
+ else:
48
+ import asyncio
49
+
50
+ loop = asyncio.get_event_loop()
51
+ for arg in tqdm(arguments):
52
+ loop.run_until_complete(get_one_answer_async(arg))
53
+ latency = time.time() - tic
54
+
55
+ print(f"Latency: {latency:.3f}")
56
+
57
+ # Write results
58
+ dump_state_text(f"tmp_output_{args.backend}.txt", states)
59
+
60
+ with open(args.result_file, "a") as fout:
61
+ value = {
62
+ "task": "Generative Agents",
63
+ "backend": args.backend,
64
+ "num_gpus": 1,
65
+ "latency": round(latency, 3),
66
+ # to pack weighted functions as a single agent
67
+ "num_requests": len(arguments) / len(mapping),
68
+ "other": {
69
+ "parallel": args.parallel,
70
+ },
71
+ }
72
+ fout.write(json.dumps(value) + "\n")
73
+
74
+
75
+ if __name__ == "__main__":
76
+ parser = argparse.ArgumentParser()
77
+ parser.add_argument("--data-path", type=str, default="agent_calls.jsonl")
78
+ parser.add_argument("--num-events", type=int, default=10)
79
+ args = add_common_other_args_and_parse(parser)
80
+ main(args)
sglang/benchmark/generative_agents/bench_sglang.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import time
4
+
5
+ from agent_functions import (
6
+ action_location_object,
7
+ action_location_sector,
8
+ generate_event_triple,
9
+ generate_pronunciatio,
10
+ poignancy_event,
11
+ )
12
+
13
+ import sglang as sgl
14
+ from sglang.test.test_utils import (
15
+ add_common_sglang_args_and_parse,
16
+ select_sglang_backend,
17
+ )
18
+ from sglang.utils import dump_state_text, read_jsonl
19
+
20
+
21
+ def main(args):
22
+ lines = read_jsonl(args.data_path)[: args.num_events]
23
+ mapping = {
24
+ "poignancy_event": poignancy_event,
25
+ "generate_event_triple": generate_event_triple,
26
+ "generate_pronunciatio": generate_pronunciatio,
27
+ "action_location_sector": action_location_sector,
28
+ "action_location_object": action_location_object,
29
+ }
30
+ arguments = [{mapping[k]: v for k, v in l.items()} for l in lines]
31
+
32
+ # Select backend
33
+ backend = select_sglang_backend(args)
34
+ sgl.set_default_backend(backend)
35
+
36
+ states = []
37
+ # Run requests
38
+ tic = time.time()
39
+ for a in arguments:
40
+ # only a single key in the dict
41
+ for func, arg in a.items():
42
+ result = func.run(**arg)
43
+ result.sync()
44
+ states.append(result)
45
+ latency = time.time() - tic
46
+
47
+ # Compute accuracy
48
+ print(f"Latency: {latency:.3f}")
49
+
50
+ # Write results
51
+ dump_state_text(f"tmp_output_{args.backend}.txt", states)
52
+
53
+ with open(args.result_file, "a") as fout:
54
+ value = {
55
+ "task": "Generative Agents",
56
+ "backend": args.backend,
57
+ "num_gpus": 1,
58
+ "latency": round(latency, 3),
59
+ # to pack weighted functions as a single agent
60
+ "num_requests": len(arguments) / len(mapping),
61
+ "other": {
62
+ "num_events": args.num_events,
63
+ "parallel": args.parallel,
64
+ },
65
+ }
66
+ fout.write(json.dumps(value) + "\n")
67
+
68
+
69
+ if __name__ == "__main__":
70
+ parser = argparse.ArgumentParser()
71
+ parser.add_argument("--data-path", type=str, default="agent_calls.jsonl")
72
+ parser.add_argument("--num-events", type=int, default=10)
73
+ args = add_common_sglang_args_and_parse(parser)
74
+ main(args)
sglang/benchmark/hellaswag/bench_sglang.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import time
4
+
5
+ import numpy as np
6
+
7
+ from sglang.api import set_default_backend
8
+ from sglang.test.test_utils import (
9
+ add_common_sglang_args_and_parse,
10
+ select_sglang_backend,
11
+ )
12
+ from sglang.utils import download_and_cache_file, read_jsonl
13
+
14
+
15
+ def get_one_example(lines, i, include_answer):
16
+ ret = lines[i]["activity_label"] + ": " + lines[i]["ctx"] + " "
17
+ if include_answer:
18
+ ret += lines[i]["endings"][lines[i]["label"]]
19
+ return ret
20
+
21
+
22
+ def get_few_shot_examples(lines, k):
23
+ ret = ""
24
+ for i in range(k):
25
+ ret += get_one_example(lines, i, True) + "\n\n"
26
+ return ret
27
+
28
+
29
+ def main(args):
30
+ # Select backend
31
+ set_default_backend(select_sglang_backend(args))
32
+
33
+ # Read data
34
+ url = "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl"
35
+ filename = download_and_cache_file(url)
36
+ lines = list(read_jsonl(filename))
37
+
38
+ # Construct prompts
39
+ num_questions = args.num_questions
40
+ num_shots = args.num_shots
41
+ few_shot_examples = get_few_shot_examples(lines, num_shots)
42
+
43
+ questions = []
44
+ choices = []
45
+ labels = []
46
+ for i in range(len(lines[:num_questions])):
47
+ questions.append(get_one_example(lines, i, False))
48
+ choices.append(lines[i]["endings"])
49
+ labels.append(lines[i]["label"])
50
+ arguments = [{"question": q, "choices": c} for q, c in zip(questions, choices)]
51
+
52
+ #####################################
53
+ ######### SGL Program Begin #########
54
+ #####################################
55
+
56
+ import sglang as sgl
57
+
58
+ @sgl.function
59
+ def few_shot_hellaswag(s, question, choices):
60
+ s += few_shot_examples + question
61
+ s += sgl.select("answer", choices=choices)
62
+
63
+ #####################################
64
+ ########## SGL Program End ##########
65
+ #####################################
66
+
67
+ # Run requests
68
+ tic = time.time()
69
+ rets = few_shot_hellaswag.run_batch(
70
+ arguments,
71
+ temperature=0,
72
+ num_threads=args.parallel,
73
+ progress_bar=True,
74
+ )
75
+ preds = [choices[i].index(rets[i]["answer"]) for i in range(len(rets))]
76
+ latency = time.time() - tic
77
+
78
+ # Compute accuracy
79
+ acc = np.mean(np.array(preds) == np.array(labels))
80
+ print(f"Latency: {latency:.3f}")
81
+ print(f"Accuracy: {acc:.3f}")
82
+
83
+ # Write results
84
+ with open(args.result_file, "a") as fout:
85
+ value = {
86
+ "task": "hellaswag",
87
+ "backend": args.backend,
88
+ "num_gpus": 1,
89
+ "latency": round(latency, 3),
90
+ "accuracy": round(acc, 3),
91
+ "num_requests": args.num_questions,
92
+ "other": {
93
+ "num_questions": args.num_questions,
94
+ "parallel": args.parallel,
95
+ },
96
+ }
97
+ fout.write(json.dumps(value) + "\n")
98
+
99
+
100
+ if __name__ == "__main__":
101
+ parser = argparse.ArgumentParser()
102
+ parser.add_argument("--num-shots", type=int, default=20)
103
+ parser.add_argument("--data-path", type=str, default="hellaswag_val.jsonl")
104
+ parser.add_argument("--num-questions", type=int, default=200)
105
+ args = add_common_sglang_args_and_parse(parser)
106
+ main(args)
sglang/benchmark/json_decode_regex/README.md ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Run benchmark
2
+
3
+ ### Build dataset
4
+ ```
5
+ pip install wikipedia
6
+ python3 build_dataset.py
7
+ ```
8
+
9
+ ### Dependencies
10
+
11
+ ```
12
+ llama_cpp_python 0.2.19
13
+ guidance 0.1.10
14
+ vllm 0.2.5
15
+ outlines 0.0.22
16
+ ```
17
+
18
+ ### Benchmark sglang
19
+
20
+ Run Llama-7B
21
+
22
+ ```
23
+ python3 -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
24
+ ```
25
+
26
+ Run Mixtral-8x7B
27
+
28
+ ```
29
+ python3 -m sglang.launch_server --model-path mistralai/Mixtral-8x7B-Instruct-v0.1 --port 30000 --tp-size 8
30
+ ```
31
+
32
+ Benchmark
33
+
34
+ ```
35
+ python3 bench_sglang.py --num-questions 10
36
+ ```
37
+
38
+
39
+ ### Benchmark Outlines + vLLM
40
+
41
+ Run Llama-7B
42
+
43
+ ```
44
+ python3 -m outlines.serve.serve --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
45
+ ```
46
+
47
+ Benchmark
48
+
49
+ ```
50
+ python3 bench_other.py --backend outlines --num-questions 10
51
+ ```
52
+
53
+
54
+ ### Benchmark guidance
55
+
56
+ Run Llama-7B and benchmark
57
+
58
+ ```
59
+ python3 bench_other.py --backend guidance --num-questions 10 --parallel 1 --n-ctx 4096 --model-path path/to/gguf
60
+ ```
sglang/benchmark/json_decode_regex/bench_other.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import time
4
+ from concurrent.futures import ThreadPoolExecutor
5
+ from functools import partial
6
+
7
+ from tqdm import tqdm
8
+
9
+ from sglang.lang.ir import REGEX_FLOAT, REGEX_INT, REGEX_STR
10
+ from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
11
+ from sglang.utils import dump_state_text, read_jsonl
12
+
13
+ REGEX_LIST = r"\[(" + REGEX_STR + ", )*" + REGEX_STR + r"\]"
14
+
15
+
16
+ # fmt: off
17
+ def json_decode(document, generate):
18
+ s = "Please extract the information of a city from the following wikipedia page.\n"
19
+ s += "Page begin.\n" + document + "Page end.\n"
20
+ s += "Here is the name, country, and symbol of the city in JSON format.\n"
21
+ s += "{\n"
22
+ s += ' "name": '
23
+ s += generate(s, max_tokens=8, regex=REGEX_STR + ",") + "\n"
24
+ s += ' "country": '
25
+ s += generate(s, max_tokens=8, regex=REGEX_STR + ",") + "\n"
26
+ s += ' "latitude": '
27
+ s += generate(s, max_tokens=8, regex=REGEX_FLOAT + ",") + "\n"
28
+ s += ' "population": '
29
+ s += generate(s, max_tokens=8, regex=REGEX_INT + ",") + "\n"
30
+ s += ' "top 3 landmarks": '
31
+ s += generate(s, max_tokens=24, regex=REGEX_LIST) + "\n"
32
+ s += "}\n"
33
+
34
+ return s
35
+ # fmt: on
36
+
37
+
38
+ def main(args):
39
+ lines = read_jsonl(args.data_path)
40
+ arguments = []
41
+ for i in range(len(lines[: args.num_questions])):
42
+ arguments.append(
43
+ {
44
+ "document": lines[i]["document"],
45
+ }
46
+ )
47
+ states = [None] * len(arguments)
48
+
49
+ # Select backend
50
+ call_generate = partial(get_call_generate(args), temperature=0)
51
+
52
+ # Run requests
53
+ def get_one_answer(i):
54
+ states[i] = json_decode(generate=call_generate, **arguments[i])
55
+
56
+ tic = time.time()
57
+ if args.parallel == 1:
58
+ for i in tqdm(range(len(arguments))):
59
+ get_one_answer(i)
60
+ else:
61
+ with ThreadPoolExecutor(args.parallel) as executor:
62
+ rets = list(
63
+ tqdm(
64
+ executor.map(get_one_answer, list(range(len(arguments)))),
65
+ total=len(arguments),
66
+ )
67
+ )
68
+ for _ in rets:
69
+ pass
70
+
71
+ latency = time.time() - tic
72
+
73
+ # Compute accuracy
74
+ print(f"Latency: {latency:.3f}")
75
+
76
+ # Write results
77
+ dump_state_text(f"tmp_output_{args.backend}.txt", states)
78
+
79
+ with open(args.result_file, "a") as fout:
80
+ value = {
81
+ "task": "json_decode_regex",
82
+ "backend": args.backend,
83
+ "num_gpus": 1,
84
+ "latency": round(latency, 3),
85
+ "num_requests": args.num_questions,
86
+ "other": {
87
+ "parallel": args.parallel,
88
+ },
89
+ }
90
+ fout.write(json.dumps(value) + "\n")
91
+
92
+
93
+ if __name__ == "__main__":
94
+ parser = argparse.ArgumentParser()
95
+ parser.add_argument("--data-path", type=str, default="questions.jsonl")
96
+ parser.add_argument("--num-questions", type=int, default=20)
97
+ args = add_common_other_args_and_parse(parser)
98
+ main(args)
sglang/benchmark/json_decode_regex/bench_sglang.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import time
4
+
5
+ import sglang as sgl
6
+ from sglang.lang.ir import REGEX_FLOAT, REGEX_INT, REGEX_STR
7
+ from sglang.test.test_utils import (
8
+ add_common_sglang_args_and_parse,
9
+ select_sglang_backend,
10
+ )
11
+ from sglang.utils import dump_state_text, read_jsonl
12
+
13
+ REGEX_LIST = r"\[(" + REGEX_STR + ", )*" + REGEX_STR + r"\]"
14
+
15
+ # fmt: off
16
+ @sgl.function
17
+ def json_warm_up(s):
18
+ s += "The information about Hogwarts is in the following JSON format.\n"
19
+ with s.var_scope("json_output"):
20
+ s += "{\n"
21
+ s += ' "name": ' + sgl.gen("name", max_tokens=8, regex=REGEX_STR + ",") + "\n"
22
+ s += ' "country": ' + sgl.gen("country", max_tokens=8, regex=REGEX_STR + ",") + "\n"
23
+ s += ' "latitude": ' + sgl.gen("latitude", max_tokens=8, regex=REGEX_FLOAT + ",") + "\n"
24
+ s += ' "population": ' + sgl.gen("population", max_tokens=8, regex=REGEX_INT + ",") + "\n"
25
+ s += ' "top 3 landmarks": ' + sgl.gen( "landmarks", max_tokens=24, regex=REGEX_LIST) + "\n"
26
+ s += "}\n"
27
+ print(f'The warmp up json result is:\n{s["json_output"]}')
28
+ # fmt: on
29
+
30
+ # fmt: off
31
+ @sgl.function
32
+ def json_decode(s, document):
33
+ s += "Please extract the information of a city from the following wikipedia page.\n"
34
+ s += "Page begin.\n" + document + "Page end.\n"
35
+ s += "Here is the name, country, and symbol of the city in JSON format.\n"
36
+ with s.var_scope("json_output"):
37
+ s += "{\n"
38
+ s += ' "name": ' + sgl.gen("name", max_tokens=8, regex=REGEX_STR + ",") + "\n"
39
+ s += ' "country": ' + sgl.gen("country", max_tokens=8, regex=REGEX_STR + ",") + "\n"
40
+ s += ' "latitude": ' + sgl.gen("latitude", max_tokens=8, regex=REGEX_FLOAT + ",") + "\n"
41
+ s += ' "population": ' + sgl.gen("population", max_tokens=8, regex=REGEX_INT + ",") + "\n"
42
+ s += ' "top 3 landmarks": ' + sgl.gen( "landmarks", max_tokens=24, regex=REGEX_LIST) + "\n"
43
+ s += "}\n"
44
+ # fmt: on
45
+
46
+
47
+ def main(args):
48
+ lines = read_jsonl(args.data_path)
49
+ arguments = []
50
+ for i in range(len(lines[: args.num_questions])):
51
+ arguments.append(
52
+ {
53
+ "document": lines[i]["document"],
54
+ }
55
+ )
56
+
57
+ # Select backend
58
+ backend = select_sglang_backend(args)
59
+ sgl.set_default_backend(backend)
60
+
61
+ # Warm up
62
+ json_warm_up.run().sync()
63
+
64
+ # Run requests
65
+ tic = time.time()
66
+ states = json_decode.run_batch(
67
+ arguments, temperature=0, num_threads=args.parallel, progress_bar=True
68
+ )
69
+ latency = time.time() - tic
70
+
71
+ # Compute accuracy
72
+ print(f"Latency: {latency:.3f}")
73
+
74
+ # Write results
75
+ dump_state_text(f"tmp_output_{args.backend}.txt", states)
76
+
77
+ with open(f"tmp_{args.backend}_json_results.txt", "w") as fout:
78
+ for state in states:
79
+ fout.write(state["json_output"] + "\n")
80
+
81
+ with open(args.result_file, "a") as fout:
82
+ value = {
83
+ "task": "json_decode_regex",
84
+ "backend": args.backend,
85
+ "num_gpus": 1,
86
+ "latency": round(latency, 3),
87
+ "num_requests": args.num_questions,
88
+ "other": {
89
+ "parallel": args.parallel,
90
+ },
91
+ }
92
+ fout.write(json.dumps(value) + "\n")
93
+
94
+
95
+ if __name__ == "__main__":
96
+ parser = argparse.ArgumentParser()
97
+ parser.add_argument("--data-path", type=str, default="questions.jsonl")
98
+ parser.add_argument("--num-questions", type=int, default=20)
99
+ args = add_common_sglang_args_and_parse(parser)
100
+ main(args)
sglang/benchmark/json_decode_regex/build_dataset.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ import transformers
4
+ import wikipedia
5
+
6
+ model_path = "meta-llama/Llama-2-7b-chat-hf"
7
+ t = transformers.AutoTokenizer.from_pretrained(model_path)
8
+ city_names = [
9
+ "los angles",
10
+ "london",
11
+ "tokyo",
12
+ "beijing",
13
+ "singapore",
14
+ "paris",
15
+ "dubai",
16
+ "sydney",
17
+ "moscow",
18
+ "rome",
19
+ "toronto",
20
+ "rio de janeiro",
21
+ "istanbul",
22
+ "berlin",
23
+ "auckland",
24
+ "buenos aires",
25
+ "mexico city",
26
+ "mumbai",
27
+ "seoul",
28
+ "bangkok",
29
+ "cairo",
30
+ "athens",
31
+ "jerusalem",
32
+ ]
33
+
34
+
35
+ def get_content(city_name):
36
+ content = str(wikipedia.page(city_name).content)
37
+ content = content.replace("\n\n", "\n")
38
+
39
+ tokens = t.encode(content)
40
+
41
+ expected_tokens = 3000
42
+ truncate_len = int((expected_tokens / len(tokens)) * len(content))
43
+ truncate_content = content[:truncate_len]
44
+ truncate_tokens = t.encode(truncate_content)
45
+
46
+ # Count token
47
+ print(
48
+ f"city_name: {city_name}, #tokens: {len(tokens)}, #truncate tokens: {len(truncate_tokens)}"
49
+ )
50
+
51
+ return truncate_content
52
+
53
+
54
+ if __name__ == "__main__":
55
+ with open("questions.jsonl", "w") as fout:
56
+ for city_name in city_names:
57
+ truncate_content = get_content(city_name)
58
+ fout.write(json.dumps({"document": truncate_content}) + "\n")
sglang/benchmark/json_jump_forward/README.md ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Run benchmark
2
+
3
+ ### Dependencies
4
+
5
+ ```
6
+ llama_cpp_python 0.2.38
7
+ guidance 0.1.10
8
+ vllm 0.2.7
9
+ outlines 0.0.25
10
+ ```
11
+
12
+ ### Build dataset
13
+
14
+ When benchmarking long document information retrieval, run the following command to build the dataset:
15
+
16
+ ```bash
17
+ pip install wikipedia
18
+ python3 build_dataset.py
19
+ ```
20
+
21
+ ### Benchmark sglang
22
+
23
+ Run Llama-7B
24
+
25
+ ```bash
26
+ python3 -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
27
+ ```
28
+
29
+ Benchmark Character Generation
30
+
31
+ ```bash
32
+ python3 bench_sglang.py --mode character
33
+ ```
34
+
35
+ Benchmark City Information Retrieval
36
+
37
+ ```bash
38
+ python3 bench_sglang.py --mode city
39
+ ```
40
+
41
+
42
+ ### Benchmark Outlines + vLLM
43
+
44
+ Run Llama-7B
45
+
46
+ ```bash
47
+ python3 -m outlines.serve.serve --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
48
+ ```
49
+
50
+ Benchmark Character Generation
51
+
52
+ ```bash
53
+ python3 bench_other.py --mode character --backend outlines
54
+ ```
55
+
56
+ Benchmark City Information Retrieval
57
+
58
+ ```bash
59
+ python3 bench_other.py --mode city --backend outlines
60
+ ```
61
+
62
+ ### Benchmark guidance
63
+
64
+ Run Llama-7B and benchmark character generation
65
+
66
+ ```bash
67
+ python3 bench_other.py --mode character --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf
68
+ ```
69
+
70
+ Run Llama-7B and benchmark city information retrieval
71
+
72
+ ```bash
73
+ python3 bench_other.py --mode city --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf
74
+ ```
75
+
76
+ ### Benchmark lmql
77
+
78
+ Run Llama-7B and benchmark character generation
79
+
80
+ ```
81
+ python3 bench_other.py --mode character --backend lmql --parallel 1
82
+ ```
83
+
84
+ Run Llama-7B and benchmark city information retrieval
85
+
86
+ ```
87
+ python3 bench_other.py --mode city --backend lmql --parallel 1
88
+ ```
sglang/benchmark/json_jump_forward/bench_other.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import time
4
+ from concurrent.futures import ThreadPoolExecutor
5
+ from functools import partial
6
+
7
+ import guidance
8
+ from tqdm import tqdm
9
+
10
+ from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
11
+ from sglang.utils import dump_state_text, read_jsonl
12
+
13
+ # there are some FSM bugs with json regex converted from pydantic model
14
+ # here use a string regex instead
15
+ # regex_string = build_regex_from_object(HarryPoterRole)
16
+ character_regex = (
17
+ r"""\{\n"""
18
+ + r""" "name": "[\w\d\s]{1,16}",\n"""
19
+ + r""" "house": "(Gryffindor|Slytherin|Ravenclaw|Hufflepuff)",\n"""
20
+ + r""" "blood status": "(Pure-blood|Half-blood|Muggle-born)",\n"""
21
+ + r""" "occupation": "(student|teacher|auror|ministry of magic|death eater|order of the phoenix)",\n"""
22
+ + r""" "wand": \{\n"""
23
+ + r""" "wood": "[\w\d\s]{1,16}",\n"""
24
+ + r""" "core": "[\w\d\s]{1,16}",\n"""
25
+ + r""" "length": [0-9]{1,2}\.[0-9]{0,2}\n"""
26
+ + r""" \},\n"""
27
+ + r""" "alive": "(Alive|Deceased)",\n"""
28
+ + r""" "patronus": "[\w\d\s]{1,16}",\n"""
29
+ + r""" "bogart": "[\w\d\s]{1,16}"\n"""
30
+ + r"""\}"""
31
+ )
32
+
33
+ city_regex = (
34
+ r"""\{\n"""
35
+ + r""" "name": "[\w\d\s]{1,16}",\n"""
36
+ + r""" "country": "[\w\d\s]{1,16}",\n"""
37
+ + r""" "latitude": [-+]?[0-9]*\.?[0-9]{0,2},\n"""
38
+ + r""" "population": [-+]?[0-9]{1,9},\n"""
39
+ + r""" "top 3 landmarks": \["[\w\d\s]{1,16}", "[\w\d\s]{1,16}", "[\w\d\s]{1,16}"\]\n"""
40
+ + r"""\}"""
41
+ )
42
+
43
+ # fmt: off
44
+ def character_gen(name, generate):
45
+ s = name + " is a character in Harry Potter. Please fill in the following information about this character.\n"
46
+ s += generate(s, max_tokens=256, regex=character_regex)
47
+ return s
48
+ # fmt: on
49
+
50
+ # fmt: off
51
+ def city_gen(document, generate):
52
+ s = "Please extract the information of a city from the following wikipedia page.\n"
53
+ s += "Page begin.\n" + document + "Page end.\n"
54
+ s += "Here is the name, country, and symbol of the city in JSON format.\n"
55
+ s += generate(s, max_tokens=256, regex=city_regex)
56
+ return s
57
+ # fmt: on
58
+
59
+
60
+ @guidance
61
+ def character_maker(lm, name):
62
+ regex_str_no_quote = r"[\w\d\s]+"
63
+ regex_float = r"[0-9]+\.[0-9]+"
64
+ lm += f"""\
65
+ {name} is a character in Harry Potter. Please fill in the following information about this character.
66
+ {{
67
+ "name": "{guidance.gen("name", max_tokens=16, regex=regex_str_no_quote)}",
68
+ "house": "{guidance.select(options=['Gryffindor', 'Slytherin', 'Ravenclaw', 'Hufflepuff'], name='house')}",
69
+ "blood status": "{guidance.select(options=['Pure-blood', 'Half-blood', 'Muggle-born'], name='blood status')}",
70
+ "occupation": "{guidance.select(options=['student', 'teacher', 'auror', 'ministry of magic', 'death eater', 'order of the phoenix'], name='occupation')}",
71
+ "wand": {{
72
+ "wood": "{guidance.gen("wood", max_tokens=16, regex=regex_str_no_quote)}",
73
+ "core": "{guidance.gen('core', max_tokens=16, regex=regex_str_no_quote)}",
74
+ "length": {guidance.gen('length', max_tokens=10, regex=regex_float)}
75
+ }},
76
+ "alive": "{guidance.select(options=['Alive', 'Deceased'], name='alive')}",
77
+ "patronus": "{guidance.gen('patronus', max_tokens=16, regex=regex_str_no_quote)}",
78
+ "bogart": "{guidance.gen('bogart', max_tokens=16, regex=regex_str_no_quote)}"
79
+ }}
80
+ """
81
+
82
+ return lm
83
+
84
+
85
+ async def call_generate_lmql(
86
+ prompt, temperature, max_tokens, regex, max_len=4096, model=None, **kwargs
87
+ ):
88
+ assert model is not None
89
+ import lmql
90
+
91
+ @lmql.query(model=model)
92
+ async def program(question, max_tokens, regex):
93
+ '''lmql
94
+ """{question}[ANSWER]""" where len(TOKENS(ANSWER)) < max_tokens and REGEX(ANSWER, regex)
95
+ return ANSWER
96
+ '''
97
+
98
+ return await program(
99
+ question=prompt,
100
+ temperature=temperature,
101
+ max_tokens=max_tokens,
102
+ max_len=max_len,
103
+ regex=regex,
104
+ **kwargs,
105
+ )
106
+
107
+
108
+ @guidance
109
+ def city_maker(lm, document):
110
+ regex_str_no_quote = r"[\w\d\s]+"
111
+ regex_float = r"[0-9]+\.[0-9]+"
112
+ lm += f"""\
113
+ Please extract the information of a city from the following wikipedia page.
114
+ Page begin.
115
+ {document}
116
+ Page end.
117
+ Here is the name, country, and symbol of the city in JSON format.
118
+ {{
119
+ "name": "{guidance.gen("name", max_tokens=16, regex=regex_str_no_quote)}",
120
+ "country": "{guidance.gen("country", max_tokens=16, regex=regex_str_no_quote)}",
121
+ "latitude": {guidance.gen("latitude", max_tokens=10, regex=regex_float)},
122
+ "population": {guidance.gen("population", max_tokens=10, regex=r"[0-9]+")},
123
+ "top 3 landmarks": [
124
+ "{guidance.gen("landmark1", max_tokens=16, regex=regex_str_no_quote)}", "{guidance.gen("landmark2", max_tokens=16, regex=regex_str_no_quote)}", "{guidance.gen("landmark3", max_tokens=16, regex=regex_str_no_quote)}"
125
+ ]
126
+ }}
127
+ """
128
+
129
+ return lm
130
+
131
+
132
+ def bench_character(args):
133
+ arguments = []
134
+ with open(args.data_path, "r") as f:
135
+ for line in f:
136
+ arguments.append({"name": line.strip()})
137
+ arguments = arguments[: args.num_jsons]
138
+
139
+ states = [None] * len(arguments)
140
+
141
+ # Select backend
142
+ if args.backend == "outlines":
143
+ call_generate = partial(get_call_generate(args), temperature=0)
144
+
145
+ def get_one_answer(i):
146
+ states[i] = character_gen(**arguments[i], generate=call_generate)
147
+
148
+ elif args.backend == "guidance":
149
+ model = guidance.models.LlamaCpp(
150
+ args.model_path,
151
+ n_gpu_layers=-1,
152
+ n_ctx=args.n_ctx,
153
+ )
154
+
155
+ def get_one_answer(i):
156
+ lm = model + character_maker(**arguments[i])
157
+ states[i] = lm
158
+
159
+ elif args.backend == "lmql":
160
+ import asyncio
161
+
162
+ import lmql
163
+
164
+ model = lmql.model(args.model_path, endpoint=f"{args.host}:{args.port}")
165
+ call_generate = partial(
166
+ call_generate_lmql,
167
+ model=model,
168
+ max_tokens=256,
169
+ regex=character_regex,
170
+ )
171
+
172
+ async def get_one_answer_async(i):
173
+ states[i] = await call_generate(prompt=arguments[i]["name"], temperature=0)
174
+
175
+ else:
176
+ raise ValueError(f"Invalid backend: {args.backend}")
177
+
178
+ tic = time.time()
179
+
180
+ if args.backend != "lmql":
181
+ if args.parallel == 1:
182
+ for i in tqdm(range(len(arguments))):
183
+ get_one_answer(i)
184
+ else:
185
+ with ThreadPoolExecutor(args.parallel) as executor:
186
+ rets = list(
187
+ tqdm(
188
+ executor.map(get_one_answer, list(range(len(arguments)))),
189
+ total=len(arguments),
190
+ )
191
+ )
192
+ for _ in rets:
193
+ pass
194
+ else:
195
+ batches = []
196
+ for i in range(0, len(arguments), args.parallel):
197
+ batches.append(list(range(i, min(i + args.parallel, len(arguments)))))
198
+ loop = asyncio.get_event_loop()
199
+
200
+ for bt in tqdm(batches):
201
+ loop.run_until_complete(
202
+ asyncio.gather(*[get_one_answer_async(i) for i in bt])
203
+ )
204
+
205
+ latency = time.time() - tic
206
+
207
+ return states, latency
208
+
209
+
210
+ def bench_city_doc(args):
211
+ arguments = []
212
+ for line in read_jsonl(args.data_path):
213
+ arguments.append({"document": line["document"]})
214
+ arguments = arguments[: args.num_jsons]
215
+
216
+ states = [None] * len(arguments)
217
+
218
+ # Select backend
219
+ if args.backend == "outlines":
220
+ call_generate = partial(get_call_generate(args), temperature=0)
221
+
222
+ def get_one_answer(i):
223
+ states[i] = city_gen(**arguments[i], generate=call_generate)
224
+
225
+ elif args.backend == "guidance":
226
+ model = guidance.models.LlamaCpp(
227
+ args.model_path,
228
+ n_gpu_layers=-1,
229
+ n_ctx=args.n_ctx,
230
+ )
231
+
232
+ def get_one_answer(i):
233
+ lm = model + city_maker(**arguments[i])
234
+ states[i] = lm
235
+
236
+ else:
237
+ raise ValueError(f"Invalid backend: {args.backend}")
238
+
239
+ tic = time.time()
240
+ if args.parallel == 1:
241
+ for i in tqdm(range(len(arguments))):
242
+ get_one_answer(i)
243
+ else:
244
+ with ThreadPoolExecutor(args.parallel) as executor:
245
+ rets = executor.map(get_one_answer, list(range(len(arguments))))
246
+ for _ in rets:
247
+ pass
248
+
249
+ latency = time.time() - tic
250
+
251
+ return states, latency
252
+
253
+
254
+ def main(args):
255
+ if args.mode == "character":
256
+ args.data_path = "dataset.txt"
257
+ states, latency = bench_character(args)
258
+ elif args.mode == "city":
259
+ args.data_path = "questions.jsonl"
260
+ states, latency = bench_city_doc(args)
261
+
262
+ # Compute accuracy
263
+ print(f"Latency: {latency:.3f}")
264
+
265
+ # Write results
266
+ dump_state_text(f"tmp_output_{args.backend}_{args.mode}.txt", states)
267
+
268
+ with open(args.result_file, "a") as fout:
269
+ value = {
270
+ "task": "json_jump_forward",
271
+ "backend": args.backend,
272
+ "latency": round(latency, 3),
273
+ "num_jsons": args.num_jsons,
274
+ "mode": args.mode,
275
+ "parallel": args.parallel,
276
+ }
277
+ fout.write(json.dumps(value) + "\n")
278
+
279
+
280
+ if __name__ == "__main__":
281
+ parser = argparse.ArgumentParser()
282
+ parser.add_argument("--data-path", type=str)
283
+ parser.add_argument("--num-jsons", type=int, default=50)
284
+ parser.add_argument(
285
+ "--mode", type=str, default="character", choices=["character", "city"]
286
+ )
287
+ args = add_common_other_args_and_parse(parser)
288
+ main(args)
sglang/benchmark/json_jump_forward/bench_sglang.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import time
4
+
5
+ import sglang as sgl
6
+ from sglang.test.test_utils import (
7
+ add_common_sglang_args_and_parse,
8
+ select_sglang_backend,
9
+ )
10
+ from sglang.utils import dump_state_text, read_jsonl
11
+
12
+ # there are some FSM bugs with json regex converted from pydantic model
13
+ # here use a string regex instead
14
+ # regex_string = build_regex_from_object(HarryPoterRole)
15
+ character_regex = (
16
+ r"""\{\n"""
17
+ + r""" "name": "[\w\d\s]{1,16}",\n"""
18
+ + r""" "house": "(Gryffindor|Slytherin|Ravenclaw|Hufflepuff)",\n"""
19
+ + r""" "blood status": "(Pure-blood|Half-blood|Muggle-born)",\n"""
20
+ + r""" "occupation": "(student|teacher|auror|ministry of magic|death eater|order of the phoenix)",\n"""
21
+ + r""" "wand": \{\n"""
22
+ + r""" "wood": "[\w\d\s]{1,16}",\n"""
23
+ + r""" "core": "[\w\d\s]{1,16}",\n"""
24
+ + r""" "length": [0-9]{1,2}\.[0-9]{0,2}\n"""
25
+ + r""" \},\n"""
26
+ + r""" "alive": "(Alive|Deceased)",\n"""
27
+ + r""" "patronus": "[\w\d\s]{1,16}",\n"""
28
+ + r""" "bogart": "[\w\d\s]{1,16}"\n"""
29
+ + r"""\}"""
30
+ )
31
+
32
+ city_regex = (
33
+ r"""\{\n"""
34
+ + r""" "name": "[\w\d\s]{1,16}",\n"""
35
+ + r""" "country": "[\w\d\s]{1,16}",\n"""
36
+ + r""" "latitude": [-+]?[0-9]*\.?[0-9]{0,2},\n"""
37
+ + r""" "population": [-+]?[0-9]{1,9},\n"""
38
+ + r""" "top 3 landmarks": \["[\w\d\s]{1,16}", "[\w\d\s]{1,16}", "[\w\d\s]{1,16}"\]\n"""
39
+ + r"""\}"""
40
+ )
41
+
42
+ # fmt: off
43
+ @sgl.function
44
+ def character_gen(s, name):
45
+ s += name + " is a character in Harry Potter. Please fill in the following information about this character.\n"
46
+ s += sgl.gen("json_output", max_tokens=256, regex=character_regex)
47
+ # fmt: on
48
+
49
+ # fmt: off
50
+ @sgl.function
51
+ def city_gen(s, document):
52
+ s += "Please extract the information of a city from the following wikipedia page.\n"
53
+ s += "Page begin.\n" + document + "Page end.\n"
54
+ s += "Here is the name, country, and symbol of the city in JSON format.\n"
55
+ s += sgl.gen("json_output",max_tokens=256, regex=city_regex)
56
+ # fmt: on
57
+
58
+
59
+ def bench_city_doc(args):
60
+ arguments = []
61
+ for line in read_jsonl(args.data_path):
62
+ arguments.append({"document": line["document"]})
63
+ arguments = arguments[: args.num_jsons]
64
+
65
+ # Select backend
66
+ backend = select_sglang_backend(args)
67
+ sgl.set_default_backend(backend)
68
+
69
+ # Run requests
70
+ tic = time.time()
71
+ states = city_gen.run_batch(
72
+ arguments,
73
+ temperature=0,
74
+ num_threads=args.parallel,
75
+ progress_bar=True,
76
+ )
77
+ latency = time.time() - tic
78
+
79
+ return states, latency
80
+
81
+
82
+ def bench_character(args):
83
+ arguments = []
84
+ with open(args.data_path, "r") as f:
85
+ for line in f:
86
+ arguments.append({"name": line.strip()})
87
+ arguments = arguments[: args.num_jsons]
88
+
89
+ # Select backend
90
+ backend = select_sglang_backend(args)
91
+ sgl.set_default_backend(backend)
92
+
93
+ # Run requests
94
+ tic = time.time()
95
+ states = character_gen.run_batch(
96
+ arguments,
97
+ temperature=0,
98
+ num_threads=args.parallel,
99
+ progress_bar=True,
100
+ )
101
+ latency = time.time() - tic
102
+
103
+ return states, latency
104
+
105
+
106
+ def main(args):
107
+ if args.mode == "character":
108
+ args.data_path = "dataset.txt"
109
+ states, latency = bench_character(args)
110
+ elif args.mode == "city":
111
+ args.data_path = "questions.jsonl"
112
+ states, latency = bench_city_doc(args)
113
+
114
+ # Compute accuracy
115
+ print(f"Latency: {latency:.3f}")
116
+
117
+ # Write results
118
+ dump_state_text(f"tmp_output_{args.backend}_{args.mode}.txt", states)
119
+ with open(f"{args.backend}_{args.mode}.json", "w") as fout:
120
+ for state in states:
121
+ fout.write(state["json_output"] + "\n")
122
+
123
+ with open(args.result_file, "a") as fout:
124
+ value = {
125
+ "task": "json_jump_forward",
126
+ "backend": args.backend,
127
+ "latency": round(latency, 3),
128
+ "num_jsons": args.num_jsons,
129
+ "mode": args.mode,
130
+ "parallel": args.parallel,
131
+ }
132
+ fout.write(json.dumps(value) + "\n")
133
+
134
+
135
+ if __name__ == "__main__":
136
+ parser = argparse.ArgumentParser()
137
+ parser.add_argument("--data-path", type=str)
138
+ parser.add_argument("--num-jsons", type=int, default=50)
139
+ parser.add_argument(
140
+ "--mode", type=str, default="character", choices=["character", "city"]
141
+ )
142
+ args = add_common_sglang_args_and_parse(parser)
143
+ main(args)
sglang/benchmark/json_jump_forward/build_dataset.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ import transformers
4
+ import wikipedia
5
+
6
+ model_path = "meta-llama/Llama-2-7b-chat-hf"
7
+ t = transformers.AutoTokenizer.from_pretrained(model_path)
8
+ city_names = [
9
+ "los angles",
10
+ "london",
11
+ "tokyo",
12
+ "beijing",
13
+ "singapore",
14
+ "paris",
15
+ "dubai",
16
+ "sydney",
17
+ "moscow",
18
+ "rome",
19
+ "toronto",
20
+ "rio de janeiro",
21
+ "istanbul",
22
+ "berlin",
23
+ "auckland",
24
+ "buenos aires",
25
+ "mexico city",
26
+ "mumbai",
27
+ "seoul",
28
+ "bangkok",
29
+ "cairo",
30
+ "athens",
31
+ "jerusalem",
32
+ ]
33
+
34
+
35
+ def get_content(city_name):
36
+ content = str(wikipedia.page(city_name).content)
37
+ content = content.replace("\n\n", "\n")
38
+
39
+ tokens = t.encode(content)
40
+
41
+ expected_tokens = 3000
42
+ truncate_len = int((expected_tokens / len(tokens)) * len(content))
43
+ truncate_content = content[:truncate_len]
44
+ truncate_tokens = t.encode(truncate_content)
45
+
46
+ # Count token
47
+ print(
48
+ f"city_name: {city_name}, #tokens: {len(tokens)}, #truncate tokens: {len(truncate_tokens)}"
49
+ )
50
+
51
+ return truncate_content
52
+
53
+
54
+ if __name__ == "__main__":
55
+ with open("questions.jsonl", "w") as fout:
56
+ for city_name in city_names:
57
+ truncate_content = get_content(city_name)
58
+ fout.write(json.dumps({"document": truncate_content}) + "\n")
sglang/benchmark/json_jump_forward/dataset.txt ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Harry Potter
2
+ Hermione Granger
3
+ Ron Weasley
4
+ Albus Dumbledore
5
+ Severus Snape
6
+ Rubeus Hagrid
7
+ Draco Malfoy
8
+ Ginny Weasley
9
+ Fred Weasley
10
+ George Weasley
11
+ Percy Weasley
12
+ Sirius Black
13
+ Remus Lupin
14
+ Neville Longbottom
15
+ Luna Lovegood
16
+ Cedric Diggory
17
+ Cho Chang
18
+ Lord Voldemort
19
+ Minerva McGonagall
20
+ Filius Flitwick
21
+ Dolores Umbridge
22
+ Bellatrix Lestrange
23
+ Lucius Malfoy
24
+ Molly Weasley
25
+ Arthur Weasley
26
+ Nymphadora Tonks
27
+ Dobby
28
+ Moaning Myrtle
29
+ Peter Pettigrew
30
+ Alastor 'Mad-Eye' Moody
31
+ Horace Slughorn
32
+ Vernon Dursley
33
+ Petunia Dursley
34
+ Dudley Dursley
35
+ Argus Filch
36
+ Sybill Trelawney
37
+ Gilderoy Lockhart
38
+ Fleur Delacour
39
+ Viktor Krum
40
+ Bill Weasley
41
+ Oliver Wood
42
+ Cornelius Fudge
43
+ Barty Crouch Sr.
44
+ Barty Crouch Jr.
45
+ Kingsley Shacklebolt
46
+ Quirinus Quirrell
47
+ Nearly Headless Nick
48
+ Aunt Marge
49
+ Griphook
50
+ Ludo Bagman
sglang/benchmark/json_schema/README.md ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Run benchmark
2
+
3
+ ### Benchmark sglang
4
+
5
+ Run Llama-8b
6
+
7
+ ```bash
8
+ python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --port 30000
9
+ ```
10
+
11
+ Benchmark
12
+
13
+ ```bash
14
+ python3 bench_sglang.py
15
+ ```
sglang/benchmark/json_schema/bench_sglang.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import time
4
+ from typing import List, Tuple
5
+
6
+ import jsonschema
7
+ from datasets import load_dataset
8
+
9
+ import sglang as sgl
10
+ from sglang.global_config import global_config
11
+ from sglang.srt.hf_transformers_utils import get_tokenizer
12
+ from sglang.test.test_utils import (
13
+ add_common_sglang_args_and_parse,
14
+ select_sglang_backend,
15
+ )
16
+ from sglang.utils import dump_state_text
17
+
18
+
19
+ @sgl.function
20
+ def schema_gen(s, message: Tuple[str, str], json_schema: str):
21
+ system, user = message
22
+ s += sgl.system(system)
23
+ s += sgl.user(user)
24
+ s += sgl.assistant(
25
+ sgl.gen("json_output", temperature=0, max_tokens=256, json_schema=json_schema)
26
+ )
27
+
28
+
29
+ def contains_formats(schema, formats: List[str]):
30
+ if isinstance(schema, dict):
31
+ if schema.get("format", None) in formats:
32
+ return True
33
+ for value in schema.values():
34
+ if contains_formats(value, formats):
35
+ return True
36
+ elif isinstance(schema, list):
37
+ for item in schema:
38
+ if contains_formats(item, formats):
39
+ return True
40
+ return False
41
+
42
+
43
+ def convert_dataset(path: str):
44
+ raw_dataset = load_dataset(path)
45
+ dataset = []
46
+ for data in raw_dataset["train"]:
47
+ messages = data["prompt"]
48
+ schema = data["schema"]
49
+ obj = json.loads(schema)
50
+
51
+ # skip some corrupted examples
52
+ if obj.get("type", None) is None:
53
+ continue
54
+
55
+ # skip schema with format "email"
56
+ # which is not supported by outlines for now
57
+ if contains_formats(obj, ["email"]):
58
+ continue
59
+
60
+ system = messages[0]
61
+ user = messages[1]
62
+ assert system["role"] == "system", "invalid role"
63
+ assert user["role"] == "user", "invalid role"
64
+ assert len(messages) == 2, "invalid message length"
65
+ message = json.dumps(system["content"]), json.dumps(user["content"])
66
+ dataset.append(
67
+ {
68
+ "message": message,
69
+ "json_schema": schema,
70
+ }
71
+ )
72
+
73
+ return dataset
74
+
75
+
76
+ def bench_schema(args):
77
+ arguments = convert_dataset(args.data_path)
78
+
79
+ if args.num_jsons < 0 or args.num_jsons > len(arguments):
80
+ args.num_jsons = len(arguments)
81
+ arguments = arguments[: args.num_jsons]
82
+
83
+ # Select backend
84
+ backend = select_sglang_backend(args)
85
+ sgl.set_default_backend(backend)
86
+
87
+ # Run requests
88
+ tic = time.time()
89
+ states = schema_gen.run_batch(
90
+ arguments,
91
+ temperature=0,
92
+ num_threads=args.parallel,
93
+ progress_bar=True,
94
+ )
95
+ latency = time.time() - tic
96
+
97
+ # Check if the outputs are valid
98
+ indexs = []
99
+ for i, state in enumerate(states):
100
+ try:
101
+ schema = json.loads(arguments[i]["json_schema"])
102
+ obj = json.loads(state["json_output"])
103
+ assert jsonschema.validate(obj, schema) is None
104
+ except Exception as e:
105
+ print(e)
106
+ indexs.append(i)
107
+
108
+ return states, latency
109
+
110
+
111
+ def main(args):
112
+ states, latency = bench_schema(args)
113
+
114
+ # Compute accuracy
115
+ tokenizer = get_tokenizer(
116
+ global_config.default_backend.get_server_info()["tokenizer_path"]
117
+ )
118
+ output_jsons = [state["json_output"] for state in states]
119
+ num_output_tokens = sum(len(tokenizer.encode(x)) for x in output_jsons)
120
+ print(f"Latency: {latency:.3f}")
121
+ print(f"Output throughput: {num_output_tokens / latency:.3f} token/s")
122
+ print(f"#output tokens: {num_output_tokens}")
123
+
124
+ # Write results
125
+ dump_state_text(f"tmp_output_{args.backend}.txt", states)
126
+ with open(f"{args.backend}.jsonl", "w") as fout:
127
+ for state in states:
128
+ fout.write(state["json_output"] + "\n")
129
+
130
+ with open(args.result_file, "a") as fout:
131
+ value = {
132
+ "task": "json_schema",
133
+ "backend": args.backend,
134
+ "latency": round(latency, 3),
135
+ "num_jsons": args.num_jsons,
136
+ "parallel": args.parallel,
137
+ }
138
+ fout.write(json.dumps(value) + "\n")
139
+
140
+
141
+ if __name__ == "__main__":
142
+ parser = argparse.ArgumentParser()
143
+ parser.add_argument("--data-path", type=str, default="NousResearch/json-mode-eval")
144
+ parser.add_argument("--num-jsons", type=int, default=-1)
145
+ args = add_common_sglang_args_and_parse(parser)
146
+ main(args)
sglang/benchmark/kernels/decoding_attention_triton/triton_flashinfer_cudnn.py ADDED
@@ -0,0 +1,405 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ import math
3
+
4
+ import cudnn
5
+ import torch
6
+ import torch.utils.benchmark as benchmark
7
+ import triton
8
+ import triton.language as tl
9
+ from flashinfer import BatchDecodeWithPagedKVCacheWrapper
10
+
11
+ from sglang.srt.layers.attention.triton_ops.decode_attention import decode_attention_fwd
12
+ from sglang.srt.utils import should_use_tensor_core
13
+
14
+
15
+ def benchmark_forward(
16
+ fn,
17
+ *inputs,
18
+ repeats=10,
19
+ amp=False,
20
+ amp_dtype=torch.float16,
21
+ **kwinputs,
22
+ ):
23
+ def amp_wrapper(*inputs, **kwinputs):
24
+ with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
25
+ fn(*inputs, **kwinputs)
26
+
27
+ t = benchmark.Timer(
28
+ stmt="fn_amp(*inputs, **kwinputs)",
29
+ globals={"fn_amp": amp_wrapper, "inputs": inputs, "kwinputs": kwinputs},
30
+ num_threads=torch.get_num_threads(),
31
+ )
32
+ m = t.timeit(repeats)
33
+ return t, m
34
+
35
+
36
+ def time_fwd(func, *args, **kwargs):
37
+ time_f = benchmark_forward(func, *args, **kwargs)
38
+ return time_f[1].mean * 1e6
39
+
40
+
41
+ def decode_attention_sglang(
42
+ q,
43
+ kv_data,
44
+ batch_size,
45
+ kv_len,
46
+ head_num_q,
47
+ head_num_kv,
48
+ head_dim,
49
+ num_kv_splits,
50
+ warmup=10,
51
+ ):
52
+
53
+ k_buffer = kv_data[0].view(-1, head_num_kv, head_dim)
54
+ v_buffer = kv_data[1].view(-1, head_num_kv, head_dim)
55
+ o = torch.empty_like(q)
56
+ total_tokens = batch_size * kv_len
57
+ req_to_token = torch.arange(0, total_tokens).to(0).int().view(batch_size, kv_len)
58
+ b_req_idx = torch.arange(0, batch_size).to(0).int()
59
+ b_seq_len = torch.full((batch_size,), kv_len, dtype=torch.int32, device="cuda")
60
+ max_len_in_batch = kv_len
61
+ sm_scale = 1.0 / (head_dim**0.5)
62
+
63
+ attn_logits = torch.empty(
64
+ (batch_size, head_num_q, num_kv_splits, head_dim + 1),
65
+ dtype=torch.float32,
66
+ device="cuda",
67
+ )
68
+
69
+ for _ in range(warmup):
70
+ decode_attention_fwd(
71
+ q,
72
+ k_buffer,
73
+ v_buffer,
74
+ o,
75
+ req_to_token,
76
+ b_req_idx,
77
+ b_seq_len,
78
+ attn_logits,
79
+ num_kv_splits,
80
+ sm_scale,
81
+ )
82
+
83
+ f = time_fwd(
84
+ decode_attention_fwd,
85
+ q,
86
+ k_buffer,
87
+ v_buffer,
88
+ o,
89
+ req_to_token,
90
+ b_req_idx,
91
+ b_seq_len,
92
+ attn_logits,
93
+ num_kv_splits,
94
+ sm_scale,
95
+ )
96
+
97
+ return f, o
98
+
99
+
100
+ def decode_attention_flashinfer(dtype, head_num_q, head_num_kv):
101
+ workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device="cuda")
102
+ use_tensor_cores = should_use_tensor_core(
103
+ kv_cache_dtype=dtype,
104
+ num_attention_heads=head_num_q,
105
+ num_kv_heads=head_num_kv,
106
+ )
107
+ flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
108
+ workspace_buffer, "NHD", use_tensor_cores=use_tensor_cores
109
+ )
110
+
111
+ class FlashinferAttention(torch.autograd.Function):
112
+ @staticmethod
113
+ def forward(
114
+ ctx,
115
+ q,
116
+ kv_data,
117
+ batch_size,
118
+ kv_len,
119
+ head_num_q,
120
+ head_num_kv,
121
+ head_dim,
122
+ dtype,
123
+ warmup=10,
124
+ ):
125
+ total_tokens = batch_size * kv_len
126
+ kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * kv_len
127
+ kv_indices = torch.arange(0, total_tokens).to(0).int()
128
+ kv_last_page_len = torch.full(
129
+ (batch_size,), 1, dtype=torch.int32, device="cuda"
130
+ )
131
+
132
+ flashinfer_decode_wrapper.end_forward()
133
+ flashinfer_decode_wrapper.begin_forward(
134
+ kv_indptr,
135
+ kv_indices,
136
+ kv_last_page_len,
137
+ head_num_q,
138
+ head_num_kv,
139
+ head_dim,
140
+ 1,
141
+ pos_encoding_mode="NONE",
142
+ data_type=dtype,
143
+ )
144
+
145
+ for _ in range(warmup):
146
+ o = flashinfer_decode_wrapper.forward(
147
+ q.contiguous().view(-1, head_num_q, head_dim), kv_data
148
+ )
149
+
150
+ f = time_fwd(
151
+ flashinfer_decode_wrapper.forward,
152
+ q.contiguous().view(-1, head_num_q, head_dim),
153
+ kv_data,
154
+ )
155
+
156
+ return f, o
157
+
158
+ return FlashinferAttention
159
+
160
+
161
+ def convert_to_cudnn_type(torch_type):
162
+ if torch_type == torch.float16:
163
+ return cudnn.data_type.HALF
164
+ elif torch_type == torch.bfloat16:
165
+ return cudnn.data_type.BFLOAT16
166
+ elif torch_type == torch.float32:
167
+ return cudnn.data_type.FLOAT
168
+ elif torch_type == torch.int32:
169
+ return cudnn.data_type.INT32
170
+ elif torch_type == torch.int64:
171
+ return cudnn.data_type.INT64
172
+ else:
173
+ raise ValueError("Unsupported tensor data type.")
174
+
175
+
176
+ def decode_attention_cudnn(
177
+ q, kv_data, batch_size, kv_len, head_num_q, head_num_kv, head_dim, dtype, warmup=10
178
+ ):
179
+ # Prepare data: continuous q,k,v
180
+ dims_q = (batch_size, head_num_q, 1, head_dim)
181
+ strides_q = (head_num_q * head_dim, head_dim, head_num_q * head_dim, 1)
182
+ q_gpu = q.as_strided(dims_q, strides_q)
183
+ o_gpu = (
184
+ torch.empty(batch_size * head_num_q * head_dim)
185
+ .half()
186
+ .cuda()
187
+ .as_strided(dims_q, strides_q)
188
+ )
189
+
190
+ dims_kv = (batch_size, head_num_kv, kv_len, head_dim)
191
+ strides_kv = (
192
+ kv_len * head_num_kv * head_dim,
193
+ head_dim,
194
+ head_num_kv * head_dim,
195
+ 1,
196
+ )
197
+ k_gpu = kv_data[0].as_strided(dims_kv, strides_kv)
198
+ v_gpu = kv_data[1].as_strided(dims_kv, strides_kv)
199
+
200
+ seq_len_q_gpu = torch.full((batch_size, 1, 1, 1), 1, device="cuda")
201
+ seq_len_kv_gpu = torch.full((batch_size, 1, 1, 1), kv_len, device="cuda")
202
+ attn_scale = 1.0 / (head_dim**0.5)
203
+
204
+ # Prepare data: paged k,v
205
+ block_size = 1
206
+ blocks_per_batch = math.ceil(kv_len / block_size)
207
+ # [num_blocks, head_num_kv, block_size, head_dim], num_blocks = batch_size * blocks_per_batch
208
+ container_k_gpu = torch.cat(k_gpu.chunk(blocks_per_batch, dim=2), dim=0)
209
+ container_v_gpu = torch.cat(v_gpu.chunk(blocks_per_batch, dim=2), dim=0)
210
+ page_table_k_gpu = (
211
+ torch.linspace(
212
+ 0,
213
+ batch_size * blocks_per_batch - 1,
214
+ batch_size * blocks_per_batch,
215
+ device="cuda",
216
+ dtype=torch.int32,
217
+ )
218
+ .reshape(blocks_per_batch, 1, batch_size, 1)
219
+ .transpose(0, 2)
220
+ )
221
+ page_table_v_gpu = page_table_k_gpu.clone()
222
+
223
+ graph = cudnn.pygraph(
224
+ io_data_type=convert_to_cudnn_type(dtype),
225
+ intermediate_data_type=cudnn.data_type.FLOAT,
226
+ compute_data_type=cudnn.data_type.FLOAT,
227
+ )
228
+
229
+ q = graph.tensor_like(q_gpu)
230
+ container_k = graph.tensor_like(container_k_gpu)
231
+ container_v = graph.tensor_like(container_v_gpu)
232
+ page_table_k = graph.tensor_like(page_table_k_gpu)
233
+ page_table_v = graph.tensor_like(page_table_v_gpu)
234
+
235
+ seq_len_q = graph.tensor_like(seq_len_q_gpu)
236
+ seq_len_kv = graph.tensor_like(seq_len_kv_gpu)
237
+
238
+ o, _ = graph.sdpa(
239
+ name="sdpa",
240
+ q=q,
241
+ k=container_k, # Container K: non contiguous container with K blocks
242
+ v=container_v, # Container V: non contiguous container with V blocks
243
+ is_inference=True,
244
+ attn_scale=attn_scale,
245
+ use_causal_mask=False,
246
+ use_padding_mask=True,
247
+ seq_len_q=seq_len_q,
248
+ seq_len_kv=seq_len_kv,
249
+ paged_attention_k_table=page_table_k, # Page Table K: Tensor containing offsets to the container with K blocks
250
+ paged_attention_v_table=page_table_v, # Page Table V: Tensor containing offsets to the container with V blocks
251
+ paged_attention_max_seq_len_kv=kv_len, # The maximum sequence length for K caches (this is optional, but recommended)
252
+ )
253
+
254
+ o.set_output(True).set_dim(dims_q).set_stride(strides_q)
255
+
256
+ graph.validate()
257
+ graph.build_operation_graph()
258
+ graph.create_execution_plans([cudnn.heur_mode.A])
259
+ graph.check_support()
260
+ graph.build_plans()
261
+
262
+ workspace = torch.empty(
263
+ graph.get_workspace_size(), device="cuda", dtype=torch.uint8
264
+ )
265
+
266
+ variant_pack = {
267
+ q: q_gpu,
268
+ container_k: container_k_gpu,
269
+ container_v: container_v_gpu,
270
+ page_table_k: page_table_k_gpu,
271
+ page_table_v: page_table_v_gpu,
272
+ seq_len_q: seq_len_q_gpu,
273
+ seq_len_kv: seq_len_kv_gpu,
274
+ o: o_gpu,
275
+ }
276
+
277
+ for _ in range(warmup):
278
+ graph.execute(variant_pack, workspace)
279
+
280
+ f = time_fwd(
281
+ graph.execute,
282
+ variant_pack,
283
+ workspace,
284
+ )
285
+
286
+ return f, o_gpu.squeeze(dim=2)
287
+
288
+
289
+ def calculate_diff():
290
+
291
+ dtype = torch.float16
292
+ batch_size = 64
293
+ kv_len = 4096
294
+ head_num_q = 64
295
+ head_num_kv = 8
296
+ head_dim = 128
297
+
298
+ q = torch.randn(batch_size, head_num_q, head_dim, dtype=dtype, device="cuda")
299
+ kv_data = (
300
+ torch.randn(
301
+ batch_size * kv_len, head_num_kv, head_dim, dtype=dtype, device="cuda"
302
+ ),
303
+ torch.randn(
304
+ batch_size * kv_len, head_num_kv, head_dim, dtype=dtype, device="cuda"
305
+ ),
306
+ )
307
+
308
+ _, output_sglang = decode_attention_sglang(
309
+ q,
310
+ kv_data,
311
+ batch_size,
312
+ kv_len,
313
+ head_num_q,
314
+ head_num_kv,
315
+ head_dim,
316
+ num_kv_splits=8,
317
+ )
318
+
319
+ attn_flashinfer = decode_attention_flashinfer(dtype, head_num_q, head_num_kv).apply
320
+ _, output_flashinfer = attn_flashinfer(
321
+ q, kv_data, batch_size, kv_len, head_num_q, head_num_kv, head_dim, dtype
322
+ )
323
+
324
+ _, output_cudnn = decode_attention_cudnn(
325
+ q, kv_data, batch_size, kv_len, head_num_q, head_num_kv, head_dim, dtype
326
+ )
327
+
328
+ print(f"SGLang output={output_sglang}")
329
+ print(f"FlashInfer output={output_flashinfer}")
330
+ print(f"cuDNN output={output_cudnn}")
331
+ if torch.allclose(output_sglang, output_flashinfer, atol=1e-2, rtol=1e-2):
332
+ print("✅ SGLang[Triton] and FlashInfer match")
333
+ else:
334
+ print("❌ SGLang[Triton] and FlashInfer differ")
335
+
336
+ if torch.allclose(output_sglang, output_cudnn, atol=1e-2, rtol=1e-2):
337
+ print("✅ SGLang[Triton] and cuDNN match")
338
+ else:
339
+ print("❌ SGLang[Triton] and cuDNN differ")
340
+
341
+
342
+ if __name__ == "__main__":
343
+ calculate_diff()
344
+
345
+ head_dim = 128
346
+ dtype = torch.float16
347
+ batch_size_range = [2**i for i in range(0, 8, 2)]
348
+ kv_len_range = [2**i for i in range(6, 13, 1)]
349
+ configs = list(itertools.product(batch_size_range, kv_len_range))
350
+
351
+ for head_num_q, head_num_kv in [[32, 32], [64, 8], [40, 8]]:
352
+ attn_flashinfer = decode_attention_flashinfer(
353
+ dtype, head_num_q, head_num_kv
354
+ ).apply
355
+ for batch_size, kv_len in configs:
356
+ q = torch.randn(
357
+ batch_size, head_num_q, head_dim, dtype=dtype, device="cuda"
358
+ )
359
+ kv_data = (
360
+ torch.randn(
361
+ batch_size * kv_len,
362
+ head_num_kv,
363
+ head_dim,
364
+ dtype=dtype,
365
+ device="cuda",
366
+ ),
367
+ torch.randn(
368
+ batch_size * kv_len,
369
+ head_num_kv,
370
+ head_dim,
371
+ dtype=dtype,
372
+ device="cuda",
373
+ ),
374
+ )
375
+ us_cudnn, output_cudnn = decode_attention_cudnn(
376
+ q, kv_data, batch_size, kv_len, head_num_q, head_num_kv, head_dim, dtype
377
+ )
378
+ us_sglang, output_sglang = decode_attention_sglang(
379
+ q,
380
+ kv_data,
381
+ batch_size,
382
+ kv_len,
383
+ head_num_q,
384
+ head_num_kv,
385
+ head_dim,
386
+ num_kv_splits=8,
387
+ )
388
+ us_flashinfer, _ = attn_flashinfer(
389
+ q, kv_data, batch_size, kv_len, head_num_q, head_num_kv, head_dim, dtype
390
+ )
391
+ print(
392
+ head_num_q,
393
+ " ",
394
+ head_num_kv,
395
+ " ",
396
+ batch_size,
397
+ " ",
398
+ kv_len,
399
+ " ",
400
+ us_cudnn,
401
+ " ",
402
+ us_sglang,
403
+ " ",
404
+ us_flashinfer,
405
+ )
sglang/benchmark/kernels/fused_moe_triton/README.md ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Benchmark Kernels
2
+
3
+ This directory contains benchmarking tools for MoE (Mixture of Experts) kernels.
4
+
5
+ ### Tuning Tool
6
+
7
+ - `tuning_fused_moe_triton.py`: A tool for tuning the `fused_moe_triton` kernel. Adapted from [vllm's benchmark_moe.py](https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_moe.py), with added support for various model architectures.
8
+
9
+ Example usage:
10
+ ```bash
11
+ # Tune Qwen2-57B with FP8 and TP=4
12
+ python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \
13
+ --model Qwen/Qwen2-57B-A14B-Instruct \
14
+ --tp-size 4 \
15
+ --dtype fp8_w8a8 \
16
+ --tune
17
+
18
+ # Tune Mixtral-8x7B with default settings
19
+ python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \
20
+ --model mistralai/Mixtral-8x7B-Instruct-v0.1 \
21
+ --tune
22
+ ```
23
+
24
+ After tuning, a configuration file (e.g., `E=64,N=640,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json`) will be generated in the current directory. You can move this file to `sglang/srt/layers/fused_moe_triton/configs/` to use it in `sglang`.
25
+
26
+ ### Performance Comparison Tool
27
+
28
+ - `benchmark_vllm_vs_sglang_fused_moe_triton.py`: A tool for comparing the performance of fused MoE kernels between vllm and sglang implementations. Supports various model architectures and data types.
29
+
30
+ Example usage:
31
+ ```bash
32
+ # Compare with default settings (Mixtral model)
33
+ python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py
34
+
35
+ # Compare with FP8 mode for Qwen2-57B
36
+ python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py \
37
+ --model Qwen/Qwen2-57B-A14B-Instruct \
38
+ --use-fp8
39
+
40
+ # Compare with custom TP size
41
+ python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py \
42
+ --tp-size 4
43
+ ```
44
+
45
+ The benchmark results will be saved as plots and data files in the specified output directory (default: `./configs/benchmark_ops/vllm_sglang_fused_moe/`).
46
+
47
+ - `benchmark_torch_compile_fused_moe.py`: A tool for benchmarking the performance of the fused MoE kernel with `torch.compile` and original fused MoE kernel.
48
+
49
+ Usage is the same as `benchmark_vllm_vs_sglang_fused_moe_triton.py`.
sglang/benchmark/kernels/rmsnorm/benchmark_rmsnorm.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ from typing import Optional, Tuple, Union
3
+
4
+ import torch
5
+ import triton
6
+ import triton.language as tl
7
+ from flashinfer.norm import fused_add_rmsnorm, rmsnorm
8
+ from torch import nn
9
+ from vllm import _custom_ops as vllm_ops
10
+
11
+
12
+ class HuggingFaceRMSNorm(nn.Module):
13
+ def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
14
+ super().__init__()
15
+ self.weight = nn.Parameter(torch.ones(hidden_size))
16
+ self.variance_epsilon = eps
17
+
18
+ def forward(
19
+ self,
20
+ x: torch.Tensor,
21
+ residual: Optional[torch.Tensor] = None,
22
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
23
+ orig_dtype = x.dtype
24
+ x = x.to(torch.float32)
25
+ if residual is not None:
26
+ x = x + residual.to(torch.float32)
27
+ residual = x.to(orig_dtype)
28
+
29
+ variance = x.pow(2).mean(dim=-1, keepdim=True)
30
+ x = x * torch.rsqrt(variance + self.variance_epsilon)
31
+ x = x.to(orig_dtype) * self.weight
32
+ if residual is None:
33
+ return x
34
+ else:
35
+ return x, residual
36
+
37
+
38
+ def rmsnorm_naive(
39
+ x: torch.Tensor,
40
+ weight: torch.Tensor,
41
+ residual: Optional[torch.Tensor] = None,
42
+ eps: float = 1e-6,
43
+ ):
44
+ naive_norm = HuggingFaceRMSNorm(x.shape[-1], eps=eps)
45
+ naive_norm.weight = nn.Parameter(weight)
46
+ naive_norm = naive_norm.to(x.device)
47
+
48
+ orig_shape = x.shape
49
+ x = x.view(-1, x.shape[-1])
50
+ if residual is not None:
51
+ residual = residual.view(-1, residual.shape[-1])
52
+
53
+ output = naive_norm(x, residual)
54
+
55
+ if isinstance(output, tuple):
56
+ output = (output[0].view(orig_shape), output[1].view(orig_shape))
57
+ else:
58
+ output = output.view(orig_shape)
59
+ return output
60
+
61
+
62
+ def rmsnorm_flashinfer(
63
+ x: torch.Tensor,
64
+ weight: torch.Tensor,
65
+ residual: Optional[torch.Tensor] = None,
66
+ eps: float = 1e-6,
67
+ ):
68
+ orig_shape = x.shape
69
+ x = x.view(-1, x.shape[-1])
70
+ if residual is not None:
71
+ residual = residual.view(-1, residual.shape[-1])
72
+
73
+ if residual is not None:
74
+ fused_add_rmsnorm(x, residual, weight, eps)
75
+ output = (x, residual)
76
+ else:
77
+ output = rmsnorm(x, weight, eps)
78
+
79
+ if isinstance(output, tuple):
80
+ output = (output[0].view(orig_shape), output[1].view(orig_shape))
81
+ else:
82
+ output = output.view(orig_shape)
83
+ return output
84
+
85
+
86
+ def rmsnorm_vllm(
87
+ x: torch.Tensor,
88
+ weight: torch.Tensor,
89
+ residual: Optional[torch.Tensor] = None,
90
+ eps: float = 1e-6,
91
+ ):
92
+ orig_shape = x.shape
93
+ x = x.view(-1, x.shape[-1])
94
+ if residual is not None:
95
+ residual = residual.view(-1, residual.shape[-1])
96
+
97
+ if residual is not None:
98
+ vllm_ops.fused_add_rms_norm(x, residual, weight, eps)
99
+ output = (x, residual)
100
+ else:
101
+ out = torch.empty_like(x)
102
+ vllm_ops.rms_norm(out, x, weight, eps)
103
+ output = out
104
+
105
+ if isinstance(output, tuple):
106
+ output = (output[0].view(orig_shape), output[1].view(orig_shape))
107
+ else:
108
+ output = output.view(orig_shape)
109
+ return output
110
+
111
+
112
+ def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True):
113
+ dtype = torch.bfloat16
114
+ x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device="cuda")
115
+ weight = torch.ones(hidden_size, dtype=dtype, device="cuda")
116
+ residual = torch.randn_like(x) if use_residual else None
117
+
118
+ output_naive = rmsnorm_naive(
119
+ x.clone(), weight, residual.clone() if residual is not None else None
120
+ )
121
+ output_flashinfer = rmsnorm_flashinfer(
122
+ x.clone(), weight, residual.clone() if residual is not None else None
123
+ )
124
+ output_vllm = rmsnorm_vllm(
125
+ x.clone(), weight, residual.clone() if residual is not None else None
126
+ )
127
+
128
+ if use_residual:
129
+ output_naive = output_naive[0]
130
+ output_flashinfer = output_flashinfer[0]
131
+ output_vllm = output_vllm[0]
132
+
133
+ print(f"Naive output={output_naive}")
134
+ print(f"FlashInfer output={output_flashinfer}")
135
+ print(f"VLLM output={output_vllm}")
136
+
137
+ if torch.allclose(
138
+ output_naive, output_flashinfer, atol=1e-2, rtol=1e-2
139
+ ) and torch.allclose(output_naive, output_vllm, atol=1e-2, rtol=1e-2):
140
+ print("✅ All implementations match")
141
+ else:
142
+ print("❌ Implementations differ")
143
+
144
+
145
+ batch_size_range = [2**i for i in range(0, 7, 2)]
146
+ seq_length_range = [2**i for i in range(6, 11, 1)]
147
+ head_num_range = [32, 48]
148
+ configs = list(itertools.product(head_num_range, batch_size_range, seq_length_range))
149
+
150
+
151
+ def get_benchmark(use_residual):
152
+ @triton.testing.perf_report(
153
+ triton.testing.Benchmark(
154
+ x_names=["head_num", "batch_size", "seq_len"],
155
+ x_vals=[list(_) for _ in configs],
156
+ line_arg="provider",
157
+ line_vals=["huggingface", "flashinfer", "vllm"],
158
+ line_names=["HuggingFace", "FlashInfer", "vLLM"],
159
+ styles=[("blue", "-"), ("green", "-"), ("red", "-")],
160
+ ylabel="us",
161
+ plot_name=f"rmsnorm-performance-{'with' if use_residual else 'without'}-residual",
162
+ args={},
163
+ )
164
+ )
165
+ def benchmark(head_num, batch_size, seq_len, provider):
166
+ dtype = torch.bfloat16
167
+ hidden_size = head_num * 128 # assuming head_dim = 128
168
+
169
+ x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device="cuda")
170
+ weight = torch.ones(hidden_size, dtype=dtype, device="cuda")
171
+ residual = torch.randn_like(x) if use_residual else None
172
+
173
+ quantiles = [0.5, 0.2, 0.8]
174
+
175
+ if provider == "huggingface":
176
+ ms, min_ms, max_ms = triton.testing.do_bench(
177
+ lambda: rmsnorm_naive(
178
+ x.clone(),
179
+ weight,
180
+ residual.clone() if residual is not None else None,
181
+ ),
182
+ quantiles=quantiles,
183
+ )
184
+ elif provider == "flashinfer":
185
+ ms, min_ms, max_ms = triton.testing.do_bench(
186
+ lambda: rmsnorm_flashinfer(
187
+ x.clone(),
188
+ weight,
189
+ residual.clone() if residual is not None else None,
190
+ ),
191
+ quantiles=quantiles,
192
+ )
193
+ else:
194
+ ms, min_ms, max_ms = triton.testing.do_bench(
195
+ lambda: rmsnorm_vllm(
196
+ x.clone(),
197
+ weight,
198
+ residual.clone() if residual is not None else None,
199
+ ),
200
+ quantiles=quantiles,
201
+ )
202
+
203
+ return 1000 * ms, 1000 * max_ms, 1000 * min_ms
204
+
205
+ return benchmark
206
+
207
+
208
+ if __name__ == "__main__":
209
+ import argparse
210
+
211
+ parser = argparse.ArgumentParser()
212
+ parser.add_argument(
213
+ "--use_residual", action="store_true", help="Whether to use residual connection"
214
+ )
215
+ parser.add_argument(
216
+ "--save_path",
217
+ type=str,
218
+ default="./configs/benchmark_ops/rmsnorm/",
219
+ help="Path to save rmsnorm benchmark results",
220
+ )
221
+ args = parser.parse_args()
222
+
223
+ # Run correctness test
224
+ calculate_diff(
225
+ batch_size=4, seq_len=128, hidden_size=4096, use_residual=args.use_residual
226
+ )
227
+
228
+ # Get the benchmark function with proper use_residual setting
229
+ benchmark = get_benchmark(args.use_residual)
230
+ # Run performance benchmark
231
+ benchmark.run(print_data=True, save_path=args.save_path)
sglang/benchmark/kernels/scheduler_batch/benchmark_write_req_to_token_pool_triton.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ import os
3
+ from typing import List
4
+
5
+ import numpy as np
6
+ import pytest
7
+ import torch
8
+ import triton
9
+ import triton.language as tl
10
+
11
+
12
+ @triton.jit
13
+ def write_req_to_token_pool_triton(
14
+ req_to_token_ptr, # [max_batch, max_context_len]
15
+ req_pool_indices,
16
+ pre_lens,
17
+ seq_lens,
18
+ extend_lens,
19
+ out_cache_loc,
20
+ req_to_token_ptr_stride: tl.constexpr,
21
+ ):
22
+ BLOCK_SIZE: tl.constexpr = 512
23
+ pid = tl.program_id(0)
24
+
25
+ req_pool_index = tl.load(req_pool_indices + pid)
26
+ pre_len = tl.load(pre_lens + pid)
27
+ seq_len = tl.load(seq_lens + pid)
28
+
29
+ # TODO: optimize this?
30
+ cumsum_start = 0
31
+ for i in range(pid):
32
+ cumsum_start += tl.load(extend_lens + i)
33
+
34
+ num_loop = tl.cdiv(seq_len - pre_len, BLOCK_SIZE)
35
+ for i in range(num_loop):
36
+ offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
37
+ mask = offset < (seq_len - pre_len)
38
+ value = tl.load(out_cache_loc + cumsum_start + offset, mask=mask)
39
+ tl.store(
40
+ req_to_token_ptr
41
+ + req_pool_index * req_to_token_ptr_stride
42
+ + offset
43
+ + pre_len,
44
+ value,
45
+ mask=mask,
46
+ )
47
+
48
+
49
+ @triton.jit
50
+ def write_req_to_token_pool_triton_optimize(
51
+ req_to_token_ptr, # [max_batch, max_context_len]
52
+ req_pool_indices,
53
+ pre_lens,
54
+ seq_lens,
55
+ extend_lens,
56
+ out_cache_loc,
57
+ req_to_token_ptr_stride: tl.constexpr,
58
+ BLOCK_SIZE: tl.constexpr,
59
+ ):
60
+ pid_batch = tl.program_id(0)
61
+ pid_token = tl.program_id(1)
62
+
63
+ req_pool_index = tl.load(req_pool_indices + pid_batch)
64
+ pre_len = tl.load(pre_lens + pid_batch)
65
+ seq_len = tl.load(seq_lens + pid_batch)
66
+ extend_len = seq_len - pre_len
67
+
68
+ cumsum_start = 0
69
+ for i in range(pid_batch):
70
+ cumsum_start += tl.load(extend_lens + i)
71
+
72
+ token_start = pid_token * BLOCK_SIZE
73
+
74
+ offset = tl.arange(0, BLOCK_SIZE)
75
+ actual_offset = token_start + offset
76
+ mask = actual_offset < extend_len
77
+
78
+ src_ptr = out_cache_loc + cumsum_start + actual_offset
79
+ src_ptr = tl.max_contiguous(tl.multiple_of(src_ptr, BLOCK_SIZE), BLOCK_SIZE)
80
+ value = tl.load(src_ptr, mask=mask)
81
+ dst_ptr = (
82
+ req_to_token_ptr
83
+ + req_pool_index * req_to_token_ptr_stride
84
+ + actual_offset
85
+ + pre_len
86
+ )
87
+ dst_ptr = tl.max_contiguous(tl.multiple_of(dst_ptr, BLOCK_SIZE), BLOCK_SIZE)
88
+
89
+ tl.store(dst_ptr, value, mask=mask)
90
+
91
+
92
+ def write_req_to_token_pool_reference(
93
+ req_to_token: torch.Tensor,
94
+ req_pool_indices: torch.Tensor,
95
+ pre_lens: torch.Tensor,
96
+ seq_lens: torch.Tensor,
97
+ extend_lens: torch.Tensor,
98
+ out_cache_loc: torch.Tensor,
99
+ ) -> None:
100
+ """Reference implementation using PyTorch"""
101
+ for i in range(len(req_pool_indices)):
102
+ req_pool_idx = req_pool_indices[i].item()
103
+ pre_len = pre_lens[i].item()
104
+ seq_len = seq_lens[i].item()
105
+ extend_len = extend_lens[i].item()
106
+
107
+ cumsum_start = sum(extend_lens[:i].tolist())
108
+
109
+ # Copy values from out_cache_loc to req_to_token
110
+ req_to_token[req_pool_idx, pre_len:seq_len] = out_cache_loc[
111
+ cumsum_start : cumsum_start + extend_len
112
+ ]
113
+
114
+
115
+ def test_write_req_to_token_pool():
116
+ max_batch = 4097
117
+ max_context_len = 6148
118
+ batch_size = 1
119
+ extend_len = 14
120
+
121
+ # Initialize input tensors
122
+ req_to_token = torch.zeros(
123
+ (max_batch, max_context_len), dtype=torch.int32, device="cuda"
124
+ )
125
+ req_pool_indices = torch.tensor([42], dtype=torch.int32, device="cuda")
126
+ pre_lens = torch.tensor([8], dtype=torch.int32, device="cuda")
127
+ seq_lens = torch.tensor([22], dtype=torch.int32, device="cuda")
128
+ extend_lens = torch.tensor([extend_len], dtype=torch.int32, device="cuda")
129
+ out_cache_loc = torch.arange(extend_len, dtype=torch.int32, device="cuda")
130
+
131
+ # Create copies for reference implementation
132
+ req_to_token_ref = req_to_token.clone()
133
+ req_to_token_opt = req_to_token.clone()
134
+
135
+ # Run original triton kernel
136
+ write_req_to_token_pool_triton[(batch_size,)](
137
+ req_to_token,
138
+ req_pool_indices,
139
+ pre_lens,
140
+ seq_lens,
141
+ extend_lens,
142
+ out_cache_loc,
143
+ max_context_len,
144
+ )
145
+
146
+ # Run optimized triton kernel
147
+ def grid(batch_size, extend_len):
148
+ num_token_blocks = triton.cdiv(extend_len, 512)
149
+ return (batch_size, num_token_blocks)
150
+
151
+ write_req_to_token_pool_triton_optimize[grid(batch_size, extend_len)](
152
+ req_to_token_opt,
153
+ req_pool_indices,
154
+ pre_lens,
155
+ seq_lens,
156
+ extend_lens,
157
+ out_cache_loc,
158
+ max_context_len,
159
+ BLOCK_SIZE=512,
160
+ )
161
+
162
+ # Run reference implementation
163
+ write_req_to_token_pool_reference(
164
+ req_to_token_ref,
165
+ req_pool_indices,
166
+ pre_lens,
167
+ seq_lens,
168
+ extend_lens,
169
+ out_cache_loc,
170
+ )
171
+
172
+ # Compare results
173
+ torch.testing.assert_close(req_to_token, req_to_token_ref)
174
+ torch.testing.assert_close(req_to_token_opt, req_to_token_ref)
175
+
176
+ # Test case 2: batch size > 1
177
+ batch_size = 3
178
+ extend_lens_list = [14, 20, 30]
179
+ total_extend_len = sum(extend_lens_list)
180
+
181
+ req_to_token = torch.zeros(
182
+ (max_batch, max_context_len), dtype=torch.int32, device="cuda"
183
+ )
184
+ req_pool_indices = torch.tensor([42, 100, 200], dtype=torch.int32, device="cuda")
185
+ pre_lens = torch.tensor([8, 10, 15], dtype=torch.int32, device="cuda")
186
+ seq_lens = torch.tensor([22, 30, 45], dtype=torch.int32, device="cuda")
187
+ extend_lens = torch.tensor(extend_lens_list, dtype=torch.int32, device="cuda")
188
+ out_cache_loc = torch.arange(total_extend_len, dtype=torch.int32, device="cuda")
189
+
190
+ req_to_token_ref = req_to_token.clone()
191
+ req_to_token_opt = req_to_token.clone()
192
+
193
+ # Run original triton kernel
194
+ write_req_to_token_pool_triton[(batch_size,)](
195
+ req_to_token,
196
+ req_pool_indices,
197
+ pre_lens,
198
+ seq_lens,
199
+ extend_lens,
200
+ out_cache_loc,
201
+ max_context_len,
202
+ )
203
+
204
+ # Run optimized triton kernel
205
+ max_extend_len = max(extend_lens_list)
206
+ write_req_to_token_pool_triton_optimize[grid(batch_size, max_extend_len)](
207
+ req_to_token_opt,
208
+ req_pool_indices,
209
+ pre_lens,
210
+ seq_lens,
211
+ extend_lens,
212
+ out_cache_loc,
213
+ max_context_len,
214
+ BLOCK_SIZE=512,
215
+ )
216
+
217
+ # Run reference implementation
218
+ write_req_to_token_pool_reference(
219
+ req_to_token_ref,
220
+ req_pool_indices,
221
+ pre_lens,
222
+ seq_lens,
223
+ extend_lens,
224
+ out_cache_loc,
225
+ )
226
+
227
+ # Compare results
228
+ torch.testing.assert_close(req_to_token, req_to_token_ref)
229
+ torch.testing.assert_close(req_to_token_opt, req_to_token_ref)
230
+
231
+
232
+ def get_benchmark():
233
+ batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128]
234
+ extend_lens = [32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]
235
+ configs = list(itertools.product(batch_sizes, extend_lens))
236
+
237
+ @triton.testing.perf_report(
238
+ triton.testing.Benchmark(
239
+ x_names=["batch_size", "extend_len"],
240
+ x_vals=configs,
241
+ line_arg="provider",
242
+ line_vals=["reference", "triton", "triton_optimize"],
243
+ line_names=["PyTorch", "Triton", "Triton Optimized"],
244
+ styles=[("blue", "-"), ("green", "-"), ("red", "-")],
245
+ ylabel="us",
246
+ plot_name="write-req-to-token-pool-performance",
247
+ args={},
248
+ )
249
+ )
250
+ def benchmark(batch_size, extend_len, provider):
251
+ max_batch = 256
252
+ max_context_len = 16384
253
+
254
+ extend_lens_list = [extend_len] * batch_size
255
+ total_extend_len = sum(extend_lens_list)
256
+
257
+ req_to_token = torch.zeros(
258
+ (max_batch, max_context_len), dtype=torch.int32, device="cuda"
259
+ )
260
+ req_pool_indices = torch.arange(batch_size, dtype=torch.int32, device="cuda")
261
+ pre_lens = torch.ones(batch_size, dtype=torch.int32, device="cuda") * 8
262
+ seq_lens = pre_lens + extend_len
263
+ extend_lens = torch.tensor(extend_lens_list, dtype=torch.int32, device="cuda")
264
+ out_cache_loc = torch.arange(total_extend_len, dtype=torch.int32, device="cuda")
265
+
266
+ quantiles = [0.5, 0.2, 0.8]
267
+
268
+ if provider == "reference":
269
+ ms, min_ms, max_ms = triton.testing.do_bench(
270
+ lambda: write_req_to_token_pool_reference(
271
+ req_to_token.clone(),
272
+ req_pool_indices,
273
+ pre_lens,
274
+ seq_lens,
275
+ extend_lens,
276
+ out_cache_loc,
277
+ ),
278
+ quantiles=quantiles,
279
+ )
280
+ elif provider == "triton":
281
+ ms, min_ms, max_ms = triton.testing.do_bench(
282
+ lambda: write_req_to_token_pool_triton[(batch_size,)](
283
+ req_to_token.clone(),
284
+ req_pool_indices,
285
+ pre_lens,
286
+ seq_lens,
287
+ extend_lens,
288
+ out_cache_loc,
289
+ max_context_len,
290
+ ),
291
+ quantiles=quantiles,
292
+ )
293
+ else:
294
+
295
+ def run_optimized():
296
+ block_size = 128 if extend_len <= 1024 else 512
297
+ grid_config = (batch_size, triton.cdiv(extend_len, block_size))
298
+ write_req_to_token_pool_triton_optimize[grid_config](
299
+ req_to_token.clone(),
300
+ req_pool_indices,
301
+ pre_lens,
302
+ seq_lens,
303
+ extend_lens,
304
+ out_cache_loc,
305
+ max_context_len,
306
+ BLOCK_SIZE=block_size,
307
+ )
308
+
309
+ ms, min_ms, max_ms = triton.testing.do_bench(
310
+ run_optimized, quantiles=quantiles
311
+ )
312
+
313
+ return 1000 * ms, 1000 * max_ms, 1000 * min_ms
314
+
315
+ return benchmark
316
+
317
+
318
+ def run_benchmark(save_path: str = "./configs/benchmark_ops/write_req_to_token_pool/"):
319
+ """Run benchmark and save results"""
320
+
321
+ # Ensure save path exists
322
+ os.makedirs(save_path, exist_ok=True)
323
+
324
+ # Run correctness test
325
+ test_write_req_to_token_pool()
326
+ print("Correctness test passed!")
327
+
328
+ # Run performance test
329
+ benchmark = get_benchmark()
330
+ benchmark.run(print_data=True, save_path=save_path)
331
+
332
+
333
+ if __name__ == "__main__":
334
+ import argparse
335
+
336
+ parser = argparse.ArgumentParser()
337
+ parser.add_argument(
338
+ "--save_path",
339
+ type=str,
340
+ default="./configs/benchmark_ops/write_req_to_token_pool/",
341
+ help="Path to save benchmark results",
342
+ )
343
+ args = parser.parse_args()
344
+
345
+ run_benchmark(args.save_path)
sglang/benchmark/line_retrieval/README.md ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Download data
2
+
3
+ ```
4
+ wget https://raw.githubusercontent.com/merrymercy/merrymercy.github.io/master/files/random_words.json
5
+ python3 gen_data.py --number 1000
6
+ ```
7
+
8
+ ## Run benchmark
9
+
10
+ ### Benchmark sglang
11
+ ```
12
+ python3 -m sglang.launch_server --model-path codellama/CodeLlama-7b-hf --port 30000
13
+ ```
14
+
15
+ ```
16
+ python3 bench_sglang.py --src-index 600 --num-q 50 --parallel 1
17
+ ```
18
+
19
+
20
+ ###
21
+
22
+ ```
23
+ # original
24
+ Accuracy: 0.940, latency: 332.83 s
25
+
26
+ # parallel encoding (no_adjust, offset = 1000)
27
+ Accuracy: 0.760, latency: 238.46 s
28
+
29
+ # parallel encoding (no_adjust, offset = 3000)
30
+ Accuracy: 0.760, latency: 238.46 s
31
+
32
+ # parallel encoding (no_adjust, offset = 0)
33
+ Accuracy: 0.520, latency: 238.46 s
34
+
35
+ # parallel encoding (adjust_cache)
36
+ Accuracy: 0.460, latency: 257.66 s
37
+ ```
sglang/benchmark/line_retrieval/bench_sglang.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import re
4
+ import time
5
+
6
+ import numpy as np
7
+
8
+ import sglang as sgl
9
+ from sglang.test.test_utils import (
10
+ add_common_sglang_args_and_parse,
11
+ select_sglang_backend,
12
+ )
13
+ from sglang.utils import dump_state_text
14
+
15
+
16
+ @sgl.function
17
+ def line_retrieval(s, prefix, suffix, body_0, body_1, body_2, body_3):
18
+ s += prefix + "\n"
19
+
20
+ contexts = [body_0, body_1, body_2, body_3]
21
+ position_ids_offset = [i * 1000 for i in range(len(contexts))]
22
+ forks = s.fork(len(contexts), position_ids_offset)
23
+ forks += lambda i: contexts[i] + "\n"
24
+ forks.join(mode="concate_and_append")
25
+
26
+ s += "\n" + suffix
27
+ s += sgl.gen("answer", max_tokens=16)
28
+
29
+
30
+ def eval_model(args, line_obj, num_hoops, src_indices, dst_percents):
31
+ arguments = []
32
+ labels = []
33
+ sum_src_indices = []
34
+ sum_dst_indices = []
35
+
36
+ for i in range(len(src_indices)):
37
+ for j in range(len(dst_percents)):
38
+ src_index = src_indices[i]
39
+ dst_percent = dst_percents[j]
40
+
41
+ query_indices = line_obj["group_by_num_hoops"][str(num_hoops)]
42
+ query_indices = [
43
+ q
44
+ for q in query_indices
45
+ if all(l <= src_index for l in line_obj["links"][q]) and q < src_index
46
+ ]
47
+ dst_index = query_indices[
48
+ min(int(len(query_indices) * dst_percent), len(query_indices) - 1)
49
+ ]
50
+ label = line_obj["values"][dst_index]
51
+
52
+ body = line_obj["lines"][: src_index + 1]
53
+ suffix = line_obj["suffix"].replace("???", line_obj["indices"][dst_index])
54
+ body_part_len = len(body) // 4
55
+
56
+ arguments.append(
57
+ {
58
+ "prefix": line_obj["prefix"],
59
+ "body_0": "\n".join(body[:body_part_len]),
60
+ "body_1": "\n".join(body[body_part_len : 2 * body_part_len]),
61
+ "body_2": "\n".join(body[2 * body_part_len : 3 * body_part_len]),
62
+ "body_3": "\n".join(body[3 * body_part_len :]),
63
+ "suffix": suffix,
64
+ }
65
+ )
66
+ labels.append(label)
67
+ sum_src_indices.append(src_index)
68
+ sum_dst_indices.append(dst_index)
69
+
70
+ # Select backend
71
+ backend = select_sglang_backend(args)
72
+
73
+ tic = time.time()
74
+ states = line_retrieval.run_batch(
75
+ arguments,
76
+ temperature=0,
77
+ backend=backend,
78
+ num_threads=args.parallel,
79
+ progress_bar=True,
80
+ )
81
+ latency = time.time() - tic
82
+
83
+ corrects = []
84
+ for i in range(len(arguments)):
85
+ output = states[i]["answer"]
86
+ prompt_len = states[i].get_meta_info("answer").get("prompt_length", -1)
87
+ label = labels[i]
88
+
89
+ # Try all numbers
90
+ findall = re.findall("\d+", output)
91
+ if not findall:
92
+ response_number = output
93
+ else:
94
+ for response_number in findall:
95
+ if response_number == label:
96
+ break
97
+
98
+ correct = response_number == label
99
+ corrects.append(correct)
100
+
101
+ # Log results
102
+ summary = (
103
+ f"Line index: {sum_src_indices[i]} -> {sum_dst_indices[i]}, "
104
+ f"Prompt len: {prompt_len}, "
105
+ f"Correct: {correct}, "
106
+ f"Label: {label}, Predicted: {response_number}, "
107
+ )
108
+ print(summary)
109
+
110
+ accuracy = np.mean(corrects)
111
+ print(f"Accuracy: {accuracy:.3f}, latency: {latency:.2f} s")
112
+
113
+ # Write results
114
+ dump_state_text(f"tmp_output_{args.backend}.txt", states)
115
+
116
+ with open(args.result_file, "a") as fout:
117
+ value = {
118
+ "task": "line_retrieval",
119
+ "backend": args.backend,
120
+ "num_gpus": 1,
121
+ "latency": round(latency, 3),
122
+ "num_requests": len(arguments),
123
+ "other": {
124
+ "num_questions": len(arguments),
125
+ "parallel": args.parallel,
126
+ },
127
+ }
128
+ fout.write(json.dumps(value) + "\n")
129
+
130
+
131
+ def main(args):
132
+ line_obj = json.load(open(args.data_path, "r"))
133
+
134
+ num_hoops = args.num_hoops
135
+ for src_index in args.src_index:
136
+ src_indices = [src_index]
137
+ num_queries = args.num_queries_per_src
138
+ dst_percents = [i * (1 / (num_queries)) for i in range(num_queries)]
139
+ eval_model(args, line_obj, num_hoops, src_indices, dst_percents)
140
+
141
+
142
+ if __name__ == "__main__":
143
+ parser = argparse.ArgumentParser()
144
+ parser.add_argument("--data-path", type=str, default="lines_1000_0.0.json")
145
+ parser.add_argument("--src-index", type=int, nargs="+", default=[100])
146
+ parser.add_argument("--num-queries-per-src", type=int, default=10)
147
+ parser.add_argument("--num-hoops", type=int, default=1)
148
+ args = add_common_sglang_args_and_parse(parser)
149
+ main(args)
sglang/benchmark/line_retrieval/gen_data.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Generate line data for line retrieval task.
3
+
4
+ Usage:
5
+ python3 gen_data.py --number 1000
6
+ """
7
+
8
+ import argparse
9
+ import json
10
+ from collections import defaultdict
11
+
12
+ import numpy as np
13
+ from tqdm import tqdm
14
+
15
+
16
+ def generate_lines(random_words, num_lines, redirect_ratio):
17
+ prefix = "Here is a list of lines, each with its corresponding REGISTER_CONTENT value. Please memorize them. Be prepared to provide the REGISTER_CONTENT value for a specific line index when I ask."
18
+ suffix = "The list has ended. Please give the final REGISTER_CONTENT value for a specific line after resovling the redirections and references. For example, the REGISTER_CONTENT of Line __idx0__ is __val0__. The REGISTER_CONTENT of Line __idx1__ is __val1__. The REGISTER_CONTENT of Line __idx2__ is __val2__. The REGISTER_CONTENT of Line ??? is"
19
+
20
+ # Raw lines
21
+ visited_indices = set([None])
22
+ visited_values = set([None])
23
+
24
+ lines = []
25
+ redirects = []
26
+ indices = []
27
+ values = []
28
+ for i in tqdm(range(num_lines)):
29
+ line_index = None
30
+ while line_index in visited_indices:
31
+ line_index = "-".join(np.random.choice(random_words, size=(2,)))
32
+ visited_indices.add(line_index)
33
+
34
+ line_value = np.random.randint(low=0, high=999999)
35
+ line_value = f"{line_value:06}"
36
+
37
+ line = f"Line {line_index}: The REGISTER_CONTENT is {line_value}."
38
+ lines.append(line)
39
+ redirects.append(None)
40
+ indices.append(line_index)
41
+ values.append(line_value)
42
+
43
+ # Add redirect
44
+ if redirect_ratio > 0:
45
+ num_redirect_lines = int(len(lines) * redirect_ratio)
46
+ redirect_indices = np.random.choice(
47
+ np.arange(len(lines)), size=(num_redirect_lines,), replace=False
48
+ )
49
+ for i in redirect_indices:
50
+ target_idx = np.random.choice(min(i * 2 + 100, num_lines))
51
+ lines[i] = (
52
+ f"Line {indices[i]}: The REGISTER_CONTENT is the same as Line {indices[target_idx]}."
53
+ )
54
+ redirects[i] = target_idx
55
+
56
+ # Build links and find sources
57
+ links = [[] for _ in range(num_lines)]
58
+ contains_ring = set()
59
+ for i in range(num_lines):
60
+ if redirects[i] is None:
61
+ continue
62
+
63
+ tmp_link = []
64
+ cur = i
65
+ visited = set()
66
+ while redirects[cur] is not None:
67
+ visited.add(cur)
68
+ tmp_link.append(redirects[cur])
69
+ cur = redirects[cur]
70
+
71
+ if cur in visited:
72
+ contains_ring.add(i)
73
+ tmp_link = None
74
+ break
75
+ values[i] = values[cur]
76
+ links[i] = tmp_link
77
+
78
+ # Group by num_links
79
+ group_by_num_hoops = defaultdict(list)
80
+ for i in range(num_lines):
81
+ if i in contains_ring:
82
+ continue
83
+ group_by_num_hoops[len(links[i]) + 1].append(i)
84
+
85
+ keys = sorted(list(group_by_num_hoops.keys()))
86
+ for num_links in keys:
87
+ print(f"#links: {num_links}, #lines: {len(group_by_num_hoops[num_links])}")
88
+
89
+ # Append few-shot examples
90
+ hoop1_candidates = list(group_by_num_hoops[1])
91
+ hoop1_candidate_keys = {c: max([c] + links[c]) for c in hoop1_candidates}
92
+ hoop1_candidates.sort(key=lambda c: hoop1_candidate_keys[c])
93
+ hoop2_candidates = list(group_by_num_hoops[2])
94
+ hoop2_candidate_keys = {c: max([c] + links[c]) for c in hoop2_candidates}
95
+ hoop2_candidates.sort(key=lambda c: hoop2_candidate_keys[c])
96
+
97
+ i = hoop1_candidates[5]
98
+ suffix = suffix.replace("__idx0__", indices[i]).replace("__val0__", values[i])
99
+ if len(hoop2_candidates):
100
+ i = hoop2_candidates[0]
101
+ suffix = suffix.replace("__idx1__", indices[i]).replace("__val1__", values[i])
102
+ i = hoop2_candidates[1]
103
+ suffix = suffix.replace("__idx2__", indices[i]).replace("__val2__", values[i])
104
+ else:
105
+ i = hoop1_candidates[1]
106
+ suffix = suffix.replace("__idx1__", indices[i]).replace("__val1__", values[i])
107
+ i = hoop1_candidates[10]
108
+ suffix = suffix.replace("__idx2__", indices[i]).replace("__val2__", values[i])
109
+
110
+ obj = {
111
+ "prefix": prefix,
112
+ "suffix": suffix,
113
+ "lines": lines,
114
+ "indices": indices,
115
+ "values": values,
116
+ "links": links,
117
+ "group_by_num_hoops": group_by_num_hoops,
118
+ "contains_ring": sorted(list(contains_ring)),
119
+ }
120
+ return obj
121
+
122
+
123
+ if __name__ == "__main__":
124
+ parser = argparse.ArgumentParser()
125
+ parser.add_argument("--number", type=int)
126
+ parser.add_argument("--redirect-ratio", type=float, default=0.0)
127
+ args = parser.parse_args()
128
+
129
+ num_lines = args.number
130
+
131
+ random_words_filename = "random_words.json"
132
+ random_words = json.load(open(random_words_filename, "r"))
133
+
134
+ np.random.seed(42)
135
+ obj = generate_lines(random_words, num_lines, args.redirect_ratio)
136
+
137
+ fout = f"lines_{num_lines}_{args.redirect_ratio:.1f}.json"
138
+ with open(fout, "w") as fout:
139
+ json.dump(obj, fout, indent=2)
sglang/benchmark/llava_bench/README.md ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Download benchmark images
2
+
3
+ ```
4
+ python3 download_images.py
5
+ ```
6
+
7
+ image benchmark source: https://huggingface.co/datasets/liuhaotian/llava-bench-in-the-wild
8
+
9
+ ### Other Dependency
10
+ ```
11
+ pip3 install "sglang[all]"
12
+ pip3 install "torch>=2.1.2" "transformers>=4.36" pillow
13
+ ```
14
+
15
+ ## Run benchmark
16
+
17
+ ### Benchmark sglang
18
+ Launch a server
19
+ ```
20
+ python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.6-vicuna-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --port 30000
21
+ ```
22
+
23
+ Run benchmark
24
+ ```
25
+ # Run with local models
26
+ python3 bench_sglang.py --num-questions 60
27
+
28
+ # Run with OpenAI models
29
+ python3 bench_sglang.py --num-questions 60 --backend gpt-4-vision-preview
30
+ ```
31
+
32
+ ### Bench LLaVA original code
33
+ ```
34
+ git clone [email protected]:haotian-liu/LLaVA.git
35
+ cd LLaVA
36
+ git reset --hard 9a26bd1435b4ac42c282757f2c16d34226575e96
37
+ pip3 install -e .
38
+
39
+ cd ~/sglang/benchmark/llava_bench
40
+ CUDA_VISIBLE_DEVICES=0 bash bench_hf_llava_bench.sh
41
+ ```
42
+
43
+
44
+ ### Benchmark llama.cpp
45
+
46
+ ```
47
+ # Install
48
+ CMAKE_ARGS="-DLLAMA_CUBLAS=on" pip install llama-cpp-python
49
+ pip install sse_starlette starlette_context pydantic_settings
50
+
51
+ # Download weights
52
+ mkdir -p ~/model_weights/llava-v1.5-7b/
53
+ wget https://huggingface.co/mys/ggml_llava-v1.5-7b/resolve/main/ggml-model-f16.gguf -O ~/model_weights/llava-v1.5-7b/ggml-model-f16.gguf
54
+ wget https://huggingface.co/mys/ggml_llava-v1.5-7b/resolve/main/mmproj-model-f16.gguf -O ~/model_weights/llava-v1.5-7b/mmproj-model-f16.gguf
55
+ ```
56
+
57
+ ```
58
+ python3 -m llama_cpp.server --model ~/model_weights/llava-v1.5-7b/ggml-model-f16.gguf --clip_model_path ~/model_weights/llava-v1.5-7b/mmproj-model-f16.gguf --chat_format llava-1-5 --port 23000
59
+
60
+ OPENAI_BASE_URL=http://localhost:23000/v1 python3 bench_sglang.py --backend gpt-4-vision-preview --num-q 1
61
+ ```
sglang/benchmark/llava_bench/bench_hf_llava_bench.sh ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ python -m llava.eval.model_vqa \
4
+ --model-path liuhaotian/llava-v1.5-7b \
5
+ --question-file ./questions.jsonl \
6
+ --image-folder ./images \
7
+ --answers-file ./answers_hf.jsonl \
8
+ --temperature 0 \
9
+ --conv-mode vicuna_v1
sglang/benchmark/llava_bench/bench_hf_mme.sh ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ python -m llava.eval.model_vqa_loader \
4
+ --model-path liuhaotian/llava-v1.5-7b \
5
+ --question-file ./mme_pack/llava_mme_bench_replace.jsonl \
6
+ --image-folder ./mme_pack/MME_Benchmark_release_version \
7
+ --answers-file ./answers_hf_mme.jsonl \
8
+ --temperature 0 \
9
+ --conv-mode vicuna_v1
sglang/benchmark/llava_bench/bench_sglang.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ import time
5
+
6
+ import tqdm
7
+
8
+ import sglang as sgl
9
+ from sglang.test.test_utils import (
10
+ add_common_sglang_args_and_parse,
11
+ select_sglang_backend,
12
+ )
13
+ from sglang.utils import dump_state_text, read_jsonl
14
+
15
+
16
+ @sgl.function
17
+ def image_qa(s, image_file, question):
18
+ s += sgl.user(sgl.image(image_file) + question)
19
+ s += sgl.assistant(sgl.gen("answer", max_tokens=args.max_tokens))
20
+
21
+
22
+ def main(args):
23
+ lines = list(read_jsonl(args.question_file))[: args.num_questions]
24
+ arguments = [
25
+ {
26
+ "image_file": os.path.abspath(args.image_folder + "/" + l["image"]),
27
+ "question": l["text"],
28
+ }
29
+ for l in lines
30
+ ]
31
+ # arguments = [
32
+ # {"image_file":
33
+ # Image.open(os.path.abspath(args.image_folder + "/" + l["image"])),
34
+ # "question": l["text"]} for l in lines
35
+ # ]
36
+
37
+ states = [None] * len(lines)
38
+
39
+ # Select backend
40
+ backend = select_sglang_backend(args)
41
+ sgl.set_default_backend(backend)
42
+
43
+ # Run requests
44
+ tic = time.time()
45
+ if args.parallel == 1:
46
+ for i in tqdm.tqdm(range(len(lines))):
47
+ image_file = arguments[i]["image_file"]
48
+ question = arguments[i]["question"]
49
+ ret = image_qa.run(image_file=image_file, question=question, temperature=0)
50
+ states[i] = ret
51
+ else:
52
+ states = image_qa.run_batch(
53
+ arguments, temperature=0, num_threads=args.parallel, progress_bar=True
54
+ )
55
+ latency = time.time() - tic
56
+
57
+ print(f"Latency: {latency:.3f}")
58
+
59
+ # Write results
60
+ dump_state_text(f"tmp_output_{args.backend}.txt", states)
61
+
62
+ print(f"Write output to {args.answer_file}")
63
+ with open(args.answer_file, "w") as fout:
64
+ for i in range(len(lines)):
65
+ value = {
66
+ "question_id": lines[i]["question_id"],
67
+ "prompt": lines[i]["text"],
68
+ "text": states[i]["answer"].strip(),
69
+ "model_id": backend.model_info["model_path"],
70
+ "answer_id": i,
71
+ "metadata": {},
72
+ }
73
+ fout.write(json.dumps(value) + "\n")
74
+
75
+ with open(args.result_file, "a") as fout:
76
+ value = {
77
+ "task": "llava_bench",
78
+ "backend": args.backend,
79
+ "num_gpus": 1,
80
+ "latency": round(latency, 3),
81
+ "num_requests": len(lines),
82
+ "parallel": args.parallel,
83
+ }
84
+ fout.write(json.dumps(value) + "\n")
85
+
86
+
87
+ if __name__ == "__main__":
88
+ parser = argparse.ArgumentParser()
89
+ parser.add_argument("--question-file", type=str, default="questions.jsonl")
90
+ parser.add_argument("--answer-file", type=str, default="answers.jsonl")
91
+ parser.add_argument("--image-folder", type=str, default="./images")
92
+ parser.add_argument("--temperature", type=float, default=0.0)
93
+ parser.add_argument("--num-questions", type=int, default=None)
94
+ parser.add_argument("--max-tokens", type=int, default=768)
95
+ args = add_common_sglang_args_and_parse(parser)
96
+ main(args)
sglang/benchmark/llava_bench/bench_sglang_mme.sh ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ MME_FOLDER=./mme_pack
2
+ python3 bench_sglang.py --num-questions 5000 --question-file $MME_FOLDER/llava_mme_bench_replace.jsonl --answer-file answer_mme.jsonl --image-folder $MME_FOLDER/MME_Benchmark_release_version --max-tokens 4
sglang/benchmark/llava_bench/download_images.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ # Create the 'images' directory if it doesn't exist
4
+ if not os.path.exists("images"):
5
+ os.makedirs("images")
6
+
7
+ # Base URL
8
+ base_url = "https://huggingface.co/datasets/liuhaotian/llava-bench-in-the-wild/resolve/main/images/"
9
+
10
+ # Loop through image numbers
11
+ for i in range(1, 25):
12
+ # Format the image number with leading zeros
13
+ image_number = str(i).zfill(3)
14
+ image_url = base_url + image_number + ".jpg"
15
+ image_path = "images/" + image_number + ".jpg"
16
+
17
+ # Download the image using wget
18
+ os.system(f"wget -O {image_path} {image_url}")
19
+
20
+ print("Download complete.")
sglang/benchmark/llava_bench/questions.jsonl ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {"image": "001.jpg", "text": "What is the name of this famous sight in the photo?", "category": "conv", "question_id": 0}
2
+ {"image": "001.jpg", "text": "Describe this photo in detail.", "category": "detail", "question_id": 1}
3
+ {"image": "001.jpg", "text": "What are the possible reasons of the formation of this sight?", "category": "complex", "question_id": 2}
4
+ {"image": "001.jpg", "text": "Compose an engaging travel blog post about a recent trip to this place, highlighting cultural experiences and must-see attractions, including both the attraction seen in the photo and other must-see attractions as well.", "category": "complex", "question_id": 3}
5
+ {"image": "002.jpg", "text": "What type of fruit is this?", "category": "conv", "question_id": 4}
6
+ {"image": "002.jpg", "text": "How many uncut fruits are in the image?", "category": "conv", "question_id": 5}
7
+ {"image": "002.jpg", "text": "Describe this photo in detail.", "category": "detail", "question_id": 6}
8
+ {"image": "002.jpg", "text": "Imagine the fragrance of the fruits in the image. How would you describe this to someone who has never had this fruit before?", "category": "complex", "question_id": 7}
9
+ {"image": "003.jpg", "text": "Describe this photo in detail.", "category": "detail", "question_id": 8}
10
+ {"image": "003.jpg", "text": "What might be the intended effect of this painting?", "category": "complex", "question_id": 9}
11
+ {"image": "003.jpg", "text": "Discuss how this creative twist on a classic work of art might be interpreted differently by various audiences.", "category": "complex", "question_id": 10}
12
+ {"image": "004.jpg", "text": "What is the name of the man in the photo?", "category": "conv", "question_id": 11}
13
+ {"image": "004.jpg", "text": "Which iconic movie scene is being parodied in the meme?", "category": "conv", "question_id": 12}
14
+ {"image": "004.jpg", "text": "How does this meme reflect or comment on Elon Musk's public image, personality, or actions?", "category": "complex", "question_id": 13}
15
+ {"image": "005.jpg", "text": "Please explain the meme in detail.", "category": "detail", "question_id": 14}
16
+ {"image": "005.jpg", "text": "In what other ways might someone express the same sentiment that this meme is expressing?", "category": "complex", "question_id": 15}
17
+ {"image": "006.jpg", "text": "Do you know who paint this?", "category": "conv", "question_id": 16}
18
+ {"image": "006.jpg", "text": "Describe this painting in detail.", "category": "detail", "question_id": 17}
19
+ {"image": "006.jpg", "text": "Discuss the historical impact and the significance of this painting in the art world.", "category": "complex", "question_id": 18}
20
+ {"image": "007.jpg", "text": "Describe this photo in detail.", "category": "detail", "question_id": 19}
21
+ {"image": "007.jpg", "text": "What's the best weather, season, time of the day of visiting this place? Is the time when this photo was taken a good time to visit this place?", "category": "complex", "question_id": 20}
22
+ {"image": "008.jpg", "text": "What is the name of the character in the image?", "category": "conv", "question_id": 21}
23
+ {"image": "008.jpg", "text": "What's the personality of this character? Explain what elements or aspects of the character's design may have contributed to its popularity.", "category": "complex", "question_id": 22}
24
+ {"image": "009.jpg", "text": "What are the things I should be cautious about when I visit here?", "category": "complex", "question_id": 23}
25
+ {"image": "009.jpg", "text": "If you were a photographer looking to capture this location's essence, what time of day and weather conditions would you choose? Describe the reasons behind your choice.", "category": "complex", "question_id": 24}
26
+ {"image": "010.jpg", "text": "Describe this photo in detail.", "category": "detail", "question_id": 25}
27
+ {"image": "010.jpg", "text": "What is unusual about this image?", "category": "complex", "question_id": 26}
28
+ {"image": "011.jpg", "text": "What fruit is in the left part of the fridge?", "category": "conv", "question_id": 27}
29
+ {"image": "011.jpg", "text": "What is the brand of the yogurt flavored with blueberry?", "category": "conv", "question_id": 28}
30
+ {"image": "011.jpg", "text": "Is there any strawberry-flavored yogurt in the fridge?", "category": "conv", "question_id": 29}
31
+ {"image": "011.jpg", "text": "Describe this photo in detail.", "category": "detail", "question_id": 30}
32
+ {"image": "011.jpg", "text": "What are the meals that I can cook with these?", "category": "complex", "question_id": 31}
33
+ {"image": "012.jpg", "text": "How many coffee mugs are in the set?", "category": "conv", "question_id": 32}
34
+ {"image": "012.jpg", "text": "Write an attractive product description for this.", "category": "complex", "question_id": 33}
35
+ {"image": "013.jpg", "text": "Show the detailed recipe for this dish.", "category": "complex", "question_id": 34}
36
+ {"image": "014.jpg", "text": "Can you explain this meme in detail?", "category": "complex", "question_id": 35}
37
+ {"image": "015.jpg", "text": "What are the two machine learning concepts mentioned in the meme?", "category": "conv", "question_id": 36}
38
+ {"image": "015.jpg", "text": "Give a detailed description of this meme.", "category": "detail", "question_id": 37}
39
+ {"image": "015.jpg", "text": "Can you explain why this is funny. Think about it step-by-step.", "category": "complex", "question_id": 38}
40
+ {"image": "016.jpg", "text": "Give a detailed description of this image. Describe it panel by panel.", "category": "detail", "question_id": 39}
41
+ {"image": "016.jpg", "text": "What is funny about this image? Describe it panel by panel.", "category": "complex", "question_id": 40}
42
+ {"image": "017.jpg", "text": "What material appears to make up the creature?", "category": "conv", "question_id": 41}
43
+ {"image": "017.jpg", "text": "This is the logo of LLaVA, Large Language and Vision Assistant, based on the LLaMA architecture. Please explain this logo in detail, and how do you think of its design.", "category": "complex", "question_id": 42}
44
+ {"image": "018.jpg", "text": "What are the animals in the painting and what are they doing?", "category": "conv", "question_id": 43}
45
+ {"image": "018.jpg", "text": "Write a fairy tale based on this painting.", "category": "complex", "question_id": 44}
46
+ {"image": "019.jpg", "text": "Describe this sketch in detail.", "category": "detail", "question_id": 45}
47
+ {"image": "019.jpg", "text": "Write brief HTML/JS to turn this mock-up into a colorful website, where the jokes are replaced by two real jokes.", "category": "complex", "question_id": 46}
48
+ {"image": "020.jpg", "text": "Describe this sketch in detail.", "category": "detail", "question_id": 47}
49
+ {"image": "020.jpg", "text": "Write brief HTML/JS to turn this mock-up into a colorful and interactive website, where the joke is replaced by a real joke.", "category": "complex", "question_id": 48}
50
+ {"image": "021.jpg", "text": "What's the ending of this movie?", "category": "conv", "question_id": 49}
51
+ {"image": "021.jpg", "text": "What is the significance of this scene in the context of the movie?", "category": "complex", "question_id": 50}
52
+ {"image": "022.jpg", "text": "What's the name of the restaurant serving these dishes?", "category": "conv", "question_id": 51}
53
+ {"image": "022.jpg", "text": "Describe this photo in detail.", "category": "detail", "question_id": 52}
54
+ {"image": "022.jpg", "text": "If someone were to recommend a new flavor or topping to the dish, describe the reason for this change and how it might alter the overall taste.", "category": "complex", "question_id": 53}
55
+ {"image": "023.jpg", "text": "What brand is featured in this advertisement?", "category": "conv", "question_id": 54}
56
+ {"image": "023.jpg", "text": "Describe this photo in detail.", "category": "detail", "question_id": 55}
57
+ {"image": "023.jpg", "text": "Show me a detailed recipe for cooking this at home.", "category": "complex", "question_id": 56}
58
+ {"image": "024.jpg", "text": "Describe this photo in detail.", "category": "detail", "question_id": 57}
59
+ {"image": "024.jpg", "text": "What is the problem this city might be facing? What are some possible solutions?", "category": "complex", "question_id": 58}
60
+ {"image": "024.jpg", "text": "Explain all the cues that indicate the current traffic conditions.", "category": "complex", "question_id": 59}
sglang/benchmark/llm_judge/README.md ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Run benchmark
2
+
3
+ ### Benchmark sglang
4
+ ```
5
+ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
6
+ ```
7
+
8
+ ```
9
+ python3 bench_sglang.py --num-questions 25 --parallel 8
10
+ python3 bench_sglang.py --num-questions 16 --parallel 1
11
+ ```
12
+
13
+
14
+ ### Benchmark vllm
15
+ ```
16
+ python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
17
+ ```
18
+
19
+ ```
20
+ python3 bench_other.py --backend vllm --num-questions 25
21
+ ```
22
+
23
+
24
+ ### Benchmark guidance
25
+ ```
26
+ python3 bench_other.py --backend guidance --num-questions 25 --parallel 1 --n-ctx 4096 --model-path path/to/gguf
27
+ ```
28
+
29
+ ### Benchmark lmql
30
+
31
+ ```
32
+ python3 bench_other.py --backend lmql --num-questions 25 --parallel 1
33
+ ```
sglang/benchmark/llm_judge/articles.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
sglang/benchmark/llm_judge/bench_other.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import time
4
+ from concurrent.futures import ThreadPoolExecutor
5
+ from functools import partial
6
+
7
+ from tqdm import tqdm
8
+
9
+ from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
10
+ from sglang.utils import dump_state_text, read_jsonl
11
+
12
+ system_prompt = "Please serve as an impartial judge and rigorously evaluate the quality of the following article. Apply the most stringent standards possible, showing no leniency."
13
+
14
+ dimension_prompts = [
15
+ "Content: This refers to the essences of the essay. The substance should be well researched, accurate, relevant to the topic and should show a thorough understanding of the subject. The essay should also reflect a clear goal or purpose.",
16
+ "Organization and Structure: An essay needs to be properly structured with a clear introduction, body, and conclusion. The essay should flow naturally, with one paragraph leading seamlessly into the next.",
17
+ "Argument and Analysis: The argument made in the essay should be logical, coherent and clearly articulated. Each point made should be backed up by solid evidence and thorough analysis.",
18
+ "Clarity and Precision: The essay should be written in a clear and concise manner. The points made should be easily understood by the reader. The language used should also be precise and unambiguous.",
19
+ "Grammar and Punctuation: Proper use of grammar and punctuation is vital in an academic essay. Errors in grammar and punctuation not only distract the reader but can also negatively impact the meaning and interpretation of the content.",
20
+ "Referencing and Citation: An essay should contain proper citations and references for all sources used. This not only prevents accusations of plagiarism but also gives credit to the authors of the works that have contributed to the essay. The citation should adhere to a specific format as required by the academic institution or specified by the professor.",
21
+ ]
22
+
23
+
24
+ def multi_dimension_judge(article, generate):
25
+ s = system_prompt
26
+ s += "\n```\n" + article + "\n```\n\n"
27
+
28
+ judges = []
29
+ for i in range(len(dimension_prompts)):
30
+ comp = generate(
31
+ s
32
+ + "USER: Please judge the quality based on the following metric. "
33
+ + dimension_prompts[i]
34
+ + " Please provide a single-paragraph judgement. "
35
+ + "Focus on the provided metric and do not say other things. "
36
+ 'End your judgement paragraph with the word "END"\nJUDGE:',
37
+ max_tokens=256,
38
+ stop="END",
39
+ )
40
+ judges.append(comp)
41
+
42
+ s += "I will judge the quality based on the following metrics.\n"
43
+ for i in range(len(dimension_prompts)):
44
+ s += dimension_prompts[i].split(":")[0] + ": " + judges[i].strip() + "\n"
45
+
46
+ s += "In summary, on a scale of 1 to 10, I would give the article a score of"
47
+ s += generate(s, max_tokens=2, stop=None)
48
+
49
+ return s
50
+
51
+
52
+ async def multi_dimension_judge_async(article, generate):
53
+ s = system_prompt
54
+ s += "\n```\n" + article + "\n```\n\n"
55
+
56
+ judges = []
57
+ for i in range(len(dimension_prompts)):
58
+ comp = await generate(
59
+ s
60
+ + "USER: Please judge the quality based on the following metric. "
61
+ + dimension_prompts[i]
62
+ + " Please provide a single-paragraph judgement. "
63
+ + "Focus on the provided metric and do not say other things. "
64
+ 'End your judgement paragraph with the word "END"\nJUDGE:',
65
+ max_tokens=256,
66
+ stop="END",
67
+ )
68
+ judges.append(comp)
69
+
70
+ s += "I will judge the quality based on the following metrics.\n"
71
+ for i in range(len(dimension_prompts)):
72
+ s += dimension_prompts[i].split(":")[0] + ": " + judges[i].strip() + "\n"
73
+
74
+ s += "In summary, on a scale of 1 to 10, I would give the article a score of"
75
+ s += await generate(s, max_tokens=2, stop=None)
76
+
77
+ return s
78
+
79
+
80
+ def main(args):
81
+ lines = read_jsonl(args.data_path)[: args.num_questions]
82
+ states = [None] * len(lines)
83
+
84
+ # Select backend
85
+ call_generate = partial(get_call_generate(args), temperature=0)
86
+
87
+ # Run requests
88
+ tic = time.time()
89
+
90
+ if args.backend != "lmql":
91
+
92
+ def get_one_answer(i):
93
+ states[i] = multi_dimension_judge(lines[i], call_generate)
94
+
95
+ if args.parallel == 1:
96
+ for i in tqdm(range(len(lines))):
97
+ get_one_answer(i)
98
+ else:
99
+ with ThreadPoolExecutor(args.parallel) as executor:
100
+ list(
101
+ tqdm(
102
+ executor.map(get_one_answer, list(range(len(lines)))),
103
+ total=len(lines),
104
+ )
105
+ )
106
+
107
+ else:
108
+ import asyncio
109
+
110
+ async def get_one_answer_async(i):
111
+ states[i] = await multi_dimension_judge_async(lines[i], call_generate)
112
+
113
+ batches = []
114
+ for i in range(0, len(lines), args.parallel):
115
+ batches.append(list(range(i, min(i + args.parallel, len(lines)))))
116
+
117
+ loop = asyncio.get_event_loop()
118
+ for bt in tqdm(batches):
119
+ loop.run_until_complete(
120
+ asyncio.gather(*[get_one_answer_async(i) for i in bt])
121
+ )
122
+
123
+ latency = time.time() - tic
124
+
125
+ # Compute accuracy
126
+ print(f"Latency: {latency:.3f}")
127
+
128
+ # Write results
129
+ dump_state_text(f"tmp_output_{args.backend}.txt", states)
130
+
131
+ with open(args.result_file, "a") as fout:
132
+ value = {
133
+ "task": "llm_judge",
134
+ "backend": args.backend,
135
+ "num_gpus": 1,
136
+ "latency": round(latency, 3),
137
+ "num_requests": args.num_questions,
138
+ "other": {
139
+ "num_questions": args.num_questions,
140
+ "parallel": args.parallel,
141
+ },
142
+ }
143
+ fout.write(json.dumps(value) + "\n")
144
+
145
+
146
+ if __name__ == "__main__":
147
+ parser = argparse.ArgumentParser()
148
+ parser.add_argument("--data-path", type=str, default="articles.jsonl")
149
+ parser.add_argument("--num-questions", type=int, default=20)
150
+ args = add_common_other_args_and_parse(parser)
151
+ main(args)
sglang/benchmark/llm_judge/bench_sglang.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import time
4
+
5
+ import sglang as sgl
6
+ from sglang.test.test_utils import (
7
+ add_common_sglang_args_and_parse,
8
+ select_sglang_backend,
9
+ )
10
+ from sglang.utils import dump_state_text, read_jsonl
11
+
12
+ system_prompt = "Please serve as an impartial judge and rigorously evaluate the quality of the following article. Apply the most stringent standards possible, showing no leniency."
13
+
14
+ dimension_prompts = [
15
+ "Content: This refers to the essences of the essay. The substance should be well researched, accurate, relevant to the topic and should show a thorough understanding of the subject. The essay should also reflect a clear goal or purpose.",
16
+ "Organization and Structure: An essay needs to be properly structured with a clear introduction, body, and conclusion. The essay should flow naturally, with one paragraph leading seamlessly into the next.",
17
+ "Argument and Analysis: The argument made in the essay should be logical, coherent and clearly articulated. Each point made should be backed up by solid evidence and thorough analysis.",
18
+ "Clarity and Precision: The essay should be written in a clear and concise manner. The points made should be easily understood by the reader. The language used should also be precise and unambiguous.",
19
+ "Grammar and Punctuation: Proper use of grammar and punctuation is vital in an academic essay. Errors in grammar and punctuation not only distract the reader but can also negatively impact the meaning and interpretation of the content.",
20
+ "Referencing and Citation: An essay should contain proper citations and references for all sources used. This not only prevents accusations of plagiarism but also gives credit to the authors of the works that have contributed to the essay. The citation should adhere to a specific format as required by the academic institution or specified by the professor.",
21
+ ]
22
+
23
+
24
+ @sgl.function
25
+ def multi_dimension_judge(s, article):
26
+ s += system_prompt
27
+ s += "\n```\n" + article + "\n```\n\n"
28
+
29
+ forks = s.fork(len(dimension_prompts))
30
+ for i in range(len(dimension_prompts)):
31
+ forks[i] += (
32
+ "USER: Please judge the quality based on the following metric. "
33
+ + dimension_prompts[i]
34
+ + " Please provide a single-paragraph judgement. "
35
+ + "Focus on the provided metric and do not say other things. "
36
+ 'End your judgement paragraph with the word "END"\nJUDGE:'
37
+ )
38
+ forks[i] += sgl.gen("judgement", max_tokens=256, stop="END")
39
+ forks.join()
40
+
41
+ s += "I will judge the quality based on the following metrics.\n"
42
+ for i in range(len(dimension_prompts)):
43
+ s += (
44
+ dimension_prompts[i].split(":")[0]
45
+ + ": "
46
+ + forks[i]["judgement"].strip()
47
+ + "\n"
48
+ )
49
+
50
+ s += "In summary, on a scale of 1 to 10, I would give the article a score of"
51
+ s += sgl.gen("score", max_tokens=2)
52
+
53
+
54
+ def main(args):
55
+ lines = read_jsonl(args.data_path)[: args.num_questions]
56
+ arguments = [{"article": l} for l in lines]
57
+
58
+ # Select backend
59
+ backend = select_sglang_backend(args)
60
+
61
+ # Run requests
62
+ tic = time.time()
63
+ states = multi_dimension_judge.run_batch(
64
+ arguments,
65
+ temperature=0,
66
+ backend=backend,
67
+ num_threads=args.parallel,
68
+ progress_bar=True,
69
+ )
70
+ latency = time.time() - tic
71
+
72
+ print(f"Latency: {latency:.3f}")
73
+
74
+ # Write results
75
+ dump_state_text(f"tmp_output_{args.backend}.txt", states)
76
+
77
+ with open(args.result_file, "a") as fout:
78
+ value = {
79
+ "task": "llm_judge",
80
+ "backend": args.backend,
81
+ "num_gpus": 1,
82
+ "latency": round(latency, 3),
83
+ "num_requests": args.num_questions,
84
+ "other": {
85
+ "num_questions": args.num_questions,
86
+ "parallel": args.parallel,
87
+ },
88
+ }
89
+ fout.write(json.dumps(value) + "\n")
90
+
91
+
92
+ if __name__ == "__main__":
93
+ parser = argparse.ArgumentParser()
94
+ parser.add_argument("--data-path", type=str, default="articles.jsonl")
95
+ parser.add_argument("--num-questions", type=int, default=20)
96
+ args = add_common_sglang_args_and_parse(parser)
97
+ main(args)
sglang/benchmark/long_json_decode/README.md ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Run benchmark
2
+
3
+ ### Benchmark sglang
4
+ ```
5
+ python3 -m sglang.launch_server --model-path codellama/CodeLlama-7b-instruct-hf --port 30000
6
+ ```
7
+
8
+ ```
9
+ python3 bench_sglang.py --num-questions 5 --parallel 1
10
+ ```
11
+
12
+
13
+ ### Benchmark vllm
14
+ ```
15
+ python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model codellama/CodeLlama-7b-instruct-hf --disable-log-requests --port 21000 --gpu 0.97
16
+ ```
17
+
18
+ ```
19
+ python3 bench_other.py --backend vllm --num-questions 5
20
+ ```
21
+
22
+
23
+ ### Benchmark guidance
24
+ ```
25
+ python3 bench_other.py --backend guidance --num-questions 5 --parallel 1 --n-ctx 11000 --model-path path/to/code-llama/gguf
26
+ ```
27
+
28
+
29
+ ### Build dataset
30
+ ```
31
+ pip install wikipedia
32
+ python3 build_dataset.py
33
+ ```
sglang/benchmark/long_json_decode/bench_other.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import time
4
+ from concurrent.futures import ThreadPoolExecutor
5
+ from functools import partial
6
+
7
+ from tqdm import tqdm
8
+
9
+ from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
10
+ from sglang.utils import dump_state_text, read_jsonl
11
+
12
+
13
+ def json_decode(document, generate):
14
+ s = "Please extract the information of a city from the following wikipedia page.\n"
15
+ s += "Page begin.\n" + document + "Page end.\n"
16
+ s += "Here is the name, country, and symbol of the city in JSON format.\n"
17
+ s += "{\n"
18
+ s += ' "name": "'
19
+ s += generate(s, max_tokens=8, stop='"') + '",\n'
20
+ s += ' "country": "'
21
+ s += generate(s, max_tokens=8, stop='"') + '",\n'
22
+ s += ' "air port code": "'
23
+ s += generate(s, max_tokens=8, stop='"') + '",\n'
24
+ s += ' "top 3 landmarks": "'
25
+ s += generate(s, max_tokens=24, stop='"') + '",\n'
26
+ s += "}\n"
27
+ return s
28
+
29
+
30
+ def main(args):
31
+ lines = read_jsonl(args.data_path)
32
+ arguments = []
33
+ for i in range(len(lines[: args.num_questions])):
34
+ arguments.append(
35
+ {
36
+ "document": lines[i]["document"],
37
+ }
38
+ )
39
+ states = [None] * len(arguments)
40
+
41
+ # Select backend
42
+ call_generate = partial(get_call_generate(args), temperature=0)
43
+
44
+ # Run requests
45
+ def get_one_answer(i):
46
+ states[i] = json_decode(generate=call_generate, **arguments[i])
47
+
48
+ tic = time.time()
49
+ if args.parallel == 1:
50
+ for i in tqdm(range(len(arguments))):
51
+ get_one_answer(i)
52
+ else:
53
+ with ThreadPoolExecutor(args.parallel) as executor:
54
+ list(
55
+ tqdm(
56
+ executor.map(get_one_answer, list(range(len(arguments)))),
57
+ total=len(arguments),
58
+ )
59
+ )
60
+
61
+ latency = time.time() - tic
62
+
63
+ # Compute accuracy
64
+ print(f"Latency: {latency:.3f}")
65
+
66
+ # Write results
67
+ dump_state_text(f"tmp_output_{args.backend}.txt", states)
68
+
69
+ with open(args.result_file, "a") as fout:
70
+ value = {
71
+ "task": "long_json_decode",
72
+ "backend": args.backend,
73
+ "num_gpus": 1,
74
+ "latency": round(latency, 3),
75
+ "num_requests": args.num_questions,
76
+ "other": {
77
+ "num_questions": args.num_questions,
78
+ "parallel": args.parallel,
79
+ },
80
+ }
81
+ fout.write(json.dumps(value) + "\n")
82
+
83
+
84
+ if __name__ == "__main__":
85
+ parser = argparse.ArgumentParser()
86
+ parser.add_argument("--data-path", type=str, default="questions.jsonl")
87
+ parser.add_argument("--num-questions", type=int, default=100)
88
+ args = add_common_other_args_and_parse(parser)
89
+ main(args)
sglang/benchmark/long_json_decode/bench_sglang.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import time
4
+
5
+ import sglang as sgl
6
+ from sglang.test.test_utils import (
7
+ add_common_sglang_args_and_parse,
8
+ select_sglang_backend,
9
+ )
10
+ from sglang.utils import dump_state_text, read_jsonl
11
+
12
+
13
+ @sgl.function
14
+ def json_decode(s, document):
15
+ s += "Please extract the information of a city from the following wikipedia page.\n"
16
+ s += "Page begin.\n" + document + "Page end.\n"
17
+ s += "Here is the name, country, and symbol of the city in JSON format.\n"
18
+ s += "{\n"
19
+ s += ' "name": "' + sgl.gen("name", max_tokens=8, stop='"') + '",\n'
20
+ s += ' "country": "' + sgl.gen("country", max_tokens=8, stop='"') + '",\n'
21
+ s += (
22
+ ' "air port code": "'
23
+ + sgl.gen("air port code", max_tokens=8, stop='"')
24
+ + '",\n'
25
+ )
26
+ s += (
27
+ ' "top 3 landmarks": "'
28
+ + sgl.gen("landmarks", max_tokens=24, stop='"')
29
+ + '",\n'
30
+ )
31
+ s += "}\n"
32
+
33
+
34
+ def main(args):
35
+ lines = read_jsonl(args.data_path)
36
+ arguments = []
37
+ for i in range(len(lines[: args.num_questions])):
38
+ arguments.append(
39
+ {
40
+ "document": lines[i]["document"],
41
+ }
42
+ )
43
+
44
+ # Select backend
45
+ backend = select_sglang_backend(args)
46
+ sgl.set_default_backend(backend)
47
+
48
+ # Run requests
49
+ tic = time.time()
50
+ states = json_decode.run_batch(
51
+ arguments, temperature=0, num_threads=args.parallel, progress_bar=True
52
+ )
53
+ latency = time.time() - tic
54
+
55
+ # Compute accuracy
56
+ print(f"Latency: {latency:.3f}")
57
+
58
+ # Write results
59
+ dump_state_text(f"tmp_output_{args.backend}.txt", states)
60
+
61
+ with open(args.result_file, "a") as fout:
62
+ value = {
63
+ "task": "long_json_decode",
64
+ "backend": args.backend,
65
+ "num_gpus": 1,
66
+ "latency": round(latency, 3),
67
+ "num_requests": args.num_questions,
68
+ "other": {
69
+ "num_questions": args.num_questions,
70
+ "parallel": args.parallel,
71
+ },
72
+ }
73
+ fout.write(json.dumps(value) + "\n")
74
+
75
+
76
+ if __name__ == "__main__":
77
+ parser = argparse.ArgumentParser()
78
+ parser.add_argument("--data-path", type=str, default="questions.jsonl")
79
+ parser.add_argument("--num-questions", type=int, default=10)
80
+ args = add_common_sglang_args_and_parse(parser)
81
+ main(args)
sglang/benchmark/long_json_decode/build_dataset.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ import transformers
4
+ import wikipedia
5
+
6
+ name = "meta-llama/Llama-2-7b-chat-hf"
7
+ t = transformers.AutoTokenizer.from_pretrained(name)
8
+ city_names = ["los angles", "london", "tokyo", "beijing", "singapore"]
9
+
10
+
11
+ for city_name in city_names:
12
+ content = str(wikipedia.page(city_name).content)
13
+ content = content.replace("\n\n", "\n")
14
+
15
+ tokens = t.encode(content)
16
+
17
+ truncate_len = int((10000 / len(tokens)) * len(content))
18
+ truncate_content = content[:truncate_len]
19
+ truncate_tokens = t.encode(truncate_content)
20
+
21
+ # Count token
22
+ print(
23
+ f"city_name: {city_name}, #tokens: {len(tokens)}, #truncate tokens: {len(truncate_tokens)}"
24
+ )
25
+
26
+ with open("questions.jsonl", "a") as fout:
27
+ fout.write(json.dumps({"document": truncate_content}) + "\n")
sglang/benchmark/mmlu/bench_other.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import asyncio
3
+ import json
4
+ import os
5
+ import time
6
+ from concurrent.futures import ThreadPoolExecutor
7
+
8
+ import numpy as np
9
+ import pandas as pd
10
+ import tiktoken
11
+ from tqdm import tqdm
12
+
13
+ from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
14
+
15
+ choices = ["A", "B", "C", "D"]
16
+
17
+ tokenizer = tiktoken.encoding_for_model("gpt-3.5-turbo")
18
+
19
+
20
+ def format_subject(subject):
21
+ l = subject.split("_")
22
+ s = ""
23
+ for entry in l:
24
+ s += " " + entry
25
+ return s
26
+
27
+
28
+ def format_example(df, idx, include_answer=True):
29
+ prompt = df.iloc[idx, 0]
30
+ k = df.shape[1] - 2
31
+ for j in range(k):
32
+ prompt += "\n{}. {}".format(choices[j], df.iloc[idx, j + 1])
33
+ prompt += "\nAnswer:"
34
+ if include_answer:
35
+ prompt += " {}\n\n".format(df.iloc[idx, k + 1])
36
+ return prompt
37
+
38
+
39
+ def gen_prompt(train_df, subject, k=-1):
40
+ prompt = "The following are multiple choice questions (with answers) about{}.\n\n".format(
41
+ format_subject(subject)
42
+ )
43
+ if k == -1:
44
+ k = train_df.shape[0]
45
+ for i in range(k):
46
+ prompt += format_example(train_df, i)
47
+ return prompt
48
+
49
+
50
+ def evaluate(args, subject, dev_df, test_df, call_generate):
51
+ prompts = []
52
+ labels = []
53
+
54
+ # Construct prompts
55
+ k = args.ntrain
56
+ train_prompt = gen_prompt(dev_df, subject, k)
57
+ while len(tokenizer.encode(train_prompt)) > 1536:
58
+ k -= 1
59
+ train_prompt = gen_prompt(dev_df, subject, k)
60
+
61
+ for i in range(test_df.shape[0]):
62
+ prompt_end = format_example(test_df, i, include_answer=False)
63
+ prompt = train_prompt + prompt_end
64
+ prompts.append(prompt)
65
+
66
+ label = test_df.iloc[i, test_df.shape[1] - 1]
67
+ labels.append(label)
68
+
69
+ preds = [None] * len(prompts)
70
+ max_tokens = 1
71
+
72
+ # Run requests
73
+ if args.backend != "lmql":
74
+ # Use thread pool
75
+ def get_one_answer(i):
76
+ pred = call_generate(prompts[i], temperature=0, max_tokens=max_tokens)
77
+ preds[i] = pred.strip()[0]
78
+
79
+ tic = time.time()
80
+ if args.parallel == 1:
81
+ for i in range(len(prompts)):
82
+ get_one_answer(i)
83
+ else:
84
+ with ThreadPoolExecutor(args.parallel) as executor:
85
+ executor.map(get_one_answer, list(range(len(prompts))))
86
+ else:
87
+ # Use asyncio
88
+ async def batched_call(batch_size):
89
+ for i in range(0, len(prompts), batch_size):
90
+ tasks = []
91
+ for p in prompts[i : i + batch_size]:
92
+ tasks.append(call_generate(p, temperature=0, max_tokens=max_tokens))
93
+ rets = await asyncio.gather(*tasks)
94
+ for j in range(len(rets)):
95
+ preds[i + j] = rets[j].strip()[0]
96
+
97
+ tic = time.time()
98
+ asyncio.run(batched_call(batch_size=args.parallel))
99
+ latency = time.time() - tic
100
+
101
+ # Compute accuracy
102
+ cors = [pred == label for pred, label in zip(preds, labels)]
103
+ acc = np.mean(cors)
104
+ cors = np.array(cors)
105
+
106
+ print(
107
+ "Average accuracy {:.3f}, latency {:.2f}, #q: {} - {}".format(
108
+ acc, latency, len(prompts), subject
109
+ )
110
+ )
111
+
112
+ return cors, acc, latency
113
+
114
+
115
+ def main(args):
116
+ subjects = sorted(
117
+ [
118
+ f.split("_test.csv")[0]
119
+ for f in os.listdir(os.path.join(args.data_dir, "test"))
120
+ if "_test.csv" in f
121
+ ]
122
+ )
123
+
124
+ all_cors = []
125
+ all_latencies = []
126
+ num_requests = 0
127
+
128
+ # Select backend
129
+ call_generate = get_call_generate(args)
130
+
131
+ for subject in tqdm(subjects[: args.nsub]):
132
+ dev_df = pd.read_csv(
133
+ os.path.join(args.data_dir, "dev", subject + "_dev.csv"), header=None
134
+ )[: args.ntrain]
135
+ test_df = pd.read_csv(
136
+ os.path.join(args.data_dir, "test", subject + "_test.csv"), header=None
137
+ )
138
+
139
+ cors, acc, latency = evaluate(args, subject, dev_df, test_df, call_generate)
140
+ all_cors.append(cors)
141
+ all_latencies.append(latency)
142
+ num_requests += len(test_df)
143
+
144
+ total_latency = np.sum(all_latencies)
145
+ print("Total latency: {:.3f}".format(total_latency))
146
+
147
+ weighted_acc = np.mean(np.concatenate(all_cors))
148
+ print("Average accuracy: {:.3f}".format(weighted_acc))
149
+
150
+ # Write results
151
+ with open(args.result_file, "a") as fout:
152
+ value = {
153
+ "task": "mmlu",
154
+ "backend": args.backend,
155
+ "num_gpus": 1,
156
+ "latency": round(total_latency, 3),
157
+ "accuracy": round(weighted_acc, 3),
158
+ "num_requests": num_requests,
159
+ "other": {
160
+ "nsub": args.nsub,
161
+ "parallel": args.parallel,
162
+ },
163
+ }
164
+ fout.write(json.dumps(value) + "\n")
165
+
166
+
167
+ if __name__ == "__main__":
168
+ parser = argparse.ArgumentParser()
169
+ parser.add_argument("--ntrain", type=int, default=5)
170
+ parser.add_argument("--data_dir", type=str, default="data")
171
+ parser.add_argument("--nsub", type=int, default=60)
172
+ args = add_common_other_args_and_parse(parser)
173
+ main(args)
sglang/benchmark/mmlu/bench_sglang.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ import time
5
+
6
+ import numpy as np
7
+ import pandas as pd
8
+ import tiktoken
9
+ from tqdm import tqdm
10
+
11
+ from sglang.test.test_utils import (
12
+ add_common_sglang_args_and_parse,
13
+ select_sglang_backend,
14
+ )
15
+
16
+ choices = ["A", "B", "C", "D"]
17
+
18
+ tokenizer = tiktoken.encoding_for_model("gpt-3.5-turbo")
19
+
20
+
21
+ def format_subject(subject):
22
+ l = subject.split("_")
23
+ s = ""
24
+ for entry in l:
25
+ s += " " + entry
26
+ return s
27
+
28
+
29
+ def format_example(df, idx, include_answer=True):
30
+ prompt = df.iloc[idx, 0]
31
+ k = df.shape[1] - 2
32
+ for j in range(k):
33
+ prompt += "\n{}. {}".format(choices[j], df.iloc[idx, j + 1])
34
+ prompt += "\nAnswer:"
35
+ if include_answer:
36
+ prompt += " {}\n\n".format(df.iloc[idx, k + 1])
37
+ return prompt
38
+
39
+
40
+ def gen_prompt(train_df, subject, k=-1):
41
+ prompt = "The following are multiple choice questions (with answers) about{}.\n\n".format(
42
+ format_subject(subject)
43
+ )
44
+ if k == -1:
45
+ k = train_df.shape[0]
46
+ for i in range(k):
47
+ prompt += format_example(train_df, i)
48
+ return prompt
49
+
50
+
51
+ def main(args):
52
+ subjects = sorted(
53
+ [
54
+ f.split("_test.csv")[0]
55
+ for f in os.listdir(os.path.join(args.data_dir, "test"))
56
+ if "_test.csv" in f
57
+ ]
58
+ )
59
+
60
+ # Build prompts
61
+ arguments = []
62
+ labels = []
63
+ num_questions = []
64
+
65
+ for subject in subjects[: args.nsub]:
66
+ dev_df = pd.read_csv(
67
+ os.path.join(args.data_dir, "dev", subject + "_dev.csv"), header=None
68
+ )[: args.ntrain]
69
+ test_df = pd.read_csv(
70
+ os.path.join(args.data_dir, "test", subject + "_test.csv"), header=None
71
+ )
72
+ num_questions.append(test_df.shape[0])
73
+
74
+ k = args.ntrain
75
+ few_shot_examples = gen_prompt(dev_df, subject, k)
76
+ while len(tokenizer.encode(few_shot_examples)) > 1536:
77
+ k -= 1
78
+ few_shot_examples = gen_prompt(dev_df, subject, k)
79
+
80
+ for i in range(test_df.shape[0]):
81
+ prompt_end = format_example(test_df, i, include_answer=False)
82
+
83
+ arguments.append(
84
+ {
85
+ "examples": few_shot_examples,
86
+ "question": prompt_end,
87
+ }
88
+ )
89
+
90
+ label = test_df.iloc[i, test_df.shape[1] - 1]
91
+ labels.append(label)
92
+
93
+ #####################################
94
+ ######### SGL Program Begin #########
95
+ #####################################
96
+
97
+ import sglang as sgl
98
+
99
+ if args.backend.startswith("gpt-"):
100
+
101
+ @sgl.function
102
+ def few_shot_mmlu(s, examples, question):
103
+ s += sgl.user(examples + question)
104
+ s += sgl.assistant(sgl.gen("answer"))
105
+
106
+ else:
107
+
108
+ @sgl.function
109
+ def few_shot_mmlu(s, examples, question):
110
+ s += examples + question + sgl.gen("answer")
111
+
112
+ #####################################
113
+ ########## SGL Program End ##########
114
+ #####################################
115
+
116
+ # Select backend
117
+ backend = select_sglang_backend(args)
118
+
119
+ # Run
120
+ tic = time.time()
121
+ states = few_shot_mmlu.run_batch(
122
+ arguments,
123
+ temperature=0,
124
+ max_new_tokens=1,
125
+ backend=backend,
126
+ num_threads=args.parallel,
127
+ progress_bar=True,
128
+ )
129
+ preds = [
130
+ s["answer"].strip()[0] if len(s["answer"].strip()) > 0 else "" for s in states
131
+ ]
132
+ latency = time.time() - tic
133
+
134
+ # Compute accuracy
135
+ cors = [pred == label for pred, label in zip(preds, labels)]
136
+
137
+ pt = 0
138
+ for subject, num_qs in zip(subjects[: args.nsub], num_questions):
139
+ print(
140
+ f"subject: {subject}, #q:{num_qs}, acc: {np.mean(cors[pt: pt + num_qs]):.3f}"
141
+ )
142
+ pt += num_qs
143
+ assert pt == len(cors)
144
+ weighted_acc = np.mean(cors)
145
+
146
+ # Print results
147
+ print("Total latency: {:.3f}".format(latency))
148
+ print("Average accuracy: {:.3f}".format(weighted_acc))
149
+
150
+ # Write results
151
+ with open(args.result_file, "a") as fout:
152
+ value = {
153
+ "task": "mmlu",
154
+ "backend": args.backend,
155
+ "num_gpus": 1,
156
+ "latency": round(latency, 3),
157
+ "accuracy": round(weighted_acc, 3),
158
+ "num_requests": len(arguments),
159
+ "other": {
160
+ "nsub": args.nsub,
161
+ "parallel": args.parallel,
162
+ },
163
+ }
164
+ fout.write(json.dumps(value) + "\n")
165
+
166
+
167
+ if __name__ == "__main__":
168
+ parser = argparse.ArgumentParser()
169
+ parser.add_argument("--ntrain", "-k", type=int, default=5)
170
+ parser.add_argument("--data_dir", "-d", type=str, default="data")
171
+ parser.add_argument("--save_dir", "-s", type=str, default="results")
172
+ parser.add_argument("--nsub", type=int, default=60)
173
+ args = add_common_sglang_args_and_parse(parser)
174
+ main(args)
sglang/benchmark/mmlu/download_data.sh ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ wget https://people.eecs.berkeley.edu/~hendrycks/data.tar
2
+ tar xf data.tar
sglang/benchmark/multi_chain_reasoning/bench_other.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import ast
3
+ import asyncio
4
+ import json
5
+ import re
6
+ import time
7
+ from concurrent.futures import ThreadPoolExecutor
8
+
9
+ import numpy as np
10
+ from tqdm import tqdm
11
+
12
+ from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
13
+ from sglang.utils import dump_state_text, read_jsonl
14
+
15
+ INVALID = -9999999
16
+
17
+
18
+ def get_answer_value(answer_str):
19
+ answer_str = answer_str.replace(",", "")
20
+ numbers = re.findall(r"\d+", answer_str)
21
+ if len(numbers) < 1:
22
+ return INVALID
23
+ try:
24
+ return ast.literal_eval(numbers[-1])
25
+ except SyntaxError:
26
+ return INVALID
27
+
28
+
29
+ prompt_lib = [
30
+ "Let us think step by step.",
31
+ "Approach this methodically. Let's dissect the problem into smaller, more manageable parts.",
32
+ "It's important to proceed step by step, ensuring accuracy at each stage.",
33
+ "Take a deep breath and break this down.",
34
+ "A little bit of arithmetic and a logical approach will help us quickly arrive at the solution to this problem.",
35
+ "I am extremely good at math.",
36
+ ]
37
+
38
+
39
+ def multi_chain_gsm8k(question, num_chains, call_generate):
40
+ s = "Question: " + question + "\n"
41
+ # s += call_generate(s + "Answer: " + prompt_lib[0], max_tokens=256,
42
+ # stop="Question", temperature=0)
43
+ # return s
44
+
45
+ comps = []
46
+ for i in range(num_chains):
47
+ comps.append(
48
+ call_generate(
49
+ s + "Answer: " + prompt_lib[i % num_chains],
50
+ max_tokens=256,
51
+ temperature=0.3,
52
+ stop="Question",
53
+ )
54
+ )
55
+
56
+ s += "Answer: To answer this question, here are some possible solutions. "
57
+ s += "After considering all of them, I will do a majority vote.\n\n"
58
+ for i in range(num_chains):
59
+ s += f"Solution {i+1}: " + comps[i].strip() + "\n\n"
60
+ s += "\nBy considering the above solutions and doing a majority vote, I think the final answer (a single integer number) is "
61
+ s += call_generate(s, max_tokens=16, temperature=0, stop=None)
62
+ return s
63
+
64
+
65
+ async def multi_chain_gsm8k_async(question, num_chains, call_generate):
66
+ s = "Question: " + question + "\n"
67
+ # s += call_generate(s + "Answer: " + prompt_lib[0], max_tokens=256,
68
+ # stop="Question", temperature=0)
69
+ # return s
70
+
71
+ comps = []
72
+ for i in range(num_chains):
73
+ comps.append(
74
+ await call_generate(
75
+ s + "Answer: " + prompt_lib[i % num_chains],
76
+ max_tokens=256,
77
+ temperature=0.3,
78
+ stop="Question",
79
+ )
80
+ )
81
+
82
+ s += "Answer: To answer this question, here are some possible solutions. "
83
+ s += "After considering all of them, I will do a majority vote.\n\n"
84
+ for i in range(num_chains):
85
+ s += f"Solution {i+1}: " + comps[i].strip() + "\n\n"
86
+ s += "\nBy considering the above solutions and doing a majority vote, I think the final answer (a single integer number) is "
87
+ s += await call_generate(s, max_tokens=16, temperature=0, stop=None)
88
+ return s
89
+
90
+
91
+ def main(args):
92
+ lines = read_jsonl(args.data_path)
93
+
94
+ # Construct prompts
95
+ k = args.num_shot
96
+
97
+ questions = []
98
+ labels = []
99
+ for i in range(len(lines[: args.num_questions])):
100
+ questions.append(lines[i]["question"])
101
+ labels.append(get_answer_value(lines[i]["answer"]))
102
+ assert all(l != INVALID for l in labels)
103
+
104
+ states = [None] * len(labels)
105
+
106
+ # Select backend
107
+ call_generate = get_call_generate(args)
108
+
109
+ # Run requests
110
+ if args.backend != "lmql":
111
+ # Use thread pool
112
+ def get_one_answer(i):
113
+ answer = multi_chain_gsm8k(questions[i], args.num_chains, call_generate)
114
+ states[i] = answer
115
+
116
+ tic = time.time()
117
+ if args.parallel == 1:
118
+ for i in tqdm(range(len(questions))):
119
+ get_one_answer(i)
120
+ else:
121
+ with ThreadPoolExecutor(args.parallel) as executor:
122
+ list(
123
+ tqdm(
124
+ executor.map(get_one_answer, list(range(len(questions)))),
125
+ total=len(questions),
126
+ )
127
+ )
128
+
129
+ else:
130
+ # Use asyncio
131
+ async def get_one_answer_asyncio(i):
132
+ answer = await multi_chain_gsm8k_async(
133
+ questions[i], args.num_chains, call_generate
134
+ )
135
+ states[i] = answer
136
+
137
+ tic = time.time()
138
+ loop = asyncio.get_event_loop()
139
+ batches = [
140
+ list(range(i, min(i + args.parallel, len(questions))))
141
+ for i in range(0, len(questions), args.parallel)
142
+ ]
143
+ for bt in tqdm(batches):
144
+ tasks = [get_one_answer_asyncio(k) for k in bt]
145
+ loop.run_until_complete(asyncio.gather(*tasks))
146
+
147
+ latency = time.time() - tic
148
+
149
+ preds = []
150
+ for i in range(len(states)):
151
+ preds.append(get_answer_value(states[i]))
152
+
153
+ # Compute accuracy
154
+ acc = np.mean(np.array(preds) == np.array(labels))
155
+ invalid = np.mean(np.array(preds) == INVALID)
156
+ print(f"Latency: {latency:.3f}")
157
+ print(f"Invalid: {invalid:.3f}")
158
+ print(f"Accuracy: {acc:.3f}")
159
+
160
+ # Write results
161
+ dump_state_text(f"tmp_output_{args.backend}.txt", states)
162
+
163
+ with open(args.result_file, "a") as fout:
164
+ value = {
165
+ "task": "multi_chain_gsm8k",
166
+ "backend": args.backend,
167
+ "num_gpus": 1,
168
+ "latency": round(latency, 3),
169
+ "accuracy": round(acc, 3),
170
+ "num_requests": args.num_questions,
171
+ "other": {
172
+ "num_questions": args.num_questions,
173
+ "parallel": args.parallel,
174
+ },
175
+ }
176
+ fout.write(json.dumps(value) + "\n")
177
+
178
+
179
+ if __name__ == "__main__":
180
+ parser = argparse.ArgumentParser()
181
+ parser.add_argument("--num-shot", type=int, default=0)
182
+ parser.add_argument("--num-chains", type=int, default=5)
183
+ parser.add_argument("--data-path", type=str, default="test.jsonl")
184
+ parser.add_argument("--num-questions", type=int, default=50)
185
+ args = add_common_other_args_and_parse(parser)
186
+ main(args)
sglang/benchmark/multi_document_qa/README.md ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Run benchmark
2
+
3
+ ### Benchmark sglang
4
+ ```
5
+ python3 -m sglang.launch_server --model-path codellama/CodeLlama-7b-instruct-hf --port 30000
6
+ ```
7
+
8
+ ```
9
+ python3 bench_sglang.py --num-questions 10 --parallel 1
10
+ ```
11
+
12
+
13
+ ### Benchmark vllm
14
+ ```
15
+ python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model codellama/CodeLlama-7b-instruct-hf --disable-log-requests --port 21000 --gpu 0.97
16
+ ```
17
+
18
+ ```
19
+ python3 bench_other.py --backend vllm --num-questions 64
20
+ ```
21
+
22
+
23
+ ### Benchmark guidance
24
+ ```
25
+ python3 bench_other.py --backend guidance --num-questions 32 --parallel 1 --n-ctx 11000 --model-path path/to/code-llama/gguf
26
+ ```
27
+
28
+
29
+
30
+ ### Build dataset
31
+
32
+ ```
33
+ pip install PyPDF2
34
+ python3 build_dataset.py
35
+ ```
36
+
37
+ ```python
38
+ import PyPDF2
39
+
40
+ with open('llama2.pdf', 'rb') as file:
41
+ reader = PyPDF2.PdfReader(file)
42
+ text = ''
43
+ for page_num in range(len(reader.pages)):
44
+ text += reader.pages[page_num].extract_text()
45
+ with open('output.txt', 'w') as text_file:
46
+ text_file.write(text)
47
+ ```
sglang/benchmark/multi_document_qa/bench_other.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import time
4
+ from concurrent.futures import ThreadPoolExecutor
5
+ from functools import partial
6
+
7
+ from tqdm import tqdm
8
+
9
+ from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
10
+ from sglang.utils import dump_state_text, read_jsonl
11
+
12
+ USER_PREFIX = "[INST] "
13
+ USER_SUFFIX = " [/INST]"
14
+ ASSISTANT_PREFIX = ""
15
+ ASSISTANT_SUFFIX = " </s><s>"
16
+
17
+
18
+ def multi_document_qa(docs, question, generate):
19
+ s = USER_PREFIX
20
+ s += "Pleaes answer a question according to given documents.\n"
21
+ s += "Question:" + question + "Documents begin.\n"
22
+
23
+ s += "".join(docs)
24
+
25
+ s += "\nDocuments end."
26
+ s += (
27
+ "\n\nBased on the above documents, please answer this question:\n"
28
+ + question
29
+ + "\nAnswer in three words or fewer."
30
+ )
31
+ s += USER_SUFFIX
32
+ s += ASSISTANT_PREFIX
33
+ answer = generate(s, max_tokens=16, stop=None)
34
+ return answer
35
+
36
+
37
+ def main(args):
38
+ lines = read_jsonl(args.data_path)
39
+ l = lines[0]
40
+ arguments = []
41
+ labels = []
42
+
43
+ num_docs = 10
44
+ if args.backend == "guidance":
45
+ num_docs = 7 # due to OOM
46
+
47
+ for i in range(len(l["questions"][: args.num_questions])):
48
+ arguments.append(
49
+ {
50
+ "docs": l["documents"][:num_docs],
51
+ "question": l["questions"][i],
52
+ }
53
+ )
54
+ labels.append(l["answers"][i])
55
+ states = [None] * len(arguments)
56
+
57
+ # Select backend
58
+ call_generate = partial(get_call_generate(args), temperature=0)
59
+
60
+ # Run requests
61
+ def get_one_answer(i):
62
+ states[i] = multi_document_qa(generate=call_generate, **arguments[i])
63
+
64
+ tic = time.time()
65
+ if args.parallel == 1:
66
+ for i in tqdm(range(len(labels))):
67
+ get_one_answer(i)
68
+ else:
69
+ with ThreadPoolExecutor(args.parallel) as executor:
70
+ list(
71
+ tqdm(
72
+ executor.map(get_one_answer, list(range(len(labels)))),
73
+ total=len(labels),
74
+ )
75
+ )
76
+
77
+ latency = time.time() - tic
78
+
79
+ # Compute accuracy
80
+ print(states)
81
+ correct = 0
82
+ for s, label in zip(states, labels):
83
+ answer = s.lower()
84
+ if all(x in answer for x in label.lower().split(" ")):
85
+ correct += 1
86
+ accuracy = correct / len(labels)
87
+ print(f"Accuracy: {accuracy:.3f}")
88
+ print(f"Latency: {latency:.3f}")
89
+
90
+ # Write results
91
+ dump_state_text(f"tmp_output_{args.backend}.txt", states)
92
+
93
+ with open(args.result_file, "a") as fout:
94
+ value = {
95
+ "task": "multi_document_qa",
96
+ "backend": args.backend,
97
+ "num_gpus": 1,
98
+ "latency": round(latency, 3),
99
+ "num_requests": args.num_questions,
100
+ "accuracy": accuracy,
101
+ "other": {
102
+ "num_questions": args.num_questions,
103
+ "parallel": args.parallel,
104
+ },
105
+ }
106
+ fout.write(json.dumps(value) + "\n")
107
+
108
+
109
+ if __name__ == "__main__":
110
+ parser = argparse.ArgumentParser()
111
+ parser.add_argument("--data-path", type=str, default="questions.jsonl")
112
+ parser.add_argument("--num-questions", type=int, default=100)
113
+ args = add_common_other_args_and_parse(parser)
114
+ main(args)
sglang/benchmark/multi_document_qa/bench_sglang.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import time
4
+
5
+ import sglang as sgl
6
+ from sglang.test.test_utils import (
7
+ add_common_sglang_args_and_parse,
8
+ select_sglang_backend,
9
+ )
10
+ from sglang.utils import dump_state_text, read_jsonl
11
+
12
+
13
+ @sgl.function
14
+ def multi_document_qa(s, docs, question):
15
+ s += sgl.user_begin()
16
+ s += "Pleaes answer a question according to given documents.\n"
17
+ s += "Question:" + question + "Documents begin.\n"
18
+
19
+ forks = s.fork(len(docs))
20
+ forks += lambda i: docs[i]
21
+ forks.join("concate_and_append")
22
+
23
+ s += "\nDocuments end."
24
+ s += (
25
+ "\n\nBased on the above documents, please answer this question:\n"
26
+ + question
27
+ + "\nAnswer in three words or fewer."
28
+ )
29
+ s += sgl.user_end()
30
+ s += sgl.assistant(sgl.gen("answer", max_tokens=16))
31
+
32
+
33
+ def main(args):
34
+ lines = read_jsonl(args.data_path)
35
+ l = lines[0]
36
+ arguments = []
37
+ labels = []
38
+ for i in range(len(l["questions"][: args.num_questions])):
39
+ arguments.append(
40
+ {
41
+ "docs": l["documents"][:10],
42
+ "question": l["questions"][i],
43
+ }
44
+ )
45
+ labels.append(l["answers"][i])
46
+
47
+ # Select backend
48
+ backend = select_sglang_backend(args)
49
+ sgl.set_default_backend(backend)
50
+
51
+ # Run requests
52
+ tic = time.time()
53
+ states = multi_document_qa.run_batch(
54
+ arguments, temperature=0, num_threads=args.parallel, progress_bar=True
55
+ )
56
+ latency = time.time() - tic
57
+
58
+ # Compute accuracy
59
+ print([s["answer"] for s in states])
60
+ correct = 0
61
+ for s, label in zip(states, labels):
62
+ answer = s["answer"].lower()
63
+ if all(x in answer for x in label.lower().split(" ")):
64
+ correct += 1
65
+ accuracy = correct / len(labels)
66
+ print(f"Accuracy: {accuracy:.3f}")
67
+ print(f"Latency: {latency:.3f}")
68
+
69
+ # Write results
70
+ dump_state_text(f"tmp_output_{args.backend}.txt", states)
71
+
72
+ with open(args.result_file, "a") as fout:
73
+ value = {
74
+ "task": "multi_document_qa",
75
+ "backend": args.backend,
76
+ "num_gpus": 1,
77
+ "latency": round(latency, 3),
78
+ "num_requests": args.num_questions,
79
+ "accuracy": accuracy,
80
+ "other": {
81
+ "num_questions": args.num_questions,
82
+ "parallel": args.parallel,
83
+ },
84
+ }
85
+ fout.write(json.dumps(value) + "\n")
86
+
87
+
88
+ if __name__ == "__main__":
89
+ parser = argparse.ArgumentParser()
90
+ parser.add_argument("--data-path", type=str, default="questions.jsonl")
91
+ parser.add_argument("--num-questions", type=int, default=100)
92
+ args = add_common_sglang_args_and_parse(parser)
93
+ main(args)