tuandunghcmut commited on
Commit
127dcad
·
verified ·
1 Parent(s): 95e1e2e

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. LLaVA/.devcontainer/Dockerfile +53 -0
  2. LLaVA/.devcontainer/devcontainer.env +2 -0
  3. LLaVA/.devcontainer/devcontainer.json +71 -0
  4. LLaVA/.devcontainer/postCreateCommand.sh +45 -0
  5. LLaVA/docs/Evaluation.md +167 -0
  6. LLaVA/scripts/convert_sqa_to_llava_base_prompt.py +334 -0
  7. LLaVA/scripts/finetune_qlora.sh +50 -0
  8. LLaVA/scripts/pretrain.sh +46 -0
  9. LLaVA/scripts/zero2.json +23 -0
  10. sglang/.github/ISSUE_TEMPLATE/2-feature-request.yml +23 -0
  11. sglang/.github/workflows/close-inactive-issues.yml +96 -0
  12. sglang/.github/workflows/execute-notebook.yml +49 -0
  13. sglang/.github/workflows/lint.yml +22 -0
  14. sglang/.github/workflows/nightly-test.yml +34 -0
  15. sglang/.github/workflows/pr-test.yml +270 -0
  16. sglang/.github/workflows/release-docker-dev.yml +35 -0
  17. sglang/.github/workflows/release-docker.yml +64 -0
  18. sglang/.github/workflows/release-pypi-kernel.yml +41 -0
  19. sglang/.github/workflows/release-pypi.yml +31 -0
  20. sglang/3rdparty/amd/profiling/PROFILING.md +425 -0
  21. sglang/3rdparty/amd/profiling/client.sh +27 -0
  22. sglang/3rdparty/amd/profiling/install_rpd.sh +10 -0
  23. sglang/3rdparty/amd/profiling/loadTracer.sh +43 -0
  24. sglang/3rdparty/amd/profiling/rpd.patch +12 -0
  25. sglang/3rdparty/amd/profiling/rpd_profile_server_enable.patch +49 -0
  26. sglang/3rdparty/amd/profiling/rpd_profile_server_enable_wCPU_activities.patch +126 -0
  27. sglang/3rdparty/amd/profiling/server.sh +20 -0
  28. sglang/3rdparty/amd/tuning/TUNING.md +118 -0
  29. sglang/3rdparty/amd/tuning/benchmark_moe_rocm.py +377 -0
  30. sglang/benchmark/blog_v0_2/405b_sglang.sh +24 -0
  31. sglang/benchmark/blog_v0_2/405b_trt.sh +17 -0
  32. sglang/benchmark/blog_v0_2/405b_vllm.sh +24 -0
  33. sglang/benchmark/dspy/README.md +51 -0
  34. sglang/benchmark/dspy/bench_dspy_intro.py +192 -0
  35. sglang/benchmark/gsm8k/README.md +47 -0
  36. sglang/benchmark/gsm8k/bench_other.py +151 -0
  37. sglang/benchmark/gsm8k/bench_sglang.py +141 -0
  38. sglang/benchmark/hellaswag/README.md +47 -0
  39. sglang/benchmark/hellaswag/bench_other.py +118 -0
  40. sglang/benchmark/lora/launch_server.py +47 -0
  41. sglang/benchmark/lora/lora_bench.py +484 -0
  42. sglang/benchmark/mmlu/README.md +59 -0
  43. sglang/benchmark/mtbench/README.md +37 -0
  44. sglang/benchmark/mtbench/bench_other.py +111 -0
  45. sglang/benchmark/mtbench/bench_sglang.py +99 -0
  46. sglang/benchmark/multi_chain_reasoning/README.md +49 -0
  47. sglang/benchmark/multi_chain_reasoning/bench_sglang.py +140 -0
  48. sglang/benchmark/multi_turn_chat/README.md +66 -0
  49. sglang/benchmark/multi_turn_chat/bench_other.py +93 -0
  50. sglang/benchmark/multi_turn_chat/bench_sglang.py +79 -0
LLaVA/.devcontainer/Dockerfile ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM mcr.microsoft.com/devcontainers/base:ubuntu-20.04
2
+
3
+ SHELL [ "bash", "-c" ]
4
+
5
+ # update apt and install packages
6
+ RUN apt update && \
7
+ apt install -yq \
8
+ ffmpeg \
9
+ dkms \
10
+ build-essential
11
+
12
+ # add user tools
13
+ RUN sudo apt install -yq \
14
+ jq \
15
+ jp \
16
+ tree \
17
+ tldr
18
+
19
+ # add git-lfs and install
20
+ RUN curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | sudo bash && \
21
+ sudo apt-get install -yq git-lfs && \
22
+ git lfs install
23
+
24
+ ############################################
25
+ # Setup user
26
+ ############################################
27
+
28
+ USER vscode
29
+
30
+ # install azcopy, a tool to copy to/from blob storage
31
+ # for more info: https://learn.microsoft.com/en-us/azure/storage/common/storage-use-azcopy-blobs-upload#upload-a-file
32
+ RUN cd /tmp && \
33
+ wget https://azcopyvnext.azureedge.net/release20230123/azcopy_linux_amd64_10.17.0.tar.gz && \
34
+ tar xvf azcopy_linux_amd64_10.17.0.tar.gz && \
35
+ mkdir -p ~/.local/bin && \
36
+ mv azcopy_linux_amd64_10.17.0/azcopy ~/.local/bin && \
37
+ chmod +x ~/.local/bin/azcopy && \
38
+ rm -rf azcopy_linux_amd64*
39
+
40
+ # Setup conda
41
+ RUN cd /tmp && \
42
+ wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \
43
+ bash ./Miniconda3-latest-Linux-x86_64.sh -b && \
44
+ rm ./Miniconda3-latest-Linux-x86_64.sh
45
+
46
+ # Install dotnet
47
+ RUN cd /tmp && \
48
+ wget https://dot.net/v1/dotnet-install.sh && \
49
+ chmod +x dotnet-install.sh && \
50
+ ./dotnet-install.sh --channel 7.0 && \
51
+ ./dotnet-install.sh --channel 3.1 && \
52
+ rm ./dotnet-install.sh
53
+
LLaVA/.devcontainer/devcontainer.env ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ SAMPLE_ENV_VAR1="Sample Value"
2
+ SAMPLE_ENV_VAR2=332431bf-68bf
LLaVA/.devcontainer/devcontainer.json ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "LLaVA",
3
+ "build": {
4
+ "dockerfile": "Dockerfile",
5
+ "context": "..",
6
+ "args": {}
7
+ },
8
+ "features": {
9
+ "ghcr.io/devcontainers/features/docker-in-docker:2": {},
10
+ "ghcr.io/devcontainers/features/azure-cli:1": {},
11
+ "ghcr.io/azure/azure-dev/azd:0": {},
12
+ "ghcr.io/devcontainers/features/powershell:1": {},
13
+ "ghcr.io/devcontainers/features/common-utils:2": {},
14
+ "ghcr.io/devcontainers-contrib/features/zsh-plugins:0": {},
15
+ },
16
+ // "forwardPorts": [],
17
+ "postCreateCommand": "bash ./.devcontainer/postCreateCommand.sh",
18
+ "customizations": {
19
+ "vscode": {
20
+ "settings": {
21
+ "python.analysis.autoImportCompletions": true,
22
+ "python.analysis.autoImportUserSymbols": true,
23
+ "python.defaultInterpreterPath": "~/miniconda3/envs/llava/bin/python",
24
+ "python.formatting.provider": "yapf",
25
+ "python.linting.enabled": true,
26
+ "python.linting.flake8Enabled": true,
27
+ "isort.check": true,
28
+ "dev.containers.copyGitConfig": true,
29
+ "terminal.integrated.defaultProfile.linux": "zsh",
30
+ "terminal.integrated.profiles.linux": {
31
+ "zsh": {
32
+ "path": "/usr/bin/zsh"
33
+ },
34
+ }
35
+ },
36
+ "extensions": [
37
+ "aaron-bond.better-comments",
38
+ "eamodio.gitlens",
39
+ "EditorConfig.EditorConfig",
40
+ "foxundermoon.shell-format",
41
+ "GitHub.copilot-chat",
42
+ "GitHub.copilot-labs",
43
+ "GitHub.copilot",
44
+ "lehoanganh298.json-lines-viewer",
45
+ "mhutchie.git-graph",
46
+ "ms-azuretools.vscode-docker",
47
+ "ms-dotnettools.dotnet-interactive-vscode",
48
+ "ms-python.flake8",
49
+ "ms-python.isort",
50
+ "ms-python.python",
51
+ "ms-python.vscode-pylance",
52
+ "njpwerner.autodocstring",
53
+ "redhat.vscode-yaml",
54
+ "stkb.rewrap",
55
+ "yzhang.markdown-all-in-one",
56
+ ]
57
+ }
58
+ },
59
+ "mounts": [],
60
+ "runArgs": [
61
+ "--gpus",
62
+ "all",
63
+ // "--ipc",
64
+ // "host",
65
+ "--ulimit",
66
+ "memlock=-1",
67
+ "--env-file",
68
+ ".devcontainer/devcontainer.env"
69
+ ],
70
+ // "remoteUser": "root"
71
+ }
LLaVA/.devcontainer/postCreateCommand.sh ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ git config --global safe.directory '*'
2
+ git config --global core.editor "code --wait"
3
+ git config --global pager.branch false
4
+
5
+ # Set AZCOPY concurrency to auto
6
+ echo "export AZCOPY_CONCURRENCY_VALUE=AUTO" >> ~/.zshrc
7
+ echo "export AZCOPY_CONCURRENCY_VALUE=AUTO" >> ~/.bashrc
8
+
9
+ # Activate conda by default
10
+ echo ". /home/vscode/miniconda3/bin/activate" >> ~/.zshrc
11
+ echo ". /home/vscode/miniconda3/bin/activate" >> ~/.bashrc
12
+
13
+ # Use llava environment by default
14
+ echo "conda activate llava" >> ~/.zshrc
15
+ echo "conda activate llava" >> ~/.bashrc
16
+
17
+ # Add dotnet to PATH
18
+ echo 'export PATH="$PATH:$HOME/.dotnet"' >> ~/.bashrc
19
+ echo 'export PATH="$PATH:$HOME/.dotnet"' >> ~/.zshrc
20
+
21
+ # Create and activate llava environment
22
+ source /home/vscode/miniconda3/bin/activate
23
+ conda create -y -q -n llava python=3.10
24
+ conda activate llava
25
+
26
+ # Install Nvidia Cuda Compiler
27
+ conda install -y -c nvidia cuda-compiler
28
+
29
+ pip install pre-commit==3.0.2
30
+
31
+ # Install package locally
32
+ pip install --upgrade pip # enable PEP 660 support
33
+ pip install -e .
34
+
35
+ # Install additional packages for training
36
+ pip install -e ".[train]"
37
+ pip install flash-attn --no-build-isolation
38
+
39
+ # Download checkpoints to location outside of the repo
40
+ git clone https://huggingface.co/liuhaotian/llava-v1.5-7b ~/llava-v1.5-7b
41
+
42
+ # Commented because it is unlikely for users to have enough local GPU memory to load the model
43
+ # git clone https://huggingface.co/liuhaotian/llava-v1.5-13b ~/llava-v1.5-13b
44
+
45
+ echo "postCreateCommand.sh COMPLETE!"
LLaVA/docs/Evaluation.md ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Evaluation
2
+
3
+ In LLaVA-1.5, we evaluate models on a diverse set of 12 benchmarks. To ensure the reproducibility, we evaluate the models with greedy decoding. We do not evaluate using beam search to make the inference process consistent with the chat demo of real-time outputs.
4
+
5
+ Currently, we mostly utilize the official toolkit or server for the evaluation.
6
+
7
+ ## Evaluate on Custom Datasets
8
+
9
+ You can evaluate LLaVA on your custom datasets by converting your dataset to LLaVA's jsonl format, and evaluate using [`model_vqa.py`](https://github.com/haotian-liu/LLaVA/blob/main/llava/eval/model_vqa.py).
10
+
11
+ Below we provide a general guideline for evaluating datasets with some common formats.
12
+
13
+ 1. Short-answer (e.g. VQAv2, MME).
14
+
15
+ ```
16
+ <question>
17
+ Answer the question using a single word or phrase.
18
+ ```
19
+
20
+ 2. Option-only for multiple-choice (e.g. MMBench, SEED-Bench).
21
+
22
+ ```
23
+ <question>
24
+ A. <option_1>
25
+ B. <option_2>
26
+ C. <option_3>
27
+ D. <option_4>
28
+ Answer with the option's letter from the given choices directly.
29
+ ```
30
+
31
+ 3. Natural QA (e.g. LLaVA-Bench, MM-Vet).
32
+
33
+ No postprocessing is needed.
34
+
35
+ ## Scripts
36
+
37
+ Before preparing task-specific data, **you MUST first download [eval.zip](https://drive.google.com/file/d/1atZSBBrAX54yYpxtVVW33zFvcnaHeFPy/view?usp=sharing)**. It contains custom annotations, scripts, and the prediction files with LLaVA v1.5. Extract to `./playground/data/eval`. This also provides a general structure for all datasets.
38
+
39
+ ### VQAv2
40
+
41
+ 1. Download [`test2015`](http://images.cocodataset.org/zips/test2015.zip) and put it under `./playground/data/eval/vqav2`.
42
+ 2. Multi-GPU inference.
43
+ ```Shell
44
+ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 bash scripts/v1_5/eval/vqav2.sh
45
+ ```
46
+ 3. Submit the results to the [evaluation server](https://eval.ai/web/challenges/challenge-page/830/my-submission): `./playground/data/eval/vqav2/answers_upload`.
47
+
48
+ ### GQA
49
+
50
+ 1. Download the [data](https://cs.stanford.edu/people/dorarad/gqa/download.html) and [evaluation scripts](https://cs.stanford.edu/people/dorarad/gqa/evaluate.html) following the official instructions and put under `./playground/data/eval/gqa/data`. You may need to modify `eval.py` as [this](https://gist.github.com/haotian-liu/db6eddc2a984b4cbcc8a7f26fd523187) due to the missing assets in the GQA v1.2 release.
51
+ 2. Multi-GPU inference.
52
+ ```Shell
53
+ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 bash scripts/v1_5/eval/gqa.sh
54
+ ```
55
+
56
+ ### VisWiz
57
+
58
+ 1. Download [`test.json`](https://vizwiz.cs.colorado.edu/VizWiz_final/vqa_data/Annotations.zip) and extract [`test.zip`](https://vizwiz.cs.colorado.edu/VizWiz_final/images/test.zip) to `test`. Put them under `./playground/data/eval/vizwiz`.
59
+ 2. Single-GPU inference.
60
+ ```Shell
61
+ CUDA_VISIBLE_DEVICES=0 bash scripts/v1_5/eval/vizwiz.sh
62
+ ```
63
+ 3. Submit the results to the [evaluation server](https://eval.ai/web/challenges/challenge-page/2185/my-submission): `./playground/data/eval/vizwiz/answers_upload`.
64
+
65
+ ### ScienceQA
66
+
67
+ 1. Under `./playground/data/eval/scienceqa`, download `images`, `pid_splits.json`, `problems.json` from the `data/scienceqa` folder of the ScienceQA [repo](https://github.com/lupantech/ScienceQA).
68
+ 2. Single-GPU inference and evaluate.
69
+ ```Shell
70
+ CUDA_VISIBLE_DEVICES=0 bash scripts/v1_5/eval/sqa.sh
71
+ ```
72
+
73
+ ### TextVQA
74
+
75
+ 1. Download [`TextVQA_0.5.1_val.json`](https://dl.fbaipublicfiles.com/textvqa/data/TextVQA_0.5.1_val.json) and [images](https://dl.fbaipublicfiles.com/textvqa/images/train_val_images.zip) and extract to `./playground/data/eval/textvqa`.
76
+ 2. Single-GPU inference and evaluate.
77
+ ```Shell
78
+ CUDA_VISIBLE_DEVICES=0 bash scripts/v1_5/eval/textvqa.sh
79
+ ```
80
+
81
+ ### POPE
82
+
83
+ 1. Download `coco` from [POPE](https://github.com/AoiDragon/POPE/tree/e3e39262c85a6a83f26cf5094022a782cb0df58d/output/coco) and put under `./playground/data/eval/pope`.
84
+ 2. Single-GPU inference and evaluate.
85
+ ```Shell
86
+ CUDA_VISIBLE_DEVICES=0 bash scripts/v1_5/eval/pope.sh
87
+ ```
88
+
89
+ ### MME
90
+
91
+ 1. Download the data following the official instructions [here](https://github.com/BradyFU/Awesome-Multimodal-Large-Language-Models/tree/Evaluation).
92
+ 2. Downloaded images to `MME_Benchmark_release_version`.
93
+ 3. put the official `eval_tool` and `MME_Benchmark_release_version` under `./playground/data/eval/MME`.
94
+ 4. Single-GPU inference and evaluate.
95
+ ```Shell
96
+ CUDA_VISIBLE_DEVICES=0 bash scripts/v1_5/eval/mme.sh
97
+ ```
98
+
99
+ ### MMBench
100
+
101
+ 1. Download [`mmbench_dev_20230712.tsv`](https://download.openmmlab.com/mmclassification/datasets/mmbench/mmbench_dev_20230712.tsv) and put under `./playground/data/eval/mmbench`.
102
+ 2. Single-GPU inference.
103
+ ```Shell
104
+ CUDA_VISIBLE_DEVICES=0 bash scripts/v1_5/eval/mmbench.sh
105
+ ```
106
+ 3. Submit the results to the [evaluation server](https://opencompass.org.cn/leaderboard-multimodal): `./playground/data/eval/mmbench/answers_upload/mmbench_dev_20230712`.
107
+
108
+ ### MMBench-CN
109
+
110
+ 1. Download [`mmbench_dev_cn_20231003.tsv`](https://download.openmmlab.com/mmclassification/datasets/mmbench/mmbench_dev_cn_20231003.tsv) and put under `./playground/data/eval/mmbench`.
111
+ 2. Single-GPU inference.
112
+ ```Shell
113
+ CUDA_VISIBLE_DEVICES=0 bash scripts/v1_5/eval/mmbench_cn.sh
114
+ ```
115
+ 3. Submit the results to the evaluation server: `./playground/data/eval/mmbench/answers_upload/mmbench_dev_cn_20231003`.
116
+
117
+
118
+ ### SEED-Bench
119
+
120
+ 1. Following the official [instructions](https://github.com/AILab-CVC/SEED-Bench/blob/main/DATASET.md) to download the images and the videos. Put images under `./playground/data/eval/seed_bench/SEED-Bench-image`.
121
+ 2. Extract the video frame in the middle from the downloaded videos, and put them under `./playground/data/eval/seed_bench/SEED-Bench-video-image`. We provide our script `extract_video_frames.py` modified from the official one.
122
+ 3. Multiple-GPU inference and evaluate.
123
+ ```Shell
124
+ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 bash scripts/v1_5/eval/seed.sh
125
+ ```
126
+ 4. Optionally, submit the results to the leaderboard: `./playground/data/eval/seed_bench/answers_upload` using the official jupyter notebook.
127
+
128
+ ### LLaVA-Bench-in-the-Wild
129
+
130
+ 1. Extract contents of [`llava-bench-in-the-wild`](https://huggingface.co/datasets/liuhaotian/llava-bench-in-the-wild) to `./playground/data/eval/llava-bench-in-the-wild`.
131
+ 2. Single-GPU inference and evaluate.
132
+ ```Shell
133
+ CUDA_VISIBLE_DEVICES=0 bash scripts/v1_5/eval/llavabench.sh
134
+ ```
135
+
136
+ ### MM-Vet
137
+
138
+ 1. Extract [`mm-vet.zip`](https://github.com/yuweihao/MM-Vet/releases/download/v1/mm-vet.zip) to `./playground/data/eval/mmvet`.
139
+ 2. Single-GPU inference.
140
+ ```Shell
141
+ CUDA_VISIBLE_DEVICES=0 bash scripts/v1_5/eval/mmvet.sh
142
+ ```
143
+ 3. Evaluate the predictions in `./playground/data/eval/mmvet/results` using the official jupyter notebook.
144
+
145
+ ## More Benchmarks
146
+
147
+ Below are awesome benchmarks for multimodal understanding from the research community, that are not initially included in the LLaVA-1.5 release.
148
+
149
+ ### Q-Bench
150
+
151
+ 1. Download [`llvisionqa_dev.json`](https://huggingface.co/datasets/nanyangtu/LLVisionQA-QBench/resolve/main/llvisionqa_dev.json) (for `dev`-subset) and [`llvisionqa_test.json`](https://huggingface.co/datasets/nanyangtu/LLVisionQA-QBench/resolve/main/llvisionqa_test.json) (for `test`-subset). Put them under `./playground/data/eval/qbench`.
152
+ 2. Download and extract [images](https://huggingface.co/datasets/nanyangtu/LLVisionQA-QBench/resolve/main/images_llvisionqa.tar) and put all the images directly under `./playground/data/eval/qbench/images_llviqionqa`.
153
+ 3. Single-GPU inference (change `dev` to `test` for evaluation on test set).
154
+ ```Shell
155
+ CUDA_VISIBLE_DEVICES=0 bash scripts/v1_5/eval/qbench.sh dev
156
+ ```
157
+ 4. Submit the results by instruction [here](https://github.com/VQAssessment/Q-Bench#option-1-submit-results): `./playground/data/eval/qbench/llvisionqa_dev_answers.jsonl`.
158
+
159
+ ### Chinese-Q-Bench
160
+
161
+ 1. Download [`质衡-问答-验证集.json`](https://huggingface.co/datasets/nanyangtu/LLVisionQA-QBench/resolve/main/%E8%B4%A8%E8%A1%A1-%E9%97%AE%E7%AD%94-%E9%AA%8C%E8%AF%81%E9%9B%86.json) (for `dev`-subset) and [`质衡-问答-测试集.json`](https://huggingface.co/datasets/nanyangtu/LLVisionQA-QBench/resolve/main/%E8%B4%A8%E8%A1%A1-%E9%97%AE%E7%AD%94-%E6%B5%8B%E8%AF%95%E9%9B%86.json) (for `test`-subset). Put them under `./playground/data/eval/qbench`.
162
+ 2. Download and extract [images](https://huggingface.co/datasets/nanyangtu/LLVisionQA-QBench/resolve/main/images_llvisionqa.tar) and put all the images directly under `./playground/data/eval/qbench/images_llviqionqa`.
163
+ 3. Single-GPU inference (change `dev` to `test` for evaluation on test set).
164
+ ```Shell
165
+ CUDA_VISIBLE_DEVICES=0 bash scripts/v1_5/eval/qbench_zh.sh dev
166
+ ```
167
+ 4. Submit the results by instruction [here](https://github.com/VQAssessment/Q-Bench#option-1-submit-results): `./playground/data/eval/qbench/llvisionqa_zh_dev_answers.jsonl`.
LLaVA/scripts/convert_sqa_to_llava_base_prompt.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def get_question_text(problem):
2
+ question = problem['question']
3
+ return question
4
+
5
+
6
+ def get_context_text(problem, use_caption):
7
+ txt_context = problem['hint']
8
+ img_context = problem['caption'] if use_caption else ""
9
+ context = " ".join([txt_context, img_context]).strip()
10
+ if context == "":
11
+ context = "N/A"
12
+ return context
13
+
14
+
15
+ def get_choice_text(probelm, options):
16
+ choices = probelm['choices']
17
+ choice_list = []
18
+ for i, c in enumerate(choices):
19
+ choice_list.append("({}) {}".format(options[i], c))
20
+ choice_txt = " ".join(choice_list)
21
+ #print(choice_txt)
22
+ return choice_txt
23
+
24
+
25
+ def get_answer(problem, options):
26
+ return options[problem['answer']]
27
+
28
+
29
+ def get_lecture_text(problem):
30
+ # \\n: GPT-3 can generate the lecture with more tokens.
31
+ lecture = problem['lecture'].replace("\n", "\\n")
32
+ return lecture
33
+
34
+
35
+ def get_solution_text(problem):
36
+ # \\n: GPT-3 can generate the solution with more tokens
37
+ solution = problem['solution'].replace("\n", "\\n")
38
+ return solution
39
+
40
+
41
+ def create_one_example_chatbot(format, question, context, choice, answer, lecture, solution, test_example=True):
42
+
43
+ input_format, output_format = format.split("-")
44
+
45
+ ## Inputs
46
+ if input_format == "CQM":
47
+ input = f"Context: {context}\nQuestion: {question}\nOptions: {choice}\n"
48
+ elif input_format == "QCM":
49
+ input = f"Question: {question}\nContext: {context}\nOptions: {choice}\n"
50
+ # upper bound experiment
51
+ elif input_format == "QCML":
52
+ input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {lecture}\n"
53
+ elif input_format == "QCME":
54
+ input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {solution}\n"
55
+ elif input_format == "QCMLE":
56
+ input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {lecture} {solution}\n"
57
+
58
+ elif input_format == "QCLM":
59
+ input = f"Question: {question}\nContext: {context}\nBECAUSE: {lecture}\nOptions: {choice}\n"
60
+ elif input_format == "QCEM":
61
+ input = f"Question: {question}\nContext: {context}\nBECAUSE: {solution}\nOptions: {choice}\n"
62
+ elif input_format == "QCLEM":
63
+ input = f"Question: {question}\nContext: {context}\nBECAUSE: {lecture} {solution}\nOptions: {choice}\n"
64
+
65
+ # Outputs
66
+ if test_example:
67
+ output = "Answer:"
68
+ elif output_format == 'A':
69
+ output = f"Answer: The answer is {answer}."
70
+
71
+ elif output_format == 'AL':
72
+ output = f"Answer: The answer is {answer}. BECAUSE: {solution}"
73
+ elif output_format == 'AE':
74
+ output = f"Answer: The answer is {answer}. BECAUSE: {lecture}"
75
+ elif output_format == 'ALE':
76
+ output = f"Answer: The answer is {answer}. BECAUSE: {lecture} {solution}"
77
+ elif output_format == 'AEL':
78
+ output = f"Answer: The answer is {answer}. BECAUSE: {solution} {lecture}"
79
+
80
+ elif output_format == 'LA':
81
+ output = f"Answer: {lecture} The answer is {answer}."
82
+ elif output_format == 'EA':
83
+ output = f"Answer: {solution} The answer is {answer}."
84
+ elif output_format == 'LEA':
85
+ output = f"Answer: {lecture} {solution} The answer is {answer}."
86
+ elif output_format == 'ELA':
87
+ output = f"Answer: {solution} {lecture} The answer is {answer}."
88
+ elif output_format == 'LEPA':
89
+ output = ''
90
+ if len(lecture.strip()) > 0:
91
+ output += f"LECTURE: {lecture}\n"
92
+ if len(solution.strip()) > 0:
93
+ output += f"SOLUTION: {solution}\n"
94
+ output += '###\n'
95
+ output += f"ANSWER: {answer}."
96
+
97
+ input = input.replace(" ", " ").strip()
98
+ output = output.replace(" ", " ").strip()
99
+ if input.endswith("BECAUSE:"):
100
+ input = input.replace("BECAUSE:", "").strip()
101
+ if output.endswith("BECAUSE:"):
102
+ output = output.replace("BECAUSE:", "").strip()
103
+ return input, output
104
+
105
+
106
+ def create_one_example(format, question, context, choice, answer, lecture, solution, test_example=True):
107
+
108
+ input_format, output_format = format.split("-")
109
+
110
+ ## Inputs
111
+ if input_format == "CQM":
112
+ input = f"Context: {context}\nQuestion: {question}\nOptions: {choice}\n"
113
+ elif input_format == "QCM":
114
+ input = f"Question: {question}\nContext: {context}\nOptions: {choice}\n"
115
+ # upper bound experiment
116
+ elif input_format == "QCML":
117
+ input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {lecture}\n"
118
+ elif input_format == "QCME":
119
+ input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {solution}\n"
120
+ elif input_format == "QCMLE":
121
+ input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {lecture} {solution}\n"
122
+
123
+ elif input_format == "QCLM":
124
+ input = f"Question: {question}\nContext: {context}\nBECAUSE: {lecture}\nOptions: {choice}\n"
125
+ elif input_format == "QCEM":
126
+ input = f"Question: {question}\nContext: {context}\nBECAUSE: {solution}\nOptions: {choice}\n"
127
+ elif input_format == "QCLEM":
128
+ input = f"Question: {question}\nContext: {context}\nBECAUSE: {lecture} {solution}\nOptions: {choice}\n"
129
+
130
+ # Outputs
131
+ if test_example:
132
+ output = "Answer:"
133
+ elif output_format == 'A':
134
+ output = f"Answer: The answer is {answer}."
135
+
136
+ elif output_format == 'AL':
137
+ output = f"Answer: The answer is {answer}. BECAUSE: {solution}"
138
+ elif output_format == 'AE':
139
+ output = f"Answer: The answer is {answer}. BECAUSE: {lecture}"
140
+ elif output_format == 'ALE':
141
+ output = f"Answer: The answer is {answer}. BECAUSE: {lecture} {solution}"
142
+ elif output_format == 'AEL':
143
+ output = f"Answer: The answer is {answer}. BECAUSE: {solution} {lecture}"
144
+
145
+ elif output_format == 'LA':
146
+ output = f"Answer: {lecture} The answer is {answer}."
147
+ elif output_format == 'EA':
148
+ output = f"Answer: {solution} The answer is {answer}."
149
+ elif output_format == 'LEA':
150
+ output = f"Answer: {lecture} {solution} The answer is {answer}."
151
+ elif output_format == 'ELA':
152
+ output = f"Answer: {solution} {lecture} The answer is {answer}."
153
+
154
+ text = input + output
155
+ text = text.replace(" ", " ").strip()
156
+ if text.endswith("BECAUSE:"):
157
+ text = text.replace("BECAUSE:", "").strip()
158
+ return text
159
+
160
+
161
+
162
+ def create_one_example_gpt4(format, question, context, choice, answer, lecture, solution, test_example=True):
163
+
164
+ input_format, output_format = format.split("-")
165
+
166
+ ## Inputs
167
+ if input_format == "CQM":
168
+ input = f"Context: {context}\nQuestion: {question}\nOptions: {choice}\n"
169
+ elif input_format == "QCM":
170
+ input = f"Question: {question}\nContext: {context}\nOptions: {choice}\n"
171
+ # upper bound experiment
172
+ elif input_format == "QCML":
173
+ input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {lecture}\n"
174
+ elif input_format == "QCME":
175
+ input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {solution}\n"
176
+ elif input_format == "QCMLE":
177
+ input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {lecture} {solution}\n"
178
+
179
+ elif input_format == "QCLM":
180
+ input = f"Question: {question}\nContext: {context}\nBECAUSE: {lecture}\nOptions: {choice}\n"
181
+ elif input_format == "QCEM":
182
+ input = f"Question: {question}\nContext: {context}\nBECAUSE: {solution}\nOptions: {choice}\n"
183
+ elif input_format == "QCLEM":
184
+ input = f"Question: {question}\nContext: {context}\nBECAUSE: {lecture} {solution}\nOptions: {choice}\n"
185
+
186
+ # Outputs
187
+ if test_example:
188
+ output = "Answer:"
189
+ elif output_format == 'A':
190
+ output = f"Answer: The answer is {answer}."
191
+
192
+ elif output_format == 'AL':
193
+ output = f"Answer: The answer is {answer}. BECAUSE: {solution}"
194
+ elif output_format == 'AE':
195
+ output = f"Answer: The answer is {answer}. BECAUSE: {lecture}"
196
+ elif output_format == 'ALE':
197
+ output = f"Answer: The answer is {answer}. BECAUSE: {lecture} {solution}"
198
+ elif output_format == 'AEL':
199
+ output = f"Answer: The answer is {answer}. BECAUSE: {solution} {lecture}"
200
+
201
+ elif output_format == 'LA':
202
+ output = f"Answer: {lecture} The answer is {answer}."
203
+ elif output_format == 'EA':
204
+ output = f"Answer: {solution} The answer is {answer}."
205
+ elif output_format == 'LEA':
206
+ output = f"Answer: {lecture} {solution} The answer is {answer}."
207
+ elif output_format == 'ELA':
208
+ output = f"Answer: {solution} {lecture} The answer is {answer}."
209
+
210
+ input = input.replace(" ", " ").strip()
211
+ output = output.replace(" ", " ").strip()
212
+ if output.endswith("BECAUSE:"):
213
+ output = output.replace("BECAUSE:", "").strip()
214
+
215
+ user_prompt = {"role": "user", "content": f"Can you explain {input}?"}
216
+ assistant_prompt = {"role": "assistant", "content": f"{output}"}
217
+
218
+ return user_prompt, assistant_prompt
219
+
220
+
221
+ def build_prompt_chatbot(problems, shot_qids, prompt_format, use_caption=False, options=["A", "B", "C", "D", "E"], is_test=False):
222
+ examples = {}
223
+
224
+ for qid in shot_qids:
225
+ question = get_question_text(problems[qid])
226
+ context = get_context_text(problems[qid], use_caption)
227
+ choice = get_choice_text(problems[qid], options)
228
+ answer = get_answer(problems[qid], options)
229
+ lecture = get_lecture_text(problems[qid]).replace('\\n', '\n')
230
+ solution = get_solution_text(problems[qid]).replace('\\n', '\n')
231
+
232
+ train_example = create_one_example_chatbot(prompt_format,
233
+ question,
234
+ context,
235
+ choice,
236
+ answer,
237
+ lecture,
238
+ solution,
239
+ test_example=is_test)
240
+ examples[qid] = train_example
241
+ return examples
242
+
243
+
244
+ def build_prompt(problems, shot_qids, test_qid, args):
245
+
246
+ examples = []
247
+
248
+ # n-shot training examples
249
+ for qid in shot_qids:
250
+ question = get_question_text(problems[qid])
251
+ context = get_context_text(problems[qid], args.use_caption)
252
+ choice = get_choice_text(problems[qid], args.options)
253
+ answer = get_answer(problems[qid], args.options)
254
+ lecture = get_lecture_text(problems[qid])
255
+ solution = get_solution_text(problems[qid])
256
+
257
+ train_example = create_one_example(args.prompt_format,
258
+ question,
259
+ context,
260
+ choice,
261
+ answer,
262
+ lecture,
263
+ solution,
264
+ test_example=False)
265
+ examples.append(train_example)
266
+
267
+ # test example
268
+ question = get_question_text(problems[test_qid])
269
+ context = get_context_text(problems[test_qid], args.use_caption)
270
+ choice = get_choice_text(problems[test_qid], args.options)
271
+ answer = get_answer(problems[test_qid], args.options)
272
+ lecture = get_lecture_text(problems[test_qid])
273
+ solution = get_solution_text(problems[test_qid])
274
+
275
+ test_example = create_one_example(args.prompt_format,
276
+ question,
277
+ context,
278
+ choice,
279
+ answer,
280
+ lecture,
281
+ solution,
282
+ test_example=True)
283
+ examples.append(test_example)
284
+
285
+ # create the prompt input
286
+ prompt_input = '\n\n'.join(examples)
287
+
288
+ return prompt_input
289
+
290
+
291
+ def build_prompt_gpt4(problems, shot_qids, test_qid, args):
292
+
293
+ prompt_array = [{"role": "system", "content": "You are a helpful assistant."}]
294
+
295
+ # n-shot training examples
296
+ for qid in shot_qids:
297
+ question = get_question_text(problems[qid])
298
+ context = get_context_text(problems[qid], args.use_caption)
299
+ choice = get_choice_text(problems[qid], args.options)
300
+ answer = get_answer(problems[qid], args.options)
301
+ lecture = get_lecture_text(problems[qid])
302
+ solution = get_solution_text(problems[qid])
303
+
304
+ user_prompt, assistant_prompt = create_one_example_gpt4(args.prompt_format,
305
+ question,
306
+ context,
307
+ choice,
308
+ answer,
309
+ lecture,
310
+ solution,
311
+ test_example=False)
312
+ prompt_array.append(user_prompt)
313
+ prompt_array.append(assistant_prompt)
314
+
315
+ # test example
316
+ question = get_question_text(problems[test_qid])
317
+ context = get_context_text(problems[test_qid], args.use_caption)
318
+ choice = get_choice_text(problems[test_qid], args.options)
319
+ answer = get_answer(problems[test_qid], args.options)
320
+ lecture = get_lecture_text(problems[test_qid])
321
+ solution = get_solution_text(problems[test_qid])
322
+
323
+ user_prompt, assistant_prompt = create_one_example_gpt4(args.prompt_format,
324
+ question,
325
+ context,
326
+ choice,
327
+ answer,
328
+ lecture,
329
+ solution,
330
+ test_example=True)
331
+ prompt_array.append(user_prompt)
332
+ prompt_array.append(assistant_prompt)
333
+
334
+ return prompt_array
LLaVA/scripts/finetune_qlora.sh ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # IMPORTANT: this is the training script for the original LLaVA, NOT FOR LLaVA V1.5!
4
+
5
+ # Uncomment and set the following variables correspondingly to run this script:
6
+
7
+ ################## VICUNA ##################
8
+ # PROMPT_VERSION=v1
9
+ # MODEL_VERSION="vicuna-v1-3-7b"
10
+ ################## VICUNA ##################
11
+
12
+ ################## LLaMA-2 ##################
13
+ # PROMPT_VERSION="llava_llama_2"
14
+ # MODEL_VERSION="llama-2-7b-chat"
15
+ ################## LLaMA-2 ##################
16
+
17
+ deepspeed llava/train/train_mem.py \
18
+ --deepspeed ./scripts/zero2.json \
19
+ --lora_enable True \
20
+ --bits 4 \
21
+ --model_name_or_path ./checkpoints/$MODEL_VERSION \
22
+ --version $PROMPT_VERSION \
23
+ --data_path ./playground/data/llava_instruct_80k.json \
24
+ --image_folder /path/to/coco/train2017 \
25
+ --vision_tower openai/clip-vit-large-patch14 \
26
+ --pretrain_mm_mlp_adapter ./checkpoints/llava-$MODEL_VERSION-pretrain/mm_projector.bin \
27
+ --mm_vision_select_layer -2 \
28
+ --mm_use_im_start_end False \
29
+ --mm_use_im_patch_token False \
30
+ --bf16 True \
31
+ --output_dir ./checkpoints/llava-$MODEL_VERSION-finetune_lora \
32
+ --num_train_epochs 1 \
33
+ --per_device_train_batch_size 16 \
34
+ --per_device_eval_batch_size 4 \
35
+ --gradient_accumulation_steps 1 \
36
+ --evaluation_strategy "no" \
37
+ --save_strategy "steps" \
38
+ --save_steps 50000 \
39
+ --save_total_limit 1 \
40
+ --learning_rate 2e-5 \
41
+ --weight_decay 0. \
42
+ --warmup_ratio 0.03 \
43
+ --lr_scheduler_type "cosine" \
44
+ --logging_steps 1 \
45
+ --tf32 True \
46
+ --model_max_length 2048 \
47
+ --gradient_checkpointing True \
48
+ --lazy_preprocess True \
49
+ --dataloader_num_workers 4 \
50
+ --report_to wandb
LLaVA/scripts/pretrain.sh ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # IMPORTANT: this is the training script for the original LLaVA, NOT FOR LLaVA V1.5!
4
+
5
+ # Uncomment and set the following variables correspondingly to run this script:
6
+
7
+ # MODEL_VERSION=vicuna-v1-3-7b
8
+ # MODEL_VERSION=llama-2-7b-chat
9
+
10
+ ########### DO NOT CHANGE ###########
11
+ ########### USE THIS FOR BOTH ###########
12
+ PROMPT_VERSION=plain
13
+ ########### DO NOT CHANGE ###########
14
+
15
+ deepspeed llava/train/train_mem.py \
16
+ --deepspeed ./scripts/zero2.json \
17
+ --model_name_or_path ./checkpoints/$MODEL_VERSION \
18
+ --version $PROMPT_VERSION \
19
+ --data_path /path/to/pretrain_data.json \
20
+ --image_folder /path/to/images \
21
+ --vision_tower openai/clip-vit-large-patch14 \
22
+ --tune_mm_mlp_adapter True \
23
+ --mm_vision_select_layer -2 \
24
+ --mm_use_im_start_end False \
25
+ --mm_use_im_patch_token False \
26
+ --bf16 True \
27
+ --output_dir ./checkpoints/llava-$MODEL_VERSION-pretrain \
28
+ --num_train_epochs 1 \
29
+ --per_device_train_batch_size 16 \
30
+ --per_device_eval_batch_size 4 \
31
+ --gradient_accumulation_steps 1 \
32
+ --evaluation_strategy "no" \
33
+ --save_strategy "steps" \
34
+ --save_steps 24000 \
35
+ --save_total_limit 1 \
36
+ --learning_rate 2e-3 \
37
+ --weight_decay 0. \
38
+ --warmup_ratio 0.03 \
39
+ --lr_scheduler_type "cosine" \
40
+ --logging_steps 1 \
41
+ --tf32 True \
42
+ --model_max_length 2048 \
43
+ --gradient_checkpointing True \
44
+ --dataloader_num_workers 4 \
45
+ --lazy_preprocess True \
46
+ --report_to wandb
LLaVA/scripts/zero2.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fp16": {
3
+ "enabled": "auto",
4
+ "loss_scale": 0,
5
+ "loss_scale_window": 1000,
6
+ "initial_scale_power": 16,
7
+ "hysteresis": 2,
8
+ "min_loss_scale": 1
9
+ },
10
+ "bf16": {
11
+ "enabled": "auto"
12
+ },
13
+ "train_micro_batch_size_per_gpu": "auto",
14
+ "train_batch_size": "auto",
15
+ "gradient_accumulation_steps": "auto",
16
+ "zero_optimization": {
17
+ "stage": 2,
18
+ "overlap_comm": true,
19
+ "contiguous_gradients": true,
20
+ "sub_group_size": 1e9,
21
+ "reduce_bucket_size": "auto"
22
+ }
23
+ }
sglang/.github/ISSUE_TEMPLATE/2-feature-request.yml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: 🚀 Feature request
2
+ description: Suggest an idea for this project
3
+ title: "[Feature] "
4
+
5
+ body:
6
+ - type: checkboxes
7
+ attributes:
8
+ label: Checklist
9
+ options:
10
+ - label: 1. If the issue you raised is not a feature but a question, please raise a discussion at https://github.com/sgl-project/sglang/discussions/new/choose Otherwise, it will be closed.
11
+ - label: 2. Please use English, otherwise it will be closed.
12
+ - type: textarea
13
+ attributes:
14
+ label: Motivation
15
+ description: |
16
+ A clear and concise description of the motivation of the feature.
17
+ validations:
18
+ required: true
19
+ - type: textarea
20
+ attributes:
21
+ label: Related resources
22
+ description: |
23
+ If there is an official code release or third-party implementations, please also provide the information here, which would be very helpful.
sglang/.github/workflows/close-inactive-issues.yml ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Close Inactive Issues
2
+
3
+ on:
4
+ schedule:
5
+ - cron: '0 0 * * *'
6
+ workflow_dispatch:
7
+
8
+ permissions:
9
+ issues: write
10
+ contents: read
11
+
12
+ jobs:
13
+ close-inactive-issues:
14
+ if: github.repository == 'sgl-project/sglang'
15
+ runs-on: ubuntu-latest
16
+ steps:
17
+ - name: Check and close inactive issues
18
+ uses: actions/github-script@v6
19
+ with:
20
+ github-token: ${{secrets.GITHUB_TOKEN}}
21
+ script: |
22
+ const sixtyDaysAgo = new Date(Date.now() - 60 * 24 * 60 * 60 * 1000);
23
+
24
+ const [owner, repo] = process.env.GITHUB_REPOSITORY.split('/');
25
+ console.log(`Owner: ${owner}, Repo: ${repo}`);
26
+
27
+ async function fetchIssues(page = 1) {
28
+ console.log(`Fetching issues for ${owner}/${repo}, page ${page}`);
29
+ return await github.rest.issues.listForRepo({
30
+ owner,
31
+ repo,
32
+ state: 'open',
33
+ sort: 'updated',
34
+ direction: 'asc',
35
+ per_page: 100,
36
+ page: page
37
+ });
38
+ }
39
+
40
+ async function processIssues() {
41
+ console.log('Starting to process issues');
42
+ console.log(`Repository: ${owner}/${repo}`);
43
+
44
+ let page = 1;
45
+ let hasMoreIssues = true;
46
+ while (hasMoreIssues) {
47
+ try {
48
+ const issues = await fetchIssues(page);
49
+ console.log(`Fetched ${issues.data.length} issues on page ${page}`);
50
+
51
+ if (issues.data.length === 0) {
52
+ hasMoreIssues = false;
53
+ break;
54
+ }
55
+
56
+ for (const issue of issues.data) {
57
+ // Skip if the issue has 'good first issue' label
58
+ if (issue.labels.some(label => label.name === 'good first issue')) {
59
+ console.log(`Skipping issue #${issue.number} as it's marked as 'good first issue'`);
60
+ continue;
61
+ }
62
+ if (new Date(issue.updated_at) < sixtyDaysAgo) {
63
+ try {
64
+ await github.rest.issues.update({
65
+ owner,
66
+ repo,
67
+ issue_number: issue.number,
68
+ state: 'closed',
69
+ labels: [...issue.labels.map(l => l.name), 'inactive']
70
+ });
71
+ await github.rest.issues.createComment({
72
+ owner,
73
+ repo,
74
+ issue_number: issue.number,
75
+ body: 'This issue has been automatically closed due to inactivity. Please feel free to reopen it if needed.'
76
+ });
77
+ console.log(`Closed issue #${issue.number} due to inactivity.`);
78
+ } catch (error) {
79
+ console.error(`Failed to close issue #${issue.number}: ${error.message}`);
80
+ }
81
+ } else {
82
+ console.log(`Issue #${issue.number} is still active. Stopping processing.`);
83
+ hasMoreIssues = false;
84
+ break;
85
+ }
86
+ }
87
+ page += 1;
88
+ } catch (error) {
89
+ console.error(`Error fetching issues on page ${page}: ${error.message}`);
90
+ hasMoreIssues = false;
91
+ }
92
+ }
93
+ console.log('Finished processing issues');
94
+ }
95
+
96
+ await processIssues();
sglang/.github/workflows/execute-notebook.yml ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Execute Notebooks
2
+
3
+ on:
4
+ push:
5
+ branches: [ main ]
6
+ paths:
7
+ - "python/sglang/**"
8
+ - "docs/**"
9
+ pull_request:
10
+ branches: [ main ]
11
+ paths:
12
+ - "python/sglang/**"
13
+ - "docs/**"
14
+ workflow_dispatch:
15
+
16
+
17
+ concurrency:
18
+ group: execute-notebook-${{ github.ref }}
19
+ cancel-in-progress: true
20
+
21
+
22
+ jobs:
23
+ run-all-notebooks:
24
+ runs-on: 1-gpu-runner
25
+ if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
26
+ steps:
27
+ - name: Checkout code
28
+ uses: actions/checkout@v3
29
+
30
+ - name: Set up Python
31
+ uses: actions/setup-python@v4
32
+ with:
33
+ python-version: '3.9'
34
+
35
+ - name: Install dependencies
36
+ run: |
37
+ bash scripts/ci_install_dependency.sh
38
+ pip install -r docs/requirements.txt
39
+
40
+ - name: Setup Jupyter Kernel
41
+ run: |
42
+ python -m ipykernel install --user --name python3 --display-name "Python 3"
43
+
44
+ - name: Execute notebooks
45
+ timeout-minutes: 30
46
+ run: |
47
+ cd docs
48
+ make clean
49
+ make compile
sglang/.github/workflows/lint.yml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Lint
2
+
3
+ on: [pull_request]
4
+
5
+ jobs:
6
+ lint:
7
+ runs-on: ubuntu-latest
8
+ steps:
9
+ - uses: actions/checkout@v2
10
+
11
+ - name: Set up Python
12
+ uses: actions/setup-python@v4
13
+ with:
14
+ python-version: '3.9'
15
+
16
+ - name: Install pre-commit hook
17
+ run: |
18
+ python -m pip install pre-commit
19
+ pre-commit install
20
+
21
+ - name: Linting
22
+ run: pre-commit run --all-files
sglang/.github/workflows/nightly-test.yml ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Nightly Test
2
+
3
+ on:
4
+ schedule:
5
+ - cron: '0 0 * * *'
6
+ push:
7
+ branches:
8
+ - main
9
+ paths:
10
+ - "python/sglang/version.py"
11
+ workflow_dispatch:
12
+
13
+ concurrency:
14
+ group: nightly-test-${{ github.ref }}
15
+ cancel-in-progress: true
16
+
17
+ jobs:
18
+ nightly-test:
19
+ if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
20
+ runs-on: 2-gpu-runner
21
+ steps:
22
+ - name: Checkout code
23
+ uses: actions/checkout@v3
24
+
25
+ - name: Install dependencies
26
+ run: |
27
+ bash scripts/ci_install_dependency.sh
28
+ pip install --upgrade "evalplus[vllm] @ git+https://github.com/evalplus/evalplus"
29
+
30
+ - name: Run test
31
+ timeout-minutes: 120
32
+ run: |
33
+ cd test/srt
34
+ python3 run_suite.py --suite nightly --timeout-per-file 2400
sglang/.github/workflows/pr-test.yml ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: PR Test
2
+
3
+ on:
4
+ push:
5
+ branches: [ main ]
6
+ paths:
7
+ - "python/sglang/**"
8
+ - "test/**"
9
+ pull_request:
10
+ branches: [ main ]
11
+ paths:
12
+ - "python/sglang/**"
13
+ - "test/**"
14
+ workflow_dispatch:
15
+ inputs:
16
+ version:
17
+ description: "FlashInfer version"
18
+ required: true
19
+ type: choice
20
+ default: 'release'
21
+ options:
22
+ - 'release'
23
+ - 'nightly'
24
+
25
+ concurrency:
26
+ group: pr-test-${{ github.ref }}
27
+ cancel-in-progress: true
28
+
29
+ jobs:
30
+
31
+ unit-test-frontend:
32
+ if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
33
+ runs-on: 1-gpu-runner
34
+ steps:
35
+ - name: Checkout code
36
+ uses: actions/checkout@v3
37
+
38
+ - name: Install dependencies
39
+ env:
40
+ FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.4/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.4/flashinfer' }}
41
+ run: |
42
+ bash scripts/ci_install_dependency.sh
43
+
44
+ - name: Run test
45
+ timeout-minutes: 10
46
+ run: |
47
+ cd test/lang
48
+ python3 run_suite.py --suite per-commit
49
+
50
+ unit-test-backend-1-gpu:
51
+ if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
52
+ runs-on: 1-gpu-runner
53
+ strategy:
54
+ matrix:
55
+ range: [0-6, 6-16, 16-23, 23-30, 30-100]
56
+ steps:
57
+ - name: Checkout code
58
+ uses: actions/checkout@v3
59
+
60
+ - name: Install dependencies
61
+ env:
62
+ FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.4/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.4/flashinfer' }}
63
+ run: |
64
+ bash scripts/ci_install_dependency.sh
65
+
66
+ - name: Run test
67
+ timeout-minutes: 25
68
+ run: |
69
+ cd test/srt
70
+ RANGE=${{ matrix.range }}
71
+ range_begin=${RANGE%-*}
72
+ range_end=${RANGE#*-}
73
+ python3 run_suite.py --suite per-commit --range-begin ${range_begin} --range-end ${range_end}
74
+
75
+ unit-test-backend-2-gpu:
76
+ if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
77
+ runs-on: 2-gpu-runner
78
+ steps:
79
+ - name: Checkout code
80
+ uses: actions/checkout@v3
81
+
82
+ - name: Install dependencies
83
+ env:
84
+ FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.4/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.4/flashinfer' }}
85
+ run: |
86
+ bash scripts/ci_install_dependency.sh
87
+
88
+ - name: Evaluate data parallelism accuracy (DP=2)
89
+ timeout-minutes: 10
90
+ run: |
91
+ cd test/srt
92
+ python3 test_data_parallelism.py
93
+
94
+ - name: Evaluate MLA accuracy (TP=2)
95
+ timeout-minutes: 10
96
+ run: |
97
+ cd test/srt
98
+ python3 test_mla.py
99
+ python3 test_mla_fp8.py
100
+ python3 test_dp_attention.py
101
+
102
+ - name: Test update weights from distributed
103
+ timeout-minutes: 10
104
+ run: |
105
+ cd test/srt
106
+ python3 test_update_weights_from_distributed.py
107
+
108
+ - name: Evaluate MoE EP accuracy (TP=2)
109
+ timeout-minutes: 10
110
+ run: |
111
+ cd test/srt
112
+ python3 test_moe_ep.py
113
+
114
+ performance-test-1-gpu-part-1:
115
+ if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
116
+ runs-on: 1-gpu-runner
117
+ steps:
118
+ - name: Checkout code
119
+ uses: actions/checkout@v3
120
+
121
+ - name: Install dependencies
122
+ env:
123
+ FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.4/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.4/flashinfer' }}
124
+ run: |
125
+ bash scripts/ci_install_dependency.sh
126
+
127
+ - name: Benchmark single latency
128
+ timeout-minutes: 10
129
+ run: |
130
+ cd test/srt
131
+ python3 -m unittest test_bench_one_batch.TestBenchOneBatch.test_default
132
+
133
+ - name: Benchmark online latency
134
+ timeout-minutes: 10
135
+ run: |
136
+ cd test/srt
137
+ python3 -m unittest test_bench_serving.TestBenchServing.test_online_latency_default
138
+
139
+ - name: Benchmark offline throughput
140
+ timeout-minutes: 10
141
+ run: |
142
+ cd test/srt
143
+ python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_default
144
+
145
+ - name: Benchmark offline throughput (Non-streaming, small batch size)
146
+ timeout-minutes: 10
147
+ run: |
148
+ cd test/srt
149
+ python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_non_stream_small_batch_size
150
+
151
+ performance-test-1-gpu-part-2:
152
+ if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
153
+ runs-on: 1-gpu-runner
154
+ steps:
155
+ - name: Checkout code
156
+ uses: actions/checkout@v3
157
+
158
+ - name: Install dependencies
159
+ env:
160
+ FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.4/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.4/flashinfer' }}
161
+ run: |
162
+ bash scripts/ci_install_dependency.sh
163
+
164
+ - name: Benchmark offline throughput (w/o RadixAttention)
165
+ timeout-minutes: 10
166
+ run: |
167
+ cd test/srt
168
+ python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_without_radix_cache
169
+
170
+ - name: Benchmark offline throughput (w/ Triton)
171
+ timeout-minutes: 10
172
+ run: |
173
+ cd test/srt
174
+ python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_with_triton_attention_backend
175
+
176
+ - name: Benchmark offline throughput (w/ FP8)
177
+ timeout-minutes: 10
178
+ run: |
179
+ cd test/srt
180
+ python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_default_fp8
181
+
182
+ performance-test-2-gpu:
183
+ if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
184
+ runs-on: 2-gpu-runner
185
+ steps:
186
+ - name: Checkout code
187
+ uses: actions/checkout@v3
188
+
189
+ - name: Install dependencies
190
+ env:
191
+ FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.4/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.4/flashinfer' }}
192
+ run: |
193
+ bash scripts/ci_install_dependency.sh
194
+
195
+ - name: Benchmark single latency (TP=2)
196
+ timeout-minutes: 10
197
+ run: |
198
+ cd test/srt
199
+ python3 -m unittest test_bench_one_batch.TestBenchOneBatch.test_moe_default
200
+
201
+ - name: Benchmark offline throughput (TP=2)
202
+ timeout-minutes: 10
203
+ run: |
204
+ cd test/srt
205
+ python3 -m unittest test_bench_serving.TestBenchServing.test_moe_offline_throughput_default
206
+
207
+ - name: Benchmark offline throughput (w/o RadixAttention) (TP=2)
208
+ timeout-minutes: 10
209
+ run: |
210
+ cd test/srt
211
+ python3 -m unittest test_bench_serving.TestBenchServing.test_moe_offline_throughput_without_radix_cache
212
+
213
+ accuracy-test-1-gpu:
214
+ if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
215
+ runs-on: 1-gpu-runner
216
+ steps:
217
+ - name: Checkout code
218
+ uses: actions/checkout@v3
219
+
220
+ - name: Install dependencies
221
+ env:
222
+ FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.4/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.4/flashinfer' }}
223
+ run: |
224
+ bash scripts/ci_install_dependency.sh
225
+
226
+ git clone https://github.com/merrymercy/human-eval.git
227
+ cd human-eval
228
+ pip install -e .
229
+
230
+ - name: Evaluate accuracy
231
+ timeout-minutes: 20
232
+ run: |
233
+ cd test/srt
234
+ python3 test_eval_accuracy_large.py
235
+
236
+
237
+ accuracy-test-2-gpu:
238
+ if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
239
+ runs-on: 2-gpu-runner
240
+ steps:
241
+ - name: Checkout code
242
+ uses: actions/checkout@v3
243
+
244
+ - name: Install dependencies
245
+ env:
246
+ FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.4/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.4/flashinfer' }}
247
+ run: |
248
+ bash scripts/ci_install_dependency.sh
249
+
250
+ git clone https://github.com/merrymercy/human-eval.git
251
+ cd human-eval
252
+ pip install -e .
253
+
254
+ - name: Evaluate accuracy (TP=2)
255
+ timeout-minutes: 20
256
+ run: |
257
+ cd test/srt
258
+ python3 test_moe_eval_accuracy_large.py
259
+
260
+
261
+ finish:
262
+ needs: [
263
+ unit-test-frontend, unit-test-backend-1-gpu, unit-test-backend-2-gpu,
264
+ performance-test-1-gpu-part-1, performance-test-1-gpu-part-2, performance-test-2-gpu,
265
+ accuracy-test-1-gpu, accuracy-test-2-gpu
266
+ ]
267
+ runs-on: ubuntu-latest
268
+ steps:
269
+ - name: Finish
270
+ run: echo "This is an empty step to ensure that all jobs are completed."
sglang/.github/workflows/release-docker-dev.yml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Build Development Docker Image
2
+
3
+ on:
4
+ workflow_dispatch:
5
+ schedule:
6
+ - cron: '0 0 * * *'
7
+
8
+ jobs:
9
+ build-dev:
10
+ runs-on: ubuntu-22.04
11
+ steps:
12
+ - name: Checkout repository
13
+ uses: actions/checkout@v3
14
+
15
+ - name: Free disk space
16
+ uses: jlumbroso/free-disk-space@main
17
+ with:
18
+ tool-cache: false
19
+ docker-images: false
20
+ android: true
21
+ dotnet: true
22
+ haskell: true
23
+ large-packages: true
24
+ swap-storage: false
25
+
26
+ - name: Login to Docker Hub
27
+ uses: docker/login-action@v2
28
+ with:
29
+ username: ${{ secrets.DOCKERHUB_USERNAME }}
30
+ password: ${{ secrets.DOCKERHUB_TOKEN }}
31
+
32
+ - name: Build and Push Dev Image
33
+ run: |
34
+ docker build . -f docker/Dockerfile.dev -t lmsysorg/sglang:dev --no-cache
35
+ docker push lmsysorg/sglang:dev
sglang/.github/workflows/release-docker.yml ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Release Docker Images
2
+ on:
3
+ push:
4
+ branches:
5
+ - main
6
+ paths:
7
+ - "python/sglang/version.py"
8
+ workflow_dispatch:
9
+
10
+ jobs:
11
+ publish:
12
+ if: github.repository == 'sgl-project/sglang'
13
+ runs-on: ubuntu-latest
14
+ environment: 'prod'
15
+ strategy:
16
+ matrix:
17
+ cuda_version: ['11.8.0', '12.1.1', '12.4.1']
18
+ build_type: ['all', 'srt']
19
+ steps:
20
+ - name: Delete huge unnecessary tools folder
21
+ run: rm -rf /opt/hostedtoolcache
22
+
23
+ - name: Checkout repository
24
+ uses: actions/checkout@v3
25
+
26
+ - name: Login to Docker Hub
27
+ uses: docker/login-action@v2
28
+ with:
29
+ username: ${{ secrets.DOCKERHUB_USERNAME }}
30
+ password: ${{ secrets.DOCKERHUB_TOKEN }}
31
+
32
+ - name: Build and Push
33
+ run: |
34
+ version=$(cat python/sglang/version.py | cut -d'"' -f2)
35
+
36
+ if [ "${{ matrix.cuda_version }}" = "11.8.0" ]; then
37
+ cuda_tag="cu118"
38
+ elif [ "${{ matrix.cuda_version }}" = "12.1.1" ]; then
39
+ cuda_tag="cu121"
40
+ elif [ "${{ matrix.cuda_version }}" = "12.4.1" ]; then
41
+ cuda_tag="cu124"
42
+ else
43
+ echo "Unsupported CUDA version"
44
+ exit 1
45
+ fi
46
+
47
+ tag=v${version}-${cuda_tag}
48
+
49
+ if [ "${{ matrix.build_type }}" = "all" ]; then
50
+ tag_suffix=""
51
+ elif [ "${{ matrix.build_type }}" = "srt" ]; then
52
+ tag_suffix="-srt"
53
+ else
54
+ echo "Unsupported build type"
55
+ exit 1
56
+ fi
57
+
58
+ docker build . -f docker/Dockerfile --build-arg CUDA_VERSION=${{ matrix.cuda_version }} --build-arg BUILD_TYPE=${{ matrix.build_type }} -t lmsysorg/sglang:${tag}${tag_suffix} --no-cache
59
+ docker push lmsysorg/sglang:${tag}${tag_suffix}
60
+
61
+ if [ "${{ matrix.cuda_version }}" = "12.4.1" ]; then
62
+ docker tag lmsysorg/sglang:${tag}${tag_suffix} lmsysorg/sglang:latest${tag_suffix}
63
+ docker push lmsysorg/sglang:latest${tag_suffix}
64
+ fi
sglang/.github/workflows/release-pypi-kernel.yml ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Release SGLang Kernel to PyPI
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - main
7
+ paths:
8
+ - sgl-kernel/pyproject.toml
9
+ workflow_dispatch:
10
+
11
+ concurrency:
12
+ group: release-pypi-kernel-${{ github.ref }}
13
+ cancel-in-progress: true
14
+
15
+ jobs:
16
+ build-wheels:
17
+ runs-on: ubuntu-latest
18
+ strategy:
19
+ matrix:
20
+ python-version: ['3.8', '3.9', '3.10', '3.11', '3.12']
21
+ cuda-version: ['12.1']
22
+
23
+ steps:
24
+ - uses: actions/checkout@v4
25
+
26
+ - name: Set up Python ${{ matrix.python-version }}
27
+ uses: actions/setup-python@v5
28
+ with:
29
+ python-version: ${{ matrix.python-version }}
30
+
31
+ - name: Build wheels for Python ${{ matrix.python-version }} and CUDA ${{ matrix.cuda-version }}
32
+ run: |
33
+ cd sgl-kernel
34
+ chmod +x ./build.sh
35
+ ./build.sh "${{ matrix.python-version }}" "${{ matrix.cuda-version }}"
36
+
37
+ - name: Upload to pypi
38
+ working-directory: sgl-kernel
39
+ run: |
40
+ pip install twine
41
+ python3 -m twine upload dist/* -u __token__ -p ${{ secrets.PYPI_TOKEN }}
sglang/.github/workflows/release-pypi.yml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Release PyPI
2
+ on:
3
+ push:
4
+ branches:
5
+ - main
6
+ paths:
7
+ - "python/sglang/version.py"
8
+ workflow_dispatch:
9
+
10
+ jobs:
11
+ publish:
12
+ if: github.repository == 'sgl-project/sglang'
13
+ runs-on: ubuntu-latest
14
+ environment: 'prod'
15
+ steps:
16
+ - name: Set up Python
17
+ uses: actions/setup-python@v4
18
+ with:
19
+ python-version: '3.9'
20
+
21
+ - name: Checkout repository
22
+ uses: actions/checkout@v3
23
+
24
+ - name: Upload to pypi
25
+ run: |
26
+ cd python
27
+ cp ../README.md ../LICENSE .
28
+ pip install build
29
+ python3 -m build
30
+ pip install twine
31
+ python3 -m twine upload dist/* -u __token__ -p ${{ secrets.PYPI_TOKEN }}
sglang/3rdparty/amd/profiling/PROFILING.md ADDED
@@ -0,0 +1,425 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Profiling SGLang Infer System with AMD GPUs
2
+ This AppNote describes the SGLang profiling technical, code augment and running steps for systems with AMD Instinct GPUs, nevertheless the same procedure may work with Nvidia GPUs too.
3
+ Examples and steps are provided in detail, to facilitate easy reproduce and use to localize performance problem towards optimizations.
4
+ Two primary methods are covered:
5
+ - [RPD](https://github.com/ROCm/rocmProfileData.git)
6
+ - [PyTorch Profiler](https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html)
7
+
8
+ ### Profiling SGLang Infer System with RPD Profiler
9
+ RPD profiler is a low-overhead cross-platform profiler. Therefore, the same RPD code augment not only works for profiling on ROCm/AMD GPUs, but also works for profiling on CUDA/Nvidia GPUs as well. To do RPD profiling on SGLang repository, please use scripts and patch files included in this directory and follow the steps below:
10
+ 1. Install RPD with rpd.patch applied during installation using install_rpd.sh, both files are in this directory.
11
+
12
+ install_rpd.sh
13
+
14
+ ```bash
15
+ # download and install RPD
16
+ apt update && apt install -y sqlite3 libsqlite3-dev libfmt-dev
17
+
18
+ # install rpd module
19
+ git clone https://github.com/ROCmSoftwarePlatform/rocmProfileData
20
+ cd rocmProfileData
21
+ git checkout 976899e9c6dbc6dd2bccf770818e4e44125590ac
22
+ git apply rpd.patch
23
+ make && make install
24
+ cd rocpd_python && python setup.py install && cd ..
25
+ cd rpd_tracer && make clean;make install && python setup.py install && cd ..
26
+ ```
27
+
28
+ rpd.patch
29
+
30
+ ```bash
31
+ diff --git a/rpd_tracer/Makefile b/rpd_tracer/Makefile
32
+ index e9d9feb..b2e9e1a 100644
33
+ --- a/rpd_tracer/Makefile
34
+ +++ b/rpd_tracer/Makefile
35
+ @@ -16,7 +16,7 @@ ifneq (,$(HIP_PATH))
36
+ $(info Building with roctracer)
37
+ RPD_LIBS += -L/opt/rocm/lib -lroctracer64 -lroctx64 -lamdhip64 -lrocm_smi64
38
+ RPD_INCLUDES += -I/opt/rocm/include -I/opt/rocm/include/roctracer -I/opt/rocm/include/hsa
39
+ - RPD_SRCS += RoctracerDataSource.cpp RocmSmiDataSource.cpp
40
+ + RPD_SRCS += RoctracerDataSource.cpp
41
+ RPD_INCLUDES += -D__HIP_PLATFORM_AMD__
42
+ endif
43
+ ```
44
+ 2. Add loadTracer.sh file included in this directory to /sglang/python/sglang.
45
+
46
+ loadTracer.sh
47
+
48
+ ```bash
49
+ #!/bin/bash
50
+ ################################################################################
51
+ # Copyright (c) 2021 - 2023 Advanced Micro Devices, Inc. All rights reserved.
52
+ #
53
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
54
+ # of this software and associated documentation files (the "Software"), to deal
55
+ # in the Software without restriction, including without limitation the rights
56
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
57
+ # copies of the Software, and to permit persons to whom the Software is
58
+ # furnished to do so, subject to the following conditions:
59
+ #
60
+ # The above copyright notice and this permission notice shall be included in
61
+ # all copies or substantial portions of the Software.
62
+ #
63
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
64
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
65
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
66
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
67
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
68
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
69
+ # THE SOFTWARE.
70
+ ################################################################################
71
+ OUTPUT_FILE="trace.rpd"
72
+
73
+ if [ "$1" = "-o" ] ; then
74
+ OUTPUT_FILE=$2
75
+ shift
76
+ shift
77
+ fi
78
+
79
+ if [ -e ${OUTPUT_FILE} ] ; then
80
+ rm ${OUTPUT_FILE}
81
+ fi
82
+
83
+ python3 -m rocpd.schema --create ${OUTPUT_FILE}
84
+ if [ $? != 0 ] ; then
85
+ echo "Error: Could not create rpd file. Please run 'python setup.py install' from the rocpd_python dir"
86
+ exit
87
+ fi
88
+
89
+ export RPDT_FILENAME=${OUTPUT_FILE}
90
+ export RPDT_AUTOSTART=0
91
+ LD_PRELOAD=librocm-smi_64:librpd_tracer.so "$@"
92
+ ```
93
+ 3. Apply patch (provided in this directory) with "git apply rpd_profile_server_enable.patch" if the main profiling purpose is to get info on gpu kernels as well as limited cpu activity info.
94
+
95
+ #### Common Notes 1
96
+ Please note that although we are doing TP=8 in the example, we purposely only log RPD profiling on 2 ranks in the patch file (i.e.tp_rank=0/1) for profiling/visualization convenience, as even Perfetto streaming mode can only load maximal 8GB json file for visualization. With 2 ranks logged in RPD profiling, we could still check whether there are issues among ranks (e.g. load imbalance issue, nccl issue), and at the same time, we could log relatively longer time duration before the json file generated from RPD file hits 8GB size.
97
+
98
+ rpd_profile_server_enable.patch
99
+
100
+ ```bash
101
+ diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py
102
+ index 62d1ff9..9021c01 100644
103
+ --- a/python/sglang/srt/managers/scheduler.py
104
+ +++ b/python/sglang/srt/managers/scheduler.py
105
+ @@ -71,6 +71,8 @@ from sglang.srt.utils import (
106
+ suppress_other_loggers,
107
+ )
108
+ from sglang.utils import get_exception_traceback
109
+ +from rpdTracerControl import rpdTracerControl
110
+ +rpdTracerControl.skipCreate()
111
+
112
+ logger = logging.getLogger(__name__)
113
+
114
+ @@ -245,6 +247,7 @@ class Scheduler:
115
+ ],
116
+ with_stack=True,
117
+ )
118
+ + self.rpd = rpdTracerControl()
119
+
120
+ @torch.inference_mode()
121
+ def event_loop(self):
122
+ @@ -1027,15 +1030,24 @@ class Scheduler:
123
+ def start_profile(self) -> None:
124
+ if self.profiler is None:
125
+ raise RuntimeError("Profiler is not enabled.")
126
+ - self.profiler.start()
127
+ + #self.profiler.start() #block pytorch profiler for rpd profiler enabling
128
+ + if self.tp_rank == 0 or self.tp_rank == 1:
129
+ + self.rpd.start()
130
+ + self.rpd.rangePush("", "rpd profile range", "")
131
+ + logger.info("rpd is enabled")
132
+
133
+ def stop_profile(self) -> None:
134
+ if self.profiler is None:
135
+ raise RuntimeError("Profiler is not enabled.")
136
+ - self.profiler.stop()
137
+ - self.profiler.export_chrome_trace(
138
+ - self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
139
+ - )
140
+ + #self.profiler.stop()
141
+ + #self.profiler.export_chrome_trace(
142
+ + # self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
143
+ + #)
144
+ + if self.tp_rank ==0 or self.tp_rank ==1:
145
+ + self.rpd.rangePop()
146
+ + self.rpd.stop()
147
+ + self.rpd.flush()
148
+ + logger.info("rpd is done")
149
+ logger.info("Profiler is done")
150
+ ```
151
+
152
+ #### Advanced Debugging with RPD Profiler
153
+ Sometimes, we want to use rpd profiler to capture more CPU and python activities in order to debug some challenging issues (e.g. root cause of load imbalance across gpu processes, root cause of bubbles, etc). Only in such cases, we need to apply patch "git apply rpd_profile_server_enable_wCPU_activities.patch", where 3 files are modified.
154
+
155
+ rpd_profile_server_enable_wCPU_activities.patch
156
+
157
+ ```bash
158
+ diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py
159
+ index 62d1ff9..2edb427 100644
160
+ --- a/python/sglang/srt/managers/scheduler.py
161
+ +++ b/python/sglang/srt/managers/scheduler.py
162
+ @@ -71,6 +71,8 @@ from sglang.srt.utils import (
163
+ suppress_other_loggers,
164
+ )
165
+ from sglang.utils import get_exception_traceback
166
+ +from rpdTracerControl import rpdTracerControl
167
+ +rpdTracerControl.skipCreate()
168
+
169
+ logger = logging.getLogger(__name__)
170
+
171
+ @@ -245,6 +247,7 @@ class Scheduler:
172
+ ],
173
+ with_stack=True,
174
+ )
175
+ + self.rpd = rpdTracerControl()
176
+
177
+ @torch.inference_mode()
178
+ def event_loop(self):
179
+ @@ -1027,15 +1030,26 @@ class Scheduler:
180
+ def start_profile(self) -> None:
181
+ if self.profiler is None:
182
+ raise RuntimeError("Profiler is not enabled.")
183
+ - self.profiler.start()
184
+ + #self.profiler.start()
185
+ + logger.info("torch profiler is disabled")
186
+ + if self.tp_rank == 0 or self.tp_rank == 1:
187
+ + self.rpd.setPythonTrace(True)
188
+ + self.rpd.start()
189
+ + self.rpd.rangePush("", "scheduler", "")
190
+ + logger.info("rpd is enabled inside scheduler profiling")
191
+
192
+ def stop_profile(self) -> None:
193
+ if self.profiler is None:
194
+ raise RuntimeError("Profiler is not enabled.")
195
+ - self.profiler.stop()
196
+ - self.profiler.export_chrome_trace(
197
+ - self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
198
+ - )
199
+ + #self.profiler.stop()
200
+ + #self.profiler.export_chrome_trace(
201
+ + # self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
202
+ + #)
203
+ + if self.tp_rank ==0 or self.tp_rank ==1:
204
+ + self.rpd.rangePop()
205
+ + self.rpd.stop()
206
+ + self.rpd.flush()
207
+ + logger.info("rpd is done inside scheduler")
208
+ logger.info("Profiler is done")
209
+
210
+
211
+ diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py
212
+ index 2621ccd..181df85 100644
213
+ --- a/python/sglang/srt/managers/tokenizer_manager.py
214
+ +++ b/python/sglang/srt/managers/tokenizer_manager.py
215
+ @@ -58,6 +58,10 @@ from sglang.srt.sampling.sampling_params import SamplingParams
216
+ from sglang.srt.server_args import PortArgs, ServerArgs
217
+ from sglang.srt.utils import is_generation_model, is_multimodal_model
218
+
219
+ +from rpdTracerControl import rpdTracerControl
220
+ +rpdTracerControl.skipCreate()
221
+ +
222
+ +
223
+ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
224
+
225
+ logger = logging.getLogger(__name__)
226
+ @@ -514,10 +518,20 @@ class TokenizerManager:
227
+ self.send_to_scheduler.send_pyobj(req)
228
+
229
+ def start_profile(self):
230
+ + rpd = rpdTracerControl()
231
+ + rpd.setPythonTrace(True)
232
+ + rpd.start()
233
+ + rpd.rangePush("", "tokenizer_manager", "")
234
+ + logger.info("tokenizer_manager rpd profiling started!")
235
+ req = ProfileReq.START_PROFILE
236
+ self.send_to_scheduler.send_pyobj(req)
237
+
238
+ def stop_profile(self):
239
+ + rpd = rpdTracerControl()
240
+ + rpd.rangePop()
241
+ + rpd.stop()
242
+ + rpd.flush()
243
+ + logger.info("rpd profiling is done inside tokenizer_manager!")
244
+ req = ProfileReq.STOP_PROFILE
245
+ self.send_to_scheduler.send_pyobj(req)
246
+
247
+ diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py
248
+ index 7111c93..2bd722c 100644
249
+ --- a/python/sglang/srt/server.py
250
+ +++ b/python/sglang/srt/server.py
251
+ @@ -30,6 +30,8 @@ import threading
252
+ import time
253
+ from http import HTTPStatus
254
+ from typing import Dict, List, Optional, Union
255
+ +from rpdTracerControl import rpdTracerControl
256
+ +rpdTracerControl.skipCreate()
257
+
258
+ # Fix a bug of Python threading
259
+ setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
260
+ @@ -152,6 +154,11 @@ async def flush_cache():
261
+ @app.post("/start_profile")
262
+ async def start_profile():
263
+ """Start profiling."""
264
+ + rpd = rpdTracerControl()
265
+ + rpd.setPythonTrace(True)
266
+ + rpd.start()
267
+ + rpd.rangePush("", "server rpd profile range", "")
268
+ + logger.info("rpd profiling started in server.py!")
269
+ tokenizer_manager.start_profile()
270
+ return Response(
271
+ content="Start profiling.\n",
272
+ @@ -164,6 +171,11 @@ async def start_profile():
273
+ async def stop_profile():
274
+ """Stop profiling."""
275
+ tokenizer_manager.stop_profile()
276
+ + rpd = rpdTracerControl()
277
+ + rpd.rangePop()
278
+ + rpd.stop()
279
+ + rpd.flush()
280
+ + logger.info("rpd profiling is done in server.py!")
281
+ return Response(
282
+ content="Stop profiling. This will take some time.\n",
283
+ status_code=200,
284
+ ```
285
+
286
+ 4. As an example for grok1 profiling, we create a dummy_grok1 directory with config.json (see content below) inside this directory and copy this directory to the right path for "--model-path" if you want to use the example server.sh file provided.
287
+ ```bash
288
+ cat ../dummy_grok1/config.json
289
+ {
290
+ "architectures": [
291
+ "Grok1ModelForCausalLM"
292
+ ],
293
+ "embedding_multiplier_scale": 78.38367176906169,
294
+ "output_multiplier_scale": 0.5773502691896257,
295
+ "vocab_size": 131072,
296
+ "hidden_size": 6144,
297
+ "intermediate_size": 32768,
298
+ "max_position_embeddings": 8192,
299
+ "num_experts_per_tok": 2,
300
+ "num_local_experts": 8,
301
+ "num_attention_heads": 48,
302
+ "num_hidden_layers": 64,
303
+ "num_key_value_heads": 8,
304
+ "head_dim": 128,
305
+ "rms_norm_eps": 1e-05,
306
+ "rope_theta": 10000.0,
307
+ "model_type": "mixtral",
308
+ "torch_dtype": "bfloat16"
309
+ }
310
+ ```
311
+ 5. Launch server with rpd enabled script ./server.sh in one terminal inside the docker container.
312
+
313
+ #### Common Notes 2
314
+ - Remember to change model-path to the correct path
315
+ - loadTracer.sh is needed to conduct profiling
316
+ - SGLANG_TORCH_PROFILER_DIR is used for default torch profiler
317
+ - Do not use loadTracer.sh if you are using the torch profiler, simply use python3 -m sglang.launch_server.
318
+
319
+
320
+ server.sh
321
+
322
+ ```bash
323
+ #!/bin/bash
324
+
325
+ # export SGLANG_TORCH_PROFILER_DIR=/data/sglang/
326
+ export SGLANG_TORCH_PROFILER_DIR=/sgl-workspace/sglang/profile/
327
+
328
+ # Get the current timestamp
329
+ TIMESTAMP=$(date +"%Y%m%d_%H%M%S")
330
+
331
+ # Define the log file with a timestamp
332
+ LOGFILE="sglang_server_log_$TIMESTAMP.json"
333
+
334
+ # Run the Python command and save the output to the log file
335
+ loadTracer.sh python3 -m sglang.launch_server \
336
+ --model-path /sgl-workspace/sglang/dummy_grok1 \
337
+ --tokenizer-path Xenova/grok-1-tokenizer \
338
+ --load-format dummy \
339
+ --quant fp8 \
340
+ --tp 8 \
341
+ --port 30000 \
342
+ --disable-radix-cache 2>&1 | tee "$LOGFILE"
343
+ ```
344
+ 6. Open another terminal for the same docker container, and run the rpd enabled ./client.sh after you see "The server is fired up and is ready to roll!" message from server side terminal.
345
+
346
+ #### Common Notes 3
347
+ - Use curl http://localhost:30000/start_profile & curl http://localhost:30000/stop_profile to control the start and end of profiling. Check sglang/python/sglang/srt/managers/scheduler.py for more details.
348
+ - Please don't use RPD profiler together with PyTorch profiler to avoid interference.
349
+ - The rocmProfileData/tools/rpd2tracing.py file is used to generate json file from RPD file.
350
+
351
+ client.sh
352
+
353
+ ```bash
354
+ #!/bin/bash
355
+
356
+ # Start profiling via API
357
+ curl http://localhost:30000/start_profile -H "Content-Type: application/json"
358
+
359
+ # Benchmark serving using sglang with random dataset and tokenizer
360
+ # Define the log file with a timestamp
361
+ TIMESTAMP=$(date +%Y%m%d_%H%M%S)
362
+ LOGFILE="sglang_client_log_$TIMESTAMP.json"
363
+
364
+ # Run the benchmark with specified parameters and save logs
365
+ python3 -m sglang.bench_serving \
366
+ --backend sglang \
367
+ --tokenizer Xenova/grok-1-tokenizer \
368
+ --dataset-name random \
369
+ --random-input 1024\
370
+ --random-output 1024 \
371
+ --num-prompts 120 \
372
+ --request-rate 8 \
373
+ --output-file online.jsonl 2>&1 | tee "$LOGFILE"
374
+
375
+ # Stop profiling via API
376
+ curl http://localhost:30000/stop_profile -H "Content-Type: application/json"
377
+
378
+ # Convert tracing file to csv & json
379
+ sqlite3 trace.rpd ".mode csv" ".header on" ".output trace.csv" "select * from top;" ".output stdout"
380
+ python3 ./rocmProfileData/tools/rpd2tracing.py trace.rpd trace.json
381
+ ```
382
+ 7. Follow [Perfetto docs](https://perfetto.dev/docs/visualization/large-traces) to visualize large json files. Try to adjust parameters so that the trace.json file size is less than 9GB.
383
+
384
+ ### Profiling SGLang Infer System with PyTorch Profiler
385
+
386
+ Please use the steps as follows:
387
+
388
+ 1. Apply the patch torch_profiler.patch. Note that you can modify "if self.tp_rank == 0" in the patch to allow more ranks be recorded in profiling.
389
+
390
+ torch_profiler.patch
391
+ ```bash
392
+ diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py
393
+ index 62d1ff9..6ecd78c 100644
394
+ --- a/python/sglang/srt/managers/scheduler.py
395
+ +++ b/python/sglang/srt/managers/scheduler.py
396
+ @@ -240,7 +240,6 @@ class Scheduler:
397
+ )
398
+ self.profiler = torch.profiler.profile(
399
+ activities=[
400
+ - torch.profiler.ProfilerActivity.CPU,
401
+ torch.profiler.ProfilerActivity.CUDA,
402
+ ],
403
+ with_stack=True,
404
+ @@ -1033,9 +1032,11 @@ class Scheduler:
405
+ if self.profiler is None:
406
+ raise RuntimeError("Profiler is not enabled.")
407
+ self.profiler.stop()
408
+ - self.profiler.export_chrome_trace(
409
+ - self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
410
+ - )
411
+ + if self.tp_rank == 0:
412
+ + with open(f"stats_repro_{int(time.time())}.txt", "w") as f:
413
+ + print(self.profiler.key_averages(group_by_input_shape=True).table(sort_by="cuda_time_total", row_limit=-1), file=f)
414
+ + print("Profiling stats done.")
415
+ +
416
+ logger.info("Profiler is done")
417
+ ```
418
+
419
+ 2. Create the model path directory and copy it to the right path for "--model-path" if you want to use the server.sh file provided.
420
+
421
+ 3. Modify the included server.sh by removing "loadTracer.sh" before python command and launch script ./server.sh in one terminal inside the docker container.
422
+
423
+ 4. Similar to step 6 in RPD profiling section, but remove the last 2 lines in client.sh, which converted rpd file into csv and json files. Run modified client.sh for PyTorch profiling.
424
+ -------
425
+ - [Torch Profiler](https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html)
sglang/3rdparty/amd/profiling/client.sh ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Start profiling via API
4
+ curl http://localhost:30000/start_profile -H "Content-Type: application/json"
5
+
6
+ # Benchmark serving using sglang with random dataset and tokenizer
7
+ # Define the log file with a timestamp
8
+ TIMESTAMP=$(date +%Y%m%d_%H%M%S)
9
+ LOGFILE="sglang_client_log_$TIMESTAMP.json"
10
+
11
+ # Run the benchmark with specified parameters and save logs
12
+ python3 -m sglang.bench_serving \
13
+ --backend sglang \
14
+ --tokenizer Xenova/grok-1-tokenizer \
15
+ --dataset-name random \
16
+ --random-input 1024\
17
+ --random-output 1024 \
18
+ --num-prompts 240 \
19
+ --request-rate 8 \
20
+ --output-file online.jsonl 2>&1 | tee "$LOGFILE"
21
+
22
+ # Stop profiling via API
23
+ curl http://localhost:30000/stop_profile -H "Content-Type: application/json"
24
+
25
+ # Convert tracing file to csv & json
26
+ sqlite3 trace.rpd ".mode csv" ".header on" ".output trace.csv" "select * from top;" ".output stdout"
27
+ python3 /sgl-workspace/rocmProfileData/tools/rpd2tracing.py trace.rpd trace.json
sglang/3rdparty/amd/profiling/install_rpd.sh ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # download and install RPD
2
+ apt update && apt install -y sqlite3 libsqlite3-dev libfmt-dev
3
+
4
+ # install rpd module
5
+ git clone https://github.com/ROCmSoftwarePlatform/rocmProfileData
6
+ cd rocmProfileData
7
+ git apply rpd.patch
8
+ make && make install
9
+ cd rocpd_python && python setup.py install && cd ..
10
+ cd rpd_tracer && make clean;make install && python setup.py install && cd ..
sglang/3rdparty/amd/profiling/loadTracer.sh ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ ################################################################################
3
+ # Copyright (c) 2021 - 2023 Advanced Micro Devices, Inc. All rights reserved.
4
+ #
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+ #
12
+ # The above copyright notice and this permission notice shall be included in
13
+ # all copies or substantial portions of the Software.
14
+ #
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
21
+ # THE SOFTWARE.
22
+ ################################################################################
23
+ OUTPUT_FILE="trace.rpd"
24
+
25
+ if [ "$1" = "-o" ] ; then
26
+ OUTPUT_FILE=$2
27
+ shift
28
+ shift
29
+ fi
30
+
31
+ if [ -e ${OUTPUT_FILE} ] ; then
32
+ rm ${OUTPUT_FILE}
33
+ fi
34
+
35
+ python3 -m rocpd.schema --create ${OUTPUT_FILE}
36
+ if [ $? != 0 ] ; then
37
+ echo "Error: Could not create rpd file. Please run 'python setup.py install' from the rocpd_python dir"
38
+ exit
39
+ fi
40
+
41
+ export RPDT_FILENAME=${OUTPUT_FILE}
42
+ export RPDT_AUTOSTART=0
43
+ LD_PRELOAD=librocm-smi_64:librpd_tracer.so "$@"
sglang/3rdparty/amd/profiling/rpd.patch ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diff --git a/rpd_tracer/Makefile b/rpd_tracer/Makefile
2
+ index e9d9feb..b2e9e1a 100644
3
+ --- a/rpd_tracer/Makefile
4
+ +++ b/rpd_tracer/Makefile
5
+ @@ -16,7 +16,7 @@ ifneq (,$(HIP_PATH))
6
+ $(info Building with roctracer)
7
+ RPD_LIBS += -L/opt/rocm/lib -lroctracer64 -lroctx64 -lamdhip64 -lrocm_smi64
8
+ RPD_INCLUDES += -I/opt/rocm/include -I/opt/rocm/include/roctracer -I/opt/rocm/include/hsa
9
+ - RPD_SRCS += RoctracerDataSource.cpp RocmSmiDataSource.cpp
10
+ + RPD_SRCS += RoctracerDataSource.cpp
11
+ RPD_INCLUDES += -D__HIP_PLATFORM_AMD__
12
+ endif
sglang/3rdparty/amd/profiling/rpd_profile_server_enable.patch ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py
2
+ index 62d1ff9..9021c01 100644
3
+ --- a/python/sglang/srt/managers/scheduler.py
4
+ +++ b/python/sglang/srt/managers/scheduler.py
5
+ @@ -71,6 +71,8 @@ from sglang.srt.utils import (
6
+ suppress_other_loggers,
7
+ )
8
+ from sglang.utils import get_exception_traceback
9
+ +from rpdTracerControl import rpdTracerControl
10
+ +rpdTracerControl.skipCreate()
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+ @@ -245,6 +247,7 @@ class Scheduler:
15
+ ],
16
+ with_stack=True,
17
+ )
18
+ + self.rpd = rpdTracerControl()
19
+
20
+ @torch.inference_mode()
21
+ def event_loop(self):
22
+ @@ -1027,15 +1030,24 @@ class Scheduler:
23
+ def start_profile(self) -> None:
24
+ if self.profiler is None:
25
+ raise RuntimeError("Profiler is not enabled.")
26
+ - self.profiler.start()
27
+ + #self.profiler.start() #block pytorch profiler for rpd profiler enabling
28
+ + if self.tp_rank == 0 or self.tp_rank == 1:
29
+ + self.rpd.start()
30
+ + self.rpd.rangePush("", "rpd profile range", "")
31
+ + logger.info("rpd is enabled")
32
+
33
+ def stop_profile(self) -> None:
34
+ if self.profiler is None:
35
+ raise RuntimeError("Profiler is not enabled.")
36
+ - self.profiler.stop()
37
+ - self.profiler.export_chrome_trace(
38
+ - self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
39
+ - )
40
+ + #self.profiler.stop()
41
+ + #self.profiler.export_chrome_trace(
42
+ + # self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
43
+ + #)
44
+ + if self.tp_rank ==0 or self.tp_rank ==1:
45
+ + self.rpd.rangePop()
46
+ + self.rpd.stop()
47
+ + self.rpd.flush()
48
+ + logger.info("rpd is done")
49
+ logger.info("Profiler is done")
sglang/3rdparty/amd/profiling/rpd_profile_server_enable_wCPU_activities.patch ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py
2
+ index 62d1ff9..2edb427 100644
3
+ --- a/python/sglang/srt/managers/scheduler.py
4
+ +++ b/python/sglang/srt/managers/scheduler.py
5
+ @@ -71,6 +71,8 @@ from sglang.srt.utils import (
6
+ suppress_other_loggers,
7
+ )
8
+ from sglang.utils import get_exception_traceback
9
+ +from rpdTracerControl import rpdTracerControl
10
+ +rpdTracerControl.skipCreate()
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+ @@ -245,6 +247,7 @@ class Scheduler:
15
+ ],
16
+ with_stack=True,
17
+ )
18
+ + self.rpd = rpdTracerControl()
19
+
20
+ @torch.inference_mode()
21
+ def event_loop(self):
22
+ @@ -1027,15 +1030,26 @@ class Scheduler:
23
+ def start_profile(self) -> None:
24
+ if self.profiler is None:
25
+ raise RuntimeError("Profiler is not enabled.")
26
+ - self.profiler.start()
27
+ + #self.profiler.start()
28
+ + logger.info("torch profiler is disabled")
29
+ + if self.tp_rank == 0 or self.tp_rank == 1:
30
+ + self.rpd.setPythonTrace(True)
31
+ + self.rpd.start()
32
+ + self.rpd.rangePush("", "scheduler", "")
33
+ + logger.info("rpd is enabled inside scheduler profiling")
34
+
35
+ def stop_profile(self) -> None:
36
+ if self.profiler is None:
37
+ raise RuntimeError("Profiler is not enabled.")
38
+ - self.profiler.stop()
39
+ - self.profiler.export_chrome_trace(
40
+ - self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
41
+ - )
42
+ + #self.profiler.stop()
43
+ + #self.profiler.export_chrome_trace(
44
+ + # self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
45
+ + #)
46
+ + if self.tp_rank ==0 or self.tp_rank ==1:
47
+ + self.rpd.rangePop()
48
+ + self.rpd.stop()
49
+ + self.rpd.flush()
50
+ + logger.info("rpd is done inside scheduler")
51
+ logger.info("Profiler is done")
52
+
53
+
54
+ diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py
55
+ index 2621ccd..181df85 100644
56
+ --- a/python/sglang/srt/managers/tokenizer_manager.py
57
+ +++ b/python/sglang/srt/managers/tokenizer_manager.py
58
+ @@ -58,6 +58,10 @@ from sglang.srt.sampling.sampling_params import SamplingParams
59
+ from sglang.srt.server_args import PortArgs, ServerArgs
60
+ from sglang.srt.utils import is_generation_model, is_multimodal_model
61
+
62
+ +from rpdTracerControl import rpdTracerControl
63
+ +rpdTracerControl.skipCreate()
64
+ +
65
+ +
66
+ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
67
+
68
+ logger = logging.getLogger(__name__)
69
+ @@ -514,10 +518,20 @@ class TokenizerManager:
70
+ self.send_to_scheduler.send_pyobj(req)
71
+
72
+ def start_profile(self):
73
+ + rpd = rpdTracerControl()
74
+ + rpd.setPythonTrace(True)
75
+ + rpd.start()
76
+ + rpd.rangePush("", "tokenizer_manager", "")
77
+ + logger.info("tokenizer_manager rpd profiling started!")
78
+ req = ProfileReq.START_PROFILE
79
+ self.send_to_scheduler.send_pyobj(req)
80
+
81
+ def stop_profile(self):
82
+ + rpd = rpdTracerControl()
83
+ + rpd.rangePop()
84
+ + rpd.stop()
85
+ + rpd.flush()
86
+ + logger.info("rpd profiling is done inside tokenizer_manager!")
87
+ req = ProfileReq.STOP_PROFILE
88
+ self.send_to_scheduler.send_pyobj(req)
89
+
90
+ diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py
91
+ index 7111c93..2bd722c 100644
92
+ --- a/python/sglang/srt/server.py
93
+ +++ b/python/sglang/srt/server.py
94
+ @@ -30,6 +30,8 @@ import threading
95
+ import time
96
+ from http import HTTPStatus
97
+ from typing import Dict, List, Optional, Union
98
+ +from rpdTracerControl import rpdTracerControl
99
+ +rpdTracerControl.skipCreate()
100
+
101
+ # Fix a bug of Python threading
102
+ setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
103
+ @@ -152,6 +154,11 @@ async def flush_cache():
104
+ @app.post("/start_profile")
105
+ async def start_profile():
106
+ """Start profiling."""
107
+ + rpd = rpdTracerControl()
108
+ + rpd.setPythonTrace(True)
109
+ + rpd.start()
110
+ + rpd.rangePush("", "server rpd profile range", "")
111
+ + logger.info("rpd profiling started in server.py!")
112
+ tokenizer_manager.start_profile()
113
+ return Response(
114
+ content="Start profiling.\n",
115
+ @@ -164,6 +171,11 @@ async def start_profile():
116
+ async def stop_profile():
117
+ """Stop profiling."""
118
+ tokenizer_manager.stop_profile()
119
+ + rpd = rpdTracerControl()
120
+ + rpd.rangePop()
121
+ + rpd.stop()
122
+ + rpd.flush()
123
+ + logger.info("rpd profiling is done in server.py!")
124
+ return Response(
125
+ content="Stop profiling. This will take some time.\n",
126
+ status_code=200,
sglang/3rdparty/amd/profiling/server.sh ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # export SGLANG_TORCH_PROFILER_DIR=/data/sglang/
4
+ export SGLANG_TORCH_PROFILER_DIR=/sgl-workspace/sglang/profile/
5
+
6
+ # Get the current timestamp
7
+ TIMESTAMP=$(date +"%Y%m%d_%H%M%S")
8
+
9
+ # Define the log file with a timestamp
10
+ LOGFILE="sglang_server_log_$TIMESTAMP.json"
11
+
12
+ # Run the Python command and save the output to the log file
13
+ loadTracer.sh python3 -m sglang.launch_server \
14
+ --model-path /sgl-workspace/sglang/dummy_grok1 \
15
+ --tokenizer-path Xenova/grok-1-tokenizer \
16
+ --load-format dummy \
17
+ --quant fp8 \
18
+ --tp 8 \
19
+ --port 30000 \
20
+ --disable-radix-cache 2>&1 | tee "$LOGFILE"
sglang/3rdparty/amd/tuning/TUNING.md ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Tuning SGLang Infer System with AMD GPUs
2
+ This AppNote describes the SGLang performance tuning technical, code harness and running steps for systems with AMD Instinct GPUs.
3
+ Harness code, examples and steps are provided in detail, to facilitate easy reproduce & use to tune performance towards workloads.
4
+ Three primary runtime areas are covered:
5
+
6
+ ## 1. Triton Kernels
7
+ To maximize Triton kernel efficiency, several strategies can be employed:
8
+
9
+ ### Key Environment Variables:
10
+ - **num_stages**: Adjusts the number of pipeline stages to optimize kernel efficiency based on the specific type of operations (e.g., General Matrix Multiplication - GEMM).
11
+ - **waves_per_eu**: Controls the usage of Vector General Purpose Registers (VGPR) to enhance occupancy, thereby improving latency or throughput.
12
+ - **BLOCK_M, BLOCK_N, BLOCK_K**: Tunable tile sizes that assist in balancing memory transfer and computational efficiency.
13
+ - **matrix_instr_nonkdim**: Optimizes the usage of Matrix-Fused Multiply-Add (MFMA) instructions for specific kernel types, such as Flash Attention.
14
+ - **OPTIMIZE_EPILOGUE**: An environment variable that can be set to `1` to enhance performance by eliminating the `convert_layout` operation in the kernel's epilogue.
15
+ ```python
16
+ @triton.autotune(configs=[
17
+ triton.Config({'waves_per_eu': 1}, num_warps=4, num_stages=1),
18
+ triton.Config({'waves_per_eu': 1}, num_warps=8, num_stages=1),
19
+ triton.Config({'waves_per_eu': 1}, num_warps=16, num_stages=1),
20
+ triton.Config({'waves_per_eu': 2}, num_warps=4, num_stages=1),
21
+ triton.Config({'waves_per_eu': 2}, num_warps=8, num_stages=1),
22
+ triton.Config({'waves_per_eu': 2}, num_warps=16, num_stages=1),
23
+ triton.Config({'waves_per_eu': 4}, num_warps=4, num_stages=1),
24
+ triton.Config({'waves_per_eu': 4}, num_warps=8, num_stages=1),
25
+ triton.Config({'waves_per_eu': 4}, num_warps=16, num_stages=1),
26
+ ], key=['BLOCK_N', 'NUM_TOKEN_BLKS'], use_cuda_graph=True)
27
+ @triton.jit
28
+ def _triton_kernel_funtion():
29
+ ...
30
+ ```
31
+ ## 2. Torch Tunable Operations
32
+ **TunableOp** is a feature in PyTorch that allows for the definition and optimization of custom kernels with tunable parameters. This feature is particularly useful for enhancing the performance of kernels by experimenting with different configurations.
33
+
34
+ ### Key Environment Variables:
35
+ 1. **PYTORCH_TUNABLEOP_ENABLED**:
36
+ - Default: `0`
37
+ - Set to `1` to enable TunableOp.
38
+
39
+ 2. **PYTORCH_TUNABLEOP_TUNING**:
40
+ - Default: `1`
41
+ - Set to `0` to disable tuning. If a tuned entry is not found, it will run the tuning step and record the entry when PYTORCH_TUNABLEOP_ENABLED is enabled.
42
+
43
+ 3. **PYTORCH_TUNABLEOP_VERBOSE**:
44
+ - Default: `0`
45
+ - Set to `1` to enable verbose output for TunableOp.
46
+
47
+ ### Usage Example:
48
+ To enable TunableOp and tuning, and optionally enable verbose mode, you can run the following command in your terminal:
49
+
50
+ ```bash
51
+ #Tuning
52
+ PYTORCH_TUNABLEOP_ENABLED=1 PYTORCH_TUNABLEOP_TUNING=1 your_script.sh
53
+
54
+ #Inference with tuning op
55
+ PYTORCH_TUNABLEOP_ENABLED=1 PYTORCH_TUNABLEOP_TUNING=0 your_script.sh
56
+
57
+ #Print out the log
58
+ PYTORCH_TUNABLEOP_ENABLED=1 PYTORCH_TUNABLEOP_TUNING=0 PYTORCH_TUNABLEOP_VERBOSE=1 your_script.sh
59
+
60
+ ```
61
+ ## 3. Torch Compilation
62
+
63
+
64
+ The following are suggestions for optimizing matrix multiplication (GEMM) and convolution (conv) operations in PyTorch using Inductor, a part of the PyTorch compilation framework. The goal is to leverage Triton to achieve better performance.
65
+
66
+ To tune Triton kernels with GEMM and convolution ops (conv), use the `torch.compile` function with the max-autotune mode. This benchmarks a predefined list of Triton configurations and selects the fastest one for each shape.
67
+
68
+ ### Key Configurations:
69
+ 1. **Max Autotune**:
70
+ - Set `torch._inductor.config.max_autotune = True` or `TORCHINDUCTOR_MAX_AUTOTUNE=1`.
71
+
72
+ 2. **Fine-Grained Control**:
73
+ - Enable GEMM tuning: `torch._inductor.config.max_autotune_gemm = True`.
74
+ - Enable tuning for pointwise/reduction ops: `torch._inductor.config.max_autotune.pointwise = True`.
75
+
76
+ 3. **Backend Selection**:
77
+ - Use `torch._inductor.max_autotune_gemm_backends` to limit backends to TRITON for better performance.
78
+
79
+ 4. **Freezing for Inference**:
80
+ - Use `torch._inductor.config.freezing=True` to enable constant folding optimizations.
81
+
82
+ 5. **Debugging**:
83
+ - Set `TORCH_COMPILE_DEBUG=1` to extract Triton kernels generated by Inductor.
84
+
85
+ ### Example Code Block:
86
+ ```bash
87
+ #Gemm Tuning
88
+ TORCHINDUCTOR_MAX_AUTOTUNE=1 TORCHINDUCTOR_COORDINATE_DESCENT_TUNING=1 your_script.sh
89
+
90
+ #Specify your backend to TRITON for Gemm Tuning
91
+ TORCHINDUCTOR_MAX_AUTOTUNE=1 TORCHINDUCTOR_COORDINATE_DESCENT_TUNING=1 TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_BACKENDS=TRITON your_script.sh
92
+
93
+ #Inference with large improvement on AMD GPU
94
+ TORCHINDUCTOR_FREEZING=1 your_script.sh
95
+ ```
96
+ ## 4. Fused MOE kernel
97
+ To maximize moe kernel efficiency, need to use below scripts to find out the best launch configuration
98
+
99
+ ### Key parameters:
100
+ - **--model**: what moe model type to do tuning, it will automatically decide the size of d_model, model_intermediate_size, num_layers
101
+ - **--tp-size**: simulate the whole model run configuration to set the dimension size using tp correctly
102
+ - **--batch**: M dimension size of moe kernel, for prefill moe kernel the value is batch*input_len, for decode moe kernel the value is batch
103
+ - **--dtype**: computation type
104
+
105
+ ```bash
106
+ #Tuning
107
+ #for example, we have one case like this "python3 -m sglang.bench_latency --model dummy_grok1/ --load-format dummy --tokenizer-path Xenova/grok-1-tokenizer --tp 8 --batch-size 32 --input 1024 --output 8 --attention-backend triton --sampling-backend pytorch --quant fp" to run, it defined batch-size 32 input lenth 1024 and output length 8, from "--batch" in moe view point, the prefill batch is 32*1024 = 32768, the decode batch is 32*1(only one output token generated in each run).
108
+ #so we can tune decode moe use below command
109
+ python benchmark_moe_rocm.py --model grok1 --tp-size 8 --dtype float8 --batch "32"
110
+ # and use this command to tune prefill moe
111
+ python benchmark_moe_rocm.py --model grok1 --tp-size 8 --dtype float8 --batch "32768"
112
+ ```
113
+
114
+ ## Reference
115
+
116
+ For more detailed information on tuning SGLang performance with AMD GPUs, please refer to the following link:
117
+
118
+ [ROCm Documentation: Triton Kernel Performance Optimization](https://rocm.docs.amd.com/en/latest/how-to/tuning-guides/mi300x/workload.html#triton-kernel-performance-optimization)
sglang/3rdparty/amd/tuning/benchmark_moe_rocm.py ADDED
@@ -0,0 +1,377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ import sys
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import triton
9
+ import triton.language as tl
10
+ from tqdm import tqdm
11
+ from transformers import AutoConfig
12
+
13
+ from sglang.srt.layers.fused_moe_triton.fused_moe import fused_moe, get_config_file_name
14
+
15
+ padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0
16
+
17
+
18
+ def main(model, tp_size, dtype: str, batches):
19
+ method = fused_moe
20
+
21
+ for bs in batches:
22
+ run_grid(int(bs), model=model, method=method, tp_size=tp_size, dtype=dtype)
23
+
24
+
25
+ def prune_configs(M, N, K, configs):
26
+ pruned_configs = []
27
+ elemBytes_a = 1 # [DV Note] Hard-coded for float16 (2 bytes)
28
+ elemBytes_b = 1 # [DV Note] Hard-coded for float16 (2 bytes)
29
+
30
+ mfma = 16 if M < 32 or N < 32 else 32
31
+
32
+ # TODO (zhanglx): figure out the boundary between large and small gemms
33
+ large_gemm = False
34
+ if M >= 2048 and N >= 2048:
35
+ large_gemm = True
36
+
37
+ for config in configs:
38
+ BLOCK_SIZE_M = config.get("BLOCK_SIZE_M")
39
+ BLOCK_SIZE_N = config.get("BLOCK_SIZE_N")
40
+ BLOCK_SIZE_K = config.get("BLOCK_SIZE_K")
41
+ num_warps = config.get("num_warps")
42
+ matrix_instr_nonkdim = config.get("matrix_instr_nonkdim")
43
+ # kpack = config.get("kpack")
44
+ if matrix_instr_nonkdim > mfma:
45
+ continue
46
+ if mfma == 4 and BLOCK_SIZE_K < 64:
47
+ continue
48
+ # some layouts could not work properly in case
49
+ # number elements per thread is less 1
50
+ if BLOCK_SIZE_M * BLOCK_SIZE_N < 64:
51
+ continue
52
+ SPLIT_K = 1 # config.get("SPLIT_K")
53
+ GROUP_M = config.get("GROUP_SIZE_M")
54
+ if matrix_instr_nonkdim > BLOCK_SIZE_M or matrix_instr_nonkdim > BLOCK_SIZE_N:
55
+ continue
56
+ if matrix_instr_nonkdim >= M and matrix_instr_nonkdim != BLOCK_SIZE_M:
57
+ continue
58
+ if matrix_instr_nonkdim >= N and matrix_instr_nonkdim != BLOCK_SIZE_N:
59
+ continue
60
+ # Skip BLOCK_SIZE that is too large compare to M/N
61
+ # unless BLOCK_SIZE is already small enough
62
+ if M * 2 < BLOCK_SIZE_M and BLOCK_SIZE_M != 16:
63
+ continue
64
+ if N * 2 < BLOCK_SIZE_N and BLOCK_SIZE_N != 16:
65
+ continue
66
+ # skip large split_k when not necessary
67
+ if SPLIT_K != 1 and not need_split_k(M, N, K):
68
+ continue
69
+ # skip split_k that leads to EVEN_K = false
70
+ leap = SPLIT_K * BLOCK_SIZE_K
71
+ modv = K % leap
72
+ if modv != 0:
73
+ continue
74
+ # skip large GROUP_M
75
+ if GROUP_M * BLOCK_SIZE_M > M and GROUP_M != 1:
76
+ continue
77
+ # out of shared memory resource
78
+ # TODO (zhanglx): This does not consider the LDS usage in the epilogue
79
+ LDS = (
80
+ BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a
81
+ + BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b
82
+ )
83
+ if LDS > 65536:
84
+ continue
85
+ # Skip small block sizes and num_warps for large gemm
86
+ # For fp16 and f8, we want to only use BLOCK_SIZE >= 64
87
+ if large_gemm:
88
+ if BLOCK_SIZE_M < 64 or BLOCK_SIZE_N < 64:
89
+ continue
90
+ if BLOCK_SIZE_K < 64:
91
+ continue
92
+ if num_warps < 4:
93
+ continue
94
+
95
+ pruned_configs.append(config)
96
+
97
+ return pruned_configs
98
+
99
+
100
+ def union_of_list_of_dicts(l1, l2):
101
+ result = []
102
+ temp_list = l1.copy()
103
+ temp_list.extend(l2)
104
+ for myDict in temp_list:
105
+ if myDict not in result:
106
+ result.append(myDict)
107
+
108
+ return result
109
+
110
+
111
+ def run_grid(bs, model, method, tp_size, dtype: str):
112
+
113
+ config = AutoConfig.from_pretrained(model)
114
+
115
+ top_k = config.num_experts_per_tok
116
+ d_model = config.hidden_size
117
+ model_intermediate_size = config.intermediate_size
118
+ num_layers = config.num_hidden_layers
119
+ hidden_states_dtype = config.torch_dtype
120
+
121
+ if config.num_experts_per_tok:
122
+ if config.architectures[0] == "Grok1ModelForCausalLM":
123
+ num_total_experts = config.num_experts
124
+ else:
125
+ num_total_experts = config.num_local_experts
126
+ else:
127
+ raise ValueError(f"Unsupported Mixtral model {model}")
128
+
129
+ # tp_size = 2
130
+ num_warmup_calls = 10
131
+ num_calls = 30
132
+
133
+ num_warmup_trials = 1
134
+ num_trials = 1
135
+
136
+ full_configs = []
137
+
138
+ block_m_range = [16, 32, 64, 128, 256]
139
+ block_n_range = [16, 32, 64, 128, 256]
140
+ block_k_range = [32, 64, 128, 256] # MUST >= 32
141
+ num_warps_range = [1, 2, 4, 8]
142
+ group_m_range = [1, 4, 8, 16, 32]
143
+ # For now we see better perf with num_stages=0 for all gemm configs we care
144
+ # But keep this explicit so that we do not forget we may need to set it to
145
+ # other values in the future
146
+ num_stage_range = [2]
147
+ waves_per_eu_range = [0, 1, 2, 4, 8]
148
+ # Remove 32 because of triton compiling error
149
+ matrix_instr_nonkdim_range = [16]
150
+ kpack_range = [1, 2]
151
+
152
+ for block_size_m in block_m_range:
153
+ for block_size_n in block_n_range:
154
+ for block_size_k in block_k_range:
155
+ for group_size_m in group_m_range:
156
+ for num_warps in num_warps_range:
157
+ for num_stages in num_stage_range:
158
+ for waves_per_eu in waves_per_eu_range:
159
+ for matrix_instr_nonkdim in matrix_instr_nonkdim_range:
160
+ for kpack in kpack_range:
161
+ full_configs.append(
162
+ {
163
+ "BLOCK_SIZE_M": block_size_m,
164
+ "BLOCK_SIZE_N": block_size_n,
165
+ "BLOCK_SIZE_K": block_size_k,
166
+ "GROUP_SIZE_M": group_size_m,
167
+ "num_warps": num_warps,
168
+ "num_stages": num_stages,
169
+ "waves_per_eu": waves_per_eu,
170
+ "matrix_instr_nonkdim": matrix_instr_nonkdim,
171
+ "kpack": kpack,
172
+ }
173
+ )
174
+
175
+ M1 = bs * 2
176
+ N1 = model_intermediate_size * 2 // tp_size
177
+ K1 = d_model
178
+ prune_configs_1 = prune_configs(M1, N1, K1, full_configs)
179
+
180
+ M2 = bs * 2
181
+ N2 = d_model
182
+ K2 = model_intermediate_size // tp_size
183
+ prune_configs_2 = prune_configs(M2, N2, K2, full_configs)
184
+
185
+ configs = union_of_list_of_dicts(prune_configs_1, prune_configs_2)
186
+
187
+ print(
188
+ f"{bs=} || {len(full_configs)=} | {len(prune_configs_1)=} | \
189
+ {len(prune_configs_2)=} | {len(configs)=}"
190
+ )
191
+
192
+ best_config = None
193
+ best_time_us = 1e20
194
+
195
+ print(f"{tp_size=} {bs=}")
196
+
197
+ for config in tqdm(configs):
198
+ # warmup
199
+ try:
200
+ print(config)
201
+ for _ in range(num_warmup_trials):
202
+ run_timing(
203
+ num_calls=num_warmup_calls,
204
+ bs=bs,
205
+ d_model=d_model,
206
+ num_total_experts=num_total_experts,
207
+ top_k=top_k,
208
+ tp_size=tp_size,
209
+ model_intermediate_size=model_intermediate_size,
210
+ method=method,
211
+ config=config,
212
+ dtype=dtype,
213
+ hidden_states_dtype=hidden_states_dtype,
214
+ )
215
+ except triton.runtime.autotuner.OutOfResources:
216
+ continue
217
+
218
+ # trial
219
+ for _ in range(num_trials):
220
+ kernel_dur_ms = run_timing(
221
+ num_calls=num_calls,
222
+ bs=bs,
223
+ d_model=d_model,
224
+ num_total_experts=num_total_experts,
225
+ top_k=top_k,
226
+ tp_size=tp_size,
227
+ model_intermediate_size=model_intermediate_size,
228
+ method=method,
229
+ config=config,
230
+ dtype=dtype,
231
+ hidden_states_dtype=hidden_states_dtype,
232
+ )
233
+
234
+ kernel_dur_us = 1000 * kernel_dur_ms
235
+ model_dur_ms = kernel_dur_ms * num_layers
236
+
237
+ if kernel_dur_us < best_time_us:
238
+ best_config = config
239
+ best_time_us = kernel_dur_us
240
+
241
+ tqdm.write(
242
+ f"{kernel_dur_us=:.1f} {model_dur_ms=:.1f}"
243
+ f" {bs=} {tp_size=} {top_k=} {num_total_experts=} "
244
+ f"{d_model=} {model_intermediate_size=} {num_layers=}"
245
+ )
246
+
247
+ print("best_time_us", best_time_us)
248
+ print("best_config", best_config)
249
+
250
+ # holds Dict[str, Dict[str, int]]
251
+ filename = get_config_file_name(
252
+ num_total_experts,
253
+ model_intermediate_size // tp_size,
254
+ "float8" if dtype == "float8" else None,
255
+ )
256
+ print(f"writing config to file {filename}")
257
+ existing_content = {}
258
+ if os.path.exists(filename):
259
+ with open(filename, "r") as f:
260
+ existing_content = json.load(f)
261
+ existing_content[str(bs)] = best_config
262
+ with open(filename, "w") as f:
263
+ json.dump(existing_content, f, indent=4)
264
+ f.write("\n")
265
+
266
+
267
+ def run_timing(
268
+ num_calls: int,
269
+ bs: int,
270
+ d_model: int,
271
+ num_total_experts: int,
272
+ top_k: int,
273
+ tp_size: int,
274
+ model_intermediate_size: int,
275
+ method,
276
+ config,
277
+ dtype: str,
278
+ hidden_states_dtype,
279
+ ) -> float:
280
+ shard_intermediate_size = model_intermediate_size // tp_size
281
+
282
+ hidden_states = torch.rand(
283
+ (bs, d_model),
284
+ device="cuda:0",
285
+ dtype=hidden_states_dtype,
286
+ )
287
+
288
+ w1 = torch.rand(
289
+ (num_total_experts, 2 * shard_intermediate_size, d_model + padding_size),
290
+ device=hidden_states.device,
291
+ dtype=hidden_states.dtype,
292
+ )
293
+
294
+ w2 = torch.rand(
295
+ (num_total_experts, d_model, shard_intermediate_size + padding_size),
296
+ device=hidden_states.device,
297
+ dtype=hidden_states.dtype,
298
+ )
299
+
300
+ w1_scale = None
301
+ w2_scale = None
302
+ a1_scale = None
303
+ a2_scale = None
304
+
305
+ if dtype == "float8":
306
+ w1 = w1.to(torch.float8_e4m3fnuz)
307
+ w2 = w2.to(torch.float8_e4m3fnuz)
308
+ w1_scale = torch.ones(
309
+ num_total_experts, device=hidden_states.device, dtype=torch.float32
310
+ )
311
+ w2_scale = torch.ones(
312
+ num_total_experts, device=hidden_states.device, dtype=torch.float32
313
+ )
314
+ a1_scale = torch.ones(1, device=hidden_states.device, dtype=torch.float32)
315
+ a2_scale = torch.ones(1, device=hidden_states.device, dtype=torch.float32)
316
+
317
+ gating_output = F.softmax(
318
+ torch.rand(
319
+ (num_calls, bs, num_total_experts),
320
+ device=hidden_states.device,
321
+ dtype=torch.float32,
322
+ ),
323
+ dim=-1,
324
+ )
325
+
326
+ ##################################
327
+
328
+ start_event = torch.cuda.Event(enable_timing=True)
329
+ end_event = torch.cuda.Event(enable_timing=True)
330
+
331
+ start_event.record()
332
+ for i in range(num_calls):
333
+ hidden_states = method(
334
+ hidden_states=hidden_states,
335
+ w1=w1,
336
+ w2=w2,
337
+ w1_scale=w1_scale,
338
+ w2_scale=w2_scale,
339
+ a1_scale=a1_scale,
340
+ a2_scale=a2_scale,
341
+ gating_output=gating_output[0],
342
+ topk=top_k,
343
+ renormalize=True,
344
+ inplace=True,
345
+ override_config=config,
346
+ use_fp8=dtype == "float8",
347
+ )
348
+
349
+ end_event.record()
350
+ end_event.synchronize()
351
+
352
+ dur_ms = start_event.elapsed_time(end_event) / num_calls
353
+ return dur_ms
354
+
355
+
356
+ if __name__ == "__main__":
357
+ parser = argparse.ArgumentParser(
358
+ prog="benchmark_mixtral_moe",
359
+ description="Benchmark and tune the fused_moe kernel",
360
+ )
361
+ parser.add_argument(
362
+ "--dtype",
363
+ type=str,
364
+ default="auto",
365
+ choices=["float8", "float16", "bfloat16"],
366
+ help="Data type used for fused_moe kernel computations",
367
+ )
368
+ parser.add_argument("--model", type=str, default="hpcai-tech/grok-1")
369
+
370
+ parser.add_argument("--tp-size", type=int, default=2, help="Tensor paralleli size")
371
+ parser.add_argument("-b", "--batches", type=str)
372
+
373
+ args = parser.parse_args()
374
+
375
+ batches = args.batches.split(",")
376
+
377
+ sys.exit(main(args.model, args.tp_size, args.dtype, batches))
sglang/benchmark/blog_v0_2/405b_sglang.sh ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Create dummy weights:
2
+ # 1. Create a folder `~/llama-3.1-405b-fp8-dummy` and create `config.json` and tokenizer under this folder.
3
+ # 2. Get `config.json`` from ./config.md
4
+ # 3. Download the tokenizer
5
+ # wget https://huggingface.co/neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8/resolve/main/tokenizer.json
6
+ # wget https://huggingface.co/neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8/resolve/main/tokenizer_config.json
7
+
8
+ # Launch sglang
9
+ # python -m sglang.launch_server --model-path ~/llama-3.1-405b-fp8-dummy/ --load-format dummy --tp 8 --quant fp8 --disable-radix --mem-frac 0.87
10
+
11
+ # offline
12
+ python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 3000 --random-input 1024 --random-output 1024 > sglang_log11
13
+ python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 4000 --random-input 1024 --random-output 512 > sglang_log12
14
+ python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 800 --random-input 4096 --random-output 2048 > sglang_log13
15
+ python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 1500 --random-input 4096 --random-output 1024 > sglang_log14
16
+ python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 6000 --random-input 256 --random-output 512 > sglang_log15
17
+ python3 -m sglang.bench_serving --backend sglang --dataset-name sharegpt --num-prompt 2000 > sglang_log21
18
+
19
+ # online
20
+ python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 300 --request-rate 1 --random-input 1024 --random-output 1024 > sglang_log31
21
+ python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 600 --request-rate 2 --random-input 1024 --random-output 1024 > sglang_log32
22
+ python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 1200 --request-rate 4 --random-input 1024 --random-output 1024 > sglang_log33
23
+ python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 2400 --request-rate 8 --random-input 1024 --random-output 1024 > sglang_log34
24
+ python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 3200 --request-rate 16 --random-input 1024 --random-output 1024 > sglang_log35
sglang/benchmark/blog_v0_2/405b_trt.sh ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Launch trtllm
2
+ # https://github.com/sgl-project/tensorrt-demo
3
+
4
+ # offline
5
+ python3 ../../python/sglang/bench_serving.py --backend trt --dataset-name random --num-prompt 3000 --random-input 1024 --random-output 1024 --model /root/Meta-Llama-3-8B-Instruct > trtllm_log11
6
+ python3 ../../python/sglang/bench_serving.py --backend trt --dataset-name random --num-prompt 4000 --random-input 1024 --random-output 512 --model /root/Meta-Llama-3-8B-Instruct > trtllm_log12
7
+ python3 ../../python/sglang/bench_serving.py --backend trt --dataset-name random --num-prompt 800 --random-input 4096 --random-output 2048 --model /root/Meta-Llama-3-8B-Instruct > trtllm_log13
8
+ python3 ../../python/sglang/bench_serving.py --backend trt --dataset-name random --num-prompt 1500 --random-input 4096 --random-output 1024 --model /root/Meta-Llama-3-8B-Instruct > trtllm_log14
9
+ python3 ../../python/sglang/bench_serving.py --backend trt --dataset-name random --num-prompt 6000 --random-input 256 --random-output 512 --model /root/Meta-Llama-3-8B-Instruct > trtllm_log15
10
+ python3 ../../python/sglang/bench_serving.py --backend trt --dataset-name sharegpt --num-prompt 2000 --model /root/Meta-Llama-3-8B-Instruct > trtllm_log21
11
+
12
+ # online
13
+ python3 ../../python/sglang/bench_serving.py --backend trt --dataset-name random --num-prompt 300 --request-rate 1 --random-input 1024 --random-output 1024 --model /root/Meta-Llama-3-8B-Instruct > trtllm_log31
14
+ python3 ../../python/sglang/bench_serving.py --backend trt --dataset-name random --num-prompt 600 --request-rate 2 --random-input 1024 --random-output 1024 --model /root/Meta-Llama-3-8B-Instruct > trtllm_log32
15
+ python3 ../../python/sglang/bench_serving.py --backend trt --dataset-name random --num-prompt 1200 --request-rate 4 --random-input 1024 --random-output 1024 --model /root/Meta-Llama-3-8B-Instruct > trtllm_log33
16
+ python3 ../../python/sglang/bench_serving.py --backend trt --dataset-name random --num-prompt 2400 --request-rate 8 --random-input 1024 --random-output 1024 --model /root/Meta-Llama-3-8B-Instruct > trtllm_log34
17
+ python3 ../../python/sglang/bench_serving.py --backend trt --dataset-name random --num-prompt 3200 --request-rate 16 --random-input 1024 --random-output 1024 --model /root/Meta-Llama-3-8B-Instruct > trtllm_log35
sglang/benchmark/blog_v0_2/405b_vllm.sh ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Create dummy weights:
2
+ # 1. Create a folder `~/llama-3.1-405b-fp8-dummy` and create `config.json` and tokenizer under this folder.
3
+ # 2. Get `config.json`` from ./config.md
4
+ # 3. Download the tokenizer
5
+ # wget https://huggingface.co/neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8/resolve/main/tokenizer.json
6
+ # wget https://huggingface.co/neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8/resolve/main/tokenizer_config.json
7
+
8
+ # Launch vllm
9
+ # python3 -m vllm.entrypoints.openai.api_server --model ~/llama-3.1-405b-fp8-dummy/ --load-format dummy --disable-log-requests --tensor-parallel-size 8 --max-model-len 10000
10
+
11
+ # offline
12
+ python3 ../../python/sglang/bench_serving.py --backend vllm --dataset-name random --num-prompt 3000 --random-input 1024 --random-output 1024 > vllm_log11
13
+ python3 ../../python/sglang/bench_serving.py --backend vllm --dataset-name random --num-prompt 4000 --random-input 1024 --random-output 512 > vllm_log12
14
+ python3 ../../python/sglang/bench_serving.py --backend vllm --dataset-name random --num-prompt 800 --random-input 4096 --random-output 2048 > vllm_log13
15
+ python3 ../../python/sglang/bench_serving.py --backend vllm --dataset-name random --num-prompt 1500 --random-input 4096 --random-output 1024 > vllm_log14
16
+ python3 ../../python/sglang/bench_serving.py --backend vllm --dataset-name random --num-prompt 6000 --random-input 256 --random-output 512 > vllm_log15
17
+ python3 ../../python/sglang/bench_serving.py --backend vllm --dataset-name sharegpt --num-prompt 2000 > vllm_log21
18
+
19
+ # online
20
+ python3 ../../python/sglang/bench_serving.py --backend vllm --dataset-name random --num-prompt 300 --request-rate 1 --random-input 1024 --random-output 1024 > vllm_log31
21
+ python3 ../../python/sglang/bench_serving.py --backend vllm --dataset-name random --num-prompt 600 --request-rate 2 --random-input 1024 --random-output 1024 > vllm_log32
22
+ python3 ../../python/sglang/bench_serving.py --backend vllm --dataset-name random --num-prompt 1200 --request-rate 4 --random-input 1024 --random-output 1024 > vllm_log33
23
+ python3 ../../python/sglang/bench_serving.py --backend vllm --dataset-name random --num-prompt 2400 --request-rate 8 --random-input 1024 --random-output 1024 > vllm_log34
24
+ python3 ../../python/sglang/bench_serving.py --backend vllm --dataset-name random --num-prompt 3200 --request-rate 16 --random-input 1024 --random-output 1024 > vllm_log35
sglang/benchmark/dspy/README.md ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Install
2
+
3
+ ```
4
+ pip3 install dspy-ai
5
+ ```
6
+
7
+ Turn off cache at https://github.com/stanfordnlp/dspy/blob/34d8420383ec752037aa271825c1d3bf391e1277/dsp/modules/cache_utils.py#L10.
8
+ ```
9
+ cache_turn_on = False
10
+ ```
11
+
12
+ or set the environment variable
13
+
14
+ ```
15
+ export DSP_CACHEBOOL=false
16
+ ```
17
+
18
+ ## Benchmark SGLang
19
+ ```
20
+ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
21
+ ```
22
+
23
+ ```
24
+ python3 bench_dspy_intro.py --backend sglang
25
+ ```
26
+
27
+
28
+ ## Benchmark TGI
29
+ ```
30
+ docker run --name tgi --rm -ti --gpus all --network host \
31
+ -v /home/ubuntu/model_weights/Llama-2-7b-chat-hf:/Llama-2-7b-chat-hf \
32
+ ghcr.io/huggingface/text-generation-inference:1.3.0 \
33
+ --model-id /Llama-2-7b-chat-hf --num-shard 1 --trust-remote-code \
34
+ --max-input-length 2048 --max-total-tokens 4096 \
35
+ --port 24000
36
+ ```
37
+
38
+ ```
39
+ python3 bench_dspy_intro.py --backend tgi
40
+ ```
41
+
42
+
43
+
44
+ ## Benchmark vLLM
45
+ ```
46
+ python3 -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
47
+ ```
48
+
49
+ ```
50
+ python3 bench_dspy_intro.py --backend vllm
51
+ ```
sglang/benchmark/dspy/bench_dspy_intro.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Adapted from
3
+ https://github.com/stanfordnlp/dspy/blob/34d8420383ec752037aa271825c1d3bf391e1277/intro.ipynb#L9
4
+ """
5
+
6
+ import argparse
7
+
8
+ import dspy
9
+ from dspy.datasets import HotPotQA
10
+
11
+
12
+ class BasicQA(dspy.Signature):
13
+ """Answer questions with short factoid answers."""
14
+
15
+ question = dspy.InputField()
16
+ answer = dspy.OutputField(desc="often between 1 and 5 words")
17
+
18
+
19
+ class GenerateAnswer(dspy.Signature):
20
+ """Answer questions with short factoid answers."""
21
+
22
+ context = dspy.InputField(desc="may contain relevant facts")
23
+ question = dspy.InputField()
24
+ answer = dspy.OutputField(desc="often between 1 and 5 words")
25
+
26
+
27
+ class RAG(dspy.Module):
28
+ def __init__(self, num_passages=3):
29
+ super().__init__()
30
+
31
+ self.retrieve = dspy.Retrieve(k=num_passages)
32
+ self.generate_answer = dspy.ChainOfThought(GenerateAnswer)
33
+
34
+ def forward(self, question):
35
+ context = self.retrieve(question).passages
36
+ prediction = self.generate_answer(context=context, question=question)
37
+ return dspy.Prediction(context=context, answer=prediction.answer)
38
+
39
+
40
+ def main(args):
41
+ # lm = dspy.OpenAI(model='gpt-3.5-turbo')
42
+ if args.backend == "tgi":
43
+ lm = dspy.HFClientTGI(
44
+ model="meta-llama/Llama-2-7b-chat-hf",
45
+ port=args.port,
46
+ url="http://localhost",
47
+ )
48
+ elif args.backend == "sglang":
49
+ lm = dspy.HFClientSGLang(
50
+ model="meta-llama/Llama-2-7b-chat-hf",
51
+ port=args.port,
52
+ url="http://localhost",
53
+ )
54
+ elif args.backend == "vllm":
55
+ lm = dspy.HFClientVLLM(
56
+ model="meta-llama/Llama-2-7b-chat-hf",
57
+ port=args.port,
58
+ url="http://localhost",
59
+ )
60
+ else:
61
+ raise ValueError(f"Invalid backend: {args.backend}")
62
+
63
+ colbertv2_wiki17_abstracts = dspy.ColBERTv2(
64
+ url="http://20.102.90.50:2017/wiki17_abstracts"
65
+ )
66
+ dspy.settings.configure(lm=lm, rm=colbertv2_wiki17_abstracts)
67
+
68
+ # Load the dataset.
69
+ dataset = HotPotQA(
70
+ train_seed=1, train_size=20, eval_seed=2023, dev_size=args.dev_size, test_size=0
71
+ )
72
+
73
+ # Tell DSPy that the 'question' field is the input. Any other fields are labels and/or metadata.
74
+ trainset = [x.with_inputs("question") for x in dataset.train]
75
+ devset = [x.with_inputs("question") for x in dataset.dev]
76
+
77
+ print(len(trainset), len(devset))
78
+
79
+ train_example = trainset[0]
80
+ print(f"Question: {train_example.question}")
81
+ print(f"Answer: {train_example.answer}")
82
+
83
+ dev_example = devset[18]
84
+ print(f"Question: {dev_example.question}")
85
+ print(f"Answer: {dev_example.answer}")
86
+ print(f"Relevant Wikipedia Titles: {dev_example.gold_titles}")
87
+
88
+ print(
89
+ f"For this dataset, training examples have input keys {train_example.inputs().keys()} and label keys {train_example.labels().keys()}"
90
+ )
91
+ print(
92
+ f"For this dataset, dev examples have input keys {dev_example.inputs().keys()} and label keys {dev_example.labels().keys()}"
93
+ )
94
+
95
+ # Define the predictor.
96
+ generate_answer = dspy.Predict(BasicQA)
97
+
98
+ # Call the predictor on a particular input.
99
+ pred = generate_answer(question=dev_example.question)
100
+
101
+ # Print the input and the prediction.
102
+ print(f"Question: {dev_example.question}")
103
+ print(f"Predicted Answer: {pred.answer}")
104
+
105
+ lm.inspect_history(n=1)
106
+
107
+ # Define the predictor. Notice we're just changing the class. The signature BasicQA is unchanged.
108
+ generate_answer_with_chain_of_thought = dspy.ChainOfThought(BasicQA)
109
+
110
+ # Call the predictor on the same input.
111
+ pred = generate_answer_with_chain_of_thought(question=dev_example.question)
112
+
113
+ # Print the input, the chain of thought, and the prediction.
114
+ print(f"Question: {dev_example.question}")
115
+ print(f"Thought: {pred.rationale.split('.', 1)[1].strip()}")
116
+ print(f"Predicted Answer: {pred.answer}")
117
+
118
+ retrieve = dspy.Retrieve(k=3)
119
+ topK_passages = retrieve(dev_example.question).passages
120
+
121
+ print(
122
+ f"Top {retrieve.k} passages for question: {dev_example.question} \n",
123
+ "-" * 30,
124
+ "\n",
125
+ )
126
+
127
+ for idx, passage in enumerate(topK_passages):
128
+ print(f"{idx+1}]", passage, "\n")
129
+
130
+ retrieve("When was the first FIFA World Cup held?").passages[0]
131
+
132
+ from dspy.teleprompt import BootstrapFewShot
133
+
134
+ # Validation logic: check that the predicted answer is correct.
135
+ # Also check that the retrieved context does actually contain that answer.
136
+ def validate_context_and_answer(example, pred, trace=None):
137
+ answer_EM = dspy.evaluate.answer_exact_match(example, pred)
138
+ answer_PM = dspy.evaluate.answer_passage_match(example, pred)
139
+ return answer_EM and answer_PM
140
+
141
+ # Set up a basic teleprompter, which will compile our RAG program.
142
+ teleprompter = BootstrapFewShot(metric=validate_context_and_answer)
143
+
144
+ # Compile!
145
+ compiled_rag = teleprompter.compile(RAG(), trainset=trainset)
146
+
147
+ # Ask any question you like to this simple RAG program.
148
+ my_question = "What castle did David Gregory inherit?"
149
+
150
+ # Get the prediction. This contains `pred.context` and `pred.answer`.
151
+ pred = compiled_rag(my_question)
152
+
153
+ # Print the contexts and the answer.
154
+ print(f"Question: {my_question}")
155
+ print(f"Predicted Answer: {pred.answer}")
156
+ print(f"Retrieved Contexts (truncated): {[c[:200] + '...' for c in pred.context]}")
157
+
158
+ from dspy.evaluate.evaluate import Evaluate
159
+
160
+ # Set up the `evaluate_on_hotpotqa` function. We'll use this many times below.
161
+ evaluate_on_hotpotqa = Evaluate(
162
+ devset=devset,
163
+ num_threads=args.num_threads,
164
+ display_progress=True,
165
+ display_table=5,
166
+ )
167
+
168
+ # Evaluate the `compiled_rag` program with the `answer_exact_match` metric.
169
+ metric = dspy.evaluate.answer_exact_match
170
+ evaluate_on_hotpotqa(compiled_rag, metric=metric)
171
+
172
+
173
+ if __name__ == "__main__":
174
+ parser = argparse.ArgumentParser()
175
+ parser.add_argument("--port", type=int)
176
+ parser.add_argument("--num-threads", type=int, default=32)
177
+ parser.add_argument("--dev-size", type=int, default=150)
178
+ parser.add_argument(
179
+ "--backend", type=str, choices=["sglang", "tgi", "vllm"], default="sglang"
180
+ )
181
+ args = parser.parse_args()
182
+
183
+ if args.port is None:
184
+ default_port = {
185
+ "vllm": 21000,
186
+ "lightllm": 22000,
187
+ "tgi": 24000,
188
+ "sglang": 30000,
189
+ }
190
+ args.port = default_port.get(args.backend, None)
191
+
192
+ main(args)
sglang/benchmark/gsm8k/README.md ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Run benchmark
2
+
3
+ ### Benchmark sglang
4
+ ```
5
+ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
6
+ ```
7
+
8
+ ```
9
+ python3 bench_sglang.py --num-questions 200
10
+ ```
11
+
12
+
13
+ ### Benchmark vllm
14
+ ```
15
+ python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
16
+ ```
17
+
18
+ ```
19
+ python3 bench_other.py --num-questions 200 --backend vllm
20
+ ```
21
+
22
+
23
+ ### Benchmark lightllm
24
+ ```
25
+ # A10G
26
+ python -m lightllm.server.api_server --tokenizer_mode auto --model_dir ~/model_weights/llama-2-7b-chat-hf --max_total_token_num 16000 --port 22000
27
+ ```
28
+
29
+ ```
30
+ python3 bench_other.py --num-questions 200 --backend lightllm
31
+ ```
32
+
33
+
34
+ ### Benchmark guidance
35
+ ```
36
+ python3 bench_other.py --num-questions 200 --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf
37
+ ```
38
+
39
+
40
+ ### Benchmark lmql
41
+ ```
42
+ CUDA_VISIBLE_DEVICES=0,1 lmql serve-model meta-llama/Llama-2-7b-chat-hf --cuda --port 23000
43
+ ```
44
+
45
+ ```
46
+ python3 bench_other.py --num-questions 100 --backend lmql --parallel 2
47
+ ```
sglang/benchmark/gsm8k/bench_other.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import ast
3
+ import asyncio
4
+ import json
5
+ import re
6
+ import time
7
+ from concurrent.futures import ThreadPoolExecutor
8
+
9
+ import numpy as np
10
+ from tqdm import tqdm
11
+
12
+ from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
13
+ from sglang.utils import download_and_cache_file, dump_state_text, read_jsonl
14
+
15
+ INVALID = -9999999
16
+
17
+
18
+ def get_one_example(lines, i, include_answer):
19
+ ret = "Question: " + lines[i]["question"] + "\nAnswer:"
20
+ if include_answer:
21
+ ret += " " + lines[i]["answer"]
22
+ return ret
23
+
24
+
25
+ def get_few_shot_examples(lines, k):
26
+ ret = ""
27
+ for i in range(k):
28
+ ret += get_one_example(lines, i, True) + "\n\n"
29
+ return ret
30
+
31
+
32
+ def get_answer_value(answer_str):
33
+ answer_str = answer_str.replace(",", "")
34
+ numbers = re.findall(r"\d+", answer_str)
35
+ if len(numbers) < 1:
36
+ return INVALID
37
+ try:
38
+ return ast.literal_eval(numbers[-1])
39
+ except SyntaxError:
40
+ return INVALID
41
+
42
+
43
+ def main(args):
44
+ # Select backend
45
+ call_generate = get_call_generate(args)
46
+
47
+ # Read data
48
+ url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl"
49
+ filename = download_and_cache_file(url)
50
+ lines = list(read_jsonl(filename))
51
+
52
+ # Construct prompts
53
+ num_questions = args.num_questions
54
+ num_shots = args.num_shots
55
+ few_shot_examples = get_few_shot_examples(lines, num_shots)
56
+
57
+ questions = []
58
+ labels = []
59
+ for i in range(len(lines[:num_questions])):
60
+ questions.append(get_one_example(lines, i, False))
61
+ labels.append(get_answer_value(lines[i]["answer"]))
62
+ assert all(l != INVALID for l in labels)
63
+
64
+ states = [None] * len(labels)
65
+
66
+ # Run requests
67
+ if args.backend != "lmql":
68
+ # Use thread pool
69
+ def get_one_answer(i):
70
+ answer = call_generate(
71
+ prompt=few_shot_examples + questions[i],
72
+ temperature=0,
73
+ max_tokens=256,
74
+ stop=["Question", "Assistant:", "<|separator|>"],
75
+ )
76
+ states[i] = answer
77
+
78
+ tic = time.time()
79
+ if args.parallel == 1:
80
+ for i in tqdm(range(len(questions))):
81
+ get_one_answer(i)
82
+ else:
83
+ with ThreadPoolExecutor(args.parallel) as executor:
84
+ list(
85
+ tqdm(
86
+ executor.map(get_one_answer, list(range(len(questions)))),
87
+ total=len(questions),
88
+ )
89
+ )
90
+
91
+ else:
92
+ # Use asyncio
93
+ async def batched_call(batch_size):
94
+ for i in range(0, len(questions), batch_size):
95
+ tasks = []
96
+ for q in questions[i : i + batch_size]:
97
+ tasks.append(
98
+ call_generate(
99
+ few_shot_examples + q,
100
+ temperature=0,
101
+ max_tokens=256,
102
+ stop="Question",
103
+ )
104
+ )
105
+ rets = await asyncio.gather(*tasks)
106
+ for j in range(len(rets)):
107
+ states[i + j] = rets[j]
108
+
109
+ tic = time.time()
110
+ asyncio.run(batched_call(batch_size=args.parallel))
111
+ latency = time.time() - tic
112
+
113
+ preds = []
114
+ for i in range(len(states)):
115
+ preds.append(get_answer_value(states[i]))
116
+
117
+ # Compute accuracy
118
+ acc = np.mean(np.array(preds) == np.array(labels))
119
+ invalid = np.mean(np.array(preds) == INVALID)
120
+
121
+ # Print results
122
+ print(f"Accuracy: {acc:.3f}")
123
+ print(f"Invalid: {invalid:.3f}")
124
+ print(f"Latency: {latency:.3f} s")
125
+
126
+ # Dump results
127
+ dump_state_text(f"tmp_output_{args.backend}.txt", states)
128
+
129
+ with open(args.result_file, "a") as fout:
130
+ value = {
131
+ "task": "gsm8k",
132
+ "backend": args.backend,
133
+ "num_gpus": 1,
134
+ "latency": round(latency, 3),
135
+ "accuracy": round(acc, 3),
136
+ "num_requests": args.num_questions,
137
+ "other": {
138
+ "num_questions": args.num_questions,
139
+ "parallel": args.parallel,
140
+ },
141
+ }
142
+ fout.write(json.dumps(value) + "\n")
143
+
144
+
145
+ if __name__ == "__main__":
146
+ parser = argparse.ArgumentParser()
147
+ parser.add_argument("--num-shots", type=int, default=5)
148
+ parser.add_argument("--data-path", type=str, default="test.jsonl")
149
+ parser.add_argument("--num-questions", type=int, default=200)
150
+ args = add_common_other_args_and_parse(parser)
151
+ main(args)
sglang/benchmark/gsm8k/bench_sglang.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import ast
3
+ import json
4
+ import re
5
+ import time
6
+
7
+ import numpy as np
8
+
9
+ from sglang.api import set_default_backend
10
+ from sglang.test.test_utils import (
11
+ add_common_sglang_args_and_parse,
12
+ select_sglang_backend,
13
+ )
14
+ from sglang.utils import download_and_cache_file, dump_state_text, read_jsonl
15
+
16
+ INVALID = -9999999
17
+
18
+
19
+ def get_one_example(lines, i, include_answer):
20
+ ret = "Question: " + lines[i]["question"] + "\nAnswer:"
21
+ if include_answer:
22
+ ret += " " + lines[i]["answer"]
23
+ return ret
24
+
25
+
26
+ def get_few_shot_examples(lines, k):
27
+ ret = ""
28
+ for i in range(k):
29
+ ret += get_one_example(lines, i, True) + "\n\n"
30
+ return ret
31
+
32
+
33
+ def get_answer_value(answer_str):
34
+ answer_str = answer_str.replace(",", "")
35
+ numbers = re.findall(r"\d+", answer_str)
36
+ if len(numbers) < 1:
37
+ return INVALID
38
+ try:
39
+ return ast.literal_eval(numbers[-1])
40
+ except SyntaxError:
41
+ return INVALID
42
+
43
+
44
+ def main(args):
45
+ # Select backend
46
+ set_default_backend(select_sglang_backend(args))
47
+
48
+ # Read data
49
+ url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl"
50
+ filename = download_and_cache_file(url)
51
+ lines = list(read_jsonl(filename))
52
+
53
+ # Construct prompts
54
+ num_questions = args.num_questions
55
+ num_shots = args.num_shots
56
+ few_shot_examples = get_few_shot_examples(lines, num_shots)
57
+
58
+ questions = []
59
+ labels = []
60
+ for i in range(len(lines[:num_questions])):
61
+ questions.append(get_one_example(lines, i, False))
62
+ labels.append(get_answer_value(lines[i]["answer"]))
63
+ assert all(l != INVALID for l in labels)
64
+ arguments = [{"question": q} for q in questions]
65
+
66
+ #####################################
67
+ ######### SGL Program Begin #########
68
+ #####################################
69
+
70
+ import sglang as sgl
71
+
72
+ @sgl.function
73
+ def few_shot_gsm8k(s, question):
74
+ s += few_shot_examples + question
75
+ s += sgl.gen(
76
+ "answer", max_tokens=512, stop=["Question", "Assistant:", "<|separator|>"]
77
+ )
78
+
79
+ #####################################
80
+ ########## SGL Program End ##########
81
+ #####################################
82
+
83
+ # Run requests
84
+ tic = time.time()
85
+ states = few_shot_gsm8k.run_batch(
86
+ arguments,
87
+ temperature=0,
88
+ num_threads=args.parallel,
89
+ progress_bar=True,
90
+ )
91
+ latency = time.time() - tic
92
+
93
+ preds = []
94
+ for i in range(len(states)):
95
+ preds.append(get_answer_value(states[i]["answer"]))
96
+
97
+ # print(f"{preds=}")
98
+ # print(f"{labels=}")
99
+
100
+ # Compute accuracy
101
+ acc = np.mean(np.array(preds) == np.array(labels))
102
+ invalid = np.mean(np.array(preds) == INVALID)
103
+
104
+ # Compute speed
105
+ num_output_tokens = sum(
106
+ s.get_meta_info("answer")["completion_tokens"] for s in states
107
+ )
108
+ output_throughput = num_output_tokens / latency
109
+
110
+ # Print results
111
+ print(f"Accuracy: {acc:.3f}")
112
+ print(f"Invalid: {invalid:.3f}")
113
+ print(f"Latency: {latency:.3f} s")
114
+ print(f"Output throughput: {output_throughput:.3f} token/s")
115
+
116
+ # Dump results
117
+ dump_state_text(f"tmp_output_{args.backend}.txt", states)
118
+
119
+ with open(args.result_file, "a") as fout:
120
+ value = {
121
+ "task": "gsm8k",
122
+ "backend": args.backend,
123
+ "num_gpus": 1,
124
+ "latency": round(latency, 3),
125
+ "accuracy": round(acc, 3),
126
+ "num_requests": args.num_questions,
127
+ "other": {
128
+ "num_questions": args.num_questions,
129
+ "parallel": args.parallel,
130
+ },
131
+ }
132
+ fout.write(json.dumps(value) + "\n")
133
+
134
+
135
+ if __name__ == "__main__":
136
+ parser = argparse.ArgumentParser()
137
+ parser.add_argument("--num-shots", type=int, default=5)
138
+ parser.add_argument("--data-path", type=str, default="test.jsonl")
139
+ parser.add_argument("--num-questions", type=int, default=200)
140
+ args = add_common_sglang_args_and_parse(parser)
141
+ main(args)
sglang/benchmark/hellaswag/README.md ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Run benchmark
2
+
3
+ ### Benchmark sglang
4
+ ```
5
+ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
6
+ ```
7
+
8
+ ```
9
+ python3 bench_sglang.py --num-questions 200
10
+ ```
11
+
12
+
13
+ ### Benchmark vllm
14
+ ```
15
+ python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
16
+ ```
17
+
18
+ ```
19
+ python3 bench_other.py --num-questions 200 --backend vllm
20
+ ```
21
+
22
+
23
+ ### Benchmark lightllm
24
+ ```
25
+ # A10G
26
+ python -m lightllm.server.api_server --tokenizer_mode auto --model_dir ~/model_weights/llama-2-7b-chat-hf --max_total_token_num 16000 --port 22000
27
+ ```
28
+
29
+ ```
30
+ python3 bench_other.py --num-questions 200 --backend lightllm
31
+ ```
32
+
33
+
34
+ ### Benchmark guidance
35
+ ```
36
+ CUDA_VISIBLE_DEVICES=0,1 python3 bench_other.py --num-questions 200 --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf
37
+ ```
38
+
39
+
40
+ ### Benchmark lmql
41
+ ```
42
+ lmql serve-model meta-llama/Llama-2-7b-chat-hf --cuda --port 23000
43
+ ```
44
+
45
+ ```
46
+ python3 bench_other.py --num-questions 200 --backend lmql --port 23000 --parallel 1
47
+ ```
sglang/benchmark/hellaswag/bench_other.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import asyncio
3
+ import json
4
+ import time
5
+ from concurrent.futures import ThreadPoolExecutor
6
+
7
+ import numpy as np
8
+ from tqdm import tqdm
9
+
10
+ from sglang.test.test_utils import add_common_other_args_and_parse, get_call_select
11
+ from sglang.utils import download_and_cache_file, read_jsonl
12
+
13
+
14
+ def get_one_example(lines, i, include_answer):
15
+ ret = lines[i]["activity_label"] + ": " + lines[i]["ctx"] + " "
16
+ if include_answer:
17
+ ret += lines[i]["endings"][lines[i]["label"]]
18
+ return ret
19
+
20
+
21
+ def get_few_shot_examples(lines, k):
22
+ ret = ""
23
+ for i in range(k):
24
+ ret += get_one_example(lines, i, True) + "\n\n"
25
+ return ret
26
+
27
+
28
+ def main(args):
29
+ # Select backend
30
+ call_select = get_call_select(args)
31
+
32
+ # Read data
33
+ url = "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl"
34
+ filename = download_and_cache_file(url)
35
+ lines = list(read_jsonl(filename))
36
+
37
+ # Construct prompts
38
+ num_questions = args.num_questions
39
+ num_shots = args.num_shots
40
+ few_shot_examples = get_few_shot_examples(lines, num_shots)
41
+
42
+ questions = []
43
+ choices = []
44
+ labels = []
45
+ for i in range(len(lines[:num_questions])):
46
+ questions.append(get_one_example(lines, i, False))
47
+ choices.append(lines[i]["endings"])
48
+ labels.append(lines[i]["label"])
49
+
50
+ preds = [None] * len(labels)
51
+
52
+ # Run requests
53
+ if args.backend != "lmql":
54
+ # Use thread pool
55
+ def get_one_answer(i):
56
+ preds[i] = call_select(
57
+ context=few_shot_examples + questions[i], choices=choices[i]
58
+ )
59
+
60
+ tic = time.time()
61
+ if args.parallel == 1:
62
+ for i in tqdm(range(len(questions))):
63
+ get_one_answer(i)
64
+ else:
65
+ with ThreadPoolExecutor(args.parallel) as executor:
66
+ list(
67
+ tqdm(
68
+ executor.map(get_one_answer, list(range(len(questions)))),
69
+ total=len(questions),
70
+ )
71
+ )
72
+ else:
73
+ # Use asyncio
74
+ async def batched_call(batch_size):
75
+ for i in range(0, len(questions), batch_size):
76
+ tasks = []
77
+ for q, c in zip(
78
+ questions[i : i + batch_size], choices[i : i + batch_size]
79
+ ):
80
+ tasks.append(call_select(context=few_shot_examples + q, choices=c))
81
+ rets = await asyncio.gather(*tasks)
82
+ for j in range(len(rets)):
83
+ preds[i + j] = rets[j]
84
+
85
+ tic = time.time()
86
+ asyncio.run(batched_call(batch_size=args.parallel))
87
+
88
+ latency = time.time() - tic
89
+
90
+ # Compute accuracy
91
+ acc = np.mean(np.array(preds) == np.array(labels))
92
+ print(f"Latency: {latency:.3f}")
93
+ print(f"Accuracy: {acc:.3f}")
94
+
95
+ # Write results
96
+ with open(args.result_file, "a") as fout:
97
+ value = {
98
+ "task": "hellaswag",
99
+ "backend": args.backend,
100
+ "num_gpus": 1,
101
+ "latency": round(latency, 3),
102
+ "accuracy": round(acc, 3),
103
+ "num_requests": args.num_questions,
104
+ "other": {
105
+ "num_questions": args.num_questions,
106
+ "parallel": args.parallel,
107
+ },
108
+ }
109
+ fout.write(json.dumps(value) + "\n")
110
+
111
+
112
+ if __name__ == "__main__":
113
+ parser = argparse.ArgumentParser()
114
+ parser.add_argument("--num-shots", type=int, default=20)
115
+ parser.add_argument("--data-path", type=str, default="hellaswag_val.jsonl")
116
+ parser.add_argument("--num-questions", type=int, default=200)
117
+ args = add_common_other_args_and_parse(parser)
118
+ main(args)
sglang/benchmark/lora/launch_server.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+ NUM_LORAS = 8
5
+ LORA_PATH = {
6
+ "base": "mistralai/Mistral-7B-Instruct-v0.3",
7
+ "lora": "/home/ying/test_lora",
8
+ }
9
+
10
+
11
+ def launch_server(args):
12
+ base_path = LORA_PATH["base"]
13
+ lora_path = LORA_PATH["lora"]
14
+
15
+ if args.base_only:
16
+ cmd = f"python3 -m sglang.launch_server --model {base_path} "
17
+ else:
18
+ cmd = f"python3 -m sglang.launch_server --model {base_path} --lora-paths "
19
+ for i in range(NUM_LORAS):
20
+ lora_name = f"lora{i}"
21
+ cmd += f"{lora_name}={lora_path} "
22
+ cmd += f"--disable-radix --disable-cuda-graph "
23
+ cmd += f"--max-loras-per-batch {args.max_loras_per_batch} "
24
+ cmd += f"--max-running-requests {args.max_running_requests}"
25
+ print(cmd)
26
+ os.system(cmd)
27
+
28
+
29
+ if __name__ == "__main__":
30
+ parser = argparse.ArgumentParser()
31
+ parser.add_argument(
32
+ "--base-only",
33
+ action="store_true",
34
+ )
35
+ parser.add_argument(
36
+ "--max-loras-per-batch",
37
+ type=int,
38
+ default=8,
39
+ )
40
+ parser.add_argument(
41
+ "--max-running-requests",
42
+ type=int,
43
+ default=8,
44
+ )
45
+ args = parser.parse_args()
46
+
47
+ launch_server(args)
sglang/benchmark/lora/lora_bench.py ADDED
@@ -0,0 +1,484 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+
15
+ import argparse
16
+ import asyncio
17
+ import json
18
+ import os
19
+ import random
20
+ import resource
21
+ import sys
22
+ import time
23
+ import traceback
24
+ import warnings
25
+ from argparse import ArgumentParser
26
+ from dataclasses import dataclass, field
27
+ from datetime import datetime
28
+ from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
29
+
30
+ import aiohttp
31
+ import numpy as np
32
+ import requests
33
+ from launch_server import LORA_PATH, NUM_LORAS
34
+ from tqdm.asyncio import tqdm
35
+ from transformers import (
36
+ AutoTokenizer,
37
+ PreTrainedTokenizer,
38
+ PreTrainedTokenizerBase,
39
+ PreTrainedTokenizerFast,
40
+ )
41
+
42
+ from sglang.bench_serving import (
43
+ AIOHTTP_TIMEOUT,
44
+ SHAREGPT_URL,
45
+ BenchmarkMetrics,
46
+ RequestFuncInput,
47
+ RequestFuncOutput,
48
+ calculate_metrics,
49
+ check_chat_template,
50
+ get_model,
51
+ get_request,
52
+ get_tokenizer,
53
+ parse_request_rate_range,
54
+ remove_prefix,
55
+ sample_random_requests,
56
+ )
57
+
58
+ global args
59
+
60
+
61
+ # set ignore_eos True by default
62
+ async def async_request_openai_completions(
63
+ request_func_input: RequestFuncInput,
64
+ pbar: Optional[tqdm] = None,
65
+ ) -> RequestFuncOutput:
66
+ api_url = request_func_input.api_url
67
+ # assert api_url.endswith(
68
+ # "completions"
69
+ # ), "OpenAI Completions API URL must end with 'completions'."
70
+
71
+ prompt = request_func_input.prompt
72
+
73
+ async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
74
+ # payload = {
75
+ # "model": request_func_input.model,
76
+ # "prompt": prompt,
77
+ # "temperature": 0.0,
78
+ # "best_of": 1,
79
+ # "max_tokens": request_func_input.output_len,
80
+ # "stream": not args.disable_stream,
81
+ # "ignore_eos": not args.disable_ignore_eos,
82
+ # **request_func_input.extra_request_body,
83
+ # }
84
+ # headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
85
+ if args.base_only:
86
+ payload = {
87
+ "text": prompt,
88
+ "sampling_params": {"max_new_tokens": request_func_input.output_len},
89
+ }
90
+ else:
91
+ payload = {
92
+ "text": prompt,
93
+ "sampling_params": {"max_new_tokens": request_func_input.output_len},
94
+ "lora_path": f"lora{random.randint(0, NUM_LORAS - 1)}",
95
+ }
96
+ headers = {"Authorization": ""}
97
+
98
+ output = RequestFuncOutput()
99
+ output.prompt_len = request_func_input.prompt_len
100
+
101
+ generated_text = ""
102
+ ttft = 0.0
103
+ st = time.perf_counter()
104
+ most_recent_timestamp = st
105
+ try:
106
+ async with session.post(
107
+ url=api_url, json=payload, headers=headers
108
+ ) as response:
109
+ if response.status == 200:
110
+ async for chunk_bytes in response.content:
111
+ chunk_bytes = chunk_bytes.strip()
112
+ if not chunk_bytes:
113
+ continue
114
+
115
+ chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ")
116
+ latency = time.perf_counter() - st
117
+ if chunk == "[DONE]":
118
+ pass
119
+ else:
120
+ data = json.loads(chunk)
121
+
122
+ # NOTE: Some completion API might have a last
123
+ # usage summary response without a token so we
124
+ # want to check a token was generated
125
+ if data["text"]:
126
+ # if data["choices"][0]["text"]:
127
+ timestamp = time.perf_counter()
128
+ # First token
129
+ if ttft == 0.0:
130
+ ttft = time.perf_counter() - st
131
+ output.ttft = ttft
132
+
133
+ # Decoding phase
134
+ else:
135
+ output.itl.append(timestamp - most_recent_timestamp)
136
+
137
+ most_recent_timestamp = timestamp
138
+ # generated_text += data["choices"][0]["text"]
139
+ generated_text += data["text"]
140
+
141
+ output.generated_text = generated_text
142
+ output.success = True
143
+ output.latency = latency
144
+ output.output_len = request_func_input.output_len
145
+ else:
146
+ output.error = response.reason or ""
147
+ output.success = False
148
+ except Exception:
149
+ output.success = False
150
+ exc_info = sys.exc_info()
151
+ output.error = "".join(traceback.format_exception(*exc_info))
152
+
153
+ if pbar:
154
+ pbar.update(1)
155
+ return output
156
+
157
+
158
+ ASYNC_REQUEST_FUNCS = {
159
+ "sglang": async_request_openai_completions,
160
+ }
161
+
162
+
163
+ async def benchmark(
164
+ backend: str,
165
+ api_url: str,
166
+ model_id: str,
167
+ tokenizer: PreTrainedTokenizerBase,
168
+ input_requests: List[Tuple[str, int, int]],
169
+ request_rate: float,
170
+ disable_tqdm: bool,
171
+ extra_request_body: Dict[str, Any],
172
+ ):
173
+ if backend in ASYNC_REQUEST_FUNCS:
174
+ request_func = ASYNC_REQUEST_FUNCS[backend]
175
+ else:
176
+ raise ValueError(f"Unknown backend: {backend}")
177
+
178
+ print("Starting initial single prompt test run...")
179
+ test_prompt, test_prompt_len, test_output_len = input_requests[0]
180
+ test_input = RequestFuncInput(
181
+ model=model_id,
182
+ prompt=test_prompt,
183
+ api_url=api_url,
184
+ prompt_len=test_prompt_len,
185
+ output_len=test_output_len,
186
+ extra_request_body=extra_request_body,
187
+ )
188
+ test_output = await request_func(request_func_input=test_input)
189
+ if not test_output.success:
190
+ raise ValueError(
191
+ "Initial test run failed - Please make sure benchmark arguments "
192
+ f"are correctly specified. Error: {test_output.error}"
193
+ )
194
+ else:
195
+ print("Initial test run completed. Starting main benchmark run...")
196
+
197
+ pbar = None if disable_tqdm else tqdm(total=len(input_requests))
198
+
199
+ benchmark_start_time = time.perf_counter()
200
+ tasks: List[asyncio.Task] = []
201
+ async for request in get_request(input_requests, request_rate):
202
+ prompt, prompt_len, output_len = request
203
+ request_func_input = RequestFuncInput(
204
+ model=model_id,
205
+ prompt=prompt,
206
+ api_url=api_url,
207
+ prompt_len=prompt_len,
208
+ output_len=output_len,
209
+ extra_request_body=extra_request_body,
210
+ )
211
+ tasks.append(
212
+ asyncio.create_task(
213
+ request_func(request_func_input=request_func_input, pbar=pbar)
214
+ )
215
+ )
216
+ outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks)
217
+
218
+ if pbar is not None:
219
+ pbar.close()
220
+
221
+ benchmark_duration = time.perf_counter() - benchmark_start_time
222
+
223
+ metrics, output_lens = calculate_metrics(
224
+ input_requests=input_requests,
225
+ outputs=outputs,
226
+ dur_s=benchmark_duration,
227
+ tokenizer=tokenizer,
228
+ backend=backend,
229
+ )
230
+
231
+ print("\n{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="="))
232
+ print("{:<40} {:<10}".format("Backend:", backend))
233
+ print("{:<40} {:<10}".format("Traffic request rate:", request_rate))
234
+ print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
235
+ print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration))
236
+ print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
237
+ print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output))
238
+ print(
239
+ "{:<40} {:<10}".format(
240
+ "Total generated tokens (retokenized):", metrics.total_output_retokenized
241
+ )
242
+ )
243
+ print(
244
+ "{:<40} {:<10.2f}".format(
245
+ "Request throughput (req/s):", metrics.request_throughput
246
+ )
247
+ )
248
+ print(
249
+ "{:<40} {:<10.2f}".format(
250
+ "Input token throughput (tok/s):", metrics.input_throughput
251
+ )
252
+ )
253
+ print(
254
+ "{:<40} {:<10.2f}".format(
255
+ "Output token throughput (tok/s):", metrics.output_throughput
256
+ )
257
+ )
258
+ print("{s:{c}^{n}}".format(s="End-to-End Latency", n=50, c="-"))
259
+ print(
260
+ "{:<40} {:<10.2f}".format("Mean E2E Latency (ms):", metrics.mean_e2e_latency_ms)
261
+ )
262
+ print(
263
+ "{:<40} {:<10.2f}".format(
264
+ "Median E2E Latency (ms):", metrics.median_e2e_latency_ms
265
+ )
266
+ )
267
+ print("{s:{c}^{n}}".format(s="Time to First Token", n=50, c="-"))
268
+ print("{:<40} {:<10.2f}".format("Mean TTFT (ms):", metrics.mean_ttft_ms))
269
+ print("{:<40} {:<10.2f}".format("Median TTFT (ms):", metrics.median_ttft_ms))
270
+ print("{:<40} {:<10.2f}".format("P99 TTFT (ms):", metrics.p99_ttft_ms))
271
+ print(
272
+ "{s:{c}^{n}}".format(s="Time per Output Token (excl. 1st token)", n=50, c="-")
273
+ )
274
+ print("{:<40} {:<10.2f}".format("Mean TPOT (ms):", metrics.mean_tpot_ms))
275
+ print("{:<40} {:<10.2f}".format("Median TPOT (ms):", metrics.median_tpot_ms))
276
+ print("{:<40} {:<10.2f}".format("P99 TPOT (ms):", metrics.p99_tpot_ms))
277
+ print("{s:{c}^{n}}".format(s="Inter-token Latency", n=50, c="-"))
278
+ print("{:<40} {:<10.2f}".format("Mean ITL (ms):", metrics.mean_itl_ms))
279
+ print("{:<40} {:<10.2f}".format("Median ITL (ms):", metrics.median_itl_ms))
280
+ print("{:<40} {:<10.2f}".format("P99 ITL (ms):", metrics.p99_itl_ms))
281
+ print("=" * 50)
282
+
283
+ if (
284
+ metrics.median_ttft_ms is not None
285
+ and metrics.mean_itl_ms is not None
286
+ and metrics.output_throughput is not None
287
+ ):
288
+ result = {
289
+ "backend": args.backend,
290
+ "request_rate": request_rate,
291
+ "total_input_tokens": metrics.total_input,
292
+ "total_output_tokens": metrics.total_output,
293
+ "total_output_tokens_retokenized": metrics.total_output_retokenized,
294
+ "mean_e2e_latency_ms": metrics.mean_e2e_latency_ms,
295
+ "median_e2e_latency_ms": metrics.median_e2e_latency_ms,
296
+ "median_ttft_ms": metrics.median_ttft_ms,
297
+ "median_itl_ms": metrics.median_itl_ms,
298
+ "output_throughput": metrics.output_throughput,
299
+ "random_input_len": args.random_input_len,
300
+ "random_output_len": args.random_output_len,
301
+ "random_range_ratio": args.random_range_ratio,
302
+ "duration": benchmark_duration,
303
+ "completed": metrics.completed,
304
+ }
305
+ else:
306
+ print(f"Error running benchmark for request rate: {request_rate}")
307
+ print("-" * 30)
308
+
309
+ # Determine output file name
310
+ if args.output_file:
311
+ output_file_name = args.output_file
312
+ else:
313
+ now = datetime.now().strftime("%m%d")
314
+ output_file_name = f"{args.backend}_{now}_{args.num_prompts}_{args.random_input_len}_{args.random_output_len}.jsonl"
315
+
316
+ # Append results to a JSONL file
317
+ with open(output_file_name, "a") as file:
318
+ file.write(json.dumps(result) + "\n")
319
+
320
+ result = {
321
+ "duration": benchmark_duration,
322
+ "completed": metrics.completed,
323
+ "total_input_tokens": metrics.total_input,
324
+ "total_output_tokens": metrics.total_output,
325
+ "total_output_tokens_retokenized": metrics.total_output_retokenized,
326
+ "request_throughput": metrics.request_throughput,
327
+ "input_throughput": metrics.input_throughput,
328
+ "output_throughput": metrics.output_throughput,
329
+ "mean_ttft_ms": metrics.mean_ttft_ms,
330
+ "median_ttft_ms": metrics.median_ttft_ms,
331
+ "std_ttft_ms": metrics.std_ttft_ms,
332
+ "p99_ttft_ms": metrics.p99_ttft_ms,
333
+ "mean_tpot_ms": metrics.mean_tpot_ms,
334
+ "median_tpot_ms": metrics.median_tpot_ms,
335
+ "std_tpot_ms": metrics.std_tpot_ms,
336
+ "p99_tpot_ms": metrics.p99_tpot_ms,
337
+ "mean_itl_ms": metrics.mean_itl_ms,
338
+ "median_itl_ms": metrics.median_itl_ms,
339
+ "std_itl_ms": metrics.std_itl_ms,
340
+ "p99_itl_ms": metrics.p99_itl_ms,
341
+ "input_lens": [output.prompt_len for output in outputs],
342
+ "output_lens": output_lens,
343
+ "ttfts": [output.ttft for output in outputs],
344
+ "itls": [output.itl for output in outputs],
345
+ "generated_texts": [output.generated_text for output in outputs],
346
+ "errors": [output.error for output in outputs],
347
+ "mean_e2e_latency_ms": metrics.mean_e2e_latency_ms,
348
+ "median_e2e_latency_ms": metrics.median_e2e_latency_ms,
349
+ }
350
+ return result
351
+
352
+
353
+ def run_benchmark(args_: argparse.Namespace):
354
+ global args
355
+ args = args_
356
+
357
+ # Set global environments
358
+ set_ulimit()
359
+ random.seed(args.seed)
360
+ np.random.seed(args.seed)
361
+
362
+ # Set url
363
+ if args.port is None:
364
+ args.port = {
365
+ "sglang": 30000,
366
+ }.get(args.backend, 30000)
367
+
368
+ # api_url = (
369
+ # f"{args.base_url}/v1/completions"
370
+ # if args.base_url
371
+ # else f"http://{args.host}:{args.port}/v1/completions"
372
+ # )
373
+ api_url = (
374
+ f"{args.base_url}/generate"
375
+ if args.base_url
376
+ else f"http://{args.host}:{args.port}/generate"
377
+ )
378
+
379
+ print(f"{args}\n")
380
+
381
+ # Read dataset
382
+ backend = args.backend
383
+ model_id = args.model = LORA_PATH["base"]
384
+ tokenizer_id = args.model
385
+
386
+ tokenizer = get_tokenizer(tokenizer_id)
387
+
388
+ input_requests = sample_random_requests(
389
+ input_len=args.random_input_len,
390
+ output_len=args.random_output_len,
391
+ num_prompts=args.num_prompts,
392
+ range_ratio=args.random_range_ratio,
393
+ tokenizer=tokenizer,
394
+ dataset_path="",
395
+ )
396
+
397
+ return asyncio.run(
398
+ benchmark(
399
+ backend=backend,
400
+ api_url=api_url,
401
+ model_id=model_id,
402
+ tokenizer=tokenizer,
403
+ input_requests=input_requests,
404
+ request_rate=args.request_rate,
405
+ disable_tqdm=False,
406
+ extra_request_body={},
407
+ )
408
+ )
409
+
410
+
411
+ def set_ulimit(target_soft_limit=65535):
412
+ resource_type = resource.RLIMIT_NOFILE
413
+ current_soft, current_hard = resource.getrlimit(resource_type)
414
+
415
+ if current_soft < target_soft_limit:
416
+ try:
417
+ resource.setrlimit(resource_type, (target_soft_limit, current_hard))
418
+ except ValueError as e:
419
+ print(f"Fail to set RLIMIT_NOFILE: {e}")
420
+
421
+
422
+ if __name__ == "__main__":
423
+ parser = ArgumentParser(description="Benchmark the online lora serving throughput.")
424
+ parser.add_argument(
425
+ "--backend",
426
+ type=str,
427
+ choices=list(ASYNC_REQUEST_FUNCS.keys()),
428
+ default="sglang",
429
+ help="Must specify a backend, depending on the LLM Inference Engine.",
430
+ )
431
+ parser.add_argument(
432
+ "--base-url",
433
+ type=str,
434
+ default=None,
435
+ help="Server or API base url if not using http host and port.",
436
+ )
437
+ parser.add_argument(
438
+ "--host", type=str, default="0.0.0.0", help="Default host is 0.0.0.0."
439
+ )
440
+ parser.add_argument(
441
+ "--port",
442
+ type=int,
443
+ help="If not set, the default port is configured according to its default value for different LLM Inference Engines.",
444
+ )
445
+ parser.add_argument(
446
+ "--num-prompts",
447
+ type=int,
448
+ default=50,
449
+ help="Number of prompts to process. Default is 1000.",
450
+ )
451
+ parser.add_argument(
452
+ "--random-input-len",
453
+ type=int,
454
+ default=1024,
455
+ help="Number of input tokens per request, used only for random dataset.",
456
+ )
457
+ parser.add_argument(
458
+ "--random-output-len",
459
+ type=int,
460
+ default=128,
461
+ help="Number of output tokens per request, used only for random dataset.",
462
+ )
463
+ parser.add_argument(
464
+ "--random-range-ratio",
465
+ type=float,
466
+ default=0.0,
467
+ help="Range of sampled ratio of input/output length, "
468
+ "used only for random dataset.",
469
+ )
470
+ parser.add_argument(
471
+ "--request-rate",
472
+ type=float,
473
+ default=float("inf"),
474
+ help="Number of requests per second. If this is inf, then all the requests are sent at time 0. "
475
+ "Otherwise, we use Poisson process to synthesize the request arrival times. Default is inf.",
476
+ )
477
+ parser.add_argument(
478
+ "--base-only",
479
+ action="store_true",
480
+ )
481
+ parser.add_argument("--output-file", type=str, help="Output JSONL file name.")
482
+ parser.add_argument("--seed", type=int, default=1, help="The random seed.")
483
+ args = parser.parse_args()
484
+ run_benchmark(args)
sglang/benchmark/mmlu/README.md ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Download data
2
+ ```
3
+ bash download_data.sh
4
+ ```
5
+
6
+ ## Run benchmark
7
+
8
+ ### Benchmark sglang
9
+ ```
10
+ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
11
+ ```
12
+
13
+ ```
14
+ python3 bench_sglang.py --nsub 10
15
+ ```
16
+
17
+ ```
18
+ # OpenAI models
19
+ python3 bench_sglang.py --backend gpt-3.5-turbo --parallel 8
20
+ ```
21
+
22
+ ### Benchmark vllm
23
+ ```
24
+ python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
25
+ ```
26
+
27
+ ```
28
+ python3 bench_other.py --nsub 10 --backend vllm
29
+ ```
30
+
31
+
32
+ ### Benchmark lightllm
33
+ ```
34
+ # A10G
35
+ python -m lightllm.server.api_server --tokenizer_mode auto --model_dir ~/model_weights/llama-2-7b-chat-hf --max_total_token_num 16000 --port 22000
36
+
37
+ # V100
38
+ python -m lightllm.server.api_server --tokenizer_mode auto --model_dir ~/model_weights/llama-2-7b-chat-hf --max_total_token_num 4500 --port 22000
39
+ ```
40
+
41
+ ```
42
+ python3 bench_other.py --nsub 10 --backend lightllm
43
+ ```
44
+
45
+
46
+ ### Benchmark guidance
47
+ ```
48
+ python3 bench_other.py --nsub 10 --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf
49
+ ```
50
+
51
+
52
+ ### Benchmark lmql
53
+ ```
54
+ CUDA_VISIBLE_DEVICES=0,1 lmql serve-model meta-llama/Llama-2-7b-chat-hf --cuda --port 23000
55
+ ```
56
+
57
+ ```
58
+ python3 bench_other.py --nsub 10 --backend lmql --parallel 2
59
+ ```
sglang/benchmark/mtbench/README.md ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Download Dataset
2
+
3
+ ```sh
4
+ wget -O question.jsonl https://raw.githubusercontent.com/lm-sys/FastChat/main/fastchat/llm_judge/data/mt_bench/question.jsonl
5
+ ```
6
+
7
+ ## Run benchmark
8
+
9
+ ### Benchmark sglang
10
+ ```
11
+ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
12
+ ```
13
+
14
+ ```
15
+ python3 bench_sglang.py --num-questions 80
16
+ ```
17
+
18
+
19
+ ### Benchmark vllm
20
+ ```
21
+ python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
22
+ ```
23
+
24
+ ```
25
+ python3 bench_other.py --num-questions 80 --backend vllm
26
+ ```
27
+
28
+
29
+ ### Benchmark lightllm
30
+ ```
31
+ # A10G
32
+ python -m lightllm.server.api_server --tokenizer_mode auto --model_dir ~/model_weights/llama-2-7b-chat-hf --max_total_token_num 16000 --port 22000
33
+ ```
34
+
35
+ ```
36
+ python3 bench_other.py --num-questions 80 --backend lightllm
37
+ ```
sglang/benchmark/mtbench/bench_other.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ import time
5
+ import uuid
6
+ from concurrent.futures import ThreadPoolExecutor
7
+
8
+ from fastchat.model import get_conversation_template
9
+ from tqdm import tqdm
10
+
11
+ from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
12
+
13
+
14
+ def load_questions(filename):
15
+ questions = []
16
+ with open(filename, "r") as fin:
17
+ for line in fin:
18
+ obj = json.loads(line)
19
+ questions.append(obj)
20
+ return questions
21
+
22
+
23
+ def write_answers(filename, model_id, questions, answers):
24
+ with open(os.path.expanduser(filename), "w") as fout:
25
+ for i in range(len(answers)):
26
+ ans_json = {
27
+ "question_id": questions[i]["question_id"],
28
+ "answer_id": uuid.uuid4().hex,
29
+ "model_id": model_id,
30
+ "choices": {
31
+ "index": 0,
32
+ "turns": [answers[i][0], answers[i][1]],
33
+ },
34
+ "tstamp": time.time(),
35
+ }
36
+ fout.write(json.dumps(ans_json) + "\n")
37
+
38
+
39
+ def main(args):
40
+ questions = load_questions(args.question_file)
41
+ questions = (questions * 10)[: args.num_questions]
42
+ max_tokens = 256
43
+ model_id = "llama-2-chat"
44
+
45
+ conv_main = get_conversation_template(model_id)
46
+
47
+ # Select backend
48
+ call_generate = get_call_generate(args)
49
+
50
+ answers = [None] * len(questions)
51
+
52
+ def get_answer(i):
53
+ conv = conv_main.copy()
54
+ cur_answers = []
55
+ for j in range(2):
56
+ q = questions[i]["turns"][j]
57
+ conv.append_message(conv.roles[0], q)
58
+ conv.append_message(conv.roles[1], None)
59
+
60
+ prompt = conv.get_prompt()
61
+ output = call_generate(prompt, temperature=0, max_tokens=max_tokens).strip()
62
+
63
+ cur_answers.append(output)
64
+ conv.update_last_message(output)
65
+
66
+ answers[i] = cur_answers
67
+
68
+ # Run requests
69
+ tic = time.time()
70
+ if args.parallel == 1:
71
+ for i in tqdm(range(len(questions))):
72
+ get_answer(i)
73
+ else:
74
+ with ThreadPoolExecutor(args.parallel) as executor:
75
+ list(
76
+ tqdm(
77
+ executor.map(get_answer, list(range(len(questions)))),
78
+ total=len(questions),
79
+ )
80
+ )
81
+
82
+ latency = time.time() - tic
83
+
84
+ print(f"#questions: {len(questions)}, Latency: {latency:.2f}")
85
+
86
+ # Write results
87
+ answer_file = args.answer_file or f"tmp_output_{args.backend}.txt"
88
+ write_answers(answer_file, model_id, questions, answers)
89
+
90
+ with open(args.result_file, "a") as fout:
91
+ value = {
92
+ "task": "mtbench",
93
+ "backend": args.backend,
94
+ "num_gpus": 1,
95
+ "latency": round(latency, 3),
96
+ "num_requests": args.num_questions,
97
+ "other": {
98
+ "num_questions": args.num_questions,
99
+ "parallel": args.parallel,
100
+ },
101
+ }
102
+ fout.write(json.dumps(value) + "\n")
103
+
104
+
105
+ if __name__ == "__main__":
106
+ parser = argparse.ArgumentParser()
107
+ parser.add_argument("--question-file", type=str, default="question.jsonl")
108
+ parser.add_argument("--answer-file", type=str, default=None)
109
+ parser.add_argument("--num-questions", type=int, default=80)
110
+ args = add_common_other_args_and_parse(parser)
111
+ main(args)
sglang/benchmark/mtbench/bench_sglang.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ import time
5
+ import uuid
6
+
7
+ import sglang as sgl
8
+ from sglang.test.test_utils import (
9
+ add_common_sglang_args_and_parse,
10
+ select_sglang_backend,
11
+ )
12
+
13
+
14
+ def load_questions(filename):
15
+ questions = []
16
+ with open(filename, "r") as fin:
17
+ for line in fin:
18
+ obj = json.loads(line)
19
+ questions.append(obj)
20
+ return questions
21
+
22
+
23
+ def write_answers(filename, model_id, questions, answers):
24
+ with open(os.path.expanduser(filename), "w") as fout:
25
+ for i in range(len(answers)):
26
+ ans_json = {
27
+ "question_id": questions[i]["question_id"],
28
+ "answer_id": uuid.uuid4().hex,
29
+ "model_id": model_id,
30
+ "choices": {
31
+ "index": 0,
32
+ "turns": [answers[i][0], answers[i][1]],
33
+ },
34
+ "tstamp": time.time(),
35
+ }
36
+ fout.write(json.dumps(ans_json) + "\n")
37
+
38
+
39
+ @sgl.function
40
+ def answer_mt_bench(s, question_1, question_2):
41
+ s += sgl.system()
42
+ s += sgl.user(question_1)
43
+ s += sgl.assistant(sgl.gen("answer_1"))
44
+ s += sgl.user(question_2)
45
+ s += sgl.assistant(sgl.gen("answer_2"))
46
+
47
+
48
+ def main(args):
49
+ # Construct prompts
50
+ questions = load_questions(args.question_file)[: args.num_questions]
51
+ arguments = [
52
+ {"question_1": q["turns"][0], "question_2": q["turns"][1]} for q in questions
53
+ ]
54
+
55
+ # Select backend
56
+ backend = select_sglang_backend(args)
57
+ sgl.set_default_backend(backend)
58
+
59
+ # Run requests
60
+ tic = time.time()
61
+ rets = answer_mt_bench.run_batch(
62
+ arguments,
63
+ temperature=0,
64
+ max_new_tokens=256,
65
+ num_threads=args.parallel,
66
+ progress_bar=True,
67
+ )
68
+ answers = [[s["answer_1"], s["answer_2"]] for s in rets]
69
+ latency = time.time() - tic
70
+
71
+ print(f"#questions: {len(questions)}, Latency: {latency:.2f}")
72
+
73
+ # Write results
74
+ model_id = backend.model_info["model_path"]
75
+ answer_file = args.answer_file or f"tmp_output_{args.backend}.txt"
76
+ write_answers(answer_file, model_id, questions, answers)
77
+
78
+ with open(args.result_file, "a") as fout:
79
+ value = {
80
+ "task": "mtbench",
81
+ "backend": args.backend,
82
+ "num_gpus": 1,
83
+ "latency": round(latency, 3),
84
+ "num_requests": args.num_questions,
85
+ "other": {
86
+ "num_questions": args.num_questions,
87
+ "parallel": args.parallel,
88
+ },
89
+ }
90
+ fout.write(json.dumps(value) + "\n")
91
+
92
+
93
+ if __name__ == "__main__":
94
+ parser = argparse.ArgumentParser()
95
+ parser.add_argument("--question-file", type=str, default="question.jsonl")
96
+ parser.add_argument("--answer-file", type=str, default=None)
97
+ parser.add_argument("--num-questions", type=int, default=80)
98
+ args = add_common_sglang_args_and_parse(parser)
99
+ main(args)
sglang/benchmark/multi_chain_reasoning/README.md ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Download data
2
+ ```
3
+ wget https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl
4
+ ```
5
+
6
+ ## Run benchmark
7
+
8
+ ### Benchmark sglang
9
+ ```
10
+ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --schedule-conservativeness 1.3
11
+ ```
12
+
13
+ ```
14
+ python3 bench_sglang.py --num-questions 64
15
+ python3 bench_sglang.py --num-questions 32 --parallel 1
16
+ ```
17
+
18
+
19
+ ### Benchmark vllm
20
+ ```
21
+ python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
22
+ ```
23
+
24
+ ```
25
+ python3 bench_other.py --num-questions 64 --backend vllm
26
+ ```
27
+
28
+
29
+ ### Benchmark lightllm
30
+ ```
31
+ # A10G
32
+ python -m lightllm.server.api_server --tokenizer_mode auto --model_dir ~/model_weights/llama-2-7b-chat-hf --max_total_token_num 16000 --port 22000
33
+ ```
34
+
35
+ ```
36
+ python3 bench_other.py --num-questions 64 --backend lightllm
37
+ ```
38
+
39
+
40
+ ### Benchmark guidance
41
+ ```
42
+ python3 bench_other.py --num-questions 8 --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf
43
+ ```
44
+
45
+ ### Benchmark lmql
46
+
47
+ ```
48
+ python3 bench_other.py --num-questions 64 --backend lmql --parallel 1
49
+ ```
sglang/benchmark/multi_chain_reasoning/bench_sglang.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import ast
3
+ import json
4
+ import re
5
+ import time
6
+
7
+ import numpy as np
8
+
9
+ from sglang.test.test_utils import (
10
+ add_common_sglang_args_and_parse,
11
+ select_sglang_backend,
12
+ )
13
+ from sglang.utils import dump_state_text, read_jsonl
14
+
15
+ INVALID = -9999999
16
+
17
+
18
+ def get_answer_value(answer_str):
19
+ answer_str = answer_str.replace(",", "")
20
+ numbers = re.findall(r"\d+", answer_str)
21
+ if len(numbers) < 1:
22
+ return INVALID
23
+ try:
24
+ return ast.literal_eval(numbers[-1])
25
+ except SyntaxError:
26
+ return INVALID
27
+
28
+
29
+ prompt_lib = [
30
+ "Let us think step by step.",
31
+ "Approach this methodically. Let's dissect the problem into smaller, more manageable parts.",
32
+ "It's important to proceed step by step, ensuring accuracy at each stage.",
33
+ "Take a deep breath and break this down.",
34
+ "A little bit of arithmetic and a logical approach will help us quickly arrive at the solution to this problem.",
35
+ "I am extremely good at math.",
36
+ ]
37
+
38
+
39
+ def main(args):
40
+ lines = read_jsonl(args.data_path)
41
+
42
+ # Construct prompts
43
+ # k = args.num_shot
44
+ # few_shot_examples = get_few_shot_examples(lines, k)
45
+
46
+ questions = []
47
+ labels = []
48
+ for i in range(len(lines[: args.num_questions])):
49
+ questions.append(lines[i]["question"])
50
+ labels.append(get_answer_value(lines[i]["answer"]))
51
+ assert all(l != INVALID for l in labels)
52
+ arguments = [{"question": q} for q in questions]
53
+
54
+ num_chains = args.num_chains
55
+
56
+ #####################################
57
+ ######### SGL Program Begin #########
58
+ #####################################
59
+
60
+ import sglang as sgl
61
+
62
+ @sgl.function
63
+ def multi_chain_gsm8k(s, question):
64
+ s += "Question: " + question + "\n"
65
+ # s += "Answer: " + prompt_lib[0] + sgl.gen("answer", max_tokens=256, stop="Question",
66
+ # temperature=0)
67
+ # return
68
+
69
+ forks = s.fork(num_chains)
70
+ for i in range(num_chains):
71
+ forks[i] += (
72
+ "Answer: "
73
+ + prompt_lib[i % num_chains]
74
+ + sgl.gen("chain", max_tokens=256, temperature=0.3, stop="Question")
75
+ )
76
+ forks.join()
77
+
78
+ s += "Answer: To answer this question, here are some possible solutions. "
79
+ s += "After considering all of them, I will do a majority vote.\n\n"
80
+ for i in range(num_chains):
81
+ s += f"Solution {i+1}: " + forks[i]["chain"].strip() + "\n\n"
82
+ s += "\nBy considering the above solutions and doing a majority vote, I think the final answer (a single integer number) is "
83
+ s += sgl.gen("answer", max_tokens=16)
84
+
85
+ #####################################
86
+ ########## SGL Program End ##########
87
+ #####################################
88
+
89
+ # Select backend
90
+ backend = select_sglang_backend(args)
91
+
92
+ # Run requests
93
+ tic = time.time()
94
+ states = multi_chain_gsm8k.run_batch(
95
+ arguments,
96
+ temperature=0,
97
+ backend=backend,
98
+ num_threads=args.parallel,
99
+ progress_bar=True,
100
+ )
101
+ latency = time.time() - tic
102
+
103
+ preds = []
104
+ for i in range(len(states)):
105
+ preds.append(get_answer_value(states[i]["answer"]))
106
+
107
+ # Compute accuracy
108
+ acc = np.mean(np.array(preds) == np.array(labels))
109
+ invalid = np.mean(np.array(preds) == INVALID)
110
+ print(f"Latency: {latency:.3f}")
111
+ print(f"Invalid: {invalid:.3f}")
112
+ print(f"Accuracy: {acc:.3f}")
113
+
114
+ # Write results
115
+ dump_state_text(f"tmp_output_{args.backend}.txt", states)
116
+
117
+ with open(args.result_file, "a") as fout:
118
+ value = {
119
+ "task": "multi_chain_gsm8k",
120
+ "backend": args.backend,
121
+ "num_gpus": 1,
122
+ "latency": round(latency, 3),
123
+ "accuracy": round(acc, 3),
124
+ "num_requests": args.num_questions,
125
+ "other": {
126
+ "num_questions": args.num_questions,
127
+ "parallel": args.parallel,
128
+ },
129
+ }
130
+ fout.write(json.dumps(value) + "\n")
131
+
132
+
133
+ if __name__ == "__main__":
134
+ parser = argparse.ArgumentParser()
135
+ parser.add_argument("--num-shot", type=int, default=0)
136
+ parser.add_argument("--num-chains", type=int, default=5)
137
+ parser.add_argument("--data-path", type=str, default="test.jsonl")
138
+ parser.add_argument("--num-questions", type=int, default=50)
139
+ args = add_common_sglang_args_and_parse(parser)
140
+ main(args)
sglang/benchmark/multi_turn_chat/README.md ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### Benchmark sglang
2
+
3
+ Run Llama-7B
4
+
5
+ ```
6
+ python3 -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
7
+ ```
8
+
9
+ Run Mixtral-8x7B
10
+ (When there is a CUDA out-of-memory error, try to reduce the `--mem-fraction-static`)
11
+
12
+ ```
13
+ python3 -m sglang.launch_server --model-path mistralai/Mixtral-8x7B-Instruct-v0.1 --port 30000 --tp-size 8
14
+ ```
15
+
16
+ Benchmark(short output)
17
+
18
+ ```
19
+ python3 bench_sglang.py --tokenizer meta-llama/Llama-2-7b-chat-hf
20
+ ```
21
+
22
+ Benchmark(long output)
23
+
24
+ ```
25
+ python3 bench_sglang.py --tokenizer meta-llama/Llama-2-7b-chat-hf --long
26
+ ```
27
+
28
+ ### Benchmark vLLM
29
+
30
+ Run Llama-7B
31
+
32
+ ```
33
+ python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
34
+ ```
35
+
36
+ Run Mixtral-8x7B
37
+
38
+ ```
39
+ python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model mistralai/Mixtral-8x7B-Instruct-v0.1 --disable-log-requests --port 21000 --tensor-parallel-size 8
40
+ ```
41
+
42
+ Benchmark(short output)
43
+
44
+ ```
45
+ python3 bench_other.py --tokenizer meta-llama/Llama-2-7b-chat-hf --backend vllm
46
+ ```
47
+
48
+ Benchmark(long output)
49
+
50
+ ```
51
+ python3 bench_other.py --tokenizer meta-llama/Llama-2-7b-chat-hf --backend vllm --long
52
+ ```
53
+
54
+ ### Benchmark guidance
55
+
56
+ Benchmark Llama-7B (short output)
57
+
58
+ ```
59
+ python3 bench_other.py --tokenizer meta-llama/Llama-2-7b-chat-hf --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf
60
+ ```
61
+
62
+ Benchmark Llama-7B (long output)
63
+
64
+ ```
65
+ python3 bench_other.py --tokenizer meta-llama/Llama-2-7b-chat-hf --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf --long
66
+ ```
sglang/benchmark/multi_turn_chat/bench_other.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import time
3
+ from argparse import ArgumentParser
4
+ from concurrent.futures import ThreadPoolExecutor
5
+ from functools import partial
6
+
7
+ from data_gen import gen_arguments
8
+ from tqdm import tqdm
9
+ from vllm.transformers_utils.tokenizer import get_tokenizer
10
+
11
+ from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
12
+ from sglang.utils import dump_state_text
13
+
14
+
15
+ def multi_turns(generate, qas):
16
+ s = ""
17
+ for qa in qas:
18
+ s += qa["prompt"]
19
+ s += generate(s, max_tokens=qa["new_tokens"])
20
+
21
+ return s
22
+
23
+
24
+ def main(args):
25
+ print(args)
26
+
27
+ tokenizer = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code)
28
+
29
+ multi_qas = gen_arguments(args, tokenizer)
30
+
31
+ states = [None] * args.num_qa
32
+
33
+ call_generate = partial(get_call_generate(args), temperature=0)
34
+
35
+ def get_one_answer(i):
36
+ states[i] = multi_turns(generate=call_generate, **multi_qas[i])
37
+
38
+ tic = time.time()
39
+ if args.parallel == 1:
40
+ for i in tqdm(range(len(multi_qas))):
41
+ get_one_answer(i)
42
+ else:
43
+ with ThreadPoolExecutor(args.parallel) as executor:
44
+ rets = list(
45
+ tqdm(
46
+ executor.map(get_one_answer, list(range(len(multi_qas)))),
47
+ total=len(multi_qas),
48
+ )
49
+ )
50
+ for _ in rets:
51
+ pass
52
+
53
+ latency = time.time() - tic
54
+
55
+ # Compute accuracy
56
+ print(f"Latency: {latency:.3f}")
57
+
58
+ dump_state_text(f"tmp_output_{args.backend}.txt", states)
59
+
60
+ with open(args.result_file, "a") as fout:
61
+ value = {
62
+ "task": "multi_turn_chat",
63
+ "backend": args.backend,
64
+ "num_gpus": 1,
65
+ "latency": round(latency, 3),
66
+ "num_requests": args.num_qa,
67
+ "num_turns": args.turns,
68
+ "other": {
69
+ "parallel": args.parallel,
70
+ "output_mode": "long" if args.long else "short",
71
+ },
72
+ }
73
+ fout.write(json.dumps(value) + "\n")
74
+
75
+
76
+ if __name__ == "__main__":
77
+ parser = ArgumentParser()
78
+ parser.add_argument("--turns", type=int, default=4)
79
+ parser.add_argument("--num-qa", type=int, default=20)
80
+ parser.add_argument("--min-len-q", type=int, default=256)
81
+ parser.add_argument("--max-len-q", type=int, default=512)
82
+ parser.add_argument("--min-len-a", type=int, default=4)
83
+ parser.add_argument("--max-len-a", type=int, default=8)
84
+ parser.add_argument("--tokenizer", type=str, required=True)
85
+ parser.add_argument("--trust-remote-code", action="store_true")
86
+ parser.add_argument("--long", action="store_true")
87
+ args = add_common_other_args_and_parse(parser)
88
+
89
+ if args.long:
90
+ args.min_len_a = 256
91
+ args.max_len_a = 512
92
+ args.num_qa = 20
93
+ main(args)
sglang/benchmark/multi_turn_chat/bench_sglang.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import time
3
+ from argparse import ArgumentParser
4
+
5
+ from data_gen import gen_arguments
6
+ from vllm.transformers_utils.tokenizer import get_tokenizer
7
+
8
+ import sglang as sgl
9
+ from sglang.test.test_utils import (
10
+ add_common_sglang_args_and_parse,
11
+ select_sglang_backend,
12
+ )
13
+ from sglang.utils import dump_state_text
14
+
15
+
16
+ @sgl.function
17
+ def multi_turns(s, qas):
18
+ for qa in qas:
19
+ s += qa["prompt"]
20
+ s += sgl.gen(max_tokens=qa["new_tokens"], ignore_eos=True)
21
+
22
+
23
+ def main(args):
24
+ tokenizer = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code)
25
+
26
+ multi_qas = gen_arguments(args, tokenizer)
27
+
28
+ backend = select_sglang_backend(args)
29
+
30
+ tic = time.time()
31
+ states = multi_turns.run_batch(
32
+ multi_qas,
33
+ temperature=0,
34
+ backend=backend,
35
+ num_threads=args.parallel,
36
+ progress_bar=True,
37
+ )
38
+ latency = time.time() - tic
39
+
40
+ print(f"Latency: {latency:.3f}")
41
+
42
+ dump_state_text(f"tmp_output_{args.backend}.txt", states)
43
+
44
+ with open(args.result_file, "a") as fout:
45
+ value = {
46
+ "task": "multi_turn_chat",
47
+ "backend": args.backend,
48
+ "num_gpus": 1,
49
+ "latency": round(latency, 3),
50
+ "num_requests": args.num_qa,
51
+ "num_turns": args.turns,
52
+ "other": {
53
+ "parallel": args.parallel,
54
+ "output_mode": "long" if args.long else "short",
55
+ },
56
+ }
57
+ fout.write(json.dumps(value) + "\n")
58
+
59
+
60
+ if __name__ == "__main__":
61
+ parser = ArgumentParser()
62
+ parser.add_argument("--turns", type=int, default=4)
63
+ parser.add_argument("--num-qa", type=int, default=20)
64
+ parser.add_argument("--min-len-q", type=int, default=256)
65
+ parser.add_argument("--max-len-q", type=int, default=512)
66
+ parser.add_argument("--min-len-a", type=int, default=4)
67
+ parser.add_argument("--max-len-a", type=int, default=8)
68
+ parser.add_argument("--tokenizer", type=str, required=True)
69
+ parser.add_argument("--trust-remote-code", action="store_true")
70
+ parser.add_argument("--long", action="store_true")
71
+ args = add_common_sglang_args_and_parse(parser)
72
+
73
+ if args.long:
74
+ args.min_len_a = 256
75
+ args.max_len_a = 512
76
+ args.num_qa = 20
77
+
78
+ print(args)
79
+ main(args)