Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- LLaVA/.devcontainer/Dockerfile +53 -0
- LLaVA/.devcontainer/devcontainer.env +2 -0
- LLaVA/.devcontainer/devcontainer.json +71 -0
- LLaVA/.devcontainer/postCreateCommand.sh +45 -0
- LLaVA/docs/Evaluation.md +167 -0
- LLaVA/scripts/convert_sqa_to_llava_base_prompt.py +334 -0
- LLaVA/scripts/finetune_qlora.sh +50 -0
- LLaVA/scripts/pretrain.sh +46 -0
- LLaVA/scripts/zero2.json +23 -0
- sglang/.github/ISSUE_TEMPLATE/2-feature-request.yml +23 -0
- sglang/.github/workflows/close-inactive-issues.yml +96 -0
- sglang/.github/workflows/execute-notebook.yml +49 -0
- sglang/.github/workflows/lint.yml +22 -0
- sglang/.github/workflows/nightly-test.yml +34 -0
- sglang/.github/workflows/pr-test.yml +270 -0
- sglang/.github/workflows/release-docker-dev.yml +35 -0
- sglang/.github/workflows/release-docker.yml +64 -0
- sglang/.github/workflows/release-pypi-kernel.yml +41 -0
- sglang/.github/workflows/release-pypi.yml +31 -0
- sglang/3rdparty/amd/profiling/PROFILING.md +425 -0
- sglang/3rdparty/amd/profiling/client.sh +27 -0
- sglang/3rdparty/amd/profiling/install_rpd.sh +10 -0
- sglang/3rdparty/amd/profiling/loadTracer.sh +43 -0
- sglang/3rdparty/amd/profiling/rpd.patch +12 -0
- sglang/3rdparty/amd/profiling/rpd_profile_server_enable.patch +49 -0
- sglang/3rdparty/amd/profiling/rpd_profile_server_enable_wCPU_activities.patch +126 -0
- sglang/3rdparty/amd/profiling/server.sh +20 -0
- sglang/3rdparty/amd/tuning/TUNING.md +118 -0
- sglang/3rdparty/amd/tuning/benchmark_moe_rocm.py +377 -0
- sglang/benchmark/blog_v0_2/405b_sglang.sh +24 -0
- sglang/benchmark/blog_v0_2/405b_trt.sh +17 -0
- sglang/benchmark/blog_v0_2/405b_vllm.sh +24 -0
- sglang/benchmark/dspy/README.md +51 -0
- sglang/benchmark/dspy/bench_dspy_intro.py +192 -0
- sglang/benchmark/gsm8k/README.md +47 -0
- sglang/benchmark/gsm8k/bench_other.py +151 -0
- sglang/benchmark/gsm8k/bench_sglang.py +141 -0
- sglang/benchmark/hellaswag/README.md +47 -0
- sglang/benchmark/hellaswag/bench_other.py +118 -0
- sglang/benchmark/lora/launch_server.py +47 -0
- sglang/benchmark/lora/lora_bench.py +484 -0
- sglang/benchmark/mmlu/README.md +59 -0
- sglang/benchmark/mtbench/README.md +37 -0
- sglang/benchmark/mtbench/bench_other.py +111 -0
- sglang/benchmark/mtbench/bench_sglang.py +99 -0
- sglang/benchmark/multi_chain_reasoning/README.md +49 -0
- sglang/benchmark/multi_chain_reasoning/bench_sglang.py +140 -0
- sglang/benchmark/multi_turn_chat/README.md +66 -0
- sglang/benchmark/multi_turn_chat/bench_other.py +93 -0
- sglang/benchmark/multi_turn_chat/bench_sglang.py +79 -0
LLaVA/.devcontainer/Dockerfile
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM mcr.microsoft.com/devcontainers/base:ubuntu-20.04
|
2 |
+
|
3 |
+
SHELL [ "bash", "-c" ]
|
4 |
+
|
5 |
+
# update apt and install packages
|
6 |
+
RUN apt update && \
|
7 |
+
apt install -yq \
|
8 |
+
ffmpeg \
|
9 |
+
dkms \
|
10 |
+
build-essential
|
11 |
+
|
12 |
+
# add user tools
|
13 |
+
RUN sudo apt install -yq \
|
14 |
+
jq \
|
15 |
+
jp \
|
16 |
+
tree \
|
17 |
+
tldr
|
18 |
+
|
19 |
+
# add git-lfs and install
|
20 |
+
RUN curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | sudo bash && \
|
21 |
+
sudo apt-get install -yq git-lfs && \
|
22 |
+
git lfs install
|
23 |
+
|
24 |
+
############################################
|
25 |
+
# Setup user
|
26 |
+
############################################
|
27 |
+
|
28 |
+
USER vscode
|
29 |
+
|
30 |
+
# install azcopy, a tool to copy to/from blob storage
|
31 |
+
# for more info: https://learn.microsoft.com/en-us/azure/storage/common/storage-use-azcopy-blobs-upload#upload-a-file
|
32 |
+
RUN cd /tmp && \
|
33 |
+
wget https://azcopyvnext.azureedge.net/release20230123/azcopy_linux_amd64_10.17.0.tar.gz && \
|
34 |
+
tar xvf azcopy_linux_amd64_10.17.0.tar.gz && \
|
35 |
+
mkdir -p ~/.local/bin && \
|
36 |
+
mv azcopy_linux_amd64_10.17.0/azcopy ~/.local/bin && \
|
37 |
+
chmod +x ~/.local/bin/azcopy && \
|
38 |
+
rm -rf azcopy_linux_amd64*
|
39 |
+
|
40 |
+
# Setup conda
|
41 |
+
RUN cd /tmp && \
|
42 |
+
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \
|
43 |
+
bash ./Miniconda3-latest-Linux-x86_64.sh -b && \
|
44 |
+
rm ./Miniconda3-latest-Linux-x86_64.sh
|
45 |
+
|
46 |
+
# Install dotnet
|
47 |
+
RUN cd /tmp && \
|
48 |
+
wget https://dot.net/v1/dotnet-install.sh && \
|
49 |
+
chmod +x dotnet-install.sh && \
|
50 |
+
./dotnet-install.sh --channel 7.0 && \
|
51 |
+
./dotnet-install.sh --channel 3.1 && \
|
52 |
+
rm ./dotnet-install.sh
|
53 |
+
|
LLaVA/.devcontainer/devcontainer.env
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
SAMPLE_ENV_VAR1="Sample Value"
|
2 |
+
SAMPLE_ENV_VAR2=332431bf-68bf
|
LLaVA/.devcontainer/devcontainer.json
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"name": "LLaVA",
|
3 |
+
"build": {
|
4 |
+
"dockerfile": "Dockerfile",
|
5 |
+
"context": "..",
|
6 |
+
"args": {}
|
7 |
+
},
|
8 |
+
"features": {
|
9 |
+
"ghcr.io/devcontainers/features/docker-in-docker:2": {},
|
10 |
+
"ghcr.io/devcontainers/features/azure-cli:1": {},
|
11 |
+
"ghcr.io/azure/azure-dev/azd:0": {},
|
12 |
+
"ghcr.io/devcontainers/features/powershell:1": {},
|
13 |
+
"ghcr.io/devcontainers/features/common-utils:2": {},
|
14 |
+
"ghcr.io/devcontainers-contrib/features/zsh-plugins:0": {},
|
15 |
+
},
|
16 |
+
// "forwardPorts": [],
|
17 |
+
"postCreateCommand": "bash ./.devcontainer/postCreateCommand.sh",
|
18 |
+
"customizations": {
|
19 |
+
"vscode": {
|
20 |
+
"settings": {
|
21 |
+
"python.analysis.autoImportCompletions": true,
|
22 |
+
"python.analysis.autoImportUserSymbols": true,
|
23 |
+
"python.defaultInterpreterPath": "~/miniconda3/envs/llava/bin/python",
|
24 |
+
"python.formatting.provider": "yapf",
|
25 |
+
"python.linting.enabled": true,
|
26 |
+
"python.linting.flake8Enabled": true,
|
27 |
+
"isort.check": true,
|
28 |
+
"dev.containers.copyGitConfig": true,
|
29 |
+
"terminal.integrated.defaultProfile.linux": "zsh",
|
30 |
+
"terminal.integrated.profiles.linux": {
|
31 |
+
"zsh": {
|
32 |
+
"path": "/usr/bin/zsh"
|
33 |
+
},
|
34 |
+
}
|
35 |
+
},
|
36 |
+
"extensions": [
|
37 |
+
"aaron-bond.better-comments",
|
38 |
+
"eamodio.gitlens",
|
39 |
+
"EditorConfig.EditorConfig",
|
40 |
+
"foxundermoon.shell-format",
|
41 |
+
"GitHub.copilot-chat",
|
42 |
+
"GitHub.copilot-labs",
|
43 |
+
"GitHub.copilot",
|
44 |
+
"lehoanganh298.json-lines-viewer",
|
45 |
+
"mhutchie.git-graph",
|
46 |
+
"ms-azuretools.vscode-docker",
|
47 |
+
"ms-dotnettools.dotnet-interactive-vscode",
|
48 |
+
"ms-python.flake8",
|
49 |
+
"ms-python.isort",
|
50 |
+
"ms-python.python",
|
51 |
+
"ms-python.vscode-pylance",
|
52 |
+
"njpwerner.autodocstring",
|
53 |
+
"redhat.vscode-yaml",
|
54 |
+
"stkb.rewrap",
|
55 |
+
"yzhang.markdown-all-in-one",
|
56 |
+
]
|
57 |
+
}
|
58 |
+
},
|
59 |
+
"mounts": [],
|
60 |
+
"runArgs": [
|
61 |
+
"--gpus",
|
62 |
+
"all",
|
63 |
+
// "--ipc",
|
64 |
+
// "host",
|
65 |
+
"--ulimit",
|
66 |
+
"memlock=-1",
|
67 |
+
"--env-file",
|
68 |
+
".devcontainer/devcontainer.env"
|
69 |
+
],
|
70 |
+
// "remoteUser": "root"
|
71 |
+
}
|
LLaVA/.devcontainer/postCreateCommand.sh
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
git config --global safe.directory '*'
|
2 |
+
git config --global core.editor "code --wait"
|
3 |
+
git config --global pager.branch false
|
4 |
+
|
5 |
+
# Set AZCOPY concurrency to auto
|
6 |
+
echo "export AZCOPY_CONCURRENCY_VALUE=AUTO" >> ~/.zshrc
|
7 |
+
echo "export AZCOPY_CONCURRENCY_VALUE=AUTO" >> ~/.bashrc
|
8 |
+
|
9 |
+
# Activate conda by default
|
10 |
+
echo ". /home/vscode/miniconda3/bin/activate" >> ~/.zshrc
|
11 |
+
echo ". /home/vscode/miniconda3/bin/activate" >> ~/.bashrc
|
12 |
+
|
13 |
+
# Use llava environment by default
|
14 |
+
echo "conda activate llava" >> ~/.zshrc
|
15 |
+
echo "conda activate llava" >> ~/.bashrc
|
16 |
+
|
17 |
+
# Add dotnet to PATH
|
18 |
+
echo 'export PATH="$PATH:$HOME/.dotnet"' >> ~/.bashrc
|
19 |
+
echo 'export PATH="$PATH:$HOME/.dotnet"' >> ~/.zshrc
|
20 |
+
|
21 |
+
# Create and activate llava environment
|
22 |
+
source /home/vscode/miniconda3/bin/activate
|
23 |
+
conda create -y -q -n llava python=3.10
|
24 |
+
conda activate llava
|
25 |
+
|
26 |
+
# Install Nvidia Cuda Compiler
|
27 |
+
conda install -y -c nvidia cuda-compiler
|
28 |
+
|
29 |
+
pip install pre-commit==3.0.2
|
30 |
+
|
31 |
+
# Install package locally
|
32 |
+
pip install --upgrade pip # enable PEP 660 support
|
33 |
+
pip install -e .
|
34 |
+
|
35 |
+
# Install additional packages for training
|
36 |
+
pip install -e ".[train]"
|
37 |
+
pip install flash-attn --no-build-isolation
|
38 |
+
|
39 |
+
# Download checkpoints to location outside of the repo
|
40 |
+
git clone https://huggingface.co/liuhaotian/llava-v1.5-7b ~/llava-v1.5-7b
|
41 |
+
|
42 |
+
# Commented because it is unlikely for users to have enough local GPU memory to load the model
|
43 |
+
# git clone https://huggingface.co/liuhaotian/llava-v1.5-13b ~/llava-v1.5-13b
|
44 |
+
|
45 |
+
echo "postCreateCommand.sh COMPLETE!"
|
LLaVA/docs/Evaluation.md
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Evaluation
|
2 |
+
|
3 |
+
In LLaVA-1.5, we evaluate models on a diverse set of 12 benchmarks. To ensure the reproducibility, we evaluate the models with greedy decoding. We do not evaluate using beam search to make the inference process consistent with the chat demo of real-time outputs.
|
4 |
+
|
5 |
+
Currently, we mostly utilize the official toolkit or server for the evaluation.
|
6 |
+
|
7 |
+
## Evaluate on Custom Datasets
|
8 |
+
|
9 |
+
You can evaluate LLaVA on your custom datasets by converting your dataset to LLaVA's jsonl format, and evaluate using [`model_vqa.py`](https://github.com/haotian-liu/LLaVA/blob/main/llava/eval/model_vqa.py).
|
10 |
+
|
11 |
+
Below we provide a general guideline for evaluating datasets with some common formats.
|
12 |
+
|
13 |
+
1. Short-answer (e.g. VQAv2, MME).
|
14 |
+
|
15 |
+
```
|
16 |
+
<question>
|
17 |
+
Answer the question using a single word or phrase.
|
18 |
+
```
|
19 |
+
|
20 |
+
2. Option-only for multiple-choice (e.g. MMBench, SEED-Bench).
|
21 |
+
|
22 |
+
```
|
23 |
+
<question>
|
24 |
+
A. <option_1>
|
25 |
+
B. <option_2>
|
26 |
+
C. <option_3>
|
27 |
+
D. <option_4>
|
28 |
+
Answer with the option's letter from the given choices directly.
|
29 |
+
```
|
30 |
+
|
31 |
+
3. Natural QA (e.g. LLaVA-Bench, MM-Vet).
|
32 |
+
|
33 |
+
No postprocessing is needed.
|
34 |
+
|
35 |
+
## Scripts
|
36 |
+
|
37 |
+
Before preparing task-specific data, **you MUST first download [eval.zip](https://drive.google.com/file/d/1atZSBBrAX54yYpxtVVW33zFvcnaHeFPy/view?usp=sharing)**. It contains custom annotations, scripts, and the prediction files with LLaVA v1.5. Extract to `./playground/data/eval`. This also provides a general structure for all datasets.
|
38 |
+
|
39 |
+
### VQAv2
|
40 |
+
|
41 |
+
1. Download [`test2015`](http://images.cocodataset.org/zips/test2015.zip) and put it under `./playground/data/eval/vqav2`.
|
42 |
+
2. Multi-GPU inference.
|
43 |
+
```Shell
|
44 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 bash scripts/v1_5/eval/vqav2.sh
|
45 |
+
```
|
46 |
+
3. Submit the results to the [evaluation server](https://eval.ai/web/challenges/challenge-page/830/my-submission): `./playground/data/eval/vqav2/answers_upload`.
|
47 |
+
|
48 |
+
### GQA
|
49 |
+
|
50 |
+
1. Download the [data](https://cs.stanford.edu/people/dorarad/gqa/download.html) and [evaluation scripts](https://cs.stanford.edu/people/dorarad/gqa/evaluate.html) following the official instructions and put under `./playground/data/eval/gqa/data`. You may need to modify `eval.py` as [this](https://gist.github.com/haotian-liu/db6eddc2a984b4cbcc8a7f26fd523187) due to the missing assets in the GQA v1.2 release.
|
51 |
+
2. Multi-GPU inference.
|
52 |
+
```Shell
|
53 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 bash scripts/v1_5/eval/gqa.sh
|
54 |
+
```
|
55 |
+
|
56 |
+
### VisWiz
|
57 |
+
|
58 |
+
1. Download [`test.json`](https://vizwiz.cs.colorado.edu/VizWiz_final/vqa_data/Annotations.zip) and extract [`test.zip`](https://vizwiz.cs.colorado.edu/VizWiz_final/images/test.zip) to `test`. Put them under `./playground/data/eval/vizwiz`.
|
59 |
+
2. Single-GPU inference.
|
60 |
+
```Shell
|
61 |
+
CUDA_VISIBLE_DEVICES=0 bash scripts/v1_5/eval/vizwiz.sh
|
62 |
+
```
|
63 |
+
3. Submit the results to the [evaluation server](https://eval.ai/web/challenges/challenge-page/2185/my-submission): `./playground/data/eval/vizwiz/answers_upload`.
|
64 |
+
|
65 |
+
### ScienceQA
|
66 |
+
|
67 |
+
1. Under `./playground/data/eval/scienceqa`, download `images`, `pid_splits.json`, `problems.json` from the `data/scienceqa` folder of the ScienceQA [repo](https://github.com/lupantech/ScienceQA).
|
68 |
+
2. Single-GPU inference and evaluate.
|
69 |
+
```Shell
|
70 |
+
CUDA_VISIBLE_DEVICES=0 bash scripts/v1_5/eval/sqa.sh
|
71 |
+
```
|
72 |
+
|
73 |
+
### TextVQA
|
74 |
+
|
75 |
+
1. Download [`TextVQA_0.5.1_val.json`](https://dl.fbaipublicfiles.com/textvqa/data/TextVQA_0.5.1_val.json) and [images](https://dl.fbaipublicfiles.com/textvqa/images/train_val_images.zip) and extract to `./playground/data/eval/textvqa`.
|
76 |
+
2. Single-GPU inference and evaluate.
|
77 |
+
```Shell
|
78 |
+
CUDA_VISIBLE_DEVICES=0 bash scripts/v1_5/eval/textvqa.sh
|
79 |
+
```
|
80 |
+
|
81 |
+
### POPE
|
82 |
+
|
83 |
+
1. Download `coco` from [POPE](https://github.com/AoiDragon/POPE/tree/e3e39262c85a6a83f26cf5094022a782cb0df58d/output/coco) and put under `./playground/data/eval/pope`.
|
84 |
+
2. Single-GPU inference and evaluate.
|
85 |
+
```Shell
|
86 |
+
CUDA_VISIBLE_DEVICES=0 bash scripts/v1_5/eval/pope.sh
|
87 |
+
```
|
88 |
+
|
89 |
+
### MME
|
90 |
+
|
91 |
+
1. Download the data following the official instructions [here](https://github.com/BradyFU/Awesome-Multimodal-Large-Language-Models/tree/Evaluation).
|
92 |
+
2. Downloaded images to `MME_Benchmark_release_version`.
|
93 |
+
3. put the official `eval_tool` and `MME_Benchmark_release_version` under `./playground/data/eval/MME`.
|
94 |
+
4. Single-GPU inference and evaluate.
|
95 |
+
```Shell
|
96 |
+
CUDA_VISIBLE_DEVICES=0 bash scripts/v1_5/eval/mme.sh
|
97 |
+
```
|
98 |
+
|
99 |
+
### MMBench
|
100 |
+
|
101 |
+
1. Download [`mmbench_dev_20230712.tsv`](https://download.openmmlab.com/mmclassification/datasets/mmbench/mmbench_dev_20230712.tsv) and put under `./playground/data/eval/mmbench`.
|
102 |
+
2. Single-GPU inference.
|
103 |
+
```Shell
|
104 |
+
CUDA_VISIBLE_DEVICES=0 bash scripts/v1_5/eval/mmbench.sh
|
105 |
+
```
|
106 |
+
3. Submit the results to the [evaluation server](https://opencompass.org.cn/leaderboard-multimodal): `./playground/data/eval/mmbench/answers_upload/mmbench_dev_20230712`.
|
107 |
+
|
108 |
+
### MMBench-CN
|
109 |
+
|
110 |
+
1. Download [`mmbench_dev_cn_20231003.tsv`](https://download.openmmlab.com/mmclassification/datasets/mmbench/mmbench_dev_cn_20231003.tsv) and put under `./playground/data/eval/mmbench`.
|
111 |
+
2. Single-GPU inference.
|
112 |
+
```Shell
|
113 |
+
CUDA_VISIBLE_DEVICES=0 bash scripts/v1_5/eval/mmbench_cn.sh
|
114 |
+
```
|
115 |
+
3. Submit the results to the evaluation server: `./playground/data/eval/mmbench/answers_upload/mmbench_dev_cn_20231003`.
|
116 |
+
|
117 |
+
|
118 |
+
### SEED-Bench
|
119 |
+
|
120 |
+
1. Following the official [instructions](https://github.com/AILab-CVC/SEED-Bench/blob/main/DATASET.md) to download the images and the videos. Put images under `./playground/data/eval/seed_bench/SEED-Bench-image`.
|
121 |
+
2. Extract the video frame in the middle from the downloaded videos, and put them under `./playground/data/eval/seed_bench/SEED-Bench-video-image`. We provide our script `extract_video_frames.py` modified from the official one.
|
122 |
+
3. Multiple-GPU inference and evaluate.
|
123 |
+
```Shell
|
124 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 bash scripts/v1_5/eval/seed.sh
|
125 |
+
```
|
126 |
+
4. Optionally, submit the results to the leaderboard: `./playground/data/eval/seed_bench/answers_upload` using the official jupyter notebook.
|
127 |
+
|
128 |
+
### LLaVA-Bench-in-the-Wild
|
129 |
+
|
130 |
+
1. Extract contents of [`llava-bench-in-the-wild`](https://huggingface.co/datasets/liuhaotian/llava-bench-in-the-wild) to `./playground/data/eval/llava-bench-in-the-wild`.
|
131 |
+
2. Single-GPU inference and evaluate.
|
132 |
+
```Shell
|
133 |
+
CUDA_VISIBLE_DEVICES=0 bash scripts/v1_5/eval/llavabench.sh
|
134 |
+
```
|
135 |
+
|
136 |
+
### MM-Vet
|
137 |
+
|
138 |
+
1. Extract [`mm-vet.zip`](https://github.com/yuweihao/MM-Vet/releases/download/v1/mm-vet.zip) to `./playground/data/eval/mmvet`.
|
139 |
+
2. Single-GPU inference.
|
140 |
+
```Shell
|
141 |
+
CUDA_VISIBLE_DEVICES=0 bash scripts/v1_5/eval/mmvet.sh
|
142 |
+
```
|
143 |
+
3. Evaluate the predictions in `./playground/data/eval/mmvet/results` using the official jupyter notebook.
|
144 |
+
|
145 |
+
## More Benchmarks
|
146 |
+
|
147 |
+
Below are awesome benchmarks for multimodal understanding from the research community, that are not initially included in the LLaVA-1.5 release.
|
148 |
+
|
149 |
+
### Q-Bench
|
150 |
+
|
151 |
+
1. Download [`llvisionqa_dev.json`](https://huggingface.co/datasets/nanyangtu/LLVisionQA-QBench/resolve/main/llvisionqa_dev.json) (for `dev`-subset) and [`llvisionqa_test.json`](https://huggingface.co/datasets/nanyangtu/LLVisionQA-QBench/resolve/main/llvisionqa_test.json) (for `test`-subset). Put them under `./playground/data/eval/qbench`.
|
152 |
+
2. Download and extract [images](https://huggingface.co/datasets/nanyangtu/LLVisionQA-QBench/resolve/main/images_llvisionqa.tar) and put all the images directly under `./playground/data/eval/qbench/images_llviqionqa`.
|
153 |
+
3. Single-GPU inference (change `dev` to `test` for evaluation on test set).
|
154 |
+
```Shell
|
155 |
+
CUDA_VISIBLE_DEVICES=0 bash scripts/v1_5/eval/qbench.sh dev
|
156 |
+
```
|
157 |
+
4. Submit the results by instruction [here](https://github.com/VQAssessment/Q-Bench#option-1-submit-results): `./playground/data/eval/qbench/llvisionqa_dev_answers.jsonl`.
|
158 |
+
|
159 |
+
### Chinese-Q-Bench
|
160 |
+
|
161 |
+
1. Download [`质衡-问答-验证集.json`](https://huggingface.co/datasets/nanyangtu/LLVisionQA-QBench/resolve/main/%E8%B4%A8%E8%A1%A1-%E9%97%AE%E7%AD%94-%E9%AA%8C%E8%AF%81%E9%9B%86.json) (for `dev`-subset) and [`质衡-问答-测试集.json`](https://huggingface.co/datasets/nanyangtu/LLVisionQA-QBench/resolve/main/%E8%B4%A8%E8%A1%A1-%E9%97%AE%E7%AD%94-%E6%B5%8B%E8%AF%95%E9%9B%86.json) (for `test`-subset). Put them under `./playground/data/eval/qbench`.
|
162 |
+
2. Download and extract [images](https://huggingface.co/datasets/nanyangtu/LLVisionQA-QBench/resolve/main/images_llvisionqa.tar) and put all the images directly under `./playground/data/eval/qbench/images_llviqionqa`.
|
163 |
+
3. Single-GPU inference (change `dev` to `test` for evaluation on test set).
|
164 |
+
```Shell
|
165 |
+
CUDA_VISIBLE_DEVICES=0 bash scripts/v1_5/eval/qbench_zh.sh dev
|
166 |
+
```
|
167 |
+
4. Submit the results by instruction [here](https://github.com/VQAssessment/Q-Bench#option-1-submit-results): `./playground/data/eval/qbench/llvisionqa_zh_dev_answers.jsonl`.
|
LLaVA/scripts/convert_sqa_to_llava_base_prompt.py
ADDED
@@ -0,0 +1,334 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def get_question_text(problem):
|
2 |
+
question = problem['question']
|
3 |
+
return question
|
4 |
+
|
5 |
+
|
6 |
+
def get_context_text(problem, use_caption):
|
7 |
+
txt_context = problem['hint']
|
8 |
+
img_context = problem['caption'] if use_caption else ""
|
9 |
+
context = " ".join([txt_context, img_context]).strip()
|
10 |
+
if context == "":
|
11 |
+
context = "N/A"
|
12 |
+
return context
|
13 |
+
|
14 |
+
|
15 |
+
def get_choice_text(probelm, options):
|
16 |
+
choices = probelm['choices']
|
17 |
+
choice_list = []
|
18 |
+
for i, c in enumerate(choices):
|
19 |
+
choice_list.append("({}) {}".format(options[i], c))
|
20 |
+
choice_txt = " ".join(choice_list)
|
21 |
+
#print(choice_txt)
|
22 |
+
return choice_txt
|
23 |
+
|
24 |
+
|
25 |
+
def get_answer(problem, options):
|
26 |
+
return options[problem['answer']]
|
27 |
+
|
28 |
+
|
29 |
+
def get_lecture_text(problem):
|
30 |
+
# \\n: GPT-3 can generate the lecture with more tokens.
|
31 |
+
lecture = problem['lecture'].replace("\n", "\\n")
|
32 |
+
return lecture
|
33 |
+
|
34 |
+
|
35 |
+
def get_solution_text(problem):
|
36 |
+
# \\n: GPT-3 can generate the solution with more tokens
|
37 |
+
solution = problem['solution'].replace("\n", "\\n")
|
38 |
+
return solution
|
39 |
+
|
40 |
+
|
41 |
+
def create_one_example_chatbot(format, question, context, choice, answer, lecture, solution, test_example=True):
|
42 |
+
|
43 |
+
input_format, output_format = format.split("-")
|
44 |
+
|
45 |
+
## Inputs
|
46 |
+
if input_format == "CQM":
|
47 |
+
input = f"Context: {context}\nQuestion: {question}\nOptions: {choice}\n"
|
48 |
+
elif input_format == "QCM":
|
49 |
+
input = f"Question: {question}\nContext: {context}\nOptions: {choice}\n"
|
50 |
+
# upper bound experiment
|
51 |
+
elif input_format == "QCML":
|
52 |
+
input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {lecture}\n"
|
53 |
+
elif input_format == "QCME":
|
54 |
+
input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {solution}\n"
|
55 |
+
elif input_format == "QCMLE":
|
56 |
+
input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {lecture} {solution}\n"
|
57 |
+
|
58 |
+
elif input_format == "QCLM":
|
59 |
+
input = f"Question: {question}\nContext: {context}\nBECAUSE: {lecture}\nOptions: {choice}\n"
|
60 |
+
elif input_format == "QCEM":
|
61 |
+
input = f"Question: {question}\nContext: {context}\nBECAUSE: {solution}\nOptions: {choice}\n"
|
62 |
+
elif input_format == "QCLEM":
|
63 |
+
input = f"Question: {question}\nContext: {context}\nBECAUSE: {lecture} {solution}\nOptions: {choice}\n"
|
64 |
+
|
65 |
+
# Outputs
|
66 |
+
if test_example:
|
67 |
+
output = "Answer:"
|
68 |
+
elif output_format == 'A':
|
69 |
+
output = f"Answer: The answer is {answer}."
|
70 |
+
|
71 |
+
elif output_format == 'AL':
|
72 |
+
output = f"Answer: The answer is {answer}. BECAUSE: {solution}"
|
73 |
+
elif output_format == 'AE':
|
74 |
+
output = f"Answer: The answer is {answer}. BECAUSE: {lecture}"
|
75 |
+
elif output_format == 'ALE':
|
76 |
+
output = f"Answer: The answer is {answer}. BECAUSE: {lecture} {solution}"
|
77 |
+
elif output_format == 'AEL':
|
78 |
+
output = f"Answer: The answer is {answer}. BECAUSE: {solution} {lecture}"
|
79 |
+
|
80 |
+
elif output_format == 'LA':
|
81 |
+
output = f"Answer: {lecture} The answer is {answer}."
|
82 |
+
elif output_format == 'EA':
|
83 |
+
output = f"Answer: {solution} The answer is {answer}."
|
84 |
+
elif output_format == 'LEA':
|
85 |
+
output = f"Answer: {lecture} {solution} The answer is {answer}."
|
86 |
+
elif output_format == 'ELA':
|
87 |
+
output = f"Answer: {solution} {lecture} The answer is {answer}."
|
88 |
+
elif output_format == 'LEPA':
|
89 |
+
output = ''
|
90 |
+
if len(lecture.strip()) > 0:
|
91 |
+
output += f"LECTURE: {lecture}\n"
|
92 |
+
if len(solution.strip()) > 0:
|
93 |
+
output += f"SOLUTION: {solution}\n"
|
94 |
+
output += '###\n'
|
95 |
+
output += f"ANSWER: {answer}."
|
96 |
+
|
97 |
+
input = input.replace(" ", " ").strip()
|
98 |
+
output = output.replace(" ", " ").strip()
|
99 |
+
if input.endswith("BECAUSE:"):
|
100 |
+
input = input.replace("BECAUSE:", "").strip()
|
101 |
+
if output.endswith("BECAUSE:"):
|
102 |
+
output = output.replace("BECAUSE:", "").strip()
|
103 |
+
return input, output
|
104 |
+
|
105 |
+
|
106 |
+
def create_one_example(format, question, context, choice, answer, lecture, solution, test_example=True):
|
107 |
+
|
108 |
+
input_format, output_format = format.split("-")
|
109 |
+
|
110 |
+
## Inputs
|
111 |
+
if input_format == "CQM":
|
112 |
+
input = f"Context: {context}\nQuestion: {question}\nOptions: {choice}\n"
|
113 |
+
elif input_format == "QCM":
|
114 |
+
input = f"Question: {question}\nContext: {context}\nOptions: {choice}\n"
|
115 |
+
# upper bound experiment
|
116 |
+
elif input_format == "QCML":
|
117 |
+
input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {lecture}\n"
|
118 |
+
elif input_format == "QCME":
|
119 |
+
input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {solution}\n"
|
120 |
+
elif input_format == "QCMLE":
|
121 |
+
input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {lecture} {solution}\n"
|
122 |
+
|
123 |
+
elif input_format == "QCLM":
|
124 |
+
input = f"Question: {question}\nContext: {context}\nBECAUSE: {lecture}\nOptions: {choice}\n"
|
125 |
+
elif input_format == "QCEM":
|
126 |
+
input = f"Question: {question}\nContext: {context}\nBECAUSE: {solution}\nOptions: {choice}\n"
|
127 |
+
elif input_format == "QCLEM":
|
128 |
+
input = f"Question: {question}\nContext: {context}\nBECAUSE: {lecture} {solution}\nOptions: {choice}\n"
|
129 |
+
|
130 |
+
# Outputs
|
131 |
+
if test_example:
|
132 |
+
output = "Answer:"
|
133 |
+
elif output_format == 'A':
|
134 |
+
output = f"Answer: The answer is {answer}."
|
135 |
+
|
136 |
+
elif output_format == 'AL':
|
137 |
+
output = f"Answer: The answer is {answer}. BECAUSE: {solution}"
|
138 |
+
elif output_format == 'AE':
|
139 |
+
output = f"Answer: The answer is {answer}. BECAUSE: {lecture}"
|
140 |
+
elif output_format == 'ALE':
|
141 |
+
output = f"Answer: The answer is {answer}. BECAUSE: {lecture} {solution}"
|
142 |
+
elif output_format == 'AEL':
|
143 |
+
output = f"Answer: The answer is {answer}. BECAUSE: {solution} {lecture}"
|
144 |
+
|
145 |
+
elif output_format == 'LA':
|
146 |
+
output = f"Answer: {lecture} The answer is {answer}."
|
147 |
+
elif output_format == 'EA':
|
148 |
+
output = f"Answer: {solution} The answer is {answer}."
|
149 |
+
elif output_format == 'LEA':
|
150 |
+
output = f"Answer: {lecture} {solution} The answer is {answer}."
|
151 |
+
elif output_format == 'ELA':
|
152 |
+
output = f"Answer: {solution} {lecture} The answer is {answer}."
|
153 |
+
|
154 |
+
text = input + output
|
155 |
+
text = text.replace(" ", " ").strip()
|
156 |
+
if text.endswith("BECAUSE:"):
|
157 |
+
text = text.replace("BECAUSE:", "").strip()
|
158 |
+
return text
|
159 |
+
|
160 |
+
|
161 |
+
|
162 |
+
def create_one_example_gpt4(format, question, context, choice, answer, lecture, solution, test_example=True):
|
163 |
+
|
164 |
+
input_format, output_format = format.split("-")
|
165 |
+
|
166 |
+
## Inputs
|
167 |
+
if input_format == "CQM":
|
168 |
+
input = f"Context: {context}\nQuestion: {question}\nOptions: {choice}\n"
|
169 |
+
elif input_format == "QCM":
|
170 |
+
input = f"Question: {question}\nContext: {context}\nOptions: {choice}\n"
|
171 |
+
# upper bound experiment
|
172 |
+
elif input_format == "QCML":
|
173 |
+
input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {lecture}\n"
|
174 |
+
elif input_format == "QCME":
|
175 |
+
input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {solution}\n"
|
176 |
+
elif input_format == "QCMLE":
|
177 |
+
input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {lecture} {solution}\n"
|
178 |
+
|
179 |
+
elif input_format == "QCLM":
|
180 |
+
input = f"Question: {question}\nContext: {context}\nBECAUSE: {lecture}\nOptions: {choice}\n"
|
181 |
+
elif input_format == "QCEM":
|
182 |
+
input = f"Question: {question}\nContext: {context}\nBECAUSE: {solution}\nOptions: {choice}\n"
|
183 |
+
elif input_format == "QCLEM":
|
184 |
+
input = f"Question: {question}\nContext: {context}\nBECAUSE: {lecture} {solution}\nOptions: {choice}\n"
|
185 |
+
|
186 |
+
# Outputs
|
187 |
+
if test_example:
|
188 |
+
output = "Answer:"
|
189 |
+
elif output_format == 'A':
|
190 |
+
output = f"Answer: The answer is {answer}."
|
191 |
+
|
192 |
+
elif output_format == 'AL':
|
193 |
+
output = f"Answer: The answer is {answer}. BECAUSE: {solution}"
|
194 |
+
elif output_format == 'AE':
|
195 |
+
output = f"Answer: The answer is {answer}. BECAUSE: {lecture}"
|
196 |
+
elif output_format == 'ALE':
|
197 |
+
output = f"Answer: The answer is {answer}. BECAUSE: {lecture} {solution}"
|
198 |
+
elif output_format == 'AEL':
|
199 |
+
output = f"Answer: The answer is {answer}. BECAUSE: {solution} {lecture}"
|
200 |
+
|
201 |
+
elif output_format == 'LA':
|
202 |
+
output = f"Answer: {lecture} The answer is {answer}."
|
203 |
+
elif output_format == 'EA':
|
204 |
+
output = f"Answer: {solution} The answer is {answer}."
|
205 |
+
elif output_format == 'LEA':
|
206 |
+
output = f"Answer: {lecture} {solution} The answer is {answer}."
|
207 |
+
elif output_format == 'ELA':
|
208 |
+
output = f"Answer: {solution} {lecture} The answer is {answer}."
|
209 |
+
|
210 |
+
input = input.replace(" ", " ").strip()
|
211 |
+
output = output.replace(" ", " ").strip()
|
212 |
+
if output.endswith("BECAUSE:"):
|
213 |
+
output = output.replace("BECAUSE:", "").strip()
|
214 |
+
|
215 |
+
user_prompt = {"role": "user", "content": f"Can you explain {input}?"}
|
216 |
+
assistant_prompt = {"role": "assistant", "content": f"{output}"}
|
217 |
+
|
218 |
+
return user_prompt, assistant_prompt
|
219 |
+
|
220 |
+
|
221 |
+
def build_prompt_chatbot(problems, shot_qids, prompt_format, use_caption=False, options=["A", "B", "C", "D", "E"], is_test=False):
|
222 |
+
examples = {}
|
223 |
+
|
224 |
+
for qid in shot_qids:
|
225 |
+
question = get_question_text(problems[qid])
|
226 |
+
context = get_context_text(problems[qid], use_caption)
|
227 |
+
choice = get_choice_text(problems[qid], options)
|
228 |
+
answer = get_answer(problems[qid], options)
|
229 |
+
lecture = get_lecture_text(problems[qid]).replace('\\n', '\n')
|
230 |
+
solution = get_solution_text(problems[qid]).replace('\\n', '\n')
|
231 |
+
|
232 |
+
train_example = create_one_example_chatbot(prompt_format,
|
233 |
+
question,
|
234 |
+
context,
|
235 |
+
choice,
|
236 |
+
answer,
|
237 |
+
lecture,
|
238 |
+
solution,
|
239 |
+
test_example=is_test)
|
240 |
+
examples[qid] = train_example
|
241 |
+
return examples
|
242 |
+
|
243 |
+
|
244 |
+
def build_prompt(problems, shot_qids, test_qid, args):
|
245 |
+
|
246 |
+
examples = []
|
247 |
+
|
248 |
+
# n-shot training examples
|
249 |
+
for qid in shot_qids:
|
250 |
+
question = get_question_text(problems[qid])
|
251 |
+
context = get_context_text(problems[qid], args.use_caption)
|
252 |
+
choice = get_choice_text(problems[qid], args.options)
|
253 |
+
answer = get_answer(problems[qid], args.options)
|
254 |
+
lecture = get_lecture_text(problems[qid])
|
255 |
+
solution = get_solution_text(problems[qid])
|
256 |
+
|
257 |
+
train_example = create_one_example(args.prompt_format,
|
258 |
+
question,
|
259 |
+
context,
|
260 |
+
choice,
|
261 |
+
answer,
|
262 |
+
lecture,
|
263 |
+
solution,
|
264 |
+
test_example=False)
|
265 |
+
examples.append(train_example)
|
266 |
+
|
267 |
+
# test example
|
268 |
+
question = get_question_text(problems[test_qid])
|
269 |
+
context = get_context_text(problems[test_qid], args.use_caption)
|
270 |
+
choice = get_choice_text(problems[test_qid], args.options)
|
271 |
+
answer = get_answer(problems[test_qid], args.options)
|
272 |
+
lecture = get_lecture_text(problems[test_qid])
|
273 |
+
solution = get_solution_text(problems[test_qid])
|
274 |
+
|
275 |
+
test_example = create_one_example(args.prompt_format,
|
276 |
+
question,
|
277 |
+
context,
|
278 |
+
choice,
|
279 |
+
answer,
|
280 |
+
lecture,
|
281 |
+
solution,
|
282 |
+
test_example=True)
|
283 |
+
examples.append(test_example)
|
284 |
+
|
285 |
+
# create the prompt input
|
286 |
+
prompt_input = '\n\n'.join(examples)
|
287 |
+
|
288 |
+
return prompt_input
|
289 |
+
|
290 |
+
|
291 |
+
def build_prompt_gpt4(problems, shot_qids, test_qid, args):
|
292 |
+
|
293 |
+
prompt_array = [{"role": "system", "content": "You are a helpful assistant."}]
|
294 |
+
|
295 |
+
# n-shot training examples
|
296 |
+
for qid in shot_qids:
|
297 |
+
question = get_question_text(problems[qid])
|
298 |
+
context = get_context_text(problems[qid], args.use_caption)
|
299 |
+
choice = get_choice_text(problems[qid], args.options)
|
300 |
+
answer = get_answer(problems[qid], args.options)
|
301 |
+
lecture = get_lecture_text(problems[qid])
|
302 |
+
solution = get_solution_text(problems[qid])
|
303 |
+
|
304 |
+
user_prompt, assistant_prompt = create_one_example_gpt4(args.prompt_format,
|
305 |
+
question,
|
306 |
+
context,
|
307 |
+
choice,
|
308 |
+
answer,
|
309 |
+
lecture,
|
310 |
+
solution,
|
311 |
+
test_example=False)
|
312 |
+
prompt_array.append(user_prompt)
|
313 |
+
prompt_array.append(assistant_prompt)
|
314 |
+
|
315 |
+
# test example
|
316 |
+
question = get_question_text(problems[test_qid])
|
317 |
+
context = get_context_text(problems[test_qid], args.use_caption)
|
318 |
+
choice = get_choice_text(problems[test_qid], args.options)
|
319 |
+
answer = get_answer(problems[test_qid], args.options)
|
320 |
+
lecture = get_lecture_text(problems[test_qid])
|
321 |
+
solution = get_solution_text(problems[test_qid])
|
322 |
+
|
323 |
+
user_prompt, assistant_prompt = create_one_example_gpt4(args.prompt_format,
|
324 |
+
question,
|
325 |
+
context,
|
326 |
+
choice,
|
327 |
+
answer,
|
328 |
+
lecture,
|
329 |
+
solution,
|
330 |
+
test_example=True)
|
331 |
+
prompt_array.append(user_prompt)
|
332 |
+
prompt_array.append(assistant_prompt)
|
333 |
+
|
334 |
+
return prompt_array
|
LLaVA/scripts/finetune_qlora.sh
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
# IMPORTANT: this is the training script for the original LLaVA, NOT FOR LLaVA V1.5!
|
4 |
+
|
5 |
+
# Uncomment and set the following variables correspondingly to run this script:
|
6 |
+
|
7 |
+
################## VICUNA ##################
|
8 |
+
# PROMPT_VERSION=v1
|
9 |
+
# MODEL_VERSION="vicuna-v1-3-7b"
|
10 |
+
################## VICUNA ##################
|
11 |
+
|
12 |
+
################## LLaMA-2 ##################
|
13 |
+
# PROMPT_VERSION="llava_llama_2"
|
14 |
+
# MODEL_VERSION="llama-2-7b-chat"
|
15 |
+
################## LLaMA-2 ##################
|
16 |
+
|
17 |
+
deepspeed llava/train/train_mem.py \
|
18 |
+
--deepspeed ./scripts/zero2.json \
|
19 |
+
--lora_enable True \
|
20 |
+
--bits 4 \
|
21 |
+
--model_name_or_path ./checkpoints/$MODEL_VERSION \
|
22 |
+
--version $PROMPT_VERSION \
|
23 |
+
--data_path ./playground/data/llava_instruct_80k.json \
|
24 |
+
--image_folder /path/to/coco/train2017 \
|
25 |
+
--vision_tower openai/clip-vit-large-patch14 \
|
26 |
+
--pretrain_mm_mlp_adapter ./checkpoints/llava-$MODEL_VERSION-pretrain/mm_projector.bin \
|
27 |
+
--mm_vision_select_layer -2 \
|
28 |
+
--mm_use_im_start_end False \
|
29 |
+
--mm_use_im_patch_token False \
|
30 |
+
--bf16 True \
|
31 |
+
--output_dir ./checkpoints/llava-$MODEL_VERSION-finetune_lora \
|
32 |
+
--num_train_epochs 1 \
|
33 |
+
--per_device_train_batch_size 16 \
|
34 |
+
--per_device_eval_batch_size 4 \
|
35 |
+
--gradient_accumulation_steps 1 \
|
36 |
+
--evaluation_strategy "no" \
|
37 |
+
--save_strategy "steps" \
|
38 |
+
--save_steps 50000 \
|
39 |
+
--save_total_limit 1 \
|
40 |
+
--learning_rate 2e-5 \
|
41 |
+
--weight_decay 0. \
|
42 |
+
--warmup_ratio 0.03 \
|
43 |
+
--lr_scheduler_type "cosine" \
|
44 |
+
--logging_steps 1 \
|
45 |
+
--tf32 True \
|
46 |
+
--model_max_length 2048 \
|
47 |
+
--gradient_checkpointing True \
|
48 |
+
--lazy_preprocess True \
|
49 |
+
--dataloader_num_workers 4 \
|
50 |
+
--report_to wandb
|
LLaVA/scripts/pretrain.sh
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
# IMPORTANT: this is the training script for the original LLaVA, NOT FOR LLaVA V1.5!
|
4 |
+
|
5 |
+
# Uncomment and set the following variables correspondingly to run this script:
|
6 |
+
|
7 |
+
# MODEL_VERSION=vicuna-v1-3-7b
|
8 |
+
# MODEL_VERSION=llama-2-7b-chat
|
9 |
+
|
10 |
+
########### DO NOT CHANGE ###########
|
11 |
+
########### USE THIS FOR BOTH ###########
|
12 |
+
PROMPT_VERSION=plain
|
13 |
+
########### DO NOT CHANGE ###########
|
14 |
+
|
15 |
+
deepspeed llava/train/train_mem.py \
|
16 |
+
--deepspeed ./scripts/zero2.json \
|
17 |
+
--model_name_or_path ./checkpoints/$MODEL_VERSION \
|
18 |
+
--version $PROMPT_VERSION \
|
19 |
+
--data_path /path/to/pretrain_data.json \
|
20 |
+
--image_folder /path/to/images \
|
21 |
+
--vision_tower openai/clip-vit-large-patch14 \
|
22 |
+
--tune_mm_mlp_adapter True \
|
23 |
+
--mm_vision_select_layer -2 \
|
24 |
+
--mm_use_im_start_end False \
|
25 |
+
--mm_use_im_patch_token False \
|
26 |
+
--bf16 True \
|
27 |
+
--output_dir ./checkpoints/llava-$MODEL_VERSION-pretrain \
|
28 |
+
--num_train_epochs 1 \
|
29 |
+
--per_device_train_batch_size 16 \
|
30 |
+
--per_device_eval_batch_size 4 \
|
31 |
+
--gradient_accumulation_steps 1 \
|
32 |
+
--evaluation_strategy "no" \
|
33 |
+
--save_strategy "steps" \
|
34 |
+
--save_steps 24000 \
|
35 |
+
--save_total_limit 1 \
|
36 |
+
--learning_rate 2e-3 \
|
37 |
+
--weight_decay 0. \
|
38 |
+
--warmup_ratio 0.03 \
|
39 |
+
--lr_scheduler_type "cosine" \
|
40 |
+
--logging_steps 1 \
|
41 |
+
--tf32 True \
|
42 |
+
--model_max_length 2048 \
|
43 |
+
--gradient_checkpointing True \
|
44 |
+
--dataloader_num_workers 4 \
|
45 |
+
--lazy_preprocess True \
|
46 |
+
--report_to wandb
|
LLaVA/scripts/zero2.json
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"fp16": {
|
3 |
+
"enabled": "auto",
|
4 |
+
"loss_scale": 0,
|
5 |
+
"loss_scale_window": 1000,
|
6 |
+
"initial_scale_power": 16,
|
7 |
+
"hysteresis": 2,
|
8 |
+
"min_loss_scale": 1
|
9 |
+
},
|
10 |
+
"bf16": {
|
11 |
+
"enabled": "auto"
|
12 |
+
},
|
13 |
+
"train_micro_batch_size_per_gpu": "auto",
|
14 |
+
"train_batch_size": "auto",
|
15 |
+
"gradient_accumulation_steps": "auto",
|
16 |
+
"zero_optimization": {
|
17 |
+
"stage": 2,
|
18 |
+
"overlap_comm": true,
|
19 |
+
"contiguous_gradients": true,
|
20 |
+
"sub_group_size": 1e9,
|
21 |
+
"reduce_bucket_size": "auto"
|
22 |
+
}
|
23 |
+
}
|
sglang/.github/ISSUE_TEMPLATE/2-feature-request.yml
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: 🚀 Feature request
|
2 |
+
description: Suggest an idea for this project
|
3 |
+
title: "[Feature] "
|
4 |
+
|
5 |
+
body:
|
6 |
+
- type: checkboxes
|
7 |
+
attributes:
|
8 |
+
label: Checklist
|
9 |
+
options:
|
10 |
+
- label: 1. If the issue you raised is not a feature but a question, please raise a discussion at https://github.com/sgl-project/sglang/discussions/new/choose Otherwise, it will be closed.
|
11 |
+
- label: 2. Please use English, otherwise it will be closed.
|
12 |
+
- type: textarea
|
13 |
+
attributes:
|
14 |
+
label: Motivation
|
15 |
+
description: |
|
16 |
+
A clear and concise description of the motivation of the feature.
|
17 |
+
validations:
|
18 |
+
required: true
|
19 |
+
- type: textarea
|
20 |
+
attributes:
|
21 |
+
label: Related resources
|
22 |
+
description: |
|
23 |
+
If there is an official code release or third-party implementations, please also provide the information here, which would be very helpful.
|
sglang/.github/workflows/close-inactive-issues.yml
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Close Inactive Issues
|
2 |
+
|
3 |
+
on:
|
4 |
+
schedule:
|
5 |
+
- cron: '0 0 * * *'
|
6 |
+
workflow_dispatch:
|
7 |
+
|
8 |
+
permissions:
|
9 |
+
issues: write
|
10 |
+
contents: read
|
11 |
+
|
12 |
+
jobs:
|
13 |
+
close-inactive-issues:
|
14 |
+
if: github.repository == 'sgl-project/sglang'
|
15 |
+
runs-on: ubuntu-latest
|
16 |
+
steps:
|
17 |
+
- name: Check and close inactive issues
|
18 |
+
uses: actions/github-script@v6
|
19 |
+
with:
|
20 |
+
github-token: ${{secrets.GITHUB_TOKEN}}
|
21 |
+
script: |
|
22 |
+
const sixtyDaysAgo = new Date(Date.now() - 60 * 24 * 60 * 60 * 1000);
|
23 |
+
|
24 |
+
const [owner, repo] = process.env.GITHUB_REPOSITORY.split('/');
|
25 |
+
console.log(`Owner: ${owner}, Repo: ${repo}`);
|
26 |
+
|
27 |
+
async function fetchIssues(page = 1) {
|
28 |
+
console.log(`Fetching issues for ${owner}/${repo}, page ${page}`);
|
29 |
+
return await github.rest.issues.listForRepo({
|
30 |
+
owner,
|
31 |
+
repo,
|
32 |
+
state: 'open',
|
33 |
+
sort: 'updated',
|
34 |
+
direction: 'asc',
|
35 |
+
per_page: 100,
|
36 |
+
page: page
|
37 |
+
});
|
38 |
+
}
|
39 |
+
|
40 |
+
async function processIssues() {
|
41 |
+
console.log('Starting to process issues');
|
42 |
+
console.log(`Repository: ${owner}/${repo}`);
|
43 |
+
|
44 |
+
let page = 1;
|
45 |
+
let hasMoreIssues = true;
|
46 |
+
while (hasMoreIssues) {
|
47 |
+
try {
|
48 |
+
const issues = await fetchIssues(page);
|
49 |
+
console.log(`Fetched ${issues.data.length} issues on page ${page}`);
|
50 |
+
|
51 |
+
if (issues.data.length === 0) {
|
52 |
+
hasMoreIssues = false;
|
53 |
+
break;
|
54 |
+
}
|
55 |
+
|
56 |
+
for (const issue of issues.data) {
|
57 |
+
// Skip if the issue has 'good first issue' label
|
58 |
+
if (issue.labels.some(label => label.name === 'good first issue')) {
|
59 |
+
console.log(`Skipping issue #${issue.number} as it's marked as 'good first issue'`);
|
60 |
+
continue;
|
61 |
+
}
|
62 |
+
if (new Date(issue.updated_at) < sixtyDaysAgo) {
|
63 |
+
try {
|
64 |
+
await github.rest.issues.update({
|
65 |
+
owner,
|
66 |
+
repo,
|
67 |
+
issue_number: issue.number,
|
68 |
+
state: 'closed',
|
69 |
+
labels: [...issue.labels.map(l => l.name), 'inactive']
|
70 |
+
});
|
71 |
+
await github.rest.issues.createComment({
|
72 |
+
owner,
|
73 |
+
repo,
|
74 |
+
issue_number: issue.number,
|
75 |
+
body: 'This issue has been automatically closed due to inactivity. Please feel free to reopen it if needed.'
|
76 |
+
});
|
77 |
+
console.log(`Closed issue #${issue.number} due to inactivity.`);
|
78 |
+
} catch (error) {
|
79 |
+
console.error(`Failed to close issue #${issue.number}: ${error.message}`);
|
80 |
+
}
|
81 |
+
} else {
|
82 |
+
console.log(`Issue #${issue.number} is still active. Stopping processing.`);
|
83 |
+
hasMoreIssues = false;
|
84 |
+
break;
|
85 |
+
}
|
86 |
+
}
|
87 |
+
page += 1;
|
88 |
+
} catch (error) {
|
89 |
+
console.error(`Error fetching issues on page ${page}: ${error.message}`);
|
90 |
+
hasMoreIssues = false;
|
91 |
+
}
|
92 |
+
}
|
93 |
+
console.log('Finished processing issues');
|
94 |
+
}
|
95 |
+
|
96 |
+
await processIssues();
|
sglang/.github/workflows/execute-notebook.yml
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Execute Notebooks
|
2 |
+
|
3 |
+
on:
|
4 |
+
push:
|
5 |
+
branches: [ main ]
|
6 |
+
paths:
|
7 |
+
- "python/sglang/**"
|
8 |
+
- "docs/**"
|
9 |
+
pull_request:
|
10 |
+
branches: [ main ]
|
11 |
+
paths:
|
12 |
+
- "python/sglang/**"
|
13 |
+
- "docs/**"
|
14 |
+
workflow_dispatch:
|
15 |
+
|
16 |
+
|
17 |
+
concurrency:
|
18 |
+
group: execute-notebook-${{ github.ref }}
|
19 |
+
cancel-in-progress: true
|
20 |
+
|
21 |
+
|
22 |
+
jobs:
|
23 |
+
run-all-notebooks:
|
24 |
+
runs-on: 1-gpu-runner
|
25 |
+
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
|
26 |
+
steps:
|
27 |
+
- name: Checkout code
|
28 |
+
uses: actions/checkout@v3
|
29 |
+
|
30 |
+
- name: Set up Python
|
31 |
+
uses: actions/setup-python@v4
|
32 |
+
with:
|
33 |
+
python-version: '3.9'
|
34 |
+
|
35 |
+
- name: Install dependencies
|
36 |
+
run: |
|
37 |
+
bash scripts/ci_install_dependency.sh
|
38 |
+
pip install -r docs/requirements.txt
|
39 |
+
|
40 |
+
- name: Setup Jupyter Kernel
|
41 |
+
run: |
|
42 |
+
python -m ipykernel install --user --name python3 --display-name "Python 3"
|
43 |
+
|
44 |
+
- name: Execute notebooks
|
45 |
+
timeout-minutes: 30
|
46 |
+
run: |
|
47 |
+
cd docs
|
48 |
+
make clean
|
49 |
+
make compile
|
sglang/.github/workflows/lint.yml
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Lint
|
2 |
+
|
3 |
+
on: [pull_request]
|
4 |
+
|
5 |
+
jobs:
|
6 |
+
lint:
|
7 |
+
runs-on: ubuntu-latest
|
8 |
+
steps:
|
9 |
+
- uses: actions/checkout@v2
|
10 |
+
|
11 |
+
- name: Set up Python
|
12 |
+
uses: actions/setup-python@v4
|
13 |
+
with:
|
14 |
+
python-version: '3.9'
|
15 |
+
|
16 |
+
- name: Install pre-commit hook
|
17 |
+
run: |
|
18 |
+
python -m pip install pre-commit
|
19 |
+
pre-commit install
|
20 |
+
|
21 |
+
- name: Linting
|
22 |
+
run: pre-commit run --all-files
|
sglang/.github/workflows/nightly-test.yml
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Nightly Test
|
2 |
+
|
3 |
+
on:
|
4 |
+
schedule:
|
5 |
+
- cron: '0 0 * * *'
|
6 |
+
push:
|
7 |
+
branches:
|
8 |
+
- main
|
9 |
+
paths:
|
10 |
+
- "python/sglang/version.py"
|
11 |
+
workflow_dispatch:
|
12 |
+
|
13 |
+
concurrency:
|
14 |
+
group: nightly-test-${{ github.ref }}
|
15 |
+
cancel-in-progress: true
|
16 |
+
|
17 |
+
jobs:
|
18 |
+
nightly-test:
|
19 |
+
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
|
20 |
+
runs-on: 2-gpu-runner
|
21 |
+
steps:
|
22 |
+
- name: Checkout code
|
23 |
+
uses: actions/checkout@v3
|
24 |
+
|
25 |
+
- name: Install dependencies
|
26 |
+
run: |
|
27 |
+
bash scripts/ci_install_dependency.sh
|
28 |
+
pip install --upgrade "evalplus[vllm] @ git+https://github.com/evalplus/evalplus"
|
29 |
+
|
30 |
+
- name: Run test
|
31 |
+
timeout-minutes: 120
|
32 |
+
run: |
|
33 |
+
cd test/srt
|
34 |
+
python3 run_suite.py --suite nightly --timeout-per-file 2400
|
sglang/.github/workflows/pr-test.yml
ADDED
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: PR Test
|
2 |
+
|
3 |
+
on:
|
4 |
+
push:
|
5 |
+
branches: [ main ]
|
6 |
+
paths:
|
7 |
+
- "python/sglang/**"
|
8 |
+
- "test/**"
|
9 |
+
pull_request:
|
10 |
+
branches: [ main ]
|
11 |
+
paths:
|
12 |
+
- "python/sglang/**"
|
13 |
+
- "test/**"
|
14 |
+
workflow_dispatch:
|
15 |
+
inputs:
|
16 |
+
version:
|
17 |
+
description: "FlashInfer version"
|
18 |
+
required: true
|
19 |
+
type: choice
|
20 |
+
default: 'release'
|
21 |
+
options:
|
22 |
+
- 'release'
|
23 |
+
- 'nightly'
|
24 |
+
|
25 |
+
concurrency:
|
26 |
+
group: pr-test-${{ github.ref }}
|
27 |
+
cancel-in-progress: true
|
28 |
+
|
29 |
+
jobs:
|
30 |
+
|
31 |
+
unit-test-frontend:
|
32 |
+
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
|
33 |
+
runs-on: 1-gpu-runner
|
34 |
+
steps:
|
35 |
+
- name: Checkout code
|
36 |
+
uses: actions/checkout@v3
|
37 |
+
|
38 |
+
- name: Install dependencies
|
39 |
+
env:
|
40 |
+
FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.4/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.4/flashinfer' }}
|
41 |
+
run: |
|
42 |
+
bash scripts/ci_install_dependency.sh
|
43 |
+
|
44 |
+
- name: Run test
|
45 |
+
timeout-minutes: 10
|
46 |
+
run: |
|
47 |
+
cd test/lang
|
48 |
+
python3 run_suite.py --suite per-commit
|
49 |
+
|
50 |
+
unit-test-backend-1-gpu:
|
51 |
+
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
|
52 |
+
runs-on: 1-gpu-runner
|
53 |
+
strategy:
|
54 |
+
matrix:
|
55 |
+
range: [0-6, 6-16, 16-23, 23-30, 30-100]
|
56 |
+
steps:
|
57 |
+
- name: Checkout code
|
58 |
+
uses: actions/checkout@v3
|
59 |
+
|
60 |
+
- name: Install dependencies
|
61 |
+
env:
|
62 |
+
FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.4/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.4/flashinfer' }}
|
63 |
+
run: |
|
64 |
+
bash scripts/ci_install_dependency.sh
|
65 |
+
|
66 |
+
- name: Run test
|
67 |
+
timeout-minutes: 25
|
68 |
+
run: |
|
69 |
+
cd test/srt
|
70 |
+
RANGE=${{ matrix.range }}
|
71 |
+
range_begin=${RANGE%-*}
|
72 |
+
range_end=${RANGE#*-}
|
73 |
+
python3 run_suite.py --suite per-commit --range-begin ${range_begin} --range-end ${range_end}
|
74 |
+
|
75 |
+
unit-test-backend-2-gpu:
|
76 |
+
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
|
77 |
+
runs-on: 2-gpu-runner
|
78 |
+
steps:
|
79 |
+
- name: Checkout code
|
80 |
+
uses: actions/checkout@v3
|
81 |
+
|
82 |
+
- name: Install dependencies
|
83 |
+
env:
|
84 |
+
FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.4/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.4/flashinfer' }}
|
85 |
+
run: |
|
86 |
+
bash scripts/ci_install_dependency.sh
|
87 |
+
|
88 |
+
- name: Evaluate data parallelism accuracy (DP=2)
|
89 |
+
timeout-minutes: 10
|
90 |
+
run: |
|
91 |
+
cd test/srt
|
92 |
+
python3 test_data_parallelism.py
|
93 |
+
|
94 |
+
- name: Evaluate MLA accuracy (TP=2)
|
95 |
+
timeout-minutes: 10
|
96 |
+
run: |
|
97 |
+
cd test/srt
|
98 |
+
python3 test_mla.py
|
99 |
+
python3 test_mla_fp8.py
|
100 |
+
python3 test_dp_attention.py
|
101 |
+
|
102 |
+
- name: Test update weights from distributed
|
103 |
+
timeout-minutes: 10
|
104 |
+
run: |
|
105 |
+
cd test/srt
|
106 |
+
python3 test_update_weights_from_distributed.py
|
107 |
+
|
108 |
+
- name: Evaluate MoE EP accuracy (TP=2)
|
109 |
+
timeout-minutes: 10
|
110 |
+
run: |
|
111 |
+
cd test/srt
|
112 |
+
python3 test_moe_ep.py
|
113 |
+
|
114 |
+
performance-test-1-gpu-part-1:
|
115 |
+
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
|
116 |
+
runs-on: 1-gpu-runner
|
117 |
+
steps:
|
118 |
+
- name: Checkout code
|
119 |
+
uses: actions/checkout@v3
|
120 |
+
|
121 |
+
- name: Install dependencies
|
122 |
+
env:
|
123 |
+
FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.4/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.4/flashinfer' }}
|
124 |
+
run: |
|
125 |
+
bash scripts/ci_install_dependency.sh
|
126 |
+
|
127 |
+
- name: Benchmark single latency
|
128 |
+
timeout-minutes: 10
|
129 |
+
run: |
|
130 |
+
cd test/srt
|
131 |
+
python3 -m unittest test_bench_one_batch.TestBenchOneBatch.test_default
|
132 |
+
|
133 |
+
- name: Benchmark online latency
|
134 |
+
timeout-minutes: 10
|
135 |
+
run: |
|
136 |
+
cd test/srt
|
137 |
+
python3 -m unittest test_bench_serving.TestBenchServing.test_online_latency_default
|
138 |
+
|
139 |
+
- name: Benchmark offline throughput
|
140 |
+
timeout-minutes: 10
|
141 |
+
run: |
|
142 |
+
cd test/srt
|
143 |
+
python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_default
|
144 |
+
|
145 |
+
- name: Benchmark offline throughput (Non-streaming, small batch size)
|
146 |
+
timeout-minutes: 10
|
147 |
+
run: |
|
148 |
+
cd test/srt
|
149 |
+
python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_non_stream_small_batch_size
|
150 |
+
|
151 |
+
performance-test-1-gpu-part-2:
|
152 |
+
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
|
153 |
+
runs-on: 1-gpu-runner
|
154 |
+
steps:
|
155 |
+
- name: Checkout code
|
156 |
+
uses: actions/checkout@v3
|
157 |
+
|
158 |
+
- name: Install dependencies
|
159 |
+
env:
|
160 |
+
FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.4/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.4/flashinfer' }}
|
161 |
+
run: |
|
162 |
+
bash scripts/ci_install_dependency.sh
|
163 |
+
|
164 |
+
- name: Benchmark offline throughput (w/o RadixAttention)
|
165 |
+
timeout-minutes: 10
|
166 |
+
run: |
|
167 |
+
cd test/srt
|
168 |
+
python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_without_radix_cache
|
169 |
+
|
170 |
+
- name: Benchmark offline throughput (w/ Triton)
|
171 |
+
timeout-minutes: 10
|
172 |
+
run: |
|
173 |
+
cd test/srt
|
174 |
+
python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_with_triton_attention_backend
|
175 |
+
|
176 |
+
- name: Benchmark offline throughput (w/ FP8)
|
177 |
+
timeout-minutes: 10
|
178 |
+
run: |
|
179 |
+
cd test/srt
|
180 |
+
python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_default_fp8
|
181 |
+
|
182 |
+
performance-test-2-gpu:
|
183 |
+
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
|
184 |
+
runs-on: 2-gpu-runner
|
185 |
+
steps:
|
186 |
+
- name: Checkout code
|
187 |
+
uses: actions/checkout@v3
|
188 |
+
|
189 |
+
- name: Install dependencies
|
190 |
+
env:
|
191 |
+
FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.4/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.4/flashinfer' }}
|
192 |
+
run: |
|
193 |
+
bash scripts/ci_install_dependency.sh
|
194 |
+
|
195 |
+
- name: Benchmark single latency (TP=2)
|
196 |
+
timeout-minutes: 10
|
197 |
+
run: |
|
198 |
+
cd test/srt
|
199 |
+
python3 -m unittest test_bench_one_batch.TestBenchOneBatch.test_moe_default
|
200 |
+
|
201 |
+
- name: Benchmark offline throughput (TP=2)
|
202 |
+
timeout-minutes: 10
|
203 |
+
run: |
|
204 |
+
cd test/srt
|
205 |
+
python3 -m unittest test_bench_serving.TestBenchServing.test_moe_offline_throughput_default
|
206 |
+
|
207 |
+
- name: Benchmark offline throughput (w/o RadixAttention) (TP=2)
|
208 |
+
timeout-minutes: 10
|
209 |
+
run: |
|
210 |
+
cd test/srt
|
211 |
+
python3 -m unittest test_bench_serving.TestBenchServing.test_moe_offline_throughput_without_radix_cache
|
212 |
+
|
213 |
+
accuracy-test-1-gpu:
|
214 |
+
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
|
215 |
+
runs-on: 1-gpu-runner
|
216 |
+
steps:
|
217 |
+
- name: Checkout code
|
218 |
+
uses: actions/checkout@v3
|
219 |
+
|
220 |
+
- name: Install dependencies
|
221 |
+
env:
|
222 |
+
FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.4/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.4/flashinfer' }}
|
223 |
+
run: |
|
224 |
+
bash scripts/ci_install_dependency.sh
|
225 |
+
|
226 |
+
git clone https://github.com/merrymercy/human-eval.git
|
227 |
+
cd human-eval
|
228 |
+
pip install -e .
|
229 |
+
|
230 |
+
- name: Evaluate accuracy
|
231 |
+
timeout-minutes: 20
|
232 |
+
run: |
|
233 |
+
cd test/srt
|
234 |
+
python3 test_eval_accuracy_large.py
|
235 |
+
|
236 |
+
|
237 |
+
accuracy-test-2-gpu:
|
238 |
+
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
|
239 |
+
runs-on: 2-gpu-runner
|
240 |
+
steps:
|
241 |
+
- name: Checkout code
|
242 |
+
uses: actions/checkout@v3
|
243 |
+
|
244 |
+
- name: Install dependencies
|
245 |
+
env:
|
246 |
+
FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.4/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.4/flashinfer' }}
|
247 |
+
run: |
|
248 |
+
bash scripts/ci_install_dependency.sh
|
249 |
+
|
250 |
+
git clone https://github.com/merrymercy/human-eval.git
|
251 |
+
cd human-eval
|
252 |
+
pip install -e .
|
253 |
+
|
254 |
+
- name: Evaluate accuracy (TP=2)
|
255 |
+
timeout-minutes: 20
|
256 |
+
run: |
|
257 |
+
cd test/srt
|
258 |
+
python3 test_moe_eval_accuracy_large.py
|
259 |
+
|
260 |
+
|
261 |
+
finish:
|
262 |
+
needs: [
|
263 |
+
unit-test-frontend, unit-test-backend-1-gpu, unit-test-backend-2-gpu,
|
264 |
+
performance-test-1-gpu-part-1, performance-test-1-gpu-part-2, performance-test-2-gpu,
|
265 |
+
accuracy-test-1-gpu, accuracy-test-2-gpu
|
266 |
+
]
|
267 |
+
runs-on: ubuntu-latest
|
268 |
+
steps:
|
269 |
+
- name: Finish
|
270 |
+
run: echo "This is an empty step to ensure that all jobs are completed."
|
sglang/.github/workflows/release-docker-dev.yml
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Build Development Docker Image
|
2 |
+
|
3 |
+
on:
|
4 |
+
workflow_dispatch:
|
5 |
+
schedule:
|
6 |
+
- cron: '0 0 * * *'
|
7 |
+
|
8 |
+
jobs:
|
9 |
+
build-dev:
|
10 |
+
runs-on: ubuntu-22.04
|
11 |
+
steps:
|
12 |
+
- name: Checkout repository
|
13 |
+
uses: actions/checkout@v3
|
14 |
+
|
15 |
+
- name: Free disk space
|
16 |
+
uses: jlumbroso/free-disk-space@main
|
17 |
+
with:
|
18 |
+
tool-cache: false
|
19 |
+
docker-images: false
|
20 |
+
android: true
|
21 |
+
dotnet: true
|
22 |
+
haskell: true
|
23 |
+
large-packages: true
|
24 |
+
swap-storage: false
|
25 |
+
|
26 |
+
- name: Login to Docker Hub
|
27 |
+
uses: docker/login-action@v2
|
28 |
+
with:
|
29 |
+
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
30 |
+
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
31 |
+
|
32 |
+
- name: Build and Push Dev Image
|
33 |
+
run: |
|
34 |
+
docker build . -f docker/Dockerfile.dev -t lmsysorg/sglang:dev --no-cache
|
35 |
+
docker push lmsysorg/sglang:dev
|
sglang/.github/workflows/release-docker.yml
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Release Docker Images
|
2 |
+
on:
|
3 |
+
push:
|
4 |
+
branches:
|
5 |
+
- main
|
6 |
+
paths:
|
7 |
+
- "python/sglang/version.py"
|
8 |
+
workflow_dispatch:
|
9 |
+
|
10 |
+
jobs:
|
11 |
+
publish:
|
12 |
+
if: github.repository == 'sgl-project/sglang'
|
13 |
+
runs-on: ubuntu-latest
|
14 |
+
environment: 'prod'
|
15 |
+
strategy:
|
16 |
+
matrix:
|
17 |
+
cuda_version: ['11.8.0', '12.1.1', '12.4.1']
|
18 |
+
build_type: ['all', 'srt']
|
19 |
+
steps:
|
20 |
+
- name: Delete huge unnecessary tools folder
|
21 |
+
run: rm -rf /opt/hostedtoolcache
|
22 |
+
|
23 |
+
- name: Checkout repository
|
24 |
+
uses: actions/checkout@v3
|
25 |
+
|
26 |
+
- name: Login to Docker Hub
|
27 |
+
uses: docker/login-action@v2
|
28 |
+
with:
|
29 |
+
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
30 |
+
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
31 |
+
|
32 |
+
- name: Build and Push
|
33 |
+
run: |
|
34 |
+
version=$(cat python/sglang/version.py | cut -d'"' -f2)
|
35 |
+
|
36 |
+
if [ "${{ matrix.cuda_version }}" = "11.8.0" ]; then
|
37 |
+
cuda_tag="cu118"
|
38 |
+
elif [ "${{ matrix.cuda_version }}" = "12.1.1" ]; then
|
39 |
+
cuda_tag="cu121"
|
40 |
+
elif [ "${{ matrix.cuda_version }}" = "12.4.1" ]; then
|
41 |
+
cuda_tag="cu124"
|
42 |
+
else
|
43 |
+
echo "Unsupported CUDA version"
|
44 |
+
exit 1
|
45 |
+
fi
|
46 |
+
|
47 |
+
tag=v${version}-${cuda_tag}
|
48 |
+
|
49 |
+
if [ "${{ matrix.build_type }}" = "all" ]; then
|
50 |
+
tag_suffix=""
|
51 |
+
elif [ "${{ matrix.build_type }}" = "srt" ]; then
|
52 |
+
tag_suffix="-srt"
|
53 |
+
else
|
54 |
+
echo "Unsupported build type"
|
55 |
+
exit 1
|
56 |
+
fi
|
57 |
+
|
58 |
+
docker build . -f docker/Dockerfile --build-arg CUDA_VERSION=${{ matrix.cuda_version }} --build-arg BUILD_TYPE=${{ matrix.build_type }} -t lmsysorg/sglang:${tag}${tag_suffix} --no-cache
|
59 |
+
docker push lmsysorg/sglang:${tag}${tag_suffix}
|
60 |
+
|
61 |
+
if [ "${{ matrix.cuda_version }}" = "12.4.1" ]; then
|
62 |
+
docker tag lmsysorg/sglang:${tag}${tag_suffix} lmsysorg/sglang:latest${tag_suffix}
|
63 |
+
docker push lmsysorg/sglang:latest${tag_suffix}
|
64 |
+
fi
|
sglang/.github/workflows/release-pypi-kernel.yml
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Release SGLang Kernel to PyPI
|
2 |
+
|
3 |
+
on:
|
4 |
+
push:
|
5 |
+
branches:
|
6 |
+
- main
|
7 |
+
paths:
|
8 |
+
- sgl-kernel/pyproject.toml
|
9 |
+
workflow_dispatch:
|
10 |
+
|
11 |
+
concurrency:
|
12 |
+
group: release-pypi-kernel-${{ github.ref }}
|
13 |
+
cancel-in-progress: true
|
14 |
+
|
15 |
+
jobs:
|
16 |
+
build-wheels:
|
17 |
+
runs-on: ubuntu-latest
|
18 |
+
strategy:
|
19 |
+
matrix:
|
20 |
+
python-version: ['3.8', '3.9', '3.10', '3.11', '3.12']
|
21 |
+
cuda-version: ['12.1']
|
22 |
+
|
23 |
+
steps:
|
24 |
+
- uses: actions/checkout@v4
|
25 |
+
|
26 |
+
- name: Set up Python ${{ matrix.python-version }}
|
27 |
+
uses: actions/setup-python@v5
|
28 |
+
with:
|
29 |
+
python-version: ${{ matrix.python-version }}
|
30 |
+
|
31 |
+
- name: Build wheels for Python ${{ matrix.python-version }} and CUDA ${{ matrix.cuda-version }}
|
32 |
+
run: |
|
33 |
+
cd sgl-kernel
|
34 |
+
chmod +x ./build.sh
|
35 |
+
./build.sh "${{ matrix.python-version }}" "${{ matrix.cuda-version }}"
|
36 |
+
|
37 |
+
- name: Upload to pypi
|
38 |
+
working-directory: sgl-kernel
|
39 |
+
run: |
|
40 |
+
pip install twine
|
41 |
+
python3 -m twine upload dist/* -u __token__ -p ${{ secrets.PYPI_TOKEN }}
|
sglang/.github/workflows/release-pypi.yml
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Release PyPI
|
2 |
+
on:
|
3 |
+
push:
|
4 |
+
branches:
|
5 |
+
- main
|
6 |
+
paths:
|
7 |
+
- "python/sglang/version.py"
|
8 |
+
workflow_dispatch:
|
9 |
+
|
10 |
+
jobs:
|
11 |
+
publish:
|
12 |
+
if: github.repository == 'sgl-project/sglang'
|
13 |
+
runs-on: ubuntu-latest
|
14 |
+
environment: 'prod'
|
15 |
+
steps:
|
16 |
+
- name: Set up Python
|
17 |
+
uses: actions/setup-python@v4
|
18 |
+
with:
|
19 |
+
python-version: '3.9'
|
20 |
+
|
21 |
+
- name: Checkout repository
|
22 |
+
uses: actions/checkout@v3
|
23 |
+
|
24 |
+
- name: Upload to pypi
|
25 |
+
run: |
|
26 |
+
cd python
|
27 |
+
cp ../README.md ../LICENSE .
|
28 |
+
pip install build
|
29 |
+
python3 -m build
|
30 |
+
pip install twine
|
31 |
+
python3 -m twine upload dist/* -u __token__ -p ${{ secrets.PYPI_TOKEN }}
|
sglang/3rdparty/amd/profiling/PROFILING.md
ADDED
@@ -0,0 +1,425 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Profiling SGLang Infer System with AMD GPUs
|
2 |
+
This AppNote describes the SGLang profiling technical, code augment and running steps for systems with AMD Instinct GPUs, nevertheless the same procedure may work with Nvidia GPUs too.
|
3 |
+
Examples and steps are provided in detail, to facilitate easy reproduce and use to localize performance problem towards optimizations.
|
4 |
+
Two primary methods are covered:
|
5 |
+
- [RPD](https://github.com/ROCm/rocmProfileData.git)
|
6 |
+
- [PyTorch Profiler](https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html)
|
7 |
+
|
8 |
+
### Profiling SGLang Infer System with RPD Profiler
|
9 |
+
RPD profiler is a low-overhead cross-platform profiler. Therefore, the same RPD code augment not only works for profiling on ROCm/AMD GPUs, but also works for profiling on CUDA/Nvidia GPUs as well. To do RPD profiling on SGLang repository, please use scripts and patch files included in this directory and follow the steps below:
|
10 |
+
1. Install RPD with rpd.patch applied during installation using install_rpd.sh, both files are in this directory.
|
11 |
+
|
12 |
+
install_rpd.sh
|
13 |
+
|
14 |
+
```bash
|
15 |
+
# download and install RPD
|
16 |
+
apt update && apt install -y sqlite3 libsqlite3-dev libfmt-dev
|
17 |
+
|
18 |
+
# install rpd module
|
19 |
+
git clone https://github.com/ROCmSoftwarePlatform/rocmProfileData
|
20 |
+
cd rocmProfileData
|
21 |
+
git checkout 976899e9c6dbc6dd2bccf770818e4e44125590ac
|
22 |
+
git apply rpd.patch
|
23 |
+
make && make install
|
24 |
+
cd rocpd_python && python setup.py install && cd ..
|
25 |
+
cd rpd_tracer && make clean;make install && python setup.py install && cd ..
|
26 |
+
```
|
27 |
+
|
28 |
+
rpd.patch
|
29 |
+
|
30 |
+
```bash
|
31 |
+
diff --git a/rpd_tracer/Makefile b/rpd_tracer/Makefile
|
32 |
+
index e9d9feb..b2e9e1a 100644
|
33 |
+
--- a/rpd_tracer/Makefile
|
34 |
+
+++ b/rpd_tracer/Makefile
|
35 |
+
@@ -16,7 +16,7 @@ ifneq (,$(HIP_PATH))
|
36 |
+
$(info Building with roctracer)
|
37 |
+
RPD_LIBS += -L/opt/rocm/lib -lroctracer64 -lroctx64 -lamdhip64 -lrocm_smi64
|
38 |
+
RPD_INCLUDES += -I/opt/rocm/include -I/opt/rocm/include/roctracer -I/opt/rocm/include/hsa
|
39 |
+
- RPD_SRCS += RoctracerDataSource.cpp RocmSmiDataSource.cpp
|
40 |
+
+ RPD_SRCS += RoctracerDataSource.cpp
|
41 |
+
RPD_INCLUDES += -D__HIP_PLATFORM_AMD__
|
42 |
+
endif
|
43 |
+
```
|
44 |
+
2. Add loadTracer.sh file included in this directory to /sglang/python/sglang.
|
45 |
+
|
46 |
+
loadTracer.sh
|
47 |
+
|
48 |
+
```bash
|
49 |
+
#!/bin/bash
|
50 |
+
################################################################################
|
51 |
+
# Copyright (c) 2021 - 2023 Advanced Micro Devices, Inc. All rights reserved.
|
52 |
+
#
|
53 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
54 |
+
# of this software and associated documentation files (the "Software"), to deal
|
55 |
+
# in the Software without restriction, including without limitation the rights
|
56 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
57 |
+
# copies of the Software, and to permit persons to whom the Software is
|
58 |
+
# furnished to do so, subject to the following conditions:
|
59 |
+
#
|
60 |
+
# The above copyright notice and this permission notice shall be included in
|
61 |
+
# all copies or substantial portions of the Software.
|
62 |
+
#
|
63 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
64 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
65 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
66 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
67 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
68 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
69 |
+
# THE SOFTWARE.
|
70 |
+
################################################################################
|
71 |
+
OUTPUT_FILE="trace.rpd"
|
72 |
+
|
73 |
+
if [ "$1" = "-o" ] ; then
|
74 |
+
OUTPUT_FILE=$2
|
75 |
+
shift
|
76 |
+
shift
|
77 |
+
fi
|
78 |
+
|
79 |
+
if [ -e ${OUTPUT_FILE} ] ; then
|
80 |
+
rm ${OUTPUT_FILE}
|
81 |
+
fi
|
82 |
+
|
83 |
+
python3 -m rocpd.schema --create ${OUTPUT_FILE}
|
84 |
+
if [ $? != 0 ] ; then
|
85 |
+
echo "Error: Could not create rpd file. Please run 'python setup.py install' from the rocpd_python dir"
|
86 |
+
exit
|
87 |
+
fi
|
88 |
+
|
89 |
+
export RPDT_FILENAME=${OUTPUT_FILE}
|
90 |
+
export RPDT_AUTOSTART=0
|
91 |
+
LD_PRELOAD=librocm-smi_64:librpd_tracer.so "$@"
|
92 |
+
```
|
93 |
+
3. Apply patch (provided in this directory) with "git apply rpd_profile_server_enable.patch" if the main profiling purpose is to get info on gpu kernels as well as limited cpu activity info.
|
94 |
+
|
95 |
+
#### Common Notes 1
|
96 |
+
Please note that although we are doing TP=8 in the example, we purposely only log RPD profiling on 2 ranks in the patch file (i.e.tp_rank=0/1) for profiling/visualization convenience, as even Perfetto streaming mode can only load maximal 8GB json file for visualization. With 2 ranks logged in RPD profiling, we could still check whether there are issues among ranks (e.g. load imbalance issue, nccl issue), and at the same time, we could log relatively longer time duration before the json file generated from RPD file hits 8GB size.
|
97 |
+
|
98 |
+
rpd_profile_server_enable.patch
|
99 |
+
|
100 |
+
```bash
|
101 |
+
diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py
|
102 |
+
index 62d1ff9..9021c01 100644
|
103 |
+
--- a/python/sglang/srt/managers/scheduler.py
|
104 |
+
+++ b/python/sglang/srt/managers/scheduler.py
|
105 |
+
@@ -71,6 +71,8 @@ from sglang.srt.utils import (
|
106 |
+
suppress_other_loggers,
|
107 |
+
)
|
108 |
+
from sglang.utils import get_exception_traceback
|
109 |
+
+from rpdTracerControl import rpdTracerControl
|
110 |
+
+rpdTracerControl.skipCreate()
|
111 |
+
|
112 |
+
logger = logging.getLogger(__name__)
|
113 |
+
|
114 |
+
@@ -245,6 +247,7 @@ class Scheduler:
|
115 |
+
],
|
116 |
+
with_stack=True,
|
117 |
+
)
|
118 |
+
+ self.rpd = rpdTracerControl()
|
119 |
+
|
120 |
+
@torch.inference_mode()
|
121 |
+
def event_loop(self):
|
122 |
+
@@ -1027,15 +1030,24 @@ class Scheduler:
|
123 |
+
def start_profile(self) -> None:
|
124 |
+
if self.profiler is None:
|
125 |
+
raise RuntimeError("Profiler is not enabled.")
|
126 |
+
- self.profiler.start()
|
127 |
+
+ #self.profiler.start() #block pytorch profiler for rpd profiler enabling
|
128 |
+
+ if self.tp_rank == 0 or self.tp_rank == 1:
|
129 |
+
+ self.rpd.start()
|
130 |
+
+ self.rpd.rangePush("", "rpd profile range", "")
|
131 |
+
+ logger.info("rpd is enabled")
|
132 |
+
|
133 |
+
def stop_profile(self) -> None:
|
134 |
+
if self.profiler is None:
|
135 |
+
raise RuntimeError("Profiler is not enabled.")
|
136 |
+
- self.profiler.stop()
|
137 |
+
- self.profiler.export_chrome_trace(
|
138 |
+
- self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
|
139 |
+
- )
|
140 |
+
+ #self.profiler.stop()
|
141 |
+
+ #self.profiler.export_chrome_trace(
|
142 |
+
+ # self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
|
143 |
+
+ #)
|
144 |
+
+ if self.tp_rank ==0 or self.tp_rank ==1:
|
145 |
+
+ self.rpd.rangePop()
|
146 |
+
+ self.rpd.stop()
|
147 |
+
+ self.rpd.flush()
|
148 |
+
+ logger.info("rpd is done")
|
149 |
+
logger.info("Profiler is done")
|
150 |
+
```
|
151 |
+
|
152 |
+
#### Advanced Debugging with RPD Profiler
|
153 |
+
Sometimes, we want to use rpd profiler to capture more CPU and python activities in order to debug some challenging issues (e.g. root cause of load imbalance across gpu processes, root cause of bubbles, etc). Only in such cases, we need to apply patch "git apply rpd_profile_server_enable_wCPU_activities.patch", where 3 files are modified.
|
154 |
+
|
155 |
+
rpd_profile_server_enable_wCPU_activities.patch
|
156 |
+
|
157 |
+
```bash
|
158 |
+
diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py
|
159 |
+
index 62d1ff9..2edb427 100644
|
160 |
+
--- a/python/sglang/srt/managers/scheduler.py
|
161 |
+
+++ b/python/sglang/srt/managers/scheduler.py
|
162 |
+
@@ -71,6 +71,8 @@ from sglang.srt.utils import (
|
163 |
+
suppress_other_loggers,
|
164 |
+
)
|
165 |
+
from sglang.utils import get_exception_traceback
|
166 |
+
+from rpdTracerControl import rpdTracerControl
|
167 |
+
+rpdTracerControl.skipCreate()
|
168 |
+
|
169 |
+
logger = logging.getLogger(__name__)
|
170 |
+
|
171 |
+
@@ -245,6 +247,7 @@ class Scheduler:
|
172 |
+
],
|
173 |
+
with_stack=True,
|
174 |
+
)
|
175 |
+
+ self.rpd = rpdTracerControl()
|
176 |
+
|
177 |
+
@torch.inference_mode()
|
178 |
+
def event_loop(self):
|
179 |
+
@@ -1027,15 +1030,26 @@ class Scheduler:
|
180 |
+
def start_profile(self) -> None:
|
181 |
+
if self.profiler is None:
|
182 |
+
raise RuntimeError("Profiler is not enabled.")
|
183 |
+
- self.profiler.start()
|
184 |
+
+ #self.profiler.start()
|
185 |
+
+ logger.info("torch profiler is disabled")
|
186 |
+
+ if self.tp_rank == 0 or self.tp_rank == 1:
|
187 |
+
+ self.rpd.setPythonTrace(True)
|
188 |
+
+ self.rpd.start()
|
189 |
+
+ self.rpd.rangePush("", "scheduler", "")
|
190 |
+
+ logger.info("rpd is enabled inside scheduler profiling")
|
191 |
+
|
192 |
+
def stop_profile(self) -> None:
|
193 |
+
if self.profiler is None:
|
194 |
+
raise RuntimeError("Profiler is not enabled.")
|
195 |
+
- self.profiler.stop()
|
196 |
+
- self.profiler.export_chrome_trace(
|
197 |
+
- self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
|
198 |
+
- )
|
199 |
+
+ #self.profiler.stop()
|
200 |
+
+ #self.profiler.export_chrome_trace(
|
201 |
+
+ # self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
|
202 |
+
+ #)
|
203 |
+
+ if self.tp_rank ==0 or self.tp_rank ==1:
|
204 |
+
+ self.rpd.rangePop()
|
205 |
+
+ self.rpd.stop()
|
206 |
+
+ self.rpd.flush()
|
207 |
+
+ logger.info("rpd is done inside scheduler")
|
208 |
+
logger.info("Profiler is done")
|
209 |
+
|
210 |
+
|
211 |
+
diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py
|
212 |
+
index 2621ccd..181df85 100644
|
213 |
+
--- a/python/sglang/srt/managers/tokenizer_manager.py
|
214 |
+
+++ b/python/sglang/srt/managers/tokenizer_manager.py
|
215 |
+
@@ -58,6 +58,10 @@ from sglang.srt.sampling.sampling_params import SamplingParams
|
216 |
+
from sglang.srt.server_args import PortArgs, ServerArgs
|
217 |
+
from sglang.srt.utils import is_generation_model, is_multimodal_model
|
218 |
+
|
219 |
+
+from rpdTracerControl import rpdTracerControl
|
220 |
+
+rpdTracerControl.skipCreate()
|
221 |
+
+
|
222 |
+
+
|
223 |
+
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
224 |
+
|
225 |
+
logger = logging.getLogger(__name__)
|
226 |
+
@@ -514,10 +518,20 @@ class TokenizerManager:
|
227 |
+
self.send_to_scheduler.send_pyobj(req)
|
228 |
+
|
229 |
+
def start_profile(self):
|
230 |
+
+ rpd = rpdTracerControl()
|
231 |
+
+ rpd.setPythonTrace(True)
|
232 |
+
+ rpd.start()
|
233 |
+
+ rpd.rangePush("", "tokenizer_manager", "")
|
234 |
+
+ logger.info("tokenizer_manager rpd profiling started!")
|
235 |
+
req = ProfileReq.START_PROFILE
|
236 |
+
self.send_to_scheduler.send_pyobj(req)
|
237 |
+
|
238 |
+
def stop_profile(self):
|
239 |
+
+ rpd = rpdTracerControl()
|
240 |
+
+ rpd.rangePop()
|
241 |
+
+ rpd.stop()
|
242 |
+
+ rpd.flush()
|
243 |
+
+ logger.info("rpd profiling is done inside tokenizer_manager!")
|
244 |
+
req = ProfileReq.STOP_PROFILE
|
245 |
+
self.send_to_scheduler.send_pyobj(req)
|
246 |
+
|
247 |
+
diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py
|
248 |
+
index 7111c93..2bd722c 100644
|
249 |
+
--- a/python/sglang/srt/server.py
|
250 |
+
+++ b/python/sglang/srt/server.py
|
251 |
+
@@ -30,6 +30,8 @@ import threading
|
252 |
+
import time
|
253 |
+
from http import HTTPStatus
|
254 |
+
from typing import Dict, List, Optional, Union
|
255 |
+
+from rpdTracerControl import rpdTracerControl
|
256 |
+
+rpdTracerControl.skipCreate()
|
257 |
+
|
258 |
+
# Fix a bug of Python threading
|
259 |
+
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
260 |
+
@@ -152,6 +154,11 @@ async def flush_cache():
|
261 |
+
@app.post("/start_profile")
|
262 |
+
async def start_profile():
|
263 |
+
"""Start profiling."""
|
264 |
+
+ rpd = rpdTracerControl()
|
265 |
+
+ rpd.setPythonTrace(True)
|
266 |
+
+ rpd.start()
|
267 |
+
+ rpd.rangePush("", "server rpd profile range", "")
|
268 |
+
+ logger.info("rpd profiling started in server.py!")
|
269 |
+
tokenizer_manager.start_profile()
|
270 |
+
return Response(
|
271 |
+
content="Start profiling.\n",
|
272 |
+
@@ -164,6 +171,11 @@ async def start_profile():
|
273 |
+
async def stop_profile():
|
274 |
+
"""Stop profiling."""
|
275 |
+
tokenizer_manager.stop_profile()
|
276 |
+
+ rpd = rpdTracerControl()
|
277 |
+
+ rpd.rangePop()
|
278 |
+
+ rpd.stop()
|
279 |
+
+ rpd.flush()
|
280 |
+
+ logger.info("rpd profiling is done in server.py!")
|
281 |
+
return Response(
|
282 |
+
content="Stop profiling. This will take some time.\n",
|
283 |
+
status_code=200,
|
284 |
+
```
|
285 |
+
|
286 |
+
4. As an example for grok1 profiling, we create a dummy_grok1 directory with config.json (see content below) inside this directory and copy this directory to the right path for "--model-path" if you want to use the example server.sh file provided.
|
287 |
+
```bash
|
288 |
+
cat ../dummy_grok1/config.json
|
289 |
+
{
|
290 |
+
"architectures": [
|
291 |
+
"Grok1ModelForCausalLM"
|
292 |
+
],
|
293 |
+
"embedding_multiplier_scale": 78.38367176906169,
|
294 |
+
"output_multiplier_scale": 0.5773502691896257,
|
295 |
+
"vocab_size": 131072,
|
296 |
+
"hidden_size": 6144,
|
297 |
+
"intermediate_size": 32768,
|
298 |
+
"max_position_embeddings": 8192,
|
299 |
+
"num_experts_per_tok": 2,
|
300 |
+
"num_local_experts": 8,
|
301 |
+
"num_attention_heads": 48,
|
302 |
+
"num_hidden_layers": 64,
|
303 |
+
"num_key_value_heads": 8,
|
304 |
+
"head_dim": 128,
|
305 |
+
"rms_norm_eps": 1e-05,
|
306 |
+
"rope_theta": 10000.0,
|
307 |
+
"model_type": "mixtral",
|
308 |
+
"torch_dtype": "bfloat16"
|
309 |
+
}
|
310 |
+
```
|
311 |
+
5. Launch server with rpd enabled script ./server.sh in one terminal inside the docker container.
|
312 |
+
|
313 |
+
#### Common Notes 2
|
314 |
+
- Remember to change model-path to the correct path
|
315 |
+
- loadTracer.sh is needed to conduct profiling
|
316 |
+
- SGLANG_TORCH_PROFILER_DIR is used for default torch profiler
|
317 |
+
- Do not use loadTracer.sh if you are using the torch profiler, simply use python3 -m sglang.launch_server.
|
318 |
+
|
319 |
+
|
320 |
+
server.sh
|
321 |
+
|
322 |
+
```bash
|
323 |
+
#!/bin/bash
|
324 |
+
|
325 |
+
# export SGLANG_TORCH_PROFILER_DIR=/data/sglang/
|
326 |
+
export SGLANG_TORCH_PROFILER_DIR=/sgl-workspace/sglang/profile/
|
327 |
+
|
328 |
+
# Get the current timestamp
|
329 |
+
TIMESTAMP=$(date +"%Y%m%d_%H%M%S")
|
330 |
+
|
331 |
+
# Define the log file with a timestamp
|
332 |
+
LOGFILE="sglang_server_log_$TIMESTAMP.json"
|
333 |
+
|
334 |
+
# Run the Python command and save the output to the log file
|
335 |
+
loadTracer.sh python3 -m sglang.launch_server \
|
336 |
+
--model-path /sgl-workspace/sglang/dummy_grok1 \
|
337 |
+
--tokenizer-path Xenova/grok-1-tokenizer \
|
338 |
+
--load-format dummy \
|
339 |
+
--quant fp8 \
|
340 |
+
--tp 8 \
|
341 |
+
--port 30000 \
|
342 |
+
--disable-radix-cache 2>&1 | tee "$LOGFILE"
|
343 |
+
```
|
344 |
+
6. Open another terminal for the same docker container, and run the rpd enabled ./client.sh after you see "The server is fired up and is ready to roll!" message from server side terminal.
|
345 |
+
|
346 |
+
#### Common Notes 3
|
347 |
+
- Use curl http://localhost:30000/start_profile & curl http://localhost:30000/stop_profile to control the start and end of profiling. Check sglang/python/sglang/srt/managers/scheduler.py for more details.
|
348 |
+
- Please don't use RPD profiler together with PyTorch profiler to avoid interference.
|
349 |
+
- The rocmProfileData/tools/rpd2tracing.py file is used to generate json file from RPD file.
|
350 |
+
|
351 |
+
client.sh
|
352 |
+
|
353 |
+
```bash
|
354 |
+
#!/bin/bash
|
355 |
+
|
356 |
+
# Start profiling via API
|
357 |
+
curl http://localhost:30000/start_profile -H "Content-Type: application/json"
|
358 |
+
|
359 |
+
# Benchmark serving using sglang with random dataset and tokenizer
|
360 |
+
# Define the log file with a timestamp
|
361 |
+
TIMESTAMP=$(date +%Y%m%d_%H%M%S)
|
362 |
+
LOGFILE="sglang_client_log_$TIMESTAMP.json"
|
363 |
+
|
364 |
+
# Run the benchmark with specified parameters and save logs
|
365 |
+
python3 -m sglang.bench_serving \
|
366 |
+
--backend sglang \
|
367 |
+
--tokenizer Xenova/grok-1-tokenizer \
|
368 |
+
--dataset-name random \
|
369 |
+
--random-input 1024\
|
370 |
+
--random-output 1024 \
|
371 |
+
--num-prompts 120 \
|
372 |
+
--request-rate 8 \
|
373 |
+
--output-file online.jsonl 2>&1 | tee "$LOGFILE"
|
374 |
+
|
375 |
+
# Stop profiling via API
|
376 |
+
curl http://localhost:30000/stop_profile -H "Content-Type: application/json"
|
377 |
+
|
378 |
+
# Convert tracing file to csv & json
|
379 |
+
sqlite3 trace.rpd ".mode csv" ".header on" ".output trace.csv" "select * from top;" ".output stdout"
|
380 |
+
python3 ./rocmProfileData/tools/rpd2tracing.py trace.rpd trace.json
|
381 |
+
```
|
382 |
+
7. Follow [Perfetto docs](https://perfetto.dev/docs/visualization/large-traces) to visualize large json files. Try to adjust parameters so that the trace.json file size is less than 9GB.
|
383 |
+
|
384 |
+
### Profiling SGLang Infer System with PyTorch Profiler
|
385 |
+
|
386 |
+
Please use the steps as follows:
|
387 |
+
|
388 |
+
1. Apply the patch torch_profiler.patch. Note that you can modify "if self.tp_rank == 0" in the patch to allow more ranks be recorded in profiling.
|
389 |
+
|
390 |
+
torch_profiler.patch
|
391 |
+
```bash
|
392 |
+
diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py
|
393 |
+
index 62d1ff9..6ecd78c 100644
|
394 |
+
--- a/python/sglang/srt/managers/scheduler.py
|
395 |
+
+++ b/python/sglang/srt/managers/scheduler.py
|
396 |
+
@@ -240,7 +240,6 @@ class Scheduler:
|
397 |
+
)
|
398 |
+
self.profiler = torch.profiler.profile(
|
399 |
+
activities=[
|
400 |
+
- torch.profiler.ProfilerActivity.CPU,
|
401 |
+
torch.profiler.ProfilerActivity.CUDA,
|
402 |
+
],
|
403 |
+
with_stack=True,
|
404 |
+
@@ -1033,9 +1032,11 @@ class Scheduler:
|
405 |
+
if self.profiler is None:
|
406 |
+
raise RuntimeError("Profiler is not enabled.")
|
407 |
+
self.profiler.stop()
|
408 |
+
- self.profiler.export_chrome_trace(
|
409 |
+
- self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
|
410 |
+
- )
|
411 |
+
+ if self.tp_rank == 0:
|
412 |
+
+ with open(f"stats_repro_{int(time.time())}.txt", "w") as f:
|
413 |
+
+ print(self.profiler.key_averages(group_by_input_shape=True).table(sort_by="cuda_time_total", row_limit=-1), file=f)
|
414 |
+
+ print("Profiling stats done.")
|
415 |
+
+
|
416 |
+
logger.info("Profiler is done")
|
417 |
+
```
|
418 |
+
|
419 |
+
2. Create the model path directory and copy it to the right path for "--model-path" if you want to use the server.sh file provided.
|
420 |
+
|
421 |
+
3. Modify the included server.sh by removing "loadTracer.sh" before python command and launch script ./server.sh in one terminal inside the docker container.
|
422 |
+
|
423 |
+
4. Similar to step 6 in RPD profiling section, but remove the last 2 lines in client.sh, which converted rpd file into csv and json files. Run modified client.sh for PyTorch profiling.
|
424 |
+
-------
|
425 |
+
- [Torch Profiler](https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html)
|
sglang/3rdparty/amd/profiling/client.sh
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
# Start profiling via API
|
4 |
+
curl http://localhost:30000/start_profile -H "Content-Type: application/json"
|
5 |
+
|
6 |
+
# Benchmark serving using sglang with random dataset and tokenizer
|
7 |
+
# Define the log file with a timestamp
|
8 |
+
TIMESTAMP=$(date +%Y%m%d_%H%M%S)
|
9 |
+
LOGFILE="sglang_client_log_$TIMESTAMP.json"
|
10 |
+
|
11 |
+
# Run the benchmark with specified parameters and save logs
|
12 |
+
python3 -m sglang.bench_serving \
|
13 |
+
--backend sglang \
|
14 |
+
--tokenizer Xenova/grok-1-tokenizer \
|
15 |
+
--dataset-name random \
|
16 |
+
--random-input 1024\
|
17 |
+
--random-output 1024 \
|
18 |
+
--num-prompts 240 \
|
19 |
+
--request-rate 8 \
|
20 |
+
--output-file online.jsonl 2>&1 | tee "$LOGFILE"
|
21 |
+
|
22 |
+
# Stop profiling via API
|
23 |
+
curl http://localhost:30000/stop_profile -H "Content-Type: application/json"
|
24 |
+
|
25 |
+
# Convert tracing file to csv & json
|
26 |
+
sqlite3 trace.rpd ".mode csv" ".header on" ".output trace.csv" "select * from top;" ".output stdout"
|
27 |
+
python3 /sgl-workspace/rocmProfileData/tools/rpd2tracing.py trace.rpd trace.json
|
sglang/3rdparty/amd/profiling/install_rpd.sh
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# download and install RPD
|
2 |
+
apt update && apt install -y sqlite3 libsqlite3-dev libfmt-dev
|
3 |
+
|
4 |
+
# install rpd module
|
5 |
+
git clone https://github.com/ROCmSoftwarePlatform/rocmProfileData
|
6 |
+
cd rocmProfileData
|
7 |
+
git apply rpd.patch
|
8 |
+
make && make install
|
9 |
+
cd rocpd_python && python setup.py install && cd ..
|
10 |
+
cd rpd_tracer && make clean;make install && python setup.py install && cd ..
|
sglang/3rdparty/amd/profiling/loadTracer.sh
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
################################################################################
|
3 |
+
# Copyright (c) 2021 - 2023 Advanced Micro Devices, Inc. All rights reserved.
|
4 |
+
#
|
5 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
# of this software and associated documentation files (the "Software"), to deal
|
7 |
+
# in the Software without restriction, including without limitation the rights
|
8 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
# copies of the Software, and to permit persons to whom the Software is
|
10 |
+
# furnished to do so, subject to the following conditions:
|
11 |
+
#
|
12 |
+
# The above copyright notice and this permission notice shall be included in
|
13 |
+
# all copies or substantial portions of the Software.
|
14 |
+
#
|
15 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
21 |
+
# THE SOFTWARE.
|
22 |
+
################################################################################
|
23 |
+
OUTPUT_FILE="trace.rpd"
|
24 |
+
|
25 |
+
if [ "$1" = "-o" ] ; then
|
26 |
+
OUTPUT_FILE=$2
|
27 |
+
shift
|
28 |
+
shift
|
29 |
+
fi
|
30 |
+
|
31 |
+
if [ -e ${OUTPUT_FILE} ] ; then
|
32 |
+
rm ${OUTPUT_FILE}
|
33 |
+
fi
|
34 |
+
|
35 |
+
python3 -m rocpd.schema --create ${OUTPUT_FILE}
|
36 |
+
if [ $? != 0 ] ; then
|
37 |
+
echo "Error: Could not create rpd file. Please run 'python setup.py install' from the rocpd_python dir"
|
38 |
+
exit
|
39 |
+
fi
|
40 |
+
|
41 |
+
export RPDT_FILENAME=${OUTPUT_FILE}
|
42 |
+
export RPDT_AUTOSTART=0
|
43 |
+
LD_PRELOAD=librocm-smi_64:librpd_tracer.so "$@"
|
sglang/3rdparty/amd/profiling/rpd.patch
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
diff --git a/rpd_tracer/Makefile b/rpd_tracer/Makefile
|
2 |
+
index e9d9feb..b2e9e1a 100644
|
3 |
+
--- a/rpd_tracer/Makefile
|
4 |
+
+++ b/rpd_tracer/Makefile
|
5 |
+
@@ -16,7 +16,7 @@ ifneq (,$(HIP_PATH))
|
6 |
+
$(info Building with roctracer)
|
7 |
+
RPD_LIBS += -L/opt/rocm/lib -lroctracer64 -lroctx64 -lamdhip64 -lrocm_smi64
|
8 |
+
RPD_INCLUDES += -I/opt/rocm/include -I/opt/rocm/include/roctracer -I/opt/rocm/include/hsa
|
9 |
+
- RPD_SRCS += RoctracerDataSource.cpp RocmSmiDataSource.cpp
|
10 |
+
+ RPD_SRCS += RoctracerDataSource.cpp
|
11 |
+
RPD_INCLUDES += -D__HIP_PLATFORM_AMD__
|
12 |
+
endif
|
sglang/3rdparty/amd/profiling/rpd_profile_server_enable.patch
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py
|
2 |
+
index 62d1ff9..9021c01 100644
|
3 |
+
--- a/python/sglang/srt/managers/scheduler.py
|
4 |
+
+++ b/python/sglang/srt/managers/scheduler.py
|
5 |
+
@@ -71,6 +71,8 @@ from sglang.srt.utils import (
|
6 |
+
suppress_other_loggers,
|
7 |
+
)
|
8 |
+
from sglang.utils import get_exception_traceback
|
9 |
+
+from rpdTracerControl import rpdTracerControl
|
10 |
+
+rpdTracerControl.skipCreate()
|
11 |
+
|
12 |
+
logger = logging.getLogger(__name__)
|
13 |
+
|
14 |
+
@@ -245,6 +247,7 @@ class Scheduler:
|
15 |
+
],
|
16 |
+
with_stack=True,
|
17 |
+
)
|
18 |
+
+ self.rpd = rpdTracerControl()
|
19 |
+
|
20 |
+
@torch.inference_mode()
|
21 |
+
def event_loop(self):
|
22 |
+
@@ -1027,15 +1030,24 @@ class Scheduler:
|
23 |
+
def start_profile(self) -> None:
|
24 |
+
if self.profiler is None:
|
25 |
+
raise RuntimeError("Profiler is not enabled.")
|
26 |
+
- self.profiler.start()
|
27 |
+
+ #self.profiler.start() #block pytorch profiler for rpd profiler enabling
|
28 |
+
+ if self.tp_rank == 0 or self.tp_rank == 1:
|
29 |
+
+ self.rpd.start()
|
30 |
+
+ self.rpd.rangePush("", "rpd profile range", "")
|
31 |
+
+ logger.info("rpd is enabled")
|
32 |
+
|
33 |
+
def stop_profile(self) -> None:
|
34 |
+
if self.profiler is None:
|
35 |
+
raise RuntimeError("Profiler is not enabled.")
|
36 |
+
- self.profiler.stop()
|
37 |
+
- self.profiler.export_chrome_trace(
|
38 |
+
- self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
|
39 |
+
- )
|
40 |
+
+ #self.profiler.stop()
|
41 |
+
+ #self.profiler.export_chrome_trace(
|
42 |
+
+ # self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
|
43 |
+
+ #)
|
44 |
+
+ if self.tp_rank ==0 or self.tp_rank ==1:
|
45 |
+
+ self.rpd.rangePop()
|
46 |
+
+ self.rpd.stop()
|
47 |
+
+ self.rpd.flush()
|
48 |
+
+ logger.info("rpd is done")
|
49 |
+
logger.info("Profiler is done")
|
sglang/3rdparty/amd/profiling/rpd_profile_server_enable_wCPU_activities.patch
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py
|
2 |
+
index 62d1ff9..2edb427 100644
|
3 |
+
--- a/python/sglang/srt/managers/scheduler.py
|
4 |
+
+++ b/python/sglang/srt/managers/scheduler.py
|
5 |
+
@@ -71,6 +71,8 @@ from sglang.srt.utils import (
|
6 |
+
suppress_other_loggers,
|
7 |
+
)
|
8 |
+
from sglang.utils import get_exception_traceback
|
9 |
+
+from rpdTracerControl import rpdTracerControl
|
10 |
+
+rpdTracerControl.skipCreate()
|
11 |
+
|
12 |
+
logger = logging.getLogger(__name__)
|
13 |
+
|
14 |
+
@@ -245,6 +247,7 @@ class Scheduler:
|
15 |
+
],
|
16 |
+
with_stack=True,
|
17 |
+
)
|
18 |
+
+ self.rpd = rpdTracerControl()
|
19 |
+
|
20 |
+
@torch.inference_mode()
|
21 |
+
def event_loop(self):
|
22 |
+
@@ -1027,15 +1030,26 @@ class Scheduler:
|
23 |
+
def start_profile(self) -> None:
|
24 |
+
if self.profiler is None:
|
25 |
+
raise RuntimeError("Profiler is not enabled.")
|
26 |
+
- self.profiler.start()
|
27 |
+
+ #self.profiler.start()
|
28 |
+
+ logger.info("torch profiler is disabled")
|
29 |
+
+ if self.tp_rank == 0 or self.tp_rank == 1:
|
30 |
+
+ self.rpd.setPythonTrace(True)
|
31 |
+
+ self.rpd.start()
|
32 |
+
+ self.rpd.rangePush("", "scheduler", "")
|
33 |
+
+ logger.info("rpd is enabled inside scheduler profiling")
|
34 |
+
|
35 |
+
def stop_profile(self) -> None:
|
36 |
+
if self.profiler is None:
|
37 |
+
raise RuntimeError("Profiler is not enabled.")
|
38 |
+
- self.profiler.stop()
|
39 |
+
- self.profiler.export_chrome_trace(
|
40 |
+
- self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
|
41 |
+
- )
|
42 |
+
+ #self.profiler.stop()
|
43 |
+
+ #self.profiler.export_chrome_trace(
|
44 |
+
+ # self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
|
45 |
+
+ #)
|
46 |
+
+ if self.tp_rank ==0 or self.tp_rank ==1:
|
47 |
+
+ self.rpd.rangePop()
|
48 |
+
+ self.rpd.stop()
|
49 |
+
+ self.rpd.flush()
|
50 |
+
+ logger.info("rpd is done inside scheduler")
|
51 |
+
logger.info("Profiler is done")
|
52 |
+
|
53 |
+
|
54 |
+
diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py
|
55 |
+
index 2621ccd..181df85 100644
|
56 |
+
--- a/python/sglang/srt/managers/tokenizer_manager.py
|
57 |
+
+++ b/python/sglang/srt/managers/tokenizer_manager.py
|
58 |
+
@@ -58,6 +58,10 @@ from sglang.srt.sampling.sampling_params import SamplingParams
|
59 |
+
from sglang.srt.server_args import PortArgs, ServerArgs
|
60 |
+
from sglang.srt.utils import is_generation_model, is_multimodal_model
|
61 |
+
|
62 |
+
+from rpdTracerControl import rpdTracerControl
|
63 |
+
+rpdTracerControl.skipCreate()
|
64 |
+
+
|
65 |
+
+
|
66 |
+
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
67 |
+
|
68 |
+
logger = logging.getLogger(__name__)
|
69 |
+
@@ -514,10 +518,20 @@ class TokenizerManager:
|
70 |
+
self.send_to_scheduler.send_pyobj(req)
|
71 |
+
|
72 |
+
def start_profile(self):
|
73 |
+
+ rpd = rpdTracerControl()
|
74 |
+
+ rpd.setPythonTrace(True)
|
75 |
+
+ rpd.start()
|
76 |
+
+ rpd.rangePush("", "tokenizer_manager", "")
|
77 |
+
+ logger.info("tokenizer_manager rpd profiling started!")
|
78 |
+
req = ProfileReq.START_PROFILE
|
79 |
+
self.send_to_scheduler.send_pyobj(req)
|
80 |
+
|
81 |
+
def stop_profile(self):
|
82 |
+
+ rpd = rpdTracerControl()
|
83 |
+
+ rpd.rangePop()
|
84 |
+
+ rpd.stop()
|
85 |
+
+ rpd.flush()
|
86 |
+
+ logger.info("rpd profiling is done inside tokenizer_manager!")
|
87 |
+
req = ProfileReq.STOP_PROFILE
|
88 |
+
self.send_to_scheduler.send_pyobj(req)
|
89 |
+
|
90 |
+
diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py
|
91 |
+
index 7111c93..2bd722c 100644
|
92 |
+
--- a/python/sglang/srt/server.py
|
93 |
+
+++ b/python/sglang/srt/server.py
|
94 |
+
@@ -30,6 +30,8 @@ import threading
|
95 |
+
import time
|
96 |
+
from http import HTTPStatus
|
97 |
+
from typing import Dict, List, Optional, Union
|
98 |
+
+from rpdTracerControl import rpdTracerControl
|
99 |
+
+rpdTracerControl.skipCreate()
|
100 |
+
|
101 |
+
# Fix a bug of Python threading
|
102 |
+
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
103 |
+
@@ -152,6 +154,11 @@ async def flush_cache():
|
104 |
+
@app.post("/start_profile")
|
105 |
+
async def start_profile():
|
106 |
+
"""Start profiling."""
|
107 |
+
+ rpd = rpdTracerControl()
|
108 |
+
+ rpd.setPythonTrace(True)
|
109 |
+
+ rpd.start()
|
110 |
+
+ rpd.rangePush("", "server rpd profile range", "")
|
111 |
+
+ logger.info("rpd profiling started in server.py!")
|
112 |
+
tokenizer_manager.start_profile()
|
113 |
+
return Response(
|
114 |
+
content="Start profiling.\n",
|
115 |
+
@@ -164,6 +171,11 @@ async def start_profile():
|
116 |
+
async def stop_profile():
|
117 |
+
"""Stop profiling."""
|
118 |
+
tokenizer_manager.stop_profile()
|
119 |
+
+ rpd = rpdTracerControl()
|
120 |
+
+ rpd.rangePop()
|
121 |
+
+ rpd.stop()
|
122 |
+
+ rpd.flush()
|
123 |
+
+ logger.info("rpd profiling is done in server.py!")
|
124 |
+
return Response(
|
125 |
+
content="Stop profiling. This will take some time.\n",
|
126 |
+
status_code=200,
|
sglang/3rdparty/amd/profiling/server.sh
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
# export SGLANG_TORCH_PROFILER_DIR=/data/sglang/
|
4 |
+
export SGLANG_TORCH_PROFILER_DIR=/sgl-workspace/sglang/profile/
|
5 |
+
|
6 |
+
# Get the current timestamp
|
7 |
+
TIMESTAMP=$(date +"%Y%m%d_%H%M%S")
|
8 |
+
|
9 |
+
# Define the log file with a timestamp
|
10 |
+
LOGFILE="sglang_server_log_$TIMESTAMP.json"
|
11 |
+
|
12 |
+
# Run the Python command and save the output to the log file
|
13 |
+
loadTracer.sh python3 -m sglang.launch_server \
|
14 |
+
--model-path /sgl-workspace/sglang/dummy_grok1 \
|
15 |
+
--tokenizer-path Xenova/grok-1-tokenizer \
|
16 |
+
--load-format dummy \
|
17 |
+
--quant fp8 \
|
18 |
+
--tp 8 \
|
19 |
+
--port 30000 \
|
20 |
+
--disable-radix-cache 2>&1 | tee "$LOGFILE"
|
sglang/3rdparty/amd/tuning/TUNING.md
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Tuning SGLang Infer System with AMD GPUs
|
2 |
+
This AppNote describes the SGLang performance tuning technical, code harness and running steps for systems with AMD Instinct GPUs.
|
3 |
+
Harness code, examples and steps are provided in detail, to facilitate easy reproduce & use to tune performance towards workloads.
|
4 |
+
Three primary runtime areas are covered:
|
5 |
+
|
6 |
+
## 1. Triton Kernels
|
7 |
+
To maximize Triton kernel efficiency, several strategies can be employed:
|
8 |
+
|
9 |
+
### Key Environment Variables:
|
10 |
+
- **num_stages**: Adjusts the number of pipeline stages to optimize kernel efficiency based on the specific type of operations (e.g., General Matrix Multiplication - GEMM).
|
11 |
+
- **waves_per_eu**: Controls the usage of Vector General Purpose Registers (VGPR) to enhance occupancy, thereby improving latency or throughput.
|
12 |
+
- **BLOCK_M, BLOCK_N, BLOCK_K**: Tunable tile sizes that assist in balancing memory transfer and computational efficiency.
|
13 |
+
- **matrix_instr_nonkdim**: Optimizes the usage of Matrix-Fused Multiply-Add (MFMA) instructions for specific kernel types, such as Flash Attention.
|
14 |
+
- **OPTIMIZE_EPILOGUE**: An environment variable that can be set to `1` to enhance performance by eliminating the `convert_layout` operation in the kernel's epilogue.
|
15 |
+
```python
|
16 |
+
@triton.autotune(configs=[
|
17 |
+
triton.Config({'waves_per_eu': 1}, num_warps=4, num_stages=1),
|
18 |
+
triton.Config({'waves_per_eu': 1}, num_warps=8, num_stages=1),
|
19 |
+
triton.Config({'waves_per_eu': 1}, num_warps=16, num_stages=1),
|
20 |
+
triton.Config({'waves_per_eu': 2}, num_warps=4, num_stages=1),
|
21 |
+
triton.Config({'waves_per_eu': 2}, num_warps=8, num_stages=1),
|
22 |
+
triton.Config({'waves_per_eu': 2}, num_warps=16, num_stages=1),
|
23 |
+
triton.Config({'waves_per_eu': 4}, num_warps=4, num_stages=1),
|
24 |
+
triton.Config({'waves_per_eu': 4}, num_warps=8, num_stages=1),
|
25 |
+
triton.Config({'waves_per_eu': 4}, num_warps=16, num_stages=1),
|
26 |
+
], key=['BLOCK_N', 'NUM_TOKEN_BLKS'], use_cuda_graph=True)
|
27 |
+
@triton.jit
|
28 |
+
def _triton_kernel_funtion():
|
29 |
+
...
|
30 |
+
```
|
31 |
+
## 2. Torch Tunable Operations
|
32 |
+
**TunableOp** is a feature in PyTorch that allows for the definition and optimization of custom kernels with tunable parameters. This feature is particularly useful for enhancing the performance of kernels by experimenting with different configurations.
|
33 |
+
|
34 |
+
### Key Environment Variables:
|
35 |
+
1. **PYTORCH_TUNABLEOP_ENABLED**:
|
36 |
+
- Default: `0`
|
37 |
+
- Set to `1` to enable TunableOp.
|
38 |
+
|
39 |
+
2. **PYTORCH_TUNABLEOP_TUNING**:
|
40 |
+
- Default: `1`
|
41 |
+
- Set to `0` to disable tuning. If a tuned entry is not found, it will run the tuning step and record the entry when PYTORCH_TUNABLEOP_ENABLED is enabled.
|
42 |
+
|
43 |
+
3. **PYTORCH_TUNABLEOP_VERBOSE**:
|
44 |
+
- Default: `0`
|
45 |
+
- Set to `1` to enable verbose output for TunableOp.
|
46 |
+
|
47 |
+
### Usage Example:
|
48 |
+
To enable TunableOp and tuning, and optionally enable verbose mode, you can run the following command in your terminal:
|
49 |
+
|
50 |
+
```bash
|
51 |
+
#Tuning
|
52 |
+
PYTORCH_TUNABLEOP_ENABLED=1 PYTORCH_TUNABLEOP_TUNING=1 your_script.sh
|
53 |
+
|
54 |
+
#Inference with tuning op
|
55 |
+
PYTORCH_TUNABLEOP_ENABLED=1 PYTORCH_TUNABLEOP_TUNING=0 your_script.sh
|
56 |
+
|
57 |
+
#Print out the log
|
58 |
+
PYTORCH_TUNABLEOP_ENABLED=1 PYTORCH_TUNABLEOP_TUNING=0 PYTORCH_TUNABLEOP_VERBOSE=1 your_script.sh
|
59 |
+
|
60 |
+
```
|
61 |
+
## 3. Torch Compilation
|
62 |
+
|
63 |
+
|
64 |
+
The following are suggestions for optimizing matrix multiplication (GEMM) and convolution (conv) operations in PyTorch using Inductor, a part of the PyTorch compilation framework. The goal is to leverage Triton to achieve better performance.
|
65 |
+
|
66 |
+
To tune Triton kernels with GEMM and convolution ops (conv), use the `torch.compile` function with the max-autotune mode. This benchmarks a predefined list of Triton configurations and selects the fastest one for each shape.
|
67 |
+
|
68 |
+
### Key Configurations:
|
69 |
+
1. **Max Autotune**:
|
70 |
+
- Set `torch._inductor.config.max_autotune = True` or `TORCHINDUCTOR_MAX_AUTOTUNE=1`.
|
71 |
+
|
72 |
+
2. **Fine-Grained Control**:
|
73 |
+
- Enable GEMM tuning: `torch._inductor.config.max_autotune_gemm = True`.
|
74 |
+
- Enable tuning for pointwise/reduction ops: `torch._inductor.config.max_autotune.pointwise = True`.
|
75 |
+
|
76 |
+
3. **Backend Selection**:
|
77 |
+
- Use `torch._inductor.max_autotune_gemm_backends` to limit backends to TRITON for better performance.
|
78 |
+
|
79 |
+
4. **Freezing for Inference**:
|
80 |
+
- Use `torch._inductor.config.freezing=True` to enable constant folding optimizations.
|
81 |
+
|
82 |
+
5. **Debugging**:
|
83 |
+
- Set `TORCH_COMPILE_DEBUG=1` to extract Triton kernels generated by Inductor.
|
84 |
+
|
85 |
+
### Example Code Block:
|
86 |
+
```bash
|
87 |
+
#Gemm Tuning
|
88 |
+
TORCHINDUCTOR_MAX_AUTOTUNE=1 TORCHINDUCTOR_COORDINATE_DESCENT_TUNING=1 your_script.sh
|
89 |
+
|
90 |
+
#Specify your backend to TRITON for Gemm Tuning
|
91 |
+
TORCHINDUCTOR_MAX_AUTOTUNE=1 TORCHINDUCTOR_COORDINATE_DESCENT_TUNING=1 TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_BACKENDS=TRITON your_script.sh
|
92 |
+
|
93 |
+
#Inference with large improvement on AMD GPU
|
94 |
+
TORCHINDUCTOR_FREEZING=1 your_script.sh
|
95 |
+
```
|
96 |
+
## 4. Fused MOE kernel
|
97 |
+
To maximize moe kernel efficiency, need to use below scripts to find out the best launch configuration
|
98 |
+
|
99 |
+
### Key parameters:
|
100 |
+
- **--model**: what moe model type to do tuning, it will automatically decide the size of d_model, model_intermediate_size, num_layers
|
101 |
+
- **--tp-size**: simulate the whole model run configuration to set the dimension size using tp correctly
|
102 |
+
- **--batch**: M dimension size of moe kernel, for prefill moe kernel the value is batch*input_len, for decode moe kernel the value is batch
|
103 |
+
- **--dtype**: computation type
|
104 |
+
|
105 |
+
```bash
|
106 |
+
#Tuning
|
107 |
+
#for example, we have one case like this "python3 -m sglang.bench_latency --model dummy_grok1/ --load-format dummy --tokenizer-path Xenova/grok-1-tokenizer --tp 8 --batch-size 32 --input 1024 --output 8 --attention-backend triton --sampling-backend pytorch --quant fp" to run, it defined batch-size 32 input lenth 1024 and output length 8, from "--batch" in moe view point, the prefill batch is 32*1024 = 32768, the decode batch is 32*1(only one output token generated in each run).
|
108 |
+
#so we can tune decode moe use below command
|
109 |
+
python benchmark_moe_rocm.py --model grok1 --tp-size 8 --dtype float8 --batch "32"
|
110 |
+
# and use this command to tune prefill moe
|
111 |
+
python benchmark_moe_rocm.py --model grok1 --tp-size 8 --dtype float8 --batch "32768"
|
112 |
+
```
|
113 |
+
|
114 |
+
## Reference
|
115 |
+
|
116 |
+
For more detailed information on tuning SGLang performance with AMD GPUs, please refer to the following link:
|
117 |
+
|
118 |
+
[ROCm Documentation: Triton Kernel Performance Optimization](https://rocm.docs.amd.com/en/latest/how-to/tuning-guides/mi300x/workload.html#triton-kernel-performance-optimization)
|
sglang/3rdparty/amd/tuning/benchmark_moe_rocm.py
ADDED
@@ -0,0 +1,377 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import sys
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
import triton
|
9 |
+
import triton.language as tl
|
10 |
+
from tqdm import tqdm
|
11 |
+
from transformers import AutoConfig
|
12 |
+
|
13 |
+
from sglang.srt.layers.fused_moe_triton.fused_moe import fused_moe, get_config_file_name
|
14 |
+
|
15 |
+
padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0
|
16 |
+
|
17 |
+
|
18 |
+
def main(model, tp_size, dtype: str, batches):
|
19 |
+
method = fused_moe
|
20 |
+
|
21 |
+
for bs in batches:
|
22 |
+
run_grid(int(bs), model=model, method=method, tp_size=tp_size, dtype=dtype)
|
23 |
+
|
24 |
+
|
25 |
+
def prune_configs(M, N, K, configs):
|
26 |
+
pruned_configs = []
|
27 |
+
elemBytes_a = 1 # [DV Note] Hard-coded for float16 (2 bytes)
|
28 |
+
elemBytes_b = 1 # [DV Note] Hard-coded for float16 (2 bytes)
|
29 |
+
|
30 |
+
mfma = 16 if M < 32 or N < 32 else 32
|
31 |
+
|
32 |
+
# TODO (zhanglx): figure out the boundary between large and small gemms
|
33 |
+
large_gemm = False
|
34 |
+
if M >= 2048 and N >= 2048:
|
35 |
+
large_gemm = True
|
36 |
+
|
37 |
+
for config in configs:
|
38 |
+
BLOCK_SIZE_M = config.get("BLOCK_SIZE_M")
|
39 |
+
BLOCK_SIZE_N = config.get("BLOCK_SIZE_N")
|
40 |
+
BLOCK_SIZE_K = config.get("BLOCK_SIZE_K")
|
41 |
+
num_warps = config.get("num_warps")
|
42 |
+
matrix_instr_nonkdim = config.get("matrix_instr_nonkdim")
|
43 |
+
# kpack = config.get("kpack")
|
44 |
+
if matrix_instr_nonkdim > mfma:
|
45 |
+
continue
|
46 |
+
if mfma == 4 and BLOCK_SIZE_K < 64:
|
47 |
+
continue
|
48 |
+
# some layouts could not work properly in case
|
49 |
+
# number elements per thread is less 1
|
50 |
+
if BLOCK_SIZE_M * BLOCK_SIZE_N < 64:
|
51 |
+
continue
|
52 |
+
SPLIT_K = 1 # config.get("SPLIT_K")
|
53 |
+
GROUP_M = config.get("GROUP_SIZE_M")
|
54 |
+
if matrix_instr_nonkdim > BLOCK_SIZE_M or matrix_instr_nonkdim > BLOCK_SIZE_N:
|
55 |
+
continue
|
56 |
+
if matrix_instr_nonkdim >= M and matrix_instr_nonkdim != BLOCK_SIZE_M:
|
57 |
+
continue
|
58 |
+
if matrix_instr_nonkdim >= N and matrix_instr_nonkdim != BLOCK_SIZE_N:
|
59 |
+
continue
|
60 |
+
# Skip BLOCK_SIZE that is too large compare to M/N
|
61 |
+
# unless BLOCK_SIZE is already small enough
|
62 |
+
if M * 2 < BLOCK_SIZE_M and BLOCK_SIZE_M != 16:
|
63 |
+
continue
|
64 |
+
if N * 2 < BLOCK_SIZE_N and BLOCK_SIZE_N != 16:
|
65 |
+
continue
|
66 |
+
# skip large split_k when not necessary
|
67 |
+
if SPLIT_K != 1 and not need_split_k(M, N, K):
|
68 |
+
continue
|
69 |
+
# skip split_k that leads to EVEN_K = false
|
70 |
+
leap = SPLIT_K * BLOCK_SIZE_K
|
71 |
+
modv = K % leap
|
72 |
+
if modv != 0:
|
73 |
+
continue
|
74 |
+
# skip large GROUP_M
|
75 |
+
if GROUP_M * BLOCK_SIZE_M > M and GROUP_M != 1:
|
76 |
+
continue
|
77 |
+
# out of shared memory resource
|
78 |
+
# TODO (zhanglx): This does not consider the LDS usage in the epilogue
|
79 |
+
LDS = (
|
80 |
+
BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a
|
81 |
+
+ BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b
|
82 |
+
)
|
83 |
+
if LDS > 65536:
|
84 |
+
continue
|
85 |
+
# Skip small block sizes and num_warps for large gemm
|
86 |
+
# For fp16 and f8, we want to only use BLOCK_SIZE >= 64
|
87 |
+
if large_gemm:
|
88 |
+
if BLOCK_SIZE_M < 64 or BLOCK_SIZE_N < 64:
|
89 |
+
continue
|
90 |
+
if BLOCK_SIZE_K < 64:
|
91 |
+
continue
|
92 |
+
if num_warps < 4:
|
93 |
+
continue
|
94 |
+
|
95 |
+
pruned_configs.append(config)
|
96 |
+
|
97 |
+
return pruned_configs
|
98 |
+
|
99 |
+
|
100 |
+
def union_of_list_of_dicts(l1, l2):
|
101 |
+
result = []
|
102 |
+
temp_list = l1.copy()
|
103 |
+
temp_list.extend(l2)
|
104 |
+
for myDict in temp_list:
|
105 |
+
if myDict not in result:
|
106 |
+
result.append(myDict)
|
107 |
+
|
108 |
+
return result
|
109 |
+
|
110 |
+
|
111 |
+
def run_grid(bs, model, method, tp_size, dtype: str):
|
112 |
+
|
113 |
+
config = AutoConfig.from_pretrained(model)
|
114 |
+
|
115 |
+
top_k = config.num_experts_per_tok
|
116 |
+
d_model = config.hidden_size
|
117 |
+
model_intermediate_size = config.intermediate_size
|
118 |
+
num_layers = config.num_hidden_layers
|
119 |
+
hidden_states_dtype = config.torch_dtype
|
120 |
+
|
121 |
+
if config.num_experts_per_tok:
|
122 |
+
if config.architectures[0] == "Grok1ModelForCausalLM":
|
123 |
+
num_total_experts = config.num_experts
|
124 |
+
else:
|
125 |
+
num_total_experts = config.num_local_experts
|
126 |
+
else:
|
127 |
+
raise ValueError(f"Unsupported Mixtral model {model}")
|
128 |
+
|
129 |
+
# tp_size = 2
|
130 |
+
num_warmup_calls = 10
|
131 |
+
num_calls = 30
|
132 |
+
|
133 |
+
num_warmup_trials = 1
|
134 |
+
num_trials = 1
|
135 |
+
|
136 |
+
full_configs = []
|
137 |
+
|
138 |
+
block_m_range = [16, 32, 64, 128, 256]
|
139 |
+
block_n_range = [16, 32, 64, 128, 256]
|
140 |
+
block_k_range = [32, 64, 128, 256] # MUST >= 32
|
141 |
+
num_warps_range = [1, 2, 4, 8]
|
142 |
+
group_m_range = [1, 4, 8, 16, 32]
|
143 |
+
# For now we see better perf with num_stages=0 for all gemm configs we care
|
144 |
+
# But keep this explicit so that we do not forget we may need to set it to
|
145 |
+
# other values in the future
|
146 |
+
num_stage_range = [2]
|
147 |
+
waves_per_eu_range = [0, 1, 2, 4, 8]
|
148 |
+
# Remove 32 because of triton compiling error
|
149 |
+
matrix_instr_nonkdim_range = [16]
|
150 |
+
kpack_range = [1, 2]
|
151 |
+
|
152 |
+
for block_size_m in block_m_range:
|
153 |
+
for block_size_n in block_n_range:
|
154 |
+
for block_size_k in block_k_range:
|
155 |
+
for group_size_m in group_m_range:
|
156 |
+
for num_warps in num_warps_range:
|
157 |
+
for num_stages in num_stage_range:
|
158 |
+
for waves_per_eu in waves_per_eu_range:
|
159 |
+
for matrix_instr_nonkdim in matrix_instr_nonkdim_range:
|
160 |
+
for kpack in kpack_range:
|
161 |
+
full_configs.append(
|
162 |
+
{
|
163 |
+
"BLOCK_SIZE_M": block_size_m,
|
164 |
+
"BLOCK_SIZE_N": block_size_n,
|
165 |
+
"BLOCK_SIZE_K": block_size_k,
|
166 |
+
"GROUP_SIZE_M": group_size_m,
|
167 |
+
"num_warps": num_warps,
|
168 |
+
"num_stages": num_stages,
|
169 |
+
"waves_per_eu": waves_per_eu,
|
170 |
+
"matrix_instr_nonkdim": matrix_instr_nonkdim,
|
171 |
+
"kpack": kpack,
|
172 |
+
}
|
173 |
+
)
|
174 |
+
|
175 |
+
M1 = bs * 2
|
176 |
+
N1 = model_intermediate_size * 2 // tp_size
|
177 |
+
K1 = d_model
|
178 |
+
prune_configs_1 = prune_configs(M1, N1, K1, full_configs)
|
179 |
+
|
180 |
+
M2 = bs * 2
|
181 |
+
N2 = d_model
|
182 |
+
K2 = model_intermediate_size // tp_size
|
183 |
+
prune_configs_2 = prune_configs(M2, N2, K2, full_configs)
|
184 |
+
|
185 |
+
configs = union_of_list_of_dicts(prune_configs_1, prune_configs_2)
|
186 |
+
|
187 |
+
print(
|
188 |
+
f"{bs=} || {len(full_configs)=} | {len(prune_configs_1)=} | \
|
189 |
+
{len(prune_configs_2)=} | {len(configs)=}"
|
190 |
+
)
|
191 |
+
|
192 |
+
best_config = None
|
193 |
+
best_time_us = 1e20
|
194 |
+
|
195 |
+
print(f"{tp_size=} {bs=}")
|
196 |
+
|
197 |
+
for config in tqdm(configs):
|
198 |
+
# warmup
|
199 |
+
try:
|
200 |
+
print(config)
|
201 |
+
for _ in range(num_warmup_trials):
|
202 |
+
run_timing(
|
203 |
+
num_calls=num_warmup_calls,
|
204 |
+
bs=bs,
|
205 |
+
d_model=d_model,
|
206 |
+
num_total_experts=num_total_experts,
|
207 |
+
top_k=top_k,
|
208 |
+
tp_size=tp_size,
|
209 |
+
model_intermediate_size=model_intermediate_size,
|
210 |
+
method=method,
|
211 |
+
config=config,
|
212 |
+
dtype=dtype,
|
213 |
+
hidden_states_dtype=hidden_states_dtype,
|
214 |
+
)
|
215 |
+
except triton.runtime.autotuner.OutOfResources:
|
216 |
+
continue
|
217 |
+
|
218 |
+
# trial
|
219 |
+
for _ in range(num_trials):
|
220 |
+
kernel_dur_ms = run_timing(
|
221 |
+
num_calls=num_calls,
|
222 |
+
bs=bs,
|
223 |
+
d_model=d_model,
|
224 |
+
num_total_experts=num_total_experts,
|
225 |
+
top_k=top_k,
|
226 |
+
tp_size=tp_size,
|
227 |
+
model_intermediate_size=model_intermediate_size,
|
228 |
+
method=method,
|
229 |
+
config=config,
|
230 |
+
dtype=dtype,
|
231 |
+
hidden_states_dtype=hidden_states_dtype,
|
232 |
+
)
|
233 |
+
|
234 |
+
kernel_dur_us = 1000 * kernel_dur_ms
|
235 |
+
model_dur_ms = kernel_dur_ms * num_layers
|
236 |
+
|
237 |
+
if kernel_dur_us < best_time_us:
|
238 |
+
best_config = config
|
239 |
+
best_time_us = kernel_dur_us
|
240 |
+
|
241 |
+
tqdm.write(
|
242 |
+
f"{kernel_dur_us=:.1f} {model_dur_ms=:.1f}"
|
243 |
+
f" {bs=} {tp_size=} {top_k=} {num_total_experts=} "
|
244 |
+
f"{d_model=} {model_intermediate_size=} {num_layers=}"
|
245 |
+
)
|
246 |
+
|
247 |
+
print("best_time_us", best_time_us)
|
248 |
+
print("best_config", best_config)
|
249 |
+
|
250 |
+
# holds Dict[str, Dict[str, int]]
|
251 |
+
filename = get_config_file_name(
|
252 |
+
num_total_experts,
|
253 |
+
model_intermediate_size // tp_size,
|
254 |
+
"float8" if dtype == "float8" else None,
|
255 |
+
)
|
256 |
+
print(f"writing config to file {filename}")
|
257 |
+
existing_content = {}
|
258 |
+
if os.path.exists(filename):
|
259 |
+
with open(filename, "r") as f:
|
260 |
+
existing_content = json.load(f)
|
261 |
+
existing_content[str(bs)] = best_config
|
262 |
+
with open(filename, "w") as f:
|
263 |
+
json.dump(existing_content, f, indent=4)
|
264 |
+
f.write("\n")
|
265 |
+
|
266 |
+
|
267 |
+
def run_timing(
|
268 |
+
num_calls: int,
|
269 |
+
bs: int,
|
270 |
+
d_model: int,
|
271 |
+
num_total_experts: int,
|
272 |
+
top_k: int,
|
273 |
+
tp_size: int,
|
274 |
+
model_intermediate_size: int,
|
275 |
+
method,
|
276 |
+
config,
|
277 |
+
dtype: str,
|
278 |
+
hidden_states_dtype,
|
279 |
+
) -> float:
|
280 |
+
shard_intermediate_size = model_intermediate_size // tp_size
|
281 |
+
|
282 |
+
hidden_states = torch.rand(
|
283 |
+
(bs, d_model),
|
284 |
+
device="cuda:0",
|
285 |
+
dtype=hidden_states_dtype,
|
286 |
+
)
|
287 |
+
|
288 |
+
w1 = torch.rand(
|
289 |
+
(num_total_experts, 2 * shard_intermediate_size, d_model + padding_size),
|
290 |
+
device=hidden_states.device,
|
291 |
+
dtype=hidden_states.dtype,
|
292 |
+
)
|
293 |
+
|
294 |
+
w2 = torch.rand(
|
295 |
+
(num_total_experts, d_model, shard_intermediate_size + padding_size),
|
296 |
+
device=hidden_states.device,
|
297 |
+
dtype=hidden_states.dtype,
|
298 |
+
)
|
299 |
+
|
300 |
+
w1_scale = None
|
301 |
+
w2_scale = None
|
302 |
+
a1_scale = None
|
303 |
+
a2_scale = None
|
304 |
+
|
305 |
+
if dtype == "float8":
|
306 |
+
w1 = w1.to(torch.float8_e4m3fnuz)
|
307 |
+
w2 = w2.to(torch.float8_e4m3fnuz)
|
308 |
+
w1_scale = torch.ones(
|
309 |
+
num_total_experts, device=hidden_states.device, dtype=torch.float32
|
310 |
+
)
|
311 |
+
w2_scale = torch.ones(
|
312 |
+
num_total_experts, device=hidden_states.device, dtype=torch.float32
|
313 |
+
)
|
314 |
+
a1_scale = torch.ones(1, device=hidden_states.device, dtype=torch.float32)
|
315 |
+
a2_scale = torch.ones(1, device=hidden_states.device, dtype=torch.float32)
|
316 |
+
|
317 |
+
gating_output = F.softmax(
|
318 |
+
torch.rand(
|
319 |
+
(num_calls, bs, num_total_experts),
|
320 |
+
device=hidden_states.device,
|
321 |
+
dtype=torch.float32,
|
322 |
+
),
|
323 |
+
dim=-1,
|
324 |
+
)
|
325 |
+
|
326 |
+
##################################
|
327 |
+
|
328 |
+
start_event = torch.cuda.Event(enable_timing=True)
|
329 |
+
end_event = torch.cuda.Event(enable_timing=True)
|
330 |
+
|
331 |
+
start_event.record()
|
332 |
+
for i in range(num_calls):
|
333 |
+
hidden_states = method(
|
334 |
+
hidden_states=hidden_states,
|
335 |
+
w1=w1,
|
336 |
+
w2=w2,
|
337 |
+
w1_scale=w1_scale,
|
338 |
+
w2_scale=w2_scale,
|
339 |
+
a1_scale=a1_scale,
|
340 |
+
a2_scale=a2_scale,
|
341 |
+
gating_output=gating_output[0],
|
342 |
+
topk=top_k,
|
343 |
+
renormalize=True,
|
344 |
+
inplace=True,
|
345 |
+
override_config=config,
|
346 |
+
use_fp8=dtype == "float8",
|
347 |
+
)
|
348 |
+
|
349 |
+
end_event.record()
|
350 |
+
end_event.synchronize()
|
351 |
+
|
352 |
+
dur_ms = start_event.elapsed_time(end_event) / num_calls
|
353 |
+
return dur_ms
|
354 |
+
|
355 |
+
|
356 |
+
if __name__ == "__main__":
|
357 |
+
parser = argparse.ArgumentParser(
|
358 |
+
prog="benchmark_mixtral_moe",
|
359 |
+
description="Benchmark and tune the fused_moe kernel",
|
360 |
+
)
|
361 |
+
parser.add_argument(
|
362 |
+
"--dtype",
|
363 |
+
type=str,
|
364 |
+
default="auto",
|
365 |
+
choices=["float8", "float16", "bfloat16"],
|
366 |
+
help="Data type used for fused_moe kernel computations",
|
367 |
+
)
|
368 |
+
parser.add_argument("--model", type=str, default="hpcai-tech/grok-1")
|
369 |
+
|
370 |
+
parser.add_argument("--tp-size", type=int, default=2, help="Tensor paralleli size")
|
371 |
+
parser.add_argument("-b", "--batches", type=str)
|
372 |
+
|
373 |
+
args = parser.parse_args()
|
374 |
+
|
375 |
+
batches = args.batches.split(",")
|
376 |
+
|
377 |
+
sys.exit(main(args.model, args.tp_size, args.dtype, batches))
|
sglang/benchmark/blog_v0_2/405b_sglang.sh
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Create dummy weights:
|
2 |
+
# 1. Create a folder `~/llama-3.1-405b-fp8-dummy` and create `config.json` and tokenizer under this folder.
|
3 |
+
# 2. Get `config.json`` from ./config.md
|
4 |
+
# 3. Download the tokenizer
|
5 |
+
# wget https://huggingface.co/neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8/resolve/main/tokenizer.json
|
6 |
+
# wget https://huggingface.co/neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8/resolve/main/tokenizer_config.json
|
7 |
+
|
8 |
+
# Launch sglang
|
9 |
+
# python -m sglang.launch_server --model-path ~/llama-3.1-405b-fp8-dummy/ --load-format dummy --tp 8 --quant fp8 --disable-radix --mem-frac 0.87
|
10 |
+
|
11 |
+
# offline
|
12 |
+
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 3000 --random-input 1024 --random-output 1024 > sglang_log11
|
13 |
+
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 4000 --random-input 1024 --random-output 512 > sglang_log12
|
14 |
+
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 800 --random-input 4096 --random-output 2048 > sglang_log13
|
15 |
+
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 1500 --random-input 4096 --random-output 1024 > sglang_log14
|
16 |
+
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 6000 --random-input 256 --random-output 512 > sglang_log15
|
17 |
+
python3 -m sglang.bench_serving --backend sglang --dataset-name sharegpt --num-prompt 2000 > sglang_log21
|
18 |
+
|
19 |
+
# online
|
20 |
+
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 300 --request-rate 1 --random-input 1024 --random-output 1024 > sglang_log31
|
21 |
+
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 600 --request-rate 2 --random-input 1024 --random-output 1024 > sglang_log32
|
22 |
+
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 1200 --request-rate 4 --random-input 1024 --random-output 1024 > sglang_log33
|
23 |
+
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 2400 --request-rate 8 --random-input 1024 --random-output 1024 > sglang_log34
|
24 |
+
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 3200 --request-rate 16 --random-input 1024 --random-output 1024 > sglang_log35
|
sglang/benchmark/blog_v0_2/405b_trt.sh
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Launch trtllm
|
2 |
+
# https://github.com/sgl-project/tensorrt-demo
|
3 |
+
|
4 |
+
# offline
|
5 |
+
python3 ../../python/sglang/bench_serving.py --backend trt --dataset-name random --num-prompt 3000 --random-input 1024 --random-output 1024 --model /root/Meta-Llama-3-8B-Instruct > trtllm_log11
|
6 |
+
python3 ../../python/sglang/bench_serving.py --backend trt --dataset-name random --num-prompt 4000 --random-input 1024 --random-output 512 --model /root/Meta-Llama-3-8B-Instruct > trtllm_log12
|
7 |
+
python3 ../../python/sglang/bench_serving.py --backend trt --dataset-name random --num-prompt 800 --random-input 4096 --random-output 2048 --model /root/Meta-Llama-3-8B-Instruct > trtllm_log13
|
8 |
+
python3 ../../python/sglang/bench_serving.py --backend trt --dataset-name random --num-prompt 1500 --random-input 4096 --random-output 1024 --model /root/Meta-Llama-3-8B-Instruct > trtllm_log14
|
9 |
+
python3 ../../python/sglang/bench_serving.py --backend trt --dataset-name random --num-prompt 6000 --random-input 256 --random-output 512 --model /root/Meta-Llama-3-8B-Instruct > trtllm_log15
|
10 |
+
python3 ../../python/sglang/bench_serving.py --backend trt --dataset-name sharegpt --num-prompt 2000 --model /root/Meta-Llama-3-8B-Instruct > trtllm_log21
|
11 |
+
|
12 |
+
# online
|
13 |
+
python3 ../../python/sglang/bench_serving.py --backend trt --dataset-name random --num-prompt 300 --request-rate 1 --random-input 1024 --random-output 1024 --model /root/Meta-Llama-3-8B-Instruct > trtllm_log31
|
14 |
+
python3 ../../python/sglang/bench_serving.py --backend trt --dataset-name random --num-prompt 600 --request-rate 2 --random-input 1024 --random-output 1024 --model /root/Meta-Llama-3-8B-Instruct > trtllm_log32
|
15 |
+
python3 ../../python/sglang/bench_serving.py --backend trt --dataset-name random --num-prompt 1200 --request-rate 4 --random-input 1024 --random-output 1024 --model /root/Meta-Llama-3-8B-Instruct > trtllm_log33
|
16 |
+
python3 ../../python/sglang/bench_serving.py --backend trt --dataset-name random --num-prompt 2400 --request-rate 8 --random-input 1024 --random-output 1024 --model /root/Meta-Llama-3-8B-Instruct > trtllm_log34
|
17 |
+
python3 ../../python/sglang/bench_serving.py --backend trt --dataset-name random --num-prompt 3200 --request-rate 16 --random-input 1024 --random-output 1024 --model /root/Meta-Llama-3-8B-Instruct > trtllm_log35
|
sglang/benchmark/blog_v0_2/405b_vllm.sh
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Create dummy weights:
|
2 |
+
# 1. Create a folder `~/llama-3.1-405b-fp8-dummy` and create `config.json` and tokenizer under this folder.
|
3 |
+
# 2. Get `config.json`` from ./config.md
|
4 |
+
# 3. Download the tokenizer
|
5 |
+
# wget https://huggingface.co/neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8/resolve/main/tokenizer.json
|
6 |
+
# wget https://huggingface.co/neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8/resolve/main/tokenizer_config.json
|
7 |
+
|
8 |
+
# Launch vllm
|
9 |
+
# python3 -m vllm.entrypoints.openai.api_server --model ~/llama-3.1-405b-fp8-dummy/ --load-format dummy --disable-log-requests --tensor-parallel-size 8 --max-model-len 10000
|
10 |
+
|
11 |
+
# offline
|
12 |
+
python3 ../../python/sglang/bench_serving.py --backend vllm --dataset-name random --num-prompt 3000 --random-input 1024 --random-output 1024 > vllm_log11
|
13 |
+
python3 ../../python/sglang/bench_serving.py --backend vllm --dataset-name random --num-prompt 4000 --random-input 1024 --random-output 512 > vllm_log12
|
14 |
+
python3 ../../python/sglang/bench_serving.py --backend vllm --dataset-name random --num-prompt 800 --random-input 4096 --random-output 2048 > vllm_log13
|
15 |
+
python3 ../../python/sglang/bench_serving.py --backend vllm --dataset-name random --num-prompt 1500 --random-input 4096 --random-output 1024 > vllm_log14
|
16 |
+
python3 ../../python/sglang/bench_serving.py --backend vllm --dataset-name random --num-prompt 6000 --random-input 256 --random-output 512 > vllm_log15
|
17 |
+
python3 ../../python/sglang/bench_serving.py --backend vllm --dataset-name sharegpt --num-prompt 2000 > vllm_log21
|
18 |
+
|
19 |
+
# online
|
20 |
+
python3 ../../python/sglang/bench_serving.py --backend vllm --dataset-name random --num-prompt 300 --request-rate 1 --random-input 1024 --random-output 1024 > vllm_log31
|
21 |
+
python3 ../../python/sglang/bench_serving.py --backend vllm --dataset-name random --num-prompt 600 --request-rate 2 --random-input 1024 --random-output 1024 > vllm_log32
|
22 |
+
python3 ../../python/sglang/bench_serving.py --backend vllm --dataset-name random --num-prompt 1200 --request-rate 4 --random-input 1024 --random-output 1024 > vllm_log33
|
23 |
+
python3 ../../python/sglang/bench_serving.py --backend vllm --dataset-name random --num-prompt 2400 --request-rate 8 --random-input 1024 --random-output 1024 > vllm_log34
|
24 |
+
python3 ../../python/sglang/bench_serving.py --backend vllm --dataset-name random --num-prompt 3200 --request-rate 16 --random-input 1024 --random-output 1024 > vllm_log35
|
sglang/benchmark/dspy/README.md
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Install
|
2 |
+
|
3 |
+
```
|
4 |
+
pip3 install dspy-ai
|
5 |
+
```
|
6 |
+
|
7 |
+
Turn off cache at https://github.com/stanfordnlp/dspy/blob/34d8420383ec752037aa271825c1d3bf391e1277/dsp/modules/cache_utils.py#L10.
|
8 |
+
```
|
9 |
+
cache_turn_on = False
|
10 |
+
```
|
11 |
+
|
12 |
+
or set the environment variable
|
13 |
+
|
14 |
+
```
|
15 |
+
export DSP_CACHEBOOL=false
|
16 |
+
```
|
17 |
+
|
18 |
+
## Benchmark SGLang
|
19 |
+
```
|
20 |
+
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
|
21 |
+
```
|
22 |
+
|
23 |
+
```
|
24 |
+
python3 bench_dspy_intro.py --backend sglang
|
25 |
+
```
|
26 |
+
|
27 |
+
|
28 |
+
## Benchmark TGI
|
29 |
+
```
|
30 |
+
docker run --name tgi --rm -ti --gpus all --network host \
|
31 |
+
-v /home/ubuntu/model_weights/Llama-2-7b-chat-hf:/Llama-2-7b-chat-hf \
|
32 |
+
ghcr.io/huggingface/text-generation-inference:1.3.0 \
|
33 |
+
--model-id /Llama-2-7b-chat-hf --num-shard 1 --trust-remote-code \
|
34 |
+
--max-input-length 2048 --max-total-tokens 4096 \
|
35 |
+
--port 24000
|
36 |
+
```
|
37 |
+
|
38 |
+
```
|
39 |
+
python3 bench_dspy_intro.py --backend tgi
|
40 |
+
```
|
41 |
+
|
42 |
+
|
43 |
+
|
44 |
+
## Benchmark vLLM
|
45 |
+
```
|
46 |
+
python3 -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
|
47 |
+
```
|
48 |
+
|
49 |
+
```
|
50 |
+
python3 bench_dspy_intro.py --backend vllm
|
51 |
+
```
|
sglang/benchmark/dspy/bench_dspy_intro.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Adapted from
|
3 |
+
https://github.com/stanfordnlp/dspy/blob/34d8420383ec752037aa271825c1d3bf391e1277/intro.ipynb#L9
|
4 |
+
"""
|
5 |
+
|
6 |
+
import argparse
|
7 |
+
|
8 |
+
import dspy
|
9 |
+
from dspy.datasets import HotPotQA
|
10 |
+
|
11 |
+
|
12 |
+
class BasicQA(dspy.Signature):
|
13 |
+
"""Answer questions with short factoid answers."""
|
14 |
+
|
15 |
+
question = dspy.InputField()
|
16 |
+
answer = dspy.OutputField(desc="often between 1 and 5 words")
|
17 |
+
|
18 |
+
|
19 |
+
class GenerateAnswer(dspy.Signature):
|
20 |
+
"""Answer questions with short factoid answers."""
|
21 |
+
|
22 |
+
context = dspy.InputField(desc="may contain relevant facts")
|
23 |
+
question = dspy.InputField()
|
24 |
+
answer = dspy.OutputField(desc="often between 1 and 5 words")
|
25 |
+
|
26 |
+
|
27 |
+
class RAG(dspy.Module):
|
28 |
+
def __init__(self, num_passages=3):
|
29 |
+
super().__init__()
|
30 |
+
|
31 |
+
self.retrieve = dspy.Retrieve(k=num_passages)
|
32 |
+
self.generate_answer = dspy.ChainOfThought(GenerateAnswer)
|
33 |
+
|
34 |
+
def forward(self, question):
|
35 |
+
context = self.retrieve(question).passages
|
36 |
+
prediction = self.generate_answer(context=context, question=question)
|
37 |
+
return dspy.Prediction(context=context, answer=prediction.answer)
|
38 |
+
|
39 |
+
|
40 |
+
def main(args):
|
41 |
+
# lm = dspy.OpenAI(model='gpt-3.5-turbo')
|
42 |
+
if args.backend == "tgi":
|
43 |
+
lm = dspy.HFClientTGI(
|
44 |
+
model="meta-llama/Llama-2-7b-chat-hf",
|
45 |
+
port=args.port,
|
46 |
+
url="http://localhost",
|
47 |
+
)
|
48 |
+
elif args.backend == "sglang":
|
49 |
+
lm = dspy.HFClientSGLang(
|
50 |
+
model="meta-llama/Llama-2-7b-chat-hf",
|
51 |
+
port=args.port,
|
52 |
+
url="http://localhost",
|
53 |
+
)
|
54 |
+
elif args.backend == "vllm":
|
55 |
+
lm = dspy.HFClientVLLM(
|
56 |
+
model="meta-llama/Llama-2-7b-chat-hf",
|
57 |
+
port=args.port,
|
58 |
+
url="http://localhost",
|
59 |
+
)
|
60 |
+
else:
|
61 |
+
raise ValueError(f"Invalid backend: {args.backend}")
|
62 |
+
|
63 |
+
colbertv2_wiki17_abstracts = dspy.ColBERTv2(
|
64 |
+
url="http://20.102.90.50:2017/wiki17_abstracts"
|
65 |
+
)
|
66 |
+
dspy.settings.configure(lm=lm, rm=colbertv2_wiki17_abstracts)
|
67 |
+
|
68 |
+
# Load the dataset.
|
69 |
+
dataset = HotPotQA(
|
70 |
+
train_seed=1, train_size=20, eval_seed=2023, dev_size=args.dev_size, test_size=0
|
71 |
+
)
|
72 |
+
|
73 |
+
# Tell DSPy that the 'question' field is the input. Any other fields are labels and/or metadata.
|
74 |
+
trainset = [x.with_inputs("question") for x in dataset.train]
|
75 |
+
devset = [x.with_inputs("question") for x in dataset.dev]
|
76 |
+
|
77 |
+
print(len(trainset), len(devset))
|
78 |
+
|
79 |
+
train_example = trainset[0]
|
80 |
+
print(f"Question: {train_example.question}")
|
81 |
+
print(f"Answer: {train_example.answer}")
|
82 |
+
|
83 |
+
dev_example = devset[18]
|
84 |
+
print(f"Question: {dev_example.question}")
|
85 |
+
print(f"Answer: {dev_example.answer}")
|
86 |
+
print(f"Relevant Wikipedia Titles: {dev_example.gold_titles}")
|
87 |
+
|
88 |
+
print(
|
89 |
+
f"For this dataset, training examples have input keys {train_example.inputs().keys()} and label keys {train_example.labels().keys()}"
|
90 |
+
)
|
91 |
+
print(
|
92 |
+
f"For this dataset, dev examples have input keys {dev_example.inputs().keys()} and label keys {dev_example.labels().keys()}"
|
93 |
+
)
|
94 |
+
|
95 |
+
# Define the predictor.
|
96 |
+
generate_answer = dspy.Predict(BasicQA)
|
97 |
+
|
98 |
+
# Call the predictor on a particular input.
|
99 |
+
pred = generate_answer(question=dev_example.question)
|
100 |
+
|
101 |
+
# Print the input and the prediction.
|
102 |
+
print(f"Question: {dev_example.question}")
|
103 |
+
print(f"Predicted Answer: {pred.answer}")
|
104 |
+
|
105 |
+
lm.inspect_history(n=1)
|
106 |
+
|
107 |
+
# Define the predictor. Notice we're just changing the class. The signature BasicQA is unchanged.
|
108 |
+
generate_answer_with_chain_of_thought = dspy.ChainOfThought(BasicQA)
|
109 |
+
|
110 |
+
# Call the predictor on the same input.
|
111 |
+
pred = generate_answer_with_chain_of_thought(question=dev_example.question)
|
112 |
+
|
113 |
+
# Print the input, the chain of thought, and the prediction.
|
114 |
+
print(f"Question: {dev_example.question}")
|
115 |
+
print(f"Thought: {pred.rationale.split('.', 1)[1].strip()}")
|
116 |
+
print(f"Predicted Answer: {pred.answer}")
|
117 |
+
|
118 |
+
retrieve = dspy.Retrieve(k=3)
|
119 |
+
topK_passages = retrieve(dev_example.question).passages
|
120 |
+
|
121 |
+
print(
|
122 |
+
f"Top {retrieve.k} passages for question: {dev_example.question} \n",
|
123 |
+
"-" * 30,
|
124 |
+
"\n",
|
125 |
+
)
|
126 |
+
|
127 |
+
for idx, passage in enumerate(topK_passages):
|
128 |
+
print(f"{idx+1}]", passage, "\n")
|
129 |
+
|
130 |
+
retrieve("When was the first FIFA World Cup held?").passages[0]
|
131 |
+
|
132 |
+
from dspy.teleprompt import BootstrapFewShot
|
133 |
+
|
134 |
+
# Validation logic: check that the predicted answer is correct.
|
135 |
+
# Also check that the retrieved context does actually contain that answer.
|
136 |
+
def validate_context_and_answer(example, pred, trace=None):
|
137 |
+
answer_EM = dspy.evaluate.answer_exact_match(example, pred)
|
138 |
+
answer_PM = dspy.evaluate.answer_passage_match(example, pred)
|
139 |
+
return answer_EM and answer_PM
|
140 |
+
|
141 |
+
# Set up a basic teleprompter, which will compile our RAG program.
|
142 |
+
teleprompter = BootstrapFewShot(metric=validate_context_and_answer)
|
143 |
+
|
144 |
+
# Compile!
|
145 |
+
compiled_rag = teleprompter.compile(RAG(), trainset=trainset)
|
146 |
+
|
147 |
+
# Ask any question you like to this simple RAG program.
|
148 |
+
my_question = "What castle did David Gregory inherit?"
|
149 |
+
|
150 |
+
# Get the prediction. This contains `pred.context` and `pred.answer`.
|
151 |
+
pred = compiled_rag(my_question)
|
152 |
+
|
153 |
+
# Print the contexts and the answer.
|
154 |
+
print(f"Question: {my_question}")
|
155 |
+
print(f"Predicted Answer: {pred.answer}")
|
156 |
+
print(f"Retrieved Contexts (truncated): {[c[:200] + '...' for c in pred.context]}")
|
157 |
+
|
158 |
+
from dspy.evaluate.evaluate import Evaluate
|
159 |
+
|
160 |
+
# Set up the `evaluate_on_hotpotqa` function. We'll use this many times below.
|
161 |
+
evaluate_on_hotpotqa = Evaluate(
|
162 |
+
devset=devset,
|
163 |
+
num_threads=args.num_threads,
|
164 |
+
display_progress=True,
|
165 |
+
display_table=5,
|
166 |
+
)
|
167 |
+
|
168 |
+
# Evaluate the `compiled_rag` program with the `answer_exact_match` metric.
|
169 |
+
metric = dspy.evaluate.answer_exact_match
|
170 |
+
evaluate_on_hotpotqa(compiled_rag, metric=metric)
|
171 |
+
|
172 |
+
|
173 |
+
if __name__ == "__main__":
|
174 |
+
parser = argparse.ArgumentParser()
|
175 |
+
parser.add_argument("--port", type=int)
|
176 |
+
parser.add_argument("--num-threads", type=int, default=32)
|
177 |
+
parser.add_argument("--dev-size", type=int, default=150)
|
178 |
+
parser.add_argument(
|
179 |
+
"--backend", type=str, choices=["sglang", "tgi", "vllm"], default="sglang"
|
180 |
+
)
|
181 |
+
args = parser.parse_args()
|
182 |
+
|
183 |
+
if args.port is None:
|
184 |
+
default_port = {
|
185 |
+
"vllm": 21000,
|
186 |
+
"lightllm": 22000,
|
187 |
+
"tgi": 24000,
|
188 |
+
"sglang": 30000,
|
189 |
+
}
|
190 |
+
args.port = default_port.get(args.backend, None)
|
191 |
+
|
192 |
+
main(args)
|
sglang/benchmark/gsm8k/README.md
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Run benchmark
|
2 |
+
|
3 |
+
### Benchmark sglang
|
4 |
+
```
|
5 |
+
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
|
6 |
+
```
|
7 |
+
|
8 |
+
```
|
9 |
+
python3 bench_sglang.py --num-questions 200
|
10 |
+
```
|
11 |
+
|
12 |
+
|
13 |
+
### Benchmark vllm
|
14 |
+
```
|
15 |
+
python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
|
16 |
+
```
|
17 |
+
|
18 |
+
```
|
19 |
+
python3 bench_other.py --num-questions 200 --backend vllm
|
20 |
+
```
|
21 |
+
|
22 |
+
|
23 |
+
### Benchmark lightllm
|
24 |
+
```
|
25 |
+
# A10G
|
26 |
+
python -m lightllm.server.api_server --tokenizer_mode auto --model_dir ~/model_weights/llama-2-7b-chat-hf --max_total_token_num 16000 --port 22000
|
27 |
+
```
|
28 |
+
|
29 |
+
```
|
30 |
+
python3 bench_other.py --num-questions 200 --backend lightllm
|
31 |
+
```
|
32 |
+
|
33 |
+
|
34 |
+
### Benchmark guidance
|
35 |
+
```
|
36 |
+
python3 bench_other.py --num-questions 200 --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf
|
37 |
+
```
|
38 |
+
|
39 |
+
|
40 |
+
### Benchmark lmql
|
41 |
+
```
|
42 |
+
CUDA_VISIBLE_DEVICES=0,1 lmql serve-model meta-llama/Llama-2-7b-chat-hf --cuda --port 23000
|
43 |
+
```
|
44 |
+
|
45 |
+
```
|
46 |
+
python3 bench_other.py --num-questions 100 --backend lmql --parallel 2
|
47 |
+
```
|
sglang/benchmark/gsm8k/bench_other.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import ast
|
3 |
+
import asyncio
|
4 |
+
import json
|
5 |
+
import re
|
6 |
+
import time
|
7 |
+
from concurrent.futures import ThreadPoolExecutor
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
from tqdm import tqdm
|
11 |
+
|
12 |
+
from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
|
13 |
+
from sglang.utils import download_and_cache_file, dump_state_text, read_jsonl
|
14 |
+
|
15 |
+
INVALID = -9999999
|
16 |
+
|
17 |
+
|
18 |
+
def get_one_example(lines, i, include_answer):
|
19 |
+
ret = "Question: " + lines[i]["question"] + "\nAnswer:"
|
20 |
+
if include_answer:
|
21 |
+
ret += " " + lines[i]["answer"]
|
22 |
+
return ret
|
23 |
+
|
24 |
+
|
25 |
+
def get_few_shot_examples(lines, k):
|
26 |
+
ret = ""
|
27 |
+
for i in range(k):
|
28 |
+
ret += get_one_example(lines, i, True) + "\n\n"
|
29 |
+
return ret
|
30 |
+
|
31 |
+
|
32 |
+
def get_answer_value(answer_str):
|
33 |
+
answer_str = answer_str.replace(",", "")
|
34 |
+
numbers = re.findall(r"\d+", answer_str)
|
35 |
+
if len(numbers) < 1:
|
36 |
+
return INVALID
|
37 |
+
try:
|
38 |
+
return ast.literal_eval(numbers[-1])
|
39 |
+
except SyntaxError:
|
40 |
+
return INVALID
|
41 |
+
|
42 |
+
|
43 |
+
def main(args):
|
44 |
+
# Select backend
|
45 |
+
call_generate = get_call_generate(args)
|
46 |
+
|
47 |
+
# Read data
|
48 |
+
url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl"
|
49 |
+
filename = download_and_cache_file(url)
|
50 |
+
lines = list(read_jsonl(filename))
|
51 |
+
|
52 |
+
# Construct prompts
|
53 |
+
num_questions = args.num_questions
|
54 |
+
num_shots = args.num_shots
|
55 |
+
few_shot_examples = get_few_shot_examples(lines, num_shots)
|
56 |
+
|
57 |
+
questions = []
|
58 |
+
labels = []
|
59 |
+
for i in range(len(lines[:num_questions])):
|
60 |
+
questions.append(get_one_example(lines, i, False))
|
61 |
+
labels.append(get_answer_value(lines[i]["answer"]))
|
62 |
+
assert all(l != INVALID for l in labels)
|
63 |
+
|
64 |
+
states = [None] * len(labels)
|
65 |
+
|
66 |
+
# Run requests
|
67 |
+
if args.backend != "lmql":
|
68 |
+
# Use thread pool
|
69 |
+
def get_one_answer(i):
|
70 |
+
answer = call_generate(
|
71 |
+
prompt=few_shot_examples + questions[i],
|
72 |
+
temperature=0,
|
73 |
+
max_tokens=256,
|
74 |
+
stop=["Question", "Assistant:", "<|separator|>"],
|
75 |
+
)
|
76 |
+
states[i] = answer
|
77 |
+
|
78 |
+
tic = time.time()
|
79 |
+
if args.parallel == 1:
|
80 |
+
for i in tqdm(range(len(questions))):
|
81 |
+
get_one_answer(i)
|
82 |
+
else:
|
83 |
+
with ThreadPoolExecutor(args.parallel) as executor:
|
84 |
+
list(
|
85 |
+
tqdm(
|
86 |
+
executor.map(get_one_answer, list(range(len(questions)))),
|
87 |
+
total=len(questions),
|
88 |
+
)
|
89 |
+
)
|
90 |
+
|
91 |
+
else:
|
92 |
+
# Use asyncio
|
93 |
+
async def batched_call(batch_size):
|
94 |
+
for i in range(0, len(questions), batch_size):
|
95 |
+
tasks = []
|
96 |
+
for q in questions[i : i + batch_size]:
|
97 |
+
tasks.append(
|
98 |
+
call_generate(
|
99 |
+
few_shot_examples + q,
|
100 |
+
temperature=0,
|
101 |
+
max_tokens=256,
|
102 |
+
stop="Question",
|
103 |
+
)
|
104 |
+
)
|
105 |
+
rets = await asyncio.gather(*tasks)
|
106 |
+
for j in range(len(rets)):
|
107 |
+
states[i + j] = rets[j]
|
108 |
+
|
109 |
+
tic = time.time()
|
110 |
+
asyncio.run(batched_call(batch_size=args.parallel))
|
111 |
+
latency = time.time() - tic
|
112 |
+
|
113 |
+
preds = []
|
114 |
+
for i in range(len(states)):
|
115 |
+
preds.append(get_answer_value(states[i]))
|
116 |
+
|
117 |
+
# Compute accuracy
|
118 |
+
acc = np.mean(np.array(preds) == np.array(labels))
|
119 |
+
invalid = np.mean(np.array(preds) == INVALID)
|
120 |
+
|
121 |
+
# Print results
|
122 |
+
print(f"Accuracy: {acc:.3f}")
|
123 |
+
print(f"Invalid: {invalid:.3f}")
|
124 |
+
print(f"Latency: {latency:.3f} s")
|
125 |
+
|
126 |
+
# Dump results
|
127 |
+
dump_state_text(f"tmp_output_{args.backend}.txt", states)
|
128 |
+
|
129 |
+
with open(args.result_file, "a") as fout:
|
130 |
+
value = {
|
131 |
+
"task": "gsm8k",
|
132 |
+
"backend": args.backend,
|
133 |
+
"num_gpus": 1,
|
134 |
+
"latency": round(latency, 3),
|
135 |
+
"accuracy": round(acc, 3),
|
136 |
+
"num_requests": args.num_questions,
|
137 |
+
"other": {
|
138 |
+
"num_questions": args.num_questions,
|
139 |
+
"parallel": args.parallel,
|
140 |
+
},
|
141 |
+
}
|
142 |
+
fout.write(json.dumps(value) + "\n")
|
143 |
+
|
144 |
+
|
145 |
+
if __name__ == "__main__":
|
146 |
+
parser = argparse.ArgumentParser()
|
147 |
+
parser.add_argument("--num-shots", type=int, default=5)
|
148 |
+
parser.add_argument("--data-path", type=str, default="test.jsonl")
|
149 |
+
parser.add_argument("--num-questions", type=int, default=200)
|
150 |
+
args = add_common_other_args_and_parse(parser)
|
151 |
+
main(args)
|
sglang/benchmark/gsm8k/bench_sglang.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import ast
|
3 |
+
import json
|
4 |
+
import re
|
5 |
+
import time
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
from sglang.api import set_default_backend
|
10 |
+
from sglang.test.test_utils import (
|
11 |
+
add_common_sglang_args_and_parse,
|
12 |
+
select_sglang_backend,
|
13 |
+
)
|
14 |
+
from sglang.utils import download_and_cache_file, dump_state_text, read_jsonl
|
15 |
+
|
16 |
+
INVALID = -9999999
|
17 |
+
|
18 |
+
|
19 |
+
def get_one_example(lines, i, include_answer):
|
20 |
+
ret = "Question: " + lines[i]["question"] + "\nAnswer:"
|
21 |
+
if include_answer:
|
22 |
+
ret += " " + lines[i]["answer"]
|
23 |
+
return ret
|
24 |
+
|
25 |
+
|
26 |
+
def get_few_shot_examples(lines, k):
|
27 |
+
ret = ""
|
28 |
+
for i in range(k):
|
29 |
+
ret += get_one_example(lines, i, True) + "\n\n"
|
30 |
+
return ret
|
31 |
+
|
32 |
+
|
33 |
+
def get_answer_value(answer_str):
|
34 |
+
answer_str = answer_str.replace(",", "")
|
35 |
+
numbers = re.findall(r"\d+", answer_str)
|
36 |
+
if len(numbers) < 1:
|
37 |
+
return INVALID
|
38 |
+
try:
|
39 |
+
return ast.literal_eval(numbers[-1])
|
40 |
+
except SyntaxError:
|
41 |
+
return INVALID
|
42 |
+
|
43 |
+
|
44 |
+
def main(args):
|
45 |
+
# Select backend
|
46 |
+
set_default_backend(select_sglang_backend(args))
|
47 |
+
|
48 |
+
# Read data
|
49 |
+
url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl"
|
50 |
+
filename = download_and_cache_file(url)
|
51 |
+
lines = list(read_jsonl(filename))
|
52 |
+
|
53 |
+
# Construct prompts
|
54 |
+
num_questions = args.num_questions
|
55 |
+
num_shots = args.num_shots
|
56 |
+
few_shot_examples = get_few_shot_examples(lines, num_shots)
|
57 |
+
|
58 |
+
questions = []
|
59 |
+
labels = []
|
60 |
+
for i in range(len(lines[:num_questions])):
|
61 |
+
questions.append(get_one_example(lines, i, False))
|
62 |
+
labels.append(get_answer_value(lines[i]["answer"]))
|
63 |
+
assert all(l != INVALID for l in labels)
|
64 |
+
arguments = [{"question": q} for q in questions]
|
65 |
+
|
66 |
+
#####################################
|
67 |
+
######### SGL Program Begin #########
|
68 |
+
#####################################
|
69 |
+
|
70 |
+
import sglang as sgl
|
71 |
+
|
72 |
+
@sgl.function
|
73 |
+
def few_shot_gsm8k(s, question):
|
74 |
+
s += few_shot_examples + question
|
75 |
+
s += sgl.gen(
|
76 |
+
"answer", max_tokens=512, stop=["Question", "Assistant:", "<|separator|>"]
|
77 |
+
)
|
78 |
+
|
79 |
+
#####################################
|
80 |
+
########## SGL Program End ##########
|
81 |
+
#####################################
|
82 |
+
|
83 |
+
# Run requests
|
84 |
+
tic = time.time()
|
85 |
+
states = few_shot_gsm8k.run_batch(
|
86 |
+
arguments,
|
87 |
+
temperature=0,
|
88 |
+
num_threads=args.parallel,
|
89 |
+
progress_bar=True,
|
90 |
+
)
|
91 |
+
latency = time.time() - tic
|
92 |
+
|
93 |
+
preds = []
|
94 |
+
for i in range(len(states)):
|
95 |
+
preds.append(get_answer_value(states[i]["answer"]))
|
96 |
+
|
97 |
+
# print(f"{preds=}")
|
98 |
+
# print(f"{labels=}")
|
99 |
+
|
100 |
+
# Compute accuracy
|
101 |
+
acc = np.mean(np.array(preds) == np.array(labels))
|
102 |
+
invalid = np.mean(np.array(preds) == INVALID)
|
103 |
+
|
104 |
+
# Compute speed
|
105 |
+
num_output_tokens = sum(
|
106 |
+
s.get_meta_info("answer")["completion_tokens"] for s in states
|
107 |
+
)
|
108 |
+
output_throughput = num_output_tokens / latency
|
109 |
+
|
110 |
+
# Print results
|
111 |
+
print(f"Accuracy: {acc:.3f}")
|
112 |
+
print(f"Invalid: {invalid:.3f}")
|
113 |
+
print(f"Latency: {latency:.3f} s")
|
114 |
+
print(f"Output throughput: {output_throughput:.3f} token/s")
|
115 |
+
|
116 |
+
# Dump results
|
117 |
+
dump_state_text(f"tmp_output_{args.backend}.txt", states)
|
118 |
+
|
119 |
+
with open(args.result_file, "a") as fout:
|
120 |
+
value = {
|
121 |
+
"task": "gsm8k",
|
122 |
+
"backend": args.backend,
|
123 |
+
"num_gpus": 1,
|
124 |
+
"latency": round(latency, 3),
|
125 |
+
"accuracy": round(acc, 3),
|
126 |
+
"num_requests": args.num_questions,
|
127 |
+
"other": {
|
128 |
+
"num_questions": args.num_questions,
|
129 |
+
"parallel": args.parallel,
|
130 |
+
},
|
131 |
+
}
|
132 |
+
fout.write(json.dumps(value) + "\n")
|
133 |
+
|
134 |
+
|
135 |
+
if __name__ == "__main__":
|
136 |
+
parser = argparse.ArgumentParser()
|
137 |
+
parser.add_argument("--num-shots", type=int, default=5)
|
138 |
+
parser.add_argument("--data-path", type=str, default="test.jsonl")
|
139 |
+
parser.add_argument("--num-questions", type=int, default=200)
|
140 |
+
args = add_common_sglang_args_and_parse(parser)
|
141 |
+
main(args)
|
sglang/benchmark/hellaswag/README.md
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Run benchmark
|
2 |
+
|
3 |
+
### Benchmark sglang
|
4 |
+
```
|
5 |
+
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
|
6 |
+
```
|
7 |
+
|
8 |
+
```
|
9 |
+
python3 bench_sglang.py --num-questions 200
|
10 |
+
```
|
11 |
+
|
12 |
+
|
13 |
+
### Benchmark vllm
|
14 |
+
```
|
15 |
+
python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
|
16 |
+
```
|
17 |
+
|
18 |
+
```
|
19 |
+
python3 bench_other.py --num-questions 200 --backend vllm
|
20 |
+
```
|
21 |
+
|
22 |
+
|
23 |
+
### Benchmark lightllm
|
24 |
+
```
|
25 |
+
# A10G
|
26 |
+
python -m lightllm.server.api_server --tokenizer_mode auto --model_dir ~/model_weights/llama-2-7b-chat-hf --max_total_token_num 16000 --port 22000
|
27 |
+
```
|
28 |
+
|
29 |
+
```
|
30 |
+
python3 bench_other.py --num-questions 200 --backend lightllm
|
31 |
+
```
|
32 |
+
|
33 |
+
|
34 |
+
### Benchmark guidance
|
35 |
+
```
|
36 |
+
CUDA_VISIBLE_DEVICES=0,1 python3 bench_other.py --num-questions 200 --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf
|
37 |
+
```
|
38 |
+
|
39 |
+
|
40 |
+
### Benchmark lmql
|
41 |
+
```
|
42 |
+
lmql serve-model meta-llama/Llama-2-7b-chat-hf --cuda --port 23000
|
43 |
+
```
|
44 |
+
|
45 |
+
```
|
46 |
+
python3 bench_other.py --num-questions 200 --backend lmql --port 23000 --parallel 1
|
47 |
+
```
|
sglang/benchmark/hellaswag/bench_other.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import asyncio
|
3 |
+
import json
|
4 |
+
import time
|
5 |
+
from concurrent.futures import ThreadPoolExecutor
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
from tqdm import tqdm
|
9 |
+
|
10 |
+
from sglang.test.test_utils import add_common_other_args_and_parse, get_call_select
|
11 |
+
from sglang.utils import download_and_cache_file, read_jsonl
|
12 |
+
|
13 |
+
|
14 |
+
def get_one_example(lines, i, include_answer):
|
15 |
+
ret = lines[i]["activity_label"] + ": " + lines[i]["ctx"] + " "
|
16 |
+
if include_answer:
|
17 |
+
ret += lines[i]["endings"][lines[i]["label"]]
|
18 |
+
return ret
|
19 |
+
|
20 |
+
|
21 |
+
def get_few_shot_examples(lines, k):
|
22 |
+
ret = ""
|
23 |
+
for i in range(k):
|
24 |
+
ret += get_one_example(lines, i, True) + "\n\n"
|
25 |
+
return ret
|
26 |
+
|
27 |
+
|
28 |
+
def main(args):
|
29 |
+
# Select backend
|
30 |
+
call_select = get_call_select(args)
|
31 |
+
|
32 |
+
# Read data
|
33 |
+
url = "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl"
|
34 |
+
filename = download_and_cache_file(url)
|
35 |
+
lines = list(read_jsonl(filename))
|
36 |
+
|
37 |
+
# Construct prompts
|
38 |
+
num_questions = args.num_questions
|
39 |
+
num_shots = args.num_shots
|
40 |
+
few_shot_examples = get_few_shot_examples(lines, num_shots)
|
41 |
+
|
42 |
+
questions = []
|
43 |
+
choices = []
|
44 |
+
labels = []
|
45 |
+
for i in range(len(lines[:num_questions])):
|
46 |
+
questions.append(get_one_example(lines, i, False))
|
47 |
+
choices.append(lines[i]["endings"])
|
48 |
+
labels.append(lines[i]["label"])
|
49 |
+
|
50 |
+
preds = [None] * len(labels)
|
51 |
+
|
52 |
+
# Run requests
|
53 |
+
if args.backend != "lmql":
|
54 |
+
# Use thread pool
|
55 |
+
def get_one_answer(i):
|
56 |
+
preds[i] = call_select(
|
57 |
+
context=few_shot_examples + questions[i], choices=choices[i]
|
58 |
+
)
|
59 |
+
|
60 |
+
tic = time.time()
|
61 |
+
if args.parallel == 1:
|
62 |
+
for i in tqdm(range(len(questions))):
|
63 |
+
get_one_answer(i)
|
64 |
+
else:
|
65 |
+
with ThreadPoolExecutor(args.parallel) as executor:
|
66 |
+
list(
|
67 |
+
tqdm(
|
68 |
+
executor.map(get_one_answer, list(range(len(questions)))),
|
69 |
+
total=len(questions),
|
70 |
+
)
|
71 |
+
)
|
72 |
+
else:
|
73 |
+
# Use asyncio
|
74 |
+
async def batched_call(batch_size):
|
75 |
+
for i in range(0, len(questions), batch_size):
|
76 |
+
tasks = []
|
77 |
+
for q, c in zip(
|
78 |
+
questions[i : i + batch_size], choices[i : i + batch_size]
|
79 |
+
):
|
80 |
+
tasks.append(call_select(context=few_shot_examples + q, choices=c))
|
81 |
+
rets = await asyncio.gather(*tasks)
|
82 |
+
for j in range(len(rets)):
|
83 |
+
preds[i + j] = rets[j]
|
84 |
+
|
85 |
+
tic = time.time()
|
86 |
+
asyncio.run(batched_call(batch_size=args.parallel))
|
87 |
+
|
88 |
+
latency = time.time() - tic
|
89 |
+
|
90 |
+
# Compute accuracy
|
91 |
+
acc = np.mean(np.array(preds) == np.array(labels))
|
92 |
+
print(f"Latency: {latency:.3f}")
|
93 |
+
print(f"Accuracy: {acc:.3f}")
|
94 |
+
|
95 |
+
# Write results
|
96 |
+
with open(args.result_file, "a") as fout:
|
97 |
+
value = {
|
98 |
+
"task": "hellaswag",
|
99 |
+
"backend": args.backend,
|
100 |
+
"num_gpus": 1,
|
101 |
+
"latency": round(latency, 3),
|
102 |
+
"accuracy": round(acc, 3),
|
103 |
+
"num_requests": args.num_questions,
|
104 |
+
"other": {
|
105 |
+
"num_questions": args.num_questions,
|
106 |
+
"parallel": args.parallel,
|
107 |
+
},
|
108 |
+
}
|
109 |
+
fout.write(json.dumps(value) + "\n")
|
110 |
+
|
111 |
+
|
112 |
+
if __name__ == "__main__":
|
113 |
+
parser = argparse.ArgumentParser()
|
114 |
+
parser.add_argument("--num-shots", type=int, default=20)
|
115 |
+
parser.add_argument("--data-path", type=str, default="hellaswag_val.jsonl")
|
116 |
+
parser.add_argument("--num-questions", type=int, default=200)
|
117 |
+
args = add_common_other_args_and_parse(parser)
|
118 |
+
main(args)
|
sglang/benchmark/lora/launch_server.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
|
4 |
+
NUM_LORAS = 8
|
5 |
+
LORA_PATH = {
|
6 |
+
"base": "mistralai/Mistral-7B-Instruct-v0.3",
|
7 |
+
"lora": "/home/ying/test_lora",
|
8 |
+
}
|
9 |
+
|
10 |
+
|
11 |
+
def launch_server(args):
|
12 |
+
base_path = LORA_PATH["base"]
|
13 |
+
lora_path = LORA_PATH["lora"]
|
14 |
+
|
15 |
+
if args.base_only:
|
16 |
+
cmd = f"python3 -m sglang.launch_server --model {base_path} "
|
17 |
+
else:
|
18 |
+
cmd = f"python3 -m sglang.launch_server --model {base_path} --lora-paths "
|
19 |
+
for i in range(NUM_LORAS):
|
20 |
+
lora_name = f"lora{i}"
|
21 |
+
cmd += f"{lora_name}={lora_path} "
|
22 |
+
cmd += f"--disable-radix --disable-cuda-graph "
|
23 |
+
cmd += f"--max-loras-per-batch {args.max_loras_per_batch} "
|
24 |
+
cmd += f"--max-running-requests {args.max_running_requests}"
|
25 |
+
print(cmd)
|
26 |
+
os.system(cmd)
|
27 |
+
|
28 |
+
|
29 |
+
if __name__ == "__main__":
|
30 |
+
parser = argparse.ArgumentParser()
|
31 |
+
parser.add_argument(
|
32 |
+
"--base-only",
|
33 |
+
action="store_true",
|
34 |
+
)
|
35 |
+
parser.add_argument(
|
36 |
+
"--max-loras-per-batch",
|
37 |
+
type=int,
|
38 |
+
default=8,
|
39 |
+
)
|
40 |
+
parser.add_argument(
|
41 |
+
"--max-running-requests",
|
42 |
+
type=int,
|
43 |
+
default=8,
|
44 |
+
)
|
45 |
+
args = parser.parse_args()
|
46 |
+
|
47 |
+
launch_server(args)
|
sglang/benchmark/lora/lora_bench.py
ADDED
@@ -0,0 +1,484 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023-2024 SGLang Team
|
2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3 |
+
# you may not use this file except in compliance with the License.
|
4 |
+
# You may obtain a copy of the License at
|
5 |
+
#
|
6 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7 |
+
#
|
8 |
+
# Unless required by applicable law or agreed to in writing, software
|
9 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11 |
+
# See the License for the specific language governing permissions and
|
12 |
+
# limitations under the License.
|
13 |
+
# ==============================================================================
|
14 |
+
|
15 |
+
import argparse
|
16 |
+
import asyncio
|
17 |
+
import json
|
18 |
+
import os
|
19 |
+
import random
|
20 |
+
import resource
|
21 |
+
import sys
|
22 |
+
import time
|
23 |
+
import traceback
|
24 |
+
import warnings
|
25 |
+
from argparse import ArgumentParser
|
26 |
+
from dataclasses import dataclass, field
|
27 |
+
from datetime import datetime
|
28 |
+
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
|
29 |
+
|
30 |
+
import aiohttp
|
31 |
+
import numpy as np
|
32 |
+
import requests
|
33 |
+
from launch_server import LORA_PATH, NUM_LORAS
|
34 |
+
from tqdm.asyncio import tqdm
|
35 |
+
from transformers import (
|
36 |
+
AutoTokenizer,
|
37 |
+
PreTrainedTokenizer,
|
38 |
+
PreTrainedTokenizerBase,
|
39 |
+
PreTrainedTokenizerFast,
|
40 |
+
)
|
41 |
+
|
42 |
+
from sglang.bench_serving import (
|
43 |
+
AIOHTTP_TIMEOUT,
|
44 |
+
SHAREGPT_URL,
|
45 |
+
BenchmarkMetrics,
|
46 |
+
RequestFuncInput,
|
47 |
+
RequestFuncOutput,
|
48 |
+
calculate_metrics,
|
49 |
+
check_chat_template,
|
50 |
+
get_model,
|
51 |
+
get_request,
|
52 |
+
get_tokenizer,
|
53 |
+
parse_request_rate_range,
|
54 |
+
remove_prefix,
|
55 |
+
sample_random_requests,
|
56 |
+
)
|
57 |
+
|
58 |
+
global args
|
59 |
+
|
60 |
+
|
61 |
+
# set ignore_eos True by default
|
62 |
+
async def async_request_openai_completions(
|
63 |
+
request_func_input: RequestFuncInput,
|
64 |
+
pbar: Optional[tqdm] = None,
|
65 |
+
) -> RequestFuncOutput:
|
66 |
+
api_url = request_func_input.api_url
|
67 |
+
# assert api_url.endswith(
|
68 |
+
# "completions"
|
69 |
+
# ), "OpenAI Completions API URL must end with 'completions'."
|
70 |
+
|
71 |
+
prompt = request_func_input.prompt
|
72 |
+
|
73 |
+
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
74 |
+
# payload = {
|
75 |
+
# "model": request_func_input.model,
|
76 |
+
# "prompt": prompt,
|
77 |
+
# "temperature": 0.0,
|
78 |
+
# "best_of": 1,
|
79 |
+
# "max_tokens": request_func_input.output_len,
|
80 |
+
# "stream": not args.disable_stream,
|
81 |
+
# "ignore_eos": not args.disable_ignore_eos,
|
82 |
+
# **request_func_input.extra_request_body,
|
83 |
+
# }
|
84 |
+
# headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
|
85 |
+
if args.base_only:
|
86 |
+
payload = {
|
87 |
+
"text": prompt,
|
88 |
+
"sampling_params": {"max_new_tokens": request_func_input.output_len},
|
89 |
+
}
|
90 |
+
else:
|
91 |
+
payload = {
|
92 |
+
"text": prompt,
|
93 |
+
"sampling_params": {"max_new_tokens": request_func_input.output_len},
|
94 |
+
"lora_path": f"lora{random.randint(0, NUM_LORAS - 1)}",
|
95 |
+
}
|
96 |
+
headers = {"Authorization": ""}
|
97 |
+
|
98 |
+
output = RequestFuncOutput()
|
99 |
+
output.prompt_len = request_func_input.prompt_len
|
100 |
+
|
101 |
+
generated_text = ""
|
102 |
+
ttft = 0.0
|
103 |
+
st = time.perf_counter()
|
104 |
+
most_recent_timestamp = st
|
105 |
+
try:
|
106 |
+
async with session.post(
|
107 |
+
url=api_url, json=payload, headers=headers
|
108 |
+
) as response:
|
109 |
+
if response.status == 200:
|
110 |
+
async for chunk_bytes in response.content:
|
111 |
+
chunk_bytes = chunk_bytes.strip()
|
112 |
+
if not chunk_bytes:
|
113 |
+
continue
|
114 |
+
|
115 |
+
chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ")
|
116 |
+
latency = time.perf_counter() - st
|
117 |
+
if chunk == "[DONE]":
|
118 |
+
pass
|
119 |
+
else:
|
120 |
+
data = json.loads(chunk)
|
121 |
+
|
122 |
+
# NOTE: Some completion API might have a last
|
123 |
+
# usage summary response without a token so we
|
124 |
+
# want to check a token was generated
|
125 |
+
if data["text"]:
|
126 |
+
# if data["choices"][0]["text"]:
|
127 |
+
timestamp = time.perf_counter()
|
128 |
+
# First token
|
129 |
+
if ttft == 0.0:
|
130 |
+
ttft = time.perf_counter() - st
|
131 |
+
output.ttft = ttft
|
132 |
+
|
133 |
+
# Decoding phase
|
134 |
+
else:
|
135 |
+
output.itl.append(timestamp - most_recent_timestamp)
|
136 |
+
|
137 |
+
most_recent_timestamp = timestamp
|
138 |
+
# generated_text += data["choices"][0]["text"]
|
139 |
+
generated_text += data["text"]
|
140 |
+
|
141 |
+
output.generated_text = generated_text
|
142 |
+
output.success = True
|
143 |
+
output.latency = latency
|
144 |
+
output.output_len = request_func_input.output_len
|
145 |
+
else:
|
146 |
+
output.error = response.reason or ""
|
147 |
+
output.success = False
|
148 |
+
except Exception:
|
149 |
+
output.success = False
|
150 |
+
exc_info = sys.exc_info()
|
151 |
+
output.error = "".join(traceback.format_exception(*exc_info))
|
152 |
+
|
153 |
+
if pbar:
|
154 |
+
pbar.update(1)
|
155 |
+
return output
|
156 |
+
|
157 |
+
|
158 |
+
ASYNC_REQUEST_FUNCS = {
|
159 |
+
"sglang": async_request_openai_completions,
|
160 |
+
}
|
161 |
+
|
162 |
+
|
163 |
+
async def benchmark(
|
164 |
+
backend: str,
|
165 |
+
api_url: str,
|
166 |
+
model_id: str,
|
167 |
+
tokenizer: PreTrainedTokenizerBase,
|
168 |
+
input_requests: List[Tuple[str, int, int]],
|
169 |
+
request_rate: float,
|
170 |
+
disable_tqdm: bool,
|
171 |
+
extra_request_body: Dict[str, Any],
|
172 |
+
):
|
173 |
+
if backend in ASYNC_REQUEST_FUNCS:
|
174 |
+
request_func = ASYNC_REQUEST_FUNCS[backend]
|
175 |
+
else:
|
176 |
+
raise ValueError(f"Unknown backend: {backend}")
|
177 |
+
|
178 |
+
print("Starting initial single prompt test run...")
|
179 |
+
test_prompt, test_prompt_len, test_output_len = input_requests[0]
|
180 |
+
test_input = RequestFuncInput(
|
181 |
+
model=model_id,
|
182 |
+
prompt=test_prompt,
|
183 |
+
api_url=api_url,
|
184 |
+
prompt_len=test_prompt_len,
|
185 |
+
output_len=test_output_len,
|
186 |
+
extra_request_body=extra_request_body,
|
187 |
+
)
|
188 |
+
test_output = await request_func(request_func_input=test_input)
|
189 |
+
if not test_output.success:
|
190 |
+
raise ValueError(
|
191 |
+
"Initial test run failed - Please make sure benchmark arguments "
|
192 |
+
f"are correctly specified. Error: {test_output.error}"
|
193 |
+
)
|
194 |
+
else:
|
195 |
+
print("Initial test run completed. Starting main benchmark run...")
|
196 |
+
|
197 |
+
pbar = None if disable_tqdm else tqdm(total=len(input_requests))
|
198 |
+
|
199 |
+
benchmark_start_time = time.perf_counter()
|
200 |
+
tasks: List[asyncio.Task] = []
|
201 |
+
async for request in get_request(input_requests, request_rate):
|
202 |
+
prompt, prompt_len, output_len = request
|
203 |
+
request_func_input = RequestFuncInput(
|
204 |
+
model=model_id,
|
205 |
+
prompt=prompt,
|
206 |
+
api_url=api_url,
|
207 |
+
prompt_len=prompt_len,
|
208 |
+
output_len=output_len,
|
209 |
+
extra_request_body=extra_request_body,
|
210 |
+
)
|
211 |
+
tasks.append(
|
212 |
+
asyncio.create_task(
|
213 |
+
request_func(request_func_input=request_func_input, pbar=pbar)
|
214 |
+
)
|
215 |
+
)
|
216 |
+
outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks)
|
217 |
+
|
218 |
+
if pbar is not None:
|
219 |
+
pbar.close()
|
220 |
+
|
221 |
+
benchmark_duration = time.perf_counter() - benchmark_start_time
|
222 |
+
|
223 |
+
metrics, output_lens = calculate_metrics(
|
224 |
+
input_requests=input_requests,
|
225 |
+
outputs=outputs,
|
226 |
+
dur_s=benchmark_duration,
|
227 |
+
tokenizer=tokenizer,
|
228 |
+
backend=backend,
|
229 |
+
)
|
230 |
+
|
231 |
+
print("\n{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="="))
|
232 |
+
print("{:<40} {:<10}".format("Backend:", backend))
|
233 |
+
print("{:<40} {:<10}".format("Traffic request rate:", request_rate))
|
234 |
+
print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
|
235 |
+
print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration))
|
236 |
+
print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
|
237 |
+
print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output))
|
238 |
+
print(
|
239 |
+
"{:<40} {:<10}".format(
|
240 |
+
"Total generated tokens (retokenized):", metrics.total_output_retokenized
|
241 |
+
)
|
242 |
+
)
|
243 |
+
print(
|
244 |
+
"{:<40} {:<10.2f}".format(
|
245 |
+
"Request throughput (req/s):", metrics.request_throughput
|
246 |
+
)
|
247 |
+
)
|
248 |
+
print(
|
249 |
+
"{:<40} {:<10.2f}".format(
|
250 |
+
"Input token throughput (tok/s):", metrics.input_throughput
|
251 |
+
)
|
252 |
+
)
|
253 |
+
print(
|
254 |
+
"{:<40} {:<10.2f}".format(
|
255 |
+
"Output token throughput (tok/s):", metrics.output_throughput
|
256 |
+
)
|
257 |
+
)
|
258 |
+
print("{s:{c}^{n}}".format(s="End-to-End Latency", n=50, c="-"))
|
259 |
+
print(
|
260 |
+
"{:<40} {:<10.2f}".format("Mean E2E Latency (ms):", metrics.mean_e2e_latency_ms)
|
261 |
+
)
|
262 |
+
print(
|
263 |
+
"{:<40} {:<10.2f}".format(
|
264 |
+
"Median E2E Latency (ms):", metrics.median_e2e_latency_ms
|
265 |
+
)
|
266 |
+
)
|
267 |
+
print("{s:{c}^{n}}".format(s="Time to First Token", n=50, c="-"))
|
268 |
+
print("{:<40} {:<10.2f}".format("Mean TTFT (ms):", metrics.mean_ttft_ms))
|
269 |
+
print("{:<40} {:<10.2f}".format("Median TTFT (ms):", metrics.median_ttft_ms))
|
270 |
+
print("{:<40} {:<10.2f}".format("P99 TTFT (ms):", metrics.p99_ttft_ms))
|
271 |
+
print(
|
272 |
+
"{s:{c}^{n}}".format(s="Time per Output Token (excl. 1st token)", n=50, c="-")
|
273 |
+
)
|
274 |
+
print("{:<40} {:<10.2f}".format("Mean TPOT (ms):", metrics.mean_tpot_ms))
|
275 |
+
print("{:<40} {:<10.2f}".format("Median TPOT (ms):", metrics.median_tpot_ms))
|
276 |
+
print("{:<40} {:<10.2f}".format("P99 TPOT (ms):", metrics.p99_tpot_ms))
|
277 |
+
print("{s:{c}^{n}}".format(s="Inter-token Latency", n=50, c="-"))
|
278 |
+
print("{:<40} {:<10.2f}".format("Mean ITL (ms):", metrics.mean_itl_ms))
|
279 |
+
print("{:<40} {:<10.2f}".format("Median ITL (ms):", metrics.median_itl_ms))
|
280 |
+
print("{:<40} {:<10.2f}".format("P99 ITL (ms):", metrics.p99_itl_ms))
|
281 |
+
print("=" * 50)
|
282 |
+
|
283 |
+
if (
|
284 |
+
metrics.median_ttft_ms is not None
|
285 |
+
and metrics.mean_itl_ms is not None
|
286 |
+
and metrics.output_throughput is not None
|
287 |
+
):
|
288 |
+
result = {
|
289 |
+
"backend": args.backend,
|
290 |
+
"request_rate": request_rate,
|
291 |
+
"total_input_tokens": metrics.total_input,
|
292 |
+
"total_output_tokens": metrics.total_output,
|
293 |
+
"total_output_tokens_retokenized": metrics.total_output_retokenized,
|
294 |
+
"mean_e2e_latency_ms": metrics.mean_e2e_latency_ms,
|
295 |
+
"median_e2e_latency_ms": metrics.median_e2e_latency_ms,
|
296 |
+
"median_ttft_ms": metrics.median_ttft_ms,
|
297 |
+
"median_itl_ms": metrics.median_itl_ms,
|
298 |
+
"output_throughput": metrics.output_throughput,
|
299 |
+
"random_input_len": args.random_input_len,
|
300 |
+
"random_output_len": args.random_output_len,
|
301 |
+
"random_range_ratio": args.random_range_ratio,
|
302 |
+
"duration": benchmark_duration,
|
303 |
+
"completed": metrics.completed,
|
304 |
+
}
|
305 |
+
else:
|
306 |
+
print(f"Error running benchmark for request rate: {request_rate}")
|
307 |
+
print("-" * 30)
|
308 |
+
|
309 |
+
# Determine output file name
|
310 |
+
if args.output_file:
|
311 |
+
output_file_name = args.output_file
|
312 |
+
else:
|
313 |
+
now = datetime.now().strftime("%m%d")
|
314 |
+
output_file_name = f"{args.backend}_{now}_{args.num_prompts}_{args.random_input_len}_{args.random_output_len}.jsonl"
|
315 |
+
|
316 |
+
# Append results to a JSONL file
|
317 |
+
with open(output_file_name, "a") as file:
|
318 |
+
file.write(json.dumps(result) + "\n")
|
319 |
+
|
320 |
+
result = {
|
321 |
+
"duration": benchmark_duration,
|
322 |
+
"completed": metrics.completed,
|
323 |
+
"total_input_tokens": metrics.total_input,
|
324 |
+
"total_output_tokens": metrics.total_output,
|
325 |
+
"total_output_tokens_retokenized": metrics.total_output_retokenized,
|
326 |
+
"request_throughput": metrics.request_throughput,
|
327 |
+
"input_throughput": metrics.input_throughput,
|
328 |
+
"output_throughput": metrics.output_throughput,
|
329 |
+
"mean_ttft_ms": metrics.mean_ttft_ms,
|
330 |
+
"median_ttft_ms": metrics.median_ttft_ms,
|
331 |
+
"std_ttft_ms": metrics.std_ttft_ms,
|
332 |
+
"p99_ttft_ms": metrics.p99_ttft_ms,
|
333 |
+
"mean_tpot_ms": metrics.mean_tpot_ms,
|
334 |
+
"median_tpot_ms": metrics.median_tpot_ms,
|
335 |
+
"std_tpot_ms": metrics.std_tpot_ms,
|
336 |
+
"p99_tpot_ms": metrics.p99_tpot_ms,
|
337 |
+
"mean_itl_ms": metrics.mean_itl_ms,
|
338 |
+
"median_itl_ms": metrics.median_itl_ms,
|
339 |
+
"std_itl_ms": metrics.std_itl_ms,
|
340 |
+
"p99_itl_ms": metrics.p99_itl_ms,
|
341 |
+
"input_lens": [output.prompt_len for output in outputs],
|
342 |
+
"output_lens": output_lens,
|
343 |
+
"ttfts": [output.ttft for output in outputs],
|
344 |
+
"itls": [output.itl for output in outputs],
|
345 |
+
"generated_texts": [output.generated_text for output in outputs],
|
346 |
+
"errors": [output.error for output in outputs],
|
347 |
+
"mean_e2e_latency_ms": metrics.mean_e2e_latency_ms,
|
348 |
+
"median_e2e_latency_ms": metrics.median_e2e_latency_ms,
|
349 |
+
}
|
350 |
+
return result
|
351 |
+
|
352 |
+
|
353 |
+
def run_benchmark(args_: argparse.Namespace):
|
354 |
+
global args
|
355 |
+
args = args_
|
356 |
+
|
357 |
+
# Set global environments
|
358 |
+
set_ulimit()
|
359 |
+
random.seed(args.seed)
|
360 |
+
np.random.seed(args.seed)
|
361 |
+
|
362 |
+
# Set url
|
363 |
+
if args.port is None:
|
364 |
+
args.port = {
|
365 |
+
"sglang": 30000,
|
366 |
+
}.get(args.backend, 30000)
|
367 |
+
|
368 |
+
# api_url = (
|
369 |
+
# f"{args.base_url}/v1/completions"
|
370 |
+
# if args.base_url
|
371 |
+
# else f"http://{args.host}:{args.port}/v1/completions"
|
372 |
+
# )
|
373 |
+
api_url = (
|
374 |
+
f"{args.base_url}/generate"
|
375 |
+
if args.base_url
|
376 |
+
else f"http://{args.host}:{args.port}/generate"
|
377 |
+
)
|
378 |
+
|
379 |
+
print(f"{args}\n")
|
380 |
+
|
381 |
+
# Read dataset
|
382 |
+
backend = args.backend
|
383 |
+
model_id = args.model = LORA_PATH["base"]
|
384 |
+
tokenizer_id = args.model
|
385 |
+
|
386 |
+
tokenizer = get_tokenizer(tokenizer_id)
|
387 |
+
|
388 |
+
input_requests = sample_random_requests(
|
389 |
+
input_len=args.random_input_len,
|
390 |
+
output_len=args.random_output_len,
|
391 |
+
num_prompts=args.num_prompts,
|
392 |
+
range_ratio=args.random_range_ratio,
|
393 |
+
tokenizer=tokenizer,
|
394 |
+
dataset_path="",
|
395 |
+
)
|
396 |
+
|
397 |
+
return asyncio.run(
|
398 |
+
benchmark(
|
399 |
+
backend=backend,
|
400 |
+
api_url=api_url,
|
401 |
+
model_id=model_id,
|
402 |
+
tokenizer=tokenizer,
|
403 |
+
input_requests=input_requests,
|
404 |
+
request_rate=args.request_rate,
|
405 |
+
disable_tqdm=False,
|
406 |
+
extra_request_body={},
|
407 |
+
)
|
408 |
+
)
|
409 |
+
|
410 |
+
|
411 |
+
def set_ulimit(target_soft_limit=65535):
|
412 |
+
resource_type = resource.RLIMIT_NOFILE
|
413 |
+
current_soft, current_hard = resource.getrlimit(resource_type)
|
414 |
+
|
415 |
+
if current_soft < target_soft_limit:
|
416 |
+
try:
|
417 |
+
resource.setrlimit(resource_type, (target_soft_limit, current_hard))
|
418 |
+
except ValueError as e:
|
419 |
+
print(f"Fail to set RLIMIT_NOFILE: {e}")
|
420 |
+
|
421 |
+
|
422 |
+
if __name__ == "__main__":
|
423 |
+
parser = ArgumentParser(description="Benchmark the online lora serving throughput.")
|
424 |
+
parser.add_argument(
|
425 |
+
"--backend",
|
426 |
+
type=str,
|
427 |
+
choices=list(ASYNC_REQUEST_FUNCS.keys()),
|
428 |
+
default="sglang",
|
429 |
+
help="Must specify a backend, depending on the LLM Inference Engine.",
|
430 |
+
)
|
431 |
+
parser.add_argument(
|
432 |
+
"--base-url",
|
433 |
+
type=str,
|
434 |
+
default=None,
|
435 |
+
help="Server or API base url if not using http host and port.",
|
436 |
+
)
|
437 |
+
parser.add_argument(
|
438 |
+
"--host", type=str, default="0.0.0.0", help="Default host is 0.0.0.0."
|
439 |
+
)
|
440 |
+
parser.add_argument(
|
441 |
+
"--port",
|
442 |
+
type=int,
|
443 |
+
help="If not set, the default port is configured according to its default value for different LLM Inference Engines.",
|
444 |
+
)
|
445 |
+
parser.add_argument(
|
446 |
+
"--num-prompts",
|
447 |
+
type=int,
|
448 |
+
default=50,
|
449 |
+
help="Number of prompts to process. Default is 1000.",
|
450 |
+
)
|
451 |
+
parser.add_argument(
|
452 |
+
"--random-input-len",
|
453 |
+
type=int,
|
454 |
+
default=1024,
|
455 |
+
help="Number of input tokens per request, used only for random dataset.",
|
456 |
+
)
|
457 |
+
parser.add_argument(
|
458 |
+
"--random-output-len",
|
459 |
+
type=int,
|
460 |
+
default=128,
|
461 |
+
help="Number of output tokens per request, used only for random dataset.",
|
462 |
+
)
|
463 |
+
parser.add_argument(
|
464 |
+
"--random-range-ratio",
|
465 |
+
type=float,
|
466 |
+
default=0.0,
|
467 |
+
help="Range of sampled ratio of input/output length, "
|
468 |
+
"used only for random dataset.",
|
469 |
+
)
|
470 |
+
parser.add_argument(
|
471 |
+
"--request-rate",
|
472 |
+
type=float,
|
473 |
+
default=float("inf"),
|
474 |
+
help="Number of requests per second. If this is inf, then all the requests are sent at time 0. "
|
475 |
+
"Otherwise, we use Poisson process to synthesize the request arrival times. Default is inf.",
|
476 |
+
)
|
477 |
+
parser.add_argument(
|
478 |
+
"--base-only",
|
479 |
+
action="store_true",
|
480 |
+
)
|
481 |
+
parser.add_argument("--output-file", type=str, help="Output JSONL file name.")
|
482 |
+
parser.add_argument("--seed", type=int, default=1, help="The random seed.")
|
483 |
+
args = parser.parse_args()
|
484 |
+
run_benchmark(args)
|
sglang/benchmark/mmlu/README.md
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Download data
|
2 |
+
```
|
3 |
+
bash download_data.sh
|
4 |
+
```
|
5 |
+
|
6 |
+
## Run benchmark
|
7 |
+
|
8 |
+
### Benchmark sglang
|
9 |
+
```
|
10 |
+
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
|
11 |
+
```
|
12 |
+
|
13 |
+
```
|
14 |
+
python3 bench_sglang.py --nsub 10
|
15 |
+
```
|
16 |
+
|
17 |
+
```
|
18 |
+
# OpenAI models
|
19 |
+
python3 bench_sglang.py --backend gpt-3.5-turbo --parallel 8
|
20 |
+
```
|
21 |
+
|
22 |
+
### Benchmark vllm
|
23 |
+
```
|
24 |
+
python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
|
25 |
+
```
|
26 |
+
|
27 |
+
```
|
28 |
+
python3 bench_other.py --nsub 10 --backend vllm
|
29 |
+
```
|
30 |
+
|
31 |
+
|
32 |
+
### Benchmark lightllm
|
33 |
+
```
|
34 |
+
# A10G
|
35 |
+
python -m lightllm.server.api_server --tokenizer_mode auto --model_dir ~/model_weights/llama-2-7b-chat-hf --max_total_token_num 16000 --port 22000
|
36 |
+
|
37 |
+
# V100
|
38 |
+
python -m lightllm.server.api_server --tokenizer_mode auto --model_dir ~/model_weights/llama-2-7b-chat-hf --max_total_token_num 4500 --port 22000
|
39 |
+
```
|
40 |
+
|
41 |
+
```
|
42 |
+
python3 bench_other.py --nsub 10 --backend lightllm
|
43 |
+
```
|
44 |
+
|
45 |
+
|
46 |
+
### Benchmark guidance
|
47 |
+
```
|
48 |
+
python3 bench_other.py --nsub 10 --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf
|
49 |
+
```
|
50 |
+
|
51 |
+
|
52 |
+
### Benchmark lmql
|
53 |
+
```
|
54 |
+
CUDA_VISIBLE_DEVICES=0,1 lmql serve-model meta-llama/Llama-2-7b-chat-hf --cuda --port 23000
|
55 |
+
```
|
56 |
+
|
57 |
+
```
|
58 |
+
python3 bench_other.py --nsub 10 --backend lmql --parallel 2
|
59 |
+
```
|
sglang/benchmark/mtbench/README.md
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Download Dataset
|
2 |
+
|
3 |
+
```sh
|
4 |
+
wget -O question.jsonl https://raw.githubusercontent.com/lm-sys/FastChat/main/fastchat/llm_judge/data/mt_bench/question.jsonl
|
5 |
+
```
|
6 |
+
|
7 |
+
## Run benchmark
|
8 |
+
|
9 |
+
### Benchmark sglang
|
10 |
+
```
|
11 |
+
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
|
12 |
+
```
|
13 |
+
|
14 |
+
```
|
15 |
+
python3 bench_sglang.py --num-questions 80
|
16 |
+
```
|
17 |
+
|
18 |
+
|
19 |
+
### Benchmark vllm
|
20 |
+
```
|
21 |
+
python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
|
22 |
+
```
|
23 |
+
|
24 |
+
```
|
25 |
+
python3 bench_other.py --num-questions 80 --backend vllm
|
26 |
+
```
|
27 |
+
|
28 |
+
|
29 |
+
### Benchmark lightllm
|
30 |
+
```
|
31 |
+
# A10G
|
32 |
+
python -m lightllm.server.api_server --tokenizer_mode auto --model_dir ~/model_weights/llama-2-7b-chat-hf --max_total_token_num 16000 --port 22000
|
33 |
+
```
|
34 |
+
|
35 |
+
```
|
36 |
+
python3 bench_other.py --num-questions 80 --backend lightllm
|
37 |
+
```
|
sglang/benchmark/mtbench/bench_other.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import time
|
5 |
+
import uuid
|
6 |
+
from concurrent.futures import ThreadPoolExecutor
|
7 |
+
|
8 |
+
from fastchat.model import get_conversation_template
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
|
12 |
+
|
13 |
+
|
14 |
+
def load_questions(filename):
|
15 |
+
questions = []
|
16 |
+
with open(filename, "r") as fin:
|
17 |
+
for line in fin:
|
18 |
+
obj = json.loads(line)
|
19 |
+
questions.append(obj)
|
20 |
+
return questions
|
21 |
+
|
22 |
+
|
23 |
+
def write_answers(filename, model_id, questions, answers):
|
24 |
+
with open(os.path.expanduser(filename), "w") as fout:
|
25 |
+
for i in range(len(answers)):
|
26 |
+
ans_json = {
|
27 |
+
"question_id": questions[i]["question_id"],
|
28 |
+
"answer_id": uuid.uuid4().hex,
|
29 |
+
"model_id": model_id,
|
30 |
+
"choices": {
|
31 |
+
"index": 0,
|
32 |
+
"turns": [answers[i][0], answers[i][1]],
|
33 |
+
},
|
34 |
+
"tstamp": time.time(),
|
35 |
+
}
|
36 |
+
fout.write(json.dumps(ans_json) + "\n")
|
37 |
+
|
38 |
+
|
39 |
+
def main(args):
|
40 |
+
questions = load_questions(args.question_file)
|
41 |
+
questions = (questions * 10)[: args.num_questions]
|
42 |
+
max_tokens = 256
|
43 |
+
model_id = "llama-2-chat"
|
44 |
+
|
45 |
+
conv_main = get_conversation_template(model_id)
|
46 |
+
|
47 |
+
# Select backend
|
48 |
+
call_generate = get_call_generate(args)
|
49 |
+
|
50 |
+
answers = [None] * len(questions)
|
51 |
+
|
52 |
+
def get_answer(i):
|
53 |
+
conv = conv_main.copy()
|
54 |
+
cur_answers = []
|
55 |
+
for j in range(2):
|
56 |
+
q = questions[i]["turns"][j]
|
57 |
+
conv.append_message(conv.roles[0], q)
|
58 |
+
conv.append_message(conv.roles[1], None)
|
59 |
+
|
60 |
+
prompt = conv.get_prompt()
|
61 |
+
output = call_generate(prompt, temperature=0, max_tokens=max_tokens).strip()
|
62 |
+
|
63 |
+
cur_answers.append(output)
|
64 |
+
conv.update_last_message(output)
|
65 |
+
|
66 |
+
answers[i] = cur_answers
|
67 |
+
|
68 |
+
# Run requests
|
69 |
+
tic = time.time()
|
70 |
+
if args.parallel == 1:
|
71 |
+
for i in tqdm(range(len(questions))):
|
72 |
+
get_answer(i)
|
73 |
+
else:
|
74 |
+
with ThreadPoolExecutor(args.parallel) as executor:
|
75 |
+
list(
|
76 |
+
tqdm(
|
77 |
+
executor.map(get_answer, list(range(len(questions)))),
|
78 |
+
total=len(questions),
|
79 |
+
)
|
80 |
+
)
|
81 |
+
|
82 |
+
latency = time.time() - tic
|
83 |
+
|
84 |
+
print(f"#questions: {len(questions)}, Latency: {latency:.2f}")
|
85 |
+
|
86 |
+
# Write results
|
87 |
+
answer_file = args.answer_file or f"tmp_output_{args.backend}.txt"
|
88 |
+
write_answers(answer_file, model_id, questions, answers)
|
89 |
+
|
90 |
+
with open(args.result_file, "a") as fout:
|
91 |
+
value = {
|
92 |
+
"task": "mtbench",
|
93 |
+
"backend": args.backend,
|
94 |
+
"num_gpus": 1,
|
95 |
+
"latency": round(latency, 3),
|
96 |
+
"num_requests": args.num_questions,
|
97 |
+
"other": {
|
98 |
+
"num_questions": args.num_questions,
|
99 |
+
"parallel": args.parallel,
|
100 |
+
},
|
101 |
+
}
|
102 |
+
fout.write(json.dumps(value) + "\n")
|
103 |
+
|
104 |
+
|
105 |
+
if __name__ == "__main__":
|
106 |
+
parser = argparse.ArgumentParser()
|
107 |
+
parser.add_argument("--question-file", type=str, default="question.jsonl")
|
108 |
+
parser.add_argument("--answer-file", type=str, default=None)
|
109 |
+
parser.add_argument("--num-questions", type=int, default=80)
|
110 |
+
args = add_common_other_args_and_parse(parser)
|
111 |
+
main(args)
|
sglang/benchmark/mtbench/bench_sglang.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import time
|
5 |
+
import uuid
|
6 |
+
|
7 |
+
import sglang as sgl
|
8 |
+
from sglang.test.test_utils import (
|
9 |
+
add_common_sglang_args_and_parse,
|
10 |
+
select_sglang_backend,
|
11 |
+
)
|
12 |
+
|
13 |
+
|
14 |
+
def load_questions(filename):
|
15 |
+
questions = []
|
16 |
+
with open(filename, "r") as fin:
|
17 |
+
for line in fin:
|
18 |
+
obj = json.loads(line)
|
19 |
+
questions.append(obj)
|
20 |
+
return questions
|
21 |
+
|
22 |
+
|
23 |
+
def write_answers(filename, model_id, questions, answers):
|
24 |
+
with open(os.path.expanduser(filename), "w") as fout:
|
25 |
+
for i in range(len(answers)):
|
26 |
+
ans_json = {
|
27 |
+
"question_id": questions[i]["question_id"],
|
28 |
+
"answer_id": uuid.uuid4().hex,
|
29 |
+
"model_id": model_id,
|
30 |
+
"choices": {
|
31 |
+
"index": 0,
|
32 |
+
"turns": [answers[i][0], answers[i][1]],
|
33 |
+
},
|
34 |
+
"tstamp": time.time(),
|
35 |
+
}
|
36 |
+
fout.write(json.dumps(ans_json) + "\n")
|
37 |
+
|
38 |
+
|
39 |
+
@sgl.function
|
40 |
+
def answer_mt_bench(s, question_1, question_2):
|
41 |
+
s += sgl.system()
|
42 |
+
s += sgl.user(question_1)
|
43 |
+
s += sgl.assistant(sgl.gen("answer_1"))
|
44 |
+
s += sgl.user(question_2)
|
45 |
+
s += sgl.assistant(sgl.gen("answer_2"))
|
46 |
+
|
47 |
+
|
48 |
+
def main(args):
|
49 |
+
# Construct prompts
|
50 |
+
questions = load_questions(args.question_file)[: args.num_questions]
|
51 |
+
arguments = [
|
52 |
+
{"question_1": q["turns"][0], "question_2": q["turns"][1]} for q in questions
|
53 |
+
]
|
54 |
+
|
55 |
+
# Select backend
|
56 |
+
backend = select_sglang_backend(args)
|
57 |
+
sgl.set_default_backend(backend)
|
58 |
+
|
59 |
+
# Run requests
|
60 |
+
tic = time.time()
|
61 |
+
rets = answer_mt_bench.run_batch(
|
62 |
+
arguments,
|
63 |
+
temperature=0,
|
64 |
+
max_new_tokens=256,
|
65 |
+
num_threads=args.parallel,
|
66 |
+
progress_bar=True,
|
67 |
+
)
|
68 |
+
answers = [[s["answer_1"], s["answer_2"]] for s in rets]
|
69 |
+
latency = time.time() - tic
|
70 |
+
|
71 |
+
print(f"#questions: {len(questions)}, Latency: {latency:.2f}")
|
72 |
+
|
73 |
+
# Write results
|
74 |
+
model_id = backend.model_info["model_path"]
|
75 |
+
answer_file = args.answer_file or f"tmp_output_{args.backend}.txt"
|
76 |
+
write_answers(answer_file, model_id, questions, answers)
|
77 |
+
|
78 |
+
with open(args.result_file, "a") as fout:
|
79 |
+
value = {
|
80 |
+
"task": "mtbench",
|
81 |
+
"backend": args.backend,
|
82 |
+
"num_gpus": 1,
|
83 |
+
"latency": round(latency, 3),
|
84 |
+
"num_requests": args.num_questions,
|
85 |
+
"other": {
|
86 |
+
"num_questions": args.num_questions,
|
87 |
+
"parallel": args.parallel,
|
88 |
+
},
|
89 |
+
}
|
90 |
+
fout.write(json.dumps(value) + "\n")
|
91 |
+
|
92 |
+
|
93 |
+
if __name__ == "__main__":
|
94 |
+
parser = argparse.ArgumentParser()
|
95 |
+
parser.add_argument("--question-file", type=str, default="question.jsonl")
|
96 |
+
parser.add_argument("--answer-file", type=str, default=None)
|
97 |
+
parser.add_argument("--num-questions", type=int, default=80)
|
98 |
+
args = add_common_sglang_args_and_parse(parser)
|
99 |
+
main(args)
|
sglang/benchmark/multi_chain_reasoning/README.md
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Download data
|
2 |
+
```
|
3 |
+
wget https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl
|
4 |
+
```
|
5 |
+
|
6 |
+
## Run benchmark
|
7 |
+
|
8 |
+
### Benchmark sglang
|
9 |
+
```
|
10 |
+
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --schedule-conservativeness 1.3
|
11 |
+
```
|
12 |
+
|
13 |
+
```
|
14 |
+
python3 bench_sglang.py --num-questions 64
|
15 |
+
python3 bench_sglang.py --num-questions 32 --parallel 1
|
16 |
+
```
|
17 |
+
|
18 |
+
|
19 |
+
### Benchmark vllm
|
20 |
+
```
|
21 |
+
python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
|
22 |
+
```
|
23 |
+
|
24 |
+
```
|
25 |
+
python3 bench_other.py --num-questions 64 --backend vllm
|
26 |
+
```
|
27 |
+
|
28 |
+
|
29 |
+
### Benchmark lightllm
|
30 |
+
```
|
31 |
+
# A10G
|
32 |
+
python -m lightllm.server.api_server --tokenizer_mode auto --model_dir ~/model_weights/llama-2-7b-chat-hf --max_total_token_num 16000 --port 22000
|
33 |
+
```
|
34 |
+
|
35 |
+
```
|
36 |
+
python3 bench_other.py --num-questions 64 --backend lightllm
|
37 |
+
```
|
38 |
+
|
39 |
+
|
40 |
+
### Benchmark guidance
|
41 |
+
```
|
42 |
+
python3 bench_other.py --num-questions 8 --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf
|
43 |
+
```
|
44 |
+
|
45 |
+
### Benchmark lmql
|
46 |
+
|
47 |
+
```
|
48 |
+
python3 bench_other.py --num-questions 64 --backend lmql --parallel 1
|
49 |
+
```
|
sglang/benchmark/multi_chain_reasoning/bench_sglang.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import ast
|
3 |
+
import json
|
4 |
+
import re
|
5 |
+
import time
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
from sglang.test.test_utils import (
|
10 |
+
add_common_sglang_args_and_parse,
|
11 |
+
select_sglang_backend,
|
12 |
+
)
|
13 |
+
from sglang.utils import dump_state_text, read_jsonl
|
14 |
+
|
15 |
+
INVALID = -9999999
|
16 |
+
|
17 |
+
|
18 |
+
def get_answer_value(answer_str):
|
19 |
+
answer_str = answer_str.replace(",", "")
|
20 |
+
numbers = re.findall(r"\d+", answer_str)
|
21 |
+
if len(numbers) < 1:
|
22 |
+
return INVALID
|
23 |
+
try:
|
24 |
+
return ast.literal_eval(numbers[-1])
|
25 |
+
except SyntaxError:
|
26 |
+
return INVALID
|
27 |
+
|
28 |
+
|
29 |
+
prompt_lib = [
|
30 |
+
"Let us think step by step.",
|
31 |
+
"Approach this methodically. Let's dissect the problem into smaller, more manageable parts.",
|
32 |
+
"It's important to proceed step by step, ensuring accuracy at each stage.",
|
33 |
+
"Take a deep breath and break this down.",
|
34 |
+
"A little bit of arithmetic and a logical approach will help us quickly arrive at the solution to this problem.",
|
35 |
+
"I am extremely good at math.",
|
36 |
+
]
|
37 |
+
|
38 |
+
|
39 |
+
def main(args):
|
40 |
+
lines = read_jsonl(args.data_path)
|
41 |
+
|
42 |
+
# Construct prompts
|
43 |
+
# k = args.num_shot
|
44 |
+
# few_shot_examples = get_few_shot_examples(lines, k)
|
45 |
+
|
46 |
+
questions = []
|
47 |
+
labels = []
|
48 |
+
for i in range(len(lines[: args.num_questions])):
|
49 |
+
questions.append(lines[i]["question"])
|
50 |
+
labels.append(get_answer_value(lines[i]["answer"]))
|
51 |
+
assert all(l != INVALID for l in labels)
|
52 |
+
arguments = [{"question": q} for q in questions]
|
53 |
+
|
54 |
+
num_chains = args.num_chains
|
55 |
+
|
56 |
+
#####################################
|
57 |
+
######### SGL Program Begin #########
|
58 |
+
#####################################
|
59 |
+
|
60 |
+
import sglang as sgl
|
61 |
+
|
62 |
+
@sgl.function
|
63 |
+
def multi_chain_gsm8k(s, question):
|
64 |
+
s += "Question: " + question + "\n"
|
65 |
+
# s += "Answer: " + prompt_lib[0] + sgl.gen("answer", max_tokens=256, stop="Question",
|
66 |
+
# temperature=0)
|
67 |
+
# return
|
68 |
+
|
69 |
+
forks = s.fork(num_chains)
|
70 |
+
for i in range(num_chains):
|
71 |
+
forks[i] += (
|
72 |
+
"Answer: "
|
73 |
+
+ prompt_lib[i % num_chains]
|
74 |
+
+ sgl.gen("chain", max_tokens=256, temperature=0.3, stop="Question")
|
75 |
+
)
|
76 |
+
forks.join()
|
77 |
+
|
78 |
+
s += "Answer: To answer this question, here are some possible solutions. "
|
79 |
+
s += "After considering all of them, I will do a majority vote.\n\n"
|
80 |
+
for i in range(num_chains):
|
81 |
+
s += f"Solution {i+1}: " + forks[i]["chain"].strip() + "\n\n"
|
82 |
+
s += "\nBy considering the above solutions and doing a majority vote, I think the final answer (a single integer number) is "
|
83 |
+
s += sgl.gen("answer", max_tokens=16)
|
84 |
+
|
85 |
+
#####################################
|
86 |
+
########## SGL Program End ##########
|
87 |
+
#####################################
|
88 |
+
|
89 |
+
# Select backend
|
90 |
+
backend = select_sglang_backend(args)
|
91 |
+
|
92 |
+
# Run requests
|
93 |
+
tic = time.time()
|
94 |
+
states = multi_chain_gsm8k.run_batch(
|
95 |
+
arguments,
|
96 |
+
temperature=0,
|
97 |
+
backend=backend,
|
98 |
+
num_threads=args.parallel,
|
99 |
+
progress_bar=True,
|
100 |
+
)
|
101 |
+
latency = time.time() - tic
|
102 |
+
|
103 |
+
preds = []
|
104 |
+
for i in range(len(states)):
|
105 |
+
preds.append(get_answer_value(states[i]["answer"]))
|
106 |
+
|
107 |
+
# Compute accuracy
|
108 |
+
acc = np.mean(np.array(preds) == np.array(labels))
|
109 |
+
invalid = np.mean(np.array(preds) == INVALID)
|
110 |
+
print(f"Latency: {latency:.3f}")
|
111 |
+
print(f"Invalid: {invalid:.3f}")
|
112 |
+
print(f"Accuracy: {acc:.3f}")
|
113 |
+
|
114 |
+
# Write results
|
115 |
+
dump_state_text(f"tmp_output_{args.backend}.txt", states)
|
116 |
+
|
117 |
+
with open(args.result_file, "a") as fout:
|
118 |
+
value = {
|
119 |
+
"task": "multi_chain_gsm8k",
|
120 |
+
"backend": args.backend,
|
121 |
+
"num_gpus": 1,
|
122 |
+
"latency": round(latency, 3),
|
123 |
+
"accuracy": round(acc, 3),
|
124 |
+
"num_requests": args.num_questions,
|
125 |
+
"other": {
|
126 |
+
"num_questions": args.num_questions,
|
127 |
+
"parallel": args.parallel,
|
128 |
+
},
|
129 |
+
}
|
130 |
+
fout.write(json.dumps(value) + "\n")
|
131 |
+
|
132 |
+
|
133 |
+
if __name__ == "__main__":
|
134 |
+
parser = argparse.ArgumentParser()
|
135 |
+
parser.add_argument("--num-shot", type=int, default=0)
|
136 |
+
parser.add_argument("--num-chains", type=int, default=5)
|
137 |
+
parser.add_argument("--data-path", type=str, default="test.jsonl")
|
138 |
+
parser.add_argument("--num-questions", type=int, default=50)
|
139 |
+
args = add_common_sglang_args_and_parse(parser)
|
140 |
+
main(args)
|
sglang/benchmark/multi_turn_chat/README.md
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
### Benchmark sglang
|
2 |
+
|
3 |
+
Run Llama-7B
|
4 |
+
|
5 |
+
```
|
6 |
+
python3 -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
|
7 |
+
```
|
8 |
+
|
9 |
+
Run Mixtral-8x7B
|
10 |
+
(When there is a CUDA out-of-memory error, try to reduce the `--mem-fraction-static`)
|
11 |
+
|
12 |
+
```
|
13 |
+
python3 -m sglang.launch_server --model-path mistralai/Mixtral-8x7B-Instruct-v0.1 --port 30000 --tp-size 8
|
14 |
+
```
|
15 |
+
|
16 |
+
Benchmark(short output)
|
17 |
+
|
18 |
+
```
|
19 |
+
python3 bench_sglang.py --tokenizer meta-llama/Llama-2-7b-chat-hf
|
20 |
+
```
|
21 |
+
|
22 |
+
Benchmark(long output)
|
23 |
+
|
24 |
+
```
|
25 |
+
python3 bench_sglang.py --tokenizer meta-llama/Llama-2-7b-chat-hf --long
|
26 |
+
```
|
27 |
+
|
28 |
+
### Benchmark vLLM
|
29 |
+
|
30 |
+
Run Llama-7B
|
31 |
+
|
32 |
+
```
|
33 |
+
python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
|
34 |
+
```
|
35 |
+
|
36 |
+
Run Mixtral-8x7B
|
37 |
+
|
38 |
+
```
|
39 |
+
python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model mistralai/Mixtral-8x7B-Instruct-v0.1 --disable-log-requests --port 21000 --tensor-parallel-size 8
|
40 |
+
```
|
41 |
+
|
42 |
+
Benchmark(short output)
|
43 |
+
|
44 |
+
```
|
45 |
+
python3 bench_other.py --tokenizer meta-llama/Llama-2-7b-chat-hf --backend vllm
|
46 |
+
```
|
47 |
+
|
48 |
+
Benchmark(long output)
|
49 |
+
|
50 |
+
```
|
51 |
+
python3 bench_other.py --tokenizer meta-llama/Llama-2-7b-chat-hf --backend vllm --long
|
52 |
+
```
|
53 |
+
|
54 |
+
### Benchmark guidance
|
55 |
+
|
56 |
+
Benchmark Llama-7B (short output)
|
57 |
+
|
58 |
+
```
|
59 |
+
python3 bench_other.py --tokenizer meta-llama/Llama-2-7b-chat-hf --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf
|
60 |
+
```
|
61 |
+
|
62 |
+
Benchmark Llama-7B (long output)
|
63 |
+
|
64 |
+
```
|
65 |
+
python3 bench_other.py --tokenizer meta-llama/Llama-2-7b-chat-hf --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf --long
|
66 |
+
```
|
sglang/benchmark/multi_turn_chat/bench_other.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import time
|
3 |
+
from argparse import ArgumentParser
|
4 |
+
from concurrent.futures import ThreadPoolExecutor
|
5 |
+
from functools import partial
|
6 |
+
|
7 |
+
from data_gen import gen_arguments
|
8 |
+
from tqdm import tqdm
|
9 |
+
from vllm.transformers_utils.tokenizer import get_tokenizer
|
10 |
+
|
11 |
+
from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
|
12 |
+
from sglang.utils import dump_state_text
|
13 |
+
|
14 |
+
|
15 |
+
def multi_turns(generate, qas):
|
16 |
+
s = ""
|
17 |
+
for qa in qas:
|
18 |
+
s += qa["prompt"]
|
19 |
+
s += generate(s, max_tokens=qa["new_tokens"])
|
20 |
+
|
21 |
+
return s
|
22 |
+
|
23 |
+
|
24 |
+
def main(args):
|
25 |
+
print(args)
|
26 |
+
|
27 |
+
tokenizer = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code)
|
28 |
+
|
29 |
+
multi_qas = gen_arguments(args, tokenizer)
|
30 |
+
|
31 |
+
states = [None] * args.num_qa
|
32 |
+
|
33 |
+
call_generate = partial(get_call_generate(args), temperature=0)
|
34 |
+
|
35 |
+
def get_one_answer(i):
|
36 |
+
states[i] = multi_turns(generate=call_generate, **multi_qas[i])
|
37 |
+
|
38 |
+
tic = time.time()
|
39 |
+
if args.parallel == 1:
|
40 |
+
for i in tqdm(range(len(multi_qas))):
|
41 |
+
get_one_answer(i)
|
42 |
+
else:
|
43 |
+
with ThreadPoolExecutor(args.parallel) as executor:
|
44 |
+
rets = list(
|
45 |
+
tqdm(
|
46 |
+
executor.map(get_one_answer, list(range(len(multi_qas)))),
|
47 |
+
total=len(multi_qas),
|
48 |
+
)
|
49 |
+
)
|
50 |
+
for _ in rets:
|
51 |
+
pass
|
52 |
+
|
53 |
+
latency = time.time() - tic
|
54 |
+
|
55 |
+
# Compute accuracy
|
56 |
+
print(f"Latency: {latency:.3f}")
|
57 |
+
|
58 |
+
dump_state_text(f"tmp_output_{args.backend}.txt", states)
|
59 |
+
|
60 |
+
with open(args.result_file, "a") as fout:
|
61 |
+
value = {
|
62 |
+
"task": "multi_turn_chat",
|
63 |
+
"backend": args.backend,
|
64 |
+
"num_gpus": 1,
|
65 |
+
"latency": round(latency, 3),
|
66 |
+
"num_requests": args.num_qa,
|
67 |
+
"num_turns": args.turns,
|
68 |
+
"other": {
|
69 |
+
"parallel": args.parallel,
|
70 |
+
"output_mode": "long" if args.long else "short",
|
71 |
+
},
|
72 |
+
}
|
73 |
+
fout.write(json.dumps(value) + "\n")
|
74 |
+
|
75 |
+
|
76 |
+
if __name__ == "__main__":
|
77 |
+
parser = ArgumentParser()
|
78 |
+
parser.add_argument("--turns", type=int, default=4)
|
79 |
+
parser.add_argument("--num-qa", type=int, default=20)
|
80 |
+
parser.add_argument("--min-len-q", type=int, default=256)
|
81 |
+
parser.add_argument("--max-len-q", type=int, default=512)
|
82 |
+
parser.add_argument("--min-len-a", type=int, default=4)
|
83 |
+
parser.add_argument("--max-len-a", type=int, default=8)
|
84 |
+
parser.add_argument("--tokenizer", type=str, required=True)
|
85 |
+
parser.add_argument("--trust-remote-code", action="store_true")
|
86 |
+
parser.add_argument("--long", action="store_true")
|
87 |
+
args = add_common_other_args_and_parse(parser)
|
88 |
+
|
89 |
+
if args.long:
|
90 |
+
args.min_len_a = 256
|
91 |
+
args.max_len_a = 512
|
92 |
+
args.num_qa = 20
|
93 |
+
main(args)
|
sglang/benchmark/multi_turn_chat/bench_sglang.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import time
|
3 |
+
from argparse import ArgumentParser
|
4 |
+
|
5 |
+
from data_gen import gen_arguments
|
6 |
+
from vllm.transformers_utils.tokenizer import get_tokenizer
|
7 |
+
|
8 |
+
import sglang as sgl
|
9 |
+
from sglang.test.test_utils import (
|
10 |
+
add_common_sglang_args_and_parse,
|
11 |
+
select_sglang_backend,
|
12 |
+
)
|
13 |
+
from sglang.utils import dump_state_text
|
14 |
+
|
15 |
+
|
16 |
+
@sgl.function
|
17 |
+
def multi_turns(s, qas):
|
18 |
+
for qa in qas:
|
19 |
+
s += qa["prompt"]
|
20 |
+
s += sgl.gen(max_tokens=qa["new_tokens"], ignore_eos=True)
|
21 |
+
|
22 |
+
|
23 |
+
def main(args):
|
24 |
+
tokenizer = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code)
|
25 |
+
|
26 |
+
multi_qas = gen_arguments(args, tokenizer)
|
27 |
+
|
28 |
+
backend = select_sglang_backend(args)
|
29 |
+
|
30 |
+
tic = time.time()
|
31 |
+
states = multi_turns.run_batch(
|
32 |
+
multi_qas,
|
33 |
+
temperature=0,
|
34 |
+
backend=backend,
|
35 |
+
num_threads=args.parallel,
|
36 |
+
progress_bar=True,
|
37 |
+
)
|
38 |
+
latency = time.time() - tic
|
39 |
+
|
40 |
+
print(f"Latency: {latency:.3f}")
|
41 |
+
|
42 |
+
dump_state_text(f"tmp_output_{args.backend}.txt", states)
|
43 |
+
|
44 |
+
with open(args.result_file, "a") as fout:
|
45 |
+
value = {
|
46 |
+
"task": "multi_turn_chat",
|
47 |
+
"backend": args.backend,
|
48 |
+
"num_gpus": 1,
|
49 |
+
"latency": round(latency, 3),
|
50 |
+
"num_requests": args.num_qa,
|
51 |
+
"num_turns": args.turns,
|
52 |
+
"other": {
|
53 |
+
"parallel": args.parallel,
|
54 |
+
"output_mode": "long" if args.long else "short",
|
55 |
+
},
|
56 |
+
}
|
57 |
+
fout.write(json.dumps(value) + "\n")
|
58 |
+
|
59 |
+
|
60 |
+
if __name__ == "__main__":
|
61 |
+
parser = ArgumentParser()
|
62 |
+
parser.add_argument("--turns", type=int, default=4)
|
63 |
+
parser.add_argument("--num-qa", type=int, default=20)
|
64 |
+
parser.add_argument("--min-len-q", type=int, default=256)
|
65 |
+
parser.add_argument("--max-len-q", type=int, default=512)
|
66 |
+
parser.add_argument("--min-len-a", type=int, default=4)
|
67 |
+
parser.add_argument("--max-len-a", type=int, default=8)
|
68 |
+
parser.add_argument("--tokenizer", type=str, required=True)
|
69 |
+
parser.add_argument("--trust-remote-code", action="store_true")
|
70 |
+
parser.add_argument("--long", action="store_true")
|
71 |
+
args = add_common_sglang_args_and_parse(parser)
|
72 |
+
|
73 |
+
if args.long:
|
74 |
+
args.min_len_a = 256
|
75 |
+
args.max_len_a = 512
|
76 |
+
args.num_qa = 20
|
77 |
+
|
78 |
+
print(args)
|
79 |
+
main(args)
|