Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- InternVL/.github/CONTRIBUTING.md +234 -0
- InternVL/internvl_chat_llava/LICENSE +201 -0
- InternVL/internvl_chat_llava/README.md +506 -0
- InternVL/internvl_chat_llava/pyproject.toml +33 -0
- InternVL/internvl_g/README.md +497 -0
- InternVL/segmentation/dist_test.sh +9 -0
- InternVL/segmentation/dist_train.sh +9 -0
- InternVL/segmentation/train.py +220 -0
- InternVL/streamlit_demo/constants.py +23 -0
- InternVL/streamlit_demo/controller.py +291 -0
- InternVL/streamlit_demo/model_worker.py +442 -0
- InternVL/video_retrieval/test_msrvtt.py +156 -0
- sglang/examples/frontend_language/quick_start/gemini_example_chat.py +73 -0
- sglang/examples/frontend_language/quick_start/gemini_example_multimodal_chat.py +30 -0
- sglang/examples/frontend_language/quick_start/local_example_complete.py +70 -0
- sglang/examples/frontend_language/quick_start/local_example_llava_next.py +78 -0
- sglang/examples/frontend_language/quick_start/openai_example_chat.py +74 -0
- sglang/examples/frontend_language/quick_start/openai_example_complete.py +68 -0
- sglang/examples/frontend_language/quick_start/openrouter_example_chat.py +81 -0
- sglang/examples/frontend_language/quick_start/together_example_complete.py +76 -0
- sglang/examples/frontend_language/usage/chinese_regex.py +53 -0
- sglang/examples/frontend_language/usage/choices_logprob.py +44 -0
- sglang/examples/frontend_language/usage/cot_decoding.py +115 -0
- sglang/examples/frontend_language/usage/json_decode.py +83 -0
- sglang/examples/frontend_language/usage/json_logprobs.py +104 -0
- sglang/examples/frontend_language/usage/llava_video/srt_example_llava_v.py +260 -0
- sglang/examples/frontend_language/usage/llava_video/srt_example_llava_v.sh +131 -0
- sglang/examples/frontend_language/usage/openai_chat_speculative.py +155 -0
- sglang/examples/frontend_language/usage/openai_speculative.py +54 -0
- sglang/examples/frontend_language/usage/parallel_sample.py +40 -0
- sglang/examples/frontend_language/usage/rag_using_parea/trace_and_evaluate_rag_using_parea.ipynb +408 -0
- sglang/examples/frontend_language/usage/readme_examples.py +109 -0
- sglang/examples/frontend_language/usage/sgl_gen_min_tokens.py +35 -0
- sglang/examples/frontend_language/usage/streaming.py +49 -0
- sglang/examples/frontend_language/usage/triton/Dockerfile +10 -0
- sglang/examples/frontend_language/usage/triton/README.md +35 -0
- sglang/examples/frontend_language/usage/triton/models/character_generation/1/model.py +55 -0
- sglang/examples/frontend_language/usage/triton/models/character_generation/config.pbtxt +23 -0
- sglang/examples/monitoring/grafana.json +1720 -0
- sglang/examples/monitoring/prometheus.yaml +10 -0
- sglang/examples/runtime/lora.py +37 -0
- sglang/examples/runtime/openai_batch_complete.py +93 -0
- sglang/examples/runtime/openai_chat_with_response_prefill.py +34 -0
- sglang/scripts/deprecated/convert_yi_vl.py +38 -0
- sglang/scripts/deprecated/test_httpserver_classify.py +85 -0
- sglang/scripts/deprecated/test_httpserver_decode_stream.py +69 -0
- sglang/scripts/deprecated/test_httpserver_llava.py +88 -0
- sglang/scripts/deprecated/test_httpserver_reuse.py +42 -0
- sglang/scripts/deprecated/test_jump_forward.py +138 -0
- sglang/scripts/deprecated/test_robust.py +132 -0
InternVL/.github/CONTRIBUTING.md
ADDED
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Contributing to InternLM
|
2 |
+
|
3 |
+
Welcome to the InternLM community, all kinds of contributions are welcomed, including but not limited to
|
4 |
+
|
5 |
+
**Fix bug**
|
6 |
+
|
7 |
+
You can directly post a Pull Request to fix typo in code or documents
|
8 |
+
|
9 |
+
The steps to fix the bug of code implementation are as follows.
|
10 |
+
|
11 |
+
1. If the modification involve significant changes, you should create an issue first and describe the error information and how to trigger the bug. Other developers will discuss with you and propose an proper solution.
|
12 |
+
|
13 |
+
2. Posting a pull request after fixing the bug and adding corresponding unit test.
|
14 |
+
|
15 |
+
**New Feature or Enhancement**
|
16 |
+
|
17 |
+
1. If the modification involve significant changes, you should create an issue to discuss with our developers to propose an proper design.
|
18 |
+
2. Post a Pull Request after implementing the new feature or enhancement and add corresponding unit test.
|
19 |
+
|
20 |
+
**Document**
|
21 |
+
|
22 |
+
You can directly post a pull request to fix documents. If you want to add a document, you should first create an issue to check if it is reasonable.
|
23 |
+
|
24 |
+
### Pull Request Workflow
|
25 |
+
|
26 |
+
If you're not familiar with Pull Request, don't worry! The following guidance will tell you how to create a Pull Request step by step. If you want to dive into the develop mode of Pull Request, you can refer to the [official documents](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests/about-pull-requests)
|
27 |
+
|
28 |
+
#### 1. Fork and clone
|
29 |
+
|
30 |
+
If you are posting a pull request for the first time, you should fork the OpenMMLab repositories by clicking the **Fork** button in the top right corner of the GitHub page, and the forked repositories will appear under your GitHub profile.
|
31 |
+
|
32 |
+
<img src="https://user-images.githubusercontent.com/57566630/167305749-43c7f4e9-449b-4e98-ade5-0c9276d5c9ce.png" width="1200">
|
33 |
+
|
34 |
+
Then, you can clone the repositories to local:
|
35 |
+
|
36 |
+
```shell
|
37 |
+
git clone [email protected]:{username}/lmdeploy.git
|
38 |
+
```
|
39 |
+
|
40 |
+
After that, you should add official repository as the upstream repository
|
41 |
+
|
42 |
+
```bash
|
43 |
+
git remote add upstream [email protected]:InternLM/lmdeploy.git
|
44 |
+
```
|
45 |
+
|
46 |
+
Check whether remote repository has been added successfully by `git remote -v`
|
47 |
+
|
48 |
+
```bash
|
49 |
+
origin [email protected]:{username}/lmdeploy.git (fetch)
|
50 |
+
origin [email protected]:{username}/lmdeploy.git (push)
|
51 |
+
upstream [email protected]:InternLM/lmdeploy.git (fetch)
|
52 |
+
upstream [email protected]:InternLM/lmdeploy.git (push)
|
53 |
+
```
|
54 |
+
|
55 |
+
> Here's a brief introduction to origin and upstream. When we use "git clone", we create an "origin" remote by default, which points to the repository cloned from. As for "upstream", we add it ourselves to point to the target repository. Of course, if you don't like the name "upstream", you could name it as you wish. Usually, we'll push the code to "origin". If the pushed code conflicts with the latest code in official("upstream"), we should pull the latest code from upstream to resolve the conflicts, and then push to "origin" again. The posted Pull Request will be updated automatically.
|
56 |
+
|
57 |
+
#### 2. Configure pre-commit
|
58 |
+
|
59 |
+
You should configure [pre-commit](https://pre-commit.com/#intro) in the local development environment to make sure the code style matches that of InternLM. **Note**: The following code should be executed under the lmdeploy directory.
|
60 |
+
|
61 |
+
```shell
|
62 |
+
pip install -U pre-commit
|
63 |
+
pre-commit install
|
64 |
+
```
|
65 |
+
|
66 |
+
Check that pre-commit is configured successfully, and install the hooks defined in `.pre-commit-config.yaml`.
|
67 |
+
|
68 |
+
```shell
|
69 |
+
pre-commit run --all-files
|
70 |
+
```
|
71 |
+
|
72 |
+
<img src="https://user-images.githubusercontent.com/57566630/173660750-3df20a63-cb66-4d33-a986-1f643f1d8aaf.png" width="1200">
|
73 |
+
|
74 |
+
<img src="https://user-images.githubusercontent.com/57566630/202368856-0465a90d-8fce-4345-918e-67b8b9c82614.png" width="1200">
|
75 |
+
|
76 |
+
If the installation process is interrupted, you can repeatedly run `pre-commit run ... ` to continue the installation.
|
77 |
+
|
78 |
+
If the code does not conform to the code style specification, pre-commit will raise a warning and fixes some of the errors automatically.
|
79 |
+
|
80 |
+
<img src="https://user-images.githubusercontent.com/57566630/202369176-67642454-0025-4023-a095-263529107aa3.png" width="1200">
|
81 |
+
|
82 |
+
If we want to commit our code bypassing the pre-commit hook, we can use the `--no-verify` option(**only for temporarily commit**).
|
83 |
+
|
84 |
+
```shell
|
85 |
+
git commit -m "xxx" --no-verify
|
86 |
+
```
|
87 |
+
|
88 |
+
#### 3. Create a development branch
|
89 |
+
|
90 |
+
After configuring the pre-commit, we should create a branch based on the master branch to develop the new feature or fix the bug. The proposed branch name is `username/pr_name`
|
91 |
+
|
92 |
+
```shell
|
93 |
+
git checkout -b yhc/refactor_contributing_doc
|
94 |
+
```
|
95 |
+
|
96 |
+
In subsequent development, if the master branch of the local repository is behind the master branch of "upstream", we need to pull the upstream for synchronization, and then execute the above command:
|
97 |
+
|
98 |
+
```shell
|
99 |
+
git pull upstream master
|
100 |
+
```
|
101 |
+
|
102 |
+
#### 4. Commit the code and pass the unit test
|
103 |
+
|
104 |
+
- lmdeploy introduces mypy to do static type checking to increase the robustness of the code. Therefore, we need to add Type Hints to our code and pass the mypy check. If you are not familiar with Type Hints, you can refer to [this tutorial](https://docs.python.org/3/library/typing.html).
|
105 |
+
|
106 |
+
- The committed code should pass through the unit test
|
107 |
+
|
108 |
+
```shell
|
109 |
+
# Pass all unit tests
|
110 |
+
pytest tests
|
111 |
+
|
112 |
+
# Pass the unit test of runner
|
113 |
+
pytest tests/test_runner/test_runner.py
|
114 |
+
```
|
115 |
+
|
116 |
+
If the unit test fails for lack of dependencies, you can install the dependencies referring to the [guidance](#unit-test)
|
117 |
+
|
118 |
+
- If the documents are modified/added, we should check the rendering result referring to [guidance](#document-rendering)
|
119 |
+
|
120 |
+
#### 5. Push the code to remote
|
121 |
+
|
122 |
+
We could push the local commits to remote after passing through the check of unit test and pre-commit. You can associate the local branch with remote branch by adding `-u` option.
|
123 |
+
|
124 |
+
```shell
|
125 |
+
git push -u origin {branch_name}
|
126 |
+
```
|
127 |
+
|
128 |
+
This will allow you to use the `git push` command to push code directly next time, without having to specify a branch or the remote repository.
|
129 |
+
|
130 |
+
#### 6. Create a Pull Request
|
131 |
+
|
132 |
+
(1) Create a pull request in GitHub's Pull request interface
|
133 |
+
|
134 |
+
<img src="https://user-images.githubusercontent.com/57566630/201533288-516f7ac4-0b14-4dc8-afbd-912475c368b5.png" width="1200">
|
135 |
+
|
136 |
+
(2) Modify the PR description according to the guidelines so that other developers can better understand your changes
|
137 |
+
|
138 |
+
<img src="https://user-images.githubusercontent.com/57566630/202242953-c91a18ff-e388-4ff9-8591-5fae0ead6c1e.png" width="1200">
|
139 |
+
|
140 |
+
Find more details about Pull Request description in [pull request guidelines](#pr-specs).
|
141 |
+
|
142 |
+
**note**
|
143 |
+
|
144 |
+
(a) The Pull Request description should contain the reason for the change, the content of the change, and the impact of the change, and be associated with the relevant Issue (see [documentation](https://docs.github.com/en/issues/tracking-your-work-with-issues/linking-a-pull-request-to-an-issue))
|
145 |
+
|
146 |
+
(b) If it is your first contribution, please sign the CLA
|
147 |
+
|
148 |
+
<img src="https://user-images.githubusercontent.com/57566630/167307569-a794b967-6e28-4eac-a942-00deb657815f.png" width="1200">
|
149 |
+
|
150 |
+
(c) Check whether the Pull Request pass through the CI
|
151 |
+
|
152 |
+
<img src="https://user-images.githubusercontent.com/57566630/167307490-f9ebf9fa-63c0-4d83-8ba1-081ea169eb3a.png" width="1200">
|
153 |
+
|
154 |
+
IternLM will run unit test for the posted Pull Request on different platforms (Linux, Window, Mac), based on different versions of Python, PyTorch, CUDA to make sure the code is correct. We can see the specific test information by clicking `Details` in the above image so that we can modify the code.
|
155 |
+
|
156 |
+
(3) If the Pull Request passes the CI, then you can wait for the review from other developers. You'll modify the code based on the reviewer's comments, and repeat the steps [4](#4-commit-the-code-and-pass-the-unit-test)-[5](#5-push-the-code-to-remote) until all reviewers approve it. Then, we will merge it ASAP.
|
157 |
+
|
158 |
+
<img src="https://user-images.githubusercontent.com/57566630/202145400-cc2cd8c4-10b0-472f-ba37-07e6f50acc67.png" width="1200">
|
159 |
+
|
160 |
+
#### 7. Resolve conflicts
|
161 |
+
|
162 |
+
If your local branch conflicts with the latest master branch of "upstream", you'll need to resolove them. There are two ways to do this:
|
163 |
+
|
164 |
+
```shell
|
165 |
+
git fetch --all --prune
|
166 |
+
git rebase upstream/master
|
167 |
+
```
|
168 |
+
|
169 |
+
or
|
170 |
+
|
171 |
+
```shell
|
172 |
+
git fetch --all --prune
|
173 |
+
git merge upstream/master
|
174 |
+
```
|
175 |
+
|
176 |
+
If you are very good at handling conflicts, then you can use rebase to resolve conflicts, as this will keep your commit logs tidy. If you are not familiar with `rebase`, then you can use `merge` to resolve conflicts.
|
177 |
+
|
178 |
+
### Guidance
|
179 |
+
|
180 |
+
#### Document rendering
|
181 |
+
|
182 |
+
If the documents are modified/added, we should check the rendering result. We could install the dependencies and run the following command to render the documents and check the results:
|
183 |
+
|
184 |
+
```shell
|
185 |
+
pip install -r requirements/docs.txt
|
186 |
+
cd docs/zh_cn/
|
187 |
+
# or docs/en
|
188 |
+
make html
|
189 |
+
# check file in ./docs/zh_cn/_build/html/index.html
|
190 |
+
```
|
191 |
+
|
192 |
+
### Code style
|
193 |
+
|
194 |
+
#### Python
|
195 |
+
|
196 |
+
We adopt [PEP8](https://www.python.org/dev/peps/pep-0008/) as the preferred code style.
|
197 |
+
|
198 |
+
We use the following tools for linting and formatting:
|
199 |
+
|
200 |
+
- [flake8](https://github.com/PyCQA/flake8): A wrapper around some linter tools.
|
201 |
+
- [isort](https://github.com/timothycrosley/isort): A Python utility to sort imports.
|
202 |
+
- [yapf](https://github.com/google/yapf): A formatter for Python files.
|
203 |
+
- [codespell](https://github.com/codespell-project/codespell): A Python utility to fix common misspellings in text files.
|
204 |
+
- [mdformat](https://github.com/executablebooks/mdformat): Mdformat is an opinionated Markdown formatter that can be used to enforce a consistent style in Markdown files.
|
205 |
+
- [docformatter](https://github.com/myint/docformatter): A formatter to format docstring.
|
206 |
+
|
207 |
+
We use [pre-commit hook](https://pre-commit.com/) that checks and formats for `flake8`, `yapf`, `isort`, `trailing whitespaces`, `markdown files`,
|
208 |
+
fixes `end-of-files`, `double-quoted-strings`, `python-encoding-pragma`, `mixed-line-ending`, sorts `requirments.txt` automatically on every commit.
|
209 |
+
The config for a pre-commit hook is stored in [.pre-commit-config](../.pre-commit-config.yaml).
|
210 |
+
|
211 |
+
#### C++ and CUDA
|
212 |
+
|
213 |
+
The clang-format config is stored in [.clang-format](../.clang-format). And it's recommended to use clang-format version **11**. Please do not use older or newer versions as they will result in differences after formatting, which can cause the [lint](https://github.com/InternLM/lmdeploy/blob/main/.github/workflows/lint.yml#L25) to fail.
|
214 |
+
|
215 |
+
### PR Specs
|
216 |
+
|
217 |
+
1. Use [pre-commit](https://pre-commit.com) hook to avoid issues of code style
|
218 |
+
|
219 |
+
2. One short-time branch should be matched with only one PR
|
220 |
+
|
221 |
+
3. Accomplish a detailed change in one PR. Avoid large PR
|
222 |
+
|
223 |
+
- Bad: Support Faster R-CNN
|
224 |
+
- Acceptable: Add a box head to Faster R-CNN
|
225 |
+
- Good: Add a parameter to box head to support custom conv-layer number
|
226 |
+
|
227 |
+
4. Provide clear and significant commit message
|
228 |
+
|
229 |
+
5. Provide clear and meaningful PR description
|
230 |
+
|
231 |
+
- Task name should be clarified in title. The general format is: \[Prefix\] Short description of the PR (Suffix)
|
232 |
+
- Prefix: add new feature \[Feature\], fix bug \[Fix\], related to documents \[Docs\], in developing \[WIP\] (which will not be reviewed temporarily)
|
233 |
+
- Introduce main changes, results and influences on other modules in short description
|
234 |
+
- Associate related issues and pull requests with a milestone
|
InternVL/internvl_chat_llava/LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
InternVL/internvl_chat_llava/README.md
ADDED
@@ -0,0 +1,506 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# InternVL for Multimodal Dialogue using LLaVA Codebase
|
2 |
+
|
3 |
+
This folder contains the implementation of the InternVL-Chat V1.0, which corresponds to Section 4.4 of our [InternVL 1.0 paper](https://arxiv.org/pdf/2312.14238).
|
4 |
+
|
5 |
+
In this part, we mainly use the [LLaVA codebase](https://github.com/haotian-liu/LLaVA) to evaluate InternVL in creating multimodal dialogue systems. Thanks for this great work.
|
6 |
+
We have retained the original documentation of LLaVA-1.5 as a more detailed manual. In most cases, you will only need to refer to the new documentation that we have added.
|
7 |
+
|
8 |
+
> Note: To unify the environment across different tasks, we have made some compatibility modifications to the LLaVA-1.5 code, allowing it to support `transformers==4.37.2` (originally locked at 4.31.0). Please note that `transformers==4.37.2` should be installed.
|
9 |
+
|
10 |
+
## 🛠️ Installation
|
11 |
+
|
12 |
+
First, follow the [installation guide](../INSTALLATION.md) to perform some basic installations.
|
13 |
+
|
14 |
+
In addition, using this codebase requires executing the following steps:
|
15 |
+
|
16 |
+
- Install other requirements:
|
17 |
+
|
18 |
+
```bash
|
19 |
+
pip install --upgrade pip # enable PEP 660 support
|
20 |
+
pip install -e .
|
21 |
+
```
|
22 |
+
|
23 |
+
## 📦 Model Preparation
|
24 |
+
|
25 |
+
| model name | type | download | size |
|
26 |
+
| ----------------------- | ----------- | ---------------------------------------------------------------------- | :-----: |
|
27 |
+
| InternViT-6B-224px | huggingface | 🤗 [HF link](https://huggingface.co/OpenGVLab/InternViT-6B-224px) | 12 GB |
|
28 |
+
| InternViT-6B-448px-V1-0 | huggingface | 🤗 [HF link](https://huggingface.co/OpenGVLab/InternViT-6B-448px-V1-0) | 12 GB |
|
29 |
+
| vicuna-13b-v1.5 | huggingface | 🤗 [HF link](https://huggingface.co/lmsys/vicuna-7b-v1.5) | 13.5 GB |
|
30 |
+
| vicuna-7b-v1.5 | huggingface | 🤗 [HF link](https://huggingface.co/lmsys/vicuna-13b-v1.5) | 26.1 GB |
|
31 |
+
|
32 |
+
Please download the above model weights and place them in the `pretrained/` folder.
|
33 |
+
|
34 |
+
```sh
|
35 |
+
cd pretrained/
|
36 |
+
# pip install -U huggingface_hub
|
37 |
+
huggingface-cli download --resume-download --local-dir-use-symlinks False OpenGVLab/InternViT-6B-224px --local-dir InternViT-6B-224px
|
38 |
+
huggingface-cli download --resume-download --local-dir-use-symlinks False OpenGVLab/InternViT-6B-448px-V1-0 --local-dir InternViT-6B-448px
|
39 |
+
huggingface-cli download --resume-download --local-dir-use-symlinks False lmsys/vicuna-13b-v1.5 --local-dir vicuna-13b-v1.5
|
40 |
+
huggingface-cli download --resume-download --local-dir-use-symlinks False lmsys/vicuna-7b-v1.5 --local-dir vicuna-7b-v1.5
|
41 |
+
```
|
42 |
+
|
43 |
+
The directory structure is:
|
44 |
+
|
45 |
+
```sh
|
46 |
+
pretrained
|
47 |
+
│── InternViT-6B-224px/
|
48 |
+
│── InternViT-6B-448px/
|
49 |
+
│── vicuna-13b-v1.5/
|
50 |
+
└── vicuna-7b-v1.5/
|
51 |
+
```
|
52 |
+
|
53 |
+
## 🔥 Training
|
54 |
+
|
55 |
+
- InternViT-6B-224px + Vicuna-7B:
|
56 |
+
|
57 |
+
```shell
|
58 |
+
# pretrain
|
59 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 sh scripts_internvl/pretrain_internvit6b_224to336_vicuna7b.sh
|
60 |
+
# finetune
|
61 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 sh scripts_internvl/finetune_internvit6b_224to336_vicuna7b.sh
|
62 |
+
```
|
63 |
+
|
64 |
+
- InternViT-6B-224px + Vicuna-13B:
|
65 |
+
|
66 |
+
```shell
|
67 |
+
# pretrain
|
68 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 sh scripts_internvl/pretrain_internvit6b_224to336_vicuna13b.sh
|
69 |
+
# finetune
|
70 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 sh scripts_internvl/finetune_internvit6b_224to336_vicuna13b.sh
|
71 |
+
```
|
72 |
+
|
73 |
+
- InternViT-6B-448px + Vicuna-7B:
|
74 |
+
|
75 |
+
```shell
|
76 |
+
# pretrain
|
77 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 sh scripts_internvl/pretrain_internvit6b_448_vicuna7b.sh
|
78 |
+
# finetune
|
79 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 sh scripts_internvl/finetune_internvit6b_448_vicuna7b.sh
|
80 |
+
```
|
81 |
+
|
82 |
+
- InternViT-6B-448px + Vicuna-13B:
|
83 |
+
|
84 |
+
```shell
|
85 |
+
# pretrain
|
86 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 sh scripts_internvl/pretrain_internvit6b_448_vicuna13b.sh
|
87 |
+
# finetune
|
88 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 sh scripts_internvl/finetune_internvit6b_448_vicuna13b.sh
|
89 |
+
```
|
90 |
+
|
91 |
+
## 🤗 Model Zoo
|
92 |
+
|
93 |
+
| method | vision encoder | LLM | res. | VQAv2 | GQA | VizWiz | SQA | TextVQA | POPE | MME | MMB | MMB<sub>CN</sub> | MMVet | Download |
|
94 |
+
| ----------------- | :------------: | :---: | :--: | :---: | :--: | :----: | :--: | :-----: | :--: | :----: | :--: | :--------------: | :---: | :----------------------------------------------------------------------------------: |
|
95 |
+
| LLaVA-1.5 | CLIP-L-336px | V-7B | 336 | 78.5 | 62.0 | 50.0 | 66.8 | 58.2 | 85.9 | 1510.7 | 64.3 | 58.3 | 30.5 | 🤗 [HF link](https://huggingface.co/liuhaotian/llava-v1.5-7b) |
|
96 |
+
| LLaVA-1.5 | CLIP-L-336px | V-13B | 336 | 80.0 | 63.3 | 53.6 | 71.6 | 61.3 | 85.9 | 1531.3 | 67.7 | 63.6 | 35.4 | 🤗 [HF link](https://huggingface.co/liuhaotian/llava-v1.5-13b) |
|
97 |
+
| InternVL-Chat-1.0 | IViT-6B-224px | V-7B | 336 | 79.3 | 62.9 | 52.5 | 66.2 | 57.0 | 86.4 | 1525.1 | 64.6 | 57.6 | 31.2 | 🤗 [HF link](https://huggingface.co/OpenGVLab/InternVL-Chat-ViT-6B-Vicuna-7B) |
|
98 |
+
| InternVL-Chat-1.0 | IViT-6B-224px | V-13B | 336 | 80.2 | 63.9 | 54.6 | 70.1 | 58.7 | 87.1 | 1546.9 | 66.5 | 61.9 | 33.7 | 🤗 [HF link](https://huggingface.co/OpenGVLab/InternVL-Chat-ViT-6B-Vicuna-13B) |
|
99 |
+
| InternVL-Chat-1.0 | IViT-6B-448px | V-13B | 448 | 82.0 | 64.1 | 60.1 | 71.6 | 64.8 | 87.2 | 1579.0 | 68.2 | 64.0 | 36.7 | 🤗 [HF link](https://huggingface.co/OpenGVLab/InternVL-Chat-ViT-6B-Vicuna-13B-448px) |
|
100 |
+
|
101 |
+
Please download the above model weights and place them in the `pretrained/` folder.
|
102 |
+
|
103 |
+
```shell
|
104 |
+
cd pretrained/
|
105 |
+
# pip install -U huggingface_hub
|
106 |
+
huggingface-cli download --resume-download --local-dir-use-symlinks False OpenGVLab/InternVL-Chat-ViT-6B-Vicuna-7B --local-dir InternVL-Chat-ViT-6B-Vicuna-7B
|
107 |
+
huggingface-cli download --resume-download --local-dir-use-symlinks False OpenGVLab/InternVL-Chat-ViT-6B-Vicuna-13B --local-dir InternVL-Chat-ViT-6B-Vicuna-13B
|
108 |
+
huggingface-cli download --resume-download --local-dir-use-symlinks False OpenGVLab/InternVL-Chat-ViT-6B-Vicuna-13B-448px --local-dir InternVL-Chat-ViT-6B-Vicuna-13B-448px
|
109 |
+
|
110 |
+
```
|
111 |
+
|
112 |
+
The directory structure is:
|
113 |
+
|
114 |
+
```
|
115 |
+
pretrained
|
116 |
+
│── InternViT-6B-224px/
|
117 |
+
│── InternViT-6B-448px/
|
118 |
+
│── vicuna-13b-v1.5/
|
119 |
+
│── vicuna-7b-v1.5/
|
120 |
+
│── InternVL-Chat-ViT-6B-Vicuna-7B/
|
121 |
+
│── InternVL-Chat-ViT-6B-Vicuna-13B/
|
122 |
+
└── InternVL-Chat-ViT-6B-Vicuna-13B-448px/
|
123 |
+
```
|
124 |
+
|
125 |
+
## 🖥️ Demo
|
126 |
+
|
127 |
+
The method for deploying the demo is consistent with LLaVA-1.5. You only need to change the model path. The specific steps are as follows:
|
128 |
+
|
129 |
+
**Launch a controller**
|
130 |
+
|
131 |
+
```shell
|
132 |
+
python -m llava.serve.controller --host 0.0.0.0 --port 10000
|
133 |
+
```
|
134 |
+
|
135 |
+
**Launch a gradio web server**
|
136 |
+
|
137 |
+
```shell
|
138 |
+
python -m llava.serve.gradio_web_server --controller http://localhost:10000 --model-list-mode reload --port 10038
|
139 |
+
```
|
140 |
+
|
141 |
+
**Launch a model worker**
|
142 |
+
|
143 |
+
```shell
|
144 |
+
# OpenGVLab/InternVL-Chat-ViT-6B-Vicuna-7B
|
145 |
+
python -m llava.serve.model_worker --host 0.0.0.0 --controller http://localhost:10000 --port 40000 --worker http://localhost:40000 --model-path ./pretrained/InternVL-Chat-ViT-6B-Vicuna-7B
|
146 |
+
# OpenGVLab/InternVL-Chat-ViT-6B-Vicuna-13B
|
147 |
+
python -m llava.serve.model_worker --host 0.0.0.0 --controller http://localhost:10000 --port 40001 --worker http://localhost:40001 --model-path ./pretrained/InternVL-Chat-ViT-6B-Vicuna-13B
|
148 |
+
```
|
149 |
+
|
150 |
+
After completing the above steps, you can access the web demo at `http://localhost:10038` and see the following page. Note that the models deployed here are `InternVL-Chat-ViT-6B-Vicuna-7B` and `InternVL-Chat-ViT-6B-Vicuna-13B`, which are the two models of our InternVL 1.0. The only difference from LLaVA-1.5 is that the CLIP-ViT-300M has been replaced with our InternViT-6B.
|
151 |
+
|
152 |
+
If you need a more effective MLLM, please check out our InternVL2 series models.
|
153 |
+
For more details on deploying the demo, please refer to [here](#gradio-web-ui).
|
154 |
+
|
155 |
+

|
156 |
+
|
157 |
+
## 💡 Testing
|
158 |
+
|
159 |
+
The method for testing the model remains the same as LLaVA-1.5; you just need to change the path of the script. Our scripts are located in `scripts_internvl/`.
|
160 |
+
|
161 |
+
For example, testing `MME` using a single GPU:
|
162 |
+
|
163 |
+
```shell
|
164 |
+
sh scripts_internvl/eval/mme.sh pretrained/InternVL-Chat-ViT-6B-Vicuna-7B/
|
165 |
+
```
|
166 |
+
|
167 |
+
______________________________________________________________________
|
168 |
+
|
169 |
+
## 🌋 LLaVA: Large Language and Vision Assistant
|
170 |
+
|
171 |
+
*Visual instruction tuning towards large language and vision models with GPT-4 level capabilities.*
|
172 |
+
|
173 |
+
\[[Project Page](https://llava-vl.github.io/)\] \[[Demo](https://llava.hliu.cc/)\] \[[Data](https://github.com/haotian-liu/LLaVA/blob/main/docs/Data.md)\] \[[Model Zoo](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md)\]
|
174 |
+
|
175 |
+
🤝Community Contributions: \[[llama.cpp](https://github.com/ggerganov/llama.cpp/pull/3436)\] \[[Colab](https://github.com/camenduru/LLaVA-colab)\] \[[🤗Space](https://huggingface.co/spaces/badayvedat/LLaVA)\]
|
176 |
+
|
177 |
+
**Improved Baselines with Visual Instruction Tuning** \[[Paper](https://arxiv.org/abs/2310.03744)\] <br>
|
178 |
+
[Haotian Liu](https://hliu.cc), [Chunyuan Li](https://chunyuan.li/), [Yuheng Li](https://yuheng-li.github.io/), [Yong Jae Lee](https://pages.cs.wisc.edu/~yongjaelee/)
|
179 |
+
|
180 |
+
**Visual Instruction Tuning** (NeurIPS 2023, **Oral**) \[[Paper](https://arxiv.org/abs/2304.08485)\]<br>
|
181 |
+
[Haotian Liu\*](https://hliu.cc), [Chunyuan Li\*](https://chunyuan.li/), [Qingyang Wu](https://scholar.google.ca/citations?user=HDiw-TsAAAAJ&hl=en/), [Yong Jae Lee](https://pages.cs.wisc.edu/~yongjaelee/) (\*Equal Contribution)
|
182 |
+
|
183 |
+
### Release
|
184 |
+
|
185 |
+
- \[10/12\] 🔥 Check out the Korean LLaVA (Ko-LLaVA), created by ETRI, who has generously supported our research! \[[🤗 Demo](https://huggingface.co/spaces/etri-vilab/Ko-LLaVA)\]
|
186 |
+
|
187 |
+
- \[10/12\] LLaVA is now supported in [llama.cpp](https://github.com/ggerganov/llama.cpp/pull/3436) with 4-bit / 5-bit quantization support!
|
188 |
+
|
189 |
+
- \[10/11\] The training data and scripts of LLaVA-1.5 are released [here](https://github.com/haotian-liu/LLaVA#train), and evaluation scripts are released [here](https://github.com/haotian-liu/LLaVA/blob/main/docs/Evaluation.md)!
|
190 |
+
|
191 |
+
- \[10/5\] 🔥 LLaVA-1.5 is out! Achieving SoTA on 11 benchmarks, with just simple modifications to the original LLaVA, utilizes all public data, completes training in ~1 day on a single 8-A100 node, and surpasses methods like Qwen-VL-Chat that use billion-scale data. Check out the [technical report](https://arxiv.org/abs/2310.03744), and explore the [demo](https://llava.hliu.cc/)! Models are available in [Model Zoo](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md).
|
192 |
+
|
193 |
+
- \[9/26\] LLaVA is improved with reinforcement learning from human feedback (RLHF) to improve fact grounding and reduce hallucination. Check out the new SFT and RLHF checkpoints at project [\[LLavA-RLHF\]](https://llava-rlhf.github.io/)
|
194 |
+
|
195 |
+
- \[9/22\] [LLaVA](https://arxiv.org/abs/2304.08485) is accepted by NeurIPS 2023 as **oral presentation**, and [LLaVA-Med](https://arxiv.org/abs/2306.00890) is accepted by NeurIPS 2023 Datasets and Benchmarks Track as **spotlight presentation**.
|
196 |
+
|
197 |
+
- \[9/20\] We summarize our empirical study of training 33B and 65B LLaVA models in a [note](https://arxiv.org/abs/2309.09958). Further, if you are interested in the comprehensive review, evolution and trend of multimodal foundation models, please check out our recent survey paper [\`\`Multimodal Foundation Models: From Specialists to General-Purpose Assistants''.](https://arxiv.org/abs/2309.10020)
|
198 |
+
|
199 |
+
<p align="center">
|
200 |
+
<img src="https://github.com/Computer-Vision-in-the-Wild/CVinW_Readings/blob/main/images/mfm_evolution.jpeg?raw=true" width=50%/>
|
201 |
+
</p>
|
202 |
+
|
203 |
+
- \[7/19\] 🔥 We release a major upgrade, including support for LLaMA-2, LoRA training, 4-/8-bit inference, higher resolution (336x336), and a lot more. We release [LLaVA Bench](https://github.com/haotian-liu/LLaVA/blob/main/docs/LLaVA_Bench.md) for benchmarking open-ended visual chat with results from Bard and Bing-Chat. We also support and verify training with RTX 3090 and RTX A6000. Check out [LLaVA-from-LLaMA-2](https://github.com/haotian-liu/LLaVA/blob/main/docs/LLaVA_from_LLaMA2.md), and our [model zoo](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md)!
|
204 |
+
|
205 |
+
- \[6/26\] [CVPR 2023 Tutorial](https://vlp-tutorial.github.io/) on **Large Multimodal Models: Towards Building and Surpassing Multimodal GPT-4**! Please check out \[[Slides](https://datarelease.blob.core.windows.net/tutorial/vision_foundation_models_2023/slides/Chunyuan_cvpr2023_tutorial_lmm.pdf)\] \[[Notes](https://arxiv.org/abs/2306.14895)\] \[[YouTube](https://youtu.be/mkI7EPD1vp8)\] \[[Bilibli](https://www.bilibili.com/video/BV1Ng4y1T7v3/)\].
|
206 |
+
|
207 |
+
- \[6/11\] We released the preview for the most requested feature: DeepSpeed and LoRA support! Please see documentations [here](./docs/LoRA.md).
|
208 |
+
|
209 |
+
- \[6/1\] We released **LLaVA-Med: Large Language and Vision Assistant for Biomedicine**, a step towards building biomedical domain large language and vision models with GPT-4 level capabilities. Checkout the [paper](https://arxiv.org/abs/2306.00890) and [page](https://github.com/microsoft/LLaVA-Med).
|
210 |
+
|
211 |
+
- \[5/6\] We are releasing [LLaVA-Lighting-MPT-7B-preview](https://huggingface.co/liuhaotian/LLaVA-Lightning-MPT-7B-preview), based on MPT-7B-Chat! See [here](#LLaVA-MPT-7b) for more details.
|
212 |
+
|
213 |
+
- \[5/2\] 🔥 We are releasing LLaVA-Lighting! Train a lite, multimodal GPT-4 with just $40 in 3 hours! See [here](#train-llava-lightning) for more details.
|
214 |
+
|
215 |
+
- \[4/27\] Thanks to the community effort, LLaVA-13B with 4-bit quantization allows you to run on a GPU with as few as 12GB VRAM! Try it out [here](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/llava).
|
216 |
+
|
217 |
+
- \[4/17\] 🔥 We released **LLaVA: Large Language and Vision Assistant**. We propose visual instruction tuning, towards building large language and vision models with GPT-4 level capabilities. Checkout the [paper](https://arxiv.org/abs/2304.08485) and [demo](https://llava.hliu.cc/).
|
218 |
+
|
219 |
+
<!-- <a href="https://llava.hliu.cc/"><img src="assets/demo.gif" width="70%"></a> -->
|
220 |
+
|
221 |
+
[](https://github.com/tatsu-lab/stanford_alpaca/blob/main/LICENSE)
|
222 |
+
[](https://github.com/tatsu-lab/stanford_alpaca/blob/main/DATA_LICENSE)
|
223 |
+
**Usage and License Notices**: The data and checkpoint is intended and licensed for research use only. They are also restricted to uses that follow the license agreement of LLaMA, Vicuna and GPT-4. The dataset is CC BY NC 4.0 (allowing only non-commercial use) and models trained using the dataset should not be used outside of research purposes.
|
224 |
+
|
225 |
+
### Contents
|
226 |
+
|
227 |
+
- [Install](#install)
|
228 |
+
- [LLaVA Weights](#llava-weights)
|
229 |
+
- [Demo](#Demo)
|
230 |
+
- [Model Zoo](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md)
|
231 |
+
- [Dataset](https://github.com/haotian-liu/LLaVA/blob/main/docs/Data.md)
|
232 |
+
- [Train](#train)
|
233 |
+
- [Evaluation](#evaluation)
|
234 |
+
|
235 |
+
### Install
|
236 |
+
|
237 |
+
1. Clone this repository and navigate to LLaVA folder
|
238 |
+
|
239 |
+
```bash
|
240 |
+
git clone https://github.com/haotian-liu/LLaVA.git
|
241 |
+
cd LLaVA
|
242 |
+
```
|
243 |
+
|
244 |
+
2. Install Package
|
245 |
+
|
246 |
+
```Shell
|
247 |
+
conda create -n llava python=3.10 -y
|
248 |
+
conda activate llava
|
249 |
+
pip install --upgrade pip # enable PEP 660 support
|
250 |
+
pip install -e .
|
251 |
+
```
|
252 |
+
|
253 |
+
3. Install additional packages for training cases
|
254 |
+
|
255 |
+
```
|
256 |
+
pip install ninja
|
257 |
+
pip install flash-attn --no-build-isolation
|
258 |
+
```
|
259 |
+
|
260 |
+
#### Upgrade to latest code base
|
261 |
+
|
262 |
+
```Shell
|
263 |
+
git pull
|
264 |
+
pip uninstall transformers
|
265 |
+
pip install -e .
|
266 |
+
```
|
267 |
+
|
268 |
+
### LLaVA Weights
|
269 |
+
|
270 |
+
Please check out our [Model Zoo](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md) for all public LLaVA checkpoints, and the instructions of how to use the weights.
|
271 |
+
|
272 |
+
### Demo
|
273 |
+
|
274 |
+
To run our demo, you need to prepare LLaVA checkpoints locally. Please follow the instructions [here](#llava-weights) to download the checkpoints.
|
275 |
+
|
276 |
+
#### Gradio Web UI
|
277 |
+
|
278 |
+
To launch a Gradio demo locally, please run the following commands one by one. If you plan to launch multiple model workers to compare between different checkpoints, you only need to launch the controller and the web server *ONCE*.
|
279 |
+
|
280 |
+
##### Launch a controller
|
281 |
+
|
282 |
+
```Shell
|
283 |
+
python -m llava.serve.controller --host 0.0.0.0 --port 10000
|
284 |
+
```
|
285 |
+
|
286 |
+
##### Launch a gradio web server.
|
287 |
+
|
288 |
+
```Shell
|
289 |
+
python -m llava.serve.gradio_web_server --controller http://localhost:10000 --model-list-mode reload
|
290 |
+
```
|
291 |
+
|
292 |
+
You just launched the Gradio web interface. Now, you can open the web interface with the URL printed on the screen. You may notice that there is no model in the model list. Do not worry, as we have not launched any model worker yet. It will be automatically updated when you launch a model worker.
|
293 |
+
|
294 |
+
##### Launch a model worker
|
295 |
+
|
296 |
+
This is the actual *worker* that performs the inference on the GPU. Each worker is responsible for a single model specified in `--model-path`.
|
297 |
+
|
298 |
+
```Shell
|
299 |
+
python -m llava.serve.model_worker --host 0.0.0.0 --controller http://localhost:10000 --port 40000 --worker http://localhost:40000 --model-path liuhaotian/llava-v1.5-13b
|
300 |
+
```
|
301 |
+
|
302 |
+
Wait until the process finishes loading the model and you see "Uvicorn running on ...". Now, refresh your Gradio web UI, and you will see the model you just launched in the model list.
|
303 |
+
|
304 |
+
You can launch as many workers as you want, and compare between different model checkpoints in the same Gradio interface. Please keep the `--controller` the same, and modify the `--port` and `--worker` to a different port number for each worker.
|
305 |
+
|
306 |
+
```Shell
|
307 |
+
python -m llava.serve.model_worker --host 0.0.0.0 --controller http://localhost:10000 --port <different from 40000, say 40001> --worker http://localhost:<change accordingly, i.e. 40001> --model-path <ckpt2>
|
308 |
+
```
|
309 |
+
|
310 |
+
If you are using an Apple device with an M1 or M2 chip, you can specify the mps device by using the `--device` flag: `--device mps`.
|
311 |
+
|
312 |
+
##### Launch a model worker (Multiple GPUs, when GPU VRAM \<= 24GB)
|
313 |
+
|
314 |
+
If the VRAM of your GPU is less than 24GB (e.g., RTX 3090, RTX 4090, etc.), you may try running it with multiple GPUs. Our latest code base will automatically try to use multiple GPUs if you have more than one GPU. You can specify which GPUs to use with `CUDA_VISIBLE_DEVICES`. Below is an example of running with the first two GPUs.
|
315 |
+
|
316 |
+
```Shell
|
317 |
+
CUDA_VISIBLE_DEVICES=0,1 python -m llava.serve.model_worker --host 0.0.0.0 --controller http://localhost:10000 --port 40000 --worker http://localhost:40000 --model-path liuhaotian/llava-v1.5-13b
|
318 |
+
```
|
319 |
+
|
320 |
+
##### Launch a model worker (4-bit, 8-bit inference, quantized)
|
321 |
+
|
322 |
+
You can launch the model worker with quantized bits (4-bit, 8-bit), which allows you to run the inference with reduced GPU memory footprint, potentially allowing you to run on a GPU with as few as 12GB VRAM. Note that inference with quantized bits may not be as accurate as the full-precision model. Simply append `--load-4bit` or `--load-8bit` to the **model worker** command that you are executing. Below is an example of running with 4-bit quantization.
|
323 |
+
|
324 |
+
```Shell
|
325 |
+
python -m llava.serve.model_worker --host 0.0.0.0 --controller http://localhost:10000 --port 40000 --worker http://localhost:40000 --model-path liuhaotian/llava-v1.5-13b --load-4bit
|
326 |
+
```
|
327 |
+
|
328 |
+
##### Launch a model worker (LoRA weights, unmerged)
|
329 |
+
|
330 |
+
You can launch the model worker with LoRA weights, without merging them with the base checkpoint, to save disk space. There will be additional loading time, while the inference speed is the same as the merged checkpoints. Unmerged LoRA checkpoints do not have `lora-merge` in the model name, and are usually much smaller (less than 1GB) than the merged checkpoints (13G for 7B, and 25G for 13B).
|
331 |
+
|
332 |
+
To load unmerged LoRA weights, you simply need to pass an additional argument `--model-base`, which is the base LLM that is used to train the LoRA weights. You can check the base LLM of each LoRA weights in the [model zoo](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md).
|
333 |
+
|
334 |
+
```Shell
|
335 |
+
python -m llava.serve.model_worker --host 0.0.0.0 --controller http://localhost:10000 --port 40000 --worker http://localhost:40000 --model-path liuhaotian/llava-v1-0719-336px-lora-vicuna-13b-v1.3 --model-base lmsys/vicuna-13b-v1.3
|
336 |
+
```
|
337 |
+
|
338 |
+
#### CLI Inference
|
339 |
+
|
340 |
+
Chat about images using LLaVA without the need of Gradio interface. It also supports multiple GPUs, 4-bit and 8-bit quantized inference. With 4-bit quantization, for our LLaVA-1.5-7B, it uses less than 8GB VRAM on a single GPU.
|
341 |
+
|
342 |
+
```Shell
|
343 |
+
python -m llava.serve.cli \
|
344 |
+
--model-path liuhaotian/llava-v1.5-7b \
|
345 |
+
--image-file "https://llava-vl.github.io/static/images/view.jpg" \
|
346 |
+
--load-4bit
|
347 |
+
```
|
348 |
+
|
349 |
+
### Train
|
350 |
+
|
351 |
+
*Below is the latest training configuration for LLaVA v1.5. For legacy models, please refer to README of [this](https://github.com/haotian-liu/LLaVA/tree/v1.0.1) version for now. We'll add them in a separate doc later.*
|
352 |
+
|
353 |
+
LLaVA training consists of two stages: (1) feature alignment stage: use our 558K subset of the LAION-CC-SBU dataset to connect a *frozen pretrained* vision encoder to a *frozen LLM*; (2) visual instruction tuning stage: use 150K GPT-generated multimodal instruction-following data, plus around 515K VQA data from academic-oriented tasks, to teach the model to follow multimodal instructions.
|
354 |
+
|
355 |
+
LLaVA is trained on 8 A100 GPUs with 80GB memory. To train on fewer GPUs, you can reduce the `per_device_train_batch_size` and increase the `gradient_accumulation_steps` accordingly. Always keep the global batch size the same: `per_device_train_batch_size` x `gradient_accumulation_steps` x `num_gpus`.
|
356 |
+
|
357 |
+
#### Hyperparameters
|
358 |
+
|
359 |
+
We use a similar set of hyperparameters as Vicuna in finetuning. Both hyperparameters used in pretraining and finetuning are provided below.
|
360 |
+
|
361 |
+
1. Pretraining
|
362 |
+
|
363 |
+
| Hyperparameter | Global Batch Size | Learning rate | Epochs | Max length | Weight decay |
|
364 |
+
| -------------- | ----------------: | ------------: | -----: | ---------: | -----------: |
|
365 |
+
| LLaVA-v1.5-13B | 256 | 1e-3 | 1 | 2048 | 0 |
|
366 |
+
|
367 |
+
2. Finetuning
|
368 |
+
|
369 |
+
| Hyperparameter | Global Batch Size | Learning rate | Epochs | Max length | Weight decay |
|
370 |
+
| -------------- | ----------------: | ------------: | -----: | ---------: | -----------: |
|
371 |
+
| LLaVA-v1.5-13B | 128 | 2e-5 | 1 | 2048 | 0 |
|
372 |
+
|
373 |
+
#### Download Vicuna checkpoints (automatically)
|
374 |
+
|
375 |
+
Our base model Vicuna v1.5, which is an instruction-tuned chatbot, will be downloaded automatically when you run our provided training scripts. No action is needed.
|
376 |
+
|
377 |
+
#### Pretrain (feature alignment)
|
378 |
+
|
379 |
+
Please download the 558K subset of the LAION-CC-SBU dataset with BLIP captions we use in the paper [here](https://huggingface.co/datasets/liuhaotian/LLaVA-Pretrain).
|
380 |
+
|
381 |
+
Pretrain takes around 5.5 hours for LLaVA-v1.5-13B on 8x A100 (80G), due to the increased resolution to 336px. It takes around 3.5 hours for LLaVA-v1.5-7B.
|
382 |
+
|
383 |
+
Training script with DeepSpeed ZeRO-2: [`pretrain.sh`](https://github.com/haotian-liu/LLaVA/blob/main/scripts/v1_5/pretrain.sh).
|
384 |
+
|
385 |
+
- `--mm_projector_type mlp2x_gelu`: the two-layer MLP vision-language connector.
|
386 |
+
- `--vision_tower openai/clip-vit-large-patch14-336`: CLIP ViT-L/14 336px.
|
387 |
+
|
388 |
+
#### Visual Instruction Tuning
|
389 |
+
|
390 |
+
1. Prepare data
|
391 |
+
|
392 |
+
Please download the annotation of the final mixture our instruction tuning data [llava_v1_5_mix665k.json](https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K/blob/main/llava_v1_5_mix665k.json), and download the images from constituting datasets:
|
393 |
+
|
394 |
+
- COCO: [train2017](http://images.cocodataset.org/zips/train2017.zip)
|
395 |
+
- GQA: [images](https://downloads.cs.stanford.edu/nlp/data/gqa/images.zip)
|
396 |
+
- OCR-VQA: [download script](https://drive.google.com/drive/folders/1_GYPY5UkUy7HIcR0zq3ZCFgeZN7BAfm_?usp=sharing)
|
397 |
+
- TextVQA: [train_val_images](https://dl.fbaipublicfiles.com/textvqa/images/train_val_images.zip)
|
398 |
+
- VisualGenome: [part1](https://cs.stanford.edu/people/rak248/VG_100K_2/images.zip), [part2](https://cs.stanford.edu/people/rak248/VG_100K_2/images2.zip)
|
399 |
+
|
400 |
+
After downloading all of them, organize the data as follows in `./playground/data`,
|
401 |
+
|
402 |
+
```
|
403 |
+
├── coco
|
404 |
+
│ └── train2017
|
405 |
+
├── gqa
|
406 |
+
│ └── images
|
407 |
+
├── ocr_vqa
|
408 |
+
│ └── images
|
409 |
+
├── textvqa
|
410 |
+
│ └── train_images
|
411 |
+
└── vg
|
412 |
+
├── VG_100K
|
413 |
+
└── VG_100K_2
|
414 |
+
```
|
415 |
+
|
416 |
+
2. Start training!
|
417 |
+
|
418 |
+
You may download our pretrained projectors in [Model Zoo](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md). It is not recommended to use legacy projectors, as they may be trained with a different version of the codebase, and if any option is off, the model will not function/train as we expected.
|
419 |
+
|
420 |
+
Visual instruction tuning takes around 20 hours for LLaVA-v1.5-13B on 8x A100 (80G), due to the increased resolution to 336px. It takes around 10 hours for LLaVA-v1.5-7B on 8x A100 (40G).
|
421 |
+
|
422 |
+
Training script with DeepSpeed ZeRO-3: [`finetune.sh`](https://github.com/haotian-liu/LLaVA/blob/main/scripts/v1_5/finetune.sh).
|
423 |
+
|
424 |
+
New options to note:
|
425 |
+
|
426 |
+
- `--mm_projector_type mlp2x_gelu`: the two-layer MLP vision-language connector.
|
427 |
+
- `--vision_tower openai/clip-vit-large-patch14-336`: CLIP ViT-L/14 336px.
|
428 |
+
- `--image_aspect_ratio pad`: this pads the non-square images to square, instead of cropping them; it slightly reduces hallucination.
|
429 |
+
- `--group_by_modality_length True`: this should only be used when your instruction tuning dataset contains both language (e.g. ShareGPT) and multimodal (e.g. LLaVA-Instruct). It makes the training sampler only sample a single modality (either image or language) during training, which we observe to speed up training by ~25%, and does not affect the final outcome.
|
430 |
+
|
431 |
+
### Evaluation
|
432 |
+
|
433 |
+
In LLaVA-1.5, we evaluate models on a diverse set of 12 benchmarks. To ensure the reproducibility, we evaluate the models with greedy decoding. We do not evaluate using beam search to make the inference process consistent with the chat demo of real-time outputs.
|
434 |
+
|
435 |
+
See [Evaluation.md](https://github.com/haotian-liu/LLaVA/blob/main/docs/Evaluation.md).
|
436 |
+
|
437 |
+
#### GPT-assisted Evaluation
|
438 |
+
|
439 |
+
Our GPT-assisted evaluation pipeline for multimodal modeling is provided for a comprehensive understanding of the capabilities of vision-language models. Please see our paper for more details.
|
440 |
+
|
441 |
+
1. Generate LLaVA responses
|
442 |
+
|
443 |
+
```Shell
|
444 |
+
python model_vqa.py \
|
445 |
+
--model-path ./checkpoints/LLaVA-13B-v0 \
|
446 |
+
--question-file \
|
447 |
+
playground/data/coco2014_val_qa_eval/qa90_questions.jsonl \
|
448 |
+
--image-folder \
|
449 |
+
/path/to/coco2014_val \
|
450 |
+
--answers-file \
|
451 |
+
/path/to/answer-file-our.jsonl
|
452 |
+
```
|
453 |
+
|
454 |
+
2. Evaluate the generated responses. In our case, [`answer-file-ref.jsonl`](./playground/data/coco2014_val_qa_eval/qa90_gpt4_answer.jsonl) is the response generated by text-only GPT-4 (0314), with the context captions/boxes provided.
|
455 |
+
|
456 |
+
```Shell
|
457 |
+
OPENAI_API_KEY="sk-***********************************" python llava/eval/eval_gpt_review_visual.py \
|
458 |
+
--question playground/data/coco2014_val_qa_eval/qa90_questions.jsonl \
|
459 |
+
--context llava/eval/table/caps_boxes_coco2014_val_80.jsonl \
|
460 |
+
--answer-list \
|
461 |
+
/path/to/answer-file-ref.jsonl \
|
462 |
+
/path/to/answer-file-our.jsonl \
|
463 |
+
--rule llava/eval/table/rule.json \
|
464 |
+
--output /path/to/review.json
|
465 |
+
```
|
466 |
+
|
467 |
+
3. Summarize the evaluation results
|
468 |
+
|
469 |
+
```Shell
|
470 |
+
python summarize_gpt_review.py
|
471 |
+
```
|
472 |
+
|
473 |
+
### Citation
|
474 |
+
|
475 |
+
If you find LLaVA useful for your research and applications, please cite using this BibTeX:
|
476 |
+
|
477 |
+
```bibtex
|
478 |
+
@misc{liu2023improvedllava,
|
479 |
+
title={Improved Baselines with Visual Instruction Tuning},
|
480 |
+
author={Liu, Haotian and Li, Chunyuan and Li, Yuheng and Lee, Yong Jae},
|
481 |
+
publisher={arXiv:2310.03744},
|
482 |
+
year={2023},
|
483 |
+
}
|
484 |
+
|
485 |
+
@misc{liu2023llava,
|
486 |
+
title={Visual Instruction Tuning},
|
487 |
+
author={Liu, Haotian and Li, Chunyuan and Wu, Qingyang and Lee, Yong Jae},
|
488 |
+
publisher={arXiv:2304.08485},
|
489 |
+
year={2023},
|
490 |
+
}
|
491 |
+
```
|
492 |
+
|
493 |
+
### Acknowledgement
|
494 |
+
|
495 |
+
- [Vicuna](https://github.com/lm-sys/FastChat): the codebase we built upon, and our base model Vicuna-13B that has the amazing language capabilities!
|
496 |
+
|
497 |
+
### Related Projects
|
498 |
+
|
499 |
+
- [Instruction Tuning with GPT-4](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
500 |
+
- [LLaVA-Med: Training a Large Language-and-Vision Assistant for Biomedicine in One Day](https://github.com/microsoft/LLaVA-Med)
|
501 |
+
- [Otter: In-Context Multi-Modal Instruction Tuning](https://github.com/Luodian/Otter)
|
502 |
+
|
503 |
+
For future project ideas, please check out:
|
504 |
+
|
505 |
+
- [SEEM: Segment Everything Everywhere All at Once](https://github.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once)
|
506 |
+
- [Grounded-Segment-Anything](https://github.com/IDEA-Research/Grounded-Segment-Anything) to detect, segment, and generate anything by marrying [Grounding DINO](https://github.com/IDEA-Research/GroundingDINO) and [Segment-Anything](https://github.com/facebookresearch/segment-anything).
|
InternVL/internvl_chat_llava/pyproject.toml
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[build-system]
|
2 |
+
requires = ["setuptools>=61.0"]
|
3 |
+
build-backend = "setuptools.build_meta"
|
4 |
+
|
5 |
+
[project]
|
6 |
+
name = "llava"
|
7 |
+
version = "1.1.1"
|
8 |
+
description = "Towards GPT-4 like large language and visual assistant."
|
9 |
+
readme = "README.md"
|
10 |
+
requires-python = ">=3.8"
|
11 |
+
classifiers = [
|
12 |
+
"Programming Language :: Python :: 3",
|
13 |
+
"License :: OSI Approved :: Apache Software License",
|
14 |
+
]
|
15 |
+
dependencies = [
|
16 |
+
"torch>=2", "torchvision>=0.15",
|
17 |
+
"transformers>=4.37.2", "tokenizers==0.15.1", "sentencepiece==0.1.99", "shortuuid",
|
18 |
+
"accelerate", "peft>=0.4.0", "bitsandbytes==0.41.0",
|
19 |
+
"pydantic", "markdown2[all]", "numpy", "scikit-learn>=1.2.2",
|
20 |
+
"gradio==3.35.2", "gradio_client==0.2.9",
|
21 |
+
"requests", "httpx==0.24.0", "uvicorn", "fastapi",
|
22 |
+
"deepspeed==0.13.5", "einops", "einops-exts", "timm==0.9.12",
|
23 |
+
]
|
24 |
+
|
25 |
+
[project.urls]
|
26 |
+
"Homepage" = "https://github.com/OpenGVLab/InternVL"
|
27 |
+
"Bug Tracker" = "https://github.com/OpenGVLab/InternVL/issues"
|
28 |
+
|
29 |
+
[tool.setuptools.packages.find]
|
30 |
+
exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"]
|
31 |
+
|
32 |
+
[tool.wheel]
|
33 |
+
exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"]
|
InternVL/internvl_g/README.md
ADDED
@@ -0,0 +1,497 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# InternVL Stage-2 Pre-training & Retrieval Fine-tuning
|
2 |
+
|
3 |
+
This folder contains the implementation of the InternVL 1.0 for stage2 pre-training and retrieval fine-tuning, which corresponds to Section 4.3 of our [InternVL 1.0 paper](https://arxiv.org/pdf/2312.14238).
|
4 |
+
|
5 |
+

|
6 |
+
|
7 |
+
## 🛠️ Installation
|
8 |
+
|
9 |
+
Follow the [installation guide](../INSTALLATION.md) to perform installations.
|
10 |
+
|
11 |
+
## 📦 Data Preparation
|
12 |
+
|
13 |
+
Three datasets need to be prepared: COCO Caption, Flickr30K, and NoCaps.
|
14 |
+
|
15 |
+
<details open>
|
16 |
+
<summary>COCO Caption</summary>
|
17 |
+
|
18 |
+
```bash
|
19 |
+
mkdir -p data/coco && cd data/coco
|
20 |
+
|
21 |
+
# download coco images
|
22 |
+
wget http://images.cocodataset.org/zips/train2014.zip && unzip train2014.zip
|
23 |
+
wget http://images.cocodataset.org/zips/val2014.zip && unzip val2014.zip
|
24 |
+
wget http://images.cocodataset.org/zips/test2015.zip && unzip test2015.zip
|
25 |
+
|
26 |
+
mkdir -p annotations && cd annotations/
|
27 |
+
# download converted annotation files
|
28 |
+
wget https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_train.json
|
29 |
+
wget https://github.com/OpenGVLab/InternVL/releases/download/data/coco_karpathy_test.json
|
30 |
+
wget https://github.com/OpenGVLab/InternVL/releases/download/data/coco_karpathy_test_gt.json
|
31 |
+
cd ../../../
|
32 |
+
```
|
33 |
+
|
34 |
+
</details>
|
35 |
+
|
36 |
+
<details open>
|
37 |
+
<summary>Flickr30K</summary>
|
38 |
+
|
39 |
+
```bash
|
40 |
+
mkdir -p data/flickr30k && cd data/flickr30k
|
41 |
+
|
42 |
+
# download images from https://bryanplummer.com/Flickr30kEntities/
|
43 |
+
# karpathy split annotations can be downloaded from the following link:
|
44 |
+
# https://github.com/mehdidc/retrieval_annotations/releases/download/1.0.0/flickr30k_test_karpathy.txt
|
45 |
+
# this file is provided by the clip-benchmark repository.
|
46 |
+
# We convert this txt file to json format, download the converted file:
|
47 |
+
wget https://github.com/OpenGVLab/InternVL/releases/download/data/flickr30k_cn_test.txt
|
48 |
+
wget https://github.com/OpenGVLab/InternVL/releases/download/data/flickr30k_cn_train.txt
|
49 |
+
wget https://github.com/OpenGVLab/InternVL/releases/download/data/flickr30k_test_karpathy.json
|
50 |
+
wget https://github.com/mehdidc/retrieval_annotations/releases/download/1.0.0/flickr30k_test_karpathy.txt
|
51 |
+
wget https://github.com/mehdidc/retrieval_annotations/releases/download/1.0.0/flickr30k_train_karpathy.txt
|
52 |
+
wget https://github.com/mehdidc/retrieval_annotations/releases/download/1.0.0/flickr30k_val_karpathy.txt
|
53 |
+
|
54 |
+
cd ../..
|
55 |
+
```
|
56 |
+
|
57 |
+
</details>
|
58 |
+
|
59 |
+
<details open>
|
60 |
+
<summary>NoCaps</summary>
|
61 |
+
|
62 |
+
```bash
|
63 |
+
mkdir -p data/nocaps && cd data/nocaps
|
64 |
+
|
65 |
+
# download images from https://nocaps.org/download
|
66 |
+
# original annotations can be downloaded from https://nocaps.s3.amazonaws.com/nocaps_val_4500_captions.json
|
67 |
+
wget https://nocaps.s3.amazonaws.com/nocaps_val_4500_captions.json
|
68 |
+
|
69 |
+
cd ../..
|
70 |
+
```
|
71 |
+
|
72 |
+
</details>
|
73 |
+
|
74 |
+
After the download is complete, the directory structure is:
|
75 |
+
|
76 |
+
```shell
|
77 |
+
data
|
78 |
+
├── coco
|
79 |
+
│ ├── annotations
|
80 |
+
│ │ ├── coco_karpathy_train.json
|
81 |
+
│ ├── test2017
|
82 |
+
│ ├── train2014
|
83 |
+
│ ├── train2017
|
84 |
+
│ ├── val2014
|
85 |
+
│ └── val2017
|
86 |
+
├── flickr30k
|
87 |
+
│ ├── flickr30k_cn_test.txt
|
88 |
+
│ ├── flickr30k_cn_train.txt
|
89 |
+
│ ├── flickr30k_test_karpathy.json
|
90 |
+
│ ├── flickr30k_test_karpathy.txt
|
91 |
+
│ ├── flickr30k_train_karpathy.txt
|
92 |
+
│ ├── flickr30k_val_karpathy.txt
|
93 |
+
│ └── Images
|
94 |
+
└── nocaps
|
95 |
+
├── images
|
96 |
+
└── nocaps_val_4500_captions.json
|
97 |
+
```
|
98 |
+
|
99 |
+
## 📦 Model Preparation
|
100 |
+
|
101 |
+
| model name | type | download | size |
|
102 |
+
| ------------------ | ----------- | ----------------------------------------------------------------- | :-----: |
|
103 |
+
| InternVL-14B-224px | huggingface | 🤗 [HF link](https://huggingface.co/OpenGVLab/InternVL-14B-224px) | 27.7 GB |
|
104 |
+
|
105 |
+
Please download the above model weights and place them in the `pretrained/` folder.
|
106 |
+
|
107 |
+
```sh
|
108 |
+
cd pretrained/
|
109 |
+
# pip install -U huggingface_hub
|
110 |
+
huggingface-cli download --resume-download --local-dir-use-symlinks False OpenGVLab/InternVL-14B-224px --local-dir InternVL-14B-224px
|
111 |
+
```
|
112 |
+
|
113 |
+
The directory structure is:
|
114 |
+
|
115 |
+
```sh
|
116 |
+
pretrained
|
117 |
+
└── InternVL-14B-224px/
|
118 |
+
```
|
119 |
+
|
120 |
+
## 🔥 Generative Pre-training
|
121 |
+
|
122 |
+
There are currently no plans to release this part of the code.
|
123 |
+
|
124 |
+
## 📊 Evaluation
|
125 |
+
|
126 |
+
### Zero-Shot Image Captioning
|
127 |
+
|
128 |
+
| model | dataset | BLEU4 | METEOR | CIDEr |
|
129 |
+
| ---------- | ----------------------- | ----- | ------ | ----- |
|
130 |
+
| InternVL-G | COCO Karpathy test | 37.1 | 30.1 | 128.2 |
|
131 |
+
| InternVL-G | Flickr30K Karpathy test | 27.0 | 25.3 | 79.2 |
|
132 |
+
| InternVL-G | NoCaps val | 44.3 | 30.1 | 113.7 |
|
133 |
+
|
134 |
+
<details>
|
135 |
+
<summary>[InternVL-G] COCO Karpathy test</summary>
|
136 |
+
|
137 |
+
```bash
|
138 |
+
sh evaluate.sh pretrained/InternVL-14B-224px caption-coco
|
139 |
+
```
|
140 |
+
|
141 |
+
Expected results:
|
142 |
+
|
143 |
+
```
|
144 |
+
['coco', 'English caption:', 10.5974, dict_items([('Bleu_1', 0.7876323287981284), ('Bleu_2', 0.6353512494727918), ('Bleu_3', 0.49108984183589743), ('Bleu_4', 0.37062736733849205), ('METEOR', 0.30106315496945923), ('ROUGE_L', 0.5898249189475652), ('CIDEr', 1.281844384075423)])]
|
145 |
+
```
|
146 |
+
|
147 |
+
</details>
|
148 |
+
|
149 |
+
<details>
|
150 |
+
<summary>[InternVL-G] Flickr30K Karpathy test</summary>
|
151 |
+
|
152 |
+
```
|
153 |
+
sh evaluate.sh pretrained/InternVL-14B-224px caption-flickr30k
|
154 |
+
```
|
155 |
+
|
156 |
+
Expected results:
|
157 |
+
|
158 |
+
```bash
|
159 |
+
['flickr30k', 'English caption:', 10.666, dict_items([('Bleu_1', 0.7182900534357628), ('Bleu_2', 0.5353390037921949), ('Bleu_3', 0.3834462132295285), ('Bleu_4', 0.2702131471765472), ('METEOR', 0.25263515267930103), ('ROUGE_L', 0.5305876871149064), ('CIDEr', 0.7919734768328237)])]
|
160 |
+
```
|
161 |
+
|
162 |
+
</details>
|
163 |
+
|
164 |
+
<details>
|
165 |
+
<summary>[InternVL-G] NoCaps val</summary>
|
166 |
+
|
167 |
+
```bash
|
168 |
+
sh evaluate.sh pretrained/InternVL-14B-224px caption-nocaps
|
169 |
+
```
|
170 |
+
|
171 |
+
Expected results:
|
172 |
+
|
173 |
+
```
|
174 |
+
['nocaps', 'English caption:', 10.463111111111111, dict_items([('Bleu_1', 0.8518290482155187), ('Bleu_2', 0.7165227921485106), ('Bleu_3', 0.5733723839888316), ('Bleu_4', 0.44268902150723105), ('METEOR', 0.30078174807736896), ('ROUGE_L', 0.6070208063052156), ('CIDEr', 1.1371742045267772)])]
|
175 |
+
```
|
176 |
+
|
177 |
+
</details>
|
178 |
+
|
179 |
+
### Fine-tuned Image-Text Retrieval
|
180 |
+
|
181 |
+
#### Flickr30K fine-tuned model: [InternVL-14B-Flickr30K-FT-364px](https://huggingface.co/OpenGVLab/InternVL-14B-Flickr30K-FT-364px)
|
182 |
+
|
183 |
+
<table>
|
184 |
+
<tr align=center>
|
185 |
+
<td rowspan="3" align=center><b>model</b></td>
|
186 |
+
<td colspan="6" align=center><b>Flickr30K</b></td>
|
187 |
+
<td rowspan="3" align=center><b>avg</b></td>
|
188 |
+
|
189 |
+
</tr>
|
190 |
+
<tr align=center>
|
191 |
+
<td colspan="3" align=center><b>image-to-text</b></td>
|
192 |
+
<td colspan="3" align=center><b>text-to-image</b></td>
|
193 |
+
</tr>
|
194 |
+
<tr>
|
195 |
+
<td>R@1</td>
|
196 |
+
<td>R@5</td>
|
197 |
+
<td>R@10</td>
|
198 |
+
<td>R@1</td>
|
199 |
+
<td>R@5</td>
|
200 |
+
<td>R@10</td>
|
201 |
+
</tr>
|
202 |
+
|
203 |
+
<tr align=center>
|
204 |
+
<td>InternVL-C-FT</td>
|
205 |
+
<td>97.2</td>
|
206 |
+
<td>100.0</td>
|
207 |
+
<td>100.0</td>
|
208 |
+
<td>88.5</td>
|
209 |
+
<td>98.4</td>
|
210 |
+
<td>99.2</td>
|
211 |
+
<td>97.2</td>
|
212 |
+
</tr>
|
213 |
+
<tr align=center>
|
214 |
+
<td>InternVL-G-FT</td>
|
215 |
+
<td>97.9</td>
|
216 |
+
<td>100.0</td>
|
217 |
+
<td>100.0</td>
|
218 |
+
<td>89.6</td>
|
219 |
+
<td>98.6</td>
|
220 |
+
<td>99.2</td>
|
221 |
+
<td>97.6</td>
|
222 |
+
</tr>
|
223 |
+
|
224 |
+
</table>
|
225 |
+
|
226 |
+
<details>
|
227 |
+
<summary>[InternVL-C-FT] Flickr30K</summary>
|
228 |
+
|
229 |
+
```bash
|
230 |
+
cd ../clip_benchmark/
|
231 |
+
CUDA_VISIBLE_DEVICES=0 python3 clip_benchmark/cli.py eval --model_type internvl --language "en" --task "zeroshot_retrieval" \
|
232 |
+
--dataset "flickr30k" --dataset_root ./data/flickr30k --model internvl_c_retrieval_hf \
|
233 |
+
--pretrained ./work_dirs/internvl_stage2_finetune_flickr_364_bs1024_ep10/ --output result_ft.json
|
234 |
+
```
|
235 |
+
|
236 |
+
Expected results:
|
237 |
+
|
238 |
+
```
|
239 |
+
{"dataset": "flickr30k", "model": "internvl_c_retrieval_hf", "pretrained": "./work_dirs/internvl_stage2_finetune_flickr_364_bs1024_ep10", "task": "zeroshot_retrieval",
|
240 |
+
"metrics": {"image_retrieval_recall@1": 0.8853999972343445, "text_retrieval_recall@1": 0.972000002861023,
|
241 |
+
"image_retrieval_recall@5": 0.9836000204086304, "text_retrieval_recall@5": 1.0,
|
242 |
+
"image_retrieval_recall@10": 0.9923999905586243, "text_retrieval_recall@10": 1.0}, "language": "en"}
|
243 |
+
```
|
244 |
+
|
245 |
+
</details>
|
246 |
+
|
247 |
+
<details>
|
248 |
+
<summary>[InternVL-G-FT] Flickr30K</summary>
|
249 |
+
|
250 |
+
```bash
|
251 |
+
cd ../clip_benchmark/
|
252 |
+
CUDA_VISIBLE_DEVICES=0 python3 clip_benchmark/cli.py eval --model_type internvl --language "en" --task "zeroshot_retrieval" \
|
253 |
+
--dataset "flickr30k" --dataset_root ./data/flickr30k --model internvl_g_retrieval_hf \
|
254 |
+
--pretrained ./work_dirs/internvl_stage2_finetune_flickr_364_bs1024_ep10/ --output result_ft.json
|
255 |
+
```
|
256 |
+
|
257 |
+
Expected results:
|
258 |
+
|
259 |
+
```
|
260 |
+
{"dataset": "flickr30k", "model": "internvl_g_retrieval_hf", "pretrained": "./work_dirs/internvl_stage2_finetune_flickr_364_bs1024_ep10", "task": "zeroshot_retrieval",
|
261 |
+
"metrics": {"image_retrieval_recall@1": 0.895799994468689, "text_retrieval_recall@1": 0.9789999723434448,
|
262 |
+
"image_retrieval_recall@5": 0.9861999750137329, "text_retrieval_recall@5": 1.0,
|
263 |
+
"image_retrieval_recall@10": 0.9922000169754028, "text_retrieval_recall@10": 1.0}, "language": "en"}
|
264 |
+
```
|
265 |
+
|
266 |
+
</details>
|
267 |
+
|
268 |
+
#### Flickr30K-CN fine-tuned model: [InternVL-14B-FlickrCN-FT-364px](https://huggingface.co/OpenGVLab/InternVL-14B-FlickrCN-FT-364px)
|
269 |
+
|
270 |
+
<table>
|
271 |
+
<tr align=center>
|
272 |
+
<td rowspan="3" align=center><b>model</b></td>
|
273 |
+
<td colspan="6" align=center><b>Flickr30K-CN</b></td>
|
274 |
+
<td rowspan="3" align=center><b>avg</b></td>
|
275 |
+
|
276 |
+
</tr>
|
277 |
+
<tr align=center>
|
278 |
+
<td colspan="3" align=center><b>image-to-text</b></td>
|
279 |
+
<td colspan="3" align=center><b>text-to-image</b></td>
|
280 |
+
</tr>
|
281 |
+
<tr>
|
282 |
+
<td>R@1</td>
|
283 |
+
<td>R@5</td>
|
284 |
+
<td>R@10</td>
|
285 |
+
<td>R@1</td>
|
286 |
+
<td>R@5</td>
|
287 |
+
<td>R@10</td>
|
288 |
+
</tr>
|
289 |
+
|
290 |
+
<tr align=center>
|
291 |
+
<td>InternVL-C-FT</td>
|
292 |
+
<td>96.5</td>
|
293 |
+
<td>99.9</td>
|
294 |
+
<td>100.0</td>
|
295 |
+
<td>85.2</td>
|
296 |
+
<td>97.0</td>
|
297 |
+
<td>98.5</td>
|
298 |
+
<td>96.2</td>
|
299 |
+
</tr>
|
300 |
+
<tr align=center>
|
301 |
+
<td>InternVL-G-FT</td>
|
302 |
+
<td>96.9</td>
|
303 |
+
<td>99.9</td>
|
304 |
+
<td>100.0</td>
|
305 |
+
<td>85.9</td>
|
306 |
+
<td>97.1</td>
|
307 |
+
<td>98.7</td>
|
308 |
+
<td>96.4</td>
|
309 |
+
</tr>
|
310 |
+
|
311 |
+
</table>
|
312 |
+
|
313 |
+
<details>
|
314 |
+
<summary>[InternVL-C-FT] Flickr30K-CN</summary>
|
315 |
+
|
316 |
+
```bash
|
317 |
+
cd ../clip_benchmark/
|
318 |
+
CUDA_VISIBLE_DEVICES=0 python3 clip_benchmark/cli.py eval --model_type internvl --language "cn" --task "zeroshot_retrieval" \
|
319 |
+
--dataset "flickr30k" --dataset_root ./data/flickr30k --model internvl_c_retrieval_hf \
|
320 |
+
--pretrained ./work_dirs/internvl_stage2_finetune_flickrcn_364_bs1024_ep10/ --output result_ft.json
|
321 |
+
```
|
322 |
+
|
323 |
+
Expected results:
|
324 |
+
|
325 |
+
```
|
326 |
+
{"dataset": "flickr30k", "model": "internvl_c_retrieval_hf", "pretrained": "./work_dirs/internvl_stage2_finetune_flickrcn_364_bs1024_ep10", "task": "zeroshot_retrieval",
|
327 |
+
"metrics": {"image_retrieval_recall@1": 0.8521999716758728, "text_retrieval_recall@1": 0.9649999737739563,
|
328 |
+
"image_retrieval_recall@5": 0.9697999954223633, "text_retrieval_recall@5": 0.9990000128746033,
|
329 |
+
"image_retrieval_recall@10": 0.9854000210762024, "text_retrieval_recall@10": 1.0}, "language": "cn"}
|
330 |
+
```
|
331 |
+
|
332 |
+
</details>
|
333 |
+
|
334 |
+
<details>
|
335 |
+
<summary>[InternVL-G-FT] Flickr30K-CN</summary>
|
336 |
+
|
337 |
+
```bash
|
338 |
+
cd ../clip_benchmark/
|
339 |
+
CUDA_VISIBLE_DEVICES=0 python3 clip_benchmark/cli.py eval --model_type internvl --language "cn" --task "zeroshot_retrieval" \
|
340 |
+
--dataset "flickr30k" --dataset_root ./data/flickr30k --model internvl_g_retrieval_hf \
|
341 |
+
--pretrained ./work_dirs/internvl_stage2_finetune_flickrcn_364_bs1024_ep10/ --output result_ft.json
|
342 |
+
```
|
343 |
+
|
344 |
+
Expected results:
|
345 |
+
|
346 |
+
```
|
347 |
+
{"dataset": "flickr30k", "model": "internvl_g_retrieval_hf", "pretrained": "./work_dirs/internvl_stage2_finetune_flickrcn_364_bs1024_ep10", "task": "zeroshot_retrieval",
|
348 |
+
"metrics": {"image_retrieval_recall@1": 0.8587999939918518, "text_retrieval_recall@1": 0.968999981880188,
|
349 |
+
"image_retrieval_recall@5": 0.9714000225067139, "text_retrieval_recall@5": 0.9990000128746033,
|
350 |
+
"image_retrieval_recall@10": 0.9865999817848206, "text_retrieval_recall@10": 1.0}, "language": "cn"}
|
351 |
+
```
|
352 |
+
|
353 |
+
</details>
|
354 |
+
|
355 |
+
## 🔥 Retrieval Fine-tuning (Fully)
|
356 |
+
|
357 |
+
> Note: In our experiments, full parameter fine-tuning achieves the best results on image-text retrieval tasks in Flickr30K and COCO. By following the experimental hyperparameters in this section, you can reproduce the model performance reported in the [Evaluation section](#evaluation).
|
358 |
+
|
359 |
+
To fine-tune InternVL on Flickr30K with 32 GPUs and slurm system, run:
|
360 |
+
|
361 |
+
```bash
|
362 |
+
PARTITION='your partition' GPUS=32 sh shell/finetune/internvl_stage2_finetune_flickr_364_bs1024_ep10.sh
|
363 |
+
```
|
364 |
+
|
365 |
+
To fine-tune InternVL on Flickr30K-CN with 32 GPUs and slurm system, run:
|
366 |
+
|
367 |
+
```shell
|
368 |
+
PARTITION='your partition' GPUS=32 sh shell/finetune/internvl_stage2_finetune_flickrcn_364_bs1024_ep10.sh
|
369 |
+
```
|
370 |
+
|
371 |
+
To fine-tune InternVL on COCO with 32 GPUs and slurm system, run:
|
372 |
+
|
373 |
+
```shell
|
374 |
+
PARTITION='your partition' GPUS=32 sh shell/finetune/internvl_stage2_finetune_coco_364_bs1024_ep5.sh
|
375 |
+
```
|
376 |
+
|
377 |
+
The hyperparameters used here are:
|
378 |
+
|
379 |
+
| config | Flickr30K | Flickr30K-CN | COCO |
|
380 |
+
| --------------------------- | ----------------------------------- | ----------------------------------- | ----------------------------------- |
|
381 |
+
| learning rate | 1e-6 | 1e-6 | 1e-6 |
|
382 |
+
| layer-wise lr<br>decay rate | InternViT-6B (0.9),<br>QLLaMA (0.9) | InternViT-6B (0.9),<br>QLLaMA (0.9) | InternViT-6B (0.9),<br>QLLaMA (0.9) |
|
383 |
+
| optimizer | AdamW | AdamW | AdamW |
|
384 |
+
| weight decay | 0.05 | 0.05 | 0.05 |
|
385 |
+
| input resolution | 364x364 | 364x364 | 364x364 |
|
386 |
+
| total batch size | 1024 | 1024 | 1024 |
|
387 |
+
| warm-up iterations | 100 | 100 | 100 |
|
388 |
+
| training epochs | 10 | 10 | 5 |
|
389 |
+
| drop path rate | 0.3 | 0.3 | 0.3 |
|
390 |
+
| numerical precision | zero1 + bf16 | zero1 + bf16 | zero1 + bf16 |
|
391 |
+
| trainable / total params | 14B / 14B | 14B / 14B | 14B / 14B |
|
392 |
+
| GPUs for training | 32×A100 (80G) | 32×A100 (80G) | 32×A100 (80G) |
|
393 |
+
| Required GPU memory | 80G | 80G | 80G |
|
394 |
+
|
395 |
+
## 🔥 Retrieval Fine-tuning (Head)
|
396 |
+
|
397 |
+
> Note: This section demonstrates how to perform a cost-effective fine-tuning of our model. The hyperparameters shown here are not optimized for any specific task. For practical applications, further adjustments to the hyperparameters may be necessary to achieve optimal performance.
|
398 |
+
|
399 |
+
To fine-tune the head of InternVL on Flickr30K with 4 GPUs, run:
|
400 |
+
|
401 |
+
```bash
|
402 |
+
GPUS=4 BATCH_SIZE=32 sh shell/head_finetune/internvl_stage2_finetune_flickr_224_bs1024_ep10_head_4gpu.sh
|
403 |
+
```
|
404 |
+
|
405 |
+
To fine-tune the head of InternVL on Flickr30K-CN with 4 GPUs, run:
|
406 |
+
|
407 |
+
```shell
|
408 |
+
GPUS=4 BATCH_SIZE=32 sh shell/head_finetune/internvl_stage2_finetune_flickrcn_224_bs1024_ep10_head_4gpu.sh
|
409 |
+
```
|
410 |
+
|
411 |
+
To fine-tune the head of InternVL on COCO with 4 GPUs, run:
|
412 |
+
|
413 |
+
```shell
|
414 |
+
GPUS=4 BATCH_SIZE=32 shell/head_finetune/internvl_stage2_finetune_coco_224_bs1024_ep5_head_4gpu.sh
|
415 |
+
```
|
416 |
+
|
417 |
+
The hyperparameters used here are:
|
418 |
+
|
419 |
+
| config | Flickr30K | Flickr30K-CN | COCO |
|
420 |
+
| ------------------------ | ------------- | ------------- | ------------- |
|
421 |
+
| learning rate | 1e-6 | 1e-6 | 1e-6 |
|
422 |
+
| optimizer | AdamW | AdamW | AdamW |
|
423 |
+
| weight decay | 0.05 | 0.05 | 0.05 |
|
424 |
+
| input resolution | 224x224 | 224x224 | 224x224 |
|
425 |
+
| total batch size | 4x32 | 4x32 | 4x32 |
|
426 |
+
| warm-up iterations | 100 | 100 | 100 |
|
427 |
+
| training epochs | 10 | 10 | 5 |
|
428 |
+
| drop path rate | 0.0 | 0.0 | 0.3 |
|
429 |
+
| numerical precision | zero3 + bf16 | zero3 + bf16 | zero1 + bf16 |
|
430 |
+
| trainable / total params | 0.2B / 14B | 0.2B / 14B | 0.2B / 14B |
|
431 |
+
| GPUs for training | 4×GPU (>=32G) | 4×GPU (>=32G) | 4×GPU (>=32G) |
|
432 |
+
| Required GPU memory | 24G | 24G | 24G |
|
433 |
+
|
434 |
+
## 🔥 Retrieval Fine-tuning (LoRA)
|
435 |
+
|
436 |
+
> Note: This section demonstrates how to perform a cost-effective fine-tuning of our model. The hyperparameters shown here are not optimized for any specific task. For practical applications, further adjustments to the hyperparameters may be necessary to achieve optimal performance.
|
437 |
+
|
438 |
+
To fine-tune InternVL using LoRA on Flickr30K with 4 GPUs, run:
|
439 |
+
|
440 |
+
```bash
|
441 |
+
GPUS=4 BATCH_SIZE=32 sh shell/lora_finetune/internvl_stage2_finetune_flickr_224_bs1024_ep10_lora16_4gpu.sh
|
442 |
+
```
|
443 |
+
|
444 |
+
To fine-tune InternVL using LoRA on Flickr30K-CN with 4 GPUs, run:
|
445 |
+
|
446 |
+
```shell
|
447 |
+
GPUS=4 BATCH_SIZE=32 sh shell/lora_finetune/internvl_stage2_finetune_flickrcn_224_bs1024_ep10_lora16_4gpu.sh
|
448 |
+
```
|
449 |
+
|
450 |
+
To fine-tune InternVL using LoRA on COCO with 4 GPUs, run:
|
451 |
+
|
452 |
+
```shell
|
453 |
+
GPUS=4 BATCH_SIZE=32 shell/lora_finetune/internvl_stage2_finetune_coco_224_bs1024_ep5_lora16_4gpu.sh
|
454 |
+
```
|
455 |
+
|
456 |
+
The hyperparameters used here are:
|
457 |
+
|
458 |
+
| config | Flickr30K | Flickr30K-CN | COCO |
|
459 |
+
| ------------------------ | ------------- | ------------- | ------------- |
|
460 |
+
| learning rate | 1e-6 | 1e-6 | 1e-6 |
|
461 |
+
| optimizer | AdamW | AdamW | AdamW |
|
462 |
+
| lora rank | 16 | 16 | 16 |
|
463 |
+
| weight decay | 0.05 | 0.05 | 0.05 |
|
464 |
+
| input resolution | 224x224 | 224x224 | 224x224 |
|
465 |
+
| total batch size | 4x32 | 4x32 | 4x32 |
|
466 |
+
| warm-up iterations | 100 | 100 | 100 |
|
467 |
+
| training epochs | 10 | 10 | 5 |
|
468 |
+
| drop path rate | 0.0 | 0.0 | 0.3 |
|
469 |
+
| numerical precision | zero3 + bf16 | zero3 + bf16 | zero1 + bf16 |
|
470 |
+
| trainable / total params | 0.3B / 14B | 0.3B / 14B | 0.3B / 14B |
|
471 |
+
| GPUs for training | 4×GPU (>=40G) | 4×GPU (>=40G) | 4×GPU (>=40G) |
|
472 |
+
| Required GPU memory | 37G | 37G | 37G |
|
473 |
+
|
474 |
+
## Fine-Tuning a Custom Dataset
|
475 |
+
|
476 |
+
1. **Organize Your Data**: Format your dataset similar to COCO or Flickr30K.
|
477 |
+
|
478 |
+
2. **Update Meta Information**: Add your dataset's meta information to the `ds_collections` dictionary in `internvl_g/internvl/train/internvl_stage2_finetune.py`. For example:
|
479 |
+
|
480 |
+
```python
|
481 |
+
ds_collections = {
|
482 |
+
'my_dataset_flickr_format': {
|
483 |
+
'root': './data/my_dataset/images/',
|
484 |
+
'annotation': './data/my_dataset/annotations.txt',
|
485 |
+
},
|
486 |
+
'my_dataset_coco_format': {
|
487 |
+
'root': './data/my_dataset/',
|
488 |
+
'annotation': './data/my_dataset/annotations.json',
|
489 |
+
},
|
490 |
+
}
|
491 |
+
```
|
492 |
+
|
493 |
+
3. **Name Your Dataset**:
|
494 |
+
|
495 |
+
- Include `flickr_format` or `coco_format` in your dataset's `dataset_name`. This will allow the script to reuse the Flickr30K or COCO dataloader accordingly.
|
496 |
+
|
497 |
+
By following these steps, you can easily fine-tune the InternVL model on your custom dataset using the existing COCO or Flickr30K data loading mechanisms.
|
InternVL/segmentation/dist_test.sh
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
CONFIG=$1
|
4 |
+
CHECKPOINT=$2
|
5 |
+
GPUS=$3
|
6 |
+
PORT=${PORT:-29510}
|
7 |
+
PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
|
8 |
+
torchrun --nproc_per_node=$GPUS --master_port=$PORT \
|
9 |
+
$(dirname "$0")/test.py $CONFIG $CHECKPOINT --launcher pytorch ${@:4}
|
InternVL/segmentation/dist_train.sh
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
CONFIG=$1
|
4 |
+
GPUS=$2
|
5 |
+
PORT=${PORT:-29300}
|
6 |
+
|
7 |
+
PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
|
8 |
+
torchrun --nproc_per_node=$GPUS --master_port=$PORT \
|
9 |
+
$(dirname "$0")/train.py $CONFIG --launcher pytorch --deterministic ${@:3}
|
InternVL/segmentation/train.py
ADDED
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# InternVL
|
3 |
+
# Copyright (c) 2023 OpenGVLab
|
4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
5 |
+
# --------------------------------------------------------
|
6 |
+
import argparse
|
7 |
+
import copy
|
8 |
+
import os
|
9 |
+
import os.path as osp
|
10 |
+
import time
|
11 |
+
import warnings
|
12 |
+
|
13 |
+
import mmcv
|
14 |
+
import mmcv_custom # noqa: F401,F403
|
15 |
+
import mmseg_custom # noqa: F401,F403
|
16 |
+
import torch
|
17 |
+
from mmcv.cnn.utils import revert_sync_batchnorm
|
18 |
+
from mmcv.runner import get_dist_info, init_dist
|
19 |
+
from mmcv.utils import Config, DictAction, get_git_hash
|
20 |
+
from mmseg import __version__
|
21 |
+
from mmseg.apis import init_random_seed, set_random_seed, train_segmentor
|
22 |
+
from mmseg.datasets import build_dataset
|
23 |
+
from mmseg.models import build_segmentor
|
24 |
+
from mmseg.utils import collect_env, get_root_logger
|
25 |
+
|
26 |
+
|
27 |
+
def parse_args():
|
28 |
+
parser = argparse.ArgumentParser(description='Train a segmentor')
|
29 |
+
parser.add_argument('config', help='train config file path')
|
30 |
+
parser.add_argument('--work-dir', help='the dir to save logs and models')
|
31 |
+
parser.add_argument(
|
32 |
+
'--load-from', help='the checkpoint file to load weights from')
|
33 |
+
parser.add_argument(
|
34 |
+
'--resume-from', help='the checkpoint file to resume from')
|
35 |
+
parser.add_argument(
|
36 |
+
'--no-validate',
|
37 |
+
action='store_true',
|
38 |
+
help='whether not to evaluate the checkpoint during training')
|
39 |
+
group_gpus = parser.add_mutually_exclusive_group()
|
40 |
+
group_gpus.add_argument(
|
41 |
+
'--gpus',
|
42 |
+
type=int,
|
43 |
+
help='number of gpus to use '
|
44 |
+
'(only applicable to non-distributed training)')
|
45 |
+
group_gpus.add_argument(
|
46 |
+
'--gpu-ids',
|
47 |
+
type=int,
|
48 |
+
nargs='+',
|
49 |
+
help='ids of gpus to use '
|
50 |
+
'(only applicable to non-distributed training)')
|
51 |
+
parser.add_argument('--seed', type=int, default=None, help='random seed')
|
52 |
+
parser.add_argument(
|
53 |
+
'--deterministic',
|
54 |
+
action='store_true',
|
55 |
+
help='whether to set deterministic options for CUDNN backend.')
|
56 |
+
parser.add_argument(
|
57 |
+
'--options',
|
58 |
+
nargs='+',
|
59 |
+
action=DictAction,
|
60 |
+
help="--options is deprecated in favor of --cfg_options' and it will "
|
61 |
+
'not be supported in version v0.22.0. Override some settings in the '
|
62 |
+
'used config, the key-value pair in xxx=yyy format will be merged '
|
63 |
+
'into config file. If the value to be overwritten is a list, it '
|
64 |
+
'should be like key="[a,b]" or key=a,b It also allows nested '
|
65 |
+
'list/tuple values, e.g. key="[(a,b),(c,d)]" Note that the quotation '
|
66 |
+
'marks are necessary and that no white space is allowed.')
|
67 |
+
parser.add_argument(
|
68 |
+
'--cfg-options',
|
69 |
+
nargs='+',
|
70 |
+
action=DictAction,
|
71 |
+
help='override some settings in the used config, the key-value pair '
|
72 |
+
'in xxx=yyy format will be merged into config file. If the value to '
|
73 |
+
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
|
74 |
+
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
|
75 |
+
'Note that the quotation marks are necessary and that no white space '
|
76 |
+
'is allowed.')
|
77 |
+
parser.add_argument(
|
78 |
+
'--launcher',
|
79 |
+
choices=['none', 'pytorch', 'slurm', 'mpi'],
|
80 |
+
default='none',
|
81 |
+
help='job launcher')
|
82 |
+
parser.add_argument('--local_rank', type=int, default=0)
|
83 |
+
parser.add_argument(
|
84 |
+
'--auto-resume',
|
85 |
+
action='store_true',
|
86 |
+
help='resume from the latest checkpoint automatically.')
|
87 |
+
args = parser.parse_args()
|
88 |
+
if 'LOCAL_RANK' not in os.environ:
|
89 |
+
os.environ['LOCAL_RANK'] = str(args.local_rank)
|
90 |
+
|
91 |
+
if args.options and args.cfg_options:
|
92 |
+
raise ValueError(
|
93 |
+
'--options and --cfg-options cannot be both '
|
94 |
+
'specified, --options is deprecated in favor of --cfg-options. '
|
95 |
+
'--options will not be supported in version v0.22.0.')
|
96 |
+
if args.options:
|
97 |
+
warnings.warn('--options is deprecated in favor of --cfg-options. '
|
98 |
+
'--options will not be supported in version v0.22.0.')
|
99 |
+
args.cfg_options = args.options
|
100 |
+
|
101 |
+
return args
|
102 |
+
|
103 |
+
|
104 |
+
def main():
|
105 |
+
args = parse_args()
|
106 |
+
|
107 |
+
cfg = Config.fromfile(args.config)
|
108 |
+
if args.cfg_options is not None:
|
109 |
+
cfg.merge_from_dict(args.cfg_options)
|
110 |
+
# set cudnn_benchmark
|
111 |
+
if cfg.get('cudnn_benchmark', False):
|
112 |
+
torch.backends.cudnn.benchmark = True
|
113 |
+
|
114 |
+
# work_dir is determined in this priority: CLI > segment in file > filename
|
115 |
+
if args.work_dir is not None:
|
116 |
+
# update configs according to CLI args if args.work_dir is not None
|
117 |
+
cfg.work_dir = args.work_dir
|
118 |
+
elif cfg.get('work_dir', None) is None:
|
119 |
+
# use config filename as default work_dir if cfg.work_dir is None
|
120 |
+
cfg.work_dir = osp.join('./work_dirs',
|
121 |
+
osp.splitext(osp.basename(args.config))[0])
|
122 |
+
if args.load_from is not None:
|
123 |
+
cfg.load_from = args.load_from
|
124 |
+
if args.resume_from is not None:
|
125 |
+
cfg.resume_from = args.resume_from
|
126 |
+
if args.gpu_ids is not None:
|
127 |
+
cfg.gpu_ids = args.gpu_ids
|
128 |
+
else:
|
129 |
+
cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus)
|
130 |
+
cfg.auto_resume = args.auto_resume
|
131 |
+
|
132 |
+
# init distributed env first, since logger depends on the dist info.
|
133 |
+
if args.launcher == 'none':
|
134 |
+
distributed = False
|
135 |
+
else:
|
136 |
+
distributed = True
|
137 |
+
init_dist(args.launcher, **cfg.dist_params)
|
138 |
+
# gpu_ids is used to calculate iter when resuming checkpoint
|
139 |
+
_, world_size = get_dist_info()
|
140 |
+
cfg.gpu_ids = range(world_size)
|
141 |
+
|
142 |
+
cfg.device = 'cuda' # fix 'ConfigDict' object has no attribute 'device'
|
143 |
+
# create work_dir
|
144 |
+
mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
|
145 |
+
# dump config
|
146 |
+
cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config)))
|
147 |
+
# init the logger before other steps
|
148 |
+
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
|
149 |
+
log_file = osp.join(cfg.work_dir, f'{timestamp}.log')
|
150 |
+
logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)
|
151 |
+
|
152 |
+
# init the meta dict to record some important information such as
|
153 |
+
# environment info and seed, which will be logged
|
154 |
+
meta = dict()
|
155 |
+
# log env info
|
156 |
+
env_info_dict = collect_env()
|
157 |
+
env_info = '\n'.join([f'{k}: {v}' for k, v in env_info_dict.items()])
|
158 |
+
dash_line = '-' * 60 + '\n'
|
159 |
+
logger.info('Environment info:\n' + dash_line + env_info + '\n' +
|
160 |
+
dash_line)
|
161 |
+
meta['env_info'] = env_info
|
162 |
+
|
163 |
+
# log some basic info
|
164 |
+
logger.info(f'Distributed training: {distributed}')
|
165 |
+
logger.info(f'Config:\n{cfg.pretty_text}')
|
166 |
+
|
167 |
+
# set random seeds
|
168 |
+
seed = init_random_seed(args.seed)
|
169 |
+
logger.info(f'Set random seed to {seed}, '
|
170 |
+
f'deterministic: {args.deterministic}')
|
171 |
+
set_random_seed(seed, deterministic=args.deterministic)
|
172 |
+
cfg.seed = seed
|
173 |
+
meta['seed'] = seed
|
174 |
+
meta['exp_name'] = osp.basename(args.config)
|
175 |
+
|
176 |
+
model = build_segmentor(
|
177 |
+
cfg.model,
|
178 |
+
train_cfg=cfg.get('train_cfg'),
|
179 |
+
test_cfg=cfg.get('test_cfg'))
|
180 |
+
model.init_weights()
|
181 |
+
|
182 |
+
# SyncBN is not support for DP
|
183 |
+
if not distributed:
|
184 |
+
warnings.warn(
|
185 |
+
'SyncBN is only supported with DDP. To be compatible with DP, '
|
186 |
+
'we convert SyncBN to BN. Please use dist_train.sh which can '
|
187 |
+
'avoid this error.')
|
188 |
+
model = revert_sync_batchnorm(model)
|
189 |
+
|
190 |
+
logger.info(model)
|
191 |
+
|
192 |
+
datasets = [build_dataset(cfg.data.train)]
|
193 |
+
if len(cfg.workflow) == 2:
|
194 |
+
val_dataset = copy.deepcopy(cfg.data.val)
|
195 |
+
val_dataset.pipeline = cfg.data.train.pipeline
|
196 |
+
datasets.append(build_dataset(val_dataset))
|
197 |
+
if cfg.checkpoint_config is not None:
|
198 |
+
# save mmseg version, config file content and class names in
|
199 |
+
# checkpoints as meta data
|
200 |
+
cfg.checkpoint_config.meta = dict(
|
201 |
+
mmseg_version=f'{__version__}+{get_git_hash()[:7]}',
|
202 |
+
config=cfg.pretty_text,
|
203 |
+
CLASSES=datasets[0].CLASSES,
|
204 |
+
PALETTE=datasets[0].PALETTE)
|
205 |
+
# add an attribute for visualization convenience
|
206 |
+
model.CLASSES = datasets[0].CLASSES
|
207 |
+
# passing checkpoint meta for saving best checkpoint
|
208 |
+
meta.update(cfg.checkpoint_config.meta)
|
209 |
+
train_segmentor(
|
210 |
+
model,
|
211 |
+
datasets,
|
212 |
+
cfg,
|
213 |
+
distributed=distributed,
|
214 |
+
validate=(not args.no_validate),
|
215 |
+
timestamp=timestamp,
|
216 |
+
meta=meta)
|
217 |
+
|
218 |
+
|
219 |
+
if __name__ == '__main__':
|
220 |
+
main()
|
InternVL/streamlit_demo/constants.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# InternVL
|
3 |
+
# Copyright (c) 2024 OpenGVLab
|
4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
5 |
+
# --------------------------------------------------------
|
6 |
+
|
7 |
+
CONTROLLER_HEART_BEAT_EXPIRATION = 30
|
8 |
+
WORKER_HEART_BEAT_INTERVAL = 15
|
9 |
+
|
10 |
+
LOGDIR = 'logs/'
|
11 |
+
|
12 |
+
# Model Constants
|
13 |
+
IGNORE_INDEX = -100
|
14 |
+
IMAGE_TOKEN_INDEX = -200
|
15 |
+
DEFAULT_IMAGE_TOKEN = '<image>'
|
16 |
+
DEFAULT_IMAGE_PATCH_TOKEN = '<IMG_CONTEXT>'
|
17 |
+
DEFAULT_IM_START_TOKEN = '<img>'
|
18 |
+
DEFAULT_IM_END_TOKEN = '</img>'
|
19 |
+
IMAGE_PLACEHOLDER = '<image-placeholder>'
|
20 |
+
IMAGENET_MEAN = (0.485, 0.456, 0.406)
|
21 |
+
IMAGENET_STD = (0.229, 0.224, 0.225)
|
22 |
+
|
23 |
+
server_error_msg = '**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**'
|
InternVL/streamlit_demo/controller.py
ADDED
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
A controller manages distributed workers.
|
3 |
+
It sends worker addresses to clients.
|
4 |
+
"""
|
5 |
+
import argparse
|
6 |
+
import dataclasses
|
7 |
+
import json
|
8 |
+
import re
|
9 |
+
import threading
|
10 |
+
import time
|
11 |
+
from enum import Enum, auto
|
12 |
+
from typing import List
|
13 |
+
|
14 |
+
import numpy as np
|
15 |
+
import requests
|
16 |
+
import uvicorn
|
17 |
+
from fastapi import FastAPI, Request
|
18 |
+
from fastapi.responses import StreamingResponse
|
19 |
+
from utils import build_logger, server_error_msg
|
20 |
+
|
21 |
+
CONTROLLER_HEART_BEAT_EXPIRATION = 30
|
22 |
+
logger = build_logger('controller', 'controller.log')
|
23 |
+
|
24 |
+
|
25 |
+
class DispatchMethod(Enum):
|
26 |
+
LOTTERY = auto()
|
27 |
+
SHORTEST_QUEUE = auto()
|
28 |
+
|
29 |
+
@classmethod
|
30 |
+
def from_str(cls, name):
|
31 |
+
if name == 'lottery':
|
32 |
+
return cls.LOTTERY
|
33 |
+
elif name == 'shortest_queue':
|
34 |
+
return cls.SHORTEST_QUEUE
|
35 |
+
else:
|
36 |
+
raise ValueError(f'Invalid dispatch method')
|
37 |
+
|
38 |
+
|
39 |
+
@dataclasses.dataclass
|
40 |
+
class WorkerInfo:
|
41 |
+
model_names: List[str]
|
42 |
+
speed: int
|
43 |
+
queue_length: int
|
44 |
+
check_heart_beat: bool
|
45 |
+
last_heart_beat: str
|
46 |
+
|
47 |
+
|
48 |
+
def heart_beat_controller(controller):
|
49 |
+
while True:
|
50 |
+
time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION)
|
51 |
+
controller.remove_stable_workers_by_expiration()
|
52 |
+
|
53 |
+
|
54 |
+
class Controller:
|
55 |
+
def __init__(self, dispatch_method: str):
|
56 |
+
# Dict[str -> WorkerInfo]
|
57 |
+
self.worker_info = {}
|
58 |
+
self.dispatch_method = DispatchMethod.from_str(dispatch_method)
|
59 |
+
|
60 |
+
self.heart_beat_thread = threading.Thread(
|
61 |
+
target=heart_beat_controller, args=(self,))
|
62 |
+
self.heart_beat_thread.start()
|
63 |
+
|
64 |
+
logger.info('Init controller')
|
65 |
+
|
66 |
+
def register_worker(self, worker_name: str, check_heart_beat: bool,
|
67 |
+
worker_status: dict):
|
68 |
+
if worker_name not in self.worker_info:
|
69 |
+
logger.info(f'Register a new worker: {worker_name}')
|
70 |
+
else:
|
71 |
+
logger.info(f'Register an existing worker: {worker_name}')
|
72 |
+
|
73 |
+
if not worker_status:
|
74 |
+
worker_status = self.get_worker_status(worker_name)
|
75 |
+
if not worker_status:
|
76 |
+
return False
|
77 |
+
|
78 |
+
self.worker_info[worker_name] = WorkerInfo(
|
79 |
+
worker_status['model_names'], worker_status['speed'], worker_status['queue_length'],
|
80 |
+
check_heart_beat, time.time())
|
81 |
+
|
82 |
+
logger.info(f'Register done: {worker_name}, {worker_status}')
|
83 |
+
return True
|
84 |
+
|
85 |
+
def get_worker_status(self, worker_name: str):
|
86 |
+
try:
|
87 |
+
r = requests.post(worker_name + '/worker_get_status', timeout=5)
|
88 |
+
except requests.exceptions.RequestException as e:
|
89 |
+
logger.error(f'Get status fails: {worker_name}, {e}')
|
90 |
+
return None
|
91 |
+
|
92 |
+
if r.status_code != 200:
|
93 |
+
logger.error(f'Get status fails: {worker_name}, {r}')
|
94 |
+
return None
|
95 |
+
|
96 |
+
return r.json()
|
97 |
+
|
98 |
+
def remove_worker(self, worker_name: str):
|
99 |
+
del self.worker_info[worker_name]
|
100 |
+
|
101 |
+
def refresh_all_workers(self):
|
102 |
+
old_info = dict(self.worker_info)
|
103 |
+
self.worker_info = {}
|
104 |
+
|
105 |
+
for w_name, w_info in old_info.items():
|
106 |
+
if not self.register_worker(w_name, w_info.check_heart_beat, None):
|
107 |
+
logger.info(f'Remove stale worker: {w_name}')
|
108 |
+
|
109 |
+
def list_models(self):
|
110 |
+
model_names = set()
|
111 |
+
|
112 |
+
for w_name, w_info in self.worker_info.items():
|
113 |
+
model_names.update(w_info.model_names)
|
114 |
+
|
115 |
+
def extract_key(s):
|
116 |
+
if 'Pro' in s:
|
117 |
+
return 999
|
118 |
+
match = re.match(r'InternVL2-(\d+)B', s)
|
119 |
+
if match:
|
120 |
+
return int(match.group(1))
|
121 |
+
return -1
|
122 |
+
|
123 |
+
def custom_sort_key(s):
|
124 |
+
key = extract_key(s)
|
125 |
+
# Return a tuple where -1 will ensure that non-matching items come last
|
126 |
+
return (0 if key != -1 else 1, -key if key != -1 else s)
|
127 |
+
|
128 |
+
sorted_list = sorted(list(model_names), key=custom_sort_key)
|
129 |
+
return sorted_list
|
130 |
+
|
131 |
+
def get_worker_address(self, model_name: str):
|
132 |
+
if self.dispatch_method == DispatchMethod.LOTTERY:
|
133 |
+
worker_names = []
|
134 |
+
worker_speeds = []
|
135 |
+
for w_name, w_info in self.worker_info.items():
|
136 |
+
if model_name in w_info.model_names:
|
137 |
+
worker_names.append(w_name)
|
138 |
+
worker_speeds.append(w_info.speed)
|
139 |
+
worker_speeds = np.array(worker_speeds, dtype=np.float32)
|
140 |
+
norm = np.sum(worker_speeds)
|
141 |
+
if norm < 1e-4:
|
142 |
+
return ''
|
143 |
+
worker_speeds = worker_speeds / norm
|
144 |
+
if True: # Directly return address
|
145 |
+
pt = np.random.choice(np.arange(len(worker_names)),
|
146 |
+
p=worker_speeds)
|
147 |
+
worker_name = worker_names[pt]
|
148 |
+
return worker_name
|
149 |
+
|
150 |
+
elif self.dispatch_method == DispatchMethod.SHORTEST_QUEUE:
|
151 |
+
worker_names = []
|
152 |
+
worker_qlen = []
|
153 |
+
for w_name, w_info in self.worker_info.items():
|
154 |
+
if model_name in w_info.model_names:
|
155 |
+
worker_names.append(w_name)
|
156 |
+
worker_qlen.append(w_info.queue_length / w_info.speed)
|
157 |
+
if len(worker_names) == 0:
|
158 |
+
return ''
|
159 |
+
min_index = np.argmin(worker_qlen)
|
160 |
+
w_name = worker_names[min_index]
|
161 |
+
self.worker_info[w_name].queue_length += 1
|
162 |
+
logger.info(f'names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}')
|
163 |
+
return w_name
|
164 |
+
else:
|
165 |
+
raise ValueError(f'Invalid dispatch method: {self.dispatch_method}')
|
166 |
+
|
167 |
+
def receive_heart_beat(self, worker_name: str, queue_length: int):
|
168 |
+
if worker_name not in self.worker_info:
|
169 |
+
logger.info(f'Receive unknown heart beat. {worker_name}')
|
170 |
+
return False
|
171 |
+
|
172 |
+
self.worker_info[worker_name].queue_length = queue_length
|
173 |
+
self.worker_info[worker_name].last_heart_beat = time.time()
|
174 |
+
logger.info(f'Receive heart beat. {worker_name}')
|
175 |
+
return True
|
176 |
+
|
177 |
+
def remove_stable_workers_by_expiration(self):
|
178 |
+
expire = time.time() - CONTROLLER_HEART_BEAT_EXPIRATION
|
179 |
+
to_delete = []
|
180 |
+
for worker_name, w_info in self.worker_info.items():
|
181 |
+
if w_info.check_heart_beat and w_info.last_heart_beat < expire:
|
182 |
+
to_delete.append(worker_name)
|
183 |
+
|
184 |
+
for worker_name in to_delete:
|
185 |
+
self.remove_worker(worker_name)
|
186 |
+
|
187 |
+
def worker_api_generate_stream(self, params):
|
188 |
+
worker_addr = self.get_worker_address(params['model'])
|
189 |
+
if not worker_addr:
|
190 |
+
logger.info(f"no worker: {params['model']}")
|
191 |
+
ret = {
|
192 |
+
'text': server_error_msg,
|
193 |
+
'error_code': 2,
|
194 |
+
}
|
195 |
+
yield json.dumps(ret).encode() + b'\0'
|
196 |
+
|
197 |
+
try:
|
198 |
+
response = requests.post(worker_addr + '/worker_generate_stream',
|
199 |
+
json=params, stream=True, timeout=5)
|
200 |
+
for chunk in response.iter_lines(decode_unicode=False, delimiter=b'\0'):
|
201 |
+
if chunk:
|
202 |
+
yield chunk + b'\0'
|
203 |
+
except requests.exceptions.RequestException as e:
|
204 |
+
logger.info(f'worker timeout: {worker_addr}')
|
205 |
+
ret = {
|
206 |
+
'text': server_error_msg,
|
207 |
+
'error_code': 3,
|
208 |
+
}
|
209 |
+
yield json.dumps(ret).encode() + b'\0'
|
210 |
+
|
211 |
+
# Let the controller act as a worker to achieve hierarchical
|
212 |
+
# management. This can be used to connect isolated sub networks.
|
213 |
+
def worker_api_get_status(self):
|
214 |
+
model_names = set()
|
215 |
+
speed = 0
|
216 |
+
queue_length = 0
|
217 |
+
|
218 |
+
for w_name in self.worker_info:
|
219 |
+
worker_status = self.get_worker_status(w_name)
|
220 |
+
if worker_status is not None:
|
221 |
+
model_names.update(worker_status['model_names'])
|
222 |
+
speed += worker_status['speed']
|
223 |
+
queue_length += worker_status['queue_length']
|
224 |
+
|
225 |
+
return {
|
226 |
+
'model_names': list(model_names),
|
227 |
+
'speed': speed,
|
228 |
+
'queue_length': queue_length,
|
229 |
+
}
|
230 |
+
|
231 |
+
|
232 |
+
app = FastAPI()
|
233 |
+
|
234 |
+
|
235 |
+
@app.post('/register_worker')
|
236 |
+
async def register_worker(request: Request):
|
237 |
+
data = await request.json()
|
238 |
+
controller.register_worker(
|
239 |
+
data['worker_name'], data['check_heart_beat'],
|
240 |
+
data.get('worker_status', None))
|
241 |
+
|
242 |
+
|
243 |
+
@app.post('/refresh_all_workers')
|
244 |
+
async def refresh_all_workers():
|
245 |
+
models = controller.refresh_all_workers()
|
246 |
+
|
247 |
+
|
248 |
+
@app.post('/list_models')
|
249 |
+
async def list_models():
|
250 |
+
models = controller.list_models()
|
251 |
+
return {'models': models}
|
252 |
+
|
253 |
+
|
254 |
+
@app.post('/get_worker_address')
|
255 |
+
async def get_worker_address(request: Request):
|
256 |
+
data = await request.json()
|
257 |
+
addr = controller.get_worker_address(data['model'])
|
258 |
+
return {'address': addr}
|
259 |
+
|
260 |
+
|
261 |
+
@app.post('/receive_heart_beat')
|
262 |
+
async def receive_heart_beat(request: Request):
|
263 |
+
data = await request.json()
|
264 |
+
exist = controller.receive_heart_beat(
|
265 |
+
data['worker_name'], data['queue_length'])
|
266 |
+
return {'exist': exist}
|
267 |
+
|
268 |
+
|
269 |
+
@app.post('/worker_generate_stream')
|
270 |
+
async def worker_api_generate_stream(request: Request):
|
271 |
+
params = await request.json()
|
272 |
+
generator = controller.worker_api_generate_stream(params)
|
273 |
+
return StreamingResponse(generator)
|
274 |
+
|
275 |
+
|
276 |
+
@app.post('/worker_get_status')
|
277 |
+
async def worker_api_get_status(request: Request):
|
278 |
+
return controller.worker_api_get_status()
|
279 |
+
|
280 |
+
|
281 |
+
if __name__ == '__main__':
|
282 |
+
parser = argparse.ArgumentParser()
|
283 |
+
parser.add_argument('--host', type=str, default='0.0.0.0')
|
284 |
+
parser.add_argument('--port', type=int, default=10075)
|
285 |
+
parser.add_argument('--dispatch-method', type=str, choices=[
|
286 |
+
'lottery', 'shortest_queue'], default='shortest_queue')
|
287 |
+
args = parser.parse_args()
|
288 |
+
logger.info(f'args: {args}')
|
289 |
+
|
290 |
+
controller = Controller(args.dispatch_method)
|
291 |
+
uvicorn.run(app, host=args.host, port=args.port, log_level='info')
|
InternVL/streamlit_demo/model_worker.py
ADDED
@@ -0,0 +1,442 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# InternVL
|
3 |
+
# Copyright (c) 2024 OpenGVLab
|
4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
5 |
+
# --------------------------------------------------------
|
6 |
+
|
7 |
+
"""
|
8 |
+
A model worker executes the model.
|
9 |
+
"""
|
10 |
+
import argparse
|
11 |
+
import asyncio
|
12 |
+
import base64
|
13 |
+
import json
|
14 |
+
import math
|
15 |
+
import threading
|
16 |
+
import time
|
17 |
+
import uuid
|
18 |
+
from functools import partial
|
19 |
+
from io import BytesIO
|
20 |
+
from threading import Thread
|
21 |
+
|
22 |
+
import requests
|
23 |
+
import torch
|
24 |
+
import torchvision.transforms as T
|
25 |
+
import uvicorn
|
26 |
+
from constants import IMAGENET_MEAN, IMAGENET_STD, WORKER_HEART_BEAT_INTERVAL
|
27 |
+
from fastapi import BackgroundTasks, FastAPI, Request
|
28 |
+
from fastapi.responses import StreamingResponse
|
29 |
+
from PIL import Image
|
30 |
+
from torchvision.transforms.functional import InterpolationMode
|
31 |
+
from transformers import AutoModel, AutoTokenizer, TextIteratorStreamer
|
32 |
+
from utils import build_logger, pretty_print_semaphore, server_error_msg
|
33 |
+
|
34 |
+
worker_id = str(uuid.uuid4())[:6]
|
35 |
+
logger = build_logger('model_worker', f'model_worker_{worker_id}.log')
|
36 |
+
global_counter = 0
|
37 |
+
model_semaphore = None
|
38 |
+
|
39 |
+
|
40 |
+
def load_image_from_base64(image):
|
41 |
+
return Image.open(BytesIO(base64.b64decode(image)))
|
42 |
+
|
43 |
+
|
44 |
+
def build_transform(input_size):
|
45 |
+
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
|
46 |
+
transform = T.Compose([
|
47 |
+
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
|
48 |
+
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
|
49 |
+
T.ToTensor(),
|
50 |
+
T.Normalize(mean=MEAN, std=STD)
|
51 |
+
])
|
52 |
+
return transform
|
53 |
+
|
54 |
+
|
55 |
+
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
|
56 |
+
best_ratio_diff = float('inf')
|
57 |
+
best_ratio = (1, 1)
|
58 |
+
area = width * height
|
59 |
+
for ratio in target_ratios:
|
60 |
+
target_aspect_ratio = ratio[0] / ratio[1]
|
61 |
+
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
|
62 |
+
if ratio_diff < best_ratio_diff:
|
63 |
+
best_ratio_diff = ratio_diff
|
64 |
+
best_ratio = ratio
|
65 |
+
elif ratio_diff == best_ratio_diff:
|
66 |
+
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
|
67 |
+
best_ratio = ratio
|
68 |
+
return best_ratio
|
69 |
+
|
70 |
+
|
71 |
+
def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False):
|
72 |
+
orig_width, orig_height = image.size
|
73 |
+
aspect_ratio = orig_width / orig_height
|
74 |
+
|
75 |
+
# calculate the existing image aspect ratio
|
76 |
+
target_ratios = set(
|
77 |
+
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
|
78 |
+
i * j <= max_num and i * j >= min_num)
|
79 |
+
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
80 |
+
|
81 |
+
# find the closest aspect ratio to the target
|
82 |
+
target_aspect_ratio = find_closest_aspect_ratio(
|
83 |
+
aspect_ratio, target_ratios, orig_width, orig_height, image_size)
|
84 |
+
|
85 |
+
# calculate the target width and height
|
86 |
+
target_width = image_size * target_aspect_ratio[0]
|
87 |
+
target_height = image_size * target_aspect_ratio[1]
|
88 |
+
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
|
89 |
+
|
90 |
+
# resize the image
|
91 |
+
resized_img = image.resize((target_width, target_height))
|
92 |
+
processed_images = []
|
93 |
+
for i in range(blocks):
|
94 |
+
box = (
|
95 |
+
(i % (target_width // image_size)) * image_size,
|
96 |
+
(i // (target_width // image_size)) * image_size,
|
97 |
+
((i % (target_width // image_size)) + 1) * image_size,
|
98 |
+
((i // (target_width // image_size)) + 1) * image_size
|
99 |
+
)
|
100 |
+
# split the image
|
101 |
+
split_img = resized_img.crop(box)
|
102 |
+
processed_images.append(split_img)
|
103 |
+
assert len(processed_images) == blocks
|
104 |
+
if use_thumbnail and len(processed_images) != 1:
|
105 |
+
thumbnail_img = image.resize((image_size, image_size))
|
106 |
+
processed_images.append(thumbnail_img)
|
107 |
+
return processed_images
|
108 |
+
|
109 |
+
|
110 |
+
def heart_beat_worker(controller):
|
111 |
+
while True:
|
112 |
+
time.sleep(WORKER_HEART_BEAT_INTERVAL)
|
113 |
+
controller.send_heart_beat()
|
114 |
+
|
115 |
+
|
116 |
+
def split_model(model_name, vit_alpha=0.5):
|
117 |
+
device_map = {}
|
118 |
+
world_size = torch.cuda.device_count()
|
119 |
+
num_layers = {
|
120 |
+
'InternVL-Chat-V1-1': 40, 'InternVL-Chat-V1-2': 60, 'InternVL-Chat-V1-2-Plus': 60,
|
121 |
+
'Mini-InternVL-2B-V1-5': 24, 'Mini-InternVL-4B-V1-5': 32, 'InternVL-Chat-V1-5': 48,
|
122 |
+
'InternVL2-8B': 32, 'InternVL2-26B': 48, 'InternVL2-40B': 60, 'InternVL2-Llama3-76B': 80,
|
123 |
+
'InternVL2-78B': 80, 'InternVL2-Pro': 80}[model_name]
|
124 |
+
# Since the first GPU will be used for ViT, treat it as half a GPU.
|
125 |
+
num_layers_per_gpu = math.ceil(num_layers / (world_size - vit_alpha))
|
126 |
+
num_layers_per_gpu = [num_layers_per_gpu] * world_size
|
127 |
+
num_layers_per_gpu[0] = math.ceil(num_layers_per_gpu[0] * (1 - vit_alpha))
|
128 |
+
layer_cnt = 0
|
129 |
+
for i, num_layer in enumerate(num_layers_per_gpu):
|
130 |
+
for j in range(num_layer):
|
131 |
+
device_map[f'language_model.model.layers.{layer_cnt}'] = i
|
132 |
+
layer_cnt += 1
|
133 |
+
device_map['vision_model'] = 0
|
134 |
+
device_map['mlp1'] = 0
|
135 |
+
device_map['language_model.model.tok_embeddings'] = 0
|
136 |
+
device_map['language_model.model.embed_tokens'] = 0
|
137 |
+
device_map['language_model.output'] = 0
|
138 |
+
device_map['language_model.model.norm'] = 0
|
139 |
+
device_map['language_model.lm_head'] = 0
|
140 |
+
device_map[f'language_model.model.layers.{num_layers - 1}'] = 0
|
141 |
+
|
142 |
+
return device_map
|
143 |
+
|
144 |
+
|
145 |
+
class ModelWorker:
|
146 |
+
def __init__(self, controller_addr, worker_addr, worker_id, model_path, model_name,
|
147 |
+
load_8bit, device, context_len=8192):
|
148 |
+
self.controller_addr = controller_addr
|
149 |
+
self.worker_addr = worker_addr
|
150 |
+
self.worker_id = worker_id
|
151 |
+
if model_path.endswith('/'):
|
152 |
+
model_path = model_path[:-1]
|
153 |
+
if model_name is None:
|
154 |
+
model_paths = model_path.split('/')
|
155 |
+
if model_paths[-1].startswith('checkpoint-'):
|
156 |
+
self.model_name = model_paths[-2] + '_' + model_paths[-1]
|
157 |
+
else:
|
158 |
+
self.model_name = model_paths[-1]
|
159 |
+
else:
|
160 |
+
self.model_name = model_name
|
161 |
+
|
162 |
+
logger.info(f'Loading the model {self.model_name} on worker {worker_id} ...')
|
163 |
+
|
164 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, use_fast=False)
|
165 |
+
tokens_to_keep = ['<box>', '</box>', '<ref>', '</ref>']
|
166 |
+
tokenizer.additional_special_tokens = [item for item in tokenizer.additional_special_tokens if item not in tokens_to_keep]
|
167 |
+
self.tokenizer = tokenizer
|
168 |
+
|
169 |
+
if device == 'auto':
|
170 |
+
device_map = split_model(self.model_name)
|
171 |
+
self.model = AutoModel.from_pretrained(
|
172 |
+
model_path,
|
173 |
+
load_in_8bit=load_8bit,
|
174 |
+
torch_dtype=torch.bfloat16,
|
175 |
+
device_map=device_map,
|
176 |
+
trust_remote_code=True).eval()
|
177 |
+
else:
|
178 |
+
self.model = AutoModel.from_pretrained(
|
179 |
+
model_path,
|
180 |
+
load_in_8bit=load_8bit,
|
181 |
+
torch_dtype=torch.bfloat16,
|
182 |
+
trust_remote_code=True).eval()
|
183 |
+
if not load_8bit and not device == 'auto':
|
184 |
+
self.model = self.model.cuda()
|
185 |
+
self.load_8bit = load_8bit
|
186 |
+
self.device = device
|
187 |
+
self.model_path = model_path
|
188 |
+
self.image_size = self.model.config.force_image_size
|
189 |
+
self.context_len = context_len
|
190 |
+
self.register_to_controller()
|
191 |
+
self.heart_beat_thread = threading.Thread(
|
192 |
+
target=heart_beat_worker, args=(self,))
|
193 |
+
self.heart_beat_thread.start()
|
194 |
+
|
195 |
+
def reload_model(self):
|
196 |
+
del self.model
|
197 |
+
torch.cuda.empty_cache()
|
198 |
+
if self.device == 'auto':
|
199 |
+
device_map = split_model(self.model_name)
|
200 |
+
self.model = AutoModel.from_pretrained(
|
201 |
+
self.model_path,
|
202 |
+
load_in_8bit=self.load_8bit,
|
203 |
+
torch_dtype=torch.bfloat16,
|
204 |
+
device_map=device_map,
|
205 |
+
trust_remote_code=True).eval()
|
206 |
+
else:
|
207 |
+
self.model = AutoModel.from_pretrained(
|
208 |
+
self.model_path,
|
209 |
+
load_in_8bit=self.load_8bit,
|
210 |
+
torch_dtype=torch.bfloat16,
|
211 |
+
trust_remote_code=True).eval()
|
212 |
+
if not self.load_8bit and not self.device == 'auto':
|
213 |
+
self.model = self.model.cuda()
|
214 |
+
|
215 |
+
def register_to_controller(self):
|
216 |
+
logger.info('Register to controller')
|
217 |
+
|
218 |
+
url = self.controller_addr + '/register_worker'
|
219 |
+
data = {
|
220 |
+
'worker_name': self.worker_addr,
|
221 |
+
'check_heart_beat': True,
|
222 |
+
'worker_status': self.get_status()
|
223 |
+
}
|
224 |
+
r = requests.post(url, json=data)
|
225 |
+
assert r.status_code == 200
|
226 |
+
|
227 |
+
def send_heart_beat(self):
|
228 |
+
logger.info(f'Send heart beat. Models: {[self.model_name]}. '
|
229 |
+
f'Semaphore: {pretty_print_semaphore(model_semaphore)}. '
|
230 |
+
f'global_counter: {global_counter}')
|
231 |
+
|
232 |
+
url = self.controller_addr + '/receive_heart_beat'
|
233 |
+
|
234 |
+
while True:
|
235 |
+
try:
|
236 |
+
ret = requests.post(url, json={
|
237 |
+
'worker_name': self.worker_addr,
|
238 |
+
'queue_length': self.get_queue_length()}, timeout=5)
|
239 |
+
exist = ret.json()['exist']
|
240 |
+
break
|
241 |
+
except requests.exceptions.RequestException as e:
|
242 |
+
logger.error(f'heart beat error: {e}')
|
243 |
+
time.sleep(5)
|
244 |
+
|
245 |
+
if not exist:
|
246 |
+
self.register_to_controller()
|
247 |
+
|
248 |
+
def get_queue_length(self):
|
249 |
+
if model_semaphore is None:
|
250 |
+
return 0
|
251 |
+
else:
|
252 |
+
return args.limit_model_concurrency - model_semaphore._value + (len(
|
253 |
+
model_semaphore._waiters) if model_semaphore._waiters is not None else 0)
|
254 |
+
|
255 |
+
def get_status(self):
|
256 |
+
return {
|
257 |
+
'model_names': [self.model_name],
|
258 |
+
'speed': 1,
|
259 |
+
'queue_length': self.get_queue_length(),
|
260 |
+
}
|
261 |
+
|
262 |
+
@torch.inference_mode()
|
263 |
+
def generate_stream(self, params):
|
264 |
+
system_message = params['prompt'][0]['content']
|
265 |
+
send_messages = params['prompt'][1:]
|
266 |
+
max_input_tiles = params['max_input_tiles']
|
267 |
+
temperature = params['temperature']
|
268 |
+
top_p = params['top_p']
|
269 |
+
max_new_tokens = params['max_new_tokens']
|
270 |
+
repetition_penalty = params['repetition_penalty']
|
271 |
+
do_sample = True if temperature > 0.0 else False
|
272 |
+
|
273 |
+
global_image_cnt = 0
|
274 |
+
history, pil_images, max_input_tile_list = [], [], []
|
275 |
+
for message in send_messages:
|
276 |
+
if message['role'] == 'user':
|
277 |
+
prefix = ''
|
278 |
+
if 'image' in message:
|
279 |
+
max_input_tile_temp = []
|
280 |
+
for image_str in message['image']:
|
281 |
+
pil_images.append(load_image_from_base64(image_str))
|
282 |
+
prefix += f'Image-{global_image_cnt + 1}: <image>\n'
|
283 |
+
global_image_cnt += 1
|
284 |
+
max_input_tile_temp.append(max(1, max_input_tiles // len(message['image'])))
|
285 |
+
if len(max_input_tile_temp) > 0:
|
286 |
+
max_input_tile_list.append(max_input_tile_temp)
|
287 |
+
content = prefix + message['content']
|
288 |
+
history.append([content, ])
|
289 |
+
else:
|
290 |
+
history[-1].append(message['content'])
|
291 |
+
question, history = history[-1][0], history[:-1]
|
292 |
+
|
293 |
+
if global_image_cnt == 1:
|
294 |
+
question = question.replace('Image-1: <image>\n', '<image>\n')
|
295 |
+
history = [[item[0].replace('Image-1: <image>\n', '<image>\n'), item[1]] for item in history]
|
296 |
+
|
297 |
+
# Create a new list to store processed sublists
|
298 |
+
flattened_list = []
|
299 |
+
# Iterate through all but the last sublist in max_input_tile_list and process them
|
300 |
+
for sublist in max_input_tile_list[:-1]:
|
301 |
+
processed_sublist = [1] * len(sublist) # Change each element in the sublist to 1
|
302 |
+
flattened_list.extend(processed_sublist) # Flatten the processed sublist and add to the new list
|
303 |
+
# If max_input_tile_list is not empty, add the last sublist to the new list
|
304 |
+
if max_input_tile_list:
|
305 |
+
flattened_list.extend(max_input_tile_list[-1])
|
306 |
+
max_input_tile_list = flattened_list
|
307 |
+
assert len(max_input_tile_list) == len(pil_images), 'The number of max_input_tile_list and pil_images should be the same.'
|
308 |
+
|
309 |
+
old_system_message = self.model.system_message
|
310 |
+
self.model.system_message = system_message
|
311 |
+
image_tiles, num_patches_list = [], []
|
312 |
+
transform = build_transform(input_size=self.image_size)
|
313 |
+
if len(pil_images) > 0:
|
314 |
+
for current_max_input_tiles, pil_image in zip(max_input_tile_list, pil_images):
|
315 |
+
if self.model.config.dynamic_image_size:
|
316 |
+
tiles = dynamic_preprocess(
|
317 |
+
pil_image, image_size=self.image_size, max_num=current_max_input_tiles,
|
318 |
+
use_thumbnail=self.model.config.use_thumbnail)
|
319 |
+
else:
|
320 |
+
tiles = [pil_image]
|
321 |
+
num_patches_list.append(len(tiles))
|
322 |
+
image_tiles += tiles
|
323 |
+
pixel_values = [transform(item) for item in image_tiles]
|
324 |
+
pixel_values = torch.stack(pixel_values).to(self.model.device, dtype=torch.bfloat16)
|
325 |
+
logger.info(f'Split images to {pixel_values.shape}')
|
326 |
+
else:
|
327 |
+
pixel_values = None
|
328 |
+
|
329 |
+
streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=10)
|
330 |
+
generation_config = dict(
|
331 |
+
num_beams=1,
|
332 |
+
max_new_tokens=max_new_tokens,
|
333 |
+
do_sample=do_sample,
|
334 |
+
temperature=temperature,
|
335 |
+
repetition_penalty=repetition_penalty,
|
336 |
+
max_length=self.context_len,
|
337 |
+
top_p=top_p,
|
338 |
+
streamer=streamer,
|
339 |
+
)
|
340 |
+
logger.info(f'Generation config: {generation_config}')
|
341 |
+
|
342 |
+
thread = Thread(target=self.model.chat, kwargs=dict(
|
343 |
+
tokenizer=self.tokenizer,
|
344 |
+
pixel_values=pixel_values,
|
345 |
+
num_patches_list=num_patches_list,
|
346 |
+
question=question,
|
347 |
+
history=history,
|
348 |
+
return_history=False,
|
349 |
+
generation_config=generation_config,
|
350 |
+
))
|
351 |
+
thread.start()
|
352 |
+
|
353 |
+
generated_text = ''
|
354 |
+
for new_text in streamer:
|
355 |
+
generated_text += new_text
|
356 |
+
if generated_text.endswith(self.model.conv_template.sep):
|
357 |
+
generated_text = generated_text[:-len(self.model.conv_template.sep)]
|
358 |
+
yield json.dumps({'text': generated_text, 'error_code': 0}).encode() + b'\0'
|
359 |
+
logger.info(f'max_input_tile_list: {max_input_tile_list}, history: {history}, '
|
360 |
+
f'question: {question}, answer: {generated_text}')
|
361 |
+
self.model.system_message = old_system_message
|
362 |
+
|
363 |
+
def generate_stream_gate(self, params):
|
364 |
+
try:
|
365 |
+
for x in self.generate_stream(params):
|
366 |
+
yield x
|
367 |
+
except ValueError as e:
|
368 |
+
print('Caught ValueError:', e)
|
369 |
+
ret = {
|
370 |
+
'text': server_error_msg,
|
371 |
+
'error_code': 1,
|
372 |
+
}
|
373 |
+
yield json.dumps(ret).encode() + b'\0'
|
374 |
+
except torch.cuda.CudaError as e:
|
375 |
+
print('Caught torch.cuda.CudaError:', e)
|
376 |
+
ret = {
|
377 |
+
'text': server_error_msg,
|
378 |
+
'error_code': 1,
|
379 |
+
}
|
380 |
+
yield json.dumps(ret).encode() + b'\0'
|
381 |
+
except Exception as e:
|
382 |
+
print('Caught Unknown Error', e)
|
383 |
+
ret = {
|
384 |
+
'text': server_error_msg,
|
385 |
+
'error_code': 1,
|
386 |
+
}
|
387 |
+
yield json.dumps(ret).encode() + b'\0'
|
388 |
+
|
389 |
+
|
390 |
+
app = FastAPI()
|
391 |
+
|
392 |
+
|
393 |
+
def release_model_semaphore(fn=None):
|
394 |
+
model_semaphore.release()
|
395 |
+
if fn is not None:
|
396 |
+
fn()
|
397 |
+
|
398 |
+
|
399 |
+
@app.post('/worker_generate_stream')
|
400 |
+
async def generate_stream(request: Request):
|
401 |
+
global model_semaphore, global_counter
|
402 |
+
global_counter += 1
|
403 |
+
params = await request.json()
|
404 |
+
|
405 |
+
if model_semaphore is None:
|
406 |
+
model_semaphore = asyncio.Semaphore(args.limit_model_concurrency)
|
407 |
+
await model_semaphore.acquire()
|
408 |
+
worker.send_heart_beat()
|
409 |
+
generator = worker.generate_stream_gate(params)
|
410 |
+
background_tasks = BackgroundTasks()
|
411 |
+
background_tasks.add_task(partial(release_model_semaphore, fn=worker.send_heart_beat))
|
412 |
+
return StreamingResponse(generator, background=background_tasks)
|
413 |
+
|
414 |
+
|
415 |
+
@app.post('/worker_get_status')
|
416 |
+
async def get_status(request: Request):
|
417 |
+
return worker.get_status()
|
418 |
+
|
419 |
+
|
420 |
+
if __name__ == '__main__':
|
421 |
+
parser = argparse.ArgumentParser()
|
422 |
+
parser.add_argument('--host', type=str, default='0.0.0.0')
|
423 |
+
parser.add_argument('--port', type=int, default=21002)
|
424 |
+
parser.add_argument('--worker-address', type=str, default='http://localhost:21002')
|
425 |
+
parser.add_argument('--controller-address', type=str, default='http://localhost:21001')
|
426 |
+
parser.add_argument('--model-path', type=str, default='facebook/opt-350m')
|
427 |
+
parser.add_argument('--model-name', type=str)
|
428 |
+
parser.add_argument('--device', type=str, default='cuda')
|
429 |
+
parser.add_argument('--limit-model-concurrency', type=int, default=5)
|
430 |
+
parser.add_argument('--stream-interval', type=int, default=1)
|
431 |
+
parser.add_argument('--load-8bit', action='store_true')
|
432 |
+
args = parser.parse_args()
|
433 |
+
logger.info(f'args: {args}')
|
434 |
+
|
435 |
+
worker = ModelWorker(args.controller_address,
|
436 |
+
args.worker_address,
|
437 |
+
worker_id,
|
438 |
+
args.model_path,
|
439 |
+
args.model_name,
|
440 |
+
args.load_8bit,
|
441 |
+
args.device)
|
442 |
+
uvicorn.run(app, host=args.host, port=args.port, log_level='info')
|
InternVL/video_retrieval/test_msrvtt.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import io
|
3 |
+
import json
|
4 |
+
import math
|
5 |
+
import os
|
6 |
+
|
7 |
+
import decord
|
8 |
+
import mmengine
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
import tqdm
|
12 |
+
from transformers import AutoModel, AutoTokenizer, CLIPImageProcessor
|
13 |
+
|
14 |
+
|
15 |
+
def recall_at_k(scores, positive_pairs, k):
|
16 |
+
"""
|
17 |
+
Compute the recall at k for each sample
|
18 |
+
:param scores: compability score between text and image embeddings (nb texts, nb images)
|
19 |
+
:param k: number of images to consider per text, for retrieval
|
20 |
+
:param positive_pairs: boolean matrix of positive pairs (nb texts, nb images)
|
21 |
+
:return: recall at k averaged over all texts
|
22 |
+
"""
|
23 |
+
nb_texts, nb_images = scores.shape
|
24 |
+
# for each text, sort according to image scores in decreasing order
|
25 |
+
topk_indices = torch.topk(scores, k, dim=1)[1]
|
26 |
+
# compute number of positives for each text
|
27 |
+
nb_positive = positive_pairs.sum(dim=1)
|
28 |
+
# nb_texts, k, nb_images
|
29 |
+
topk_indices_onehot = torch.nn.functional.one_hot(topk_indices, num_classes=nb_images)
|
30 |
+
# compute number of true positives
|
31 |
+
positive_pairs_reshaped = positive_pairs.view(nb_texts, 1, nb_images)
|
32 |
+
# a true positive means a positive among the topk
|
33 |
+
nb_true_positive = (topk_indices_onehot * positive_pairs_reshaped).sum(dim=(1, 2))
|
34 |
+
# compute recall at k
|
35 |
+
recall_at_k = (nb_true_positive / nb_positive)
|
36 |
+
return recall_at_k
|
37 |
+
|
38 |
+
|
39 |
+
def batchify(func, X, Y, batch_size, device, *args, **kwargs):
|
40 |
+
results = []
|
41 |
+
for start in range(0, len(X), batch_size):
|
42 |
+
end = start + batch_size
|
43 |
+
x = X[start:end].to(device)
|
44 |
+
y = Y[start:end].to(device)
|
45 |
+
result = func(x, y, *args, **kwargs).cpu()
|
46 |
+
results.append(result)
|
47 |
+
return torch.cat(results)
|
48 |
+
|
49 |
+
|
50 |
+
def validate_msrvtt(model, tokenizer, image_processor, root, metadata,
|
51 |
+
num_frames=1, prefix='summarize:', mode='InternVL-G', recall_k_list=[1, 5, 10],
|
52 |
+
use_dsl=True, eval_batch_size=32):
|
53 |
+
metadata = json.load(open(metadata))
|
54 |
+
|
55 |
+
video_features = []
|
56 |
+
text_features = []
|
57 |
+
|
58 |
+
# compute text features
|
59 |
+
print('Computing text features', flush=True)
|
60 |
+
for data in tqdm.tqdm(metadata):
|
61 |
+
caption = prefix + data['caption']
|
62 |
+
input_ids = tokenizer(caption, return_tensors='pt', max_length=80,
|
63 |
+
truncation=True, padding='max_length').input_ids.cuda()
|
64 |
+
with torch.no_grad():
|
65 |
+
feat = model.encode_text(input_ids)
|
66 |
+
text_features.append(feat.cpu())
|
67 |
+
text_features = torch.cat(text_features)
|
68 |
+
|
69 |
+
# compute video features
|
70 |
+
print('Computing video features', flush=True)
|
71 |
+
for data in tqdm.tqdm(metadata):
|
72 |
+
video_id = data['video']
|
73 |
+
video_path = os.path.join(root, video_id)
|
74 |
+
video_data = mmengine.get(video_path)
|
75 |
+
video_data = io.BytesIO(video_data)
|
76 |
+
video_reader = decord.VideoReader(video_data)
|
77 |
+
|
78 |
+
# uniformly sample frames
|
79 |
+
interval = math.ceil(len(video_reader) / num_frames)
|
80 |
+
frames_id = np.arange(0, len(video_reader), interval) + interval // 2
|
81 |
+
assert len(frames_id) == num_frames and frames_id[-1] < len(video_reader)
|
82 |
+
|
83 |
+
frames = video_reader.get_batch(frames_id).asnumpy()
|
84 |
+
|
85 |
+
pixel_values = image_processor(images=frames, return_tensors='pt').pixel_values
|
86 |
+
with torch.no_grad():
|
87 |
+
pixel_values = pixel_values.to(torch.bfloat16).cuda()
|
88 |
+
feat = model.encode_image(pixel_values, mode=mode)
|
89 |
+
feat = feat.mean(dim=0, keepdim=True)
|
90 |
+
video_features.append(feat.cpu())
|
91 |
+
video_features = torch.cat(video_features)
|
92 |
+
|
93 |
+
print('Computing metrics', flush=True)
|
94 |
+
texts_emb = text_features / text_features.norm(dim=-1, keepdim=True)
|
95 |
+
images_emb = video_features / video_features.norm(dim=-1, keepdim=True)
|
96 |
+
|
97 |
+
# get the score for each text and image pair
|
98 |
+
scores = texts_emb @ images_emb.t()
|
99 |
+
|
100 |
+
# construct a the positive pair matrix, which tells whether each text-image pair is a positive or not
|
101 |
+
positive_pairs = torch.zeros_like(scores, dtype=bool)
|
102 |
+
positive_pairs[torch.arange(len(scores)), torch.arange(len(scores))] = True
|
103 |
+
|
104 |
+
scores_T = scores.T
|
105 |
+
positive_pairs_T = positive_pairs.T
|
106 |
+
|
107 |
+
if use_dsl:
|
108 |
+
scores = scores * scores.softmax(dim=0)
|
109 |
+
scores_T = scores_T * scores_T.softmax(dim=0)
|
110 |
+
|
111 |
+
metrics = {}
|
112 |
+
for recall_k in recall_k_list:
|
113 |
+
# Note that recall_at_k computes **actual** recall i.e. nb_true_positive/nb_positives, where the number
|
114 |
+
# of true positives, e.g. for text retrieval, is, for each image, the number of retrieved texts matching that image among the top-k.
|
115 |
+
# Also, the number of positives are the total number of texts matching the image in the dataset, as we have a set of captions
|
116 |
+
# for each image, that number will be greater than 1 for text retrieval.
|
117 |
+
# However, image/text retrieval recall@k, the way it is done in CLIP-like papers, is a bit different.
|
118 |
+
# recall@k, in CLIP-like papers, is, for each image, either 1 or 0. It is 1 if atleast one text matches the image among the top-k.
|
119 |
+
# so we can easily compute that using the actual recall, by checking whether there is at least one true positive,
|
120 |
+
# which would be the case if the recall is greater than 0. One we compute the recal for each image (or text), we average
|
121 |
+
# it over the dataset.
|
122 |
+
metrics[f't2v_retrieval_recall@{recall_k}'] = (
|
123 |
+
batchify(recall_at_k, scores, positive_pairs, eval_batch_size, scores.device,
|
124 |
+
k=recall_k) > 0).float().mean().item()
|
125 |
+
metrics[f'v2t_retrieval_recall@{recall_k}'] = (
|
126 |
+
batchify(recall_at_k, scores_T, positive_pairs_T, eval_batch_size, scores.device,
|
127 |
+
k=recall_k) > 0).float().mean().item()
|
128 |
+
|
129 |
+
print(metrics)
|
130 |
+
|
131 |
+
|
132 |
+
if __name__ == '__main__':
|
133 |
+
parser = argparse.ArgumentParser(description='validate MSR-VTT', add_help=False)
|
134 |
+
parser.add_argument('--video-root', type=str)
|
135 |
+
parser.add_argument('--metadata', type=str)
|
136 |
+
parser.add_argument('--mode', type=str, default='InternVL-C',choices=['InternVL-C', 'InternVL-G'])
|
137 |
+
parser.add_argument('--num-frames', type=int, default=1)
|
138 |
+
args = parser.parse_args()
|
139 |
+
|
140 |
+
model = AutoModel.from_pretrained(
|
141 |
+
'OpenGVLab/InternVL-14B-224px',
|
142 |
+
torch_dtype=torch.bfloat16,
|
143 |
+
low_cpu_mem_usage=True,
|
144 |
+
trust_remote_code=True).cuda().eval()
|
145 |
+
|
146 |
+
image_processor = CLIPImageProcessor.from_pretrained('OpenGVLab/InternVL-14B-224px')
|
147 |
+
|
148 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
149 |
+
'OpenGVLab/InternVL-14B-224px', use_fast=False, add_eos_token=True)
|
150 |
+
tokenizer.pad_token_id = 0 # set pad_token_id to 0
|
151 |
+
|
152 |
+
metrics = validate_msrvtt(model, tokenizer, image_processor,
|
153 |
+
root=args.video_root,
|
154 |
+
metadata=args.metadata,
|
155 |
+
mode=args.mode,
|
156 |
+
num_frames=args.num_frames,)
|
sglang/examples/frontend_language/quick_start/gemini_example_chat.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Usage:
|
3 |
+
export GCP_PROJECT_ID=******
|
4 |
+
python3 gemini_example_chat.py
|
5 |
+
"""
|
6 |
+
|
7 |
+
import sglang as sgl
|
8 |
+
|
9 |
+
|
10 |
+
@sgl.function
|
11 |
+
def multi_turn_question(s, question_1, question_2):
|
12 |
+
s += sgl.user(question_1)
|
13 |
+
s += sgl.assistant(sgl.gen("answer_1", max_tokens=256))
|
14 |
+
s += sgl.user(question_2)
|
15 |
+
s += sgl.assistant(sgl.gen("answer_2", max_tokens=256))
|
16 |
+
|
17 |
+
|
18 |
+
def single():
|
19 |
+
state = multi_turn_question.run(
|
20 |
+
question_1="What is the capital of the United States?",
|
21 |
+
question_2="List two local attractions.",
|
22 |
+
)
|
23 |
+
|
24 |
+
for m in state.messages():
|
25 |
+
print(m["role"], ":", m["content"])
|
26 |
+
|
27 |
+
print("\n-- answer_1 --\n", state["answer_1"])
|
28 |
+
|
29 |
+
|
30 |
+
def stream():
|
31 |
+
state = multi_turn_question.run(
|
32 |
+
question_1="What is the capital of the United States?",
|
33 |
+
question_2="List two local attractions.",
|
34 |
+
stream=True,
|
35 |
+
)
|
36 |
+
|
37 |
+
for out in state.text_iter():
|
38 |
+
print(out, end="", flush=True)
|
39 |
+
print()
|
40 |
+
|
41 |
+
|
42 |
+
def batch():
|
43 |
+
states = multi_turn_question.run_batch(
|
44 |
+
[
|
45 |
+
{
|
46 |
+
"question_1": "What is the capital of the United States?",
|
47 |
+
"question_2": "List two local attractions.",
|
48 |
+
},
|
49 |
+
{
|
50 |
+
"question_1": "What is the capital of France?",
|
51 |
+
"question_2": "What is the population of this city?",
|
52 |
+
},
|
53 |
+
]
|
54 |
+
)
|
55 |
+
|
56 |
+
for s in states:
|
57 |
+
print(s.messages())
|
58 |
+
|
59 |
+
|
60 |
+
if __name__ == "__main__":
|
61 |
+
sgl.set_default_backend(sgl.VertexAI("gemini-pro"))
|
62 |
+
|
63 |
+
# Run a single request
|
64 |
+
print("\n========== single ==========\n")
|
65 |
+
single()
|
66 |
+
|
67 |
+
# Stream output
|
68 |
+
print("\n========== stream ==========\n")
|
69 |
+
stream()
|
70 |
+
|
71 |
+
# Run a batch of requests
|
72 |
+
print("\n========== batch ==========\n")
|
73 |
+
batch()
|
sglang/examples/frontend_language/quick_start/gemini_example_multimodal_chat.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Usage:
|
3 |
+
export GCP_PROJECT_ID=******
|
4 |
+
python3 gemini_example_multimodal_chat.py
|
5 |
+
"""
|
6 |
+
|
7 |
+
import sglang as sgl
|
8 |
+
|
9 |
+
|
10 |
+
@sgl.function
|
11 |
+
def image_qa(s, image_file1, image_file2, question):
|
12 |
+
s += sgl.user(sgl.image(image_file1) + sgl.image(image_file2) + question)
|
13 |
+
s += sgl.assistant(sgl.gen("answer", max_tokens=256))
|
14 |
+
|
15 |
+
|
16 |
+
if __name__ == "__main__":
|
17 |
+
sgl.set_default_backend(sgl.VertexAI("gemini-pro-vision"))
|
18 |
+
|
19 |
+
state = image_qa.run(
|
20 |
+
image_file1="./images/cat.jpeg",
|
21 |
+
image_file2="./images/dog.jpeg",
|
22 |
+
question="Describe difference of the two images in one sentence.",
|
23 |
+
stream=True,
|
24 |
+
)
|
25 |
+
|
26 |
+
for out in state.text_iter("answer"):
|
27 |
+
print(out, end="", flush=True)
|
28 |
+
print()
|
29 |
+
|
30 |
+
print(state["answer"])
|
sglang/examples/frontend_language/quick_start/local_example_complete.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Usage:
|
3 |
+
python3 local_example_complete.py
|
4 |
+
"""
|
5 |
+
|
6 |
+
import sglang as sgl
|
7 |
+
|
8 |
+
|
9 |
+
@sgl.function
|
10 |
+
def few_shot_qa(s, question):
|
11 |
+
s += """The following are questions with answers.
|
12 |
+
Q: What is the capital of France?
|
13 |
+
A: Paris
|
14 |
+
Q: What is the capital of Germany?
|
15 |
+
A: Berlin
|
16 |
+
Q: What is the capital of Italy?
|
17 |
+
A: Rome
|
18 |
+
"""
|
19 |
+
s += "Q: " + question + "\n"
|
20 |
+
s += "A:" + sgl.gen("answer", stop="\n", temperature=0)
|
21 |
+
|
22 |
+
|
23 |
+
def single():
|
24 |
+
state = few_shot_qa.run(question="What is the capital of the United States?")
|
25 |
+
answer = state["answer"].strip().lower()
|
26 |
+
|
27 |
+
assert "washington" in answer, f"answer: {state['answer']}"
|
28 |
+
|
29 |
+
print(state.text())
|
30 |
+
|
31 |
+
|
32 |
+
def stream():
|
33 |
+
state = few_shot_qa.run(
|
34 |
+
question="What is the capital of the United States?", stream=True
|
35 |
+
)
|
36 |
+
|
37 |
+
for out in state.text_iter("answer"):
|
38 |
+
print(out, end="", flush=True)
|
39 |
+
print()
|
40 |
+
|
41 |
+
|
42 |
+
def batch():
|
43 |
+
states = few_shot_qa.run_batch(
|
44 |
+
[
|
45 |
+
{"question": "What is the capital of the United States?"},
|
46 |
+
{"question": "What is the capital of China?"},
|
47 |
+
]
|
48 |
+
)
|
49 |
+
|
50 |
+
for s in states:
|
51 |
+
print(s["answer"])
|
52 |
+
|
53 |
+
|
54 |
+
if __name__ == "__main__":
|
55 |
+
runtime = sgl.Runtime(model_path="meta-llama/Llama-2-7b-chat-hf")
|
56 |
+
sgl.set_default_backend(runtime)
|
57 |
+
|
58 |
+
# Run a single request
|
59 |
+
print("\n========== single ==========\n")
|
60 |
+
single()
|
61 |
+
|
62 |
+
# Stream output
|
63 |
+
print("\n========== stream ==========\n")
|
64 |
+
stream()
|
65 |
+
|
66 |
+
# Run a batch of requests
|
67 |
+
print("\n========== batch ==========\n")
|
68 |
+
batch()
|
69 |
+
|
70 |
+
runtime.shutdown()
|
sglang/examples/frontend_language/quick_start/local_example_llava_next.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Usage: python3 local_example_llava_next.py
|
3 |
+
"""
|
4 |
+
|
5 |
+
import sglang as sgl
|
6 |
+
from sglang.lang.chat_template import get_chat_template
|
7 |
+
|
8 |
+
|
9 |
+
@sgl.function
|
10 |
+
def image_qa(s, image_path, question):
|
11 |
+
s += sgl.user(sgl.image(image_path) + question)
|
12 |
+
s += sgl.assistant(sgl.gen("answer"))
|
13 |
+
|
14 |
+
|
15 |
+
def single():
|
16 |
+
state = image_qa.run(
|
17 |
+
image_path="images/cat.jpeg", question="What is this?", max_new_tokens=128
|
18 |
+
)
|
19 |
+
print(state["answer"], "\n")
|
20 |
+
|
21 |
+
|
22 |
+
def stream():
|
23 |
+
state = image_qa.run(
|
24 |
+
image_path="images/cat.jpeg",
|
25 |
+
question="What is this?",
|
26 |
+
max_new_tokens=64,
|
27 |
+
stream=True,
|
28 |
+
)
|
29 |
+
|
30 |
+
for out in state.text_iter("answer"):
|
31 |
+
print(out, end="", flush=True)
|
32 |
+
print()
|
33 |
+
|
34 |
+
|
35 |
+
def batch():
|
36 |
+
states = image_qa.run_batch(
|
37 |
+
[
|
38 |
+
{"image_path": "images/cat.jpeg", "question": "What is this?"},
|
39 |
+
{"image_path": "images/dog.jpeg", "question": "What is this?"},
|
40 |
+
],
|
41 |
+
max_new_tokens=128,
|
42 |
+
)
|
43 |
+
for s in states:
|
44 |
+
print(s["answer"], "\n")
|
45 |
+
|
46 |
+
|
47 |
+
if __name__ == "__main__":
|
48 |
+
import multiprocessing as mp
|
49 |
+
|
50 |
+
mp.set_start_method("spawn", force=True)
|
51 |
+
|
52 |
+
runtime = sgl.Runtime(model_path="lmms-lab/llama3-llava-next-8b")
|
53 |
+
runtime.endpoint.chat_template = get_chat_template("llama-3-instruct-llava")
|
54 |
+
|
55 |
+
# Or you can use the 72B model
|
56 |
+
# runtime = sgl.Runtime(model_path="lmms-lab/llava-next-72b", tp_size=8)
|
57 |
+
# runtime.endpoint.chat_template = get_chat_template("chatml-llava")
|
58 |
+
|
59 |
+
sgl.set_default_backend(runtime)
|
60 |
+
print(f"chat template: {runtime.endpoint.chat_template.name}")
|
61 |
+
|
62 |
+
# Or you can use API models
|
63 |
+
# sgl.set_default_backend(sgl.OpenAI("gpt-4-vision-preview"))
|
64 |
+
# sgl.set_default_backend(sgl.VertexAI("gemini-pro-vision"))
|
65 |
+
|
66 |
+
# Run a single request
|
67 |
+
print("\n========== single ==========\n")
|
68 |
+
single()
|
69 |
+
|
70 |
+
# Stream output
|
71 |
+
print("\n========== stream ==========\n")
|
72 |
+
stream()
|
73 |
+
|
74 |
+
# Run a batch of requests
|
75 |
+
print("\n========== batch ==========\n")
|
76 |
+
batch()
|
77 |
+
|
78 |
+
runtime.shutdown()
|
sglang/examples/frontend_language/quick_start/openai_example_chat.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Usage:
|
3 |
+
export OPENAI_API_KEY=sk-******
|
4 |
+
python3 openai_example_chat.py
|
5 |
+
"""
|
6 |
+
|
7 |
+
import sglang as sgl
|
8 |
+
|
9 |
+
|
10 |
+
@sgl.function
|
11 |
+
def multi_turn_question(s, question_1, question_2):
|
12 |
+
s += sgl.system("You are a helpful assistant.")
|
13 |
+
s += sgl.user(question_1)
|
14 |
+
s += sgl.assistant(sgl.gen("answer_1", max_tokens=256))
|
15 |
+
s += sgl.user(question_2)
|
16 |
+
s += sgl.assistant(sgl.gen("answer_2", max_tokens=256))
|
17 |
+
|
18 |
+
|
19 |
+
def single():
|
20 |
+
state = multi_turn_question.run(
|
21 |
+
question_1="What is the capital of the United States?",
|
22 |
+
question_2="List two local attractions.",
|
23 |
+
)
|
24 |
+
|
25 |
+
for m in state.messages():
|
26 |
+
print(m["role"], ":", m["content"])
|
27 |
+
|
28 |
+
print("\n-- answer_1 --\n", state["answer_1"])
|
29 |
+
|
30 |
+
|
31 |
+
def stream():
|
32 |
+
state = multi_turn_question.run(
|
33 |
+
question_1="What is the capital of the United States?",
|
34 |
+
question_2="List two local attractions.",
|
35 |
+
stream=True,
|
36 |
+
)
|
37 |
+
|
38 |
+
for out in state.text_iter():
|
39 |
+
print(out, end="", flush=True)
|
40 |
+
print()
|
41 |
+
|
42 |
+
|
43 |
+
def batch():
|
44 |
+
states = multi_turn_question.run_batch(
|
45 |
+
[
|
46 |
+
{
|
47 |
+
"question_1": "What is the capital of the United States?",
|
48 |
+
"question_2": "List two local attractions.",
|
49 |
+
},
|
50 |
+
{
|
51 |
+
"question_1": "What is the capital of France?",
|
52 |
+
"question_2": "What is the population of this city?",
|
53 |
+
},
|
54 |
+
]
|
55 |
+
)
|
56 |
+
|
57 |
+
for s in states:
|
58 |
+
print(s.messages())
|
59 |
+
|
60 |
+
|
61 |
+
if __name__ == "__main__":
|
62 |
+
sgl.set_default_backend(sgl.OpenAI("gpt-3.5-turbo"))
|
63 |
+
|
64 |
+
# Run a single request
|
65 |
+
print("\n========== single ==========\n")
|
66 |
+
single()
|
67 |
+
|
68 |
+
# Stream output
|
69 |
+
print("\n========== stream ==========\n")
|
70 |
+
stream()
|
71 |
+
|
72 |
+
# Run a batch of requests
|
73 |
+
print("\n========== batch ==========\n")
|
74 |
+
batch()
|
sglang/examples/frontend_language/quick_start/openai_example_complete.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Usage:
|
3 |
+
export OPENAI_API_KEY=sk-******
|
4 |
+
python3 openai_example_complete.py
|
5 |
+
"""
|
6 |
+
|
7 |
+
import sglang as sgl
|
8 |
+
|
9 |
+
|
10 |
+
@sgl.function
|
11 |
+
def few_shot_qa(s, question):
|
12 |
+
s += """The following are questions with answers.
|
13 |
+
Q: What is the capital of France?
|
14 |
+
A: Paris
|
15 |
+
Q: What is the capital of Germany?
|
16 |
+
A: Berlin
|
17 |
+
Q: What is the capital of Italy?
|
18 |
+
A: Rome
|
19 |
+
"""
|
20 |
+
s += "Q: " + question + "\n"
|
21 |
+
s += "A:" + sgl.gen("answer", stop="\n", temperature=0)
|
22 |
+
|
23 |
+
|
24 |
+
def single():
|
25 |
+
state = few_shot_qa.run(question="What is the capital of the United States?")
|
26 |
+
answer = state["answer"].strip().lower()
|
27 |
+
|
28 |
+
assert "washington" in answer, f"answer: {state['answer']}"
|
29 |
+
|
30 |
+
print(state.text())
|
31 |
+
|
32 |
+
|
33 |
+
def stream():
|
34 |
+
state = few_shot_qa.run(
|
35 |
+
question="What is the capital of the United States?", stream=True
|
36 |
+
)
|
37 |
+
|
38 |
+
for out in state.text_iter("answer"):
|
39 |
+
print(out, end="", flush=True)
|
40 |
+
print()
|
41 |
+
|
42 |
+
|
43 |
+
def batch():
|
44 |
+
states = few_shot_qa.run_batch(
|
45 |
+
[
|
46 |
+
{"question": "What is the capital of the United States?"},
|
47 |
+
{"question": "What is the capital of China?"},
|
48 |
+
]
|
49 |
+
)
|
50 |
+
|
51 |
+
for s in states:
|
52 |
+
print(s["answer"])
|
53 |
+
|
54 |
+
|
55 |
+
if __name__ == "__main__":
|
56 |
+
sgl.set_default_backend(sgl.OpenAI("gpt-3.5-turbo-instruct"))
|
57 |
+
|
58 |
+
# Run a single request
|
59 |
+
print("\n========== single ==========\n")
|
60 |
+
single()
|
61 |
+
|
62 |
+
# Stream output
|
63 |
+
print("\n========== stream ==========\n")
|
64 |
+
stream()
|
65 |
+
|
66 |
+
# Run a batch of requests
|
67 |
+
print("\n========== batch ==========\n")
|
68 |
+
batch()
|
sglang/examples/frontend_language/quick_start/openrouter_example_chat.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Usage:
|
3 |
+
export OPENROUTER_API_KEY=sk-******
|
4 |
+
python3 together_example_chat.py
|
5 |
+
"""
|
6 |
+
|
7 |
+
import os
|
8 |
+
|
9 |
+
import sglang as sgl
|
10 |
+
|
11 |
+
|
12 |
+
@sgl.function
|
13 |
+
def multi_turn_question(s, question_1, question_2):
|
14 |
+
s += sgl.system("You are a helpful assistant.")
|
15 |
+
s += sgl.user(question_1)
|
16 |
+
s += sgl.assistant(sgl.gen("answer_1", max_tokens=256))
|
17 |
+
s += sgl.user(question_2)
|
18 |
+
s += sgl.assistant(sgl.gen("answer_2", max_tokens=256))
|
19 |
+
|
20 |
+
|
21 |
+
def single():
|
22 |
+
state = multi_turn_question.run(
|
23 |
+
question_1="What is the capital of the United States?",
|
24 |
+
question_2="List two local attractions.",
|
25 |
+
)
|
26 |
+
|
27 |
+
for m in state.messages():
|
28 |
+
print(m["role"], ":", m["content"])
|
29 |
+
|
30 |
+
print("\n-- answer_1 --\n", state["answer_1"])
|
31 |
+
|
32 |
+
|
33 |
+
def stream():
|
34 |
+
state = multi_turn_question.run(
|
35 |
+
question_1="What is the capital of the United States?",
|
36 |
+
question_2="List two local attractions.",
|
37 |
+
stream=True,
|
38 |
+
)
|
39 |
+
|
40 |
+
for out in state.text_iter():
|
41 |
+
print(out, end="", flush=True)
|
42 |
+
print()
|
43 |
+
|
44 |
+
|
45 |
+
def batch():
|
46 |
+
states = multi_turn_question.run_batch(
|
47 |
+
[
|
48 |
+
{
|
49 |
+
"question_1": "What is the capital of the United States?",
|
50 |
+
"question_2": "List two local attractions.",
|
51 |
+
},
|
52 |
+
{
|
53 |
+
"question_1": "What is the capital of France?",
|
54 |
+
"question_2": "What is the population of this city?",
|
55 |
+
},
|
56 |
+
]
|
57 |
+
)
|
58 |
+
|
59 |
+
for s in states:
|
60 |
+
print(s.messages())
|
61 |
+
|
62 |
+
|
63 |
+
if __name__ == "__main__":
|
64 |
+
backend = sgl.OpenAI(
|
65 |
+
model_name="google/gemma-7b-it:free",
|
66 |
+
base_url="https://openrouter.ai/api/v1",
|
67 |
+
api_key=os.environ.get("OPENROUTER_API_KEY"),
|
68 |
+
)
|
69 |
+
sgl.set_default_backend(backend)
|
70 |
+
|
71 |
+
# Run a single request
|
72 |
+
print("\n========== single ==========\n")
|
73 |
+
single()
|
74 |
+
|
75 |
+
# Stream output
|
76 |
+
print("\n========== stream ==========\n")
|
77 |
+
stream()
|
78 |
+
|
79 |
+
# Run a batch of requests
|
80 |
+
print("\n========== batch ==========\n")
|
81 |
+
batch()
|
sglang/examples/frontend_language/quick_start/together_example_complete.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Usage:
|
3 |
+
export TOGETHER_API_KEY=sk-******
|
4 |
+
python3 together_example_complete.py
|
5 |
+
"""
|
6 |
+
|
7 |
+
import os
|
8 |
+
|
9 |
+
import sglang as sgl
|
10 |
+
|
11 |
+
|
12 |
+
@sgl.function
|
13 |
+
def few_shot_qa(s, question):
|
14 |
+
s += """The following are questions with answers.
|
15 |
+
Q: What is the capital of France?
|
16 |
+
A: Paris
|
17 |
+
Q: What is the capital of Germany?
|
18 |
+
A: Berlin
|
19 |
+
Q: What is the capital of Italy?
|
20 |
+
A: Rome
|
21 |
+
"""
|
22 |
+
s += "Q: " + question + "\n"
|
23 |
+
s += "A:" + sgl.gen("answer", stop="\n", temperature=0)
|
24 |
+
|
25 |
+
|
26 |
+
def single():
|
27 |
+
state = few_shot_qa.run(question="What is the capital of the United States?")
|
28 |
+
answer = state["answer"].strip().lower()
|
29 |
+
|
30 |
+
assert "washington" in answer, f"answer: {state['answer']}"
|
31 |
+
|
32 |
+
print(state.text())
|
33 |
+
|
34 |
+
|
35 |
+
def stream():
|
36 |
+
state = few_shot_qa.run(
|
37 |
+
question="What is the capital of the United States?", stream=True
|
38 |
+
)
|
39 |
+
|
40 |
+
for out in state.text_iter("answer"):
|
41 |
+
print(out, end="", flush=True)
|
42 |
+
print()
|
43 |
+
|
44 |
+
|
45 |
+
def batch():
|
46 |
+
states = few_shot_qa.run_batch(
|
47 |
+
[
|
48 |
+
{"question": "What is the capital of the United States?"},
|
49 |
+
{"question": "What is the capital of China?"},
|
50 |
+
]
|
51 |
+
)
|
52 |
+
|
53 |
+
for s in states:
|
54 |
+
print(s["answer"])
|
55 |
+
|
56 |
+
|
57 |
+
if __name__ == "__main__":
|
58 |
+
backend = sgl.OpenAI(
|
59 |
+
model_name="mistralai/Mixtral-8x7B-Instruct-v0.1",
|
60 |
+
is_chat_model=False,
|
61 |
+
base_url="https://api.together.xyz/v1",
|
62 |
+
api_key=os.environ.get("TOGETHER_API_KEY"),
|
63 |
+
)
|
64 |
+
sgl.set_default_backend(backend)
|
65 |
+
|
66 |
+
# Run a single request
|
67 |
+
print("\n========== single ==========\n")
|
68 |
+
single()
|
69 |
+
|
70 |
+
# Stream output
|
71 |
+
print("\n========== stream ==========\n")
|
72 |
+
stream()
|
73 |
+
|
74 |
+
# Run a batch of requests
|
75 |
+
print("\n========== batch ==========\n")
|
76 |
+
batch()
|
sglang/examples/frontend_language/usage/chinese_regex.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sglang as sgl
|
2 |
+
|
3 |
+
character_regex = (
|
4 |
+
r"""\{\n"""
|
5 |
+
+ r""" "姓名": "[^"]{1,32}",\n"""
|
6 |
+
+ r""" "学院": "(格兰芬多|赫奇帕奇|拉文克劳|斯莱特林)",\n"""
|
7 |
+
+ r""" "血型": "(纯血|混血|麻瓜)",\n"""
|
8 |
+
+ r""" "职业": "(学生|教师|傲罗|魔法部|食死徒|凤凰社成员)",\n"""
|
9 |
+
+ r""" "魔杖": \{\n"""
|
10 |
+
+ r""" "材质": "[^"]{1,32}",\n"""
|
11 |
+
+ r""" "杖芯": "[^"]{1,32}",\n"""
|
12 |
+
+ r""" "长度": [0-9]{1,2}\.[0-9]{0,2}\n"""
|
13 |
+
+ r""" \},\n"""
|
14 |
+
+ r""" "存活": "(存活|死亡)",\n"""
|
15 |
+
+ r""" "守护神": "[^"]{1,32}",\n"""
|
16 |
+
+ r""" "博格特": "[^"]{1,32}"\n"""
|
17 |
+
+ r"""\}"""
|
18 |
+
)
|
19 |
+
|
20 |
+
|
21 |
+
@sgl.function
|
22 |
+
def character_gen(s, name):
|
23 |
+
s += name + " 是一名哈利波特系列小说中的角色。请填写以下关于这个角色的信息。"
|
24 |
+
s += """\
|
25 |
+
这是一个例子
|
26 |
+
{
|
27 |
+
"姓名": "哈利波特",
|
28 |
+
"学院": "格兰芬多",
|
29 |
+
"血型": "混血",
|
30 |
+
"职业": "学生",
|
31 |
+
"魔杖": {
|
32 |
+
"材质": "冬青木",
|
33 |
+
"杖芯": "凤凰尾羽",
|
34 |
+
"长度": 11.0
|
35 |
+
},
|
36 |
+
"存活": "存活",
|
37 |
+
"守护神": "麋鹿",
|
38 |
+
"博格特": "摄魂怪"
|
39 |
+
}
|
40 |
+
"""
|
41 |
+
s += f"现在请你填写{name}的信息:\n"
|
42 |
+
s += sgl.gen("json_output", max_tokens=256, regex=character_regex)
|
43 |
+
|
44 |
+
|
45 |
+
def main():
|
46 |
+
backend = sgl.RuntimeEndpoint("http://localhost:30000")
|
47 |
+
sgl.set_default_backend(backend)
|
48 |
+
ret = character_gen.run(name="赫敏格兰杰", temperature=0)
|
49 |
+
print(ret.text())
|
50 |
+
|
51 |
+
|
52 |
+
if __name__ == "__main__":
|
53 |
+
main()
|
sglang/examples/frontend_language/usage/choices_logprob.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Usage:
|
3 |
+
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
|
4 |
+
python choices_logprob.py
|
5 |
+
"""
|
6 |
+
|
7 |
+
import sglang as sgl
|
8 |
+
|
9 |
+
|
10 |
+
@sgl.function
|
11 |
+
def tool_use(s, question):
|
12 |
+
s += "To answer this question: " + question + ", "
|
13 |
+
s += "I need to use a " + sgl.gen("tool", choices=["calculator", "search engine"])
|
14 |
+
|
15 |
+
|
16 |
+
def main():
|
17 |
+
# Run one case
|
18 |
+
question = "What is 5 + 5?"
|
19 |
+
state = tool_use.run(question)
|
20 |
+
print("questions:", question)
|
21 |
+
print("choice:", state["tool"])
|
22 |
+
meta_info = state.get_meta_info("tool")
|
23 |
+
print("logprobs of choice 1", meta_info["input_token_logprobs"][0])
|
24 |
+
print("logprobs of choice 2", meta_info["input_token_logprobs"][1])
|
25 |
+
print("-" * 50)
|
26 |
+
|
27 |
+
# Run a batch
|
28 |
+
questions = [
|
29 |
+
"What is 5 + 6?",
|
30 |
+
"Who is Michael Jordan?",
|
31 |
+
]
|
32 |
+
states = tool_use.run_batch([{"question": q} for q in questions])
|
33 |
+
for question, state in zip(questions, states):
|
34 |
+
print("questions:", question)
|
35 |
+
print("choice:", state["tool"])
|
36 |
+
meta_info = state.get_meta_info("tool")
|
37 |
+
print("logprobs of choice 1", meta_info["input_token_logprobs"][0])
|
38 |
+
print("logprobs of choice 2", meta_info["input_token_logprobs"][1])
|
39 |
+
print("-" * 50)
|
40 |
+
|
41 |
+
|
42 |
+
if __name__ == "__main__":
|
43 |
+
sgl.set_default_backend(sgl.RuntimeEndpoint("http://localhost:30000"))
|
44 |
+
main()
|
sglang/examples/frontend_language/usage/cot_decoding.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from math import exp
|
2 |
+
from pprint import pformat
|
3 |
+
|
4 |
+
import sglang as sgl
|
5 |
+
|
6 |
+
YELLOW = "\033[1;33m"
|
7 |
+
GREEN = "\033[1;32m"
|
8 |
+
BLUE = "\033[1;34m"
|
9 |
+
CLEAR = "\033[1;0m"
|
10 |
+
|
11 |
+
|
12 |
+
@sgl.function
|
13 |
+
def cot_decoding(s, question, get_top_k, is_chat_model, verbose):
|
14 |
+
"""CoT Decoding: http://arxiv.org/abs/2402.10200"""
|
15 |
+
|
16 |
+
if is_chat_model:
|
17 |
+
s += sgl.user("Question: " + question + "\nAnswer:")
|
18 |
+
s += sgl.assistant_begin()
|
19 |
+
else:
|
20 |
+
s += "Question: " + question + "\nAnswer:"
|
21 |
+
|
22 |
+
step_0 = s.fork(1)[0]
|
23 |
+
forks = s.fork(get_top_k)
|
24 |
+
answer_forks = s.fork(get_top_k)
|
25 |
+
|
26 |
+
# decoding step 0
|
27 |
+
step_0 += sgl.gen(
|
28 |
+
"get_top_k",
|
29 |
+
max_tokens=0,
|
30 |
+
return_logprob=True,
|
31 |
+
top_logprobs_num=get_top_k,
|
32 |
+
return_text_in_logprobs=True,
|
33 |
+
)
|
34 |
+
logprobs = step_0.get_meta_info("get_top_k")["output_top_logprobs"][0]
|
35 |
+
|
36 |
+
print("Decoding step 0:", ", ".join(pformat(token[2]) for token in logprobs))
|
37 |
+
for idx, (f, token) in enumerate(zip(forks, logprobs)):
|
38 |
+
logprob, token_id, text = token
|
39 |
+
f += text
|
40 |
+
|
41 |
+
if text == "<|end_of_text|>":
|
42 |
+
print(
|
43 |
+
f"{YELLOW}Path #{idx} {pformat(text)}[{exp(logprob):.3f}] (score=nan, answer=nan){CLEAR}"
|
44 |
+
)
|
45 |
+
continue
|
46 |
+
|
47 |
+
# continue greedy decoding
|
48 |
+
f += sgl.gen(
|
49 |
+
"answer",
|
50 |
+
temperature=0,
|
51 |
+
max_tokens=1024,
|
52 |
+
return_logprob=True,
|
53 |
+
top_logprobs_num=2,
|
54 |
+
return_text_in_logprobs=True,
|
55 |
+
)
|
56 |
+
|
57 |
+
# calculate probability disparity between the top and secondary tokens
|
58 |
+
x1s = [exp(xt[0][0]) for xt in f.get_meta_info("answer")["output_top_logprobs"]]
|
59 |
+
x2s = [exp(xt[1][0]) for xt in f.get_meta_info("answer")["output_top_logprobs"]]
|
60 |
+
tokens = [xt[0][2] for xt in f.get_meta_info("answer")["output_top_logprobs"]]
|
61 |
+
delta = (sum(x1s) - sum(x2s)) / len(x1s)
|
62 |
+
|
63 |
+
# extract the answer span (without the '<|end_of_text|>' token)
|
64 |
+
answer_forks[idx] += text + f["answer"] + "\nSo the answer is"
|
65 |
+
answer_forks[idx] += sgl.gen(
|
66 |
+
"answer_span",
|
67 |
+
temperature=0,
|
68 |
+
max_tokens=64,
|
69 |
+
return_logprob=True,
|
70 |
+
top_logprobs_num=2,
|
71 |
+
return_text_in_logprobs=True,
|
72 |
+
)
|
73 |
+
answer = answer_forks[idx]["answer_span"].replace("\n", " ").strip(":")
|
74 |
+
print(
|
75 |
+
f"{YELLOW}Path #{idx} {pformat(text)}[{exp(logprob):.3f}] (score={delta}, answer={answer}){CLEAR}"
|
76 |
+
)
|
77 |
+
generated_text = str(answer_forks[idx])[len("ProgramState(") : -1]
|
78 |
+
print(f"{BLUE}{pformat(generated_text)}{CLEAR}")
|
79 |
+
|
80 |
+
if verbose:
|
81 |
+
answer_tokens = [
|
82 |
+
xt[0][2]
|
83 |
+
for xt in answer_forks[idx].get_meta_info("answer_span")[
|
84 |
+
"output_top_logprobs"
|
85 |
+
]
|
86 |
+
]
|
87 |
+
answer_x1s = [
|
88 |
+
exp(xt[0][0])
|
89 |
+
for xt in answer_forks[idx].get_meta_info("answer_span")[
|
90 |
+
"output_top_logprobs"
|
91 |
+
]
|
92 |
+
]
|
93 |
+
answer_x2s = [
|
94 |
+
exp(xt[1][0])
|
95 |
+
for xt in answer_forks[idx].get_meta_info("answer_span")[
|
96 |
+
"output_top_logprobs"
|
97 |
+
]
|
98 |
+
]
|
99 |
+
|
100 |
+
for token, x1, x2 in zip(tokens, x1s, x2s):
|
101 |
+
print(f" {GREEN}{pformat(token)}{CLEAR}({x1:.3f}-{x2:.3f})", end="")
|
102 |
+
print("\n===========")
|
103 |
+
for token, x1, x2 in zip(answer_tokens, answer_x1s, answer_x2s):
|
104 |
+
print(f" {GREEN}{pformat(token)}{CLEAR}({x1:.3f}-{x2:.3f})", end="")
|
105 |
+
print()
|
106 |
+
|
107 |
+
|
108 |
+
sgl.set_default_backend(sgl.RuntimeEndpoint("http://localhost:30000"))
|
109 |
+
|
110 |
+
state = cot_decoding.run(
|
111 |
+
question=r"Claire makes a 3 egg omelet every morning for breakfast. How many dozens of eggs will she eat in 4 weeks?",
|
112 |
+
get_top_k=10,
|
113 |
+
is_chat_model=True,
|
114 |
+
verbose=False,
|
115 |
+
)
|
sglang/examples/frontend_language/usage/json_decode.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Usage:
|
3 |
+
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
|
4 |
+
python json_decode.py
|
5 |
+
"""
|
6 |
+
|
7 |
+
from enum import Enum
|
8 |
+
|
9 |
+
from pydantic import BaseModel
|
10 |
+
|
11 |
+
import sglang as sgl
|
12 |
+
from sglang.srt.constrained import build_regex_from_object
|
13 |
+
|
14 |
+
character_regex = (
|
15 |
+
r"""\{\n"""
|
16 |
+
+ r""" "name": "[\w\d\s]{1,16}",\n"""
|
17 |
+
+ r""" "house": "(Gryffindor|Slytherin|Ravenclaw|Hufflepuff)",\n"""
|
18 |
+
+ r""" "blood status": "(Pure-blood|Half-blood|Muggle-born)",\n"""
|
19 |
+
+ r""" "occupation": "(student|teacher|auror|ministry of magic|death eater|order of the phoenix)",\n"""
|
20 |
+
+ r""" "wand": \{\n"""
|
21 |
+
+ r""" "wood": "[\w\d\s]{1,16}",\n"""
|
22 |
+
+ r""" "core": "[\w\d\s]{1,16}",\n"""
|
23 |
+
+ r""" "length": [0-9]{1,2}\.[0-9]{0,2}\n"""
|
24 |
+
+ r""" \},\n"""
|
25 |
+
+ r""" "alive": "(Alive|Deceased)",\n"""
|
26 |
+
+ r""" "patronus": "[\w\d\s]{1,16}",\n"""
|
27 |
+
+ r""" "bogart": "[\w\d\s]{1,16}"\n"""
|
28 |
+
+ r"""\}"""
|
29 |
+
)
|
30 |
+
|
31 |
+
|
32 |
+
@sgl.function
|
33 |
+
def character_gen(s, name):
|
34 |
+
s += (
|
35 |
+
name
|
36 |
+
+ " is a character in Harry Potter. Please fill in the following information about this character.\n"
|
37 |
+
)
|
38 |
+
s += "The constrained regex is:\n"
|
39 |
+
s += character_regex + "\n"
|
40 |
+
s += "The JSON output is:\n"
|
41 |
+
s += sgl.gen("json_output", max_tokens=256, regex=character_regex)
|
42 |
+
|
43 |
+
|
44 |
+
def driver_character_gen():
|
45 |
+
state = character_gen.run(name="Hermione Granger")
|
46 |
+
print(state.text())
|
47 |
+
|
48 |
+
|
49 |
+
class Weapon(str, Enum):
|
50 |
+
sword = "sword"
|
51 |
+
axe = "axe"
|
52 |
+
mace = "mace"
|
53 |
+
spear = "spear"
|
54 |
+
bow = "bow"
|
55 |
+
crossbow = "crossbow"
|
56 |
+
|
57 |
+
|
58 |
+
class Wizard(BaseModel):
|
59 |
+
name: str
|
60 |
+
age: int
|
61 |
+
weapon: Weapon
|
62 |
+
|
63 |
+
|
64 |
+
@sgl.function
|
65 |
+
def pydantic_wizard_gen(s):
|
66 |
+
s += "Give me a description about a wizard in the JSON format.\n"
|
67 |
+
s += sgl.gen(
|
68 |
+
"character",
|
69 |
+
max_tokens=128,
|
70 |
+
temperature=0,
|
71 |
+
regex=build_regex_from_object(Wizard), # Requires pydantic >= 2.0
|
72 |
+
)
|
73 |
+
|
74 |
+
|
75 |
+
def driver_pydantic_wizard_gen():
|
76 |
+
state = pydantic_wizard_gen.run()
|
77 |
+
print(state.text())
|
78 |
+
|
79 |
+
|
80 |
+
if __name__ == "__main__":
|
81 |
+
sgl.set_default_backend(sgl.RuntimeEndpoint("http://localhost:30000"))
|
82 |
+
driver_character_gen()
|
83 |
+
# driver_pydantic_wizard_gen()
|
sglang/examples/frontend_language/usage/json_logprobs.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# NOTE: Currently this can only be run through HTTP requests.
|
2 |
+
import json
|
3 |
+
from concurrent.futures import ThreadPoolExecutor
|
4 |
+
|
5 |
+
from json_decode import character_regex
|
6 |
+
|
7 |
+
from sglang.utils import http_request
|
8 |
+
|
9 |
+
character_names = ["Hermione Granger", "Ron Weasley", "Harry Potter"]
|
10 |
+
|
11 |
+
base_url = "http://localhost:30000"
|
12 |
+
|
13 |
+
prompt = "is a character in Harry Potter. Please fill in the following information about this character.\n"
|
14 |
+
|
15 |
+
|
16 |
+
def openai_api_request(name):
|
17 |
+
data = {
|
18 |
+
"model": "",
|
19 |
+
"prompt": name + prompt,
|
20 |
+
"temperature": 0,
|
21 |
+
"max_tokens": 128,
|
22 |
+
"regex": character_regex,
|
23 |
+
"logprobs": 3,
|
24 |
+
}
|
25 |
+
res = http_request(base_url + "/v1/completions", json=data).json()
|
26 |
+
|
27 |
+
# with open(f"json_logprobs_{name.replace(' ', '_')}_tmp.json", "w") as fout:
|
28 |
+
# fout.write(json.dumps(res, indent=4))
|
29 |
+
|
30 |
+
logprobs = res["choices"][0]["logprobs"]
|
31 |
+
usage = res["usage"]
|
32 |
+
assert len(logprobs["token_logprobs"]) == len(logprobs["tokens"])
|
33 |
+
assert len(logprobs["token_logprobs"]) == len(logprobs["top_logprobs"])
|
34 |
+
assert len(logprobs["token_logprobs"]) == usage["completion_tokens"] - 1
|
35 |
+
|
36 |
+
return res
|
37 |
+
|
38 |
+
|
39 |
+
def srt_api_request(name):
|
40 |
+
data = {
|
41 |
+
"text": name + prompt,
|
42 |
+
"sampling_params": {
|
43 |
+
"temperature": 0,
|
44 |
+
"max_new_tokens": 128,
|
45 |
+
"regex": character_regex,
|
46 |
+
},
|
47 |
+
"return_logprob": True,
|
48 |
+
"logprob_start_len": 0,
|
49 |
+
"top_logprobs_num": 3,
|
50 |
+
"return_text_in_logprobs": True,
|
51 |
+
}
|
52 |
+
|
53 |
+
res = http_request(base_url + "/generate", json=data).json()
|
54 |
+
|
55 |
+
# with open(f"json_logprobs_{name.replace(' ', '_')}_tmp.json", "w") as fout:
|
56 |
+
# fout.write(json.dumps(res, indent=4))
|
57 |
+
|
58 |
+
meta_info = res["meta_info"]
|
59 |
+
assert len(meta_info["input_token_logprobs"]) == len(
|
60 |
+
meta_info["input_top_logprobs"]
|
61 |
+
)
|
62 |
+
assert len(meta_info["output_token_logprobs"]) == len(
|
63 |
+
meta_info["output_top_logprobs"]
|
64 |
+
)
|
65 |
+
assert len(meta_info["input_token_logprobs"]) == meta_info["prompt_tokens"]
|
66 |
+
assert len(meta_info["output_token_logprobs"]) == meta_info["completion_tokens"] - 1
|
67 |
+
|
68 |
+
return res
|
69 |
+
|
70 |
+
|
71 |
+
def pretty_print(res):
|
72 |
+
meta_info = res["meta_info"]
|
73 |
+
|
74 |
+
print("\n\n", "=" * 30, "Prefill", "=" * 30)
|
75 |
+
for i in range(len(meta_info["input_token_logprobs"])):
|
76 |
+
print(f"{str(meta_info['input_token_logprobs'][i][2].encode()): <20}", end="")
|
77 |
+
top_ks = (
|
78 |
+
[str(t[2].encode()) for t in meta_info["input_top_logprobs"][i]]
|
79 |
+
if meta_info["input_top_logprobs"][i]
|
80 |
+
else []
|
81 |
+
)
|
82 |
+
for top_k in top_ks:
|
83 |
+
print(f"{top_k: <15}", end="")
|
84 |
+
print()
|
85 |
+
|
86 |
+
print("\n\n", "=" * 30, "Decode", "=" * 30)
|
87 |
+
for i in range(len(meta_info["output_token_logprobs"])):
|
88 |
+
print(f"{str(meta_info['output_token_logprobs'][i][2].encode()): <20}", end="")
|
89 |
+
top_ks = [str(t[2].encode()) for t in meta_info["output_top_logprobs"][i]]
|
90 |
+
for top_k in top_ks:
|
91 |
+
print(f"{top_k: <15}", end="")
|
92 |
+
print()
|
93 |
+
|
94 |
+
print(res["text"])
|
95 |
+
|
96 |
+
|
97 |
+
if __name__ == "__main__":
|
98 |
+
with ThreadPoolExecutor() as executor:
|
99 |
+
ress = executor.map(srt_api_request, character_names)
|
100 |
+
|
101 |
+
for res in ress:
|
102 |
+
pretty_print(res)
|
103 |
+
|
104 |
+
openai_api_request("Hermione Granger")
|
sglang/examples/frontend_language/usage/llava_video/srt_example_llava_v.py
ADDED
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Usage:
|
3 |
+
pip install opencv-python-headless
|
4 |
+
|
5 |
+
python3 srt_example_llava_v.py
|
6 |
+
"""
|
7 |
+
|
8 |
+
import argparse
|
9 |
+
import csv
|
10 |
+
import json
|
11 |
+
import os
|
12 |
+
import time
|
13 |
+
|
14 |
+
import requests
|
15 |
+
|
16 |
+
import sglang as sgl
|
17 |
+
|
18 |
+
|
19 |
+
@sgl.function
|
20 |
+
def video_qa(s, num_frames, video_path, question):
|
21 |
+
s += sgl.user(sgl.video(video_path, num_frames) + question)
|
22 |
+
s += sgl.assistant(sgl.gen("answer"))
|
23 |
+
|
24 |
+
|
25 |
+
def single(path, num_frames=16):
|
26 |
+
state = video_qa.run(
|
27 |
+
num_frames=num_frames,
|
28 |
+
video_path=path,
|
29 |
+
question="Please provide a detailed description of the video, focusing on the main subjects, their actions, the background scenes",
|
30 |
+
temperature=0.0,
|
31 |
+
max_new_tokens=1024,
|
32 |
+
)
|
33 |
+
print(state["answer"], "\n")
|
34 |
+
|
35 |
+
|
36 |
+
def split_into_chunks(lst, num_chunks):
|
37 |
+
"""Split a list into a specified number of chunks."""
|
38 |
+
# Calculate the chunk size using integer division. Note that this may drop some items if not evenly divisible.
|
39 |
+
chunk_size = len(lst) // num_chunks
|
40 |
+
|
41 |
+
if chunk_size == 0:
|
42 |
+
chunk_size = len(lst)
|
43 |
+
# Use list comprehension to generate chunks. The last chunk will take any remainder if the list size isn't evenly divisible.
|
44 |
+
chunks = [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)]
|
45 |
+
# Ensure we have exactly num_chunks chunks, even if some are empty
|
46 |
+
chunks.extend([[] for _ in range(num_chunks - len(chunks))])
|
47 |
+
return chunks
|
48 |
+
|
49 |
+
|
50 |
+
def save_batch_results(batch_video_files, states, cur_chunk, batch_idx, save_dir):
|
51 |
+
csv_filename = f"{save_dir}/chunk_{cur_chunk}_batch_{batch_idx}.csv"
|
52 |
+
with open(csv_filename, "w", newline="") as csvfile:
|
53 |
+
writer = csv.writer(csvfile)
|
54 |
+
writer.writerow(["video_name", "answer"])
|
55 |
+
for video_path, state in zip(batch_video_files, states):
|
56 |
+
video_name = os.path.basename(video_path)
|
57 |
+
writer.writerow([video_name, state["answer"]])
|
58 |
+
|
59 |
+
|
60 |
+
def compile_and_cleanup_final_results(cur_chunk, num_batches, save_dir):
|
61 |
+
final_csv_filename = f"{save_dir}/final_results_chunk_{cur_chunk}.csv"
|
62 |
+
with open(final_csv_filename, "w", newline="") as final_csvfile:
|
63 |
+
writer = csv.writer(final_csvfile)
|
64 |
+
writer.writerow(["video_name", "answer"])
|
65 |
+
for batch_idx in range(num_batches):
|
66 |
+
batch_csv_filename = f"{save_dir}/chunk_{cur_chunk}_batch_{batch_idx}.csv"
|
67 |
+
with open(batch_csv_filename, "r") as batch_csvfile:
|
68 |
+
reader = csv.reader(batch_csvfile)
|
69 |
+
next(reader) # Skip header row
|
70 |
+
for row in reader:
|
71 |
+
writer.writerow(row)
|
72 |
+
os.remove(batch_csv_filename)
|
73 |
+
|
74 |
+
|
75 |
+
def find_video_files(video_dir):
|
76 |
+
# Check if the video_dir is actually a file
|
77 |
+
if os.path.isfile(video_dir):
|
78 |
+
# If it's a file, return it as a single-element list
|
79 |
+
return [video_dir]
|
80 |
+
|
81 |
+
# Original logic to find video files in a directory
|
82 |
+
video_files = []
|
83 |
+
for root, dirs, files in os.walk(video_dir):
|
84 |
+
for file in files:
|
85 |
+
if file.endswith((".mp4", ".avi", ".mov")):
|
86 |
+
video_files.append(os.path.join(root, file))
|
87 |
+
return video_files
|
88 |
+
|
89 |
+
|
90 |
+
def batch(video_dir, save_dir, cur_chunk, num_chunks, num_frames=16, batch_size=64):
|
91 |
+
video_files = find_video_files(video_dir)
|
92 |
+
chunked_video_files = split_into_chunks(video_files, num_chunks)[cur_chunk]
|
93 |
+
num_batches = 0
|
94 |
+
|
95 |
+
for i in range(0, len(chunked_video_files), batch_size):
|
96 |
+
batch_video_files = chunked_video_files[i : i + batch_size]
|
97 |
+
print(f"Processing batch of {len(batch_video_files)} video(s)...")
|
98 |
+
|
99 |
+
if not batch_video_files:
|
100 |
+
print("No video files found in the specified directory.")
|
101 |
+
return
|
102 |
+
|
103 |
+
batch_input = [
|
104 |
+
{
|
105 |
+
"num_frames": num_frames,
|
106 |
+
"video_path": video_path,
|
107 |
+
"question": "Please provide a detailed description of the video, focusing on the main subjects, their actions, the background scenes.",
|
108 |
+
}
|
109 |
+
for video_path in batch_video_files
|
110 |
+
]
|
111 |
+
|
112 |
+
start_time = time.time()
|
113 |
+
states = video_qa.run_batch(batch_input, max_new_tokens=512, temperature=0.2)
|
114 |
+
total_time = time.time() - start_time
|
115 |
+
average_time = total_time / len(batch_video_files)
|
116 |
+
print(
|
117 |
+
f"Number of videos in batch: {len(batch_video_files)}. Average processing time per video: {average_time:.2f} seconds. Total time for this batch: {total_time:.2f} seconds"
|
118 |
+
)
|
119 |
+
|
120 |
+
save_batch_results(batch_video_files, states, cur_chunk, num_batches, save_dir)
|
121 |
+
num_batches += 1
|
122 |
+
|
123 |
+
compile_and_cleanup_final_results(cur_chunk, num_batches, save_dir)
|
124 |
+
|
125 |
+
|
126 |
+
if __name__ == "__main__":
|
127 |
+
|
128 |
+
url = "https://raw.githubusercontent.com/EvolvingLMMs-Lab/sglang/dev/onevision_local/assets/jobs.mp4"
|
129 |
+
|
130 |
+
cache_dir = os.path.expanduser("~/.cache")
|
131 |
+
file_path = os.path.join(cache_dir, "jobs.mp4")
|
132 |
+
|
133 |
+
os.makedirs(cache_dir, exist_ok=True)
|
134 |
+
|
135 |
+
response = requests.get(url)
|
136 |
+
response.raise_for_status() # Raise an exception for bad responses
|
137 |
+
|
138 |
+
with open(file_path, "wb") as f:
|
139 |
+
f.write(response.content)
|
140 |
+
|
141 |
+
print(f"File downloaded and saved to: {file_path}")
|
142 |
+
# Create the parser
|
143 |
+
parser = argparse.ArgumentParser(
|
144 |
+
description="Run video processing with specified port."
|
145 |
+
)
|
146 |
+
|
147 |
+
# Add an argument for the port
|
148 |
+
parser.add_argument(
|
149 |
+
"--port",
|
150 |
+
type=int,
|
151 |
+
default=30000,
|
152 |
+
help="The master port for distributed serving.",
|
153 |
+
)
|
154 |
+
parser.add_argument(
|
155 |
+
"--chunk-idx", type=int, default=0, help="The index of the chunk to process."
|
156 |
+
)
|
157 |
+
parser.add_argument(
|
158 |
+
"--num-chunks", type=int, default=8, help="The number of chunks to process."
|
159 |
+
)
|
160 |
+
parser.add_argument(
|
161 |
+
"--save-dir",
|
162 |
+
type=str,
|
163 |
+
default="./work_dirs/llava_video",
|
164 |
+
help="The directory to save the processed video files.",
|
165 |
+
)
|
166 |
+
parser.add_argument(
|
167 |
+
"--video-dir",
|
168 |
+
type=str,
|
169 |
+
default=os.path.expanduser("~/.cache/jobs.mp4"),
|
170 |
+
help="The directory or path for the processed video files.",
|
171 |
+
)
|
172 |
+
parser.add_argument(
|
173 |
+
"--model-path",
|
174 |
+
type=str,
|
175 |
+
default="lmms-lab/LLaVA-NeXT-Video-7B",
|
176 |
+
help="The model path for the video processing.",
|
177 |
+
)
|
178 |
+
parser.add_argument(
|
179 |
+
"--num-frames",
|
180 |
+
type=int,
|
181 |
+
default=16,
|
182 |
+
help="The number of frames to process in each video.",
|
183 |
+
)
|
184 |
+
parser.add_argument("--mm_spatial_pool_stride", type=int, default=2)
|
185 |
+
|
186 |
+
# Parse the arguments
|
187 |
+
args = parser.parse_args()
|
188 |
+
cur_port = args.port
|
189 |
+
cur_chunk = args.chunk_idx
|
190 |
+
num_chunks = args.num_chunks
|
191 |
+
num_frames = args.num_frames
|
192 |
+
|
193 |
+
if "34b" in args.model_path.lower():
|
194 |
+
tokenizer_path = "liuhaotian/llava-v1.6-34b-tokenizer"
|
195 |
+
elif "7b" in args.model_path.lower():
|
196 |
+
tokenizer_path = "llava-hf/llava-1.5-7b-hf"
|
197 |
+
else:
|
198 |
+
print("Invalid model path. Please specify a valid model path.")
|
199 |
+
exit()
|
200 |
+
|
201 |
+
model_override_args = {}
|
202 |
+
model_override_args["mm_spatial_pool_stride"] = args.mm_spatial_pool_stride
|
203 |
+
model_override_args["architectures"] = ["LlavaVidForCausalLM"]
|
204 |
+
model_override_args["num_frames"] = args.num_frames
|
205 |
+
model_override_args["model_type"] = "llava"
|
206 |
+
|
207 |
+
if "34b" in args.model_path.lower():
|
208 |
+
model_override_args["image_token_index"] = 64002
|
209 |
+
|
210 |
+
if args.num_frames == 32:
|
211 |
+
model_override_args["rope_scaling"] = {"factor": 2.0, "rope_type": "linear"}
|
212 |
+
model_override_args["max_sequence_length"] = 4096 * 2
|
213 |
+
model_override_args["tokenizer_model_max_length"] = 4096 * 2
|
214 |
+
elif args.num_frames < 32:
|
215 |
+
pass
|
216 |
+
else:
|
217 |
+
print(
|
218 |
+
"The maximum number of frames to process is 32. Please specify a valid number of frames."
|
219 |
+
)
|
220 |
+
exit()
|
221 |
+
|
222 |
+
runtime = sgl.Runtime(
|
223 |
+
model_path=args.model_path, # "liuhaotian/llava-v1.6-vicuna-7b",
|
224 |
+
tokenizer_path=tokenizer_path,
|
225 |
+
port=cur_port,
|
226 |
+
json_model_override_args=json.dumps(model_override_args),
|
227 |
+
tp_size=1,
|
228 |
+
)
|
229 |
+
sgl.set_default_backend(runtime)
|
230 |
+
print(f"chat template: {runtime.endpoint.chat_template.name}")
|
231 |
+
|
232 |
+
# Run a single request
|
233 |
+
print("\n========== single ==========\n")
|
234 |
+
root = args.video_dir
|
235 |
+
if os.path.isfile(root):
|
236 |
+
video_files = [root]
|
237 |
+
else:
|
238 |
+
video_files = [
|
239 |
+
os.path.join(root, f)
|
240 |
+
for f in os.listdir(root)
|
241 |
+
if f.endswith((".mp4", ".avi", ".mov"))
|
242 |
+
] # Add more extensions if needed
|
243 |
+
start_time = time.time() # Start time for processing a single video
|
244 |
+
for cur_video in video_files[:1]:
|
245 |
+
print(cur_video)
|
246 |
+
single(cur_video, num_frames)
|
247 |
+
end_time = time.time() # End time for processing a single video
|
248 |
+
total_time = end_time - start_time
|
249 |
+
average_time = total_time / len(
|
250 |
+
video_files
|
251 |
+
) # Calculate the average processing time
|
252 |
+
print(f"Average processing time per video: {average_time:.2f} seconds")
|
253 |
+
runtime.shutdown()
|
254 |
+
|
255 |
+
# # Run a batch of requests
|
256 |
+
# print("\n========== batch ==========\n")
|
257 |
+
# if not os.path.exists(args.save_dir):
|
258 |
+
# os.makedirs(args.save_dir)
|
259 |
+
# batch(args.video_dir, args.save_dir, cur_chunk, num_chunks, num_frames, num_chunks)
|
260 |
+
# runtime.shutdown()
|
sglang/examples/frontend_language/usage/llava_video/srt_example_llava_v.sh
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
##### USAGE #####
|
4 |
+
# - First node:
|
5 |
+
# ```sh
|
6 |
+
# bash examples/usage/llava_video/srt_example_llava_v.sh K 0 YOUR_VIDEO_PATH YOUR_MODEL_PATH FRAMES_PER_VIDEO
|
7 |
+
# ```
|
8 |
+
# - Second node:
|
9 |
+
# ```sh
|
10 |
+
# bash examples/usage/llava_video/srt_example_llava_v.sh K 1 YOUR_VIDEO_PATH YOUR_MODEL_PATH FRAMES_PER_VIDEO
|
11 |
+
# ```
|
12 |
+
# - The K node:
|
13 |
+
# ```sh
|
14 |
+
# bash examples/usage/llava_video/srt_example_llava_v.sh K K-1 YOUR_VIDEO_PATH YOUR_MODEL_PATH FRAMES_PER_VIDEO
|
15 |
+
# ```
|
16 |
+
|
17 |
+
|
18 |
+
# Replace `K`, `YOUR_VIDEO_PATH`, `YOUR_MODEL_PATH`, and `FRAMES_PER_VIDEO` with your specific details.
|
19 |
+
# CURRENT_ROOT="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
|
20 |
+
CURRENT_ROOT=$(dirname "$0")
|
21 |
+
|
22 |
+
echo ${CURRENT_ROOT}
|
23 |
+
|
24 |
+
cd ${CURRENT_ROOT}
|
25 |
+
|
26 |
+
export PYTHONWARNINGS=ignore
|
27 |
+
|
28 |
+
START_TIME=$(date +%s) # Capture start time
|
29 |
+
|
30 |
+
NUM_NODES=$1
|
31 |
+
|
32 |
+
CUR_NODES_IDX=$2
|
33 |
+
|
34 |
+
VIDEO_DIR=$3
|
35 |
+
|
36 |
+
MODEL_PATH=$4
|
37 |
+
|
38 |
+
NUM_FRAMES=$5
|
39 |
+
|
40 |
+
|
41 |
+
# FRAME_FORMAT=$6
|
42 |
+
|
43 |
+
# FRAME_FORMAT=$(echo $FRAME_FORMAT | tr '[:lower:]' '[:upper:]')
|
44 |
+
|
45 |
+
# # Check if FRAME_FORMAT is either JPEG or PNG
|
46 |
+
# if [[ "$FRAME_FORMAT" != "JPEG" && "$FRAME_FORMAT" != "PNG" ]]; then
|
47 |
+
# echo "Error: FRAME_FORMAT must be either JPEG or PNG."
|
48 |
+
# exit 1
|
49 |
+
# fi
|
50 |
+
|
51 |
+
# export TARGET_FRAMES=$TARGET_FRAMES
|
52 |
+
|
53 |
+
echo "Each video you will sample $NUM_FRAMES frames"
|
54 |
+
|
55 |
+
# export FRAME_FORMAT=$FRAME_FORMAT
|
56 |
+
|
57 |
+
# echo "The frame format is $FRAME_FORMAT"
|
58 |
+
|
59 |
+
# Assuming GPULIST is a bash array containing your GPUs
|
60 |
+
GPULIST=(0 1 2 3 4 5 6 7)
|
61 |
+
LOCAL_CHUNKS=${#GPULIST[@]}
|
62 |
+
|
63 |
+
echo "Number of GPUs in GPULIST: $LOCAL_CHUNKS"
|
64 |
+
|
65 |
+
ALL_CHUNKS=$((NUM_NODES * LOCAL_CHUNKS))
|
66 |
+
|
67 |
+
# Calculate GPUs per chunk
|
68 |
+
GPUS_PER_CHUNK=1
|
69 |
+
|
70 |
+
echo $GPUS_PER_CHUNK
|
71 |
+
|
72 |
+
for IDX in $(seq 1 $LOCAL_CHUNKS); do
|
73 |
+
(
|
74 |
+
START=$(((IDX-1) * GPUS_PER_CHUNK))
|
75 |
+
LENGTH=$GPUS_PER_CHUNK # Length for slicing, not the end index
|
76 |
+
|
77 |
+
CHUNK_GPUS=(${GPULIST[@]:$START:$LENGTH})
|
78 |
+
|
79 |
+
# Convert the chunk GPUs array to a comma-separated string
|
80 |
+
CHUNK_GPUS_STR=$(IFS=,; echo "${CHUNK_GPUS[*]}")
|
81 |
+
|
82 |
+
LOCAL_IDX=$((CUR_NODES_IDX * LOCAL_CHUNKS + IDX))
|
83 |
+
|
84 |
+
echo "Chunk $(($LOCAL_IDX - 1)) will run on GPUs $CHUNK_GPUS_STR"
|
85 |
+
|
86 |
+
# Calculate the port for this chunk. Ensure it's incremented by 5 for each chunk.
|
87 |
+
PORT=$((10000 + RANDOM % 55536))
|
88 |
+
|
89 |
+
MAX_RETRIES=10
|
90 |
+
RETRY_COUNT=0
|
91 |
+
COMMAND_STATUS=1 # Initialize as failed
|
92 |
+
|
93 |
+
while [ $RETRY_COUNT -lt $MAX_RETRIES ] && [ $COMMAND_STATUS -ne 0 ]; do
|
94 |
+
echo "Running chunk $(($LOCAL_IDX - 1)) on GPUs $CHUNK_GPUS_STR with port $PORT. Attempt $(($RETRY_COUNT + 1))"
|
95 |
+
|
96 |
+
#!/bin/bash
|
97 |
+
CUDA_VISIBLE_DEVICES=$CHUNK_GPUS_STR python3 srt_example_llava_v.py \
|
98 |
+
--port $PORT \
|
99 |
+
--num-chunks $ALL_CHUNKS \
|
100 |
+
--chunk-idx $(($LOCAL_IDX - 1)) \
|
101 |
+
--save-dir work_dirs/llava_next_video_inference_results \
|
102 |
+
--video-dir $VIDEO_DIR \
|
103 |
+
--model-path $MODEL_PATH \
|
104 |
+
--num-frames $NUM_FRAMES #&
|
105 |
+
|
106 |
+
wait $! # Wait for the process to finish and capture its exit status
|
107 |
+
COMMAND_STATUS=$?
|
108 |
+
|
109 |
+
if [ $COMMAND_STATUS -ne 0 ]; then
|
110 |
+
echo "Execution failed for chunk $(($LOCAL_IDX - 1)), attempt $(($RETRY_COUNT + 1)). Retrying..."
|
111 |
+
RETRY_COUNT=$(($RETRY_COUNT + 1))
|
112 |
+
sleep 180 # Wait a bit before retrying
|
113 |
+
else
|
114 |
+
echo "Execution succeeded for chunk $(($LOCAL_IDX - 1))."
|
115 |
+
fi
|
116 |
+
done
|
117 |
+
|
118 |
+
if [ $COMMAND_STATUS -ne 0 ]; then
|
119 |
+
echo "Execution failed for chunk $(($LOCAL_IDX - 1)) after $MAX_RETRIES attempts."
|
120 |
+
fi
|
121 |
+
) #&
|
122 |
+
sleep 2 # Slight delay to stagger the start times
|
123 |
+
done
|
124 |
+
|
125 |
+
wait
|
126 |
+
|
127 |
+
cat work_dirs/llava_next_video_inference_results/final_results_chunk_*.csv > work_dirs/llava_next_video_inference_results/final_results_node_${CUR_NODES_IDX}.csv
|
128 |
+
|
129 |
+
END_TIME=$(date +%s) # Capture end time
|
130 |
+
ELAPSED_TIME=$(($END_TIME - $START_TIME))
|
131 |
+
echo "Total execution time: $ELAPSED_TIME seconds."
|
sglang/examples/frontend_language/usage/openai_chat_speculative.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Usage:
|
3 |
+
***Note: for speculative execution to work, user must put all "gen" in "assistant".
|
4 |
+
Show in "assistant" the desired answer format. Each "gen" term should have a stop token.
|
5 |
+
The stream mode is not supported in speculative execution.
|
6 |
+
|
7 |
+
E.g.
|
8 |
+
correct:
|
9 |
+
sgl.assistant("\nName:" + sgl.gen("name", stop="\n") + "\nBirthday:" + sgl.gen("birthday", stop="\n") + "\nJob:" + sgl.gen("job", stop="\n"))
|
10 |
+
incorrect:
|
11 |
+
s += sgl.assistant("\nName:" + sgl.gen("name", stop="\n"))
|
12 |
+
s += sgl.assistant("\nBirthday:" + sgl.gen("birthday", stop="\n"))
|
13 |
+
s += sgl.assistant("\nJob:" + sgl.gen("job", stop="\n"))
|
14 |
+
|
15 |
+
export OPENAI_API_KEY=sk-******
|
16 |
+
python3 openai_chat_speculative.py
|
17 |
+
"""
|
18 |
+
|
19 |
+
import sglang as sgl
|
20 |
+
from sglang import OpenAI, function, set_default_backend
|
21 |
+
|
22 |
+
|
23 |
+
@function(num_api_spec_tokens=256)
|
24 |
+
def gen_character_spec(s):
|
25 |
+
s += sgl.system("You are a helpful assistant.")
|
26 |
+
s += sgl.user("Construct a character within the following format:")
|
27 |
+
s += sgl.assistant(
|
28 |
+
"Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
|
29 |
+
)
|
30 |
+
s += sgl.user("Please generate new Name, Birthday and Job.\n")
|
31 |
+
s += sgl.assistant(
|
32 |
+
"Name:"
|
33 |
+
+ sgl.gen("name", stop="\n")
|
34 |
+
+ "\nBirthday:"
|
35 |
+
+ sgl.gen("birthday", stop="\n")
|
36 |
+
+ "\nJob:"
|
37 |
+
+ sgl.gen("job", stop="\n")
|
38 |
+
)
|
39 |
+
|
40 |
+
|
41 |
+
@function(num_api_spec_tokens=256)
|
42 |
+
def gen_character_spec_no_few_shot(s):
|
43 |
+
s += sgl.user("Construct a character. For each field stop with a newline\n")
|
44 |
+
s += sgl.assistant(
|
45 |
+
"Name:"
|
46 |
+
+ sgl.gen("name", stop="\n")
|
47 |
+
+ "\nAge:"
|
48 |
+
+ sgl.gen("age", stop="\n")
|
49 |
+
+ "\nJob:"
|
50 |
+
+ sgl.gen("job", stop="\n")
|
51 |
+
)
|
52 |
+
|
53 |
+
|
54 |
+
@function
|
55 |
+
def gen_character_normal(s):
|
56 |
+
s += sgl.system("You are a helpful assistant.")
|
57 |
+
s += sgl.user("What's the answer of 23 + 8?")
|
58 |
+
s += sgl.assistant(sgl.gen("answer", max_tokens=64))
|
59 |
+
|
60 |
+
|
61 |
+
@function(num_api_spec_tokens=1024)
|
62 |
+
def multi_turn_question(s, question_1, question_2):
|
63 |
+
s += sgl.system("You are a helpful assistant.")
|
64 |
+
s += sgl.user("Answer questions in the following format:")
|
65 |
+
s += sgl.user(
|
66 |
+
"Question 1: What is the capital of France?\nQuestion 2: What is the population of this city?\n"
|
67 |
+
)
|
68 |
+
s += sgl.assistant(
|
69 |
+
"Answer 1: The capital of France is Paris.\nAnswer 2: The population of Paris in 2024 is estimated to be around 2.1 million for the city proper.\n"
|
70 |
+
)
|
71 |
+
s += sgl.user("Question 1: " + question_1 + "\nQuestion 2: " + question_2)
|
72 |
+
s += sgl.assistant(
|
73 |
+
"Answer 1: "
|
74 |
+
+ sgl.gen("answer_1", stop="\n")
|
75 |
+
+ "\nAnswer 2: "
|
76 |
+
+ sgl.gen("answer_2", stop="\n")
|
77 |
+
)
|
78 |
+
|
79 |
+
|
80 |
+
def test_spec_single_turn():
|
81 |
+
backend.token_usage.reset()
|
82 |
+
|
83 |
+
state = gen_character_spec.run()
|
84 |
+
for m in state.messages():
|
85 |
+
print(m["role"], ":", m["content"])
|
86 |
+
|
87 |
+
print("\n-- name:", state["name"])
|
88 |
+
print("-- birthday:", state["birthday"])
|
89 |
+
print("-- job:", state["job"])
|
90 |
+
print(backend.token_usage)
|
91 |
+
|
92 |
+
|
93 |
+
def test_inaccurate_spec_single_turn():
|
94 |
+
state = gen_character_spec_no_few_shot.run()
|
95 |
+
for m in state.messages():
|
96 |
+
print(m["role"], ":", m["content"])
|
97 |
+
|
98 |
+
print("\n-- name:", state["name"])
|
99 |
+
print("\n-- age:", state["age"])
|
100 |
+
print("\n-- job:", state["job"])
|
101 |
+
|
102 |
+
|
103 |
+
def test_normal_single_turn():
|
104 |
+
state = gen_character_normal.run()
|
105 |
+
for m in state.messages():
|
106 |
+
print(m["role"], ":", m["content"])
|
107 |
+
|
108 |
+
|
109 |
+
def test_spec_multi_turn():
|
110 |
+
state = multi_turn_question.run(
|
111 |
+
question_1="What is the capital of the United States?",
|
112 |
+
question_2="List two local attractions in the capital of the United States.",
|
113 |
+
)
|
114 |
+
|
115 |
+
for m in state.messages():
|
116 |
+
print(m["role"], ":", m["content"])
|
117 |
+
|
118 |
+
print("\n-- answer_1 --\n", state["answer_1"])
|
119 |
+
print("\n-- answer_2 --\n", state["answer_2"])
|
120 |
+
|
121 |
+
|
122 |
+
def test_spec_multi_turn_stream():
|
123 |
+
state = multi_turn_question.run(
|
124 |
+
question_1="What is the capital of the United States?",
|
125 |
+
question_2="List two local attractions.",
|
126 |
+
stream=True,
|
127 |
+
)
|
128 |
+
|
129 |
+
for out in state.text_iter():
|
130 |
+
print(out, end="", flush=True)
|
131 |
+
|
132 |
+
|
133 |
+
if __name__ == "__main__":
|
134 |
+
backend = OpenAI("gpt-4-turbo")
|
135 |
+
set_default_backend(backend)
|
136 |
+
|
137 |
+
print("\n========== test spec single turn ==========\n")
|
138 |
+
# expect reasonable answer for each field
|
139 |
+
test_spec_single_turn()
|
140 |
+
|
141 |
+
print("\n========== test inaccurate spec single turn ==========\n")
|
142 |
+
# expect incomplete or unreasonable answers
|
143 |
+
test_inaccurate_spec_single_turn()
|
144 |
+
|
145 |
+
print("\n========== test normal single turn ==========\n")
|
146 |
+
# expect reasonable answer
|
147 |
+
test_normal_single_turn()
|
148 |
+
|
149 |
+
print("\n========== test spec multi turn ==========\n")
|
150 |
+
# expect answer with same format as in the few shot
|
151 |
+
test_spec_multi_turn()
|
152 |
+
|
153 |
+
print("\n========== test spec multi turn stream ==========\n")
|
154 |
+
# expect error in stream_executor: stream is not supported...
|
155 |
+
test_spec_multi_turn_stream()
|
sglang/examples/frontend_language/usage/openai_speculative.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Usage:
|
3 |
+
python3 openai_speculative.py
|
4 |
+
"""
|
5 |
+
|
6 |
+
from sglang import OpenAI, function, gen, set_default_backend
|
7 |
+
|
8 |
+
|
9 |
+
@function(num_api_spec_tokens=64)
|
10 |
+
def gen_character_spec(s):
|
11 |
+
s += "Construct a character within the following format:\n"
|
12 |
+
s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
|
13 |
+
s += "\nPlease generate new Name, Birthday and Job.\n"
|
14 |
+
s += "Name:" + gen("name", stop="\n") + "\nBirthday:" + gen("birthday", stop="\n")
|
15 |
+
s += "\nJob:" + gen("job", stop="\n") + "\n"
|
16 |
+
|
17 |
+
|
18 |
+
@function
|
19 |
+
def gen_character_no_spec(s):
|
20 |
+
s += "Construct a character within the following format:\n"
|
21 |
+
s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
|
22 |
+
s += "\nPlease generate new Name, Birthday and Job.\n"
|
23 |
+
s += "Name:" + gen("name", stop="\n") + "\nBirthday:" + gen("birthday", stop="\n")
|
24 |
+
s += "\nJob:" + gen("job", stop="\n") + "\n"
|
25 |
+
|
26 |
+
|
27 |
+
@function(num_api_spec_tokens=64)
|
28 |
+
def gen_character_spec_no_few_shot(s):
|
29 |
+
# s += "Construct a character with name, birthday, and job:\n"
|
30 |
+
s += "Construct a character:\n"
|
31 |
+
s += "Name:" + gen("name", stop="\n") + "\nBirthday:" + gen("birthday", stop="\n")
|
32 |
+
s += "\nJob:" + gen("job", stop="\n") + "\n"
|
33 |
+
|
34 |
+
|
35 |
+
if __name__ == "__main__":
|
36 |
+
backend = OpenAI("gpt-3.5-turbo-instruct")
|
37 |
+
set_default_backend(backend)
|
38 |
+
|
39 |
+
for function in [
|
40 |
+
gen_character_spec,
|
41 |
+
gen_character_no_spec,
|
42 |
+
gen_character_spec_no_few_shot,
|
43 |
+
]:
|
44 |
+
backend.token_usage.reset()
|
45 |
+
|
46 |
+
print(f"function: {function.func.__name__}")
|
47 |
+
|
48 |
+
state = function.run()
|
49 |
+
|
50 |
+
print("...name:", state["name"])
|
51 |
+
print("...birthday:", state["birthday"])
|
52 |
+
print("...job:", state["job"])
|
53 |
+
print(backend.token_usage)
|
54 |
+
print()
|
sglang/examples/frontend_language/usage/parallel_sample.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Usage:
|
3 |
+
python3 parallel_sample.py
|
4 |
+
"""
|
5 |
+
|
6 |
+
import sglang as sgl
|
7 |
+
|
8 |
+
|
9 |
+
@sgl.function
|
10 |
+
def parallel_sample(s, question, n):
|
11 |
+
s += (
|
12 |
+
"Question: Compute 1 + 2 + 3\n"
|
13 |
+
"Reasoning: I need to use a calculator.\n"
|
14 |
+
"Tool: calculator\n"
|
15 |
+
"Answer: 6\n"
|
16 |
+
"Question: Compute 3 + 2 + 2\n"
|
17 |
+
"Reasoning: I will try a calculator.\n"
|
18 |
+
"Tool: calculator\n"
|
19 |
+
"Answer: 7\n"
|
20 |
+
)
|
21 |
+
s += "Question: " + question + "\n"
|
22 |
+
forks = s.fork(n)
|
23 |
+
forks += "Reasoning:" + sgl.gen("reasoning", stop="\n") + "\n"
|
24 |
+
forks += "Tool:" + sgl.gen("tool", choices=["calculator", "browser"]) + "\n"
|
25 |
+
forks += "Answer:" + sgl.gen("answer", stop="\n") + "\n"
|
26 |
+
forks.join()
|
27 |
+
|
28 |
+
|
29 |
+
sgl.set_default_backend(sgl.OpenAI("gpt-3.5-turbo-instruct"))
|
30 |
+
# sgl.set_default_backend(sgl.RuntimeEndpoint("http://localhost:30000"))
|
31 |
+
|
32 |
+
state = parallel_sample.run(question="Compute 5 + 2 + 4.", n=5, temperature=1.0)
|
33 |
+
|
34 |
+
for i in range(5):
|
35 |
+
obj = {
|
36 |
+
"reasoning": state["reasoning"][i],
|
37 |
+
"tool": state["tool"][i],
|
38 |
+
"answer": state["answer"][i],
|
39 |
+
}
|
40 |
+
print(f"[{i}], {obj}")
|
sglang/examples/frontend_language/usage/rag_using_parea/trace_and_evaluate_rag_using_parea.ipynb
ADDED
@@ -0,0 +1,408 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {},
|
6 |
+
"source": [
|
7 |
+
"# RAG Powered by SGLang & Chroma Evaluated using Parea\n",
|
8 |
+
"\n",
|
9 |
+
"In this notebook, we will build a simple RAG pipeline using SGLang to execute our LLM calls, Chroma as vector database for retrieval and [Parea](https://www.parea.ai) for tracing and evaluation. We will then evaluate the performance of our RAG pipeline. The dataset we will use was created by [Virat](https://twitter.com/virattt) and contains 100 questions, contexts and answers from the Airbnb 2023 10k filing.\n",
|
10 |
+
"\n",
|
11 |
+
"The RAG pipeline consists of two steps:\n",
|
12 |
+
"1. Retrieval: Given a question, we retrieve the relevant context from all provided contexts.\n",
|
13 |
+
"2. Generation: Given the question and the retrieved context, we generate an answer.\n",
|
14 |
+
"\n",
|
15 |
+
"ℹ️ This notebook requires an OpenAI API key.\n",
|
16 |
+
"\n",
|
17 |
+
"ℹ️ This notebook requires a Parea API key, which can be created [here](https://docs.parea.ai/api-reference/authentication#parea-api-key)."
|
18 |
+
]
|
19 |
+
},
|
20 |
+
{
|
21 |
+
"cell_type": "markdown",
|
22 |
+
"metadata": {},
|
23 |
+
"source": [
|
24 |
+
"## Setting up the environment\n",
|
25 |
+
"\n",
|
26 |
+
"We will first install the necessary packages: `sglang`, `parea-ai` and `chromadb`."
|
27 |
+
]
|
28 |
+
},
|
29 |
+
{
|
30 |
+
"cell_type": "code",
|
31 |
+
"execution_count": null,
|
32 |
+
"metadata": {},
|
33 |
+
"outputs": [],
|
34 |
+
"source": [
|
35 |
+
"# note, if you use a Mac M1 chip, you might need to install grpcio 1.59.0 first such that installing chromadb works\n",
|
36 |
+
"# !pip install grpcio==1.59.0\n",
|
37 |
+
"\n",
|
38 |
+
"!pip install sglang[openai] parea-ai chromadb"
|
39 |
+
]
|
40 |
+
},
|
41 |
+
{
|
42 |
+
"cell_type": "markdown",
|
43 |
+
"metadata": {},
|
44 |
+
"source": [
|
45 |
+
"Create a Parea API key as outlined [here](https://docs.parea.ai/api-reference/authentication#parea-api-key) and save it in a `.env` file as `PAREA_API_KEY=your-api-key`."
|
46 |
+
]
|
47 |
+
},
|
48 |
+
{
|
49 |
+
"cell_type": "markdown",
|
50 |
+
"metadata": {},
|
51 |
+
"source": [
|
52 |
+
"## Indexing the data\n",
|
53 |
+
"\n",
|
54 |
+
"Now it's time to download the data & index it! For that, we create a collection called `contexts` in Chroma and add the contexts as documents."
|
55 |
+
]
|
56 |
+
},
|
57 |
+
{
|
58 |
+
"cell_type": "code",
|
59 |
+
"execution_count": null,
|
60 |
+
"metadata": {},
|
61 |
+
"outputs": [],
|
62 |
+
"source": [
|
63 |
+
"import json\n",
|
64 |
+
"import os\n",
|
65 |
+
"from typing import List\n",
|
66 |
+
"\n",
|
67 |
+
"import chromadb\n",
|
68 |
+
"\n",
|
69 |
+
"path_qca = \"airbnb-2023-10k-qca.json\"\n",
|
70 |
+
"\n",
|
71 |
+
"if not os.path.exists(path_qca):\n",
|
72 |
+
" !wget https://virattt.github.io/datasets/abnb-2023-10k.json -O airbnb-2023-10k-qca.json\n",
|
73 |
+
"\n",
|
74 |
+
"with open(path_qca, \"r\") as f:\n",
|
75 |
+
" question_context_answers = json.load(f)\n",
|
76 |
+
"\n",
|
77 |
+
"chroma_client = chromadb.PersistentClient()\n",
|
78 |
+
"collection = chroma_client.get_or_create_collection(name=\"contexts\")\n",
|
79 |
+
"if collection.count() == 0:\n",
|
80 |
+
" collection.add(\n",
|
81 |
+
" documents=[qca[\"context\"] for qca in question_context_answers],\n",
|
82 |
+
" ids=[str(i) for i in range(len(question_context_answers))],\n",
|
83 |
+
" )"
|
84 |
+
]
|
85 |
+
},
|
86 |
+
{
|
87 |
+
"cell_type": "markdown",
|
88 |
+
"metadata": {},
|
89 |
+
"source": [
|
90 |
+
"## Defining the RAG pipeline\n",
|
91 |
+
"\n",
|
92 |
+
"We will start with importing the necessary packages, setting up tracing of OpenAI calls via Parea and setting OpenAI as the default backend for SGLang."
|
93 |
+
]
|
94 |
+
},
|
95 |
+
{
|
96 |
+
"cell_type": "code",
|
97 |
+
"execution_count": null,
|
98 |
+
"metadata": {},
|
99 |
+
"outputs": [],
|
100 |
+
"source": [
|
101 |
+
"import os\n",
|
102 |
+
"import time\n",
|
103 |
+
"\n",
|
104 |
+
"from dotenv import load_dotenv\n",
|
105 |
+
"\n",
|
106 |
+
"from sglang import function, user, assistant, gen, set_default_backend, OpenAI\n",
|
107 |
+
"from sglang.lang.interpreter import ProgramState\n",
|
108 |
+
"from parea import Parea, trace\n",
|
109 |
+
"\n",
|
110 |
+
"\n",
|
111 |
+
"load_dotenv()\n",
|
112 |
+
"\n",
|
113 |
+
"os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
|
114 |
+
"\n",
|
115 |
+
"p = Parea(api_key=os.getenv(\"PAREA_API_KEY\"), project_name=\"rag_sglang\")\n",
|
116 |
+
"p.integrate_with_sglang()\n",
|
117 |
+
"\n",
|
118 |
+
"set_default_backend(OpenAI(\"gpt-3.5-turbo\"))"
|
119 |
+
]
|
120 |
+
},
|
121 |
+
{
|
122 |
+
"cell_type": "markdown",
|
123 |
+
"metadata": {},
|
124 |
+
"source": [
|
125 |
+
"Now we can define our retrieval step shown below. Notice, the `trace` decorator which will automatically trace inputs, output, latency, etc. of that call."
|
126 |
+
]
|
127 |
+
},
|
128 |
+
{
|
129 |
+
"cell_type": "code",
|
130 |
+
"execution_count": null,
|
131 |
+
"metadata": {},
|
132 |
+
"outputs": [],
|
133 |
+
"source": [
|
134 |
+
"@trace\n",
|
135 |
+
"def retrieval(question: str) -> List[str]:\n",
|
136 |
+
" return collection.query(query_texts=[question], n_results=1)[\"documents\"][0]"
|
137 |
+
]
|
138 |
+
},
|
139 |
+
{
|
140 |
+
"cell_type": "markdown",
|
141 |
+
"metadata": {},
|
142 |
+
"source": [
|
143 |
+
"Next we will define the generation step which uses SGLang to execute the LLM call."
|
144 |
+
]
|
145 |
+
},
|
146 |
+
{
|
147 |
+
"cell_type": "code",
|
148 |
+
"execution_count": null,
|
149 |
+
"metadata": {},
|
150 |
+
"outputs": [],
|
151 |
+
"source": [
|
152 |
+
"@function\n",
|
153 |
+
"def generation_sglang(s, question: str, *context: str):\n",
|
154 |
+
" context = \"\\n\".join(context)\n",
|
155 |
+
" s += user(\n",
|
156 |
+
" f\"Given this question:\\n{question}\\n\\nAnd this context:\\n{context}\\n\\nAnswer the question.\"\n",
|
157 |
+
" )\n",
|
158 |
+
" s += assistant(gen(\"answer\"))\n",
|
159 |
+
"\n",
|
160 |
+
"\n",
|
161 |
+
"@trace\n",
|
162 |
+
"def generation(question: str, *context):\n",
|
163 |
+
" state: ProgramState = generation_sglang.run(question, *context)\n",
|
164 |
+
" while not state.stream_executor.is_finished:\n",
|
165 |
+
" time.sleep(1)\n",
|
166 |
+
" return state.stream_executor.variables[\"answer\"]"
|
167 |
+
]
|
168 |
+
},
|
169 |
+
{
|
170 |
+
"cell_type": "markdown",
|
171 |
+
"metadata": {},
|
172 |
+
"source": [
|
173 |
+
"Finally, we can tie it together and execute a sample query."
|
174 |
+
]
|
175 |
+
},
|
176 |
+
{
|
177 |
+
"cell_type": "code",
|
178 |
+
"execution_count": null,
|
179 |
+
"metadata": {},
|
180 |
+
"outputs": [],
|
181 |
+
"source": [
|
182 |
+
"@trace\n",
|
183 |
+
"def rag_pipeline(question: str) -> str:\n",
|
184 |
+
" contexts = retrieval(question)\n",
|
185 |
+
" return generation(question, *contexts)\n",
|
186 |
+
"\n",
|
187 |
+
"\n",
|
188 |
+
"rag_pipeline(\n",
|
189 |
+
" \"When did the World Health Organization formally declare an end to the COVID-19 global health emergency?\"\n",
|
190 |
+
")"
|
191 |
+
]
|
192 |
+
},
|
193 |
+
{
|
194 |
+
"cell_type": "markdown",
|
195 |
+
"metadata": {},
|
196 |
+
"source": [
|
197 |
+
"## Debug Trace\n",
|
198 |
+
"\n",
|
199 |
+
"The output is unfortunately wrong! Using the traced pipeline, we can see that\n",
|
200 |
+
"\n",
|
201 |
+
"- the context is relevant to the question and contains the correct information\n",
|
202 |
+
"- but, the generation step is cut off as max tokens is set to 16\n",
|
203 |
+
"\n",
|
204 |
+
"When opening the generation step in the playground and rerunning the prompt with max. tokens set to 1000, the correct answer is produced.\n",
|
205 |
+
"\n",
|
206 |
+
""
|
207 |
+
]
|
208 |
+
},
|
209 |
+
{
|
210 |
+
"cell_type": "markdown",
|
211 |
+
"metadata": {},
|
212 |
+
"source": [
|
213 |
+
"## Evaluating RAG Pipelines\n",
|
214 |
+
"\n",
|
215 |
+
"Before we apply above's fix, let's dive into evaluating RAG pipelines.\n",
|
216 |
+
"\n",
|
217 |
+
"RAG pipelines consist of a retrieval step to fetch relevant information and a generation step to generate a response to a users question. A RAG pipeline can fail at either step. E.g. the retrieval step can fail to find relevant information which makes generating the correct impossible. Another failure mode is that the generation step doesn't leverage the retrieved information correctly. We will apply the following evaluation metrics to understand different failure modes:\n",
|
218 |
+
"\n",
|
219 |
+
"- `context_relevancy`: measures how relevant the context is given the question\n",
|
220 |
+
"- `percent_target_supported_by_context`: measures how much of the target answer is supported by the context; this will give an upper ceiling of how well the generation step can perform\n",
|
221 |
+
"- `answer_context_faithfulness`: measures how much the generated answer utilizes the context\n",
|
222 |
+
"- `answer_matches_target`: measures how well the generated answer matches the target answer judged by a LLM and gives a sense of accuracy of our entire pipeline\n",
|
223 |
+
"\n",
|
224 |
+
"To use these evaluation metrics, we can import them from `parea.evals.rag` and `parea.evals.general` and apply them to a function by specifying in the `trace` decorator which evaluation metrics to use. The `@trace` decorator will automatically log the results of the evaluation metrics to the Parea dashboard.\n",
|
225 |
+
"\n",
|
226 |
+
"Applying them to the retrieval step:"
|
227 |
+
]
|
228 |
+
},
|
229 |
+
{
|
230 |
+
"cell_type": "code",
|
231 |
+
"execution_count": null,
|
232 |
+
"metadata": {},
|
233 |
+
"outputs": [],
|
234 |
+
"source": [
|
235 |
+
"from parea.evals.rag import (\n",
|
236 |
+
" context_query_relevancy_factory,\n",
|
237 |
+
" percent_target_supported_by_context_factory,\n",
|
238 |
+
")\n",
|
239 |
+
"\n",
|
240 |
+
"\n",
|
241 |
+
"context_relevancy_eval = context_query_relevancy_factory()\n",
|
242 |
+
"percent_target_supported_by_context = percent_target_supported_by_context_factory()\n",
|
243 |
+
"\n",
|
244 |
+
"\n",
|
245 |
+
"@trace(eval_funcs=[context_relevancy_eval, percent_target_supported_by_context])\n",
|
246 |
+
"def retrieval(question: str) -> List[str]:\n",
|
247 |
+
" return collection.query(query_texts=[question], n_results=1)[\"documents\"][0]"
|
248 |
+
]
|
249 |
+
},
|
250 |
+
{
|
251 |
+
"cell_type": "markdown",
|
252 |
+
"metadata": {},
|
253 |
+
"source": [
|
254 |
+
"Now we can apply `answer_context_faithfulness` and `answer_matches_target` to the generation step."
|
255 |
+
]
|
256 |
+
},
|
257 |
+
{
|
258 |
+
"cell_type": "code",
|
259 |
+
"execution_count": null,
|
260 |
+
"metadata": {},
|
261 |
+
"outputs": [],
|
262 |
+
"source": [
|
263 |
+
"from parea.evals.general import answer_matches_target_llm_grader_factory\n",
|
264 |
+
"from parea.evals.rag import answer_context_faithfulness_statement_level_factory\n",
|
265 |
+
"\n",
|
266 |
+
"\n",
|
267 |
+
"answer_context_faithfulness = answer_context_faithfulness_statement_level_factory()\n",
|
268 |
+
"answer_matches_target_llm_grader = answer_matches_target_llm_grader_factory()\n",
|
269 |
+
"\n",
|
270 |
+
"\n",
|
271 |
+
"@function\n",
|
272 |
+
"def generation_sglang(s, question: str, *context: str):\n",
|
273 |
+
" context = \"\\n\".join(context)\n",
|
274 |
+
" s += user(\n",
|
275 |
+
" f\"Given this question:\\n{question}\\n\\nAnd this context:\\n{context}\\n\\nAnswer the question.\"\n",
|
276 |
+
" )\n",
|
277 |
+
" s += assistant(gen(\"answer\", max_tokens=1_000))\n",
|
278 |
+
"\n",
|
279 |
+
"\n",
|
280 |
+
"@trace(eval_funcs=[answer_context_faithfulness, answer_matches_target_llm_grader])\n",
|
281 |
+
"def generation(question: str, *context):\n",
|
282 |
+
" state: ProgramState = generation_sglang.run(question, *context)\n",
|
283 |
+
" while not state.stream_executor.is_finished:\n",
|
284 |
+
" time.sleep(1)\n",
|
285 |
+
" return state.stream_executor.variables[\"answer\"]"
|
286 |
+
]
|
287 |
+
},
|
288 |
+
{
|
289 |
+
"cell_type": "markdown",
|
290 |
+
"metadata": {},
|
291 |
+
"source": [
|
292 |
+
"Finally, we tie them together & execute the original sample query."
|
293 |
+
]
|
294 |
+
},
|
295 |
+
{
|
296 |
+
"cell_type": "code",
|
297 |
+
"execution_count": null,
|
298 |
+
"metadata": {},
|
299 |
+
"outputs": [],
|
300 |
+
"source": [
|
301 |
+
"@trace\n",
|
302 |
+
"def rag_pipeline(question: str) -> str:\n",
|
303 |
+
" contexts = retrieval(question)\n",
|
304 |
+
" return generation(question, *contexts)\n",
|
305 |
+
"\n",
|
306 |
+
"\n",
|
307 |
+
"rag_pipeline(\n",
|
308 |
+
" \"When did the World Health Organization formally declare an end to the COVID-19 global health emergency?\"\n",
|
309 |
+
")"
|
310 |
+
]
|
311 |
+
},
|
312 |
+
{
|
313 |
+
"cell_type": "markdown",
|
314 |
+
"metadata": {},
|
315 |
+
"source": [
|
316 |
+
"Great, the answer is correct! Can you spot the line where we fixed the output truncation issue?\n",
|
317 |
+
"\n",
|
318 |
+
"The evaluation scores appear in the bottom right of the logs (screenshot below). Note, that there is no score for `answer_matches_target_llm_grader` and `percent_target_supported_by_context` as these evals are automatically skipped if the target answer is not provided.\n",
|
319 |
+
"\n",
|
320 |
+
""
|
321 |
+
]
|
322 |
+
},
|
323 |
+
{
|
324 |
+
"cell_type": "markdown",
|
325 |
+
"metadata": {},
|
326 |
+
"source": [
|
327 |
+
"## Running an experiment\n",
|
328 |
+
"\n",
|
329 |
+
"Now we are (almost) ready to evaluate the performance of our RAG pipeline on the entire dataset. First, we will need to apply the `nest_asyncio` package to avoid issues with the Jupyter notebook event loop."
|
330 |
+
]
|
331 |
+
},
|
332 |
+
{
|
333 |
+
"cell_type": "code",
|
334 |
+
"execution_count": null,
|
335 |
+
"metadata": {},
|
336 |
+
"outputs": [],
|
337 |
+
"source": [
|
338 |
+
"!pip install nest-asyncio\n",
|
339 |
+
"import nest_asyncio\n",
|
340 |
+
"\n",
|
341 |
+
"nest_asyncio.apply()"
|
342 |
+
]
|
343 |
+
},
|
344 |
+
{
|
345 |
+
"cell_type": "markdown",
|
346 |
+
"metadata": {},
|
347 |
+
"source": [
|
348 |
+
"Running the actual experiment is straight-forward. For that we use `p.experiment` to initialize the experiment with a name, the data (list of key-value pairs fed into our entry function) and the entry function. We then call `run` on the experiment to execute it. Note, that `target` is a reserved key in the data dictionary and will be used as the target answer for evaluation."
|
349 |
+
]
|
350 |
+
},
|
351 |
+
{
|
352 |
+
"cell_type": "code",
|
353 |
+
"execution_count": null,
|
354 |
+
"metadata": {},
|
355 |
+
"outputs": [],
|
356 |
+
"source": [
|
357 |
+
"e = p.experiment(\n",
|
358 |
+
" \"RAG\",\n",
|
359 |
+
" data=[\n",
|
360 |
+
" {\n",
|
361 |
+
" \"question\": qca[\"question\"],\n",
|
362 |
+
" \"target\": qca[\"answer\"],\n",
|
363 |
+
" }\n",
|
364 |
+
" for qca in question_context_answers\n",
|
365 |
+
" ],\n",
|
366 |
+
" func=rag_pipeline,\n",
|
367 |
+
").run()"
|
368 |
+
]
|
369 |
+
},
|
370 |
+
{
|
371 |
+
"cell_type": "markdown",
|
372 |
+
"metadata": {},
|
373 |
+
"source": [
|
374 |
+
"## Analyzing the results\n",
|
375 |
+
"\n",
|
376 |
+
"When opening above experiment, we will see an overview of the experiment as shown below. The upper half shows a summary of the statistics on the left and charts to investigate the distribution and relationships of scores on the right. The lower half is a table with the individual traces which we can use to debug individual samples.\n",
|
377 |
+
"\n",
|
378 |
+
"When looking at the statistics, we can see that the accuracy of our RAG pipeline is 22% as measured by `answer_matches_target_llm_grader`. Though when checking the quality of our retrieval step (`context_query_relevancy`), we can see that our retrival step is fetching relevant information in only 27% of all samples. As shown in the GIF, we investigate the relationship between the two and see the two scores have 95% agreement. This confirms that the retrieval step is a major bottleneck for our RAG pipeline. So, now it's your turn to improve the retrieval step!\n",
|
379 |
+
"\n",
|
380 |
+
"Note, above link isn't publicly accessible but the experiment can be accessed through [here](https://app.parea.ai/public-experiments/parea/rag_sglang/30f0244a-d56c-44ff-bdfb-8f47626304b6).\n",
|
381 |
+
"\n",
|
382 |
+
""
|
383 |
+
]
|
384 |
+
},
|
385 |
+
{
|
386 |
+
"cell_type": "code",
|
387 |
+
"execution_count": null,
|
388 |
+
"metadata": {},
|
389 |
+
"outputs": [],
|
390 |
+
"source": []
|
391 |
+
}
|
392 |
+
],
|
393 |
+
"metadata": {
|
394 |
+
"language_info": {
|
395 |
+
"codemirror_mode": {
|
396 |
+
"name": "ipython",
|
397 |
+
"version": 2
|
398 |
+
},
|
399 |
+
"file_extension": ".py",
|
400 |
+
"mimetype": "text/x-python",
|
401 |
+
"name": "python",
|
402 |
+
"nbconvert_exporter": "python",
|
403 |
+
"pygments_lexer": "ipython2"
|
404 |
+
}
|
405 |
+
},
|
406 |
+
"nbformat": 4,
|
407 |
+
"nbformat_minor": 0
|
408 |
+
}
|
sglang/examples/frontend_language/usage/readme_examples.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Usage:
|
3 |
+
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
|
4 |
+
python readme_examples.py
|
5 |
+
"""
|
6 |
+
|
7 |
+
import sglang as sgl
|
8 |
+
|
9 |
+
|
10 |
+
@sgl.function
|
11 |
+
def tool_use(s, question):
|
12 |
+
s += "To answer this question: " + question + ". "
|
13 |
+
s += (
|
14 |
+
"I need to use a "
|
15 |
+
+ sgl.gen("tool", choices=["calculator", "search engine"])
|
16 |
+
+ ". "
|
17 |
+
)
|
18 |
+
|
19 |
+
if s["tool"] == "calculator":
|
20 |
+
s += "The math expression is" + sgl.gen("expression")
|
21 |
+
elif s["tool"] == "search engine":
|
22 |
+
s += "The key word to search is" + sgl.gen("word")
|
23 |
+
|
24 |
+
|
25 |
+
@sgl.function
|
26 |
+
def tip_suggestion(s):
|
27 |
+
s += (
|
28 |
+
"Here are two tips for staying healthy: "
|
29 |
+
"1. Balanced Diet. 2. Regular Exercise.\n\n"
|
30 |
+
)
|
31 |
+
|
32 |
+
forks = s.fork(2)
|
33 |
+
for i, f in enumerate(forks):
|
34 |
+
f += f"Now, expand tip {i+1} into a paragraph:\n"
|
35 |
+
f += sgl.gen(f"detailed_tip", max_tokens=256, stop="\n\n")
|
36 |
+
|
37 |
+
s += "Tip 1:" + forks[0]["detailed_tip"] + "\n"
|
38 |
+
s += "Tip 2:" + forks[1]["detailed_tip"] + "\n"
|
39 |
+
s += "In summary" + sgl.gen("summary")
|
40 |
+
|
41 |
+
|
42 |
+
@sgl.function
|
43 |
+
def regular_expression_gen(s):
|
44 |
+
s += "Q: What is the IP address of the Google DNS servers?\n"
|
45 |
+
s += "A: " + sgl.gen(
|
46 |
+
"answer",
|
47 |
+
temperature=0,
|
48 |
+
regex=r"((25[0-5]|2[0-4]\d|[01]?\d\d?).){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)",
|
49 |
+
)
|
50 |
+
|
51 |
+
|
52 |
+
@sgl.function
|
53 |
+
def text_qa(s, question):
|
54 |
+
s += "Q: " + question + "\n"
|
55 |
+
s += "A:" + sgl.gen("answer", stop="\n")
|
56 |
+
|
57 |
+
|
58 |
+
def driver_tool_use():
|
59 |
+
state = tool_use.run(question="What is the capital of the United States?")
|
60 |
+
print(state.text())
|
61 |
+
print("\n")
|
62 |
+
|
63 |
+
|
64 |
+
def driver_tip_suggestion():
|
65 |
+
state = tip_suggestion.run()
|
66 |
+
print(state.text())
|
67 |
+
print("\n")
|
68 |
+
|
69 |
+
|
70 |
+
def driver_regex():
|
71 |
+
state = regular_expression_gen.run()
|
72 |
+
print(state.text())
|
73 |
+
print("\n")
|
74 |
+
|
75 |
+
|
76 |
+
def driver_batching():
|
77 |
+
states = text_qa.run_batch(
|
78 |
+
[
|
79 |
+
{"question": "What is the capital of the United Kingdom?"},
|
80 |
+
{"question": "What is the capital of France?"},
|
81 |
+
{"question": "What is the capital of Japan?"},
|
82 |
+
],
|
83 |
+
progress_bar=True,
|
84 |
+
)
|
85 |
+
|
86 |
+
for s in states:
|
87 |
+
print(s.text())
|
88 |
+
print("\n")
|
89 |
+
|
90 |
+
|
91 |
+
def driver_stream():
|
92 |
+
state = text_qa.run(
|
93 |
+
question="What is the capital of France?", temperature=0.1, stream=True
|
94 |
+
)
|
95 |
+
|
96 |
+
for out in state.text_iter():
|
97 |
+
print(out, end="", flush=True)
|
98 |
+
print("\n")
|
99 |
+
|
100 |
+
|
101 |
+
if __name__ == "__main__":
|
102 |
+
# sgl.set_default_backend(sgl.OpenAI("gpt-3.5-turbo-instruct"))
|
103 |
+
sgl.set_default_backend(sgl.RuntimeEndpoint("http://localhost:30000"))
|
104 |
+
|
105 |
+
driver_tool_use()
|
106 |
+
driver_tip_suggestion()
|
107 |
+
driver_regex()
|
108 |
+
driver_batching()
|
109 |
+
driver_stream()
|
sglang/examples/frontend_language/usage/sgl_gen_min_tokens.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This example demonstrates how to use `min_tokens` to enforce sgl.gen to generate a longer sequence
|
3 |
+
|
4 |
+
Usage:
|
5 |
+
python3 sgl_gen_min_tokens.py
|
6 |
+
"""
|
7 |
+
|
8 |
+
import sglang as sgl
|
9 |
+
|
10 |
+
|
11 |
+
@sgl.function
|
12 |
+
def long_answer(s):
|
13 |
+
s += sgl.user("What is the capital of the United States?")
|
14 |
+
s += sgl.assistant(sgl.gen("answer", min_tokens=64, max_tokens=128))
|
15 |
+
|
16 |
+
|
17 |
+
@sgl.function
|
18 |
+
def short_answer(s):
|
19 |
+
s += sgl.user("What is the capital of the United States?")
|
20 |
+
s += sgl.assistant(sgl.gen("answer"))
|
21 |
+
|
22 |
+
|
23 |
+
if __name__ == "__main__":
|
24 |
+
runtime = sgl.Runtime(model_path="meta-llama/Meta-Llama-3.1-8B-Instruct")
|
25 |
+
sgl.set_default_backend(runtime)
|
26 |
+
|
27 |
+
state = long_answer.run()
|
28 |
+
print("=" * 20)
|
29 |
+
print("Longer Answer", state["answer"])
|
30 |
+
|
31 |
+
state = short_answer.run()
|
32 |
+
print("=" * 20)
|
33 |
+
print("Short Answer", state["answer"])
|
34 |
+
|
35 |
+
runtime.shutdown()
|
sglang/examples/frontend_language/usage/streaming.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Usage:
|
3 |
+
python3 streaming.py
|
4 |
+
"""
|
5 |
+
|
6 |
+
import asyncio
|
7 |
+
|
8 |
+
import sglang as sgl
|
9 |
+
|
10 |
+
|
11 |
+
@sgl.function
|
12 |
+
def multi_turn_question(s, question_1, question_2):
|
13 |
+
s += sgl.system("You are a helpful assistant.")
|
14 |
+
s += sgl.user(question_1)
|
15 |
+
s += sgl.assistant(sgl.gen("answer_1", max_tokens=256))
|
16 |
+
s += sgl.user(question_2)
|
17 |
+
s += sgl.assistant(sgl.gen("answer_2", max_tokens=256))
|
18 |
+
|
19 |
+
|
20 |
+
sgl.set_default_backend(sgl.OpenAI("gpt-3.5-turbo"))
|
21 |
+
|
22 |
+
|
23 |
+
def stream_a_variable():
|
24 |
+
state = multi_turn_question.run(
|
25 |
+
question_1="What is the capital of the United States?",
|
26 |
+
question_2="List two local attractions.",
|
27 |
+
stream=True,
|
28 |
+
)
|
29 |
+
|
30 |
+
for out in state.text_iter(var_name="answer_2"):
|
31 |
+
print(out, end="", flush=True)
|
32 |
+
print("\n")
|
33 |
+
|
34 |
+
|
35 |
+
async def async_stream():
|
36 |
+
state = multi_turn_question.run(
|
37 |
+
question_1="What is the capital of the United States?",
|
38 |
+
question_2="List two local attractions.",
|
39 |
+
stream=True,
|
40 |
+
)
|
41 |
+
|
42 |
+
async for out in state.text_async_iter(var_name="answer_2"):
|
43 |
+
print(out, end="", flush=True)
|
44 |
+
print("\n")
|
45 |
+
|
46 |
+
|
47 |
+
if __name__ == "__main__":
|
48 |
+
stream_a_variable()
|
49 |
+
asyncio.run(async_stream())
|
sglang/examples/frontend_language/usage/triton/Dockerfile
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM nvcr.io/nvidia/tritonserver:24.01-py3
|
2 |
+
|
3 |
+
WORKDIR /opt
|
4 |
+
|
5 |
+
RUN git clone https://github.com/sgl-project/sglang.git
|
6 |
+
|
7 |
+
WORKDIR /opt/sglang
|
8 |
+
RUN pip install --upgrade pip && \
|
9 |
+
pip install -e "python[all]" && \
|
10 |
+
pip install datasets
|
sglang/examples/frontend_language/usage/triton/README.md
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# sglang_triton
|
2 |
+
|
3 |
+
Build the docker image:
|
4 |
+
```
|
5 |
+
docker build -t sglang-triton .
|
6 |
+
```
|
7 |
+
|
8 |
+
Then do:
|
9 |
+
```
|
10 |
+
docker run -ti --gpus=all --network=host --name sglang-triton -v ./models:/mnt/models sglang-triton
|
11 |
+
```
|
12 |
+
|
13 |
+
inside the docker container:
|
14 |
+
```
|
15 |
+
cd sglang
|
16 |
+
python3 -m sglang.launch_server --model-path mistralai/Mistral-7B-Instruct-v0.2 --port 30000 --mem-fraction-static 0.9
|
17 |
+
```
|
18 |
+
|
19 |
+
with another shell, inside the docker container:
|
20 |
+
```
|
21 |
+
docker exec -ti sglang-triton /bin/bash
|
22 |
+
cd /mnt
|
23 |
+
tritonserver --model-repository=/mnt/models
|
24 |
+
```
|
25 |
+
|
26 |
+
|
27 |
+
Send request to the server:
|
28 |
+
```
|
29 |
+
curl -X POST http://localhost:8000/v2/models/character_generation/generate \
|
30 |
+
-H "Content-Type: application/json" \
|
31 |
+
-d '{
|
32 |
+
"INPUT_TEXT": ["harry"]
|
33 |
+
}'
|
34 |
+
|
35 |
+
```
|
sglang/examples/frontend_language/usage/triton/models/character_generation/1/model.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy
|
2 |
+
import triton_python_backend_utils as pb_utils
|
3 |
+
from pydantic import BaseModel
|
4 |
+
|
5 |
+
import sglang as sgl
|
6 |
+
from sglang import function, set_default_backend
|
7 |
+
from sglang.srt.constrained import build_regex_from_object
|
8 |
+
|
9 |
+
sgl.set_default_backend(sgl.RuntimeEndpoint("http://localhost:30000"))
|
10 |
+
|
11 |
+
|
12 |
+
class Character(BaseModel):
|
13 |
+
name: str
|
14 |
+
eye_color: str
|
15 |
+
house: str
|
16 |
+
|
17 |
+
|
18 |
+
@function
|
19 |
+
def character_gen(s, name):
|
20 |
+
s += (
|
21 |
+
name
|
22 |
+
+ " is a character in Harry Potter. Please fill in the following information about this character.\n"
|
23 |
+
)
|
24 |
+
s += sgl.gen(
|
25 |
+
"json_output", max_tokens=256, regex=build_regex_from_object(Character)
|
26 |
+
)
|
27 |
+
|
28 |
+
|
29 |
+
class TritonPythonModel:
|
30 |
+
def initialize(self, args):
|
31 |
+
print("Initialized.")
|
32 |
+
|
33 |
+
def execute(self, requests):
|
34 |
+
responses = []
|
35 |
+
for request in requests:
|
36 |
+
tensor_in = pb_utils.get_input_tensor_by_name(request, "INPUT_TEXT")
|
37 |
+
if tensor_in is None:
|
38 |
+
return pb_utils.InferenceResponse(output_tensors=[])
|
39 |
+
|
40 |
+
input_list_names = [
|
41 |
+
i.decode("utf-8") if isinstance(i, bytes) else i
|
42 |
+
for i in tensor_in.as_numpy().tolist()
|
43 |
+
]
|
44 |
+
|
45 |
+
input_list_dicts = [{"name": i} for i in input_list_names]
|
46 |
+
|
47 |
+
states = character_gen.run_batch(input_list_dicts)
|
48 |
+
character_strs = [state.text() for state in states]
|
49 |
+
|
50 |
+
tensor_out = pb_utils.Tensor(
|
51 |
+
"OUTPUT_TEXT", numpy.array(character_strs, dtype=object)
|
52 |
+
)
|
53 |
+
|
54 |
+
responses.append(pb_utils.InferenceResponse(output_tensors=[tensor_out]))
|
55 |
+
return responses
|
sglang/examples/frontend_language/usage/triton/models/character_generation/config.pbtxt
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: "character_generation"
|
2 |
+
backend: "python"
|
3 |
+
input [
|
4 |
+
{
|
5 |
+
name: "INPUT_TEXT"
|
6 |
+
data_type: TYPE_STRING
|
7 |
+
dims: [ -1 ]
|
8 |
+
}
|
9 |
+
]
|
10 |
+
output [
|
11 |
+
{
|
12 |
+
name: "OUTPUT_TEXT"
|
13 |
+
data_type: TYPE_STRING
|
14 |
+
dims: [ -1 ]
|
15 |
+
}
|
16 |
+
]
|
17 |
+
instance_group [
|
18 |
+
{
|
19 |
+
count: 1
|
20 |
+
kind: KIND_GPU
|
21 |
+
gpus: [ 0 ]
|
22 |
+
}
|
23 |
+
]
|
sglang/examples/monitoring/grafana.json
ADDED
@@ -0,0 +1,1720 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"annotations": {
|
3 |
+
"list": [
|
4 |
+
{
|
5 |
+
"builtIn": 1,
|
6 |
+
"datasource": {
|
7 |
+
"type": "grafana",
|
8 |
+
"uid": "-- Grafana --"
|
9 |
+
},
|
10 |
+
"enable": true,
|
11 |
+
"hide": true,
|
12 |
+
"iconColor": "rgba(0, 211, 255, 1)",
|
13 |
+
"name": "Annotations & Alerts",
|
14 |
+
"type": "dashboard"
|
15 |
+
}
|
16 |
+
]
|
17 |
+
},
|
18 |
+
"editable": true,
|
19 |
+
"fiscalYearStartMonth": 0,
|
20 |
+
"graphTooltip": 0,
|
21 |
+
"id": 1,
|
22 |
+
"links": [],
|
23 |
+
"panels": [
|
24 |
+
{
|
25 |
+
"datasource": {
|
26 |
+
"default": true,
|
27 |
+
"type": "prometheus",
|
28 |
+
"uid": "ee2vha8w6f5kwf"
|
29 |
+
},
|
30 |
+
"description": "max-running-requests from server argument",
|
31 |
+
"fieldConfig": {
|
32 |
+
"defaults": {
|
33 |
+
"color": {
|
34 |
+
"mode": "thresholds"
|
35 |
+
},
|
36 |
+
"mappings": [],
|
37 |
+
"thresholds": {
|
38 |
+
"mode": "absolute",
|
39 |
+
"steps": [
|
40 |
+
{
|
41 |
+
"color": "green",
|
42 |
+
"value": null
|
43 |
+
}
|
44 |
+
]
|
45 |
+
}
|
46 |
+
},
|
47 |
+
"overrides": []
|
48 |
+
},
|
49 |
+
"gridPos": {
|
50 |
+
"h": 3,
|
51 |
+
"w": 3,
|
52 |
+
"x": 0,
|
53 |
+
"y": 0
|
54 |
+
},
|
55 |
+
"id": 2,
|
56 |
+
"options": {
|
57 |
+
"colorMode": "value",
|
58 |
+
"graphMode": "none",
|
59 |
+
"justifyMode": "auto",
|
60 |
+
"orientation": "auto",
|
61 |
+
"percentChangeColorMode": "standard",
|
62 |
+
"reduceOptions": {
|
63 |
+
"calcs": [
|
64 |
+
"last"
|
65 |
+
],
|
66 |
+
"fields": "",
|
67 |
+
"values": false
|
68 |
+
},
|
69 |
+
"showPercentChange": false,
|
70 |
+
"textMode": "auto",
|
71 |
+
"wideLayout": true
|
72 |
+
},
|
73 |
+
"pluginVersion": "11.2.0",
|
74 |
+
"targets": [
|
75 |
+
{
|
76 |
+
"datasource": {
|
77 |
+
"type": "prometheus",
|
78 |
+
"uid": "ddyfngn31dg5cf"
|
79 |
+
},
|
80 |
+
"disableTextWrap": false,
|
81 |
+
"editorMode": "builder",
|
82 |
+
"expr": "sglang:max_running_requests{name=\"$name\", instance=\"$instance\"}",
|
83 |
+
"fullMetaSearch": false,
|
84 |
+
"includeNullMetadata": true,
|
85 |
+
"instant": false,
|
86 |
+
"legendFormat": "__auto",
|
87 |
+
"range": true,
|
88 |
+
"refId": "A",
|
89 |
+
"useBackend": false
|
90 |
+
}
|
91 |
+
],
|
92 |
+
"title": "Max Running Requests",
|
93 |
+
"type": "stat"
|
94 |
+
},
|
95 |
+
{
|
96 |
+
"datasource": {
|
97 |
+
"default": true,
|
98 |
+
"type": "prometheus",
|
99 |
+
"uid": "ee2vha8w6f5kwf"
|
100 |
+
},
|
101 |
+
"description": "Supported context length with loaded model",
|
102 |
+
"fieldConfig": {
|
103 |
+
"defaults": {
|
104 |
+
"color": {
|
105 |
+
"mode": "thresholds"
|
106 |
+
},
|
107 |
+
"mappings": [],
|
108 |
+
"thresholds": {
|
109 |
+
"mode": "absolute",
|
110 |
+
"steps": [
|
111 |
+
{
|
112 |
+
"color": "green",
|
113 |
+
"value": null
|
114 |
+
}
|
115 |
+
]
|
116 |
+
}
|
117 |
+
},
|
118 |
+
"overrides": []
|
119 |
+
},
|
120 |
+
"gridPos": {
|
121 |
+
"h": 3,
|
122 |
+
"w": 3,
|
123 |
+
"x": 3,
|
124 |
+
"y": 0
|
125 |
+
},
|
126 |
+
"id": 1,
|
127 |
+
"options": {
|
128 |
+
"colorMode": "value",
|
129 |
+
"graphMode": "none",
|
130 |
+
"justifyMode": "auto",
|
131 |
+
"orientation": "auto",
|
132 |
+
"percentChangeColorMode": "standard",
|
133 |
+
"reduceOptions": {
|
134 |
+
"calcs": [
|
135 |
+
"last"
|
136 |
+
],
|
137 |
+
"fields": "",
|
138 |
+
"values": false
|
139 |
+
},
|
140 |
+
"showPercentChange": false,
|
141 |
+
"textMode": "auto",
|
142 |
+
"wideLayout": true
|
143 |
+
},
|
144 |
+
"pluginVersion": "11.2.0",
|
145 |
+
"targets": [
|
146 |
+
{
|
147 |
+
"datasource": {
|
148 |
+
"type": "prometheus",
|
149 |
+
"uid": "ddyfngn31dg5cf"
|
150 |
+
},
|
151 |
+
"disableTextWrap": false,
|
152 |
+
"editorMode": "builder",
|
153 |
+
"expr": "sglang:context_len{instance=\"$instance\", name=\"$name\"}",
|
154 |
+
"fullMetaSearch": false,
|
155 |
+
"includeNullMetadata": true,
|
156 |
+
"instant": false,
|
157 |
+
"legendFormat": "__auto",
|
158 |
+
"range": true,
|
159 |
+
"refId": "A",
|
160 |
+
"useBackend": false
|
161 |
+
}
|
162 |
+
],
|
163 |
+
"title": "Max Context Length",
|
164 |
+
"type": "stat"
|
165 |
+
},
|
166 |
+
{
|
167 |
+
"datasource": {
|
168 |
+
"default": true,
|
169 |
+
"type": "prometheus",
|
170 |
+
"uid": "ee2vha8w6f5kwf"
|
171 |
+
},
|
172 |
+
"description": "max_total_tokens",
|
173 |
+
"fieldConfig": {
|
174 |
+
"defaults": {
|
175 |
+
"color": {
|
176 |
+
"mode": "thresholds"
|
177 |
+
},
|
178 |
+
"mappings": [],
|
179 |
+
"thresholds": {
|
180 |
+
"mode": "absolute",
|
181 |
+
"steps": [
|
182 |
+
{
|
183 |
+
"color": "green",
|
184 |
+
"value": null
|
185 |
+
}
|
186 |
+
]
|
187 |
+
}
|
188 |
+
},
|
189 |
+
"overrides": []
|
190 |
+
},
|
191 |
+
"gridPos": {
|
192 |
+
"h": 3,
|
193 |
+
"w": 3,
|
194 |
+
"x": 6,
|
195 |
+
"y": 0
|
196 |
+
},
|
197 |
+
"id": 4,
|
198 |
+
"options": {
|
199 |
+
"colorMode": "value",
|
200 |
+
"graphMode": "none",
|
201 |
+
"justifyMode": "auto",
|
202 |
+
"orientation": "auto",
|
203 |
+
"percentChangeColorMode": "standard",
|
204 |
+
"reduceOptions": {
|
205 |
+
"calcs": [
|
206 |
+
"last"
|
207 |
+
],
|
208 |
+
"fields": "",
|
209 |
+
"values": false
|
210 |
+
},
|
211 |
+
"showPercentChange": false,
|
212 |
+
"textMode": "auto",
|
213 |
+
"wideLayout": true
|
214 |
+
},
|
215 |
+
"pluginVersion": "11.2.0",
|
216 |
+
"targets": [
|
217 |
+
{
|
218 |
+
"datasource": {
|
219 |
+
"type": "prometheus",
|
220 |
+
"uid": "ddyfngn31dg5cf"
|
221 |
+
},
|
222 |
+
"disableTextWrap": false,
|
223 |
+
"editorMode": "builder",
|
224 |
+
"expr": "sglang:max_total_num_tokens{instance=\"$instance\", name=\"$name\"}",
|
225 |
+
"fullMetaSearch": false,
|
226 |
+
"includeNullMetadata": true,
|
227 |
+
"instant": false,
|
228 |
+
"legendFormat": "__auto",
|
229 |
+
"range": true,
|
230 |
+
"refId": "A",
|
231 |
+
"useBackend": false
|
232 |
+
}
|
233 |
+
],
|
234 |
+
"title": "Max Total Num Tokens",
|
235 |
+
"type": "stat"
|
236 |
+
},
|
237 |
+
{
|
238 |
+
"datasource": {
|
239 |
+
"default": true,
|
240 |
+
"type": "prometheus",
|
241 |
+
"uid": "ee2vha8w6f5kwf"
|
242 |
+
},
|
243 |
+
"description": "max_prefill_tokens from server args",
|
244 |
+
"fieldConfig": {
|
245 |
+
"defaults": {
|
246 |
+
"color": {
|
247 |
+
"mode": "thresholds"
|
248 |
+
},
|
249 |
+
"mappings": [],
|
250 |
+
"thresholds": {
|
251 |
+
"mode": "absolute",
|
252 |
+
"steps": [
|
253 |
+
{
|
254 |
+
"color": "green",
|
255 |
+
"value": null
|
256 |
+
}
|
257 |
+
]
|
258 |
+
}
|
259 |
+
},
|
260 |
+
"overrides": []
|
261 |
+
},
|
262 |
+
"gridPos": {
|
263 |
+
"h": 3,
|
264 |
+
"w": 3,
|
265 |
+
"x": 9,
|
266 |
+
"y": 0
|
267 |
+
},
|
268 |
+
"id": 3,
|
269 |
+
"options": {
|
270 |
+
"colorMode": "value",
|
271 |
+
"graphMode": "none",
|
272 |
+
"justifyMode": "auto",
|
273 |
+
"orientation": "auto",
|
274 |
+
"percentChangeColorMode": "standard",
|
275 |
+
"reduceOptions": {
|
276 |
+
"calcs": [
|
277 |
+
"last"
|
278 |
+
],
|
279 |
+
"fields": "",
|
280 |
+
"values": false
|
281 |
+
},
|
282 |
+
"showPercentChange": false,
|
283 |
+
"textMode": "auto",
|
284 |
+
"wideLayout": true
|
285 |
+
},
|
286 |
+
"pluginVersion": "11.2.0",
|
287 |
+
"targets": [
|
288 |
+
{
|
289 |
+
"datasource": {
|
290 |
+
"type": "prometheus",
|
291 |
+
"uid": "ddyfngn31dg5cf"
|
292 |
+
},
|
293 |
+
"disableTextWrap": false,
|
294 |
+
"editorMode": "code",
|
295 |
+
"expr": "sglang:max_prefill_tokens{instance=\"$instance\", name=\"$name\"}",
|
296 |
+
"fullMetaSearch": false,
|
297 |
+
"includeNullMetadata": true,
|
298 |
+
"instant": false,
|
299 |
+
"legendFormat": "__auto",
|
300 |
+
"range": true,
|
301 |
+
"refId": "A",
|
302 |
+
"useBackend": false
|
303 |
+
}
|
304 |
+
],
|
305 |
+
"title": "Max Prefill Tokens",
|
306 |
+
"type": "stat"
|
307 |
+
},
|
308 |
+
{
|
309 |
+
"datasource": {
|
310 |
+
"default": true,
|
311 |
+
"type": "prometheus",
|
312 |
+
"uid": "ee2vha8w6f5kwf"
|
313 |
+
},
|
314 |
+
"fieldConfig": {
|
315 |
+
"defaults": {
|
316 |
+
"color": {
|
317 |
+
"mode": "thresholds"
|
318 |
+
},
|
319 |
+
"mappings": [],
|
320 |
+
"thresholds": {
|
321 |
+
"mode": "absolute",
|
322 |
+
"steps": [
|
323 |
+
{
|
324 |
+
"color": "green",
|
325 |
+
"value": null
|
326 |
+
}
|
327 |
+
]
|
328 |
+
}
|
329 |
+
},
|
330 |
+
"overrides": []
|
331 |
+
},
|
332 |
+
"gridPos": {
|
333 |
+
"h": 3,
|
334 |
+
"w": 6,
|
335 |
+
"x": 12,
|
336 |
+
"y": 0
|
337 |
+
},
|
338 |
+
"id": 6,
|
339 |
+
"options": {
|
340 |
+
"colorMode": "value",
|
341 |
+
"graphMode": "area",
|
342 |
+
"justifyMode": "auto",
|
343 |
+
"orientation": "auto",
|
344 |
+
"percentChangeColorMode": "standard",
|
345 |
+
"reduceOptions": {
|
346 |
+
"calcs": [
|
347 |
+
"lastNotNull"
|
348 |
+
],
|
349 |
+
"fields": "",
|
350 |
+
"values": false
|
351 |
+
},
|
352 |
+
"showPercentChange": false,
|
353 |
+
"textMode": "auto",
|
354 |
+
"wideLayout": true
|
355 |
+
},
|
356 |
+
"pluginVersion": "11.2.0",
|
357 |
+
"targets": [
|
358 |
+
{
|
359 |
+
"datasource": {
|
360 |
+
"type": "prometheus",
|
361 |
+
"uid": "ddyfngn31dg5cf"
|
362 |
+
},
|
363 |
+
"disableTextWrap": false,
|
364 |
+
"editorMode": "code",
|
365 |
+
"expr": "sglang:cached_token{instance=\"$instance\", name=\"$name\"}",
|
366 |
+
"fullMetaSearch": false,
|
367 |
+
"includeNullMetadata": true,
|
368 |
+
"instant": false,
|
369 |
+
"legendFormat": "{{__name__}}",
|
370 |
+
"range": true,
|
371 |
+
"refId": "A",
|
372 |
+
"useBackend": false
|
373 |
+
}
|
374 |
+
],
|
375 |
+
"title": "Cached Tokens",
|
376 |
+
"type": "stat"
|
377 |
+
},
|
378 |
+
{
|
379 |
+
"datasource": {
|
380 |
+
"default": true,
|
381 |
+
"type": "prometheus",
|
382 |
+
"uid": "ee2vha8w6f5kwf"
|
383 |
+
},
|
384 |
+
"description": "",
|
385 |
+
"fieldConfig": {
|
386 |
+
"defaults": {
|
387 |
+
"color": {
|
388 |
+
"mode": "thresholds"
|
389 |
+
},
|
390 |
+
"mappings": [],
|
391 |
+
"thresholds": {
|
392 |
+
"mode": "absolute",
|
393 |
+
"steps": [
|
394 |
+
{
|
395 |
+
"color": "green",
|
396 |
+
"value": null
|
397 |
+
}
|
398 |
+
]
|
399 |
+
}
|
400 |
+
},
|
401 |
+
"overrides": []
|
402 |
+
},
|
403 |
+
"gridPos": {
|
404 |
+
"h": 3,
|
405 |
+
"w": 6,
|
406 |
+
"x": 18,
|
407 |
+
"y": 0
|
408 |
+
},
|
409 |
+
"id": 5,
|
410 |
+
"options": {
|
411 |
+
"colorMode": "value",
|
412 |
+
"graphMode": "area",
|
413 |
+
"justifyMode": "auto",
|
414 |
+
"orientation": "auto",
|
415 |
+
"percentChangeColorMode": "standard",
|
416 |
+
"reduceOptions": {
|
417 |
+
"calcs": [
|
418 |
+
"lastNotNull"
|
419 |
+
],
|
420 |
+
"fields": "",
|
421 |
+
"values": false
|
422 |
+
},
|
423 |
+
"showPercentChange": false,
|
424 |
+
"textMode": "auto",
|
425 |
+
"wideLayout": true
|
426 |
+
},
|
427 |
+
"pluginVersion": "11.2.0",
|
428 |
+
"targets": [
|
429 |
+
{
|
430 |
+
"datasource": {
|
431 |
+
"type": "prometheus",
|
432 |
+
"uid": "ddyfngn31dg5cf"
|
433 |
+
},
|
434 |
+
"disableTextWrap": false,
|
435 |
+
"editorMode": "code",
|
436 |
+
"expr": "sglang:cache_hit_rate{instance=\"$instance\", name=\"$name\"}",
|
437 |
+
"fullMetaSearch": false,
|
438 |
+
"includeNullMetadata": true,
|
439 |
+
"instant": false,
|
440 |
+
"legendFormat": "{{__name__}}",
|
441 |
+
"range": true,
|
442 |
+
"refId": "A",
|
443 |
+
"useBackend": false
|
444 |
+
}
|
445 |
+
],
|
446 |
+
"title": "Cache Hit Rate (%)",
|
447 |
+
"type": "stat"
|
448 |
+
},
|
449 |
+
{
|
450 |
+
"datasource": {
|
451 |
+
"default": true,
|
452 |
+
"type": "prometheus",
|
453 |
+
"uid": "ee2vha8w6f5kwf"
|
454 |
+
},
|
455 |
+
"fieldConfig": {
|
456 |
+
"defaults": {
|
457 |
+
"color": {
|
458 |
+
"mode": "palette-classic"
|
459 |
+
},
|
460 |
+
"custom": {
|
461 |
+
"axisBorderShow": false,
|
462 |
+
"axisCenteredZero": false,
|
463 |
+
"axisColorMode": "text",
|
464 |
+
"axisLabel": "",
|
465 |
+
"axisPlacement": "auto",
|
466 |
+
"barAlignment": 0,
|
467 |
+
"barWidthFactor": 0.6,
|
468 |
+
"drawStyle": "line",
|
469 |
+
"fillOpacity": 0,
|
470 |
+
"gradientMode": "none",
|
471 |
+
"hideFrom": {
|
472 |
+
"legend": false,
|
473 |
+
"tooltip": false,
|
474 |
+
"viz": false
|
475 |
+
},
|
476 |
+
"insertNulls": false,
|
477 |
+
"lineInterpolation": "linear",
|
478 |
+
"lineWidth": 1,
|
479 |
+
"pointSize": 5,
|
480 |
+
"scaleDistribution": {
|
481 |
+
"type": "linear"
|
482 |
+
},
|
483 |
+
"showPoints": "auto",
|
484 |
+
"spanNulls": false,
|
485 |
+
"stacking": {
|
486 |
+
"group": "A",
|
487 |
+
"mode": "none"
|
488 |
+
},
|
489 |
+
"thresholdsStyle": {
|
490 |
+
"mode": "off"
|
491 |
+
}
|
492 |
+
},
|
493 |
+
"mappings": [],
|
494 |
+
"thresholds": {
|
495 |
+
"mode": "absolute",
|
496 |
+
"steps": [
|
497 |
+
{
|
498 |
+
"color": "green",
|
499 |
+
"value": null
|
500 |
+
},
|
501 |
+
{
|
502 |
+
"color": "red",
|
503 |
+
"value": 80
|
504 |
+
}
|
505 |
+
]
|
506 |
+
}
|
507 |
+
},
|
508 |
+
"overrides": []
|
509 |
+
},
|
510 |
+
"gridPos": {
|
511 |
+
"h": 8,
|
512 |
+
"w": 12,
|
513 |
+
"x": 0,
|
514 |
+
"y": 3
|
515 |
+
},
|
516 |
+
"id": 14,
|
517 |
+
"options": {
|
518 |
+
"legend": {
|
519 |
+
"calcs": [],
|
520 |
+
"displayMode": "list",
|
521 |
+
"placement": "bottom",
|
522 |
+
"showLegend": true
|
523 |
+
},
|
524 |
+
"tooltip": {
|
525 |
+
"mode": "single",
|
526 |
+
"sort": "none"
|
527 |
+
}
|
528 |
+
},
|
529 |
+
"targets": [
|
530 |
+
{
|
531 |
+
"datasource": {
|
532 |
+
"type": "prometheus",
|
533 |
+
"uid": "ddyfngn31dg5cf"
|
534 |
+
},
|
535 |
+
"disableTextWrap": false,
|
536 |
+
"editorMode": "code",
|
537 |
+
"expr": "histogram_quantile(0.99, sum by(le) (rate(sglang:e2e_request_latency_seconds_bucket{instance=\"$instance\", name=\"$name\"}[$__rate_interval])))",
|
538 |
+
"fullMetaSearch": false,
|
539 |
+
"includeNullMetadata": true,
|
540 |
+
"instant": false,
|
541 |
+
"legendFormat": "P99",
|
542 |
+
"range": true,
|
543 |
+
"refId": "A",
|
544 |
+
"useBackend": false
|
545 |
+
},
|
546 |
+
{
|
547 |
+
"datasource": {
|
548 |
+
"type": "prometheus",
|
549 |
+
"uid": "ddyfngn31dg5cf"
|
550 |
+
},
|
551 |
+
"disableTextWrap": false,
|
552 |
+
"editorMode": "code",
|
553 |
+
"expr": "histogram_quantile(0.9, sum by(le) (rate(sglang:e2e_request_latency_seconds_bucket{instance=\"$instance\", name=\"$name\"}[$__rate_interval])))",
|
554 |
+
"fullMetaSearch": false,
|
555 |
+
"hide": false,
|
556 |
+
"includeNullMetadata": true,
|
557 |
+
"instant": false,
|
558 |
+
"legendFormat": "P90",
|
559 |
+
"range": true,
|
560 |
+
"refId": "B",
|
561 |
+
"useBackend": false
|
562 |
+
},
|
563 |
+
{
|
564 |
+
"datasource": {
|
565 |
+
"type": "prometheus",
|
566 |
+
"uid": "ddyfngn31dg5cf"
|
567 |
+
},
|
568 |
+
"disableTextWrap": false,
|
569 |
+
"editorMode": "builder",
|
570 |
+
"expr": "histogram_quantile(0.95, sum by(le) (rate(sglang:e2e_request_latency_seconds_bucket{instance=\"$instance\", name=\"$model_name\"}[$__rate_interval])))",
|
571 |
+
"fullMetaSearch": false,
|
572 |
+
"hide": false,
|
573 |
+
"includeNullMetadata": true,
|
574 |
+
"instant": false,
|
575 |
+
"legendFormat": "P95",
|
576 |
+
"range": true,
|
577 |
+
"refId": "C",
|
578 |
+
"useBackend": false
|
579 |
+
},
|
580 |
+
{
|
581 |
+
"datasource": {
|
582 |
+
"type": "prometheus",
|
583 |
+
"uid": "ddyfngn31dg5cf"
|
584 |
+
},
|
585 |
+
"disableTextWrap": false,
|
586 |
+
"editorMode": "builder",
|
587 |
+
"expr": "histogram_quantile(0.5, sum by(le) (rate(sglang:e2e_request_latency_seconds_bucket{instance=\"$instance\", name=\"$model_name\"}[$__rate_interval])))",
|
588 |
+
"fullMetaSearch": false,
|
589 |
+
"hide": false,
|
590 |
+
"includeNullMetadata": true,
|
591 |
+
"instant": false,
|
592 |
+
"legendFormat": "P50",
|
593 |
+
"range": true,
|
594 |
+
"refId": "D",
|
595 |
+
"useBackend": false
|
596 |
+
},
|
597 |
+
{
|
598 |
+
"datasource": {
|
599 |
+
"type": "prometheus",
|
600 |
+
"uid": "ddyfngn31dg5cf"
|
601 |
+
},
|
602 |
+
"disableTextWrap": false,
|
603 |
+
"editorMode": "builder",
|
604 |
+
"expr": "rate(sglang:e2e_request_latency_seconds_sum{instance=\"$instance\", name=\"$model_name\"}[$__rate_interval]) / rate(sglang:e2e_request_latency_seconds_count[$__rate_interval])",
|
605 |
+
"fullMetaSearch": false,
|
606 |
+
"hide": false,
|
607 |
+
"includeNullMetadata": true,
|
608 |
+
"instant": false,
|
609 |
+
"legendFormat": "Average",
|
610 |
+
"range": true,
|
611 |
+
"refId": "E",
|
612 |
+
"useBackend": false
|
613 |
+
}
|
614 |
+
],
|
615 |
+
"title": "E2E Request Latency (S)",
|
616 |
+
"type": "timeseries"
|
617 |
+
},
|
618 |
+
{
|
619 |
+
"datasource": {
|
620 |
+
"default": true,
|
621 |
+
"type": "prometheus",
|
622 |
+
"uid": "ee2vha8w6f5kwf"
|
623 |
+
},
|
624 |
+
"fieldConfig": {
|
625 |
+
"defaults": {
|
626 |
+
"color": {
|
627 |
+
"mode": "palette-classic"
|
628 |
+
},
|
629 |
+
"custom": {
|
630 |
+
"axisBorderShow": false,
|
631 |
+
"axisCenteredZero": false,
|
632 |
+
"axisColorMode": "text",
|
633 |
+
"axisLabel": "",
|
634 |
+
"axisPlacement": "auto",
|
635 |
+
"barAlignment": 0,
|
636 |
+
"barWidthFactor": 0.6,
|
637 |
+
"drawStyle": "line",
|
638 |
+
"fillOpacity": 0,
|
639 |
+
"gradientMode": "none",
|
640 |
+
"hideFrom": {
|
641 |
+
"legend": false,
|
642 |
+
"tooltip": false,
|
643 |
+
"viz": false
|
644 |
+
},
|
645 |
+
"insertNulls": false,
|
646 |
+
"lineInterpolation": "linear",
|
647 |
+
"lineWidth": 1,
|
648 |
+
"pointSize": 5,
|
649 |
+
"scaleDistribution": {
|
650 |
+
"type": "linear"
|
651 |
+
},
|
652 |
+
"showPoints": "auto",
|
653 |
+
"spanNulls": false,
|
654 |
+
"stacking": {
|
655 |
+
"group": "A",
|
656 |
+
"mode": "none"
|
657 |
+
},
|
658 |
+
"thresholdsStyle": {
|
659 |
+
"mode": "off"
|
660 |
+
}
|
661 |
+
},
|
662 |
+
"mappings": [],
|
663 |
+
"thresholds": {
|
664 |
+
"mode": "absolute",
|
665 |
+
"steps": [
|
666 |
+
{
|
667 |
+
"color": "green",
|
668 |
+
"value": null
|
669 |
+
},
|
670 |
+
{
|
671 |
+
"color": "red",
|
672 |
+
"value": 80
|
673 |
+
}
|
674 |
+
]
|
675 |
+
}
|
676 |
+
},
|
677 |
+
"overrides": []
|
678 |
+
},
|
679 |
+
"gridPos": {
|
680 |
+
"h": 8,
|
681 |
+
"w": 12,
|
682 |
+
"x": 12,
|
683 |
+
"y": 3
|
684 |
+
},
|
685 |
+
"id": 18,
|
686 |
+
"options": {
|
687 |
+
"legend": {
|
688 |
+
"calcs": [],
|
689 |
+
"displayMode": "list",
|
690 |
+
"placement": "bottom",
|
691 |
+
"showLegend": true
|
692 |
+
},
|
693 |
+
"tooltip": {
|
694 |
+
"mode": "single",
|
695 |
+
"sort": "none"
|
696 |
+
}
|
697 |
+
},
|
698 |
+
"targets": [
|
699 |
+
{
|
700 |
+
"datasource": {
|
701 |
+
"type": "prometheus",
|
702 |
+
"uid": "ddyfngn31dg5cf"
|
703 |
+
},
|
704 |
+
"editorMode": "code",
|
705 |
+
"expr": "sglang:gen_throughput{instance=\"$instance\", name=\"$name\"}",
|
706 |
+
"instant": false,
|
707 |
+
"legendFormat": "__auto",
|
708 |
+
"range": true,
|
709 |
+
"refId": "A"
|
710 |
+
}
|
711 |
+
],
|
712 |
+
"title": "Generation Throughput (Token / S)",
|
713 |
+
"type": "timeseries"
|
714 |
+
},
|
715 |
+
{
|
716 |
+
"datasource": {
|
717 |
+
"default": true,
|
718 |
+
"type": "prometheus",
|
719 |
+
"uid": "ee2vha8w6f5kwf"
|
720 |
+
},
|
721 |
+
"fieldConfig": {
|
722 |
+
"defaults": {
|
723 |
+
"color": {
|
724 |
+
"mode": "palette-classic"
|
725 |
+
},
|
726 |
+
"custom": {
|
727 |
+
"axisBorderShow": false,
|
728 |
+
"axisCenteredZero": false,
|
729 |
+
"axisColorMode": "text",
|
730 |
+
"axisLabel": "",
|
731 |
+
"axisPlacement": "auto",
|
732 |
+
"barAlignment": 0,
|
733 |
+
"barWidthFactor": 0.6,
|
734 |
+
"drawStyle": "line",
|
735 |
+
"fillOpacity": 0,
|
736 |
+
"gradientMode": "none",
|
737 |
+
"hideFrom": {
|
738 |
+
"legend": false,
|
739 |
+
"tooltip": false,
|
740 |
+
"viz": false
|
741 |
+
},
|
742 |
+
"insertNulls": false,
|
743 |
+
"lineInterpolation": "linear",
|
744 |
+
"lineWidth": 1,
|
745 |
+
"pointSize": 5,
|
746 |
+
"scaleDistribution": {
|
747 |
+
"type": "linear"
|
748 |
+
},
|
749 |
+
"showPoints": "auto",
|
750 |
+
"spanNulls": false,
|
751 |
+
"stacking": {
|
752 |
+
"group": "A",
|
753 |
+
"mode": "none"
|
754 |
+
},
|
755 |
+
"thresholdsStyle": {
|
756 |
+
"mode": "off"
|
757 |
+
}
|
758 |
+
},
|
759 |
+
"mappings": [],
|
760 |
+
"thresholds": {
|
761 |
+
"mode": "absolute",
|
762 |
+
"steps": [
|
763 |
+
{
|
764 |
+
"color": "green",
|
765 |
+
"value": null
|
766 |
+
},
|
767 |
+
{
|
768 |
+
"color": "red",
|
769 |
+
"value": 80
|
770 |
+
}
|
771 |
+
]
|
772 |
+
}
|
773 |
+
},
|
774 |
+
"overrides": []
|
775 |
+
},
|
776 |
+
"gridPos": {
|
777 |
+
"h": 8,
|
778 |
+
"w": 12,
|
779 |
+
"x": 0,
|
780 |
+
"y": 11
|
781 |
+
},
|
782 |
+
"id": 7,
|
783 |
+
"options": {
|
784 |
+
"legend": {
|
785 |
+
"calcs": [],
|
786 |
+
"displayMode": "list",
|
787 |
+
"placement": "bottom",
|
788 |
+
"showLegend": true
|
789 |
+
},
|
790 |
+
"tooltip": {
|
791 |
+
"mode": "single",
|
792 |
+
"sort": "none"
|
793 |
+
}
|
794 |
+
},
|
795 |
+
"targets": [
|
796 |
+
{
|
797 |
+
"datasource": {
|
798 |
+
"type": "prometheus",
|
799 |
+
"uid": "ddyfngn31dg5cf"
|
800 |
+
},
|
801 |
+
"disableTextWrap": false,
|
802 |
+
"editorMode": "code",
|
803 |
+
"expr": "sglang:num_requests_running{instance=\"$instance\", name=\"$name\"}",
|
804 |
+
"fullMetaSearch": false,
|
805 |
+
"includeNullMetadata": true,
|
806 |
+
"instant": false,
|
807 |
+
"legendFormat": "{{__name__}}",
|
808 |
+
"range": true,
|
809 |
+
"refId": "A",
|
810 |
+
"useBackend": false
|
811 |
+
}
|
812 |
+
],
|
813 |
+
"title": "Num Requests Running",
|
814 |
+
"type": "timeseries"
|
815 |
+
},
|
816 |
+
{
|
817 |
+
"datasource": {
|
818 |
+
"default": true,
|
819 |
+
"type": "prometheus",
|
820 |
+
"uid": "ee2vha8w6f5kwf"
|
821 |
+
},
|
822 |
+
"fieldConfig": {
|
823 |
+
"defaults": {
|
824 |
+
"color": {
|
825 |
+
"mode": "palette-classic"
|
826 |
+
},
|
827 |
+
"custom": {
|
828 |
+
"axisBorderShow": false,
|
829 |
+
"axisCenteredZero": false,
|
830 |
+
"axisColorMode": "text",
|
831 |
+
"axisLabel": "",
|
832 |
+
"axisPlacement": "auto",
|
833 |
+
"barAlignment": 0,
|
834 |
+
"barWidthFactor": 0.6,
|
835 |
+
"drawStyle": "line",
|
836 |
+
"fillOpacity": 0,
|
837 |
+
"gradientMode": "none",
|
838 |
+
"hideFrom": {
|
839 |
+
"legend": false,
|
840 |
+
"tooltip": false,
|
841 |
+
"viz": false
|
842 |
+
},
|
843 |
+
"insertNulls": false,
|
844 |
+
"lineInterpolation": "linear",
|
845 |
+
"lineWidth": 1,
|
846 |
+
"pointSize": 5,
|
847 |
+
"scaleDistribution": {
|
848 |
+
"type": "linear"
|
849 |
+
},
|
850 |
+
"showPoints": "auto",
|
851 |
+
"spanNulls": false,
|
852 |
+
"stacking": {
|
853 |
+
"group": "A",
|
854 |
+
"mode": "none"
|
855 |
+
},
|
856 |
+
"thresholdsStyle": {
|
857 |
+
"mode": "off"
|
858 |
+
}
|
859 |
+
},
|
860 |
+
"mappings": [],
|
861 |
+
"thresholds": {
|
862 |
+
"mode": "absolute",
|
863 |
+
"steps": [
|
864 |
+
{
|
865 |
+
"color": "green",
|
866 |
+
"value": null
|
867 |
+
},
|
868 |
+
{
|
869 |
+
"color": "red",
|
870 |
+
"value": 80
|
871 |
+
}
|
872 |
+
]
|
873 |
+
}
|
874 |
+
},
|
875 |
+
"overrides": []
|
876 |
+
},
|
877 |
+
"gridPos": {
|
878 |
+
"h": 8,
|
879 |
+
"w": 12,
|
880 |
+
"x": 12,
|
881 |
+
"y": 11
|
882 |
+
},
|
883 |
+
"id": 8,
|
884 |
+
"options": {
|
885 |
+
"legend": {
|
886 |
+
"calcs": [],
|
887 |
+
"displayMode": "list",
|
888 |
+
"placement": "bottom",
|
889 |
+
"showLegend": true
|
890 |
+
},
|
891 |
+
"tooltip": {
|
892 |
+
"mode": "single",
|
893 |
+
"sort": "none"
|
894 |
+
}
|
895 |
+
},
|
896 |
+
"targets": [
|
897 |
+
{
|
898 |
+
"datasource": {
|
899 |
+
"type": "prometheus",
|
900 |
+
"uid": "ddyfngn31dg5cf"
|
901 |
+
},
|
902 |
+
"disableTextWrap": false,
|
903 |
+
"editorMode": "code",
|
904 |
+
"expr": "sglang:num_requests_waiting{instance=\"$instance\", name=\"$name\"}",
|
905 |
+
"fullMetaSearch": false,
|
906 |
+
"includeNullMetadata": true,
|
907 |
+
"instant": false,
|
908 |
+
"legendFormat": "{{__name__}}",
|
909 |
+
"range": true,
|
910 |
+
"refId": "A",
|
911 |
+
"useBackend": false
|
912 |
+
}
|
913 |
+
],
|
914 |
+
"title": "Number of Requests Waiting",
|
915 |
+
"type": "timeseries"
|
916 |
+
},
|
917 |
+
{
|
918 |
+
"datasource": {
|
919 |
+
"default": true,
|
920 |
+
"type": "prometheus",
|
921 |
+
"uid": "ee2vha8w6f5kwf"
|
922 |
+
},
|
923 |
+
"fieldConfig": {
|
924 |
+
"defaults": {
|
925 |
+
"color": {
|
926 |
+
"mode": "palette-classic"
|
927 |
+
},
|
928 |
+
"custom": {
|
929 |
+
"axisBorderShow": false,
|
930 |
+
"axisCenteredZero": false,
|
931 |
+
"axisColorMode": "text",
|
932 |
+
"axisLabel": "",
|
933 |
+
"axisPlacement": "auto",
|
934 |
+
"barAlignment": 0,
|
935 |
+
"barWidthFactor": 0.6,
|
936 |
+
"drawStyle": "line",
|
937 |
+
"fillOpacity": 0,
|
938 |
+
"gradientMode": "none",
|
939 |
+
"hideFrom": {
|
940 |
+
"legend": false,
|
941 |
+
"tooltip": false,
|
942 |
+
"viz": false
|
943 |
+
},
|
944 |
+
"insertNulls": false,
|
945 |
+
"lineInterpolation": "linear",
|
946 |
+
"lineWidth": 1,
|
947 |
+
"pointSize": 5,
|
948 |
+
"scaleDistribution": {
|
949 |
+
"type": "linear"
|
950 |
+
},
|
951 |
+
"showPoints": "auto",
|
952 |
+
"spanNulls": false,
|
953 |
+
"stacking": {
|
954 |
+
"group": "A",
|
955 |
+
"mode": "none"
|
956 |
+
},
|
957 |
+
"thresholdsStyle": {
|
958 |
+
"mode": "off"
|
959 |
+
}
|
960 |
+
},
|
961 |
+
"mappings": [],
|
962 |
+
"thresholds": {
|
963 |
+
"mode": "absolute",
|
964 |
+
"steps": [
|
965 |
+
{
|
966 |
+
"color": "green",
|
967 |
+
"value": null
|
968 |
+
},
|
969 |
+
{
|
970 |
+
"color": "red",
|
971 |
+
"value": 80
|
972 |
+
}
|
973 |
+
]
|
974 |
+
}
|
975 |
+
},
|
976 |
+
"overrides": []
|
977 |
+
},
|
978 |
+
"gridPos": {
|
979 |
+
"h": 8,
|
980 |
+
"w": 12,
|
981 |
+
"x": 0,
|
982 |
+
"y": 19
|
983 |
+
},
|
984 |
+
"id": 16,
|
985 |
+
"options": {
|
986 |
+
"legend": {
|
987 |
+
"calcs": [],
|
988 |
+
"displayMode": "list",
|
989 |
+
"placement": "bottom",
|
990 |
+
"showLegend": true
|
991 |
+
},
|
992 |
+
"tooltip": {
|
993 |
+
"mode": "single",
|
994 |
+
"sort": "none"
|
995 |
+
}
|
996 |
+
},
|
997 |
+
"targets": [
|
998 |
+
{
|
999 |
+
"datasource": {
|
1000 |
+
"type": "prometheus",
|
1001 |
+
"uid": "ddyfngn31dg5cf"
|
1002 |
+
},
|
1003 |
+
"disableTextWrap": false,
|
1004 |
+
"editorMode": "code",
|
1005 |
+
"expr": "histogram_quantile(0.99, sum by(le) (rate(sglang:e2e_request_latency_seconds_bucket{name=\"$name\"}[$__rate_interval])))",
|
1006 |
+
"fullMetaSearch": false,
|
1007 |
+
"includeNullMetadata": true,
|
1008 |
+
"instant": false,
|
1009 |
+
"legendFormat": "P99",
|
1010 |
+
"range": true,
|
1011 |
+
"refId": "A",
|
1012 |
+
"useBackend": false
|
1013 |
+
},
|
1014 |
+
{
|
1015 |
+
"datasource": {
|
1016 |
+
"type": "prometheus",
|
1017 |
+
"uid": "ddyfngn31dg5cf"
|
1018 |
+
},
|
1019 |
+
"disableTextWrap": false,
|
1020 |
+
"editorMode": "code",
|
1021 |
+
"expr": "histogram_quantile(0.9, sum by(le) (rate(sglang:e2e_request_latency_seconds_bucket{name=\"$name\"}[$__rate_interval])))",
|
1022 |
+
"fullMetaSearch": false,
|
1023 |
+
"hide": false,
|
1024 |
+
"includeNullMetadata": true,
|
1025 |
+
"instant": false,
|
1026 |
+
"legendFormat": "P90",
|
1027 |
+
"range": true,
|
1028 |
+
"refId": "B",
|
1029 |
+
"useBackend": false
|
1030 |
+
},
|
1031 |
+
{
|
1032 |
+
"datasource": {
|
1033 |
+
"type": "prometheus",
|
1034 |
+
"uid": "ddyfngn31dg5cf"
|
1035 |
+
},
|
1036 |
+
"disableTextWrap": false,
|
1037 |
+
"editorMode": "code",
|
1038 |
+
"expr": "histogram_quantile(0.95, sum by(le) (rate(sglang:e2e_request_latency_seconds_bucket{name=\"$name\"}[$__rate_interval])))",
|
1039 |
+
"fullMetaSearch": false,
|
1040 |
+
"hide": false,
|
1041 |
+
"includeNullMetadata": true,
|
1042 |
+
"instant": false,
|
1043 |
+
"legendFormat": "P95",
|
1044 |
+
"range": true,
|
1045 |
+
"refId": "C",
|
1046 |
+
"useBackend": false
|
1047 |
+
},
|
1048 |
+
{
|
1049 |
+
"datasource": {
|
1050 |
+
"type": "prometheus",
|
1051 |
+
"uid": "ddyfngn31dg5cf"
|
1052 |
+
},
|
1053 |
+
"disableTextWrap": false,
|
1054 |
+
"editorMode": "code",
|
1055 |
+
"expr": "histogram_quantile(0.5, sum by(le) (rate(sglang:e2e_request_latency_seconds_bucket{name=\"$name\"}[$__rate_interval])))",
|
1056 |
+
"fullMetaSearch": false,
|
1057 |
+
"hide": false,
|
1058 |
+
"includeNullMetadata": true,
|
1059 |
+
"instant": false,
|
1060 |
+
"legendFormat": "P50",
|
1061 |
+
"range": true,
|
1062 |
+
"refId": "D",
|
1063 |
+
"useBackend": false
|
1064 |
+
},
|
1065 |
+
{
|
1066 |
+
"datasource": {
|
1067 |
+
"type": "prometheus",
|
1068 |
+
"uid": "ddyfngn31dg5cf"
|
1069 |
+
},
|
1070 |
+
"disableTextWrap": false,
|
1071 |
+
"editorMode": "code",
|
1072 |
+
"expr": "rate(sglang:e2e_request_latency_seconds_sum{name=\"$name\"}[$__rate_interval]) / rate(sglang:e2e_request_latency_seconds_count{name=\"$name\"}[$__rate_interval])",
|
1073 |
+
"fullMetaSearch": false,
|
1074 |
+
"hide": false,
|
1075 |
+
"includeNullMetadata": true,
|
1076 |
+
"instant": false,
|
1077 |
+
"legendFormat": "Average",
|
1078 |
+
"range": true,
|
1079 |
+
"refId": "E",
|
1080 |
+
"useBackend": false
|
1081 |
+
}
|
1082 |
+
],
|
1083 |
+
"title": "Time Request Decoding (S)",
|
1084 |
+
"type": "timeseries"
|
1085 |
+
},
|
1086 |
+
{
|
1087 |
+
"datasource": {
|
1088 |
+
"default": true,
|
1089 |
+
"type": "prometheus",
|
1090 |
+
"uid": "ee2vha8w6f5kwf"
|
1091 |
+
},
|
1092 |
+
"description": "Time requests waiting before added to batch",
|
1093 |
+
"fieldConfig": {
|
1094 |
+
"defaults": {
|
1095 |
+
"color": {
|
1096 |
+
"mode": "palette-classic"
|
1097 |
+
},
|
1098 |
+
"custom": {
|
1099 |
+
"axisBorderShow": false,
|
1100 |
+
"axisCenteredZero": false,
|
1101 |
+
"axisColorMode": "text",
|
1102 |
+
"axisLabel": "",
|
1103 |
+
"axisPlacement": "auto",
|
1104 |
+
"barAlignment": 0,
|
1105 |
+
"barWidthFactor": 0.6,
|
1106 |
+
"drawStyle": "line",
|
1107 |
+
"fillOpacity": 0,
|
1108 |
+
"gradientMode": "none",
|
1109 |
+
"hideFrom": {
|
1110 |
+
"legend": false,
|
1111 |
+
"tooltip": false,
|
1112 |
+
"viz": false
|
1113 |
+
},
|
1114 |
+
"insertNulls": false,
|
1115 |
+
"lineInterpolation": "linear",
|
1116 |
+
"lineWidth": 1,
|
1117 |
+
"pointSize": 5,
|
1118 |
+
"scaleDistribution": {
|
1119 |
+
"type": "linear"
|
1120 |
+
},
|
1121 |
+
"showPoints": "auto",
|
1122 |
+
"spanNulls": false,
|
1123 |
+
"stacking": {
|
1124 |
+
"group": "A",
|
1125 |
+
"mode": "none"
|
1126 |
+
},
|
1127 |
+
"thresholdsStyle": {
|
1128 |
+
"mode": "off"
|
1129 |
+
}
|
1130 |
+
},
|
1131 |
+
"mappings": [],
|
1132 |
+
"thresholds": {
|
1133 |
+
"mode": "absolute",
|
1134 |
+
"steps": [
|
1135 |
+
{
|
1136 |
+
"color": "green",
|
1137 |
+
"value": null
|
1138 |
+
},
|
1139 |
+
{
|
1140 |
+
"color": "red",
|
1141 |
+
"value": 80
|
1142 |
+
}
|
1143 |
+
]
|
1144 |
+
}
|
1145 |
+
},
|
1146 |
+
"overrides": []
|
1147 |
+
},
|
1148 |
+
"gridPos": {
|
1149 |
+
"h": 8,
|
1150 |
+
"w": 12,
|
1151 |
+
"x": 12,
|
1152 |
+
"y": 19
|
1153 |
+
},
|
1154 |
+
"id": 15,
|
1155 |
+
"options": {
|
1156 |
+
"legend": {
|
1157 |
+
"calcs": [],
|
1158 |
+
"displayMode": "list",
|
1159 |
+
"placement": "bottom",
|
1160 |
+
"showLegend": true
|
1161 |
+
},
|
1162 |
+
"tooltip": {
|
1163 |
+
"mode": "single",
|
1164 |
+
"sort": "none"
|
1165 |
+
}
|
1166 |
+
},
|
1167 |
+
"targets": [
|
1168 |
+
{
|
1169 |
+
"datasource": {
|
1170 |
+
"type": "prometheus",
|
1171 |
+
"uid": "ddyfngn31dg5cf"
|
1172 |
+
},
|
1173 |
+
"editorMode": "code",
|
1174 |
+
"expr": "histogram_quantile(0.99, sum by (le) (rate(sglang:waiting_request_latency_seconds_bucket{name=\"$name\"}[$__rate_interval])))",
|
1175 |
+
"instant": false,
|
1176 |
+
"legendFormat": "P99",
|
1177 |
+
"range": true,
|
1178 |
+
"refId": "A"
|
1179 |
+
},
|
1180 |
+
{
|
1181 |
+
"datasource": {
|
1182 |
+
"type": "prometheus",
|
1183 |
+
"uid": "ddyfngn31dg5cf"
|
1184 |
+
},
|
1185 |
+
"editorMode": "code",
|
1186 |
+
"expr": "histogram_quantile(0.95, sum by (le) (rate(sglang:waiting_request_latency_seconds_bucket{name=\"$name\"}[$__rate_interval])))",
|
1187 |
+
"hide": false,
|
1188 |
+
"instant": false,
|
1189 |
+
"legendFormat": "P95",
|
1190 |
+
"range": true,
|
1191 |
+
"refId": "B"
|
1192 |
+
},
|
1193 |
+
{
|
1194 |
+
"datasource": {
|
1195 |
+
"type": "prometheus",
|
1196 |
+
"uid": "ddyfngn31dg5cf"
|
1197 |
+
},
|
1198 |
+
"editorMode": "code",
|
1199 |
+
"expr": "histogram_quantile(0.9, sum by (le) (rate(sglang:waiting_request_latency_seconds_bucket{name=\"$name\"}[$__rate_interval])))",
|
1200 |
+
"hide": false,
|
1201 |
+
"instant": false,
|
1202 |
+
"legendFormat": "P90",
|
1203 |
+
"range": true,
|
1204 |
+
"refId": "C"
|
1205 |
+
},
|
1206 |
+
{
|
1207 |
+
"datasource": {
|
1208 |
+
"type": "prometheus",
|
1209 |
+
"uid": "ddyfngn31dg5cf"
|
1210 |
+
},
|
1211 |
+
"editorMode": "code",
|
1212 |
+
"expr": "histogram_quantile(0.5, sum by (le) (rate(sglang:waiting_request_latency_seconds_bucket{name=\"$name\"}[$__rate_interval])))",
|
1213 |
+
"hide": false,
|
1214 |
+
"instant": false,
|
1215 |
+
"legendFormat": "P50",
|
1216 |
+
"range": true,
|
1217 |
+
"refId": "D"
|
1218 |
+
},
|
1219 |
+
{
|
1220 |
+
"datasource": {
|
1221 |
+
"type": "prometheus",
|
1222 |
+
"uid": "ddyfngn31dg5cf"
|
1223 |
+
},
|
1224 |
+
"editorMode": "code",
|
1225 |
+
"expr": "rate(sglang:waiting_request_latency_seconds_sum{name=\"$name\"}[$__rate_interval])\r\n/\r\nrate(sglang:waiting_request_latency_seconds_count{name=\"$name\"}[$__rate_interval])",
|
1226 |
+
"hide": false,
|
1227 |
+
"instant": false,
|
1228 |
+
"legendFormat": "Average",
|
1229 |
+
"range": true,
|
1230 |
+
"refId": "E"
|
1231 |
+
}
|
1232 |
+
],
|
1233 |
+
"title": "Time Request Waiting (S)",
|
1234 |
+
"type": "timeseries"
|
1235 |
+
},
|
1236 |
+
{
|
1237 |
+
"datasource": {
|
1238 |
+
"default": true,
|
1239 |
+
"type": "prometheus",
|
1240 |
+
"uid": "ee2vha8w6f5kwf"
|
1241 |
+
},
|
1242 |
+
"fieldConfig": {
|
1243 |
+
"defaults": {
|
1244 |
+
"color": {
|
1245 |
+
"mode": "palette-classic"
|
1246 |
+
},
|
1247 |
+
"custom": {
|
1248 |
+
"axisBorderShow": false,
|
1249 |
+
"axisCenteredZero": false,
|
1250 |
+
"axisColorMode": "text",
|
1251 |
+
"axisLabel": "",
|
1252 |
+
"axisPlacement": "auto",
|
1253 |
+
"barAlignment": 0,
|
1254 |
+
"barWidthFactor": 0.6,
|
1255 |
+
"drawStyle": "line",
|
1256 |
+
"fillOpacity": 0,
|
1257 |
+
"gradientMode": "none",
|
1258 |
+
"hideFrom": {
|
1259 |
+
"legend": false,
|
1260 |
+
"tooltip": false,
|
1261 |
+
"viz": false
|
1262 |
+
},
|
1263 |
+
"insertNulls": false,
|
1264 |
+
"lineInterpolation": "linear",
|
1265 |
+
"lineWidth": 1,
|
1266 |
+
"pointSize": 5,
|
1267 |
+
"scaleDistribution": {
|
1268 |
+
"type": "linear"
|
1269 |
+
},
|
1270 |
+
"showPoints": "auto",
|
1271 |
+
"spanNulls": false,
|
1272 |
+
"stacking": {
|
1273 |
+
"group": "A",
|
1274 |
+
"mode": "none"
|
1275 |
+
},
|
1276 |
+
"thresholdsStyle": {
|
1277 |
+
"mode": "off"
|
1278 |
+
}
|
1279 |
+
},
|
1280 |
+
"mappings": [],
|
1281 |
+
"thresholds": {
|
1282 |
+
"mode": "absolute",
|
1283 |
+
"steps": [
|
1284 |
+
{
|
1285 |
+
"color": "green",
|
1286 |
+
"value": null
|
1287 |
+
},
|
1288 |
+
{
|
1289 |
+
"color": "red",
|
1290 |
+
"value": 80
|
1291 |
+
}
|
1292 |
+
]
|
1293 |
+
}
|
1294 |
+
},
|
1295 |
+
"overrides": []
|
1296 |
+
},
|
1297 |
+
"gridPos": {
|
1298 |
+
"h": 8,
|
1299 |
+
"w": 12,
|
1300 |
+
"x": 0,
|
1301 |
+
"y": 27
|
1302 |
+
},
|
1303 |
+
"id": 11,
|
1304 |
+
"options": {
|
1305 |
+
"legend": {
|
1306 |
+
"calcs": [],
|
1307 |
+
"displayMode": "list",
|
1308 |
+
"placement": "bottom",
|
1309 |
+
"showLegend": true
|
1310 |
+
},
|
1311 |
+
"tooltip": {
|
1312 |
+
"mode": "single",
|
1313 |
+
"sort": "none"
|
1314 |
+
}
|
1315 |
+
},
|
1316 |
+
"targets": [
|
1317 |
+
{
|
1318 |
+
"datasource": {
|
1319 |
+
"type": "prometheus",
|
1320 |
+
"uid": "ddyfngn31dg5cf"
|
1321 |
+
},
|
1322 |
+
"disableTextWrap": false,
|
1323 |
+
"editorMode": "code",
|
1324 |
+
"expr": "sum(rate(sglang:request_prompt_tokens_sum{instance=\"$instance\", name=\"$name\"}[$__rate_interval])) by (instance, name)",
|
1325 |
+
"fullMetaSearch": false,
|
1326 |
+
"includeNullMetadata": true,
|
1327 |
+
"instant": false,
|
1328 |
+
"legendFormat": "{{__name__}}",
|
1329 |
+
"range": true,
|
1330 |
+
"refId": "A",
|
1331 |
+
"useBackend": false
|
1332 |
+
},
|
1333 |
+
{
|
1334 |
+
"datasource": {
|
1335 |
+
"type": "prometheus",
|
1336 |
+
"uid": "ddyfngn31dg5cf"
|
1337 |
+
},
|
1338 |
+
"disableTextWrap": false,
|
1339 |
+
"editorMode": "code",
|
1340 |
+
"expr": "",
|
1341 |
+
"fullMetaSearch": false,
|
1342 |
+
"hide": false,
|
1343 |
+
"includeNullMetadata": true,
|
1344 |
+
"instant": false,
|
1345 |
+
"legendFormat": "__auto",
|
1346 |
+
"range": true,
|
1347 |
+
"refId": "B",
|
1348 |
+
"useBackend": false
|
1349 |
+
}
|
1350 |
+
],
|
1351 |
+
"title": "Prompt Tokens",
|
1352 |
+
"type": "timeseries"
|
1353 |
+
},
|
1354 |
+
{
|
1355 |
+
"datasource": {
|
1356 |
+
"default": true,
|
1357 |
+
"type": "prometheus",
|
1358 |
+
"uid": "ee2vha8w6f5kwf"
|
1359 |
+
},
|
1360 |
+
"fieldConfig": {
|
1361 |
+
"defaults": {
|
1362 |
+
"color": {
|
1363 |
+
"mode": "palette-classic"
|
1364 |
+
},
|
1365 |
+
"custom": {
|
1366 |
+
"axisBorderShow": false,
|
1367 |
+
"axisCenteredZero": false,
|
1368 |
+
"axisColorMode": "text",
|
1369 |
+
"axisLabel": "",
|
1370 |
+
"axisPlacement": "auto",
|
1371 |
+
"barAlignment": 0,
|
1372 |
+
"barWidthFactor": 0.6,
|
1373 |
+
"drawStyle": "line",
|
1374 |
+
"fillOpacity": 0,
|
1375 |
+
"gradientMode": "none",
|
1376 |
+
"hideFrom": {
|
1377 |
+
"legend": false,
|
1378 |
+
"tooltip": false,
|
1379 |
+
"viz": false
|
1380 |
+
},
|
1381 |
+
"insertNulls": false,
|
1382 |
+
"lineInterpolation": "linear",
|
1383 |
+
"lineWidth": 1,
|
1384 |
+
"pointSize": 5,
|
1385 |
+
"scaleDistribution": {
|
1386 |
+
"type": "linear"
|
1387 |
+
},
|
1388 |
+
"showPoints": "auto",
|
1389 |
+
"spanNulls": false,
|
1390 |
+
"stacking": {
|
1391 |
+
"group": "A",
|
1392 |
+
"mode": "none"
|
1393 |
+
},
|
1394 |
+
"thresholdsStyle": {
|
1395 |
+
"mode": "off"
|
1396 |
+
}
|
1397 |
+
},
|
1398 |
+
"mappings": [],
|
1399 |
+
"thresholds": {
|
1400 |
+
"mode": "absolute",
|
1401 |
+
"steps": [
|
1402 |
+
{
|
1403 |
+
"color": "green",
|
1404 |
+
"value": null
|
1405 |
+
},
|
1406 |
+
{
|
1407 |
+
"color": "red",
|
1408 |
+
"value": 80
|
1409 |
+
}
|
1410 |
+
]
|
1411 |
+
}
|
1412 |
+
},
|
1413 |
+
"overrides": []
|
1414 |
+
},
|
1415 |
+
"gridPos": {
|
1416 |
+
"h": 8,
|
1417 |
+
"w": 12,
|
1418 |
+
"x": 12,
|
1419 |
+
"y": 27
|
1420 |
+
},
|
1421 |
+
"id": 17,
|
1422 |
+
"options": {
|
1423 |
+
"legend": {
|
1424 |
+
"calcs": [],
|
1425 |
+
"displayMode": "list",
|
1426 |
+
"placement": "bottom",
|
1427 |
+
"showLegend": true
|
1428 |
+
},
|
1429 |
+
"tooltip": {
|
1430 |
+
"mode": "single",
|
1431 |
+
"sort": "none"
|
1432 |
+
}
|
1433 |
+
},
|
1434 |
+
"targets": [
|
1435 |
+
{
|
1436 |
+
"datasource": {
|
1437 |
+
"type": "prometheus",
|
1438 |
+
"uid": "ddyfngn31dg5cf"
|
1439 |
+
},
|
1440 |
+
"disableTextWrap": false,
|
1441 |
+
"editorMode": "code",
|
1442 |
+
"expr": "sum(rate(sglang:request_generation_tokens_sum{instance=\"$instance\", name=\"$name\"}[$__rate_interval])) by (instance, name)",
|
1443 |
+
"fullMetaSearch": false,
|
1444 |
+
"includeNullMetadata": true,
|
1445 |
+
"instant": false,
|
1446 |
+
"legendFormat": "{{__name__}}",
|
1447 |
+
"range": true,
|
1448 |
+
"refId": "A",
|
1449 |
+
"useBackend": false
|
1450 |
+
}
|
1451 |
+
],
|
1452 |
+
"title": "Generated Tokens",
|
1453 |
+
"type": "timeseries"
|
1454 |
+
},
|
1455 |
+
{
|
1456 |
+
"datasource": {
|
1457 |
+
"default": true,
|
1458 |
+
"type": "prometheus",
|
1459 |
+
"uid": "ee2vha8w6f5kwf"
|
1460 |
+
},
|
1461 |
+
"fieldConfig": {
|
1462 |
+
"defaults": {
|
1463 |
+
"custom": {
|
1464 |
+
"hideFrom": {
|
1465 |
+
"legend": false,
|
1466 |
+
"tooltip": false,
|
1467 |
+
"viz": false
|
1468 |
+
},
|
1469 |
+
"scaleDistribution": {
|
1470 |
+
"type": "linear"
|
1471 |
+
}
|
1472 |
+
}
|
1473 |
+
},
|
1474 |
+
"overrides": []
|
1475 |
+
},
|
1476 |
+
"gridPos": {
|
1477 |
+
"h": 8,
|
1478 |
+
"w": 12,
|
1479 |
+
"x": 0,
|
1480 |
+
"y": 35
|
1481 |
+
},
|
1482 |
+
"id": 13,
|
1483 |
+
"options": {
|
1484 |
+
"calculate": false,
|
1485 |
+
"calculation": {
|
1486 |
+
"yBuckets": {
|
1487 |
+
"scale": {
|
1488 |
+
"log": 2,
|
1489 |
+
"type": "log"
|
1490 |
+
}
|
1491 |
+
}
|
1492 |
+
},
|
1493 |
+
"cellGap": 1,
|
1494 |
+
"color": {
|
1495 |
+
"exponent": 0.5,
|
1496 |
+
"fill": "dark-orange",
|
1497 |
+
"mode": "scheme",
|
1498 |
+
"reverse": false,
|
1499 |
+
"scale": "exponential",
|
1500 |
+
"scheme": "Oranges",
|
1501 |
+
"steps": 64
|
1502 |
+
},
|
1503 |
+
"exemplars": {
|
1504 |
+
"color": "rgba(255,0,255,0.7)"
|
1505 |
+
},
|
1506 |
+
"filterValues": {
|
1507 |
+
"le": 1e-9
|
1508 |
+
},
|
1509 |
+
"legend": {
|
1510 |
+
"show": true
|
1511 |
+
},
|
1512 |
+
"rowsFrame": {
|
1513 |
+
"layout": "auto"
|
1514 |
+
},
|
1515 |
+
"tooltip": {
|
1516 |
+
"mode": "single",
|
1517 |
+
"showColorScale": false,
|
1518 |
+
"yHistogram": false
|
1519 |
+
},
|
1520 |
+
"yAxis": {
|
1521 |
+
"axisPlacement": "left",
|
1522 |
+
"reverse": false
|
1523 |
+
}
|
1524 |
+
},
|
1525 |
+
"pluginVersion": "11.2.0",
|
1526 |
+
"targets": [
|
1527 |
+
{
|
1528 |
+
"datasource": {
|
1529 |
+
"type": "prometheus",
|
1530 |
+
"uid": "ddyfngn31dg5cf"
|
1531 |
+
},
|
1532 |
+
"disableTextWrap": false,
|
1533 |
+
"editorMode": "code",
|
1534 |
+
"expr": "sum by(le) (increase(sglang:request_prompt_tokens_bucket{name=\"$name\", instance=\"$instance\"}[$__rate_interval]))",
|
1535 |
+
"fullMetaSearch": false,
|
1536 |
+
"includeNullMetadata": true,
|
1537 |
+
"instant": false,
|
1538 |
+
"legendFormat": "{{__name__}}",
|
1539 |
+
"range": true,
|
1540 |
+
"refId": "A",
|
1541 |
+
"useBackend": false
|
1542 |
+
}
|
1543 |
+
],
|
1544 |
+
"title": "Request Prompt Tokens",
|
1545 |
+
"type": "heatmap"
|
1546 |
+
},
|
1547 |
+
{
|
1548 |
+
"datasource": {
|
1549 |
+
"default": true,
|
1550 |
+
"type": "prometheus",
|
1551 |
+
"uid": "ee2vha8w6f5kwf"
|
1552 |
+
},
|
1553 |
+
"description": "",
|
1554 |
+
"fieldConfig": {
|
1555 |
+
"defaults": {
|
1556 |
+
"custom": {
|
1557 |
+
"hideFrom": {
|
1558 |
+
"legend": false,
|
1559 |
+
"tooltip": false,
|
1560 |
+
"viz": false
|
1561 |
+
},
|
1562 |
+
"scaleDistribution": {
|
1563 |
+
"type": "linear"
|
1564 |
+
}
|
1565 |
+
}
|
1566 |
+
},
|
1567 |
+
"overrides": []
|
1568 |
+
},
|
1569 |
+
"gridPos": {
|
1570 |
+
"h": 8,
|
1571 |
+
"w": 12,
|
1572 |
+
"x": 12,
|
1573 |
+
"y": 35
|
1574 |
+
},
|
1575 |
+
"id": 12,
|
1576 |
+
"options": {
|
1577 |
+
"calculate": false,
|
1578 |
+
"calculation": {
|
1579 |
+
"xBuckets": {
|
1580 |
+
"mode": "size",
|
1581 |
+
"value": ""
|
1582 |
+
},
|
1583 |
+
"yBuckets": {
|
1584 |
+
"mode": "size",
|
1585 |
+
"scale": {
|
1586 |
+
"log": 2,
|
1587 |
+
"type": "log"
|
1588 |
+
},
|
1589 |
+
"value": ""
|
1590 |
+
}
|
1591 |
+
},
|
1592 |
+
"cellGap": 1,
|
1593 |
+
"color": {
|
1594 |
+
"exponent": 0.5,
|
1595 |
+
"fill": "dark-orange",
|
1596 |
+
"min": 0,
|
1597 |
+
"mode": "scheme",
|
1598 |
+
"reverse": false,
|
1599 |
+
"scale": "exponential",
|
1600 |
+
"scheme": "Spectral",
|
1601 |
+
"steps": 64
|
1602 |
+
},
|
1603 |
+
"exemplars": {
|
1604 |
+
"color": "rgba(255,0,255,0.7)"
|
1605 |
+
},
|
1606 |
+
"filterValues": {
|
1607 |
+
"le": 1e-9
|
1608 |
+
},
|
1609 |
+
"legend": {
|
1610 |
+
"show": true
|
1611 |
+
},
|
1612 |
+
"rowsFrame": {
|
1613 |
+
"layout": "auto",
|
1614 |
+
"value": "Request count"
|
1615 |
+
},
|
1616 |
+
"tooltip": {
|
1617 |
+
"mode": "single",
|
1618 |
+
"showColorScale": false,
|
1619 |
+
"yHistogram": true
|
1620 |
+
},
|
1621 |
+
"yAxis": {
|
1622 |
+
"axisLabel": "Generation Length",
|
1623 |
+
"axisPlacement": "left",
|
1624 |
+
"reverse": false,
|
1625 |
+
"unit": "none"
|
1626 |
+
}
|
1627 |
+
},
|
1628 |
+
"pluginVersion": "11.2.0",
|
1629 |
+
"targets": [
|
1630 |
+
{
|
1631 |
+
"datasource": {
|
1632 |
+
"type": "prometheus",
|
1633 |
+
"uid": "ddyfngn31dg5cf"
|
1634 |
+
},
|
1635 |
+
"disableTextWrap": false,
|
1636 |
+
"editorMode": "code",
|
1637 |
+
"expr": "sum by(le) (increase(sglang:request_generation_tokens_bucket{name=\"$name\", instance=\"$instance\"}[$__rate_interval]))",
|
1638 |
+
"fullMetaSearch": false,
|
1639 |
+
"includeNullMetadata": true,
|
1640 |
+
"instant": false,
|
1641 |
+
"legendFormat": "{{__name__}}",
|
1642 |
+
"range": true,
|
1643 |
+
"refId": "A",
|
1644 |
+
"useBackend": false
|
1645 |
+
}
|
1646 |
+
],
|
1647 |
+
"title": "Request Generation Tokens",
|
1648 |
+
"type": "heatmap"
|
1649 |
+
}
|
1650 |
+
],
|
1651 |
+
"refresh": "5s",
|
1652 |
+
"schemaVersion": 39,
|
1653 |
+
"tags": [],
|
1654 |
+
"templating": {
|
1655 |
+
"list": [
|
1656 |
+
{
|
1657 |
+
"current": {
|
1658 |
+
"selected": false,
|
1659 |
+
"text": "127.0.0.1:30000",
|
1660 |
+
"value": "127.0.0.1:30000"
|
1661 |
+
},
|
1662 |
+
"datasource": {
|
1663 |
+
"type": "prometheus",
|
1664 |
+
"uid": "ddyfngn31dg5cf"
|
1665 |
+
},
|
1666 |
+
"definition": "label_values(instance)",
|
1667 |
+
"hide": 0,
|
1668 |
+
"includeAll": false,
|
1669 |
+
"label": "instance",
|
1670 |
+
"multi": false,
|
1671 |
+
"name": "instance",
|
1672 |
+
"options": [],
|
1673 |
+
"query": {
|
1674 |
+
"qryType": 1,
|
1675 |
+
"query": "label_values(instance)",
|
1676 |
+
"refId": "PrometheusVariableQueryEditor-VariableQuery"
|
1677 |
+
},
|
1678 |
+
"refresh": 1,
|
1679 |
+
"regex": "",
|
1680 |
+
"skipUrlSync": false,
|
1681 |
+
"sort": 0,
|
1682 |
+
"type": "query"
|
1683 |
+
},
|
1684 |
+
{
|
1685 |
+
"current": {
|
1686 |
+
"selected": true,
|
1687 |
+
"text": "google/gemma-2-9b-it",
|
1688 |
+
"value": "google/gemma-2-9b-it"
|
1689 |
+
},
|
1690 |
+
"definition": "label_values(name)",
|
1691 |
+
"hide": 1,
|
1692 |
+
"includeAll": false,
|
1693 |
+
"label": "name",
|
1694 |
+
"multi": false,
|
1695 |
+
"name": "name",
|
1696 |
+
"options": [],
|
1697 |
+
"query": {
|
1698 |
+
"qryType": 1,
|
1699 |
+
"query": "label_values(name)",
|
1700 |
+
"refId": "PrometheusVariableQueryEditor-VariableQuery"
|
1701 |
+
},
|
1702 |
+
"refresh": 1,
|
1703 |
+
"regex": "",
|
1704 |
+
"skipUrlSync": false,
|
1705 |
+
"sort": 0,
|
1706 |
+
"type": "query"
|
1707 |
+
}
|
1708 |
+
]
|
1709 |
+
},
|
1710 |
+
"time": {
|
1711 |
+
"from": "now-30m",
|
1712 |
+
"to": "now"
|
1713 |
+
},
|
1714 |
+
"timepicker": {},
|
1715 |
+
"timezone": "browser",
|
1716 |
+
"title": "SGLang Dashboard",
|
1717 |
+
"uid": "ddyp55uq7brpcc",
|
1718 |
+
"version": 3,
|
1719 |
+
"weekStart": ""
|
1720 |
+
}
|
sglang/examples/monitoring/prometheus.yaml
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# prometheus.yaml
|
2 |
+
global:
|
3 |
+
scrape_interval: 5s
|
4 |
+
evaluation_interval: 30s
|
5 |
+
|
6 |
+
scrape_configs:
|
7 |
+
- job_name: sglang
|
8 |
+
static_configs:
|
9 |
+
- targets:
|
10 |
+
- '127.0.0.1:30000'
|
sglang/examples/runtime/lora.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# launch server
|
2 |
+
# python -m sglang.launch_server --model mistralai/Mistral-7B-Instruct-v0.3 --lora-paths /home/ying/test_lora lora1=/home/ying/test_lora_1 lora2=/home/ying/test_lora_2 --disable-radix --disable-cuda-graph --max-loras-per-batch 4
|
3 |
+
|
4 |
+
# send requests
|
5 |
+
# lora_path[i] specifies the LoRA used for text[i], so make sure they have the same length
|
6 |
+
# use None to specify base-only prompt, e.x. "lora_path": [None, "/home/ying/test_lora"]
|
7 |
+
import json
|
8 |
+
|
9 |
+
import requests
|
10 |
+
|
11 |
+
url = "http://127.0.0.1:30000"
|
12 |
+
json_data = {
|
13 |
+
"text": [
|
14 |
+
"prompt 1",
|
15 |
+
"prompt 2",
|
16 |
+
"prompt 3",
|
17 |
+
"prompt 4",
|
18 |
+
"prompt 5",
|
19 |
+
"prompt 6",
|
20 |
+
"prompt 7",
|
21 |
+
],
|
22 |
+
"sampling_params": {"max_new_tokens": 32},
|
23 |
+
"lora_path": [
|
24 |
+
"/home/ying/test_lora",
|
25 |
+
"lora1",
|
26 |
+
"lora2",
|
27 |
+
"lora1",
|
28 |
+
"lora2",
|
29 |
+
None,
|
30 |
+
None,
|
31 |
+
],
|
32 |
+
}
|
33 |
+
response = requests.post(
|
34 |
+
url + "/generate",
|
35 |
+
json=json_data,
|
36 |
+
)
|
37 |
+
print(json.dumps(response.json()))
|
sglang/examples/runtime/openai_batch_complete.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Usage:
|
3 |
+
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
|
4 |
+
python openai_batch_complete.py
|
5 |
+
Note: Before running this script,
|
6 |
+
you should create the input.jsonl file with the following content:
|
7 |
+
{"custom_id": "request-1", "method": "POST", "url": "/v1/completions", "body": {"model": "gpt-3.5-turbo-instruct", "prompt": "List 3 names of famous soccer player: ", "max_tokens": 200}}
|
8 |
+
{"custom_id": "request-2", "method": "POST", "url": "/v1/completions", "body": {"model": "gpt-3.5-turbo-instruct", "prompt": "List 6 names of famous basketball player: ", "max_tokens": 400}}
|
9 |
+
{"custom_id": "request-3", "method": "POST", "url": "/v1/completions", "body": {"model": "gpt-3.5-turbo-instruct", "prompt": "List 6 names of famous basketball player: ", "max_tokens": 400}}
|
10 |
+
"""
|
11 |
+
|
12 |
+
import json
|
13 |
+
import time
|
14 |
+
|
15 |
+
import openai
|
16 |
+
|
17 |
+
|
18 |
+
class OpenAIBatchProcessor:
|
19 |
+
def __init__(self):
|
20 |
+
client = openai.Client(base_url="http://127.0.0.1:30000/v1", api_key="EMPTY")
|
21 |
+
|
22 |
+
self.client = client
|
23 |
+
|
24 |
+
def process_batch(self, input_file_path, endpoint, completion_window):
|
25 |
+
|
26 |
+
# Upload the input file
|
27 |
+
with open(input_file_path, "rb") as file:
|
28 |
+
uploaded_file = self.client.files.create(file=file, purpose="batch")
|
29 |
+
|
30 |
+
# Create the batch job
|
31 |
+
batch_job = self.client.batches.create(
|
32 |
+
input_file_id=uploaded_file.id,
|
33 |
+
endpoint=endpoint,
|
34 |
+
completion_window=completion_window,
|
35 |
+
)
|
36 |
+
|
37 |
+
# Monitor the batch job status
|
38 |
+
while batch_job.status not in ["completed", "failed", "cancelled"]:
|
39 |
+
time.sleep(3) # Wait for 3 seconds before checking the status again
|
40 |
+
print(
|
41 |
+
f"Batch job status: {batch_job.status}...trying again in 3 seconds..."
|
42 |
+
)
|
43 |
+
batch_job = self.client.batches.retrieve(batch_job.id)
|
44 |
+
|
45 |
+
# Check the batch job status and errors
|
46 |
+
if batch_job.status == "failed":
|
47 |
+
print(f"Batch job failed with status: {batch_job.status}")
|
48 |
+
print(f"Batch job errors: {batch_job.errors}")
|
49 |
+
return None
|
50 |
+
|
51 |
+
# If the batch job is completed, process the results
|
52 |
+
if batch_job.status == "completed":
|
53 |
+
|
54 |
+
# print result of batch job
|
55 |
+
print("batch", batch_job.request_counts)
|
56 |
+
|
57 |
+
result_file_id = batch_job.output_file_id
|
58 |
+
# Retrieve the file content from the server
|
59 |
+
file_response = self.client.files.content(result_file_id)
|
60 |
+
result_content = file_response.read() # Read the content of the file
|
61 |
+
|
62 |
+
# Save the content to a local file
|
63 |
+
result_file_name = "batch_job_complete_results.jsonl"
|
64 |
+
with open(result_file_name, "wb") as file:
|
65 |
+
file.write(result_content) # Write the binary content to the file
|
66 |
+
# Load data from the saved JSONL file
|
67 |
+
results = []
|
68 |
+
with open(result_file_name, "r", encoding="utf-8") as file:
|
69 |
+
for line in file:
|
70 |
+
json_object = json.loads(
|
71 |
+
line.strip()
|
72 |
+
) # Parse each line as a JSON object
|
73 |
+
results.append(json_object)
|
74 |
+
|
75 |
+
return results
|
76 |
+
else:
|
77 |
+
print(f"Batch job failed with status: {batch_job.status}")
|
78 |
+
return None
|
79 |
+
|
80 |
+
|
81 |
+
# Initialize the OpenAIBatchProcessor
|
82 |
+
processor = OpenAIBatchProcessor()
|
83 |
+
|
84 |
+
# Process the batch job
|
85 |
+
input_file_path = "input.jsonl"
|
86 |
+
endpoint = "/v1/completions"
|
87 |
+
completion_window = "24h"
|
88 |
+
|
89 |
+
# Process the batch job
|
90 |
+
results = processor.process_batch(input_file_path, endpoint, completion_window)
|
91 |
+
|
92 |
+
# Print the results
|
93 |
+
print(results)
|
sglang/examples/runtime/openai_chat_with_response_prefill.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Usage:
|
3 |
+
python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --port 30000
|
4 |
+
python openai_chat.py
|
5 |
+
"""
|
6 |
+
|
7 |
+
import openai
|
8 |
+
from openai import OpenAI
|
9 |
+
|
10 |
+
client = openai.Client(base_url="http://127.0.0.1:30000/v1", api_key="EMPTY")
|
11 |
+
|
12 |
+
response = client.chat.completions.create(
|
13 |
+
model="meta-llama/Llama-3.1-8B-Instruct",
|
14 |
+
messages=[
|
15 |
+
{"role": "system", "content": "You are a helpful AI assistant"},
|
16 |
+
{
|
17 |
+
"role": "user",
|
18 |
+
"content": """
|
19 |
+
Extract the name, size, price, and color from this product description as a JSON object:
|
20 |
+
|
21 |
+
<description>
|
22 |
+
The SmartHome Mini is a compact smart home assistant available in black or white for only $49.99. At just 5 inches wide, it lets you control lights, thermostats, and other connected devices via voice or app—no matter where you place it in your home. This affordable little hub brings convenient hands-free control to your smart devices.
|
23 |
+
</description>
|
24 |
+
""",
|
25 |
+
},
|
26 |
+
{
|
27 |
+
"role": "assistant",
|
28 |
+
"content": "{\n",
|
29 |
+
},
|
30 |
+
],
|
31 |
+
temperature=0,
|
32 |
+
)
|
33 |
+
|
34 |
+
print(response.choices[0].message.content)
|
sglang/scripts/deprecated/convert_yi_vl.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Convert Yi-VL config into a format useable with SGLang
|
3 |
+
|
4 |
+
Usage: python3 scripts/convert_yi_vl.py --model-path <path-to-model>
|
5 |
+
"""
|
6 |
+
|
7 |
+
import argparse
|
8 |
+
import json
|
9 |
+
import os
|
10 |
+
|
11 |
+
from transformers import AutoConfig, AutoTokenizer
|
12 |
+
|
13 |
+
|
14 |
+
def add_image_token(model_path: str):
|
15 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
16 |
+
tokenizer.add_tokens(["<image_placeholder>"], special_tokens=True)
|
17 |
+
|
18 |
+
print(tokenizer)
|
19 |
+
tokenizer.save_pretrained(model_path)
|
20 |
+
|
21 |
+
|
22 |
+
def edit_model_config(model_path):
|
23 |
+
config = AutoConfig.from_pretrained(model_path)
|
24 |
+
|
25 |
+
setattr(config, "architectures", ["YiVLForCausalLM"])
|
26 |
+
setattr(config, "image_token_index", 64002)
|
27 |
+
|
28 |
+
print(config)
|
29 |
+
config.save_pretrained(model_path)
|
30 |
+
|
31 |
+
|
32 |
+
if __name__ == "__main__":
|
33 |
+
parser = argparse.ArgumentParser()
|
34 |
+
parser.add_argument("--model-path", type=str)
|
35 |
+
args = parser.parse_args()
|
36 |
+
|
37 |
+
add_image_token(args.model_path)
|
38 |
+
edit_model_config(args.model_path)
|
sglang/scripts/deprecated/test_httpserver_classify.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Usage:
|
3 |
+
python3 -m sglang.launch_server --model-path /model/llama-classification --is-embedding --disable-radix-cache
|
4 |
+
|
5 |
+
python3 test_httpserver_classify.py
|
6 |
+
"""
|
7 |
+
|
8 |
+
import argparse
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import requests
|
12 |
+
|
13 |
+
|
14 |
+
def get_logits_deprecated(url: str, prompt: str):
|
15 |
+
response = requests.post(
|
16 |
+
url + "/generate",
|
17 |
+
json={
|
18 |
+
"text": prompt,
|
19 |
+
"sampling_params": {
|
20 |
+
"max_new_tokens": 0,
|
21 |
+
},
|
22 |
+
"return_logprob": True,
|
23 |
+
},
|
24 |
+
)
|
25 |
+
return response.json()["meta_info"]["normalized_prompt_logprob"]
|
26 |
+
|
27 |
+
|
28 |
+
def get_logits_batch_deprecated(url: str, prompts: list[str]):
|
29 |
+
response = requests.post(
|
30 |
+
url + "/generate",
|
31 |
+
json={
|
32 |
+
"text": prompts,
|
33 |
+
"sampling_params": {
|
34 |
+
"max_new_tokens": 0,
|
35 |
+
},
|
36 |
+
"return_logprob": True,
|
37 |
+
},
|
38 |
+
)
|
39 |
+
ret = response.json()
|
40 |
+
logits = np.array(
|
41 |
+
list(
|
42 |
+
ret[i]["meta_info"]["normalized_prompt_logprob"]
|
43 |
+
for i in range(len(prompts))
|
44 |
+
)
|
45 |
+
)
|
46 |
+
return logits
|
47 |
+
|
48 |
+
|
49 |
+
def get_logits(url: str, prompt: str):
|
50 |
+
response = requests.post(
|
51 |
+
url + "/classify",
|
52 |
+
json={"text": prompt},
|
53 |
+
)
|
54 |
+
return response.json()["embedding"]
|
55 |
+
|
56 |
+
|
57 |
+
def get_logits_batch(url: str, prompts: list[str]):
|
58 |
+
response = requests.post(
|
59 |
+
url + "/classify",
|
60 |
+
json={"text": prompts},
|
61 |
+
)
|
62 |
+
return np.array([x["embedding"] for x in response.json()])
|
63 |
+
|
64 |
+
|
65 |
+
if __name__ == "__main__":
|
66 |
+
parser = argparse.ArgumentParser()
|
67 |
+
parser.add_argument("--host", type=str, default="http://127.0.0.1")
|
68 |
+
parser.add_argument("--port", type=int, default=30000)
|
69 |
+
args = parser.parse_args()
|
70 |
+
|
71 |
+
url = f"{args.host}:{args.port}"
|
72 |
+
|
73 |
+
# A single request
|
74 |
+
prompt = "This is a test prompt.<|eot_id|>"
|
75 |
+
logits = get_logits(url, prompt)
|
76 |
+
print(f"{logits=}")
|
77 |
+
|
78 |
+
# A batch of requests
|
79 |
+
prompts = [
|
80 |
+
"This is a test prompt.<|eot_id|>",
|
81 |
+
"This is another test prompt.<|eot_id|>",
|
82 |
+
"This is a long long long long test prompt.<|eot_id|>",
|
83 |
+
]
|
84 |
+
logits = get_logits_batch(url, prompts)
|
85 |
+
print(f"{logits=}")
|
sglang/scripts/deprecated/test_httpserver_decode_stream.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Usage:
|
3 |
+
python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --port 30000
|
4 |
+
python3 test_httpserver_decode_stream.py
|
5 |
+
|
6 |
+
Output:
|
7 |
+
The capital of France is Paris.\nThe capital of the United States is Washington, D.C.\nThe capital of Canada is Ottawa.\nThe capital of Japan is Tokyo
|
8 |
+
"""
|
9 |
+
|
10 |
+
import argparse
|
11 |
+
import json
|
12 |
+
|
13 |
+
import requests
|
14 |
+
|
15 |
+
|
16 |
+
def test_decode_stream(url, return_logprob, top_logprobs_num):
|
17 |
+
response = requests.post(
|
18 |
+
url + "/generate",
|
19 |
+
json={
|
20 |
+
"text": "The capital of France is",
|
21 |
+
"sampling_params": {
|
22 |
+
"temperature": 0,
|
23 |
+
"max_new_tokens": 128,
|
24 |
+
},
|
25 |
+
"stream": True,
|
26 |
+
"return_logprob": return_logprob,
|
27 |
+
"top_logprobs_num": top_logprobs_num,
|
28 |
+
"return_text_in_logprobs": True,
|
29 |
+
"logprob_start_len": 0,
|
30 |
+
},
|
31 |
+
stream=True,
|
32 |
+
)
|
33 |
+
|
34 |
+
prev = 0
|
35 |
+
for chunk in response.iter_lines(decode_unicode=False):
|
36 |
+
chunk = chunk.decode("utf-8")
|
37 |
+
if chunk and chunk.startswith("data:"):
|
38 |
+
if chunk == "data: [DONE]":
|
39 |
+
break
|
40 |
+
data = json.loads(chunk[5:].strip("\n"))
|
41 |
+
|
42 |
+
if return_logprob:
|
43 |
+
assert data["meta_info"]["input_token_logprobs"] is not None
|
44 |
+
assert data["meta_info"]["output_token_logprobs"] is not None
|
45 |
+
assert data["meta_info"]["normalized_prompt_logprob"] is not None
|
46 |
+
for logprob, token_id, token_text in data["meta_info"][
|
47 |
+
"output_token_logprobs"
|
48 |
+
][prev:]:
|
49 |
+
print(f"{token_text:12s}\t{logprob}\t{token_id}", flush=True)
|
50 |
+
prev = len(data["meta_info"]["output_token_logprobs"])
|
51 |
+
else:
|
52 |
+
output = data["text"].strip()
|
53 |
+
print(output[prev:], end="", flush=True)
|
54 |
+
prev = len(output)
|
55 |
+
|
56 |
+
print("=" * 100)
|
57 |
+
|
58 |
+
|
59 |
+
if __name__ == "__main__":
|
60 |
+
parser = argparse.ArgumentParser()
|
61 |
+
parser.add_argument("--host", type=str, default="http://127.0.0.1")
|
62 |
+
parser.add_argument("--port", type=int, default=30000)
|
63 |
+
args = parser.parse_args()
|
64 |
+
|
65 |
+
url = f"{args.host}:{args.port}"
|
66 |
+
|
67 |
+
test_decode_stream(url, False, 0)
|
68 |
+
test_decode_stream(url, True, 0)
|
69 |
+
test_decode_stream(url, True, 3)
|
sglang/scripts/deprecated/test_httpserver_llava.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Usage:
|
3 |
+
python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --port 30000
|
4 |
+
python3 test_httpserver_llava.py
|
5 |
+
|
6 |
+
Output:
|
7 |
+
The image features a man standing on the back of a yellow taxi cab, holding
|
8 |
+
"""
|
9 |
+
|
10 |
+
import argparse
|
11 |
+
import asyncio
|
12 |
+
import json
|
13 |
+
|
14 |
+
import aiohttp
|
15 |
+
import requests
|
16 |
+
|
17 |
+
|
18 |
+
async def send_request(url, data, delay=0):
|
19 |
+
await asyncio.sleep(delay)
|
20 |
+
async with aiohttp.ClientSession() as session:
|
21 |
+
async with session.post(url, json=data) as resp:
|
22 |
+
output = await resp.json()
|
23 |
+
return output
|
24 |
+
|
25 |
+
|
26 |
+
async def test_concurrent(args):
|
27 |
+
url = f"{args.host}:{args.port}"
|
28 |
+
|
29 |
+
response = []
|
30 |
+
for i in range(8):
|
31 |
+
response.append(
|
32 |
+
send_request(
|
33 |
+
url + "/generate",
|
34 |
+
{
|
35 |
+
"text": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <image>\nDescribe this picture ASSISTANT:",
|
36 |
+
"image_data": "example_image.png",
|
37 |
+
"sampling_params": {
|
38 |
+
"temperature": 0,
|
39 |
+
"max_new_tokens": 64,
|
40 |
+
},
|
41 |
+
},
|
42 |
+
)
|
43 |
+
)
|
44 |
+
|
45 |
+
rets = await asyncio.gather(*response)
|
46 |
+
for ret in rets:
|
47 |
+
print(ret["text"])
|
48 |
+
|
49 |
+
|
50 |
+
def test_streaming(args):
|
51 |
+
url = f"{args.host}:{args.port}"
|
52 |
+
|
53 |
+
response = requests.post(
|
54 |
+
url + "/generate",
|
55 |
+
json={
|
56 |
+
"text": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <image>\nDescribe this picture ASSISTANT:",
|
57 |
+
"image_data": "example_image.png",
|
58 |
+
"sampling_params": {
|
59 |
+
"temperature": 0,
|
60 |
+
"max_new_tokens": 128,
|
61 |
+
},
|
62 |
+
"stream": True,
|
63 |
+
},
|
64 |
+
stream=True,
|
65 |
+
)
|
66 |
+
|
67 |
+
prev = 0
|
68 |
+
for chunk in response.iter_lines(decode_unicode=False):
|
69 |
+
chunk = chunk.decode("utf-8")
|
70 |
+
if chunk and chunk.startswith("data:"):
|
71 |
+
if chunk == "data: [DONE]":
|
72 |
+
break
|
73 |
+
data = json.loads(chunk[5:].strip("\n"))
|
74 |
+
output = data["text"].strip()
|
75 |
+
print(output[prev:], end="", flush=True)
|
76 |
+
prev = len(output)
|
77 |
+
print("")
|
78 |
+
|
79 |
+
|
80 |
+
if __name__ == "__main__":
|
81 |
+
parser = argparse.ArgumentParser()
|
82 |
+
parser.add_argument("--host", type=str, default="http://127.0.0.1")
|
83 |
+
parser.add_argument("--port", type=int, default=30000)
|
84 |
+
args = parser.parse_args()
|
85 |
+
|
86 |
+
asyncio.run(test_concurrent(args))
|
87 |
+
|
88 |
+
test_streaming(args)
|
sglang/scripts/deprecated/test_httpserver_reuse.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --port 30000
|
3 |
+
|
4 |
+
Output:
|
5 |
+
The capital of France is Paris.\nThe capital of the United States is Washington, D.C.\nThe capital of Canada is Ottawa.\nThe capital of Japan is Tokyo
|
6 |
+
"""
|
7 |
+
|
8 |
+
import argparse
|
9 |
+
|
10 |
+
import requests
|
11 |
+
|
12 |
+
if __name__ == "__main__":
|
13 |
+
parser = argparse.ArgumentParser()
|
14 |
+
parser.add_argument("--host", type=str, default="http://127.0.0.1")
|
15 |
+
parser.add_argument("--port", type=int, default=30000)
|
16 |
+
args = parser.parse_args()
|
17 |
+
|
18 |
+
url = f"{args.host}:{args.port}"
|
19 |
+
|
20 |
+
response = requests.post(
|
21 |
+
url + "/generate",
|
22 |
+
json={
|
23 |
+
"text": "The capital of France is",
|
24 |
+
"sampling_params": {
|
25 |
+
"temperature": 0,
|
26 |
+
"max_new_tokens": 32,
|
27 |
+
},
|
28 |
+
},
|
29 |
+
)
|
30 |
+
print(response.json())
|
31 |
+
|
32 |
+
response = requests.post(
|
33 |
+
url + "/generate",
|
34 |
+
json={
|
35 |
+
"text": "The capital of France is Paris.\nThe capital of the United States is",
|
36 |
+
"sampling_params": {
|
37 |
+
"temperature": 0,
|
38 |
+
"max_new_tokens": 32,
|
39 |
+
},
|
40 |
+
},
|
41 |
+
)
|
42 |
+
print(response.json())
|
sglang/scripts/deprecated/test_jump_forward.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from enum import Enum
|
3 |
+
|
4 |
+
from pydantic import BaseModel, constr
|
5 |
+
|
6 |
+
import sglang as sgl
|
7 |
+
from sglang.srt.constrained import build_regex_from_object
|
8 |
+
from sglang.test.test_utils import (
|
9 |
+
add_common_sglang_args_and_parse,
|
10 |
+
select_sglang_backend,
|
11 |
+
)
|
12 |
+
|
13 |
+
IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
|
14 |
+
|
15 |
+
ip_jump_forward = (
|
16 |
+
r"The google's DNS sever address is "
|
17 |
+
+ IP_REGEX
|
18 |
+
+ r" and "
|
19 |
+
+ IP_REGEX
|
20 |
+
+ r". "
|
21 |
+
+ r"The google's website domain name is "
|
22 |
+
+ r"www\.(\w)+\.(\w)+"
|
23 |
+
+ r"."
|
24 |
+
)
|
25 |
+
|
26 |
+
|
27 |
+
# fmt: off
|
28 |
+
@sgl.function
|
29 |
+
def regex_gen(s):
|
30 |
+
s += "Q: What is the IP address of the Google DNS servers?\n"
|
31 |
+
s += "A: " + sgl.gen(
|
32 |
+
"answer",
|
33 |
+
max_tokens=128,
|
34 |
+
temperature=0,
|
35 |
+
regex=ip_jump_forward,
|
36 |
+
)
|
37 |
+
# fmt: on
|
38 |
+
|
39 |
+
json_jump_forward = (
|
40 |
+
r"""The information about Hogwarts is in the following JSON format\.\n"""
|
41 |
+
+ r"""\n\{\n"""
|
42 |
+
+ r""" "name": "[\w\d\s]*",\n"""
|
43 |
+
+ r""" "country": "[\w\d\s]*",\n"""
|
44 |
+
+ r""" "latitude": [-+]?[0-9]*\.?[0-9]+,\n"""
|
45 |
+
+ r""" "population": [-+]?[0-9]+,\n"""
|
46 |
+
+ r""" "top 3 landmarks": \["[\w\d\s]*", "[\w\d\s]*", "[\w\d\s]*"\],\n"""
|
47 |
+
+ r"""\}\n"""
|
48 |
+
)
|
49 |
+
|
50 |
+
# fmt: off
|
51 |
+
@sgl.function
|
52 |
+
def json_gen(s):
|
53 |
+
s += sgl.gen(
|
54 |
+
"json",
|
55 |
+
max_tokens=128,
|
56 |
+
temperature=0,
|
57 |
+
regex=json_jump_forward,
|
58 |
+
)
|
59 |
+
# fmt: on
|
60 |
+
|
61 |
+
|
62 |
+
class Weapon(str, Enum):
|
63 |
+
sword = "sword"
|
64 |
+
axe = "axe"
|
65 |
+
mace = "mace"
|
66 |
+
spear = "spear"
|
67 |
+
bow = "bow"
|
68 |
+
crossbow = "crossbow"
|
69 |
+
|
70 |
+
|
71 |
+
class Armor(str, Enum):
|
72 |
+
leather = "leather"
|
73 |
+
chainmail = "chainmail"
|
74 |
+
plate = "plate"
|
75 |
+
|
76 |
+
|
77 |
+
class Character(BaseModel):
|
78 |
+
name: constr(max_length=10)
|
79 |
+
age: int
|
80 |
+
armor: Armor
|
81 |
+
weapon: Weapon
|
82 |
+
strength: int
|
83 |
+
|
84 |
+
|
85 |
+
@sgl.function
|
86 |
+
def character_gen(s):
|
87 |
+
s += "Give me a character description who is a wizard.\n"
|
88 |
+
s += sgl.gen(
|
89 |
+
"character",
|
90 |
+
max_tokens=128,
|
91 |
+
temperature=0,
|
92 |
+
regex=build_regex_from_object(Character),
|
93 |
+
)
|
94 |
+
|
95 |
+
|
96 |
+
def main(args):
|
97 |
+
# Select backend
|
98 |
+
backend = select_sglang_backend(args)
|
99 |
+
sgl.set_default_backend(backend)
|
100 |
+
|
101 |
+
state = regex_gen.run(temperature=0)
|
102 |
+
|
103 |
+
print("=" * 20, "IP TEST", "=" * 20)
|
104 |
+
print(state.text())
|
105 |
+
|
106 |
+
state = json_gen.run(temperature=0)
|
107 |
+
|
108 |
+
print("=" * 20, "JSON TEST", "=" * 20)
|
109 |
+
print(state.text())
|
110 |
+
|
111 |
+
state = character_gen.run(temperature=0)
|
112 |
+
|
113 |
+
print("=" * 20, "CHARACTER TEST", "=" * 20)
|
114 |
+
print(state.text())
|
115 |
+
|
116 |
+
|
117 |
+
if __name__ == "__main__":
|
118 |
+
parser = argparse.ArgumentParser()
|
119 |
+
args = add_common_sglang_args_and_parse(parser)
|
120 |
+
main(args)
|
121 |
+
|
122 |
+
# ==================== IP TEST ====================
|
123 |
+
# Q: What is the IP address of the Google DNS servers?
|
124 |
+
# A: The google's DNS sever address is 8.8.8.8 and 8.8.4.4. The google's website domain name is www.google.com.
|
125 |
+
# ==================== JSON TEST ====================
|
126 |
+
# The information about Hogwarts is in the following JSON format.
|
127 |
+
|
128 |
+
# {
|
129 |
+
# "name": "Hogwarts School of Witchcraft and Wizardry",
|
130 |
+
# "country": "Scotland",
|
131 |
+
# "latitude": 55.566667,
|
132 |
+
# "population": 1000,
|
133 |
+
# "top 3 landmarks": ["Hogwarts Castle", "The Great Hall", "The Forbidden Forest"],
|
134 |
+
# }
|
135 |
+
|
136 |
+
# ==================== CHARACTER TEST ====================
|
137 |
+
# Give me a character description who is a wizard.
|
138 |
+
# { "name" : "Merlin", "age" : 500, "armor" : "chainmail" , "weapon" : "sword" , "strength" : 10 }
|
sglang/scripts/deprecated/test_robust.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import random
|
3 |
+
import string
|
4 |
+
|
5 |
+
from vllm.transformers_utils.tokenizer import get_tokenizer
|
6 |
+
|
7 |
+
import sglang as sgl
|
8 |
+
from sglang.test.test_utils import (
|
9 |
+
add_common_sglang_args_and_parse,
|
10 |
+
select_sglang_backend,
|
11 |
+
)
|
12 |
+
|
13 |
+
TOKENIZER = None
|
14 |
+
RANDOM_PREFILL_LEN = None
|
15 |
+
RANDOM_DECODE_LEN = None
|
16 |
+
|
17 |
+
|
18 |
+
def gen_prompt(token_num):
|
19 |
+
if RANDOM_PREFILL_LEN:
|
20 |
+
token_num = random.randint(1, token_num)
|
21 |
+
|
22 |
+
cha_set = string.ascii_letters + string.digits
|
23 |
+
ret = "".join(random.choices(cha_set, k=token_num))
|
24 |
+
while len(TOKENIZER(ret).input_ids) < token_num:
|
25 |
+
ret += random.choice(cha_set)
|
26 |
+
|
27 |
+
return ret
|
28 |
+
|
29 |
+
|
30 |
+
def robust_test_dfs(s, d, args, leaf_states):
|
31 |
+
if d == 0:
|
32 |
+
s += "END"
|
33 |
+
leaf_states.append(s)
|
34 |
+
return
|
35 |
+
|
36 |
+
s += gen_prompt(args.len_prefill)
|
37 |
+
forks = s.fork(args.num_fork)
|
38 |
+
for fork_s in forks:
|
39 |
+
fork_s += gen_prompt(args.len_prefill)
|
40 |
+
new_tokens = (
|
41 |
+
args.len_decode
|
42 |
+
if not RANDOM_DECODE_LEN
|
43 |
+
else random.randint(1, args.len_decode)
|
44 |
+
)
|
45 |
+
fork_s += sgl.gen(
|
46 |
+
max_tokens=new_tokens,
|
47 |
+
ignore_eos=True,
|
48 |
+
)
|
49 |
+
|
50 |
+
for fork_s in forks:
|
51 |
+
robust_test_dfs(fork_s, d - 1, args, leaf_states)
|
52 |
+
|
53 |
+
|
54 |
+
def robust_test_bfs(s, args, leaf_states):
|
55 |
+
old_forks = [s]
|
56 |
+
new_forks = []
|
57 |
+
for _ in range(args.depth):
|
58 |
+
for old_fork in old_forks:
|
59 |
+
old_fork += gen_prompt(args.len_prefill)
|
60 |
+
forks = old_fork.fork(args.num_fork)
|
61 |
+
for fork_s in forks:
|
62 |
+
fork_s += gen_prompt(args.len_prefill)
|
63 |
+
new_tokens = (
|
64 |
+
args.len_decode
|
65 |
+
if not RANDOM_DECODE_LEN
|
66 |
+
else random.randint(1, args.len_decode)
|
67 |
+
)
|
68 |
+
fork_s += sgl.gen(
|
69 |
+
max_tokens=new_tokens,
|
70 |
+
ignore_eos=True,
|
71 |
+
)
|
72 |
+
new_forks.extend(forks)
|
73 |
+
|
74 |
+
old_forks = new_forks
|
75 |
+
new_forks = []
|
76 |
+
|
77 |
+
for old_fork in old_forks:
|
78 |
+
old_fork += "END"
|
79 |
+
leaf_states.append(old_fork)
|
80 |
+
|
81 |
+
|
82 |
+
@sgl.function
|
83 |
+
def robust_test(s, args):
|
84 |
+
leaf_states = []
|
85 |
+
if args.mode == "bfs":
|
86 |
+
robust_test_bfs(s, args, leaf_states)
|
87 |
+
else:
|
88 |
+
robust_test_dfs(s, args.depth, args, leaf_states)
|
89 |
+
return leaf_states
|
90 |
+
|
91 |
+
|
92 |
+
def main(args):
|
93 |
+
backend = select_sglang_backend(args)
|
94 |
+
|
95 |
+
arguments = [{"args": args} for _ in range(args.num_req)]
|
96 |
+
|
97 |
+
states = robust_test.run_batch(
|
98 |
+
arguments, temperature=0, backend=backend, num_threads=args.parallel
|
99 |
+
)
|
100 |
+
|
101 |
+
with open(f"tmp_robust_{args.mode}.txt", "w") as f:
|
102 |
+
for state in states:
|
103 |
+
leaf_states = state.ret_value
|
104 |
+
for leaf_state in leaf_states:
|
105 |
+
assert leaf_state.text()[-3:] == "END"
|
106 |
+
f.write(leaf_state.text()[:-3] + "\n")
|
107 |
+
|
108 |
+
|
109 |
+
if __name__ == "__main__":
|
110 |
+
# fmt: off
|
111 |
+
parser = argparse.ArgumentParser()
|
112 |
+
parser.add_argument("--num-req", type=int, default=2)
|
113 |
+
parser.add_argument("--depth", type=int, default=3)
|
114 |
+
parser.add_argument("--num-fork", type=int, default=2)
|
115 |
+
parser.add_argument("--len-prefill", type=int, default=128)
|
116 |
+
parser.add_argument("--len-decode", type=int, default=128)
|
117 |
+
parser.add_argument("--random-prefill-len", action="store_true")
|
118 |
+
parser.add_argument("--random-decode-len", action="store_true")
|
119 |
+
parser.add_argument("--mode", type=str, default="bfs", choices=["dfs", "bfs"])
|
120 |
+
parser.add_argument("--tokenizer", type=str, default = "meta-llama/Llama-2-7b-chat-hf")
|
121 |
+
parser.add_argument("--trust-remote-code", action="store_true")
|
122 |
+
parser.add_argument("--seed", type=int, default=42)
|
123 |
+
args = add_common_sglang_args_and_parse(parser)
|
124 |
+
# fmt: on
|
125 |
+
|
126 |
+
RANDOM_PREFILL_LEN = args.random_prefill_len
|
127 |
+
RANDOM_DECODE_LEN = args.random_decode_len
|
128 |
+
TOKENIZER = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code)
|
129 |
+
|
130 |
+
random.seed(args.seed)
|
131 |
+
|
132 |
+
main(args)
|