tuandunghcmut commited on
Commit
f2fefbb
·
verified ·
1 Parent(s): 25f579c

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. InternVL/.github/CONTRIBUTING.md +234 -0
  2. InternVL/internvl_chat_llava/LICENSE +201 -0
  3. InternVL/internvl_chat_llava/README.md +506 -0
  4. InternVL/internvl_chat_llava/pyproject.toml +33 -0
  5. InternVL/internvl_g/README.md +497 -0
  6. InternVL/segmentation/dist_test.sh +9 -0
  7. InternVL/segmentation/dist_train.sh +9 -0
  8. InternVL/segmentation/train.py +220 -0
  9. InternVL/streamlit_demo/constants.py +23 -0
  10. InternVL/streamlit_demo/controller.py +291 -0
  11. InternVL/streamlit_demo/model_worker.py +442 -0
  12. InternVL/video_retrieval/test_msrvtt.py +156 -0
  13. sglang/examples/frontend_language/quick_start/gemini_example_chat.py +73 -0
  14. sglang/examples/frontend_language/quick_start/gemini_example_multimodal_chat.py +30 -0
  15. sglang/examples/frontend_language/quick_start/local_example_complete.py +70 -0
  16. sglang/examples/frontend_language/quick_start/local_example_llava_next.py +78 -0
  17. sglang/examples/frontend_language/quick_start/openai_example_chat.py +74 -0
  18. sglang/examples/frontend_language/quick_start/openai_example_complete.py +68 -0
  19. sglang/examples/frontend_language/quick_start/openrouter_example_chat.py +81 -0
  20. sglang/examples/frontend_language/quick_start/together_example_complete.py +76 -0
  21. sglang/examples/frontend_language/usage/chinese_regex.py +53 -0
  22. sglang/examples/frontend_language/usage/choices_logprob.py +44 -0
  23. sglang/examples/frontend_language/usage/cot_decoding.py +115 -0
  24. sglang/examples/frontend_language/usage/json_decode.py +83 -0
  25. sglang/examples/frontend_language/usage/json_logprobs.py +104 -0
  26. sglang/examples/frontend_language/usage/llava_video/srt_example_llava_v.py +260 -0
  27. sglang/examples/frontend_language/usage/llava_video/srt_example_llava_v.sh +131 -0
  28. sglang/examples/frontend_language/usage/openai_chat_speculative.py +155 -0
  29. sglang/examples/frontend_language/usage/openai_speculative.py +54 -0
  30. sglang/examples/frontend_language/usage/parallel_sample.py +40 -0
  31. sglang/examples/frontend_language/usage/rag_using_parea/trace_and_evaluate_rag_using_parea.ipynb +408 -0
  32. sglang/examples/frontend_language/usage/readme_examples.py +109 -0
  33. sglang/examples/frontend_language/usage/sgl_gen_min_tokens.py +35 -0
  34. sglang/examples/frontend_language/usage/streaming.py +49 -0
  35. sglang/examples/frontend_language/usage/triton/Dockerfile +10 -0
  36. sglang/examples/frontend_language/usage/triton/README.md +35 -0
  37. sglang/examples/frontend_language/usage/triton/models/character_generation/1/model.py +55 -0
  38. sglang/examples/frontend_language/usage/triton/models/character_generation/config.pbtxt +23 -0
  39. sglang/examples/monitoring/grafana.json +1720 -0
  40. sglang/examples/monitoring/prometheus.yaml +10 -0
  41. sglang/examples/runtime/lora.py +37 -0
  42. sglang/examples/runtime/openai_batch_complete.py +93 -0
  43. sglang/examples/runtime/openai_chat_with_response_prefill.py +34 -0
  44. sglang/scripts/deprecated/convert_yi_vl.py +38 -0
  45. sglang/scripts/deprecated/test_httpserver_classify.py +85 -0
  46. sglang/scripts/deprecated/test_httpserver_decode_stream.py +69 -0
  47. sglang/scripts/deprecated/test_httpserver_llava.py +88 -0
  48. sglang/scripts/deprecated/test_httpserver_reuse.py +42 -0
  49. sglang/scripts/deprecated/test_jump_forward.py +138 -0
  50. sglang/scripts/deprecated/test_robust.py +132 -0
InternVL/.github/CONTRIBUTING.md ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Contributing to InternLM
2
+
3
+ Welcome to the InternLM community, all kinds of contributions are welcomed, including but not limited to
4
+
5
+ **Fix bug**
6
+
7
+ You can directly post a Pull Request to fix typo in code or documents
8
+
9
+ The steps to fix the bug of code implementation are as follows.
10
+
11
+ 1. If the modification involve significant changes, you should create an issue first and describe the error information and how to trigger the bug. Other developers will discuss with you and propose an proper solution.
12
+
13
+ 2. Posting a pull request after fixing the bug and adding corresponding unit test.
14
+
15
+ **New Feature or Enhancement**
16
+
17
+ 1. If the modification involve significant changes, you should create an issue to discuss with our developers to propose an proper design.
18
+ 2. Post a Pull Request after implementing the new feature or enhancement and add corresponding unit test.
19
+
20
+ **Document**
21
+
22
+ You can directly post a pull request to fix documents. If you want to add a document, you should first create an issue to check if it is reasonable.
23
+
24
+ ### Pull Request Workflow
25
+
26
+ If you're not familiar with Pull Request, don't worry! The following guidance will tell you how to create a Pull Request step by step. If you want to dive into the develop mode of Pull Request, you can refer to the [official documents](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests/about-pull-requests)
27
+
28
+ #### 1. Fork and clone
29
+
30
+ If you are posting a pull request for the first time, you should fork the OpenMMLab repositories by clicking the **Fork** button in the top right corner of the GitHub page, and the forked repositories will appear under your GitHub profile.
31
+
32
+ <img src="https://user-images.githubusercontent.com/57566630/167305749-43c7f4e9-449b-4e98-ade5-0c9276d5c9ce.png" width="1200">
33
+
34
+ Then, you can clone the repositories to local:
35
+
36
+ ```shell
37
+ git clone [email protected]:{username}/lmdeploy.git
38
+ ```
39
+
40
+ After that, you should add official repository as the upstream repository
41
+
42
+ ```bash
43
+ git remote add upstream [email protected]:InternLM/lmdeploy.git
44
+ ```
45
+
46
+ Check whether remote repository has been added successfully by `git remote -v`
47
+
48
+ ```bash
49
+ origin [email protected]:{username}/lmdeploy.git (fetch)
50
+ origin [email protected]:{username}/lmdeploy.git (push)
51
+ upstream [email protected]:InternLM/lmdeploy.git (fetch)
52
+ upstream [email protected]:InternLM/lmdeploy.git (push)
53
+ ```
54
+
55
+ > Here's a brief introduction to origin and upstream. When we use "git clone", we create an "origin" remote by default, which points to the repository cloned from. As for "upstream", we add it ourselves to point to the target repository. Of course, if you don't like the name "upstream", you could name it as you wish. Usually, we'll push the code to "origin". If the pushed code conflicts with the latest code in official("upstream"), we should pull the latest code from upstream to resolve the conflicts, and then push to "origin" again. The posted Pull Request will be updated automatically.
56
+
57
+ #### 2. Configure pre-commit
58
+
59
+ You should configure [pre-commit](https://pre-commit.com/#intro) in the local development environment to make sure the code style matches that of InternLM. **Note**: The following code should be executed under the lmdeploy directory.
60
+
61
+ ```shell
62
+ pip install -U pre-commit
63
+ pre-commit install
64
+ ```
65
+
66
+ Check that pre-commit is configured successfully, and install the hooks defined in `.pre-commit-config.yaml`.
67
+
68
+ ```shell
69
+ pre-commit run --all-files
70
+ ```
71
+
72
+ <img src="https://user-images.githubusercontent.com/57566630/173660750-3df20a63-cb66-4d33-a986-1f643f1d8aaf.png" width="1200">
73
+
74
+ <img src="https://user-images.githubusercontent.com/57566630/202368856-0465a90d-8fce-4345-918e-67b8b9c82614.png" width="1200">
75
+
76
+ If the installation process is interrupted, you can repeatedly run `pre-commit run ... ` to continue the installation.
77
+
78
+ If the code does not conform to the code style specification, pre-commit will raise a warning and fixes some of the errors automatically.
79
+
80
+ <img src="https://user-images.githubusercontent.com/57566630/202369176-67642454-0025-4023-a095-263529107aa3.png" width="1200">
81
+
82
+ If we want to commit our code bypassing the pre-commit hook, we can use the `--no-verify` option(**only for temporarily commit**).
83
+
84
+ ```shell
85
+ git commit -m "xxx" --no-verify
86
+ ```
87
+
88
+ #### 3. Create a development branch
89
+
90
+ After configuring the pre-commit, we should create a branch based on the master branch to develop the new feature or fix the bug. The proposed branch name is `username/pr_name`
91
+
92
+ ```shell
93
+ git checkout -b yhc/refactor_contributing_doc
94
+ ```
95
+
96
+ In subsequent development, if the master branch of the local repository is behind the master branch of "upstream", we need to pull the upstream for synchronization, and then execute the above command:
97
+
98
+ ```shell
99
+ git pull upstream master
100
+ ```
101
+
102
+ #### 4. Commit the code and pass the unit test
103
+
104
+ - lmdeploy introduces mypy to do static type checking to increase the robustness of the code. Therefore, we need to add Type Hints to our code and pass the mypy check. If you are not familiar with Type Hints, you can refer to [this tutorial](https://docs.python.org/3/library/typing.html).
105
+
106
+ - The committed code should pass through the unit test
107
+
108
+ ```shell
109
+ # Pass all unit tests
110
+ pytest tests
111
+
112
+ # Pass the unit test of runner
113
+ pytest tests/test_runner/test_runner.py
114
+ ```
115
+
116
+ If the unit test fails for lack of dependencies, you can install the dependencies referring to the [guidance](#unit-test)
117
+
118
+ - If the documents are modified/added, we should check the rendering result referring to [guidance](#document-rendering)
119
+
120
+ #### 5. Push the code to remote
121
+
122
+ We could push the local commits to remote after passing through the check of unit test and pre-commit. You can associate the local branch with remote branch by adding `-u` option.
123
+
124
+ ```shell
125
+ git push -u origin {branch_name}
126
+ ```
127
+
128
+ This will allow you to use the `git push` command to push code directly next time, without having to specify a branch or the remote repository.
129
+
130
+ #### 6. Create a Pull Request
131
+
132
+ (1) Create a pull request in GitHub's Pull request interface
133
+
134
+ <img src="https://user-images.githubusercontent.com/57566630/201533288-516f7ac4-0b14-4dc8-afbd-912475c368b5.png" width="1200">
135
+
136
+ (2) Modify the PR description according to the guidelines so that other developers can better understand your changes
137
+
138
+ <img src="https://user-images.githubusercontent.com/57566630/202242953-c91a18ff-e388-4ff9-8591-5fae0ead6c1e.png" width="1200">
139
+
140
+ Find more details about Pull Request description in [pull request guidelines](#pr-specs).
141
+
142
+ **note**
143
+
144
+ (a) The Pull Request description should contain the reason for the change, the content of the change, and the impact of the change, and be associated with the relevant Issue (see [documentation](https://docs.github.com/en/issues/tracking-your-work-with-issues/linking-a-pull-request-to-an-issue))
145
+
146
+ (b) If it is your first contribution, please sign the CLA
147
+
148
+ <img src="https://user-images.githubusercontent.com/57566630/167307569-a794b967-6e28-4eac-a942-00deb657815f.png" width="1200">
149
+
150
+ (c) Check whether the Pull Request pass through the CI
151
+
152
+ <img src="https://user-images.githubusercontent.com/57566630/167307490-f9ebf9fa-63c0-4d83-8ba1-081ea169eb3a.png" width="1200">
153
+
154
+ IternLM will run unit test for the posted Pull Request on different platforms (Linux, Window, Mac), based on different versions of Python, PyTorch, CUDA to make sure the code is correct. We can see the specific test information by clicking `Details` in the above image so that we can modify the code.
155
+
156
+ (3) If the Pull Request passes the CI, then you can wait for the review from other developers. You'll modify the code based on the reviewer's comments, and repeat the steps [4](#4-commit-the-code-and-pass-the-unit-test)-[5](#5-push-the-code-to-remote) until all reviewers approve it. Then, we will merge it ASAP.
157
+
158
+ <img src="https://user-images.githubusercontent.com/57566630/202145400-cc2cd8c4-10b0-472f-ba37-07e6f50acc67.png" width="1200">
159
+
160
+ #### 7. Resolve conflicts
161
+
162
+ If your local branch conflicts with the latest master branch of "upstream", you'll need to resolove them. There are two ways to do this:
163
+
164
+ ```shell
165
+ git fetch --all --prune
166
+ git rebase upstream/master
167
+ ```
168
+
169
+ or
170
+
171
+ ```shell
172
+ git fetch --all --prune
173
+ git merge upstream/master
174
+ ```
175
+
176
+ If you are very good at handling conflicts, then you can use rebase to resolve conflicts, as this will keep your commit logs tidy. If you are not familiar with `rebase`, then you can use `merge` to resolve conflicts.
177
+
178
+ ### Guidance
179
+
180
+ #### Document rendering
181
+
182
+ If the documents are modified/added, we should check the rendering result. We could install the dependencies and run the following command to render the documents and check the results:
183
+
184
+ ```shell
185
+ pip install -r requirements/docs.txt
186
+ cd docs/zh_cn/
187
+ # or docs/en
188
+ make html
189
+ # check file in ./docs/zh_cn/_build/html/index.html
190
+ ```
191
+
192
+ ### Code style
193
+
194
+ #### Python
195
+
196
+ We adopt [PEP8](https://www.python.org/dev/peps/pep-0008/) as the preferred code style.
197
+
198
+ We use the following tools for linting and formatting:
199
+
200
+ - [flake8](https://github.com/PyCQA/flake8): A wrapper around some linter tools.
201
+ - [isort](https://github.com/timothycrosley/isort): A Python utility to sort imports.
202
+ - [yapf](https://github.com/google/yapf): A formatter for Python files.
203
+ - [codespell](https://github.com/codespell-project/codespell): A Python utility to fix common misspellings in text files.
204
+ - [mdformat](https://github.com/executablebooks/mdformat): Mdformat is an opinionated Markdown formatter that can be used to enforce a consistent style in Markdown files.
205
+ - [docformatter](https://github.com/myint/docformatter): A formatter to format docstring.
206
+
207
+ We use [pre-commit hook](https://pre-commit.com/) that checks and formats for `flake8`, `yapf`, `isort`, `trailing whitespaces`, `markdown files`,
208
+ fixes `end-of-files`, `double-quoted-strings`, `python-encoding-pragma`, `mixed-line-ending`, sorts `requirments.txt` automatically on every commit.
209
+ The config for a pre-commit hook is stored in [.pre-commit-config](../.pre-commit-config.yaml).
210
+
211
+ #### C++ and CUDA
212
+
213
+ The clang-format config is stored in [.clang-format](../.clang-format). And it's recommended to use clang-format version **11**. Please do not use older or newer versions as they will result in differences after formatting, which can cause the [lint](https://github.com/InternLM/lmdeploy/blob/main/.github/workflows/lint.yml#L25) to fail.
214
+
215
+ ### PR Specs
216
+
217
+ 1. Use [pre-commit](https://pre-commit.com) hook to avoid issues of code style
218
+
219
+ 2. One short-time branch should be matched with only one PR
220
+
221
+ 3. Accomplish a detailed change in one PR. Avoid large PR
222
+
223
+ - Bad: Support Faster R-CNN
224
+ - Acceptable: Add a box head to Faster R-CNN
225
+ - Good: Add a parameter to box head to support custom conv-layer number
226
+
227
+ 4. Provide clear and significant commit message
228
+
229
+ 5. Provide clear and meaningful PR description
230
+
231
+ - Task name should be clarified in title. The general format is: \[Prefix\] Short description of the PR (Suffix)
232
+ - Prefix: add new feature \[Feature\], fix bug \[Fix\], related to documents \[Docs\], in developing \[WIP\] (which will not be reviewed temporarily)
233
+ - Introduce main changes, results and influences on other modules in short description
234
+ - Associate related issues and pull requests with a milestone
InternVL/internvl_chat_llava/LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
InternVL/internvl_chat_llava/README.md ADDED
@@ -0,0 +1,506 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # InternVL for Multimodal Dialogue using LLaVA Codebase
2
+
3
+ This folder contains the implementation of the InternVL-Chat V1.0, which corresponds to Section 4.4 of our [InternVL 1.0 paper](https://arxiv.org/pdf/2312.14238).
4
+
5
+ In this part, we mainly use the [LLaVA codebase](https://github.com/haotian-liu/LLaVA) to evaluate InternVL in creating multimodal dialogue systems. Thanks for this great work.
6
+ We have retained the original documentation of LLaVA-1.5 as a more detailed manual. In most cases, you will only need to refer to the new documentation that we have added.
7
+
8
+ > Note: To unify the environment across different tasks, we have made some compatibility modifications to the LLaVA-1.5 code, allowing it to support `transformers==4.37.2` (originally locked at 4.31.0). Please note that `transformers==4.37.2` should be installed.
9
+
10
+ ## 🛠️ Installation
11
+
12
+ First, follow the [installation guide](../INSTALLATION.md) to perform some basic installations.
13
+
14
+ In addition, using this codebase requires executing the following steps:
15
+
16
+ - Install other requirements:
17
+
18
+ ```bash
19
+ pip install --upgrade pip # enable PEP 660 support
20
+ pip install -e .
21
+ ```
22
+
23
+ ## 📦 Model Preparation
24
+
25
+ | model name | type | download | size |
26
+ | ----------------------- | ----------- | ---------------------------------------------------------------------- | :-----: |
27
+ | InternViT-6B-224px | huggingface | 🤗 [HF link](https://huggingface.co/OpenGVLab/InternViT-6B-224px) | 12 GB |
28
+ | InternViT-6B-448px-V1-0 | huggingface | 🤗 [HF link](https://huggingface.co/OpenGVLab/InternViT-6B-448px-V1-0) | 12 GB |
29
+ | vicuna-13b-v1.5 | huggingface | 🤗 [HF link](https://huggingface.co/lmsys/vicuna-7b-v1.5) | 13.5 GB |
30
+ | vicuna-7b-v1.5 | huggingface | 🤗 [HF link](https://huggingface.co/lmsys/vicuna-13b-v1.5) | 26.1 GB |
31
+
32
+ Please download the above model weights and place them in the `pretrained/` folder.
33
+
34
+ ```sh
35
+ cd pretrained/
36
+ # pip install -U huggingface_hub
37
+ huggingface-cli download --resume-download --local-dir-use-symlinks False OpenGVLab/InternViT-6B-224px --local-dir InternViT-6B-224px
38
+ huggingface-cli download --resume-download --local-dir-use-symlinks False OpenGVLab/InternViT-6B-448px-V1-0 --local-dir InternViT-6B-448px
39
+ huggingface-cli download --resume-download --local-dir-use-symlinks False lmsys/vicuna-13b-v1.5 --local-dir vicuna-13b-v1.5
40
+ huggingface-cli download --resume-download --local-dir-use-symlinks False lmsys/vicuna-7b-v1.5 --local-dir vicuna-7b-v1.5
41
+ ```
42
+
43
+ The directory structure is:
44
+
45
+ ```sh
46
+ pretrained
47
+ │── InternViT-6B-224px/
48
+ │── InternViT-6B-448px/
49
+ │── vicuna-13b-v1.5/
50
+ └── vicuna-7b-v1.5/
51
+ ```
52
+
53
+ ## 🔥 Training
54
+
55
+ - InternViT-6B-224px + Vicuna-7B:
56
+
57
+ ```shell
58
+ # pretrain
59
+ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 sh scripts_internvl/pretrain_internvit6b_224to336_vicuna7b.sh
60
+ # finetune
61
+ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 sh scripts_internvl/finetune_internvit6b_224to336_vicuna7b.sh
62
+ ```
63
+
64
+ - InternViT-6B-224px + Vicuna-13B:
65
+
66
+ ```shell
67
+ # pretrain
68
+ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 sh scripts_internvl/pretrain_internvit6b_224to336_vicuna13b.sh
69
+ # finetune
70
+ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 sh scripts_internvl/finetune_internvit6b_224to336_vicuna13b.sh
71
+ ```
72
+
73
+ - InternViT-6B-448px + Vicuna-7B:
74
+
75
+ ```shell
76
+ # pretrain
77
+ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 sh scripts_internvl/pretrain_internvit6b_448_vicuna7b.sh
78
+ # finetune
79
+ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 sh scripts_internvl/finetune_internvit6b_448_vicuna7b.sh
80
+ ```
81
+
82
+ - InternViT-6B-448px + Vicuna-13B:
83
+
84
+ ```shell
85
+ # pretrain
86
+ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 sh scripts_internvl/pretrain_internvit6b_448_vicuna13b.sh
87
+ # finetune
88
+ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 sh scripts_internvl/finetune_internvit6b_448_vicuna13b.sh
89
+ ```
90
+
91
+ ## 🤗 Model Zoo
92
+
93
+ | method | vision encoder | LLM | res. | VQAv2 | GQA | VizWiz | SQA | TextVQA | POPE | MME | MMB | MMB<sub>CN</sub> | MMVet | Download |
94
+ | ----------------- | :------------: | :---: | :--: | :---: | :--: | :----: | :--: | :-----: | :--: | :----: | :--: | :--------------: | :---: | :----------------------------------------------------------------------------------: |
95
+ | LLaVA-1.5 | CLIP-L-336px | V-7B | 336 | 78.5 | 62.0 | 50.0 | 66.8 | 58.2 | 85.9 | 1510.7 | 64.3 | 58.3 | 30.5 | 🤗 [HF link](https://huggingface.co/liuhaotian/llava-v1.5-7b) |
96
+ | LLaVA-1.5 | CLIP-L-336px | V-13B | 336 | 80.0 | 63.3 | 53.6 | 71.6 | 61.3 | 85.9 | 1531.3 | 67.7 | 63.6 | 35.4 | 🤗 [HF link](https://huggingface.co/liuhaotian/llava-v1.5-13b) |
97
+ | InternVL-Chat-1.0 | IViT-6B-224px | V-7B | 336 | 79.3 | 62.9 | 52.5 | 66.2 | 57.0 | 86.4 | 1525.1 | 64.6 | 57.6 | 31.2 | 🤗 [HF link](https://huggingface.co/OpenGVLab/InternVL-Chat-ViT-6B-Vicuna-7B) |
98
+ | InternVL-Chat-1.0 | IViT-6B-224px | V-13B | 336 | 80.2 | 63.9 | 54.6 | 70.1 | 58.7 | 87.1 | 1546.9 | 66.5 | 61.9 | 33.7 | 🤗 [HF link](https://huggingface.co/OpenGVLab/InternVL-Chat-ViT-6B-Vicuna-13B) |
99
+ | InternVL-Chat-1.0 | IViT-6B-448px | V-13B | 448 | 82.0 | 64.1 | 60.1 | 71.6 | 64.8 | 87.2 | 1579.0 | 68.2 | 64.0 | 36.7 | 🤗 [HF link](https://huggingface.co/OpenGVLab/InternVL-Chat-ViT-6B-Vicuna-13B-448px) |
100
+
101
+ Please download the above model weights and place them in the `pretrained/` folder.
102
+
103
+ ```shell
104
+ cd pretrained/
105
+ # pip install -U huggingface_hub
106
+ huggingface-cli download --resume-download --local-dir-use-symlinks False OpenGVLab/InternVL-Chat-ViT-6B-Vicuna-7B --local-dir InternVL-Chat-ViT-6B-Vicuna-7B
107
+ huggingface-cli download --resume-download --local-dir-use-symlinks False OpenGVLab/InternVL-Chat-ViT-6B-Vicuna-13B --local-dir InternVL-Chat-ViT-6B-Vicuna-13B
108
+ huggingface-cli download --resume-download --local-dir-use-symlinks False OpenGVLab/InternVL-Chat-ViT-6B-Vicuna-13B-448px --local-dir InternVL-Chat-ViT-6B-Vicuna-13B-448px
109
+
110
+ ```
111
+
112
+ The directory structure is:
113
+
114
+ ```
115
+ pretrained
116
+ │── InternViT-6B-224px/
117
+ │── InternViT-6B-448px/
118
+ │── vicuna-13b-v1.5/
119
+ │── vicuna-7b-v1.5/
120
+ │── InternVL-Chat-ViT-6B-Vicuna-7B/
121
+ │── InternVL-Chat-ViT-6B-Vicuna-13B/
122
+ └── InternVL-Chat-ViT-6B-Vicuna-13B-448px/
123
+ ```
124
+
125
+ ## 🖥️ Demo
126
+
127
+ The method for deploying the demo is consistent with LLaVA-1.5. You only need to change the model path. The specific steps are as follows:
128
+
129
+ **Launch a controller**
130
+
131
+ ```shell
132
+ python -m llava.serve.controller --host 0.0.0.0 --port 10000
133
+ ```
134
+
135
+ **Launch a gradio web server**
136
+
137
+ ```shell
138
+ python -m llava.serve.gradio_web_server --controller http://localhost:10000 --model-list-mode reload --port 10038
139
+ ```
140
+
141
+ **Launch a model worker**
142
+
143
+ ```shell
144
+ # OpenGVLab/InternVL-Chat-ViT-6B-Vicuna-7B
145
+ python -m llava.serve.model_worker --host 0.0.0.0 --controller http://localhost:10000 --port 40000 --worker http://localhost:40000 --model-path ./pretrained/InternVL-Chat-ViT-6B-Vicuna-7B
146
+ # OpenGVLab/InternVL-Chat-ViT-6B-Vicuna-13B
147
+ python -m llava.serve.model_worker --host 0.0.0.0 --controller http://localhost:10000 --port 40001 --worker http://localhost:40001 --model-path ./pretrained/InternVL-Chat-ViT-6B-Vicuna-13B
148
+ ```
149
+
150
+ After completing the above steps, you can access the web demo at `http://localhost:10038` and see the following page. Note that the models deployed here are `InternVL-Chat-ViT-6B-Vicuna-7B` and `InternVL-Chat-ViT-6B-Vicuna-13B`, which are the two models of our InternVL 1.0. The only difference from LLaVA-1.5 is that the CLIP-ViT-300M has been replaced with our InternViT-6B.
151
+
152
+ If you need a more effective MLLM, please check out our InternVL2 series models.
153
+ For more details on deploying the demo, please refer to [here](#gradio-web-ui).
154
+
155
+ ![llava_webui](https://github.com/user-attachments/assets/2ca2180f-70b9-41c7-8174-c518d4054248)
156
+
157
+ ## 💡 Testing
158
+
159
+ The method for testing the model remains the same as LLaVA-1.5; you just need to change the path of the script. Our scripts are located in `scripts_internvl/`.
160
+
161
+ For example, testing `MME` using a single GPU:
162
+
163
+ ```shell
164
+ sh scripts_internvl/eval/mme.sh pretrained/InternVL-Chat-ViT-6B-Vicuna-7B/
165
+ ```
166
+
167
+ ______________________________________________________________________
168
+
169
+ ## 🌋 LLaVA: Large Language and Vision Assistant
170
+
171
+ *Visual instruction tuning towards large language and vision models with GPT-4 level capabilities.*
172
+
173
+ \[[Project Page](https://llava-vl.github.io/)\] \[[Demo](https://llava.hliu.cc/)\] \[[Data](https://github.com/haotian-liu/LLaVA/blob/main/docs/Data.md)\] \[[Model Zoo](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md)\]
174
+
175
+ 🤝Community Contributions: \[[llama.cpp](https://github.com/ggerganov/llama.cpp/pull/3436)\] \[[Colab](https://github.com/camenduru/LLaVA-colab)\] \[[🤗Space](https://huggingface.co/spaces/badayvedat/LLaVA)\]
176
+
177
+ **Improved Baselines with Visual Instruction Tuning** \[[Paper](https://arxiv.org/abs/2310.03744)\] <br>
178
+ [Haotian Liu](https://hliu.cc), [Chunyuan Li](https://chunyuan.li/), [Yuheng Li](https://yuheng-li.github.io/), [Yong Jae Lee](https://pages.cs.wisc.edu/~yongjaelee/)
179
+
180
+ **Visual Instruction Tuning** (NeurIPS 2023, **Oral**) \[[Paper](https://arxiv.org/abs/2304.08485)\]<br>
181
+ [Haotian Liu\*](https://hliu.cc), [Chunyuan Li\*](https://chunyuan.li/), [Qingyang Wu](https://scholar.google.ca/citations?user=HDiw-TsAAAAJ&hl=en/), [Yong Jae Lee](https://pages.cs.wisc.edu/~yongjaelee/) (\*Equal Contribution)
182
+
183
+ ### Release
184
+
185
+ - \[10/12\] 🔥 Check out the Korean LLaVA (Ko-LLaVA), created by ETRI, who has generously supported our research! \[[🤗 Demo](https://huggingface.co/spaces/etri-vilab/Ko-LLaVA)\]
186
+
187
+ - \[10/12\] LLaVA is now supported in [llama.cpp](https://github.com/ggerganov/llama.cpp/pull/3436) with 4-bit / 5-bit quantization support!
188
+
189
+ - \[10/11\] The training data and scripts of LLaVA-1.5 are released [here](https://github.com/haotian-liu/LLaVA#train), and evaluation scripts are released [here](https://github.com/haotian-liu/LLaVA/blob/main/docs/Evaluation.md)!
190
+
191
+ - \[10/5\] 🔥 LLaVA-1.5 is out! Achieving SoTA on 11 benchmarks, with just simple modifications to the original LLaVA, utilizes all public data, completes training in ~1 day on a single 8-A100 node, and surpasses methods like Qwen-VL-Chat that use billion-scale data. Check out the [technical report](https://arxiv.org/abs/2310.03744), and explore the [demo](https://llava.hliu.cc/)! Models are available in [Model Zoo](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md).
192
+
193
+ - \[9/26\] LLaVA is improved with reinforcement learning from human feedback (RLHF) to improve fact grounding and reduce hallucination. Check out the new SFT and RLHF checkpoints at project [\[LLavA-RLHF\]](https://llava-rlhf.github.io/)
194
+
195
+ - \[9/22\] [LLaVA](https://arxiv.org/abs/2304.08485) is accepted by NeurIPS 2023 as **oral presentation**, and [LLaVA-Med](https://arxiv.org/abs/2306.00890) is accepted by NeurIPS 2023 Datasets and Benchmarks Track as **spotlight presentation**.
196
+
197
+ - \[9/20\] We summarize our empirical study of training 33B and 65B LLaVA models in a [note](https://arxiv.org/abs/2309.09958). Further, if you are interested in the comprehensive review, evolution and trend of multimodal foundation models, please check out our recent survey paper [\`\`Multimodal Foundation Models: From Specialists to General-Purpose Assistants''.](https://arxiv.org/abs/2309.10020)
198
+
199
+ <p align="center">
200
+ <img src="https://github.com/Computer-Vision-in-the-Wild/CVinW_Readings/blob/main/images/mfm_evolution.jpeg?raw=true" width=50%/>
201
+ </p>
202
+
203
+ - \[7/19\] 🔥 We release a major upgrade, including support for LLaMA-2, LoRA training, 4-/8-bit inference, higher resolution (336x336), and a lot more. We release [LLaVA Bench](https://github.com/haotian-liu/LLaVA/blob/main/docs/LLaVA_Bench.md) for benchmarking open-ended visual chat with results from Bard and Bing-Chat. We also support and verify training with RTX 3090 and RTX A6000. Check out [LLaVA-from-LLaMA-2](https://github.com/haotian-liu/LLaVA/blob/main/docs/LLaVA_from_LLaMA2.md), and our [model zoo](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md)!
204
+
205
+ - \[6/26\] [CVPR 2023 Tutorial](https://vlp-tutorial.github.io/) on **Large Multimodal Models: Towards Building and Surpassing Multimodal GPT-4**! Please check out \[[Slides](https://datarelease.blob.core.windows.net/tutorial/vision_foundation_models_2023/slides/Chunyuan_cvpr2023_tutorial_lmm.pdf)\] \[[Notes](https://arxiv.org/abs/2306.14895)\] \[[YouTube](https://youtu.be/mkI7EPD1vp8)\] \[[Bilibli](https://www.bilibili.com/video/BV1Ng4y1T7v3/)\].
206
+
207
+ - \[6/11\] We released the preview for the most requested feature: DeepSpeed and LoRA support! Please see documentations [here](./docs/LoRA.md).
208
+
209
+ - \[6/1\] We released **LLaVA-Med: Large Language and Vision Assistant for Biomedicine**, a step towards building biomedical domain large language and vision models with GPT-4 level capabilities. Checkout the [paper](https://arxiv.org/abs/2306.00890) and [page](https://github.com/microsoft/LLaVA-Med).
210
+
211
+ - \[5/6\] We are releasing [LLaVA-Lighting-MPT-7B-preview](https://huggingface.co/liuhaotian/LLaVA-Lightning-MPT-7B-preview), based on MPT-7B-Chat! See [here](#LLaVA-MPT-7b) for more details.
212
+
213
+ - \[5/2\] 🔥 We are releasing LLaVA-Lighting! Train a lite, multimodal GPT-4 with just $40 in 3 hours! See [here](#train-llava-lightning) for more details.
214
+
215
+ - \[4/27\] Thanks to the community effort, LLaVA-13B with 4-bit quantization allows you to run on a GPU with as few as 12GB VRAM! Try it out [here](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/llava).
216
+
217
+ - \[4/17\] 🔥 We released **LLaVA: Large Language and Vision Assistant**. We propose visual instruction tuning, towards building large language and vision models with GPT-4 level capabilities. Checkout the [paper](https://arxiv.org/abs/2304.08485) and [demo](https://llava.hliu.cc/).
218
+
219
+ <!-- <a href="https://llava.hliu.cc/"><img src="assets/demo.gif" width="70%"></a> -->
220
+
221
+ [![Code License](https://img.shields.io/badge/Code%20License-Apache_2.0-green.svg)](https://github.com/tatsu-lab/stanford_alpaca/blob/main/LICENSE)
222
+ [![Data License](https://img.shields.io/badge/Data%20License-CC%20By%20NC%204.0-red.svg)](https://github.com/tatsu-lab/stanford_alpaca/blob/main/DATA_LICENSE)
223
+ **Usage and License Notices**: The data and checkpoint is intended and licensed for research use only. They are also restricted to uses that follow the license agreement of LLaMA, Vicuna and GPT-4. The dataset is CC BY NC 4.0 (allowing only non-commercial use) and models trained using the dataset should not be used outside of research purposes.
224
+
225
+ ### Contents
226
+
227
+ - [Install](#install)
228
+ - [LLaVA Weights](#llava-weights)
229
+ - [Demo](#Demo)
230
+ - [Model Zoo](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md)
231
+ - [Dataset](https://github.com/haotian-liu/LLaVA/blob/main/docs/Data.md)
232
+ - [Train](#train)
233
+ - [Evaluation](#evaluation)
234
+
235
+ ### Install
236
+
237
+ 1. Clone this repository and navigate to LLaVA folder
238
+
239
+ ```bash
240
+ git clone https://github.com/haotian-liu/LLaVA.git
241
+ cd LLaVA
242
+ ```
243
+
244
+ 2. Install Package
245
+
246
+ ```Shell
247
+ conda create -n llava python=3.10 -y
248
+ conda activate llava
249
+ pip install --upgrade pip # enable PEP 660 support
250
+ pip install -e .
251
+ ```
252
+
253
+ 3. Install additional packages for training cases
254
+
255
+ ```
256
+ pip install ninja
257
+ pip install flash-attn --no-build-isolation
258
+ ```
259
+
260
+ #### Upgrade to latest code base
261
+
262
+ ```Shell
263
+ git pull
264
+ pip uninstall transformers
265
+ pip install -e .
266
+ ```
267
+
268
+ ### LLaVA Weights
269
+
270
+ Please check out our [Model Zoo](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md) for all public LLaVA checkpoints, and the instructions of how to use the weights.
271
+
272
+ ### Demo
273
+
274
+ To run our demo, you need to prepare LLaVA checkpoints locally. Please follow the instructions [here](#llava-weights) to download the checkpoints.
275
+
276
+ #### Gradio Web UI
277
+
278
+ To launch a Gradio demo locally, please run the following commands one by one. If you plan to launch multiple model workers to compare between different checkpoints, you only need to launch the controller and the web server *ONCE*.
279
+
280
+ ##### Launch a controller
281
+
282
+ ```Shell
283
+ python -m llava.serve.controller --host 0.0.0.0 --port 10000
284
+ ```
285
+
286
+ ##### Launch a gradio web server.
287
+
288
+ ```Shell
289
+ python -m llava.serve.gradio_web_server --controller http://localhost:10000 --model-list-mode reload
290
+ ```
291
+
292
+ You just launched the Gradio web interface. Now, you can open the web interface with the URL printed on the screen. You may notice that there is no model in the model list. Do not worry, as we have not launched any model worker yet. It will be automatically updated when you launch a model worker.
293
+
294
+ ##### Launch a model worker
295
+
296
+ This is the actual *worker* that performs the inference on the GPU. Each worker is responsible for a single model specified in `--model-path`.
297
+
298
+ ```Shell
299
+ python -m llava.serve.model_worker --host 0.0.0.0 --controller http://localhost:10000 --port 40000 --worker http://localhost:40000 --model-path liuhaotian/llava-v1.5-13b
300
+ ```
301
+
302
+ Wait until the process finishes loading the model and you see "Uvicorn running on ...". Now, refresh your Gradio web UI, and you will see the model you just launched in the model list.
303
+
304
+ You can launch as many workers as you want, and compare between different model checkpoints in the same Gradio interface. Please keep the `--controller` the same, and modify the `--port` and `--worker` to a different port number for each worker.
305
+
306
+ ```Shell
307
+ python -m llava.serve.model_worker --host 0.0.0.0 --controller http://localhost:10000 --port <different from 40000, say 40001> --worker http://localhost:<change accordingly, i.e. 40001> --model-path <ckpt2>
308
+ ```
309
+
310
+ If you are using an Apple device with an M1 or M2 chip, you can specify the mps device by using the `--device` flag: `--device mps`.
311
+
312
+ ##### Launch a model worker (Multiple GPUs, when GPU VRAM \<= 24GB)
313
+
314
+ If the VRAM of your GPU is less than 24GB (e.g., RTX 3090, RTX 4090, etc.), you may try running it with multiple GPUs. Our latest code base will automatically try to use multiple GPUs if you have more than one GPU. You can specify which GPUs to use with `CUDA_VISIBLE_DEVICES`. Below is an example of running with the first two GPUs.
315
+
316
+ ```Shell
317
+ CUDA_VISIBLE_DEVICES=0,1 python -m llava.serve.model_worker --host 0.0.0.0 --controller http://localhost:10000 --port 40000 --worker http://localhost:40000 --model-path liuhaotian/llava-v1.5-13b
318
+ ```
319
+
320
+ ##### Launch a model worker (4-bit, 8-bit inference, quantized)
321
+
322
+ You can launch the model worker with quantized bits (4-bit, 8-bit), which allows you to run the inference with reduced GPU memory footprint, potentially allowing you to run on a GPU with as few as 12GB VRAM. Note that inference with quantized bits may not be as accurate as the full-precision model. Simply append `--load-4bit` or `--load-8bit` to the **model worker** command that you are executing. Below is an example of running with 4-bit quantization.
323
+
324
+ ```Shell
325
+ python -m llava.serve.model_worker --host 0.0.0.0 --controller http://localhost:10000 --port 40000 --worker http://localhost:40000 --model-path liuhaotian/llava-v1.5-13b --load-4bit
326
+ ```
327
+
328
+ ##### Launch a model worker (LoRA weights, unmerged)
329
+
330
+ You can launch the model worker with LoRA weights, without merging them with the base checkpoint, to save disk space. There will be additional loading time, while the inference speed is the same as the merged checkpoints. Unmerged LoRA checkpoints do not have `lora-merge` in the model name, and are usually much smaller (less than 1GB) than the merged checkpoints (13G for 7B, and 25G for 13B).
331
+
332
+ To load unmerged LoRA weights, you simply need to pass an additional argument `--model-base`, which is the base LLM that is used to train the LoRA weights. You can check the base LLM of each LoRA weights in the [model zoo](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md).
333
+
334
+ ```Shell
335
+ python -m llava.serve.model_worker --host 0.0.0.0 --controller http://localhost:10000 --port 40000 --worker http://localhost:40000 --model-path liuhaotian/llava-v1-0719-336px-lora-vicuna-13b-v1.3 --model-base lmsys/vicuna-13b-v1.3
336
+ ```
337
+
338
+ #### CLI Inference
339
+
340
+ Chat about images using LLaVA without the need of Gradio interface. It also supports multiple GPUs, 4-bit and 8-bit quantized inference. With 4-bit quantization, for our LLaVA-1.5-7B, it uses less than 8GB VRAM on a single GPU.
341
+
342
+ ```Shell
343
+ python -m llava.serve.cli \
344
+ --model-path liuhaotian/llava-v1.5-7b \
345
+ --image-file "https://llava-vl.github.io/static/images/view.jpg" \
346
+ --load-4bit
347
+ ```
348
+
349
+ ### Train
350
+
351
+ *Below is the latest training configuration for LLaVA v1.5. For legacy models, please refer to README of [this](https://github.com/haotian-liu/LLaVA/tree/v1.0.1) version for now. We'll add them in a separate doc later.*
352
+
353
+ LLaVA training consists of two stages: (1) feature alignment stage: use our 558K subset of the LAION-CC-SBU dataset to connect a *frozen pretrained* vision encoder to a *frozen LLM*; (2) visual instruction tuning stage: use 150K GPT-generated multimodal instruction-following data, plus around 515K VQA data from academic-oriented tasks, to teach the model to follow multimodal instructions.
354
+
355
+ LLaVA is trained on 8 A100 GPUs with 80GB memory. To train on fewer GPUs, you can reduce the `per_device_train_batch_size` and increase the `gradient_accumulation_steps` accordingly. Always keep the global batch size the same: `per_device_train_batch_size` x `gradient_accumulation_steps` x `num_gpus`.
356
+
357
+ #### Hyperparameters
358
+
359
+ We use a similar set of hyperparameters as Vicuna in finetuning. Both hyperparameters used in pretraining and finetuning are provided below.
360
+
361
+ 1. Pretraining
362
+
363
+ | Hyperparameter | Global Batch Size | Learning rate | Epochs | Max length | Weight decay |
364
+ | -------------- | ----------------: | ------------: | -----: | ---------: | -----------: |
365
+ | LLaVA-v1.5-13B | 256 | 1e-3 | 1 | 2048 | 0 |
366
+
367
+ 2. Finetuning
368
+
369
+ | Hyperparameter | Global Batch Size | Learning rate | Epochs | Max length | Weight decay |
370
+ | -------------- | ----------------: | ------------: | -----: | ---------: | -----------: |
371
+ | LLaVA-v1.5-13B | 128 | 2e-5 | 1 | 2048 | 0 |
372
+
373
+ #### Download Vicuna checkpoints (automatically)
374
+
375
+ Our base model Vicuna v1.5, which is an instruction-tuned chatbot, will be downloaded automatically when you run our provided training scripts. No action is needed.
376
+
377
+ #### Pretrain (feature alignment)
378
+
379
+ Please download the 558K subset of the LAION-CC-SBU dataset with BLIP captions we use in the paper [here](https://huggingface.co/datasets/liuhaotian/LLaVA-Pretrain).
380
+
381
+ Pretrain takes around 5.5 hours for LLaVA-v1.5-13B on 8x A100 (80G), due to the increased resolution to 336px. It takes around 3.5 hours for LLaVA-v1.5-7B.
382
+
383
+ Training script with DeepSpeed ZeRO-2: [`pretrain.sh`](https://github.com/haotian-liu/LLaVA/blob/main/scripts/v1_5/pretrain.sh).
384
+
385
+ - `--mm_projector_type mlp2x_gelu`: the two-layer MLP vision-language connector.
386
+ - `--vision_tower openai/clip-vit-large-patch14-336`: CLIP ViT-L/14 336px.
387
+
388
+ #### Visual Instruction Tuning
389
+
390
+ 1. Prepare data
391
+
392
+ Please download the annotation of the final mixture our instruction tuning data [llava_v1_5_mix665k.json](https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K/blob/main/llava_v1_5_mix665k.json), and download the images from constituting datasets:
393
+
394
+ - COCO: [train2017](http://images.cocodataset.org/zips/train2017.zip)
395
+ - GQA: [images](https://downloads.cs.stanford.edu/nlp/data/gqa/images.zip)
396
+ - OCR-VQA: [download script](https://drive.google.com/drive/folders/1_GYPY5UkUy7HIcR0zq3ZCFgeZN7BAfm_?usp=sharing)
397
+ - TextVQA: [train_val_images](https://dl.fbaipublicfiles.com/textvqa/images/train_val_images.zip)
398
+ - VisualGenome: [part1](https://cs.stanford.edu/people/rak248/VG_100K_2/images.zip), [part2](https://cs.stanford.edu/people/rak248/VG_100K_2/images2.zip)
399
+
400
+ After downloading all of them, organize the data as follows in `./playground/data`,
401
+
402
+ ```
403
+ ├── coco
404
+ │ └── train2017
405
+ ├── gqa
406
+ │ └── images
407
+ ├── ocr_vqa
408
+ │ └── images
409
+ ├── textvqa
410
+ │ └── train_images
411
+ └── vg
412
+ ├── VG_100K
413
+ └── VG_100K_2
414
+ ```
415
+
416
+ 2. Start training!
417
+
418
+ You may download our pretrained projectors in [Model Zoo](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md). It is not recommended to use legacy projectors, as they may be trained with a different version of the codebase, and if any option is off, the model will not function/train as we expected.
419
+
420
+ Visual instruction tuning takes around 20 hours for LLaVA-v1.5-13B on 8x A100 (80G), due to the increased resolution to 336px. It takes around 10 hours for LLaVA-v1.5-7B on 8x A100 (40G).
421
+
422
+ Training script with DeepSpeed ZeRO-3: [`finetune.sh`](https://github.com/haotian-liu/LLaVA/blob/main/scripts/v1_5/finetune.sh).
423
+
424
+ New options to note:
425
+
426
+ - `--mm_projector_type mlp2x_gelu`: the two-layer MLP vision-language connector.
427
+ - `--vision_tower openai/clip-vit-large-patch14-336`: CLIP ViT-L/14 336px.
428
+ - `--image_aspect_ratio pad`: this pads the non-square images to square, instead of cropping them; it slightly reduces hallucination.
429
+ - `--group_by_modality_length True`: this should only be used when your instruction tuning dataset contains both language (e.g. ShareGPT) and multimodal (e.g. LLaVA-Instruct). It makes the training sampler only sample a single modality (either image or language) during training, which we observe to speed up training by ~25%, and does not affect the final outcome.
430
+
431
+ ### Evaluation
432
+
433
+ In LLaVA-1.5, we evaluate models on a diverse set of 12 benchmarks. To ensure the reproducibility, we evaluate the models with greedy decoding. We do not evaluate using beam search to make the inference process consistent with the chat demo of real-time outputs.
434
+
435
+ See [Evaluation.md](https://github.com/haotian-liu/LLaVA/blob/main/docs/Evaluation.md).
436
+
437
+ #### GPT-assisted Evaluation
438
+
439
+ Our GPT-assisted evaluation pipeline for multimodal modeling is provided for a comprehensive understanding of the capabilities of vision-language models. Please see our paper for more details.
440
+
441
+ 1. Generate LLaVA responses
442
+
443
+ ```Shell
444
+ python model_vqa.py \
445
+ --model-path ./checkpoints/LLaVA-13B-v0 \
446
+ --question-file \
447
+ playground/data/coco2014_val_qa_eval/qa90_questions.jsonl \
448
+ --image-folder \
449
+ /path/to/coco2014_val \
450
+ --answers-file \
451
+ /path/to/answer-file-our.jsonl
452
+ ```
453
+
454
+ 2. Evaluate the generated responses. In our case, [`answer-file-ref.jsonl`](./playground/data/coco2014_val_qa_eval/qa90_gpt4_answer.jsonl) is the response generated by text-only GPT-4 (0314), with the context captions/boxes provided.
455
+
456
+ ```Shell
457
+ OPENAI_API_KEY="sk-***********************************" python llava/eval/eval_gpt_review_visual.py \
458
+ --question playground/data/coco2014_val_qa_eval/qa90_questions.jsonl \
459
+ --context llava/eval/table/caps_boxes_coco2014_val_80.jsonl \
460
+ --answer-list \
461
+ /path/to/answer-file-ref.jsonl \
462
+ /path/to/answer-file-our.jsonl \
463
+ --rule llava/eval/table/rule.json \
464
+ --output /path/to/review.json
465
+ ```
466
+
467
+ 3. Summarize the evaluation results
468
+
469
+ ```Shell
470
+ python summarize_gpt_review.py
471
+ ```
472
+
473
+ ### Citation
474
+
475
+ If you find LLaVA useful for your research and applications, please cite using this BibTeX:
476
+
477
+ ```bibtex
478
+ @misc{liu2023improvedllava,
479
+ title={Improved Baselines with Visual Instruction Tuning},
480
+ author={Liu, Haotian and Li, Chunyuan and Li, Yuheng and Lee, Yong Jae},
481
+ publisher={arXiv:2310.03744},
482
+ year={2023},
483
+ }
484
+
485
+ @misc{liu2023llava,
486
+ title={Visual Instruction Tuning},
487
+ author={Liu, Haotian and Li, Chunyuan and Wu, Qingyang and Lee, Yong Jae},
488
+ publisher={arXiv:2304.08485},
489
+ year={2023},
490
+ }
491
+ ```
492
+
493
+ ### Acknowledgement
494
+
495
+ - [Vicuna](https://github.com/lm-sys/FastChat): the codebase we built upon, and our base model Vicuna-13B that has the amazing language capabilities!
496
+
497
+ ### Related Projects
498
+
499
+ - [Instruction Tuning with GPT-4](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
500
+ - [LLaVA-Med: Training a Large Language-and-Vision Assistant for Biomedicine in One Day](https://github.com/microsoft/LLaVA-Med)
501
+ - [Otter: In-Context Multi-Modal Instruction Tuning](https://github.com/Luodian/Otter)
502
+
503
+ For future project ideas, please check out:
504
+
505
+ - [SEEM: Segment Everything Everywhere All at Once](https://github.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once)
506
+ - [Grounded-Segment-Anything](https://github.com/IDEA-Research/Grounded-Segment-Anything) to detect, segment, and generate anything by marrying [Grounding DINO](https://github.com/IDEA-Research/GroundingDINO) and [Segment-Anything](https://github.com/facebookresearch/segment-anything).
InternVL/internvl_chat_llava/pyproject.toml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=61.0"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "llava"
7
+ version = "1.1.1"
8
+ description = "Towards GPT-4 like large language and visual assistant."
9
+ readme = "README.md"
10
+ requires-python = ">=3.8"
11
+ classifiers = [
12
+ "Programming Language :: Python :: 3",
13
+ "License :: OSI Approved :: Apache Software License",
14
+ ]
15
+ dependencies = [
16
+ "torch>=2", "torchvision>=0.15",
17
+ "transformers>=4.37.2", "tokenizers==0.15.1", "sentencepiece==0.1.99", "shortuuid",
18
+ "accelerate", "peft>=0.4.0", "bitsandbytes==0.41.0",
19
+ "pydantic", "markdown2[all]", "numpy", "scikit-learn>=1.2.2",
20
+ "gradio==3.35.2", "gradio_client==0.2.9",
21
+ "requests", "httpx==0.24.0", "uvicorn", "fastapi",
22
+ "deepspeed==0.13.5", "einops", "einops-exts", "timm==0.9.12",
23
+ ]
24
+
25
+ [project.urls]
26
+ "Homepage" = "https://github.com/OpenGVLab/InternVL"
27
+ "Bug Tracker" = "https://github.com/OpenGVLab/InternVL/issues"
28
+
29
+ [tool.setuptools.packages.find]
30
+ exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"]
31
+
32
+ [tool.wheel]
33
+ exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"]
InternVL/internvl_g/README.md ADDED
@@ -0,0 +1,497 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # InternVL Stage-2 Pre-training & Retrieval Fine-tuning
2
+
3
+ This folder contains the implementation of the InternVL 1.0 for stage2 pre-training and retrieval fine-tuning, which corresponds to Section 4.3 of our [InternVL 1.0 paper](https://arxiv.org/pdf/2312.14238).
4
+
5
+ ![image](https://github.com/user-attachments/assets/239f38b2-8867-4539-9dd8-c1a1eaa40aef)
6
+
7
+ ## 🛠️ Installation
8
+
9
+ Follow the [installation guide](../INSTALLATION.md) to perform installations.
10
+
11
+ ## 📦 Data Preparation
12
+
13
+ Three datasets need to be prepared: COCO Caption, Flickr30K, and NoCaps.
14
+
15
+ <details open>
16
+ <summary>COCO Caption</summary>
17
+
18
+ ```bash
19
+ mkdir -p data/coco && cd data/coco
20
+
21
+ # download coco images
22
+ wget http://images.cocodataset.org/zips/train2014.zip && unzip train2014.zip
23
+ wget http://images.cocodataset.org/zips/val2014.zip && unzip val2014.zip
24
+ wget http://images.cocodataset.org/zips/test2015.zip && unzip test2015.zip
25
+
26
+ mkdir -p annotations && cd annotations/
27
+ # download converted annotation files
28
+ wget https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_train.json
29
+ wget https://github.com/OpenGVLab/InternVL/releases/download/data/coco_karpathy_test.json
30
+ wget https://github.com/OpenGVLab/InternVL/releases/download/data/coco_karpathy_test_gt.json
31
+ cd ../../../
32
+ ```
33
+
34
+ </details>
35
+
36
+ <details open>
37
+ <summary>Flickr30K</summary>
38
+
39
+ ```bash
40
+ mkdir -p data/flickr30k && cd data/flickr30k
41
+
42
+ # download images from https://bryanplummer.com/Flickr30kEntities/
43
+ # karpathy split annotations can be downloaded from the following link:
44
+ # https://github.com/mehdidc/retrieval_annotations/releases/download/1.0.0/flickr30k_test_karpathy.txt
45
+ # this file is provided by the clip-benchmark repository.
46
+ # We convert this txt file to json format, download the converted file:
47
+ wget https://github.com/OpenGVLab/InternVL/releases/download/data/flickr30k_cn_test.txt
48
+ wget https://github.com/OpenGVLab/InternVL/releases/download/data/flickr30k_cn_train.txt
49
+ wget https://github.com/OpenGVLab/InternVL/releases/download/data/flickr30k_test_karpathy.json
50
+ wget https://github.com/mehdidc/retrieval_annotations/releases/download/1.0.0/flickr30k_test_karpathy.txt
51
+ wget https://github.com/mehdidc/retrieval_annotations/releases/download/1.0.0/flickr30k_train_karpathy.txt
52
+ wget https://github.com/mehdidc/retrieval_annotations/releases/download/1.0.0/flickr30k_val_karpathy.txt
53
+
54
+ cd ../..
55
+ ```
56
+
57
+ </details>
58
+
59
+ <details open>
60
+ <summary>NoCaps</summary>
61
+
62
+ ```bash
63
+ mkdir -p data/nocaps && cd data/nocaps
64
+
65
+ # download images from https://nocaps.org/download
66
+ # original annotations can be downloaded from https://nocaps.s3.amazonaws.com/nocaps_val_4500_captions.json
67
+ wget https://nocaps.s3.amazonaws.com/nocaps_val_4500_captions.json
68
+
69
+ cd ../..
70
+ ```
71
+
72
+ </details>
73
+
74
+ After the download is complete, the directory structure is:
75
+
76
+ ```shell
77
+ data
78
+ ├── coco
79
+ │   ├── annotations
80
+ │   │   ├── coco_karpathy_train.json
81
+ │   ├── test2017
82
+ │   ├── train2014
83
+ │   ├── train2017
84
+ │   ├── val2014
85
+ │   └── val2017
86
+ ├── flickr30k
87
+ │   ├── flickr30k_cn_test.txt
88
+ │   ├── flickr30k_cn_train.txt
89
+ │   ├── flickr30k_test_karpathy.json
90
+ │   ├── flickr30k_test_karpathy.txt
91
+ │   ├── flickr30k_train_karpathy.txt
92
+ │   ├── flickr30k_val_karpathy.txt
93
+ │   └── Images
94
+ └── nocaps
95
+ ├── images
96
+ └── nocaps_val_4500_captions.json
97
+ ```
98
+
99
+ ## 📦 Model Preparation
100
+
101
+ | model name | type | download | size |
102
+ | ------------------ | ----------- | ----------------------------------------------------------------- | :-----: |
103
+ | InternVL-14B-224px | huggingface | 🤗 [HF link](https://huggingface.co/OpenGVLab/InternVL-14B-224px) | 27.7 GB |
104
+
105
+ Please download the above model weights and place them in the `pretrained/` folder.
106
+
107
+ ```sh
108
+ cd pretrained/
109
+ # pip install -U huggingface_hub
110
+ huggingface-cli download --resume-download --local-dir-use-symlinks False OpenGVLab/InternVL-14B-224px --local-dir InternVL-14B-224px
111
+ ```
112
+
113
+ The directory structure is:
114
+
115
+ ```sh
116
+ pretrained
117
+ └── InternVL-14B-224px/
118
+ ```
119
+
120
+ ## 🔥 Generative Pre-training
121
+
122
+ There are currently no plans to release this part of the code.
123
+
124
+ ## 📊 Evaluation
125
+
126
+ ### Zero-Shot Image Captioning
127
+
128
+ | model | dataset | BLEU4 | METEOR | CIDEr |
129
+ | ---------- | ----------------------- | ----- | ------ | ----- |
130
+ | InternVL-G | COCO Karpathy test | 37.1 | 30.1 | 128.2 |
131
+ | InternVL-G | Flickr30K Karpathy test | 27.0 | 25.3 | 79.2 |
132
+ | InternVL-G | NoCaps val | 44.3 | 30.1 | 113.7 |
133
+
134
+ <details>
135
+ <summary>[InternVL-G] COCO Karpathy test</summary>
136
+
137
+ ```bash
138
+ sh evaluate.sh pretrained/InternVL-14B-224px caption-coco
139
+ ```
140
+
141
+ Expected results:
142
+
143
+ ```
144
+ ['coco', 'English caption:', 10.5974, dict_items([('Bleu_1', 0.7876323287981284), ('Bleu_2', 0.6353512494727918), ('Bleu_3', 0.49108984183589743), ('Bleu_4', 0.37062736733849205), ('METEOR', 0.30106315496945923), ('ROUGE_L', 0.5898249189475652), ('CIDEr', 1.281844384075423)])]
145
+ ```
146
+
147
+ </details>
148
+
149
+ <details>
150
+ <summary>[InternVL-G] Flickr30K Karpathy test</summary>
151
+
152
+ ```
153
+ sh evaluate.sh pretrained/InternVL-14B-224px caption-flickr30k
154
+ ```
155
+
156
+ Expected results:
157
+
158
+ ```bash
159
+ ['flickr30k', 'English caption:', 10.666, dict_items([('Bleu_1', 0.7182900534357628), ('Bleu_2', 0.5353390037921949), ('Bleu_3', 0.3834462132295285), ('Bleu_4', 0.2702131471765472), ('METEOR', 0.25263515267930103), ('ROUGE_L', 0.5305876871149064), ('CIDEr', 0.7919734768328237)])]
160
+ ```
161
+
162
+ </details>
163
+
164
+ <details>
165
+ <summary>[InternVL-G] NoCaps val</summary>
166
+
167
+ ```bash
168
+ sh evaluate.sh pretrained/InternVL-14B-224px caption-nocaps
169
+ ```
170
+
171
+ Expected results:
172
+
173
+ ```
174
+ ['nocaps', 'English caption:', 10.463111111111111, dict_items([('Bleu_1', 0.8518290482155187), ('Bleu_2', 0.7165227921485106), ('Bleu_3', 0.5733723839888316), ('Bleu_4', 0.44268902150723105), ('METEOR', 0.30078174807736896), ('ROUGE_L', 0.6070208063052156), ('CIDEr', 1.1371742045267772)])]
175
+ ```
176
+
177
+ </details>
178
+
179
+ ### Fine-tuned Image-Text Retrieval
180
+
181
+ #### Flickr30K fine-tuned model: [InternVL-14B-Flickr30K-FT-364px](https://huggingface.co/OpenGVLab/InternVL-14B-Flickr30K-FT-364px)
182
+
183
+ <table>
184
+ <tr align=center>
185
+ <td rowspan="3" align=center><b>model</b></td>
186
+ <td colspan="6" align=center><b>Flickr30K</b></td>
187
+ <td rowspan="3" align=center><b>avg</b></td>
188
+
189
+ </tr>
190
+ <tr align=center>
191
+ <td colspan="3" align=center><b>image-to-text</b></td>
192
+ <td colspan="3" align=center><b>text-to-image</b></td>
193
+ </tr>
194
+ <tr>
195
+ <td>R@1</td>
196
+ <td>R@5</td>
197
+ <td>R@10</td>
198
+ <td>R@1</td>
199
+ <td>R@5</td>
200
+ <td>R@10</td>
201
+ </tr>
202
+
203
+ <tr align=center>
204
+ <td>InternVL-C-FT</td>
205
+ <td>97.2</td>
206
+ <td>100.0</td>
207
+ <td>100.0</td>
208
+ <td>88.5</td>
209
+ <td>98.4</td>
210
+ <td>99.2</td>
211
+ <td>97.2</td>
212
+ </tr>
213
+ <tr align=center>
214
+ <td>InternVL-G-FT</td>
215
+ <td>97.9</td>
216
+ <td>100.0</td>
217
+ <td>100.0</td>
218
+ <td>89.6</td>
219
+ <td>98.6</td>
220
+ <td>99.2</td>
221
+ <td>97.6</td>
222
+ </tr>
223
+
224
+ </table>
225
+
226
+ <details>
227
+ <summary>[InternVL-C-FT] Flickr30K</summary>
228
+
229
+ ```bash
230
+ cd ../clip_benchmark/
231
+ CUDA_VISIBLE_DEVICES=0 python3 clip_benchmark/cli.py eval --model_type internvl --language "en" --task "zeroshot_retrieval" \
232
+ --dataset "flickr30k" --dataset_root ./data/flickr30k --model internvl_c_retrieval_hf \
233
+ --pretrained ./work_dirs/internvl_stage2_finetune_flickr_364_bs1024_ep10/ --output result_ft.json
234
+ ```
235
+
236
+ Expected results:
237
+
238
+ ```
239
+ {"dataset": "flickr30k", "model": "internvl_c_retrieval_hf", "pretrained": "./work_dirs/internvl_stage2_finetune_flickr_364_bs1024_ep10", "task": "zeroshot_retrieval",
240
+ "metrics": {"image_retrieval_recall@1": 0.8853999972343445, "text_retrieval_recall@1": 0.972000002861023,
241
+ "image_retrieval_recall@5": 0.9836000204086304, "text_retrieval_recall@5": 1.0,
242
+ "image_retrieval_recall@10": 0.9923999905586243, "text_retrieval_recall@10": 1.0}, "language": "en"}
243
+ ```
244
+
245
+ </details>
246
+
247
+ <details>
248
+ <summary>[InternVL-G-FT] Flickr30K</summary>
249
+
250
+ ```bash
251
+ cd ../clip_benchmark/
252
+ CUDA_VISIBLE_DEVICES=0 python3 clip_benchmark/cli.py eval --model_type internvl --language "en" --task "zeroshot_retrieval" \
253
+ --dataset "flickr30k" --dataset_root ./data/flickr30k --model internvl_g_retrieval_hf \
254
+ --pretrained ./work_dirs/internvl_stage2_finetune_flickr_364_bs1024_ep10/ --output result_ft.json
255
+ ```
256
+
257
+ Expected results:
258
+
259
+ ```
260
+ {"dataset": "flickr30k", "model": "internvl_g_retrieval_hf", "pretrained": "./work_dirs/internvl_stage2_finetune_flickr_364_bs1024_ep10", "task": "zeroshot_retrieval",
261
+ "metrics": {"image_retrieval_recall@1": 0.895799994468689, "text_retrieval_recall@1": 0.9789999723434448,
262
+ "image_retrieval_recall@5": 0.9861999750137329, "text_retrieval_recall@5": 1.0,
263
+ "image_retrieval_recall@10": 0.9922000169754028, "text_retrieval_recall@10": 1.0}, "language": "en"}
264
+ ```
265
+
266
+ </details>
267
+
268
+ #### Flickr30K-CN fine-tuned model: [InternVL-14B-FlickrCN-FT-364px](https://huggingface.co/OpenGVLab/InternVL-14B-FlickrCN-FT-364px)
269
+
270
+ <table>
271
+ <tr align=center>
272
+ <td rowspan="3" align=center><b>model</b></td>
273
+ <td colspan="6" align=center><b>Flickr30K-CN</b></td>
274
+ <td rowspan="3" align=center><b>avg</b></td>
275
+
276
+ </tr>
277
+ <tr align=center>
278
+ <td colspan="3" align=center><b>image-to-text</b></td>
279
+ <td colspan="3" align=center><b>text-to-image</b></td>
280
+ </tr>
281
+ <tr>
282
+ <td>R@1</td>
283
+ <td>R@5</td>
284
+ <td>R@10</td>
285
+ <td>R@1</td>
286
+ <td>R@5</td>
287
+ <td>R@10</td>
288
+ </tr>
289
+
290
+ <tr align=center>
291
+ <td>InternVL-C-FT</td>
292
+ <td>96.5</td>
293
+ <td>99.9</td>
294
+ <td>100.0</td>
295
+ <td>85.2</td>
296
+ <td>97.0</td>
297
+ <td>98.5</td>
298
+ <td>96.2</td>
299
+ </tr>
300
+ <tr align=center>
301
+ <td>InternVL-G-FT</td>
302
+ <td>96.9</td>
303
+ <td>99.9</td>
304
+ <td>100.0</td>
305
+ <td>85.9</td>
306
+ <td>97.1</td>
307
+ <td>98.7</td>
308
+ <td>96.4</td>
309
+ </tr>
310
+
311
+ </table>
312
+
313
+ <details>
314
+ <summary>[InternVL-C-FT] Flickr30K-CN</summary>
315
+
316
+ ```bash
317
+ cd ../clip_benchmark/
318
+ CUDA_VISIBLE_DEVICES=0 python3 clip_benchmark/cli.py eval --model_type internvl --language "cn" --task "zeroshot_retrieval" \
319
+ --dataset "flickr30k" --dataset_root ./data/flickr30k --model internvl_c_retrieval_hf \
320
+ --pretrained ./work_dirs/internvl_stage2_finetune_flickrcn_364_bs1024_ep10/ --output result_ft.json
321
+ ```
322
+
323
+ Expected results:
324
+
325
+ ```
326
+ {"dataset": "flickr30k", "model": "internvl_c_retrieval_hf", "pretrained": "./work_dirs/internvl_stage2_finetune_flickrcn_364_bs1024_ep10", "task": "zeroshot_retrieval",
327
+ "metrics": {"image_retrieval_recall@1": 0.8521999716758728, "text_retrieval_recall@1": 0.9649999737739563,
328
+ "image_retrieval_recall@5": 0.9697999954223633, "text_retrieval_recall@5": 0.9990000128746033,
329
+ "image_retrieval_recall@10": 0.9854000210762024, "text_retrieval_recall@10": 1.0}, "language": "cn"}
330
+ ```
331
+
332
+ </details>
333
+
334
+ <details>
335
+ <summary>[InternVL-G-FT] Flickr30K-CN</summary>
336
+
337
+ ```bash
338
+ cd ../clip_benchmark/
339
+ CUDA_VISIBLE_DEVICES=0 python3 clip_benchmark/cli.py eval --model_type internvl --language "cn" --task "zeroshot_retrieval" \
340
+ --dataset "flickr30k" --dataset_root ./data/flickr30k --model internvl_g_retrieval_hf \
341
+ --pretrained ./work_dirs/internvl_stage2_finetune_flickrcn_364_bs1024_ep10/ --output result_ft.json
342
+ ```
343
+
344
+ Expected results:
345
+
346
+ ```
347
+ {"dataset": "flickr30k", "model": "internvl_g_retrieval_hf", "pretrained": "./work_dirs/internvl_stage2_finetune_flickrcn_364_bs1024_ep10", "task": "zeroshot_retrieval",
348
+ "metrics": {"image_retrieval_recall@1": 0.8587999939918518, "text_retrieval_recall@1": 0.968999981880188,
349
+ "image_retrieval_recall@5": 0.9714000225067139, "text_retrieval_recall@5": 0.9990000128746033,
350
+ "image_retrieval_recall@10": 0.9865999817848206, "text_retrieval_recall@10": 1.0}, "language": "cn"}
351
+ ```
352
+
353
+ </details>
354
+
355
+ ## 🔥 Retrieval Fine-tuning (Fully)
356
+
357
+ > Note: In our experiments, full parameter fine-tuning achieves the best results on image-text retrieval tasks in Flickr30K and COCO. By following the experimental hyperparameters in this section, you can reproduce the model performance reported in the [Evaluation section](#evaluation).
358
+
359
+ To fine-tune InternVL on Flickr30K with 32 GPUs and slurm system, run:
360
+
361
+ ```bash
362
+ PARTITION='your partition' GPUS=32 sh shell/finetune/internvl_stage2_finetune_flickr_364_bs1024_ep10.sh
363
+ ```
364
+
365
+ To fine-tune InternVL on Flickr30K-CN with 32 GPUs and slurm system, run:
366
+
367
+ ```shell
368
+ PARTITION='your partition' GPUS=32 sh shell/finetune/internvl_stage2_finetune_flickrcn_364_bs1024_ep10.sh
369
+ ```
370
+
371
+ To fine-tune InternVL on COCO with 32 GPUs and slurm system, run:
372
+
373
+ ```shell
374
+ PARTITION='your partition' GPUS=32 sh shell/finetune/internvl_stage2_finetune_coco_364_bs1024_ep5.sh
375
+ ```
376
+
377
+ The hyperparameters used here are:
378
+
379
+ | config | Flickr30K | Flickr30K-CN | COCO |
380
+ | --------------------------- | ----------------------------------- | ----------------------------------- | ----------------------------------- |
381
+ | learning rate | 1e-6 | 1e-6 | 1e-6 |
382
+ | layer-wise lr<br>decay rate | InternViT-6B (0.9),<br>QLLaMA (0.9) | InternViT-6B (0.9),<br>QLLaMA (0.9) | InternViT-6B (0.9),<br>QLLaMA (0.9) |
383
+ | optimizer | AdamW | AdamW | AdamW |
384
+ | weight decay | 0.05 | 0.05 | 0.05 |
385
+ | input resolution | 364x364 | 364x364 | 364x364 |
386
+ | total batch size | 1024 | 1024 | 1024 |
387
+ | warm-up iterations | 100 | 100 | 100 |
388
+ | training epochs | 10 | 10 | 5 |
389
+ | drop path rate | 0.3 | 0.3 | 0.3 |
390
+ | numerical precision | zero1 + bf16 | zero1 + bf16 | zero1 + bf16 |
391
+ | trainable / total params | 14B / 14B | 14B / 14B | 14B / 14B |
392
+ | GPUs for training | 32×A100 (80G) | 32×A100 (80G) | 32×A100 (80G) |
393
+ | Required GPU memory | 80G | 80G | 80G |
394
+
395
+ ## 🔥 Retrieval Fine-tuning (Head)
396
+
397
+ > Note: This section demonstrates how to perform a cost-effective fine-tuning of our model. The hyperparameters shown here are not optimized for any specific task. For practical applications, further adjustments to the hyperparameters may be necessary to achieve optimal performance.
398
+
399
+ To fine-tune the head of InternVL on Flickr30K with 4 GPUs, run:
400
+
401
+ ```bash
402
+ GPUS=4 BATCH_SIZE=32 sh shell/head_finetune/internvl_stage2_finetune_flickr_224_bs1024_ep10_head_4gpu.sh
403
+ ```
404
+
405
+ To fine-tune the head of InternVL on Flickr30K-CN with 4 GPUs, run:
406
+
407
+ ```shell
408
+ GPUS=4 BATCH_SIZE=32 sh shell/head_finetune/internvl_stage2_finetune_flickrcn_224_bs1024_ep10_head_4gpu.sh
409
+ ```
410
+
411
+ To fine-tune the head of InternVL on COCO with 4 GPUs, run:
412
+
413
+ ```shell
414
+ GPUS=4 BATCH_SIZE=32 shell/head_finetune/internvl_stage2_finetune_coco_224_bs1024_ep5_head_4gpu.sh
415
+ ```
416
+
417
+ The hyperparameters used here are:
418
+
419
+ | config | Flickr30K | Flickr30K-CN | COCO |
420
+ | ------------------------ | ------------- | ------------- | ------------- |
421
+ | learning rate | 1e-6 | 1e-6 | 1e-6 |
422
+ | optimizer | AdamW | AdamW | AdamW |
423
+ | weight decay | 0.05 | 0.05 | 0.05 |
424
+ | input resolution | 224x224 | 224x224 | 224x224 |
425
+ | total batch size | 4x32 | 4x32 | 4x32 |
426
+ | warm-up iterations | 100 | 100 | 100 |
427
+ | training epochs | 10 | 10 | 5 |
428
+ | drop path rate | 0.0 | 0.0 | 0.3 |
429
+ | numerical precision | zero3 + bf16 | zero3 + bf16 | zero1 + bf16 |
430
+ | trainable / total params | 0.2B / 14B | 0.2B / 14B | 0.2B / 14B |
431
+ | GPUs for training | 4×GPU (>=32G) | 4×GPU (>=32G) | 4×GPU (>=32G) |
432
+ | Required GPU memory | 24G | 24G | 24G |
433
+
434
+ ## 🔥 Retrieval Fine-tuning (LoRA)
435
+
436
+ > Note: This section demonstrates how to perform a cost-effective fine-tuning of our model. The hyperparameters shown here are not optimized for any specific task. For practical applications, further adjustments to the hyperparameters may be necessary to achieve optimal performance.
437
+
438
+ To fine-tune InternVL using LoRA on Flickr30K with 4 GPUs, run:
439
+
440
+ ```bash
441
+ GPUS=4 BATCH_SIZE=32 sh shell/lora_finetune/internvl_stage2_finetune_flickr_224_bs1024_ep10_lora16_4gpu.sh
442
+ ```
443
+
444
+ To fine-tune InternVL using LoRA on Flickr30K-CN with 4 GPUs, run:
445
+
446
+ ```shell
447
+ GPUS=4 BATCH_SIZE=32 sh shell/lora_finetune/internvl_stage2_finetune_flickrcn_224_bs1024_ep10_lora16_4gpu.sh
448
+ ```
449
+
450
+ To fine-tune InternVL using LoRA on COCO with 4 GPUs, run:
451
+
452
+ ```shell
453
+ GPUS=4 BATCH_SIZE=32 shell/lora_finetune/internvl_stage2_finetune_coco_224_bs1024_ep5_lora16_4gpu.sh
454
+ ```
455
+
456
+ The hyperparameters used here are:
457
+
458
+ | config | Flickr30K | Flickr30K-CN | COCO |
459
+ | ------------------------ | ------------- | ------------- | ------------- |
460
+ | learning rate | 1e-6 | 1e-6 | 1e-6 |
461
+ | optimizer | AdamW | AdamW | AdamW |
462
+ | lora rank | 16 | 16 | 16 |
463
+ | weight decay | 0.05 | 0.05 | 0.05 |
464
+ | input resolution | 224x224 | 224x224 | 224x224 |
465
+ | total batch size | 4x32 | 4x32 | 4x32 |
466
+ | warm-up iterations | 100 | 100 | 100 |
467
+ | training epochs | 10 | 10 | 5 |
468
+ | drop path rate | 0.0 | 0.0 | 0.3 |
469
+ | numerical precision | zero3 + bf16 | zero3 + bf16 | zero1 + bf16 |
470
+ | trainable / total params | 0.3B / 14B | 0.3B / 14B | 0.3B / 14B |
471
+ | GPUs for training | 4×GPU (>=40G) | 4×GPU (>=40G) | 4×GPU (>=40G) |
472
+ | Required GPU memory | 37G | 37G | 37G |
473
+
474
+ ## Fine-Tuning a Custom Dataset
475
+
476
+ 1. **Organize Your Data**: Format your dataset similar to COCO or Flickr30K.
477
+
478
+ 2. **Update Meta Information**: Add your dataset's meta information to the `ds_collections` dictionary in `internvl_g/internvl/train/internvl_stage2_finetune.py`. For example:
479
+
480
+ ```python
481
+ ds_collections = {
482
+ 'my_dataset_flickr_format': {
483
+ 'root': './data/my_dataset/images/',
484
+ 'annotation': './data/my_dataset/annotations.txt',
485
+ },
486
+ 'my_dataset_coco_format': {
487
+ 'root': './data/my_dataset/',
488
+ 'annotation': './data/my_dataset/annotations.json',
489
+ },
490
+ }
491
+ ```
492
+
493
+ 3. **Name Your Dataset**:
494
+
495
+ - Include `flickr_format` or `coco_format` in your dataset's `dataset_name`. This will allow the script to reuse the Flickr30K or COCO dataloader accordingly.
496
+
497
+ By following these steps, you can easily fine-tune the InternVL model on your custom dataset using the existing COCO or Flickr30K data loading mechanisms.
InternVL/segmentation/dist_test.sh ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ CONFIG=$1
4
+ CHECKPOINT=$2
5
+ GPUS=$3
6
+ PORT=${PORT:-29510}
7
+ PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
8
+ torchrun --nproc_per_node=$GPUS --master_port=$PORT \
9
+ $(dirname "$0")/test.py $CONFIG $CHECKPOINT --launcher pytorch ${@:4}
InternVL/segmentation/dist_train.sh ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ CONFIG=$1
4
+ GPUS=$2
5
+ PORT=${PORT:-29300}
6
+
7
+ PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
8
+ torchrun --nproc_per_node=$GPUS --master_port=$PORT \
9
+ $(dirname "$0")/train.py $CONFIG --launcher pytorch --deterministic ${@:3}
InternVL/segmentation/train.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # InternVL
3
+ # Copyright (c) 2023 OpenGVLab
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+ import argparse
7
+ import copy
8
+ import os
9
+ import os.path as osp
10
+ import time
11
+ import warnings
12
+
13
+ import mmcv
14
+ import mmcv_custom # noqa: F401,F403
15
+ import mmseg_custom # noqa: F401,F403
16
+ import torch
17
+ from mmcv.cnn.utils import revert_sync_batchnorm
18
+ from mmcv.runner import get_dist_info, init_dist
19
+ from mmcv.utils import Config, DictAction, get_git_hash
20
+ from mmseg import __version__
21
+ from mmseg.apis import init_random_seed, set_random_seed, train_segmentor
22
+ from mmseg.datasets import build_dataset
23
+ from mmseg.models import build_segmentor
24
+ from mmseg.utils import collect_env, get_root_logger
25
+
26
+
27
+ def parse_args():
28
+ parser = argparse.ArgumentParser(description='Train a segmentor')
29
+ parser.add_argument('config', help='train config file path')
30
+ parser.add_argument('--work-dir', help='the dir to save logs and models')
31
+ parser.add_argument(
32
+ '--load-from', help='the checkpoint file to load weights from')
33
+ parser.add_argument(
34
+ '--resume-from', help='the checkpoint file to resume from')
35
+ parser.add_argument(
36
+ '--no-validate',
37
+ action='store_true',
38
+ help='whether not to evaluate the checkpoint during training')
39
+ group_gpus = parser.add_mutually_exclusive_group()
40
+ group_gpus.add_argument(
41
+ '--gpus',
42
+ type=int,
43
+ help='number of gpus to use '
44
+ '(only applicable to non-distributed training)')
45
+ group_gpus.add_argument(
46
+ '--gpu-ids',
47
+ type=int,
48
+ nargs='+',
49
+ help='ids of gpus to use '
50
+ '(only applicable to non-distributed training)')
51
+ parser.add_argument('--seed', type=int, default=None, help='random seed')
52
+ parser.add_argument(
53
+ '--deterministic',
54
+ action='store_true',
55
+ help='whether to set deterministic options for CUDNN backend.')
56
+ parser.add_argument(
57
+ '--options',
58
+ nargs='+',
59
+ action=DictAction,
60
+ help="--options is deprecated in favor of --cfg_options' and it will "
61
+ 'not be supported in version v0.22.0. Override some settings in the '
62
+ 'used config, the key-value pair in xxx=yyy format will be merged '
63
+ 'into config file. If the value to be overwritten is a list, it '
64
+ 'should be like key="[a,b]" or key=a,b It also allows nested '
65
+ 'list/tuple values, e.g. key="[(a,b),(c,d)]" Note that the quotation '
66
+ 'marks are necessary and that no white space is allowed.')
67
+ parser.add_argument(
68
+ '--cfg-options',
69
+ nargs='+',
70
+ action=DictAction,
71
+ help='override some settings in the used config, the key-value pair '
72
+ 'in xxx=yyy format will be merged into config file. If the value to '
73
+ 'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
74
+ 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
75
+ 'Note that the quotation marks are necessary and that no white space '
76
+ 'is allowed.')
77
+ parser.add_argument(
78
+ '--launcher',
79
+ choices=['none', 'pytorch', 'slurm', 'mpi'],
80
+ default='none',
81
+ help='job launcher')
82
+ parser.add_argument('--local_rank', type=int, default=0)
83
+ parser.add_argument(
84
+ '--auto-resume',
85
+ action='store_true',
86
+ help='resume from the latest checkpoint automatically.')
87
+ args = parser.parse_args()
88
+ if 'LOCAL_RANK' not in os.environ:
89
+ os.environ['LOCAL_RANK'] = str(args.local_rank)
90
+
91
+ if args.options and args.cfg_options:
92
+ raise ValueError(
93
+ '--options and --cfg-options cannot be both '
94
+ 'specified, --options is deprecated in favor of --cfg-options. '
95
+ '--options will not be supported in version v0.22.0.')
96
+ if args.options:
97
+ warnings.warn('--options is deprecated in favor of --cfg-options. '
98
+ '--options will not be supported in version v0.22.0.')
99
+ args.cfg_options = args.options
100
+
101
+ return args
102
+
103
+
104
+ def main():
105
+ args = parse_args()
106
+
107
+ cfg = Config.fromfile(args.config)
108
+ if args.cfg_options is not None:
109
+ cfg.merge_from_dict(args.cfg_options)
110
+ # set cudnn_benchmark
111
+ if cfg.get('cudnn_benchmark', False):
112
+ torch.backends.cudnn.benchmark = True
113
+
114
+ # work_dir is determined in this priority: CLI > segment in file > filename
115
+ if args.work_dir is not None:
116
+ # update configs according to CLI args if args.work_dir is not None
117
+ cfg.work_dir = args.work_dir
118
+ elif cfg.get('work_dir', None) is None:
119
+ # use config filename as default work_dir if cfg.work_dir is None
120
+ cfg.work_dir = osp.join('./work_dirs',
121
+ osp.splitext(osp.basename(args.config))[0])
122
+ if args.load_from is not None:
123
+ cfg.load_from = args.load_from
124
+ if args.resume_from is not None:
125
+ cfg.resume_from = args.resume_from
126
+ if args.gpu_ids is not None:
127
+ cfg.gpu_ids = args.gpu_ids
128
+ else:
129
+ cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus)
130
+ cfg.auto_resume = args.auto_resume
131
+
132
+ # init distributed env first, since logger depends on the dist info.
133
+ if args.launcher == 'none':
134
+ distributed = False
135
+ else:
136
+ distributed = True
137
+ init_dist(args.launcher, **cfg.dist_params)
138
+ # gpu_ids is used to calculate iter when resuming checkpoint
139
+ _, world_size = get_dist_info()
140
+ cfg.gpu_ids = range(world_size)
141
+
142
+ cfg.device = 'cuda' # fix 'ConfigDict' object has no attribute 'device'
143
+ # create work_dir
144
+ mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
145
+ # dump config
146
+ cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config)))
147
+ # init the logger before other steps
148
+ timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
149
+ log_file = osp.join(cfg.work_dir, f'{timestamp}.log')
150
+ logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)
151
+
152
+ # init the meta dict to record some important information such as
153
+ # environment info and seed, which will be logged
154
+ meta = dict()
155
+ # log env info
156
+ env_info_dict = collect_env()
157
+ env_info = '\n'.join([f'{k}: {v}' for k, v in env_info_dict.items()])
158
+ dash_line = '-' * 60 + '\n'
159
+ logger.info('Environment info:\n' + dash_line + env_info + '\n' +
160
+ dash_line)
161
+ meta['env_info'] = env_info
162
+
163
+ # log some basic info
164
+ logger.info(f'Distributed training: {distributed}')
165
+ logger.info(f'Config:\n{cfg.pretty_text}')
166
+
167
+ # set random seeds
168
+ seed = init_random_seed(args.seed)
169
+ logger.info(f'Set random seed to {seed}, '
170
+ f'deterministic: {args.deterministic}')
171
+ set_random_seed(seed, deterministic=args.deterministic)
172
+ cfg.seed = seed
173
+ meta['seed'] = seed
174
+ meta['exp_name'] = osp.basename(args.config)
175
+
176
+ model = build_segmentor(
177
+ cfg.model,
178
+ train_cfg=cfg.get('train_cfg'),
179
+ test_cfg=cfg.get('test_cfg'))
180
+ model.init_weights()
181
+
182
+ # SyncBN is not support for DP
183
+ if not distributed:
184
+ warnings.warn(
185
+ 'SyncBN is only supported with DDP. To be compatible with DP, '
186
+ 'we convert SyncBN to BN. Please use dist_train.sh which can '
187
+ 'avoid this error.')
188
+ model = revert_sync_batchnorm(model)
189
+
190
+ logger.info(model)
191
+
192
+ datasets = [build_dataset(cfg.data.train)]
193
+ if len(cfg.workflow) == 2:
194
+ val_dataset = copy.deepcopy(cfg.data.val)
195
+ val_dataset.pipeline = cfg.data.train.pipeline
196
+ datasets.append(build_dataset(val_dataset))
197
+ if cfg.checkpoint_config is not None:
198
+ # save mmseg version, config file content and class names in
199
+ # checkpoints as meta data
200
+ cfg.checkpoint_config.meta = dict(
201
+ mmseg_version=f'{__version__}+{get_git_hash()[:7]}',
202
+ config=cfg.pretty_text,
203
+ CLASSES=datasets[0].CLASSES,
204
+ PALETTE=datasets[0].PALETTE)
205
+ # add an attribute for visualization convenience
206
+ model.CLASSES = datasets[0].CLASSES
207
+ # passing checkpoint meta for saving best checkpoint
208
+ meta.update(cfg.checkpoint_config.meta)
209
+ train_segmentor(
210
+ model,
211
+ datasets,
212
+ cfg,
213
+ distributed=distributed,
214
+ validate=(not args.no_validate),
215
+ timestamp=timestamp,
216
+ meta=meta)
217
+
218
+
219
+ if __name__ == '__main__':
220
+ main()
InternVL/streamlit_demo/constants.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # InternVL
3
+ # Copyright (c) 2024 OpenGVLab
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+
7
+ CONTROLLER_HEART_BEAT_EXPIRATION = 30
8
+ WORKER_HEART_BEAT_INTERVAL = 15
9
+
10
+ LOGDIR = 'logs/'
11
+
12
+ # Model Constants
13
+ IGNORE_INDEX = -100
14
+ IMAGE_TOKEN_INDEX = -200
15
+ DEFAULT_IMAGE_TOKEN = '<image>'
16
+ DEFAULT_IMAGE_PATCH_TOKEN = '<IMG_CONTEXT>'
17
+ DEFAULT_IM_START_TOKEN = '<img>'
18
+ DEFAULT_IM_END_TOKEN = '</img>'
19
+ IMAGE_PLACEHOLDER = '<image-placeholder>'
20
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
21
+ IMAGENET_STD = (0.229, 0.224, 0.225)
22
+
23
+ server_error_msg = '**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**'
InternVL/streamlit_demo/controller.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A controller manages distributed workers.
3
+ It sends worker addresses to clients.
4
+ """
5
+ import argparse
6
+ import dataclasses
7
+ import json
8
+ import re
9
+ import threading
10
+ import time
11
+ from enum import Enum, auto
12
+ from typing import List
13
+
14
+ import numpy as np
15
+ import requests
16
+ import uvicorn
17
+ from fastapi import FastAPI, Request
18
+ from fastapi.responses import StreamingResponse
19
+ from utils import build_logger, server_error_msg
20
+
21
+ CONTROLLER_HEART_BEAT_EXPIRATION = 30
22
+ logger = build_logger('controller', 'controller.log')
23
+
24
+
25
+ class DispatchMethod(Enum):
26
+ LOTTERY = auto()
27
+ SHORTEST_QUEUE = auto()
28
+
29
+ @classmethod
30
+ def from_str(cls, name):
31
+ if name == 'lottery':
32
+ return cls.LOTTERY
33
+ elif name == 'shortest_queue':
34
+ return cls.SHORTEST_QUEUE
35
+ else:
36
+ raise ValueError(f'Invalid dispatch method')
37
+
38
+
39
+ @dataclasses.dataclass
40
+ class WorkerInfo:
41
+ model_names: List[str]
42
+ speed: int
43
+ queue_length: int
44
+ check_heart_beat: bool
45
+ last_heart_beat: str
46
+
47
+
48
+ def heart_beat_controller(controller):
49
+ while True:
50
+ time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION)
51
+ controller.remove_stable_workers_by_expiration()
52
+
53
+
54
+ class Controller:
55
+ def __init__(self, dispatch_method: str):
56
+ # Dict[str -> WorkerInfo]
57
+ self.worker_info = {}
58
+ self.dispatch_method = DispatchMethod.from_str(dispatch_method)
59
+
60
+ self.heart_beat_thread = threading.Thread(
61
+ target=heart_beat_controller, args=(self,))
62
+ self.heart_beat_thread.start()
63
+
64
+ logger.info('Init controller')
65
+
66
+ def register_worker(self, worker_name: str, check_heart_beat: bool,
67
+ worker_status: dict):
68
+ if worker_name not in self.worker_info:
69
+ logger.info(f'Register a new worker: {worker_name}')
70
+ else:
71
+ logger.info(f'Register an existing worker: {worker_name}')
72
+
73
+ if not worker_status:
74
+ worker_status = self.get_worker_status(worker_name)
75
+ if not worker_status:
76
+ return False
77
+
78
+ self.worker_info[worker_name] = WorkerInfo(
79
+ worker_status['model_names'], worker_status['speed'], worker_status['queue_length'],
80
+ check_heart_beat, time.time())
81
+
82
+ logger.info(f'Register done: {worker_name}, {worker_status}')
83
+ return True
84
+
85
+ def get_worker_status(self, worker_name: str):
86
+ try:
87
+ r = requests.post(worker_name + '/worker_get_status', timeout=5)
88
+ except requests.exceptions.RequestException as e:
89
+ logger.error(f'Get status fails: {worker_name}, {e}')
90
+ return None
91
+
92
+ if r.status_code != 200:
93
+ logger.error(f'Get status fails: {worker_name}, {r}')
94
+ return None
95
+
96
+ return r.json()
97
+
98
+ def remove_worker(self, worker_name: str):
99
+ del self.worker_info[worker_name]
100
+
101
+ def refresh_all_workers(self):
102
+ old_info = dict(self.worker_info)
103
+ self.worker_info = {}
104
+
105
+ for w_name, w_info in old_info.items():
106
+ if not self.register_worker(w_name, w_info.check_heart_beat, None):
107
+ logger.info(f'Remove stale worker: {w_name}')
108
+
109
+ def list_models(self):
110
+ model_names = set()
111
+
112
+ for w_name, w_info in self.worker_info.items():
113
+ model_names.update(w_info.model_names)
114
+
115
+ def extract_key(s):
116
+ if 'Pro' in s:
117
+ return 999
118
+ match = re.match(r'InternVL2-(\d+)B', s)
119
+ if match:
120
+ return int(match.group(1))
121
+ return -1
122
+
123
+ def custom_sort_key(s):
124
+ key = extract_key(s)
125
+ # Return a tuple where -1 will ensure that non-matching items come last
126
+ return (0 if key != -1 else 1, -key if key != -1 else s)
127
+
128
+ sorted_list = sorted(list(model_names), key=custom_sort_key)
129
+ return sorted_list
130
+
131
+ def get_worker_address(self, model_name: str):
132
+ if self.dispatch_method == DispatchMethod.LOTTERY:
133
+ worker_names = []
134
+ worker_speeds = []
135
+ for w_name, w_info in self.worker_info.items():
136
+ if model_name in w_info.model_names:
137
+ worker_names.append(w_name)
138
+ worker_speeds.append(w_info.speed)
139
+ worker_speeds = np.array(worker_speeds, dtype=np.float32)
140
+ norm = np.sum(worker_speeds)
141
+ if norm < 1e-4:
142
+ return ''
143
+ worker_speeds = worker_speeds / norm
144
+ if True: # Directly return address
145
+ pt = np.random.choice(np.arange(len(worker_names)),
146
+ p=worker_speeds)
147
+ worker_name = worker_names[pt]
148
+ return worker_name
149
+
150
+ elif self.dispatch_method == DispatchMethod.SHORTEST_QUEUE:
151
+ worker_names = []
152
+ worker_qlen = []
153
+ for w_name, w_info in self.worker_info.items():
154
+ if model_name in w_info.model_names:
155
+ worker_names.append(w_name)
156
+ worker_qlen.append(w_info.queue_length / w_info.speed)
157
+ if len(worker_names) == 0:
158
+ return ''
159
+ min_index = np.argmin(worker_qlen)
160
+ w_name = worker_names[min_index]
161
+ self.worker_info[w_name].queue_length += 1
162
+ logger.info(f'names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}')
163
+ return w_name
164
+ else:
165
+ raise ValueError(f'Invalid dispatch method: {self.dispatch_method}')
166
+
167
+ def receive_heart_beat(self, worker_name: str, queue_length: int):
168
+ if worker_name not in self.worker_info:
169
+ logger.info(f'Receive unknown heart beat. {worker_name}')
170
+ return False
171
+
172
+ self.worker_info[worker_name].queue_length = queue_length
173
+ self.worker_info[worker_name].last_heart_beat = time.time()
174
+ logger.info(f'Receive heart beat. {worker_name}')
175
+ return True
176
+
177
+ def remove_stable_workers_by_expiration(self):
178
+ expire = time.time() - CONTROLLER_HEART_BEAT_EXPIRATION
179
+ to_delete = []
180
+ for worker_name, w_info in self.worker_info.items():
181
+ if w_info.check_heart_beat and w_info.last_heart_beat < expire:
182
+ to_delete.append(worker_name)
183
+
184
+ for worker_name in to_delete:
185
+ self.remove_worker(worker_name)
186
+
187
+ def worker_api_generate_stream(self, params):
188
+ worker_addr = self.get_worker_address(params['model'])
189
+ if not worker_addr:
190
+ logger.info(f"no worker: {params['model']}")
191
+ ret = {
192
+ 'text': server_error_msg,
193
+ 'error_code': 2,
194
+ }
195
+ yield json.dumps(ret).encode() + b'\0'
196
+
197
+ try:
198
+ response = requests.post(worker_addr + '/worker_generate_stream',
199
+ json=params, stream=True, timeout=5)
200
+ for chunk in response.iter_lines(decode_unicode=False, delimiter=b'\0'):
201
+ if chunk:
202
+ yield chunk + b'\0'
203
+ except requests.exceptions.RequestException as e:
204
+ logger.info(f'worker timeout: {worker_addr}')
205
+ ret = {
206
+ 'text': server_error_msg,
207
+ 'error_code': 3,
208
+ }
209
+ yield json.dumps(ret).encode() + b'\0'
210
+
211
+ # Let the controller act as a worker to achieve hierarchical
212
+ # management. This can be used to connect isolated sub networks.
213
+ def worker_api_get_status(self):
214
+ model_names = set()
215
+ speed = 0
216
+ queue_length = 0
217
+
218
+ for w_name in self.worker_info:
219
+ worker_status = self.get_worker_status(w_name)
220
+ if worker_status is not None:
221
+ model_names.update(worker_status['model_names'])
222
+ speed += worker_status['speed']
223
+ queue_length += worker_status['queue_length']
224
+
225
+ return {
226
+ 'model_names': list(model_names),
227
+ 'speed': speed,
228
+ 'queue_length': queue_length,
229
+ }
230
+
231
+
232
+ app = FastAPI()
233
+
234
+
235
+ @app.post('/register_worker')
236
+ async def register_worker(request: Request):
237
+ data = await request.json()
238
+ controller.register_worker(
239
+ data['worker_name'], data['check_heart_beat'],
240
+ data.get('worker_status', None))
241
+
242
+
243
+ @app.post('/refresh_all_workers')
244
+ async def refresh_all_workers():
245
+ models = controller.refresh_all_workers()
246
+
247
+
248
+ @app.post('/list_models')
249
+ async def list_models():
250
+ models = controller.list_models()
251
+ return {'models': models}
252
+
253
+
254
+ @app.post('/get_worker_address')
255
+ async def get_worker_address(request: Request):
256
+ data = await request.json()
257
+ addr = controller.get_worker_address(data['model'])
258
+ return {'address': addr}
259
+
260
+
261
+ @app.post('/receive_heart_beat')
262
+ async def receive_heart_beat(request: Request):
263
+ data = await request.json()
264
+ exist = controller.receive_heart_beat(
265
+ data['worker_name'], data['queue_length'])
266
+ return {'exist': exist}
267
+
268
+
269
+ @app.post('/worker_generate_stream')
270
+ async def worker_api_generate_stream(request: Request):
271
+ params = await request.json()
272
+ generator = controller.worker_api_generate_stream(params)
273
+ return StreamingResponse(generator)
274
+
275
+
276
+ @app.post('/worker_get_status')
277
+ async def worker_api_get_status(request: Request):
278
+ return controller.worker_api_get_status()
279
+
280
+
281
+ if __name__ == '__main__':
282
+ parser = argparse.ArgumentParser()
283
+ parser.add_argument('--host', type=str, default='0.0.0.0')
284
+ parser.add_argument('--port', type=int, default=10075)
285
+ parser.add_argument('--dispatch-method', type=str, choices=[
286
+ 'lottery', 'shortest_queue'], default='shortest_queue')
287
+ args = parser.parse_args()
288
+ logger.info(f'args: {args}')
289
+
290
+ controller = Controller(args.dispatch_method)
291
+ uvicorn.run(app, host=args.host, port=args.port, log_level='info')
InternVL/streamlit_demo/model_worker.py ADDED
@@ -0,0 +1,442 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # InternVL
3
+ # Copyright (c) 2024 OpenGVLab
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+
7
+ """
8
+ A model worker executes the model.
9
+ """
10
+ import argparse
11
+ import asyncio
12
+ import base64
13
+ import json
14
+ import math
15
+ import threading
16
+ import time
17
+ import uuid
18
+ from functools import partial
19
+ from io import BytesIO
20
+ from threading import Thread
21
+
22
+ import requests
23
+ import torch
24
+ import torchvision.transforms as T
25
+ import uvicorn
26
+ from constants import IMAGENET_MEAN, IMAGENET_STD, WORKER_HEART_BEAT_INTERVAL
27
+ from fastapi import BackgroundTasks, FastAPI, Request
28
+ from fastapi.responses import StreamingResponse
29
+ from PIL import Image
30
+ from torchvision.transforms.functional import InterpolationMode
31
+ from transformers import AutoModel, AutoTokenizer, TextIteratorStreamer
32
+ from utils import build_logger, pretty_print_semaphore, server_error_msg
33
+
34
+ worker_id = str(uuid.uuid4())[:6]
35
+ logger = build_logger('model_worker', f'model_worker_{worker_id}.log')
36
+ global_counter = 0
37
+ model_semaphore = None
38
+
39
+
40
+ def load_image_from_base64(image):
41
+ return Image.open(BytesIO(base64.b64decode(image)))
42
+
43
+
44
+ def build_transform(input_size):
45
+ MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
46
+ transform = T.Compose([
47
+ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
48
+ T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
49
+ T.ToTensor(),
50
+ T.Normalize(mean=MEAN, std=STD)
51
+ ])
52
+ return transform
53
+
54
+
55
+ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
56
+ best_ratio_diff = float('inf')
57
+ best_ratio = (1, 1)
58
+ area = width * height
59
+ for ratio in target_ratios:
60
+ target_aspect_ratio = ratio[0] / ratio[1]
61
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
62
+ if ratio_diff < best_ratio_diff:
63
+ best_ratio_diff = ratio_diff
64
+ best_ratio = ratio
65
+ elif ratio_diff == best_ratio_diff:
66
+ if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
67
+ best_ratio = ratio
68
+ return best_ratio
69
+
70
+
71
+ def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False):
72
+ orig_width, orig_height = image.size
73
+ aspect_ratio = orig_width / orig_height
74
+
75
+ # calculate the existing image aspect ratio
76
+ target_ratios = set(
77
+ (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
78
+ i * j <= max_num and i * j >= min_num)
79
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
80
+
81
+ # find the closest aspect ratio to the target
82
+ target_aspect_ratio = find_closest_aspect_ratio(
83
+ aspect_ratio, target_ratios, orig_width, orig_height, image_size)
84
+
85
+ # calculate the target width and height
86
+ target_width = image_size * target_aspect_ratio[0]
87
+ target_height = image_size * target_aspect_ratio[1]
88
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
89
+
90
+ # resize the image
91
+ resized_img = image.resize((target_width, target_height))
92
+ processed_images = []
93
+ for i in range(blocks):
94
+ box = (
95
+ (i % (target_width // image_size)) * image_size,
96
+ (i // (target_width // image_size)) * image_size,
97
+ ((i % (target_width // image_size)) + 1) * image_size,
98
+ ((i // (target_width // image_size)) + 1) * image_size
99
+ )
100
+ # split the image
101
+ split_img = resized_img.crop(box)
102
+ processed_images.append(split_img)
103
+ assert len(processed_images) == blocks
104
+ if use_thumbnail and len(processed_images) != 1:
105
+ thumbnail_img = image.resize((image_size, image_size))
106
+ processed_images.append(thumbnail_img)
107
+ return processed_images
108
+
109
+
110
+ def heart_beat_worker(controller):
111
+ while True:
112
+ time.sleep(WORKER_HEART_BEAT_INTERVAL)
113
+ controller.send_heart_beat()
114
+
115
+
116
+ def split_model(model_name, vit_alpha=0.5):
117
+ device_map = {}
118
+ world_size = torch.cuda.device_count()
119
+ num_layers = {
120
+ 'InternVL-Chat-V1-1': 40, 'InternVL-Chat-V1-2': 60, 'InternVL-Chat-V1-2-Plus': 60,
121
+ 'Mini-InternVL-2B-V1-5': 24, 'Mini-InternVL-4B-V1-5': 32, 'InternVL-Chat-V1-5': 48,
122
+ 'InternVL2-8B': 32, 'InternVL2-26B': 48, 'InternVL2-40B': 60, 'InternVL2-Llama3-76B': 80,
123
+ 'InternVL2-78B': 80, 'InternVL2-Pro': 80}[model_name]
124
+ # Since the first GPU will be used for ViT, treat it as half a GPU.
125
+ num_layers_per_gpu = math.ceil(num_layers / (world_size - vit_alpha))
126
+ num_layers_per_gpu = [num_layers_per_gpu] * world_size
127
+ num_layers_per_gpu[0] = math.ceil(num_layers_per_gpu[0] * (1 - vit_alpha))
128
+ layer_cnt = 0
129
+ for i, num_layer in enumerate(num_layers_per_gpu):
130
+ for j in range(num_layer):
131
+ device_map[f'language_model.model.layers.{layer_cnt}'] = i
132
+ layer_cnt += 1
133
+ device_map['vision_model'] = 0
134
+ device_map['mlp1'] = 0
135
+ device_map['language_model.model.tok_embeddings'] = 0
136
+ device_map['language_model.model.embed_tokens'] = 0
137
+ device_map['language_model.output'] = 0
138
+ device_map['language_model.model.norm'] = 0
139
+ device_map['language_model.lm_head'] = 0
140
+ device_map[f'language_model.model.layers.{num_layers - 1}'] = 0
141
+
142
+ return device_map
143
+
144
+
145
+ class ModelWorker:
146
+ def __init__(self, controller_addr, worker_addr, worker_id, model_path, model_name,
147
+ load_8bit, device, context_len=8192):
148
+ self.controller_addr = controller_addr
149
+ self.worker_addr = worker_addr
150
+ self.worker_id = worker_id
151
+ if model_path.endswith('/'):
152
+ model_path = model_path[:-1]
153
+ if model_name is None:
154
+ model_paths = model_path.split('/')
155
+ if model_paths[-1].startswith('checkpoint-'):
156
+ self.model_name = model_paths[-2] + '_' + model_paths[-1]
157
+ else:
158
+ self.model_name = model_paths[-1]
159
+ else:
160
+ self.model_name = model_name
161
+
162
+ logger.info(f'Loading the model {self.model_name} on worker {worker_id} ...')
163
+
164
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, use_fast=False)
165
+ tokens_to_keep = ['<box>', '</box>', '<ref>', '</ref>']
166
+ tokenizer.additional_special_tokens = [item for item in tokenizer.additional_special_tokens if item not in tokens_to_keep]
167
+ self.tokenizer = tokenizer
168
+
169
+ if device == 'auto':
170
+ device_map = split_model(self.model_name)
171
+ self.model = AutoModel.from_pretrained(
172
+ model_path,
173
+ load_in_8bit=load_8bit,
174
+ torch_dtype=torch.bfloat16,
175
+ device_map=device_map,
176
+ trust_remote_code=True).eval()
177
+ else:
178
+ self.model = AutoModel.from_pretrained(
179
+ model_path,
180
+ load_in_8bit=load_8bit,
181
+ torch_dtype=torch.bfloat16,
182
+ trust_remote_code=True).eval()
183
+ if not load_8bit and not device == 'auto':
184
+ self.model = self.model.cuda()
185
+ self.load_8bit = load_8bit
186
+ self.device = device
187
+ self.model_path = model_path
188
+ self.image_size = self.model.config.force_image_size
189
+ self.context_len = context_len
190
+ self.register_to_controller()
191
+ self.heart_beat_thread = threading.Thread(
192
+ target=heart_beat_worker, args=(self,))
193
+ self.heart_beat_thread.start()
194
+
195
+ def reload_model(self):
196
+ del self.model
197
+ torch.cuda.empty_cache()
198
+ if self.device == 'auto':
199
+ device_map = split_model(self.model_name)
200
+ self.model = AutoModel.from_pretrained(
201
+ self.model_path,
202
+ load_in_8bit=self.load_8bit,
203
+ torch_dtype=torch.bfloat16,
204
+ device_map=device_map,
205
+ trust_remote_code=True).eval()
206
+ else:
207
+ self.model = AutoModel.from_pretrained(
208
+ self.model_path,
209
+ load_in_8bit=self.load_8bit,
210
+ torch_dtype=torch.bfloat16,
211
+ trust_remote_code=True).eval()
212
+ if not self.load_8bit and not self.device == 'auto':
213
+ self.model = self.model.cuda()
214
+
215
+ def register_to_controller(self):
216
+ logger.info('Register to controller')
217
+
218
+ url = self.controller_addr + '/register_worker'
219
+ data = {
220
+ 'worker_name': self.worker_addr,
221
+ 'check_heart_beat': True,
222
+ 'worker_status': self.get_status()
223
+ }
224
+ r = requests.post(url, json=data)
225
+ assert r.status_code == 200
226
+
227
+ def send_heart_beat(self):
228
+ logger.info(f'Send heart beat. Models: {[self.model_name]}. '
229
+ f'Semaphore: {pretty_print_semaphore(model_semaphore)}. '
230
+ f'global_counter: {global_counter}')
231
+
232
+ url = self.controller_addr + '/receive_heart_beat'
233
+
234
+ while True:
235
+ try:
236
+ ret = requests.post(url, json={
237
+ 'worker_name': self.worker_addr,
238
+ 'queue_length': self.get_queue_length()}, timeout=5)
239
+ exist = ret.json()['exist']
240
+ break
241
+ except requests.exceptions.RequestException as e:
242
+ logger.error(f'heart beat error: {e}')
243
+ time.sleep(5)
244
+
245
+ if not exist:
246
+ self.register_to_controller()
247
+
248
+ def get_queue_length(self):
249
+ if model_semaphore is None:
250
+ return 0
251
+ else:
252
+ return args.limit_model_concurrency - model_semaphore._value + (len(
253
+ model_semaphore._waiters) if model_semaphore._waiters is not None else 0)
254
+
255
+ def get_status(self):
256
+ return {
257
+ 'model_names': [self.model_name],
258
+ 'speed': 1,
259
+ 'queue_length': self.get_queue_length(),
260
+ }
261
+
262
+ @torch.inference_mode()
263
+ def generate_stream(self, params):
264
+ system_message = params['prompt'][0]['content']
265
+ send_messages = params['prompt'][1:]
266
+ max_input_tiles = params['max_input_tiles']
267
+ temperature = params['temperature']
268
+ top_p = params['top_p']
269
+ max_new_tokens = params['max_new_tokens']
270
+ repetition_penalty = params['repetition_penalty']
271
+ do_sample = True if temperature > 0.0 else False
272
+
273
+ global_image_cnt = 0
274
+ history, pil_images, max_input_tile_list = [], [], []
275
+ for message in send_messages:
276
+ if message['role'] == 'user':
277
+ prefix = ''
278
+ if 'image' in message:
279
+ max_input_tile_temp = []
280
+ for image_str in message['image']:
281
+ pil_images.append(load_image_from_base64(image_str))
282
+ prefix += f'Image-{global_image_cnt + 1}: <image>\n'
283
+ global_image_cnt += 1
284
+ max_input_tile_temp.append(max(1, max_input_tiles // len(message['image'])))
285
+ if len(max_input_tile_temp) > 0:
286
+ max_input_tile_list.append(max_input_tile_temp)
287
+ content = prefix + message['content']
288
+ history.append([content, ])
289
+ else:
290
+ history[-1].append(message['content'])
291
+ question, history = history[-1][0], history[:-1]
292
+
293
+ if global_image_cnt == 1:
294
+ question = question.replace('Image-1: <image>\n', '<image>\n')
295
+ history = [[item[0].replace('Image-1: <image>\n', '<image>\n'), item[1]] for item in history]
296
+
297
+ # Create a new list to store processed sublists
298
+ flattened_list = []
299
+ # Iterate through all but the last sublist in max_input_tile_list and process them
300
+ for sublist in max_input_tile_list[:-1]:
301
+ processed_sublist = [1] * len(sublist) # Change each element in the sublist to 1
302
+ flattened_list.extend(processed_sublist) # Flatten the processed sublist and add to the new list
303
+ # If max_input_tile_list is not empty, add the last sublist to the new list
304
+ if max_input_tile_list:
305
+ flattened_list.extend(max_input_tile_list[-1])
306
+ max_input_tile_list = flattened_list
307
+ assert len(max_input_tile_list) == len(pil_images), 'The number of max_input_tile_list and pil_images should be the same.'
308
+
309
+ old_system_message = self.model.system_message
310
+ self.model.system_message = system_message
311
+ image_tiles, num_patches_list = [], []
312
+ transform = build_transform(input_size=self.image_size)
313
+ if len(pil_images) > 0:
314
+ for current_max_input_tiles, pil_image in zip(max_input_tile_list, pil_images):
315
+ if self.model.config.dynamic_image_size:
316
+ tiles = dynamic_preprocess(
317
+ pil_image, image_size=self.image_size, max_num=current_max_input_tiles,
318
+ use_thumbnail=self.model.config.use_thumbnail)
319
+ else:
320
+ tiles = [pil_image]
321
+ num_patches_list.append(len(tiles))
322
+ image_tiles += tiles
323
+ pixel_values = [transform(item) for item in image_tiles]
324
+ pixel_values = torch.stack(pixel_values).to(self.model.device, dtype=torch.bfloat16)
325
+ logger.info(f'Split images to {pixel_values.shape}')
326
+ else:
327
+ pixel_values = None
328
+
329
+ streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=10)
330
+ generation_config = dict(
331
+ num_beams=1,
332
+ max_new_tokens=max_new_tokens,
333
+ do_sample=do_sample,
334
+ temperature=temperature,
335
+ repetition_penalty=repetition_penalty,
336
+ max_length=self.context_len,
337
+ top_p=top_p,
338
+ streamer=streamer,
339
+ )
340
+ logger.info(f'Generation config: {generation_config}')
341
+
342
+ thread = Thread(target=self.model.chat, kwargs=dict(
343
+ tokenizer=self.tokenizer,
344
+ pixel_values=pixel_values,
345
+ num_patches_list=num_patches_list,
346
+ question=question,
347
+ history=history,
348
+ return_history=False,
349
+ generation_config=generation_config,
350
+ ))
351
+ thread.start()
352
+
353
+ generated_text = ''
354
+ for new_text in streamer:
355
+ generated_text += new_text
356
+ if generated_text.endswith(self.model.conv_template.sep):
357
+ generated_text = generated_text[:-len(self.model.conv_template.sep)]
358
+ yield json.dumps({'text': generated_text, 'error_code': 0}).encode() + b'\0'
359
+ logger.info(f'max_input_tile_list: {max_input_tile_list}, history: {history}, '
360
+ f'question: {question}, answer: {generated_text}')
361
+ self.model.system_message = old_system_message
362
+
363
+ def generate_stream_gate(self, params):
364
+ try:
365
+ for x in self.generate_stream(params):
366
+ yield x
367
+ except ValueError as e:
368
+ print('Caught ValueError:', e)
369
+ ret = {
370
+ 'text': server_error_msg,
371
+ 'error_code': 1,
372
+ }
373
+ yield json.dumps(ret).encode() + b'\0'
374
+ except torch.cuda.CudaError as e:
375
+ print('Caught torch.cuda.CudaError:', e)
376
+ ret = {
377
+ 'text': server_error_msg,
378
+ 'error_code': 1,
379
+ }
380
+ yield json.dumps(ret).encode() + b'\0'
381
+ except Exception as e:
382
+ print('Caught Unknown Error', e)
383
+ ret = {
384
+ 'text': server_error_msg,
385
+ 'error_code': 1,
386
+ }
387
+ yield json.dumps(ret).encode() + b'\0'
388
+
389
+
390
+ app = FastAPI()
391
+
392
+
393
+ def release_model_semaphore(fn=None):
394
+ model_semaphore.release()
395
+ if fn is not None:
396
+ fn()
397
+
398
+
399
+ @app.post('/worker_generate_stream')
400
+ async def generate_stream(request: Request):
401
+ global model_semaphore, global_counter
402
+ global_counter += 1
403
+ params = await request.json()
404
+
405
+ if model_semaphore is None:
406
+ model_semaphore = asyncio.Semaphore(args.limit_model_concurrency)
407
+ await model_semaphore.acquire()
408
+ worker.send_heart_beat()
409
+ generator = worker.generate_stream_gate(params)
410
+ background_tasks = BackgroundTasks()
411
+ background_tasks.add_task(partial(release_model_semaphore, fn=worker.send_heart_beat))
412
+ return StreamingResponse(generator, background=background_tasks)
413
+
414
+
415
+ @app.post('/worker_get_status')
416
+ async def get_status(request: Request):
417
+ return worker.get_status()
418
+
419
+
420
+ if __name__ == '__main__':
421
+ parser = argparse.ArgumentParser()
422
+ parser.add_argument('--host', type=str, default='0.0.0.0')
423
+ parser.add_argument('--port', type=int, default=21002)
424
+ parser.add_argument('--worker-address', type=str, default='http://localhost:21002')
425
+ parser.add_argument('--controller-address', type=str, default='http://localhost:21001')
426
+ parser.add_argument('--model-path', type=str, default='facebook/opt-350m')
427
+ parser.add_argument('--model-name', type=str)
428
+ parser.add_argument('--device', type=str, default='cuda')
429
+ parser.add_argument('--limit-model-concurrency', type=int, default=5)
430
+ parser.add_argument('--stream-interval', type=int, default=1)
431
+ parser.add_argument('--load-8bit', action='store_true')
432
+ args = parser.parse_args()
433
+ logger.info(f'args: {args}')
434
+
435
+ worker = ModelWorker(args.controller_address,
436
+ args.worker_address,
437
+ worker_id,
438
+ args.model_path,
439
+ args.model_name,
440
+ args.load_8bit,
441
+ args.device)
442
+ uvicorn.run(app, host=args.host, port=args.port, log_level='info')
InternVL/video_retrieval/test_msrvtt.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import io
3
+ import json
4
+ import math
5
+ import os
6
+
7
+ import decord
8
+ import mmengine
9
+ import numpy as np
10
+ import torch
11
+ import tqdm
12
+ from transformers import AutoModel, AutoTokenizer, CLIPImageProcessor
13
+
14
+
15
+ def recall_at_k(scores, positive_pairs, k):
16
+ """
17
+ Compute the recall at k for each sample
18
+ :param scores: compability score between text and image embeddings (nb texts, nb images)
19
+ :param k: number of images to consider per text, for retrieval
20
+ :param positive_pairs: boolean matrix of positive pairs (nb texts, nb images)
21
+ :return: recall at k averaged over all texts
22
+ """
23
+ nb_texts, nb_images = scores.shape
24
+ # for each text, sort according to image scores in decreasing order
25
+ topk_indices = torch.topk(scores, k, dim=1)[1]
26
+ # compute number of positives for each text
27
+ nb_positive = positive_pairs.sum(dim=1)
28
+ # nb_texts, k, nb_images
29
+ topk_indices_onehot = torch.nn.functional.one_hot(topk_indices, num_classes=nb_images)
30
+ # compute number of true positives
31
+ positive_pairs_reshaped = positive_pairs.view(nb_texts, 1, nb_images)
32
+ # a true positive means a positive among the topk
33
+ nb_true_positive = (topk_indices_onehot * positive_pairs_reshaped).sum(dim=(1, 2))
34
+ # compute recall at k
35
+ recall_at_k = (nb_true_positive / nb_positive)
36
+ return recall_at_k
37
+
38
+
39
+ def batchify(func, X, Y, batch_size, device, *args, **kwargs):
40
+ results = []
41
+ for start in range(0, len(X), batch_size):
42
+ end = start + batch_size
43
+ x = X[start:end].to(device)
44
+ y = Y[start:end].to(device)
45
+ result = func(x, y, *args, **kwargs).cpu()
46
+ results.append(result)
47
+ return torch.cat(results)
48
+
49
+
50
+ def validate_msrvtt(model, tokenizer, image_processor, root, metadata,
51
+ num_frames=1, prefix='summarize:', mode='InternVL-G', recall_k_list=[1, 5, 10],
52
+ use_dsl=True, eval_batch_size=32):
53
+ metadata = json.load(open(metadata))
54
+
55
+ video_features = []
56
+ text_features = []
57
+
58
+ # compute text features
59
+ print('Computing text features', flush=True)
60
+ for data in tqdm.tqdm(metadata):
61
+ caption = prefix + data['caption']
62
+ input_ids = tokenizer(caption, return_tensors='pt', max_length=80,
63
+ truncation=True, padding='max_length').input_ids.cuda()
64
+ with torch.no_grad():
65
+ feat = model.encode_text(input_ids)
66
+ text_features.append(feat.cpu())
67
+ text_features = torch.cat(text_features)
68
+
69
+ # compute video features
70
+ print('Computing video features', flush=True)
71
+ for data in tqdm.tqdm(metadata):
72
+ video_id = data['video']
73
+ video_path = os.path.join(root, video_id)
74
+ video_data = mmengine.get(video_path)
75
+ video_data = io.BytesIO(video_data)
76
+ video_reader = decord.VideoReader(video_data)
77
+
78
+ # uniformly sample frames
79
+ interval = math.ceil(len(video_reader) / num_frames)
80
+ frames_id = np.arange(0, len(video_reader), interval) + interval // 2
81
+ assert len(frames_id) == num_frames and frames_id[-1] < len(video_reader)
82
+
83
+ frames = video_reader.get_batch(frames_id).asnumpy()
84
+
85
+ pixel_values = image_processor(images=frames, return_tensors='pt').pixel_values
86
+ with torch.no_grad():
87
+ pixel_values = pixel_values.to(torch.bfloat16).cuda()
88
+ feat = model.encode_image(pixel_values, mode=mode)
89
+ feat = feat.mean(dim=0, keepdim=True)
90
+ video_features.append(feat.cpu())
91
+ video_features = torch.cat(video_features)
92
+
93
+ print('Computing metrics', flush=True)
94
+ texts_emb = text_features / text_features.norm(dim=-1, keepdim=True)
95
+ images_emb = video_features / video_features.norm(dim=-1, keepdim=True)
96
+
97
+ # get the score for each text and image pair
98
+ scores = texts_emb @ images_emb.t()
99
+
100
+ # construct a the positive pair matrix, which tells whether each text-image pair is a positive or not
101
+ positive_pairs = torch.zeros_like(scores, dtype=bool)
102
+ positive_pairs[torch.arange(len(scores)), torch.arange(len(scores))] = True
103
+
104
+ scores_T = scores.T
105
+ positive_pairs_T = positive_pairs.T
106
+
107
+ if use_dsl:
108
+ scores = scores * scores.softmax(dim=0)
109
+ scores_T = scores_T * scores_T.softmax(dim=0)
110
+
111
+ metrics = {}
112
+ for recall_k in recall_k_list:
113
+ # Note that recall_at_k computes **actual** recall i.e. nb_true_positive/nb_positives, where the number
114
+ # of true positives, e.g. for text retrieval, is, for each image, the number of retrieved texts matching that image among the top-k.
115
+ # Also, the number of positives are the total number of texts matching the image in the dataset, as we have a set of captions
116
+ # for each image, that number will be greater than 1 for text retrieval.
117
+ # However, image/text retrieval recall@k, the way it is done in CLIP-like papers, is a bit different.
118
+ # recall@k, in CLIP-like papers, is, for each image, either 1 or 0. It is 1 if atleast one text matches the image among the top-k.
119
+ # so we can easily compute that using the actual recall, by checking whether there is at least one true positive,
120
+ # which would be the case if the recall is greater than 0. One we compute the recal for each image (or text), we average
121
+ # it over the dataset.
122
+ metrics[f't2v_retrieval_recall@{recall_k}'] = (
123
+ batchify(recall_at_k, scores, positive_pairs, eval_batch_size, scores.device,
124
+ k=recall_k) > 0).float().mean().item()
125
+ metrics[f'v2t_retrieval_recall@{recall_k}'] = (
126
+ batchify(recall_at_k, scores_T, positive_pairs_T, eval_batch_size, scores.device,
127
+ k=recall_k) > 0).float().mean().item()
128
+
129
+ print(metrics)
130
+
131
+
132
+ if __name__ == '__main__':
133
+ parser = argparse.ArgumentParser(description='validate MSR-VTT', add_help=False)
134
+ parser.add_argument('--video-root', type=str)
135
+ parser.add_argument('--metadata', type=str)
136
+ parser.add_argument('--mode', type=str, default='InternVL-C',choices=['InternVL-C', 'InternVL-G'])
137
+ parser.add_argument('--num-frames', type=int, default=1)
138
+ args = parser.parse_args()
139
+
140
+ model = AutoModel.from_pretrained(
141
+ 'OpenGVLab/InternVL-14B-224px',
142
+ torch_dtype=torch.bfloat16,
143
+ low_cpu_mem_usage=True,
144
+ trust_remote_code=True).cuda().eval()
145
+
146
+ image_processor = CLIPImageProcessor.from_pretrained('OpenGVLab/InternVL-14B-224px')
147
+
148
+ tokenizer = AutoTokenizer.from_pretrained(
149
+ 'OpenGVLab/InternVL-14B-224px', use_fast=False, add_eos_token=True)
150
+ tokenizer.pad_token_id = 0 # set pad_token_id to 0
151
+
152
+ metrics = validate_msrvtt(model, tokenizer, image_processor,
153
+ root=args.video_root,
154
+ metadata=args.metadata,
155
+ mode=args.mode,
156
+ num_frames=args.num_frames,)
sglang/examples/frontend_language/quick_start/gemini_example_chat.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Usage:
3
+ export GCP_PROJECT_ID=******
4
+ python3 gemini_example_chat.py
5
+ """
6
+
7
+ import sglang as sgl
8
+
9
+
10
+ @sgl.function
11
+ def multi_turn_question(s, question_1, question_2):
12
+ s += sgl.user(question_1)
13
+ s += sgl.assistant(sgl.gen("answer_1", max_tokens=256))
14
+ s += sgl.user(question_2)
15
+ s += sgl.assistant(sgl.gen("answer_2", max_tokens=256))
16
+
17
+
18
+ def single():
19
+ state = multi_turn_question.run(
20
+ question_1="What is the capital of the United States?",
21
+ question_2="List two local attractions.",
22
+ )
23
+
24
+ for m in state.messages():
25
+ print(m["role"], ":", m["content"])
26
+
27
+ print("\n-- answer_1 --\n", state["answer_1"])
28
+
29
+
30
+ def stream():
31
+ state = multi_turn_question.run(
32
+ question_1="What is the capital of the United States?",
33
+ question_2="List two local attractions.",
34
+ stream=True,
35
+ )
36
+
37
+ for out in state.text_iter():
38
+ print(out, end="", flush=True)
39
+ print()
40
+
41
+
42
+ def batch():
43
+ states = multi_turn_question.run_batch(
44
+ [
45
+ {
46
+ "question_1": "What is the capital of the United States?",
47
+ "question_2": "List two local attractions.",
48
+ },
49
+ {
50
+ "question_1": "What is the capital of France?",
51
+ "question_2": "What is the population of this city?",
52
+ },
53
+ ]
54
+ )
55
+
56
+ for s in states:
57
+ print(s.messages())
58
+
59
+
60
+ if __name__ == "__main__":
61
+ sgl.set_default_backend(sgl.VertexAI("gemini-pro"))
62
+
63
+ # Run a single request
64
+ print("\n========== single ==========\n")
65
+ single()
66
+
67
+ # Stream output
68
+ print("\n========== stream ==========\n")
69
+ stream()
70
+
71
+ # Run a batch of requests
72
+ print("\n========== batch ==========\n")
73
+ batch()
sglang/examples/frontend_language/quick_start/gemini_example_multimodal_chat.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Usage:
3
+ export GCP_PROJECT_ID=******
4
+ python3 gemini_example_multimodal_chat.py
5
+ """
6
+
7
+ import sglang as sgl
8
+
9
+
10
+ @sgl.function
11
+ def image_qa(s, image_file1, image_file2, question):
12
+ s += sgl.user(sgl.image(image_file1) + sgl.image(image_file2) + question)
13
+ s += sgl.assistant(sgl.gen("answer", max_tokens=256))
14
+
15
+
16
+ if __name__ == "__main__":
17
+ sgl.set_default_backend(sgl.VertexAI("gemini-pro-vision"))
18
+
19
+ state = image_qa.run(
20
+ image_file1="./images/cat.jpeg",
21
+ image_file2="./images/dog.jpeg",
22
+ question="Describe difference of the two images in one sentence.",
23
+ stream=True,
24
+ )
25
+
26
+ for out in state.text_iter("answer"):
27
+ print(out, end="", flush=True)
28
+ print()
29
+
30
+ print(state["answer"])
sglang/examples/frontend_language/quick_start/local_example_complete.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Usage:
3
+ python3 local_example_complete.py
4
+ """
5
+
6
+ import sglang as sgl
7
+
8
+
9
+ @sgl.function
10
+ def few_shot_qa(s, question):
11
+ s += """The following are questions with answers.
12
+ Q: What is the capital of France?
13
+ A: Paris
14
+ Q: What is the capital of Germany?
15
+ A: Berlin
16
+ Q: What is the capital of Italy?
17
+ A: Rome
18
+ """
19
+ s += "Q: " + question + "\n"
20
+ s += "A:" + sgl.gen("answer", stop="\n", temperature=0)
21
+
22
+
23
+ def single():
24
+ state = few_shot_qa.run(question="What is the capital of the United States?")
25
+ answer = state["answer"].strip().lower()
26
+
27
+ assert "washington" in answer, f"answer: {state['answer']}"
28
+
29
+ print(state.text())
30
+
31
+
32
+ def stream():
33
+ state = few_shot_qa.run(
34
+ question="What is the capital of the United States?", stream=True
35
+ )
36
+
37
+ for out in state.text_iter("answer"):
38
+ print(out, end="", flush=True)
39
+ print()
40
+
41
+
42
+ def batch():
43
+ states = few_shot_qa.run_batch(
44
+ [
45
+ {"question": "What is the capital of the United States?"},
46
+ {"question": "What is the capital of China?"},
47
+ ]
48
+ )
49
+
50
+ for s in states:
51
+ print(s["answer"])
52
+
53
+
54
+ if __name__ == "__main__":
55
+ runtime = sgl.Runtime(model_path="meta-llama/Llama-2-7b-chat-hf")
56
+ sgl.set_default_backend(runtime)
57
+
58
+ # Run a single request
59
+ print("\n========== single ==========\n")
60
+ single()
61
+
62
+ # Stream output
63
+ print("\n========== stream ==========\n")
64
+ stream()
65
+
66
+ # Run a batch of requests
67
+ print("\n========== batch ==========\n")
68
+ batch()
69
+
70
+ runtime.shutdown()
sglang/examples/frontend_language/quick_start/local_example_llava_next.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Usage: python3 local_example_llava_next.py
3
+ """
4
+
5
+ import sglang as sgl
6
+ from sglang.lang.chat_template import get_chat_template
7
+
8
+
9
+ @sgl.function
10
+ def image_qa(s, image_path, question):
11
+ s += sgl.user(sgl.image(image_path) + question)
12
+ s += sgl.assistant(sgl.gen("answer"))
13
+
14
+
15
+ def single():
16
+ state = image_qa.run(
17
+ image_path="images/cat.jpeg", question="What is this?", max_new_tokens=128
18
+ )
19
+ print(state["answer"], "\n")
20
+
21
+
22
+ def stream():
23
+ state = image_qa.run(
24
+ image_path="images/cat.jpeg",
25
+ question="What is this?",
26
+ max_new_tokens=64,
27
+ stream=True,
28
+ )
29
+
30
+ for out in state.text_iter("answer"):
31
+ print(out, end="", flush=True)
32
+ print()
33
+
34
+
35
+ def batch():
36
+ states = image_qa.run_batch(
37
+ [
38
+ {"image_path": "images/cat.jpeg", "question": "What is this?"},
39
+ {"image_path": "images/dog.jpeg", "question": "What is this?"},
40
+ ],
41
+ max_new_tokens=128,
42
+ )
43
+ for s in states:
44
+ print(s["answer"], "\n")
45
+
46
+
47
+ if __name__ == "__main__":
48
+ import multiprocessing as mp
49
+
50
+ mp.set_start_method("spawn", force=True)
51
+
52
+ runtime = sgl.Runtime(model_path="lmms-lab/llama3-llava-next-8b")
53
+ runtime.endpoint.chat_template = get_chat_template("llama-3-instruct-llava")
54
+
55
+ # Or you can use the 72B model
56
+ # runtime = sgl.Runtime(model_path="lmms-lab/llava-next-72b", tp_size=8)
57
+ # runtime.endpoint.chat_template = get_chat_template("chatml-llava")
58
+
59
+ sgl.set_default_backend(runtime)
60
+ print(f"chat template: {runtime.endpoint.chat_template.name}")
61
+
62
+ # Or you can use API models
63
+ # sgl.set_default_backend(sgl.OpenAI("gpt-4-vision-preview"))
64
+ # sgl.set_default_backend(sgl.VertexAI("gemini-pro-vision"))
65
+
66
+ # Run a single request
67
+ print("\n========== single ==========\n")
68
+ single()
69
+
70
+ # Stream output
71
+ print("\n========== stream ==========\n")
72
+ stream()
73
+
74
+ # Run a batch of requests
75
+ print("\n========== batch ==========\n")
76
+ batch()
77
+
78
+ runtime.shutdown()
sglang/examples/frontend_language/quick_start/openai_example_chat.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Usage:
3
+ export OPENAI_API_KEY=sk-******
4
+ python3 openai_example_chat.py
5
+ """
6
+
7
+ import sglang as sgl
8
+
9
+
10
+ @sgl.function
11
+ def multi_turn_question(s, question_1, question_2):
12
+ s += sgl.system("You are a helpful assistant.")
13
+ s += sgl.user(question_1)
14
+ s += sgl.assistant(sgl.gen("answer_1", max_tokens=256))
15
+ s += sgl.user(question_2)
16
+ s += sgl.assistant(sgl.gen("answer_2", max_tokens=256))
17
+
18
+
19
+ def single():
20
+ state = multi_turn_question.run(
21
+ question_1="What is the capital of the United States?",
22
+ question_2="List two local attractions.",
23
+ )
24
+
25
+ for m in state.messages():
26
+ print(m["role"], ":", m["content"])
27
+
28
+ print("\n-- answer_1 --\n", state["answer_1"])
29
+
30
+
31
+ def stream():
32
+ state = multi_turn_question.run(
33
+ question_1="What is the capital of the United States?",
34
+ question_2="List two local attractions.",
35
+ stream=True,
36
+ )
37
+
38
+ for out in state.text_iter():
39
+ print(out, end="", flush=True)
40
+ print()
41
+
42
+
43
+ def batch():
44
+ states = multi_turn_question.run_batch(
45
+ [
46
+ {
47
+ "question_1": "What is the capital of the United States?",
48
+ "question_2": "List two local attractions.",
49
+ },
50
+ {
51
+ "question_1": "What is the capital of France?",
52
+ "question_2": "What is the population of this city?",
53
+ },
54
+ ]
55
+ )
56
+
57
+ for s in states:
58
+ print(s.messages())
59
+
60
+
61
+ if __name__ == "__main__":
62
+ sgl.set_default_backend(sgl.OpenAI("gpt-3.5-turbo"))
63
+
64
+ # Run a single request
65
+ print("\n========== single ==========\n")
66
+ single()
67
+
68
+ # Stream output
69
+ print("\n========== stream ==========\n")
70
+ stream()
71
+
72
+ # Run a batch of requests
73
+ print("\n========== batch ==========\n")
74
+ batch()
sglang/examples/frontend_language/quick_start/openai_example_complete.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Usage:
3
+ export OPENAI_API_KEY=sk-******
4
+ python3 openai_example_complete.py
5
+ """
6
+
7
+ import sglang as sgl
8
+
9
+
10
+ @sgl.function
11
+ def few_shot_qa(s, question):
12
+ s += """The following are questions with answers.
13
+ Q: What is the capital of France?
14
+ A: Paris
15
+ Q: What is the capital of Germany?
16
+ A: Berlin
17
+ Q: What is the capital of Italy?
18
+ A: Rome
19
+ """
20
+ s += "Q: " + question + "\n"
21
+ s += "A:" + sgl.gen("answer", stop="\n", temperature=0)
22
+
23
+
24
+ def single():
25
+ state = few_shot_qa.run(question="What is the capital of the United States?")
26
+ answer = state["answer"].strip().lower()
27
+
28
+ assert "washington" in answer, f"answer: {state['answer']}"
29
+
30
+ print(state.text())
31
+
32
+
33
+ def stream():
34
+ state = few_shot_qa.run(
35
+ question="What is the capital of the United States?", stream=True
36
+ )
37
+
38
+ for out in state.text_iter("answer"):
39
+ print(out, end="", flush=True)
40
+ print()
41
+
42
+
43
+ def batch():
44
+ states = few_shot_qa.run_batch(
45
+ [
46
+ {"question": "What is the capital of the United States?"},
47
+ {"question": "What is the capital of China?"},
48
+ ]
49
+ )
50
+
51
+ for s in states:
52
+ print(s["answer"])
53
+
54
+
55
+ if __name__ == "__main__":
56
+ sgl.set_default_backend(sgl.OpenAI("gpt-3.5-turbo-instruct"))
57
+
58
+ # Run a single request
59
+ print("\n========== single ==========\n")
60
+ single()
61
+
62
+ # Stream output
63
+ print("\n========== stream ==========\n")
64
+ stream()
65
+
66
+ # Run a batch of requests
67
+ print("\n========== batch ==========\n")
68
+ batch()
sglang/examples/frontend_language/quick_start/openrouter_example_chat.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Usage:
3
+ export OPENROUTER_API_KEY=sk-******
4
+ python3 together_example_chat.py
5
+ """
6
+
7
+ import os
8
+
9
+ import sglang as sgl
10
+
11
+
12
+ @sgl.function
13
+ def multi_turn_question(s, question_1, question_2):
14
+ s += sgl.system("You are a helpful assistant.")
15
+ s += sgl.user(question_1)
16
+ s += sgl.assistant(sgl.gen("answer_1", max_tokens=256))
17
+ s += sgl.user(question_2)
18
+ s += sgl.assistant(sgl.gen("answer_2", max_tokens=256))
19
+
20
+
21
+ def single():
22
+ state = multi_turn_question.run(
23
+ question_1="What is the capital of the United States?",
24
+ question_2="List two local attractions.",
25
+ )
26
+
27
+ for m in state.messages():
28
+ print(m["role"], ":", m["content"])
29
+
30
+ print("\n-- answer_1 --\n", state["answer_1"])
31
+
32
+
33
+ def stream():
34
+ state = multi_turn_question.run(
35
+ question_1="What is the capital of the United States?",
36
+ question_2="List two local attractions.",
37
+ stream=True,
38
+ )
39
+
40
+ for out in state.text_iter():
41
+ print(out, end="", flush=True)
42
+ print()
43
+
44
+
45
+ def batch():
46
+ states = multi_turn_question.run_batch(
47
+ [
48
+ {
49
+ "question_1": "What is the capital of the United States?",
50
+ "question_2": "List two local attractions.",
51
+ },
52
+ {
53
+ "question_1": "What is the capital of France?",
54
+ "question_2": "What is the population of this city?",
55
+ },
56
+ ]
57
+ )
58
+
59
+ for s in states:
60
+ print(s.messages())
61
+
62
+
63
+ if __name__ == "__main__":
64
+ backend = sgl.OpenAI(
65
+ model_name="google/gemma-7b-it:free",
66
+ base_url="https://openrouter.ai/api/v1",
67
+ api_key=os.environ.get("OPENROUTER_API_KEY"),
68
+ )
69
+ sgl.set_default_backend(backend)
70
+
71
+ # Run a single request
72
+ print("\n========== single ==========\n")
73
+ single()
74
+
75
+ # Stream output
76
+ print("\n========== stream ==========\n")
77
+ stream()
78
+
79
+ # Run a batch of requests
80
+ print("\n========== batch ==========\n")
81
+ batch()
sglang/examples/frontend_language/quick_start/together_example_complete.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Usage:
3
+ export TOGETHER_API_KEY=sk-******
4
+ python3 together_example_complete.py
5
+ """
6
+
7
+ import os
8
+
9
+ import sglang as sgl
10
+
11
+
12
+ @sgl.function
13
+ def few_shot_qa(s, question):
14
+ s += """The following are questions with answers.
15
+ Q: What is the capital of France?
16
+ A: Paris
17
+ Q: What is the capital of Germany?
18
+ A: Berlin
19
+ Q: What is the capital of Italy?
20
+ A: Rome
21
+ """
22
+ s += "Q: " + question + "\n"
23
+ s += "A:" + sgl.gen("answer", stop="\n", temperature=0)
24
+
25
+
26
+ def single():
27
+ state = few_shot_qa.run(question="What is the capital of the United States?")
28
+ answer = state["answer"].strip().lower()
29
+
30
+ assert "washington" in answer, f"answer: {state['answer']}"
31
+
32
+ print(state.text())
33
+
34
+
35
+ def stream():
36
+ state = few_shot_qa.run(
37
+ question="What is the capital of the United States?", stream=True
38
+ )
39
+
40
+ for out in state.text_iter("answer"):
41
+ print(out, end="", flush=True)
42
+ print()
43
+
44
+
45
+ def batch():
46
+ states = few_shot_qa.run_batch(
47
+ [
48
+ {"question": "What is the capital of the United States?"},
49
+ {"question": "What is the capital of China?"},
50
+ ]
51
+ )
52
+
53
+ for s in states:
54
+ print(s["answer"])
55
+
56
+
57
+ if __name__ == "__main__":
58
+ backend = sgl.OpenAI(
59
+ model_name="mistralai/Mixtral-8x7B-Instruct-v0.1",
60
+ is_chat_model=False,
61
+ base_url="https://api.together.xyz/v1",
62
+ api_key=os.environ.get("TOGETHER_API_KEY"),
63
+ )
64
+ sgl.set_default_backend(backend)
65
+
66
+ # Run a single request
67
+ print("\n========== single ==========\n")
68
+ single()
69
+
70
+ # Stream output
71
+ print("\n========== stream ==========\n")
72
+ stream()
73
+
74
+ # Run a batch of requests
75
+ print("\n========== batch ==========\n")
76
+ batch()
sglang/examples/frontend_language/usage/chinese_regex.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sglang as sgl
2
+
3
+ character_regex = (
4
+ r"""\{\n"""
5
+ + r""" "姓名": "[^"]{1,32}",\n"""
6
+ + r""" "学院": "(格兰芬多|赫奇帕奇|拉文克劳|斯莱特林)",\n"""
7
+ + r""" "血型": "(纯血|混血|麻瓜)",\n"""
8
+ + r""" "职业": "(学生|教师|傲罗|魔法部|食死徒|凤凰社成员)",\n"""
9
+ + r""" "魔杖": \{\n"""
10
+ + r""" "材质": "[^"]{1,32}",\n"""
11
+ + r""" "杖芯": "[^"]{1,32}",\n"""
12
+ + r""" "长度": [0-9]{1,2}\.[0-9]{0,2}\n"""
13
+ + r""" \},\n"""
14
+ + r""" "存活": "(存活|死亡)",\n"""
15
+ + r""" "守护神": "[^"]{1,32}",\n"""
16
+ + r""" "博格特": "[^"]{1,32}"\n"""
17
+ + r"""\}"""
18
+ )
19
+
20
+
21
+ @sgl.function
22
+ def character_gen(s, name):
23
+ s += name + " 是一名哈利波特系列小说中的角色。请填写以下关于这个角色的信息。"
24
+ s += """\
25
+ 这是一个例子
26
+ {
27
+ "姓名": "哈利波特",
28
+ "学院": "格兰芬多",
29
+ "血型": "混血",
30
+ "职业": "学生",
31
+ "魔杖": {
32
+ "材质": "冬青木",
33
+ "杖芯": "凤凰尾羽",
34
+ "长度": 11.0
35
+ },
36
+ "存活": "存活",
37
+ "守护神": "麋鹿",
38
+ "博格特": "摄魂怪"
39
+ }
40
+ """
41
+ s += f"现在请你填写{name}的信息:\n"
42
+ s += sgl.gen("json_output", max_tokens=256, regex=character_regex)
43
+
44
+
45
+ def main():
46
+ backend = sgl.RuntimeEndpoint("http://localhost:30000")
47
+ sgl.set_default_backend(backend)
48
+ ret = character_gen.run(name="赫敏格兰杰", temperature=0)
49
+ print(ret.text())
50
+
51
+
52
+ if __name__ == "__main__":
53
+ main()
sglang/examples/frontend_language/usage/choices_logprob.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Usage:
3
+ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
4
+ python choices_logprob.py
5
+ """
6
+
7
+ import sglang as sgl
8
+
9
+
10
+ @sgl.function
11
+ def tool_use(s, question):
12
+ s += "To answer this question: " + question + ", "
13
+ s += "I need to use a " + sgl.gen("tool", choices=["calculator", "search engine"])
14
+
15
+
16
+ def main():
17
+ # Run one case
18
+ question = "What is 5 + 5?"
19
+ state = tool_use.run(question)
20
+ print("questions:", question)
21
+ print("choice:", state["tool"])
22
+ meta_info = state.get_meta_info("tool")
23
+ print("logprobs of choice 1", meta_info["input_token_logprobs"][0])
24
+ print("logprobs of choice 2", meta_info["input_token_logprobs"][1])
25
+ print("-" * 50)
26
+
27
+ # Run a batch
28
+ questions = [
29
+ "What is 5 + 6?",
30
+ "Who is Michael Jordan?",
31
+ ]
32
+ states = tool_use.run_batch([{"question": q} for q in questions])
33
+ for question, state in zip(questions, states):
34
+ print("questions:", question)
35
+ print("choice:", state["tool"])
36
+ meta_info = state.get_meta_info("tool")
37
+ print("logprobs of choice 1", meta_info["input_token_logprobs"][0])
38
+ print("logprobs of choice 2", meta_info["input_token_logprobs"][1])
39
+ print("-" * 50)
40
+
41
+
42
+ if __name__ == "__main__":
43
+ sgl.set_default_backend(sgl.RuntimeEndpoint("http://localhost:30000"))
44
+ main()
sglang/examples/frontend_language/usage/cot_decoding.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import exp
2
+ from pprint import pformat
3
+
4
+ import sglang as sgl
5
+
6
+ YELLOW = "\033[1;33m"
7
+ GREEN = "\033[1;32m"
8
+ BLUE = "\033[1;34m"
9
+ CLEAR = "\033[1;0m"
10
+
11
+
12
+ @sgl.function
13
+ def cot_decoding(s, question, get_top_k, is_chat_model, verbose):
14
+ """CoT Decoding: http://arxiv.org/abs/2402.10200"""
15
+
16
+ if is_chat_model:
17
+ s += sgl.user("Question: " + question + "\nAnswer:")
18
+ s += sgl.assistant_begin()
19
+ else:
20
+ s += "Question: " + question + "\nAnswer:"
21
+
22
+ step_0 = s.fork(1)[0]
23
+ forks = s.fork(get_top_k)
24
+ answer_forks = s.fork(get_top_k)
25
+
26
+ # decoding step 0
27
+ step_0 += sgl.gen(
28
+ "get_top_k",
29
+ max_tokens=0,
30
+ return_logprob=True,
31
+ top_logprobs_num=get_top_k,
32
+ return_text_in_logprobs=True,
33
+ )
34
+ logprobs = step_0.get_meta_info("get_top_k")["output_top_logprobs"][0]
35
+
36
+ print("Decoding step 0:", ", ".join(pformat(token[2]) for token in logprobs))
37
+ for idx, (f, token) in enumerate(zip(forks, logprobs)):
38
+ logprob, token_id, text = token
39
+ f += text
40
+
41
+ if text == "<|end_of_text|>":
42
+ print(
43
+ f"{YELLOW}Path #{idx} {pformat(text)}[{exp(logprob):.3f}] (score=nan, answer=nan){CLEAR}"
44
+ )
45
+ continue
46
+
47
+ # continue greedy decoding
48
+ f += sgl.gen(
49
+ "answer",
50
+ temperature=0,
51
+ max_tokens=1024,
52
+ return_logprob=True,
53
+ top_logprobs_num=2,
54
+ return_text_in_logprobs=True,
55
+ )
56
+
57
+ # calculate probability disparity between the top and secondary tokens
58
+ x1s = [exp(xt[0][0]) for xt in f.get_meta_info("answer")["output_top_logprobs"]]
59
+ x2s = [exp(xt[1][0]) for xt in f.get_meta_info("answer")["output_top_logprobs"]]
60
+ tokens = [xt[0][2] for xt in f.get_meta_info("answer")["output_top_logprobs"]]
61
+ delta = (sum(x1s) - sum(x2s)) / len(x1s)
62
+
63
+ # extract the answer span (without the '<|end_of_text|>' token)
64
+ answer_forks[idx] += text + f["answer"] + "\nSo the answer is"
65
+ answer_forks[idx] += sgl.gen(
66
+ "answer_span",
67
+ temperature=0,
68
+ max_tokens=64,
69
+ return_logprob=True,
70
+ top_logprobs_num=2,
71
+ return_text_in_logprobs=True,
72
+ )
73
+ answer = answer_forks[idx]["answer_span"].replace("\n", " ").strip(":")
74
+ print(
75
+ f"{YELLOW}Path #{idx} {pformat(text)}[{exp(logprob):.3f}] (score={delta}, answer={answer}){CLEAR}"
76
+ )
77
+ generated_text = str(answer_forks[idx])[len("ProgramState(") : -1]
78
+ print(f"{BLUE}{pformat(generated_text)}{CLEAR}")
79
+
80
+ if verbose:
81
+ answer_tokens = [
82
+ xt[0][2]
83
+ for xt in answer_forks[idx].get_meta_info("answer_span")[
84
+ "output_top_logprobs"
85
+ ]
86
+ ]
87
+ answer_x1s = [
88
+ exp(xt[0][0])
89
+ for xt in answer_forks[idx].get_meta_info("answer_span")[
90
+ "output_top_logprobs"
91
+ ]
92
+ ]
93
+ answer_x2s = [
94
+ exp(xt[1][0])
95
+ for xt in answer_forks[idx].get_meta_info("answer_span")[
96
+ "output_top_logprobs"
97
+ ]
98
+ ]
99
+
100
+ for token, x1, x2 in zip(tokens, x1s, x2s):
101
+ print(f" {GREEN}{pformat(token)}{CLEAR}({x1:.3f}-{x2:.3f})", end="")
102
+ print("\n===========")
103
+ for token, x1, x2 in zip(answer_tokens, answer_x1s, answer_x2s):
104
+ print(f" {GREEN}{pformat(token)}{CLEAR}({x1:.3f}-{x2:.3f})", end="")
105
+ print()
106
+
107
+
108
+ sgl.set_default_backend(sgl.RuntimeEndpoint("http://localhost:30000"))
109
+
110
+ state = cot_decoding.run(
111
+ question=r"Claire makes a 3 egg omelet every morning for breakfast. How many dozens of eggs will she eat in 4 weeks?",
112
+ get_top_k=10,
113
+ is_chat_model=True,
114
+ verbose=False,
115
+ )
sglang/examples/frontend_language/usage/json_decode.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Usage:
3
+ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
4
+ python json_decode.py
5
+ """
6
+
7
+ from enum import Enum
8
+
9
+ from pydantic import BaseModel
10
+
11
+ import sglang as sgl
12
+ from sglang.srt.constrained import build_regex_from_object
13
+
14
+ character_regex = (
15
+ r"""\{\n"""
16
+ + r""" "name": "[\w\d\s]{1,16}",\n"""
17
+ + r""" "house": "(Gryffindor|Slytherin|Ravenclaw|Hufflepuff)",\n"""
18
+ + r""" "blood status": "(Pure-blood|Half-blood|Muggle-born)",\n"""
19
+ + r""" "occupation": "(student|teacher|auror|ministry of magic|death eater|order of the phoenix)",\n"""
20
+ + r""" "wand": \{\n"""
21
+ + r""" "wood": "[\w\d\s]{1,16}",\n"""
22
+ + r""" "core": "[\w\d\s]{1,16}",\n"""
23
+ + r""" "length": [0-9]{1,2}\.[0-9]{0,2}\n"""
24
+ + r""" \},\n"""
25
+ + r""" "alive": "(Alive|Deceased)",\n"""
26
+ + r""" "patronus": "[\w\d\s]{1,16}",\n"""
27
+ + r""" "bogart": "[\w\d\s]{1,16}"\n"""
28
+ + r"""\}"""
29
+ )
30
+
31
+
32
+ @sgl.function
33
+ def character_gen(s, name):
34
+ s += (
35
+ name
36
+ + " is a character in Harry Potter. Please fill in the following information about this character.\n"
37
+ )
38
+ s += "The constrained regex is:\n"
39
+ s += character_regex + "\n"
40
+ s += "The JSON output is:\n"
41
+ s += sgl.gen("json_output", max_tokens=256, regex=character_regex)
42
+
43
+
44
+ def driver_character_gen():
45
+ state = character_gen.run(name="Hermione Granger")
46
+ print(state.text())
47
+
48
+
49
+ class Weapon(str, Enum):
50
+ sword = "sword"
51
+ axe = "axe"
52
+ mace = "mace"
53
+ spear = "spear"
54
+ bow = "bow"
55
+ crossbow = "crossbow"
56
+
57
+
58
+ class Wizard(BaseModel):
59
+ name: str
60
+ age: int
61
+ weapon: Weapon
62
+
63
+
64
+ @sgl.function
65
+ def pydantic_wizard_gen(s):
66
+ s += "Give me a description about a wizard in the JSON format.\n"
67
+ s += sgl.gen(
68
+ "character",
69
+ max_tokens=128,
70
+ temperature=0,
71
+ regex=build_regex_from_object(Wizard), # Requires pydantic >= 2.0
72
+ )
73
+
74
+
75
+ def driver_pydantic_wizard_gen():
76
+ state = pydantic_wizard_gen.run()
77
+ print(state.text())
78
+
79
+
80
+ if __name__ == "__main__":
81
+ sgl.set_default_backend(sgl.RuntimeEndpoint("http://localhost:30000"))
82
+ driver_character_gen()
83
+ # driver_pydantic_wizard_gen()
sglang/examples/frontend_language/usage/json_logprobs.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # NOTE: Currently this can only be run through HTTP requests.
2
+ import json
3
+ from concurrent.futures import ThreadPoolExecutor
4
+
5
+ from json_decode import character_regex
6
+
7
+ from sglang.utils import http_request
8
+
9
+ character_names = ["Hermione Granger", "Ron Weasley", "Harry Potter"]
10
+
11
+ base_url = "http://localhost:30000"
12
+
13
+ prompt = "is a character in Harry Potter. Please fill in the following information about this character.\n"
14
+
15
+
16
+ def openai_api_request(name):
17
+ data = {
18
+ "model": "",
19
+ "prompt": name + prompt,
20
+ "temperature": 0,
21
+ "max_tokens": 128,
22
+ "regex": character_regex,
23
+ "logprobs": 3,
24
+ }
25
+ res = http_request(base_url + "/v1/completions", json=data).json()
26
+
27
+ # with open(f"json_logprobs_{name.replace(' ', '_')}_tmp.json", "w") as fout:
28
+ # fout.write(json.dumps(res, indent=4))
29
+
30
+ logprobs = res["choices"][0]["logprobs"]
31
+ usage = res["usage"]
32
+ assert len(logprobs["token_logprobs"]) == len(logprobs["tokens"])
33
+ assert len(logprobs["token_logprobs"]) == len(logprobs["top_logprobs"])
34
+ assert len(logprobs["token_logprobs"]) == usage["completion_tokens"] - 1
35
+
36
+ return res
37
+
38
+
39
+ def srt_api_request(name):
40
+ data = {
41
+ "text": name + prompt,
42
+ "sampling_params": {
43
+ "temperature": 0,
44
+ "max_new_tokens": 128,
45
+ "regex": character_regex,
46
+ },
47
+ "return_logprob": True,
48
+ "logprob_start_len": 0,
49
+ "top_logprobs_num": 3,
50
+ "return_text_in_logprobs": True,
51
+ }
52
+
53
+ res = http_request(base_url + "/generate", json=data).json()
54
+
55
+ # with open(f"json_logprobs_{name.replace(' ', '_')}_tmp.json", "w") as fout:
56
+ # fout.write(json.dumps(res, indent=4))
57
+
58
+ meta_info = res["meta_info"]
59
+ assert len(meta_info["input_token_logprobs"]) == len(
60
+ meta_info["input_top_logprobs"]
61
+ )
62
+ assert len(meta_info["output_token_logprobs"]) == len(
63
+ meta_info["output_top_logprobs"]
64
+ )
65
+ assert len(meta_info["input_token_logprobs"]) == meta_info["prompt_tokens"]
66
+ assert len(meta_info["output_token_logprobs"]) == meta_info["completion_tokens"] - 1
67
+
68
+ return res
69
+
70
+
71
+ def pretty_print(res):
72
+ meta_info = res["meta_info"]
73
+
74
+ print("\n\n", "=" * 30, "Prefill", "=" * 30)
75
+ for i in range(len(meta_info["input_token_logprobs"])):
76
+ print(f"{str(meta_info['input_token_logprobs'][i][2].encode()): <20}", end="")
77
+ top_ks = (
78
+ [str(t[2].encode()) for t in meta_info["input_top_logprobs"][i]]
79
+ if meta_info["input_top_logprobs"][i]
80
+ else []
81
+ )
82
+ for top_k in top_ks:
83
+ print(f"{top_k: <15}", end="")
84
+ print()
85
+
86
+ print("\n\n", "=" * 30, "Decode", "=" * 30)
87
+ for i in range(len(meta_info["output_token_logprobs"])):
88
+ print(f"{str(meta_info['output_token_logprobs'][i][2].encode()): <20}", end="")
89
+ top_ks = [str(t[2].encode()) for t in meta_info["output_top_logprobs"][i]]
90
+ for top_k in top_ks:
91
+ print(f"{top_k: <15}", end="")
92
+ print()
93
+
94
+ print(res["text"])
95
+
96
+
97
+ if __name__ == "__main__":
98
+ with ThreadPoolExecutor() as executor:
99
+ ress = executor.map(srt_api_request, character_names)
100
+
101
+ for res in ress:
102
+ pretty_print(res)
103
+
104
+ openai_api_request("Hermione Granger")
sglang/examples/frontend_language/usage/llava_video/srt_example_llava_v.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Usage:
3
+ pip install opencv-python-headless
4
+
5
+ python3 srt_example_llava_v.py
6
+ """
7
+
8
+ import argparse
9
+ import csv
10
+ import json
11
+ import os
12
+ import time
13
+
14
+ import requests
15
+
16
+ import sglang as sgl
17
+
18
+
19
+ @sgl.function
20
+ def video_qa(s, num_frames, video_path, question):
21
+ s += sgl.user(sgl.video(video_path, num_frames) + question)
22
+ s += sgl.assistant(sgl.gen("answer"))
23
+
24
+
25
+ def single(path, num_frames=16):
26
+ state = video_qa.run(
27
+ num_frames=num_frames,
28
+ video_path=path,
29
+ question="Please provide a detailed description of the video, focusing on the main subjects, their actions, the background scenes",
30
+ temperature=0.0,
31
+ max_new_tokens=1024,
32
+ )
33
+ print(state["answer"], "\n")
34
+
35
+
36
+ def split_into_chunks(lst, num_chunks):
37
+ """Split a list into a specified number of chunks."""
38
+ # Calculate the chunk size using integer division. Note that this may drop some items if not evenly divisible.
39
+ chunk_size = len(lst) // num_chunks
40
+
41
+ if chunk_size == 0:
42
+ chunk_size = len(lst)
43
+ # Use list comprehension to generate chunks. The last chunk will take any remainder if the list size isn't evenly divisible.
44
+ chunks = [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)]
45
+ # Ensure we have exactly num_chunks chunks, even if some are empty
46
+ chunks.extend([[] for _ in range(num_chunks - len(chunks))])
47
+ return chunks
48
+
49
+
50
+ def save_batch_results(batch_video_files, states, cur_chunk, batch_idx, save_dir):
51
+ csv_filename = f"{save_dir}/chunk_{cur_chunk}_batch_{batch_idx}.csv"
52
+ with open(csv_filename, "w", newline="") as csvfile:
53
+ writer = csv.writer(csvfile)
54
+ writer.writerow(["video_name", "answer"])
55
+ for video_path, state in zip(batch_video_files, states):
56
+ video_name = os.path.basename(video_path)
57
+ writer.writerow([video_name, state["answer"]])
58
+
59
+
60
+ def compile_and_cleanup_final_results(cur_chunk, num_batches, save_dir):
61
+ final_csv_filename = f"{save_dir}/final_results_chunk_{cur_chunk}.csv"
62
+ with open(final_csv_filename, "w", newline="") as final_csvfile:
63
+ writer = csv.writer(final_csvfile)
64
+ writer.writerow(["video_name", "answer"])
65
+ for batch_idx in range(num_batches):
66
+ batch_csv_filename = f"{save_dir}/chunk_{cur_chunk}_batch_{batch_idx}.csv"
67
+ with open(batch_csv_filename, "r") as batch_csvfile:
68
+ reader = csv.reader(batch_csvfile)
69
+ next(reader) # Skip header row
70
+ for row in reader:
71
+ writer.writerow(row)
72
+ os.remove(batch_csv_filename)
73
+
74
+
75
+ def find_video_files(video_dir):
76
+ # Check if the video_dir is actually a file
77
+ if os.path.isfile(video_dir):
78
+ # If it's a file, return it as a single-element list
79
+ return [video_dir]
80
+
81
+ # Original logic to find video files in a directory
82
+ video_files = []
83
+ for root, dirs, files in os.walk(video_dir):
84
+ for file in files:
85
+ if file.endswith((".mp4", ".avi", ".mov")):
86
+ video_files.append(os.path.join(root, file))
87
+ return video_files
88
+
89
+
90
+ def batch(video_dir, save_dir, cur_chunk, num_chunks, num_frames=16, batch_size=64):
91
+ video_files = find_video_files(video_dir)
92
+ chunked_video_files = split_into_chunks(video_files, num_chunks)[cur_chunk]
93
+ num_batches = 0
94
+
95
+ for i in range(0, len(chunked_video_files), batch_size):
96
+ batch_video_files = chunked_video_files[i : i + batch_size]
97
+ print(f"Processing batch of {len(batch_video_files)} video(s)...")
98
+
99
+ if not batch_video_files:
100
+ print("No video files found in the specified directory.")
101
+ return
102
+
103
+ batch_input = [
104
+ {
105
+ "num_frames": num_frames,
106
+ "video_path": video_path,
107
+ "question": "Please provide a detailed description of the video, focusing on the main subjects, their actions, the background scenes.",
108
+ }
109
+ for video_path in batch_video_files
110
+ ]
111
+
112
+ start_time = time.time()
113
+ states = video_qa.run_batch(batch_input, max_new_tokens=512, temperature=0.2)
114
+ total_time = time.time() - start_time
115
+ average_time = total_time / len(batch_video_files)
116
+ print(
117
+ f"Number of videos in batch: {len(batch_video_files)}. Average processing time per video: {average_time:.2f} seconds. Total time for this batch: {total_time:.2f} seconds"
118
+ )
119
+
120
+ save_batch_results(batch_video_files, states, cur_chunk, num_batches, save_dir)
121
+ num_batches += 1
122
+
123
+ compile_and_cleanup_final_results(cur_chunk, num_batches, save_dir)
124
+
125
+
126
+ if __name__ == "__main__":
127
+
128
+ url = "https://raw.githubusercontent.com/EvolvingLMMs-Lab/sglang/dev/onevision_local/assets/jobs.mp4"
129
+
130
+ cache_dir = os.path.expanduser("~/.cache")
131
+ file_path = os.path.join(cache_dir, "jobs.mp4")
132
+
133
+ os.makedirs(cache_dir, exist_ok=True)
134
+
135
+ response = requests.get(url)
136
+ response.raise_for_status() # Raise an exception for bad responses
137
+
138
+ with open(file_path, "wb") as f:
139
+ f.write(response.content)
140
+
141
+ print(f"File downloaded and saved to: {file_path}")
142
+ # Create the parser
143
+ parser = argparse.ArgumentParser(
144
+ description="Run video processing with specified port."
145
+ )
146
+
147
+ # Add an argument for the port
148
+ parser.add_argument(
149
+ "--port",
150
+ type=int,
151
+ default=30000,
152
+ help="The master port for distributed serving.",
153
+ )
154
+ parser.add_argument(
155
+ "--chunk-idx", type=int, default=0, help="The index of the chunk to process."
156
+ )
157
+ parser.add_argument(
158
+ "--num-chunks", type=int, default=8, help="The number of chunks to process."
159
+ )
160
+ parser.add_argument(
161
+ "--save-dir",
162
+ type=str,
163
+ default="./work_dirs/llava_video",
164
+ help="The directory to save the processed video files.",
165
+ )
166
+ parser.add_argument(
167
+ "--video-dir",
168
+ type=str,
169
+ default=os.path.expanduser("~/.cache/jobs.mp4"),
170
+ help="The directory or path for the processed video files.",
171
+ )
172
+ parser.add_argument(
173
+ "--model-path",
174
+ type=str,
175
+ default="lmms-lab/LLaVA-NeXT-Video-7B",
176
+ help="The model path for the video processing.",
177
+ )
178
+ parser.add_argument(
179
+ "--num-frames",
180
+ type=int,
181
+ default=16,
182
+ help="The number of frames to process in each video.",
183
+ )
184
+ parser.add_argument("--mm_spatial_pool_stride", type=int, default=2)
185
+
186
+ # Parse the arguments
187
+ args = parser.parse_args()
188
+ cur_port = args.port
189
+ cur_chunk = args.chunk_idx
190
+ num_chunks = args.num_chunks
191
+ num_frames = args.num_frames
192
+
193
+ if "34b" in args.model_path.lower():
194
+ tokenizer_path = "liuhaotian/llava-v1.6-34b-tokenizer"
195
+ elif "7b" in args.model_path.lower():
196
+ tokenizer_path = "llava-hf/llava-1.5-7b-hf"
197
+ else:
198
+ print("Invalid model path. Please specify a valid model path.")
199
+ exit()
200
+
201
+ model_override_args = {}
202
+ model_override_args["mm_spatial_pool_stride"] = args.mm_spatial_pool_stride
203
+ model_override_args["architectures"] = ["LlavaVidForCausalLM"]
204
+ model_override_args["num_frames"] = args.num_frames
205
+ model_override_args["model_type"] = "llava"
206
+
207
+ if "34b" in args.model_path.lower():
208
+ model_override_args["image_token_index"] = 64002
209
+
210
+ if args.num_frames == 32:
211
+ model_override_args["rope_scaling"] = {"factor": 2.0, "rope_type": "linear"}
212
+ model_override_args["max_sequence_length"] = 4096 * 2
213
+ model_override_args["tokenizer_model_max_length"] = 4096 * 2
214
+ elif args.num_frames < 32:
215
+ pass
216
+ else:
217
+ print(
218
+ "The maximum number of frames to process is 32. Please specify a valid number of frames."
219
+ )
220
+ exit()
221
+
222
+ runtime = sgl.Runtime(
223
+ model_path=args.model_path, # "liuhaotian/llava-v1.6-vicuna-7b",
224
+ tokenizer_path=tokenizer_path,
225
+ port=cur_port,
226
+ json_model_override_args=json.dumps(model_override_args),
227
+ tp_size=1,
228
+ )
229
+ sgl.set_default_backend(runtime)
230
+ print(f"chat template: {runtime.endpoint.chat_template.name}")
231
+
232
+ # Run a single request
233
+ print("\n========== single ==========\n")
234
+ root = args.video_dir
235
+ if os.path.isfile(root):
236
+ video_files = [root]
237
+ else:
238
+ video_files = [
239
+ os.path.join(root, f)
240
+ for f in os.listdir(root)
241
+ if f.endswith((".mp4", ".avi", ".mov"))
242
+ ] # Add more extensions if needed
243
+ start_time = time.time() # Start time for processing a single video
244
+ for cur_video in video_files[:1]:
245
+ print(cur_video)
246
+ single(cur_video, num_frames)
247
+ end_time = time.time() # End time for processing a single video
248
+ total_time = end_time - start_time
249
+ average_time = total_time / len(
250
+ video_files
251
+ ) # Calculate the average processing time
252
+ print(f"Average processing time per video: {average_time:.2f} seconds")
253
+ runtime.shutdown()
254
+
255
+ # # Run a batch of requests
256
+ # print("\n========== batch ==========\n")
257
+ # if not os.path.exists(args.save_dir):
258
+ # os.makedirs(args.save_dir)
259
+ # batch(args.video_dir, args.save_dir, cur_chunk, num_chunks, num_frames, num_chunks)
260
+ # runtime.shutdown()
sglang/examples/frontend_language/usage/llava_video/srt_example_llava_v.sh ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ ##### USAGE #####
4
+ # - First node:
5
+ # ```sh
6
+ # bash examples/usage/llava_video/srt_example_llava_v.sh K 0 YOUR_VIDEO_PATH YOUR_MODEL_PATH FRAMES_PER_VIDEO
7
+ # ```
8
+ # - Second node:
9
+ # ```sh
10
+ # bash examples/usage/llava_video/srt_example_llava_v.sh K 1 YOUR_VIDEO_PATH YOUR_MODEL_PATH FRAMES_PER_VIDEO
11
+ # ```
12
+ # - The K node:
13
+ # ```sh
14
+ # bash examples/usage/llava_video/srt_example_llava_v.sh K K-1 YOUR_VIDEO_PATH YOUR_MODEL_PATH FRAMES_PER_VIDEO
15
+ # ```
16
+
17
+
18
+ # Replace `K`, `YOUR_VIDEO_PATH`, `YOUR_MODEL_PATH`, and `FRAMES_PER_VIDEO` with your specific details.
19
+ # CURRENT_ROOT="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
20
+ CURRENT_ROOT=$(dirname "$0")
21
+
22
+ echo ${CURRENT_ROOT}
23
+
24
+ cd ${CURRENT_ROOT}
25
+
26
+ export PYTHONWARNINGS=ignore
27
+
28
+ START_TIME=$(date +%s) # Capture start time
29
+
30
+ NUM_NODES=$1
31
+
32
+ CUR_NODES_IDX=$2
33
+
34
+ VIDEO_DIR=$3
35
+
36
+ MODEL_PATH=$4
37
+
38
+ NUM_FRAMES=$5
39
+
40
+
41
+ # FRAME_FORMAT=$6
42
+
43
+ # FRAME_FORMAT=$(echo $FRAME_FORMAT | tr '[:lower:]' '[:upper:]')
44
+
45
+ # # Check if FRAME_FORMAT is either JPEG or PNG
46
+ # if [[ "$FRAME_FORMAT" != "JPEG" && "$FRAME_FORMAT" != "PNG" ]]; then
47
+ # echo "Error: FRAME_FORMAT must be either JPEG or PNG."
48
+ # exit 1
49
+ # fi
50
+
51
+ # export TARGET_FRAMES=$TARGET_FRAMES
52
+
53
+ echo "Each video you will sample $NUM_FRAMES frames"
54
+
55
+ # export FRAME_FORMAT=$FRAME_FORMAT
56
+
57
+ # echo "The frame format is $FRAME_FORMAT"
58
+
59
+ # Assuming GPULIST is a bash array containing your GPUs
60
+ GPULIST=(0 1 2 3 4 5 6 7)
61
+ LOCAL_CHUNKS=${#GPULIST[@]}
62
+
63
+ echo "Number of GPUs in GPULIST: $LOCAL_CHUNKS"
64
+
65
+ ALL_CHUNKS=$((NUM_NODES * LOCAL_CHUNKS))
66
+
67
+ # Calculate GPUs per chunk
68
+ GPUS_PER_CHUNK=1
69
+
70
+ echo $GPUS_PER_CHUNK
71
+
72
+ for IDX in $(seq 1 $LOCAL_CHUNKS); do
73
+ (
74
+ START=$(((IDX-1) * GPUS_PER_CHUNK))
75
+ LENGTH=$GPUS_PER_CHUNK # Length for slicing, not the end index
76
+
77
+ CHUNK_GPUS=(${GPULIST[@]:$START:$LENGTH})
78
+
79
+ # Convert the chunk GPUs array to a comma-separated string
80
+ CHUNK_GPUS_STR=$(IFS=,; echo "${CHUNK_GPUS[*]}")
81
+
82
+ LOCAL_IDX=$((CUR_NODES_IDX * LOCAL_CHUNKS + IDX))
83
+
84
+ echo "Chunk $(($LOCAL_IDX - 1)) will run on GPUs $CHUNK_GPUS_STR"
85
+
86
+ # Calculate the port for this chunk. Ensure it's incremented by 5 for each chunk.
87
+ PORT=$((10000 + RANDOM % 55536))
88
+
89
+ MAX_RETRIES=10
90
+ RETRY_COUNT=0
91
+ COMMAND_STATUS=1 # Initialize as failed
92
+
93
+ while [ $RETRY_COUNT -lt $MAX_RETRIES ] && [ $COMMAND_STATUS -ne 0 ]; do
94
+ echo "Running chunk $(($LOCAL_IDX - 1)) on GPUs $CHUNK_GPUS_STR with port $PORT. Attempt $(($RETRY_COUNT + 1))"
95
+
96
+ #!/bin/bash
97
+ CUDA_VISIBLE_DEVICES=$CHUNK_GPUS_STR python3 srt_example_llava_v.py \
98
+ --port $PORT \
99
+ --num-chunks $ALL_CHUNKS \
100
+ --chunk-idx $(($LOCAL_IDX - 1)) \
101
+ --save-dir work_dirs/llava_next_video_inference_results \
102
+ --video-dir $VIDEO_DIR \
103
+ --model-path $MODEL_PATH \
104
+ --num-frames $NUM_FRAMES #&
105
+
106
+ wait $! # Wait for the process to finish and capture its exit status
107
+ COMMAND_STATUS=$?
108
+
109
+ if [ $COMMAND_STATUS -ne 0 ]; then
110
+ echo "Execution failed for chunk $(($LOCAL_IDX - 1)), attempt $(($RETRY_COUNT + 1)). Retrying..."
111
+ RETRY_COUNT=$(($RETRY_COUNT + 1))
112
+ sleep 180 # Wait a bit before retrying
113
+ else
114
+ echo "Execution succeeded for chunk $(($LOCAL_IDX - 1))."
115
+ fi
116
+ done
117
+
118
+ if [ $COMMAND_STATUS -ne 0 ]; then
119
+ echo "Execution failed for chunk $(($LOCAL_IDX - 1)) after $MAX_RETRIES attempts."
120
+ fi
121
+ ) #&
122
+ sleep 2 # Slight delay to stagger the start times
123
+ done
124
+
125
+ wait
126
+
127
+ cat work_dirs/llava_next_video_inference_results/final_results_chunk_*.csv > work_dirs/llava_next_video_inference_results/final_results_node_${CUR_NODES_IDX}.csv
128
+
129
+ END_TIME=$(date +%s) # Capture end time
130
+ ELAPSED_TIME=$(($END_TIME - $START_TIME))
131
+ echo "Total execution time: $ELAPSED_TIME seconds."
sglang/examples/frontend_language/usage/openai_chat_speculative.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Usage:
3
+ ***Note: for speculative execution to work, user must put all "gen" in "assistant".
4
+ Show in "assistant" the desired answer format. Each "gen" term should have a stop token.
5
+ The stream mode is not supported in speculative execution.
6
+
7
+ E.g.
8
+ correct:
9
+ sgl.assistant("\nName:" + sgl.gen("name", stop="\n") + "\nBirthday:" + sgl.gen("birthday", stop="\n") + "\nJob:" + sgl.gen("job", stop="\n"))
10
+ incorrect:
11
+ s += sgl.assistant("\nName:" + sgl.gen("name", stop="\n"))
12
+ s += sgl.assistant("\nBirthday:" + sgl.gen("birthday", stop="\n"))
13
+ s += sgl.assistant("\nJob:" + sgl.gen("job", stop="\n"))
14
+
15
+ export OPENAI_API_KEY=sk-******
16
+ python3 openai_chat_speculative.py
17
+ """
18
+
19
+ import sglang as sgl
20
+ from sglang import OpenAI, function, set_default_backend
21
+
22
+
23
+ @function(num_api_spec_tokens=256)
24
+ def gen_character_spec(s):
25
+ s += sgl.system("You are a helpful assistant.")
26
+ s += sgl.user("Construct a character within the following format:")
27
+ s += sgl.assistant(
28
+ "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
29
+ )
30
+ s += sgl.user("Please generate new Name, Birthday and Job.\n")
31
+ s += sgl.assistant(
32
+ "Name:"
33
+ + sgl.gen("name", stop="\n")
34
+ + "\nBirthday:"
35
+ + sgl.gen("birthday", stop="\n")
36
+ + "\nJob:"
37
+ + sgl.gen("job", stop="\n")
38
+ )
39
+
40
+
41
+ @function(num_api_spec_tokens=256)
42
+ def gen_character_spec_no_few_shot(s):
43
+ s += sgl.user("Construct a character. For each field stop with a newline\n")
44
+ s += sgl.assistant(
45
+ "Name:"
46
+ + sgl.gen("name", stop="\n")
47
+ + "\nAge:"
48
+ + sgl.gen("age", stop="\n")
49
+ + "\nJob:"
50
+ + sgl.gen("job", stop="\n")
51
+ )
52
+
53
+
54
+ @function
55
+ def gen_character_normal(s):
56
+ s += sgl.system("You are a helpful assistant.")
57
+ s += sgl.user("What's the answer of 23 + 8?")
58
+ s += sgl.assistant(sgl.gen("answer", max_tokens=64))
59
+
60
+
61
+ @function(num_api_spec_tokens=1024)
62
+ def multi_turn_question(s, question_1, question_2):
63
+ s += sgl.system("You are a helpful assistant.")
64
+ s += sgl.user("Answer questions in the following format:")
65
+ s += sgl.user(
66
+ "Question 1: What is the capital of France?\nQuestion 2: What is the population of this city?\n"
67
+ )
68
+ s += sgl.assistant(
69
+ "Answer 1: The capital of France is Paris.\nAnswer 2: The population of Paris in 2024 is estimated to be around 2.1 million for the city proper.\n"
70
+ )
71
+ s += sgl.user("Question 1: " + question_1 + "\nQuestion 2: " + question_2)
72
+ s += sgl.assistant(
73
+ "Answer 1: "
74
+ + sgl.gen("answer_1", stop="\n")
75
+ + "\nAnswer 2: "
76
+ + sgl.gen("answer_2", stop="\n")
77
+ )
78
+
79
+
80
+ def test_spec_single_turn():
81
+ backend.token_usage.reset()
82
+
83
+ state = gen_character_spec.run()
84
+ for m in state.messages():
85
+ print(m["role"], ":", m["content"])
86
+
87
+ print("\n-- name:", state["name"])
88
+ print("-- birthday:", state["birthday"])
89
+ print("-- job:", state["job"])
90
+ print(backend.token_usage)
91
+
92
+
93
+ def test_inaccurate_spec_single_turn():
94
+ state = gen_character_spec_no_few_shot.run()
95
+ for m in state.messages():
96
+ print(m["role"], ":", m["content"])
97
+
98
+ print("\n-- name:", state["name"])
99
+ print("\n-- age:", state["age"])
100
+ print("\n-- job:", state["job"])
101
+
102
+
103
+ def test_normal_single_turn():
104
+ state = gen_character_normal.run()
105
+ for m in state.messages():
106
+ print(m["role"], ":", m["content"])
107
+
108
+
109
+ def test_spec_multi_turn():
110
+ state = multi_turn_question.run(
111
+ question_1="What is the capital of the United States?",
112
+ question_2="List two local attractions in the capital of the United States.",
113
+ )
114
+
115
+ for m in state.messages():
116
+ print(m["role"], ":", m["content"])
117
+
118
+ print("\n-- answer_1 --\n", state["answer_1"])
119
+ print("\n-- answer_2 --\n", state["answer_2"])
120
+
121
+
122
+ def test_spec_multi_turn_stream():
123
+ state = multi_turn_question.run(
124
+ question_1="What is the capital of the United States?",
125
+ question_2="List two local attractions.",
126
+ stream=True,
127
+ )
128
+
129
+ for out in state.text_iter():
130
+ print(out, end="", flush=True)
131
+
132
+
133
+ if __name__ == "__main__":
134
+ backend = OpenAI("gpt-4-turbo")
135
+ set_default_backend(backend)
136
+
137
+ print("\n========== test spec single turn ==========\n")
138
+ # expect reasonable answer for each field
139
+ test_spec_single_turn()
140
+
141
+ print("\n========== test inaccurate spec single turn ==========\n")
142
+ # expect incomplete or unreasonable answers
143
+ test_inaccurate_spec_single_turn()
144
+
145
+ print("\n========== test normal single turn ==========\n")
146
+ # expect reasonable answer
147
+ test_normal_single_turn()
148
+
149
+ print("\n========== test spec multi turn ==========\n")
150
+ # expect answer with same format as in the few shot
151
+ test_spec_multi_turn()
152
+
153
+ print("\n========== test spec multi turn stream ==========\n")
154
+ # expect error in stream_executor: stream is not supported...
155
+ test_spec_multi_turn_stream()
sglang/examples/frontend_language/usage/openai_speculative.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Usage:
3
+ python3 openai_speculative.py
4
+ """
5
+
6
+ from sglang import OpenAI, function, gen, set_default_backend
7
+
8
+
9
+ @function(num_api_spec_tokens=64)
10
+ def gen_character_spec(s):
11
+ s += "Construct a character within the following format:\n"
12
+ s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
13
+ s += "\nPlease generate new Name, Birthday and Job.\n"
14
+ s += "Name:" + gen("name", stop="\n") + "\nBirthday:" + gen("birthday", stop="\n")
15
+ s += "\nJob:" + gen("job", stop="\n") + "\n"
16
+
17
+
18
+ @function
19
+ def gen_character_no_spec(s):
20
+ s += "Construct a character within the following format:\n"
21
+ s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
22
+ s += "\nPlease generate new Name, Birthday and Job.\n"
23
+ s += "Name:" + gen("name", stop="\n") + "\nBirthday:" + gen("birthday", stop="\n")
24
+ s += "\nJob:" + gen("job", stop="\n") + "\n"
25
+
26
+
27
+ @function(num_api_spec_tokens=64)
28
+ def gen_character_spec_no_few_shot(s):
29
+ # s += "Construct a character with name, birthday, and job:\n"
30
+ s += "Construct a character:\n"
31
+ s += "Name:" + gen("name", stop="\n") + "\nBirthday:" + gen("birthday", stop="\n")
32
+ s += "\nJob:" + gen("job", stop="\n") + "\n"
33
+
34
+
35
+ if __name__ == "__main__":
36
+ backend = OpenAI("gpt-3.5-turbo-instruct")
37
+ set_default_backend(backend)
38
+
39
+ for function in [
40
+ gen_character_spec,
41
+ gen_character_no_spec,
42
+ gen_character_spec_no_few_shot,
43
+ ]:
44
+ backend.token_usage.reset()
45
+
46
+ print(f"function: {function.func.__name__}")
47
+
48
+ state = function.run()
49
+
50
+ print("...name:", state["name"])
51
+ print("...birthday:", state["birthday"])
52
+ print("...job:", state["job"])
53
+ print(backend.token_usage)
54
+ print()
sglang/examples/frontend_language/usage/parallel_sample.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Usage:
3
+ python3 parallel_sample.py
4
+ """
5
+
6
+ import sglang as sgl
7
+
8
+
9
+ @sgl.function
10
+ def parallel_sample(s, question, n):
11
+ s += (
12
+ "Question: Compute 1 + 2 + 3\n"
13
+ "Reasoning: I need to use a calculator.\n"
14
+ "Tool: calculator\n"
15
+ "Answer: 6\n"
16
+ "Question: Compute 3 + 2 + 2\n"
17
+ "Reasoning: I will try a calculator.\n"
18
+ "Tool: calculator\n"
19
+ "Answer: 7\n"
20
+ )
21
+ s += "Question: " + question + "\n"
22
+ forks = s.fork(n)
23
+ forks += "Reasoning:" + sgl.gen("reasoning", stop="\n") + "\n"
24
+ forks += "Tool:" + sgl.gen("tool", choices=["calculator", "browser"]) + "\n"
25
+ forks += "Answer:" + sgl.gen("answer", stop="\n") + "\n"
26
+ forks.join()
27
+
28
+
29
+ sgl.set_default_backend(sgl.OpenAI("gpt-3.5-turbo-instruct"))
30
+ # sgl.set_default_backend(sgl.RuntimeEndpoint("http://localhost:30000"))
31
+
32
+ state = parallel_sample.run(question="Compute 5 + 2 + 4.", n=5, temperature=1.0)
33
+
34
+ for i in range(5):
35
+ obj = {
36
+ "reasoning": state["reasoning"][i],
37
+ "tool": state["tool"][i],
38
+ "answer": state["answer"][i],
39
+ }
40
+ print(f"[{i}], {obj}")
sglang/examples/frontend_language/usage/rag_using_parea/trace_and_evaluate_rag_using_parea.ipynb ADDED
@@ -0,0 +1,408 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# RAG Powered by SGLang & Chroma Evaluated using Parea\n",
8
+ "\n",
9
+ "In this notebook, we will build a simple RAG pipeline using SGLang to execute our LLM calls, Chroma as vector database for retrieval and [Parea](https://www.parea.ai) for tracing and evaluation. We will then evaluate the performance of our RAG pipeline. The dataset we will use was created by [Virat](https://twitter.com/virattt) and contains 100 questions, contexts and answers from the Airbnb 2023 10k filing.\n",
10
+ "\n",
11
+ "The RAG pipeline consists of two steps:\n",
12
+ "1. Retrieval: Given a question, we retrieve the relevant context from all provided contexts.\n",
13
+ "2. Generation: Given the question and the retrieved context, we generate an answer.\n",
14
+ "\n",
15
+ "ℹ️ This notebook requires an OpenAI API key.\n",
16
+ "\n",
17
+ "ℹ️ This notebook requires a Parea API key, which can be created [here](https://docs.parea.ai/api-reference/authentication#parea-api-key)."
18
+ ]
19
+ },
20
+ {
21
+ "cell_type": "markdown",
22
+ "metadata": {},
23
+ "source": [
24
+ "## Setting up the environment\n",
25
+ "\n",
26
+ "We will first install the necessary packages: `sglang`, `parea-ai` and `chromadb`."
27
+ ]
28
+ },
29
+ {
30
+ "cell_type": "code",
31
+ "execution_count": null,
32
+ "metadata": {},
33
+ "outputs": [],
34
+ "source": [
35
+ "# note, if you use a Mac M1 chip, you might need to install grpcio 1.59.0 first such that installing chromadb works\n",
36
+ "# !pip install grpcio==1.59.0\n",
37
+ "\n",
38
+ "!pip install sglang[openai] parea-ai chromadb"
39
+ ]
40
+ },
41
+ {
42
+ "cell_type": "markdown",
43
+ "metadata": {},
44
+ "source": [
45
+ "Create a Parea API key as outlined [here](https://docs.parea.ai/api-reference/authentication#parea-api-key) and save it in a `.env` file as `PAREA_API_KEY=your-api-key`."
46
+ ]
47
+ },
48
+ {
49
+ "cell_type": "markdown",
50
+ "metadata": {},
51
+ "source": [
52
+ "## Indexing the data\n",
53
+ "\n",
54
+ "Now it's time to download the data & index it! For that, we create a collection called `contexts` in Chroma and add the contexts as documents."
55
+ ]
56
+ },
57
+ {
58
+ "cell_type": "code",
59
+ "execution_count": null,
60
+ "metadata": {},
61
+ "outputs": [],
62
+ "source": [
63
+ "import json\n",
64
+ "import os\n",
65
+ "from typing import List\n",
66
+ "\n",
67
+ "import chromadb\n",
68
+ "\n",
69
+ "path_qca = \"airbnb-2023-10k-qca.json\"\n",
70
+ "\n",
71
+ "if not os.path.exists(path_qca):\n",
72
+ " !wget https://virattt.github.io/datasets/abnb-2023-10k.json -O airbnb-2023-10k-qca.json\n",
73
+ "\n",
74
+ "with open(path_qca, \"r\") as f:\n",
75
+ " question_context_answers = json.load(f)\n",
76
+ "\n",
77
+ "chroma_client = chromadb.PersistentClient()\n",
78
+ "collection = chroma_client.get_or_create_collection(name=\"contexts\")\n",
79
+ "if collection.count() == 0:\n",
80
+ " collection.add(\n",
81
+ " documents=[qca[\"context\"] for qca in question_context_answers],\n",
82
+ " ids=[str(i) for i in range(len(question_context_answers))],\n",
83
+ " )"
84
+ ]
85
+ },
86
+ {
87
+ "cell_type": "markdown",
88
+ "metadata": {},
89
+ "source": [
90
+ "## Defining the RAG pipeline\n",
91
+ "\n",
92
+ "We will start with importing the necessary packages, setting up tracing of OpenAI calls via Parea and setting OpenAI as the default backend for SGLang."
93
+ ]
94
+ },
95
+ {
96
+ "cell_type": "code",
97
+ "execution_count": null,
98
+ "metadata": {},
99
+ "outputs": [],
100
+ "source": [
101
+ "import os\n",
102
+ "import time\n",
103
+ "\n",
104
+ "from dotenv import load_dotenv\n",
105
+ "\n",
106
+ "from sglang import function, user, assistant, gen, set_default_backend, OpenAI\n",
107
+ "from sglang.lang.interpreter import ProgramState\n",
108
+ "from parea import Parea, trace\n",
109
+ "\n",
110
+ "\n",
111
+ "load_dotenv()\n",
112
+ "\n",
113
+ "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
114
+ "\n",
115
+ "p = Parea(api_key=os.getenv(\"PAREA_API_KEY\"), project_name=\"rag_sglang\")\n",
116
+ "p.integrate_with_sglang()\n",
117
+ "\n",
118
+ "set_default_backend(OpenAI(\"gpt-3.5-turbo\"))"
119
+ ]
120
+ },
121
+ {
122
+ "cell_type": "markdown",
123
+ "metadata": {},
124
+ "source": [
125
+ "Now we can define our retrieval step shown below. Notice, the `trace` decorator which will automatically trace inputs, output, latency, etc. of that call."
126
+ ]
127
+ },
128
+ {
129
+ "cell_type": "code",
130
+ "execution_count": null,
131
+ "metadata": {},
132
+ "outputs": [],
133
+ "source": [
134
+ "@trace\n",
135
+ "def retrieval(question: str) -> List[str]:\n",
136
+ " return collection.query(query_texts=[question], n_results=1)[\"documents\"][0]"
137
+ ]
138
+ },
139
+ {
140
+ "cell_type": "markdown",
141
+ "metadata": {},
142
+ "source": [
143
+ "Next we will define the generation step which uses SGLang to execute the LLM call."
144
+ ]
145
+ },
146
+ {
147
+ "cell_type": "code",
148
+ "execution_count": null,
149
+ "metadata": {},
150
+ "outputs": [],
151
+ "source": [
152
+ "@function\n",
153
+ "def generation_sglang(s, question: str, *context: str):\n",
154
+ " context = \"\\n\".join(context)\n",
155
+ " s += user(\n",
156
+ " f\"Given this question:\\n{question}\\n\\nAnd this context:\\n{context}\\n\\nAnswer the question.\"\n",
157
+ " )\n",
158
+ " s += assistant(gen(\"answer\"))\n",
159
+ "\n",
160
+ "\n",
161
+ "@trace\n",
162
+ "def generation(question: str, *context):\n",
163
+ " state: ProgramState = generation_sglang.run(question, *context)\n",
164
+ " while not state.stream_executor.is_finished:\n",
165
+ " time.sleep(1)\n",
166
+ " return state.stream_executor.variables[\"answer\"]"
167
+ ]
168
+ },
169
+ {
170
+ "cell_type": "markdown",
171
+ "metadata": {},
172
+ "source": [
173
+ "Finally, we can tie it together and execute a sample query."
174
+ ]
175
+ },
176
+ {
177
+ "cell_type": "code",
178
+ "execution_count": null,
179
+ "metadata": {},
180
+ "outputs": [],
181
+ "source": [
182
+ "@trace\n",
183
+ "def rag_pipeline(question: str) -> str:\n",
184
+ " contexts = retrieval(question)\n",
185
+ " return generation(question, *contexts)\n",
186
+ "\n",
187
+ "\n",
188
+ "rag_pipeline(\n",
189
+ " \"When did the World Health Organization formally declare an end to the COVID-19 global health emergency?\"\n",
190
+ ")"
191
+ ]
192
+ },
193
+ {
194
+ "cell_type": "markdown",
195
+ "metadata": {},
196
+ "source": [
197
+ "## Debug Trace\n",
198
+ "\n",
199
+ "The output is unfortunately wrong! Using the traced pipeline, we can see that\n",
200
+ "\n",
201
+ "- the context is relevant to the question and contains the correct information\n",
202
+ "- but, the generation step is cut off as max tokens is set to 16\n",
203
+ "\n",
204
+ "When opening the generation step in the playground and rerunning the prompt with max. tokens set to 1000, the correct answer is produced.\n",
205
+ "\n",
206
+ "![RAG Trace](https://drive.google.com/uc?id=1QI243ogGjzbO01tUrR72g9rFoGzUJqVH)"
207
+ ]
208
+ },
209
+ {
210
+ "cell_type": "markdown",
211
+ "metadata": {},
212
+ "source": [
213
+ "## Evaluating RAG Pipelines\n",
214
+ "\n",
215
+ "Before we apply above's fix, let's dive into evaluating RAG pipelines.\n",
216
+ "\n",
217
+ "RAG pipelines consist of a retrieval step to fetch relevant information and a generation step to generate a response to a users question. A RAG pipeline can fail at either step. E.g. the retrieval step can fail to find relevant information which makes generating the correct impossible. Another failure mode is that the generation step doesn't leverage the retrieved information correctly. We will apply the following evaluation metrics to understand different failure modes:\n",
218
+ "\n",
219
+ "- `context_relevancy`: measures how relevant the context is given the question\n",
220
+ "- `percent_target_supported_by_context`: measures how much of the target answer is supported by the context; this will give an upper ceiling of how well the generation step can perform\n",
221
+ "- `answer_context_faithfulness`: measures how much the generated answer utilizes the context\n",
222
+ "- `answer_matches_target`: measures how well the generated answer matches the target answer judged by a LLM and gives a sense of accuracy of our entire pipeline\n",
223
+ "\n",
224
+ "To use these evaluation metrics, we can import them from `parea.evals.rag` and `parea.evals.general` and apply them to a function by specifying in the `trace` decorator which evaluation metrics to use. The `@trace` decorator will automatically log the results of the evaluation metrics to the Parea dashboard.\n",
225
+ "\n",
226
+ "Applying them to the retrieval step:"
227
+ ]
228
+ },
229
+ {
230
+ "cell_type": "code",
231
+ "execution_count": null,
232
+ "metadata": {},
233
+ "outputs": [],
234
+ "source": [
235
+ "from parea.evals.rag import (\n",
236
+ " context_query_relevancy_factory,\n",
237
+ " percent_target_supported_by_context_factory,\n",
238
+ ")\n",
239
+ "\n",
240
+ "\n",
241
+ "context_relevancy_eval = context_query_relevancy_factory()\n",
242
+ "percent_target_supported_by_context = percent_target_supported_by_context_factory()\n",
243
+ "\n",
244
+ "\n",
245
+ "@trace(eval_funcs=[context_relevancy_eval, percent_target_supported_by_context])\n",
246
+ "def retrieval(question: str) -> List[str]:\n",
247
+ " return collection.query(query_texts=[question], n_results=1)[\"documents\"][0]"
248
+ ]
249
+ },
250
+ {
251
+ "cell_type": "markdown",
252
+ "metadata": {},
253
+ "source": [
254
+ "Now we can apply `answer_context_faithfulness` and `answer_matches_target` to the generation step."
255
+ ]
256
+ },
257
+ {
258
+ "cell_type": "code",
259
+ "execution_count": null,
260
+ "metadata": {},
261
+ "outputs": [],
262
+ "source": [
263
+ "from parea.evals.general import answer_matches_target_llm_grader_factory\n",
264
+ "from parea.evals.rag import answer_context_faithfulness_statement_level_factory\n",
265
+ "\n",
266
+ "\n",
267
+ "answer_context_faithfulness = answer_context_faithfulness_statement_level_factory()\n",
268
+ "answer_matches_target_llm_grader = answer_matches_target_llm_grader_factory()\n",
269
+ "\n",
270
+ "\n",
271
+ "@function\n",
272
+ "def generation_sglang(s, question: str, *context: str):\n",
273
+ " context = \"\\n\".join(context)\n",
274
+ " s += user(\n",
275
+ " f\"Given this question:\\n{question}\\n\\nAnd this context:\\n{context}\\n\\nAnswer the question.\"\n",
276
+ " )\n",
277
+ " s += assistant(gen(\"answer\", max_tokens=1_000))\n",
278
+ "\n",
279
+ "\n",
280
+ "@trace(eval_funcs=[answer_context_faithfulness, answer_matches_target_llm_grader])\n",
281
+ "def generation(question: str, *context):\n",
282
+ " state: ProgramState = generation_sglang.run(question, *context)\n",
283
+ " while not state.stream_executor.is_finished:\n",
284
+ " time.sleep(1)\n",
285
+ " return state.stream_executor.variables[\"answer\"]"
286
+ ]
287
+ },
288
+ {
289
+ "cell_type": "markdown",
290
+ "metadata": {},
291
+ "source": [
292
+ "Finally, we tie them together & execute the original sample query."
293
+ ]
294
+ },
295
+ {
296
+ "cell_type": "code",
297
+ "execution_count": null,
298
+ "metadata": {},
299
+ "outputs": [],
300
+ "source": [
301
+ "@trace\n",
302
+ "def rag_pipeline(question: str) -> str:\n",
303
+ " contexts = retrieval(question)\n",
304
+ " return generation(question, *contexts)\n",
305
+ "\n",
306
+ "\n",
307
+ "rag_pipeline(\n",
308
+ " \"When did the World Health Organization formally declare an end to the COVID-19 global health emergency?\"\n",
309
+ ")"
310
+ ]
311
+ },
312
+ {
313
+ "cell_type": "markdown",
314
+ "metadata": {},
315
+ "source": [
316
+ "Great, the answer is correct! Can you spot the line where we fixed the output truncation issue?\n",
317
+ "\n",
318
+ "The evaluation scores appear in the bottom right of the logs (screenshot below). Note, that there is no score for `answer_matches_target_llm_grader` and `percent_target_supported_by_context` as these evals are automatically skipped if the target answer is not provided.\n",
319
+ "\n",
320
+ "![Fixed Max. Tokens](max-tokens-fixed-rag-trace.png)"
321
+ ]
322
+ },
323
+ {
324
+ "cell_type": "markdown",
325
+ "metadata": {},
326
+ "source": [
327
+ "## Running an experiment\n",
328
+ "\n",
329
+ "Now we are (almost) ready to evaluate the performance of our RAG pipeline on the entire dataset. First, we will need to apply the `nest_asyncio` package to avoid issues with the Jupyter notebook event loop."
330
+ ]
331
+ },
332
+ {
333
+ "cell_type": "code",
334
+ "execution_count": null,
335
+ "metadata": {},
336
+ "outputs": [],
337
+ "source": [
338
+ "!pip install nest-asyncio\n",
339
+ "import nest_asyncio\n",
340
+ "\n",
341
+ "nest_asyncio.apply()"
342
+ ]
343
+ },
344
+ {
345
+ "cell_type": "markdown",
346
+ "metadata": {},
347
+ "source": [
348
+ "Running the actual experiment is straight-forward. For that we use `p.experiment` to initialize the experiment with a name, the data (list of key-value pairs fed into our entry function) and the entry function. We then call `run` on the experiment to execute it. Note, that `target` is a reserved key in the data dictionary and will be used as the target answer for evaluation."
349
+ ]
350
+ },
351
+ {
352
+ "cell_type": "code",
353
+ "execution_count": null,
354
+ "metadata": {},
355
+ "outputs": [],
356
+ "source": [
357
+ "e = p.experiment(\n",
358
+ " \"RAG\",\n",
359
+ " data=[\n",
360
+ " {\n",
361
+ " \"question\": qca[\"question\"],\n",
362
+ " \"target\": qca[\"answer\"],\n",
363
+ " }\n",
364
+ " for qca in question_context_answers\n",
365
+ " ],\n",
366
+ " func=rag_pipeline,\n",
367
+ ").run()"
368
+ ]
369
+ },
370
+ {
371
+ "cell_type": "markdown",
372
+ "metadata": {},
373
+ "source": [
374
+ "## Analyzing the results\n",
375
+ "\n",
376
+ "When opening above experiment, we will see an overview of the experiment as shown below. The upper half shows a summary of the statistics on the left and charts to investigate the distribution and relationships of scores on the right. The lower half is a table with the individual traces which we can use to debug individual samples.\n",
377
+ "\n",
378
+ "When looking at the statistics, we can see that the accuracy of our RAG pipeline is 22% as measured by `answer_matches_target_llm_grader`. Though when checking the quality of our retrieval step (`context_query_relevancy`), we can see that our retrival step is fetching relevant information in only 27% of all samples. As shown in the GIF, we investigate the relationship between the two and see the two scores have 95% agreement. This confirms that the retrieval step is a major bottleneck for our RAG pipeline. So, now it's your turn to improve the retrieval step!\n",
379
+ "\n",
380
+ "Note, above link isn't publicly accessible but the experiment can be accessed through [here](https://app.parea.ai/public-experiments/parea/rag_sglang/30f0244a-d56c-44ff-bdfb-8f47626304b6).\n",
381
+ "\n",
382
+ "![Experiment Results](https://drive.google.com/uc?id=1KMtJBU47nPB02Pvv3SPPTK7RnHRh5YdA)"
383
+ ]
384
+ },
385
+ {
386
+ "cell_type": "code",
387
+ "execution_count": null,
388
+ "metadata": {},
389
+ "outputs": [],
390
+ "source": []
391
+ }
392
+ ],
393
+ "metadata": {
394
+ "language_info": {
395
+ "codemirror_mode": {
396
+ "name": "ipython",
397
+ "version": 2
398
+ },
399
+ "file_extension": ".py",
400
+ "mimetype": "text/x-python",
401
+ "name": "python",
402
+ "nbconvert_exporter": "python",
403
+ "pygments_lexer": "ipython2"
404
+ }
405
+ },
406
+ "nbformat": 4,
407
+ "nbformat_minor": 0
408
+ }
sglang/examples/frontend_language/usage/readme_examples.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Usage:
3
+ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
4
+ python readme_examples.py
5
+ """
6
+
7
+ import sglang as sgl
8
+
9
+
10
+ @sgl.function
11
+ def tool_use(s, question):
12
+ s += "To answer this question: " + question + ". "
13
+ s += (
14
+ "I need to use a "
15
+ + sgl.gen("tool", choices=["calculator", "search engine"])
16
+ + ". "
17
+ )
18
+
19
+ if s["tool"] == "calculator":
20
+ s += "The math expression is" + sgl.gen("expression")
21
+ elif s["tool"] == "search engine":
22
+ s += "The key word to search is" + sgl.gen("word")
23
+
24
+
25
+ @sgl.function
26
+ def tip_suggestion(s):
27
+ s += (
28
+ "Here are two tips for staying healthy: "
29
+ "1. Balanced Diet. 2. Regular Exercise.\n\n"
30
+ )
31
+
32
+ forks = s.fork(2)
33
+ for i, f in enumerate(forks):
34
+ f += f"Now, expand tip {i+1} into a paragraph:\n"
35
+ f += sgl.gen(f"detailed_tip", max_tokens=256, stop="\n\n")
36
+
37
+ s += "Tip 1:" + forks[0]["detailed_tip"] + "\n"
38
+ s += "Tip 2:" + forks[1]["detailed_tip"] + "\n"
39
+ s += "In summary" + sgl.gen("summary")
40
+
41
+
42
+ @sgl.function
43
+ def regular_expression_gen(s):
44
+ s += "Q: What is the IP address of the Google DNS servers?\n"
45
+ s += "A: " + sgl.gen(
46
+ "answer",
47
+ temperature=0,
48
+ regex=r"((25[0-5]|2[0-4]\d|[01]?\d\d?).){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)",
49
+ )
50
+
51
+
52
+ @sgl.function
53
+ def text_qa(s, question):
54
+ s += "Q: " + question + "\n"
55
+ s += "A:" + sgl.gen("answer", stop="\n")
56
+
57
+
58
+ def driver_tool_use():
59
+ state = tool_use.run(question="What is the capital of the United States?")
60
+ print(state.text())
61
+ print("\n")
62
+
63
+
64
+ def driver_tip_suggestion():
65
+ state = tip_suggestion.run()
66
+ print(state.text())
67
+ print("\n")
68
+
69
+
70
+ def driver_regex():
71
+ state = regular_expression_gen.run()
72
+ print(state.text())
73
+ print("\n")
74
+
75
+
76
+ def driver_batching():
77
+ states = text_qa.run_batch(
78
+ [
79
+ {"question": "What is the capital of the United Kingdom?"},
80
+ {"question": "What is the capital of France?"},
81
+ {"question": "What is the capital of Japan?"},
82
+ ],
83
+ progress_bar=True,
84
+ )
85
+
86
+ for s in states:
87
+ print(s.text())
88
+ print("\n")
89
+
90
+
91
+ def driver_stream():
92
+ state = text_qa.run(
93
+ question="What is the capital of France?", temperature=0.1, stream=True
94
+ )
95
+
96
+ for out in state.text_iter():
97
+ print(out, end="", flush=True)
98
+ print("\n")
99
+
100
+
101
+ if __name__ == "__main__":
102
+ # sgl.set_default_backend(sgl.OpenAI("gpt-3.5-turbo-instruct"))
103
+ sgl.set_default_backend(sgl.RuntimeEndpoint("http://localhost:30000"))
104
+
105
+ driver_tool_use()
106
+ driver_tip_suggestion()
107
+ driver_regex()
108
+ driver_batching()
109
+ driver_stream()
sglang/examples/frontend_language/usage/sgl_gen_min_tokens.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This example demonstrates how to use `min_tokens` to enforce sgl.gen to generate a longer sequence
3
+
4
+ Usage:
5
+ python3 sgl_gen_min_tokens.py
6
+ """
7
+
8
+ import sglang as sgl
9
+
10
+
11
+ @sgl.function
12
+ def long_answer(s):
13
+ s += sgl.user("What is the capital of the United States?")
14
+ s += sgl.assistant(sgl.gen("answer", min_tokens=64, max_tokens=128))
15
+
16
+
17
+ @sgl.function
18
+ def short_answer(s):
19
+ s += sgl.user("What is the capital of the United States?")
20
+ s += sgl.assistant(sgl.gen("answer"))
21
+
22
+
23
+ if __name__ == "__main__":
24
+ runtime = sgl.Runtime(model_path="meta-llama/Meta-Llama-3.1-8B-Instruct")
25
+ sgl.set_default_backend(runtime)
26
+
27
+ state = long_answer.run()
28
+ print("=" * 20)
29
+ print("Longer Answer", state["answer"])
30
+
31
+ state = short_answer.run()
32
+ print("=" * 20)
33
+ print("Short Answer", state["answer"])
34
+
35
+ runtime.shutdown()
sglang/examples/frontend_language/usage/streaming.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Usage:
3
+ python3 streaming.py
4
+ """
5
+
6
+ import asyncio
7
+
8
+ import sglang as sgl
9
+
10
+
11
+ @sgl.function
12
+ def multi_turn_question(s, question_1, question_2):
13
+ s += sgl.system("You are a helpful assistant.")
14
+ s += sgl.user(question_1)
15
+ s += sgl.assistant(sgl.gen("answer_1", max_tokens=256))
16
+ s += sgl.user(question_2)
17
+ s += sgl.assistant(sgl.gen("answer_2", max_tokens=256))
18
+
19
+
20
+ sgl.set_default_backend(sgl.OpenAI("gpt-3.5-turbo"))
21
+
22
+
23
+ def stream_a_variable():
24
+ state = multi_turn_question.run(
25
+ question_1="What is the capital of the United States?",
26
+ question_2="List two local attractions.",
27
+ stream=True,
28
+ )
29
+
30
+ for out in state.text_iter(var_name="answer_2"):
31
+ print(out, end="", flush=True)
32
+ print("\n")
33
+
34
+
35
+ async def async_stream():
36
+ state = multi_turn_question.run(
37
+ question_1="What is the capital of the United States?",
38
+ question_2="List two local attractions.",
39
+ stream=True,
40
+ )
41
+
42
+ async for out in state.text_async_iter(var_name="answer_2"):
43
+ print(out, end="", flush=True)
44
+ print("\n")
45
+
46
+
47
+ if __name__ == "__main__":
48
+ stream_a_variable()
49
+ asyncio.run(async_stream())
sglang/examples/frontend_language/usage/triton/Dockerfile ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvcr.io/nvidia/tritonserver:24.01-py3
2
+
3
+ WORKDIR /opt
4
+
5
+ RUN git clone https://github.com/sgl-project/sglang.git
6
+
7
+ WORKDIR /opt/sglang
8
+ RUN pip install --upgrade pip && \
9
+ pip install -e "python[all]" && \
10
+ pip install datasets
sglang/examples/frontend_language/usage/triton/README.md ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # sglang_triton
2
+
3
+ Build the docker image:
4
+ ```
5
+ docker build -t sglang-triton .
6
+ ```
7
+
8
+ Then do:
9
+ ```
10
+ docker run -ti --gpus=all --network=host --name sglang-triton -v ./models:/mnt/models sglang-triton
11
+ ```
12
+
13
+ inside the docker container:
14
+ ```
15
+ cd sglang
16
+ python3 -m sglang.launch_server --model-path mistralai/Mistral-7B-Instruct-v0.2 --port 30000 --mem-fraction-static 0.9
17
+ ```
18
+
19
+ with another shell, inside the docker container:
20
+ ```
21
+ docker exec -ti sglang-triton /bin/bash
22
+ cd /mnt
23
+ tritonserver --model-repository=/mnt/models
24
+ ```
25
+
26
+
27
+ Send request to the server:
28
+ ```
29
+ curl -X POST http://localhost:8000/v2/models/character_generation/generate \
30
+ -H "Content-Type: application/json" \
31
+ -d '{
32
+ "INPUT_TEXT": ["harry"]
33
+ }'
34
+
35
+ ```
sglang/examples/frontend_language/usage/triton/models/character_generation/1/model.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy
2
+ import triton_python_backend_utils as pb_utils
3
+ from pydantic import BaseModel
4
+
5
+ import sglang as sgl
6
+ from sglang import function, set_default_backend
7
+ from sglang.srt.constrained import build_regex_from_object
8
+
9
+ sgl.set_default_backend(sgl.RuntimeEndpoint("http://localhost:30000"))
10
+
11
+
12
+ class Character(BaseModel):
13
+ name: str
14
+ eye_color: str
15
+ house: str
16
+
17
+
18
+ @function
19
+ def character_gen(s, name):
20
+ s += (
21
+ name
22
+ + " is a character in Harry Potter. Please fill in the following information about this character.\n"
23
+ )
24
+ s += sgl.gen(
25
+ "json_output", max_tokens=256, regex=build_regex_from_object(Character)
26
+ )
27
+
28
+
29
+ class TritonPythonModel:
30
+ def initialize(self, args):
31
+ print("Initialized.")
32
+
33
+ def execute(self, requests):
34
+ responses = []
35
+ for request in requests:
36
+ tensor_in = pb_utils.get_input_tensor_by_name(request, "INPUT_TEXT")
37
+ if tensor_in is None:
38
+ return pb_utils.InferenceResponse(output_tensors=[])
39
+
40
+ input_list_names = [
41
+ i.decode("utf-8") if isinstance(i, bytes) else i
42
+ for i in tensor_in.as_numpy().tolist()
43
+ ]
44
+
45
+ input_list_dicts = [{"name": i} for i in input_list_names]
46
+
47
+ states = character_gen.run_batch(input_list_dicts)
48
+ character_strs = [state.text() for state in states]
49
+
50
+ tensor_out = pb_utils.Tensor(
51
+ "OUTPUT_TEXT", numpy.array(character_strs, dtype=object)
52
+ )
53
+
54
+ responses.append(pb_utils.InferenceResponse(output_tensors=[tensor_out]))
55
+ return responses
sglang/examples/frontend_language/usage/triton/models/character_generation/config.pbtxt ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: "character_generation"
2
+ backend: "python"
3
+ input [
4
+ {
5
+ name: "INPUT_TEXT"
6
+ data_type: TYPE_STRING
7
+ dims: [ -1 ]
8
+ }
9
+ ]
10
+ output [
11
+ {
12
+ name: "OUTPUT_TEXT"
13
+ data_type: TYPE_STRING
14
+ dims: [ -1 ]
15
+ }
16
+ ]
17
+ instance_group [
18
+ {
19
+ count: 1
20
+ kind: KIND_GPU
21
+ gpus: [ 0 ]
22
+ }
23
+ ]
sglang/examples/monitoring/grafana.json ADDED
@@ -0,0 +1,1720 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "annotations": {
3
+ "list": [
4
+ {
5
+ "builtIn": 1,
6
+ "datasource": {
7
+ "type": "grafana",
8
+ "uid": "-- Grafana --"
9
+ },
10
+ "enable": true,
11
+ "hide": true,
12
+ "iconColor": "rgba(0, 211, 255, 1)",
13
+ "name": "Annotations & Alerts",
14
+ "type": "dashboard"
15
+ }
16
+ ]
17
+ },
18
+ "editable": true,
19
+ "fiscalYearStartMonth": 0,
20
+ "graphTooltip": 0,
21
+ "id": 1,
22
+ "links": [],
23
+ "panels": [
24
+ {
25
+ "datasource": {
26
+ "default": true,
27
+ "type": "prometheus",
28
+ "uid": "ee2vha8w6f5kwf"
29
+ },
30
+ "description": "max-running-requests from server argument",
31
+ "fieldConfig": {
32
+ "defaults": {
33
+ "color": {
34
+ "mode": "thresholds"
35
+ },
36
+ "mappings": [],
37
+ "thresholds": {
38
+ "mode": "absolute",
39
+ "steps": [
40
+ {
41
+ "color": "green",
42
+ "value": null
43
+ }
44
+ ]
45
+ }
46
+ },
47
+ "overrides": []
48
+ },
49
+ "gridPos": {
50
+ "h": 3,
51
+ "w": 3,
52
+ "x": 0,
53
+ "y": 0
54
+ },
55
+ "id": 2,
56
+ "options": {
57
+ "colorMode": "value",
58
+ "graphMode": "none",
59
+ "justifyMode": "auto",
60
+ "orientation": "auto",
61
+ "percentChangeColorMode": "standard",
62
+ "reduceOptions": {
63
+ "calcs": [
64
+ "last"
65
+ ],
66
+ "fields": "",
67
+ "values": false
68
+ },
69
+ "showPercentChange": false,
70
+ "textMode": "auto",
71
+ "wideLayout": true
72
+ },
73
+ "pluginVersion": "11.2.0",
74
+ "targets": [
75
+ {
76
+ "datasource": {
77
+ "type": "prometheus",
78
+ "uid": "ddyfngn31dg5cf"
79
+ },
80
+ "disableTextWrap": false,
81
+ "editorMode": "builder",
82
+ "expr": "sglang:max_running_requests{name=\"$name\", instance=\"$instance\"}",
83
+ "fullMetaSearch": false,
84
+ "includeNullMetadata": true,
85
+ "instant": false,
86
+ "legendFormat": "__auto",
87
+ "range": true,
88
+ "refId": "A",
89
+ "useBackend": false
90
+ }
91
+ ],
92
+ "title": "Max Running Requests",
93
+ "type": "stat"
94
+ },
95
+ {
96
+ "datasource": {
97
+ "default": true,
98
+ "type": "prometheus",
99
+ "uid": "ee2vha8w6f5kwf"
100
+ },
101
+ "description": "Supported context length with loaded model",
102
+ "fieldConfig": {
103
+ "defaults": {
104
+ "color": {
105
+ "mode": "thresholds"
106
+ },
107
+ "mappings": [],
108
+ "thresholds": {
109
+ "mode": "absolute",
110
+ "steps": [
111
+ {
112
+ "color": "green",
113
+ "value": null
114
+ }
115
+ ]
116
+ }
117
+ },
118
+ "overrides": []
119
+ },
120
+ "gridPos": {
121
+ "h": 3,
122
+ "w": 3,
123
+ "x": 3,
124
+ "y": 0
125
+ },
126
+ "id": 1,
127
+ "options": {
128
+ "colorMode": "value",
129
+ "graphMode": "none",
130
+ "justifyMode": "auto",
131
+ "orientation": "auto",
132
+ "percentChangeColorMode": "standard",
133
+ "reduceOptions": {
134
+ "calcs": [
135
+ "last"
136
+ ],
137
+ "fields": "",
138
+ "values": false
139
+ },
140
+ "showPercentChange": false,
141
+ "textMode": "auto",
142
+ "wideLayout": true
143
+ },
144
+ "pluginVersion": "11.2.0",
145
+ "targets": [
146
+ {
147
+ "datasource": {
148
+ "type": "prometheus",
149
+ "uid": "ddyfngn31dg5cf"
150
+ },
151
+ "disableTextWrap": false,
152
+ "editorMode": "builder",
153
+ "expr": "sglang:context_len{instance=\"$instance\", name=\"$name\"}",
154
+ "fullMetaSearch": false,
155
+ "includeNullMetadata": true,
156
+ "instant": false,
157
+ "legendFormat": "__auto",
158
+ "range": true,
159
+ "refId": "A",
160
+ "useBackend": false
161
+ }
162
+ ],
163
+ "title": "Max Context Length",
164
+ "type": "stat"
165
+ },
166
+ {
167
+ "datasource": {
168
+ "default": true,
169
+ "type": "prometheus",
170
+ "uid": "ee2vha8w6f5kwf"
171
+ },
172
+ "description": "max_total_tokens",
173
+ "fieldConfig": {
174
+ "defaults": {
175
+ "color": {
176
+ "mode": "thresholds"
177
+ },
178
+ "mappings": [],
179
+ "thresholds": {
180
+ "mode": "absolute",
181
+ "steps": [
182
+ {
183
+ "color": "green",
184
+ "value": null
185
+ }
186
+ ]
187
+ }
188
+ },
189
+ "overrides": []
190
+ },
191
+ "gridPos": {
192
+ "h": 3,
193
+ "w": 3,
194
+ "x": 6,
195
+ "y": 0
196
+ },
197
+ "id": 4,
198
+ "options": {
199
+ "colorMode": "value",
200
+ "graphMode": "none",
201
+ "justifyMode": "auto",
202
+ "orientation": "auto",
203
+ "percentChangeColorMode": "standard",
204
+ "reduceOptions": {
205
+ "calcs": [
206
+ "last"
207
+ ],
208
+ "fields": "",
209
+ "values": false
210
+ },
211
+ "showPercentChange": false,
212
+ "textMode": "auto",
213
+ "wideLayout": true
214
+ },
215
+ "pluginVersion": "11.2.0",
216
+ "targets": [
217
+ {
218
+ "datasource": {
219
+ "type": "prometheus",
220
+ "uid": "ddyfngn31dg5cf"
221
+ },
222
+ "disableTextWrap": false,
223
+ "editorMode": "builder",
224
+ "expr": "sglang:max_total_num_tokens{instance=\"$instance\", name=\"$name\"}",
225
+ "fullMetaSearch": false,
226
+ "includeNullMetadata": true,
227
+ "instant": false,
228
+ "legendFormat": "__auto",
229
+ "range": true,
230
+ "refId": "A",
231
+ "useBackend": false
232
+ }
233
+ ],
234
+ "title": "Max Total Num Tokens",
235
+ "type": "stat"
236
+ },
237
+ {
238
+ "datasource": {
239
+ "default": true,
240
+ "type": "prometheus",
241
+ "uid": "ee2vha8w6f5kwf"
242
+ },
243
+ "description": "max_prefill_tokens from server args",
244
+ "fieldConfig": {
245
+ "defaults": {
246
+ "color": {
247
+ "mode": "thresholds"
248
+ },
249
+ "mappings": [],
250
+ "thresholds": {
251
+ "mode": "absolute",
252
+ "steps": [
253
+ {
254
+ "color": "green",
255
+ "value": null
256
+ }
257
+ ]
258
+ }
259
+ },
260
+ "overrides": []
261
+ },
262
+ "gridPos": {
263
+ "h": 3,
264
+ "w": 3,
265
+ "x": 9,
266
+ "y": 0
267
+ },
268
+ "id": 3,
269
+ "options": {
270
+ "colorMode": "value",
271
+ "graphMode": "none",
272
+ "justifyMode": "auto",
273
+ "orientation": "auto",
274
+ "percentChangeColorMode": "standard",
275
+ "reduceOptions": {
276
+ "calcs": [
277
+ "last"
278
+ ],
279
+ "fields": "",
280
+ "values": false
281
+ },
282
+ "showPercentChange": false,
283
+ "textMode": "auto",
284
+ "wideLayout": true
285
+ },
286
+ "pluginVersion": "11.2.0",
287
+ "targets": [
288
+ {
289
+ "datasource": {
290
+ "type": "prometheus",
291
+ "uid": "ddyfngn31dg5cf"
292
+ },
293
+ "disableTextWrap": false,
294
+ "editorMode": "code",
295
+ "expr": "sglang:max_prefill_tokens{instance=\"$instance\", name=\"$name\"}",
296
+ "fullMetaSearch": false,
297
+ "includeNullMetadata": true,
298
+ "instant": false,
299
+ "legendFormat": "__auto",
300
+ "range": true,
301
+ "refId": "A",
302
+ "useBackend": false
303
+ }
304
+ ],
305
+ "title": "Max Prefill Tokens",
306
+ "type": "stat"
307
+ },
308
+ {
309
+ "datasource": {
310
+ "default": true,
311
+ "type": "prometheus",
312
+ "uid": "ee2vha8w6f5kwf"
313
+ },
314
+ "fieldConfig": {
315
+ "defaults": {
316
+ "color": {
317
+ "mode": "thresholds"
318
+ },
319
+ "mappings": [],
320
+ "thresholds": {
321
+ "mode": "absolute",
322
+ "steps": [
323
+ {
324
+ "color": "green",
325
+ "value": null
326
+ }
327
+ ]
328
+ }
329
+ },
330
+ "overrides": []
331
+ },
332
+ "gridPos": {
333
+ "h": 3,
334
+ "w": 6,
335
+ "x": 12,
336
+ "y": 0
337
+ },
338
+ "id": 6,
339
+ "options": {
340
+ "colorMode": "value",
341
+ "graphMode": "area",
342
+ "justifyMode": "auto",
343
+ "orientation": "auto",
344
+ "percentChangeColorMode": "standard",
345
+ "reduceOptions": {
346
+ "calcs": [
347
+ "lastNotNull"
348
+ ],
349
+ "fields": "",
350
+ "values": false
351
+ },
352
+ "showPercentChange": false,
353
+ "textMode": "auto",
354
+ "wideLayout": true
355
+ },
356
+ "pluginVersion": "11.2.0",
357
+ "targets": [
358
+ {
359
+ "datasource": {
360
+ "type": "prometheus",
361
+ "uid": "ddyfngn31dg5cf"
362
+ },
363
+ "disableTextWrap": false,
364
+ "editorMode": "code",
365
+ "expr": "sglang:cached_token{instance=\"$instance\", name=\"$name\"}",
366
+ "fullMetaSearch": false,
367
+ "includeNullMetadata": true,
368
+ "instant": false,
369
+ "legendFormat": "{{__name__}}",
370
+ "range": true,
371
+ "refId": "A",
372
+ "useBackend": false
373
+ }
374
+ ],
375
+ "title": "Cached Tokens",
376
+ "type": "stat"
377
+ },
378
+ {
379
+ "datasource": {
380
+ "default": true,
381
+ "type": "prometheus",
382
+ "uid": "ee2vha8w6f5kwf"
383
+ },
384
+ "description": "",
385
+ "fieldConfig": {
386
+ "defaults": {
387
+ "color": {
388
+ "mode": "thresholds"
389
+ },
390
+ "mappings": [],
391
+ "thresholds": {
392
+ "mode": "absolute",
393
+ "steps": [
394
+ {
395
+ "color": "green",
396
+ "value": null
397
+ }
398
+ ]
399
+ }
400
+ },
401
+ "overrides": []
402
+ },
403
+ "gridPos": {
404
+ "h": 3,
405
+ "w": 6,
406
+ "x": 18,
407
+ "y": 0
408
+ },
409
+ "id": 5,
410
+ "options": {
411
+ "colorMode": "value",
412
+ "graphMode": "area",
413
+ "justifyMode": "auto",
414
+ "orientation": "auto",
415
+ "percentChangeColorMode": "standard",
416
+ "reduceOptions": {
417
+ "calcs": [
418
+ "lastNotNull"
419
+ ],
420
+ "fields": "",
421
+ "values": false
422
+ },
423
+ "showPercentChange": false,
424
+ "textMode": "auto",
425
+ "wideLayout": true
426
+ },
427
+ "pluginVersion": "11.2.0",
428
+ "targets": [
429
+ {
430
+ "datasource": {
431
+ "type": "prometheus",
432
+ "uid": "ddyfngn31dg5cf"
433
+ },
434
+ "disableTextWrap": false,
435
+ "editorMode": "code",
436
+ "expr": "sglang:cache_hit_rate{instance=\"$instance\", name=\"$name\"}",
437
+ "fullMetaSearch": false,
438
+ "includeNullMetadata": true,
439
+ "instant": false,
440
+ "legendFormat": "{{__name__}}",
441
+ "range": true,
442
+ "refId": "A",
443
+ "useBackend": false
444
+ }
445
+ ],
446
+ "title": "Cache Hit Rate (%)",
447
+ "type": "stat"
448
+ },
449
+ {
450
+ "datasource": {
451
+ "default": true,
452
+ "type": "prometheus",
453
+ "uid": "ee2vha8w6f5kwf"
454
+ },
455
+ "fieldConfig": {
456
+ "defaults": {
457
+ "color": {
458
+ "mode": "palette-classic"
459
+ },
460
+ "custom": {
461
+ "axisBorderShow": false,
462
+ "axisCenteredZero": false,
463
+ "axisColorMode": "text",
464
+ "axisLabel": "",
465
+ "axisPlacement": "auto",
466
+ "barAlignment": 0,
467
+ "barWidthFactor": 0.6,
468
+ "drawStyle": "line",
469
+ "fillOpacity": 0,
470
+ "gradientMode": "none",
471
+ "hideFrom": {
472
+ "legend": false,
473
+ "tooltip": false,
474
+ "viz": false
475
+ },
476
+ "insertNulls": false,
477
+ "lineInterpolation": "linear",
478
+ "lineWidth": 1,
479
+ "pointSize": 5,
480
+ "scaleDistribution": {
481
+ "type": "linear"
482
+ },
483
+ "showPoints": "auto",
484
+ "spanNulls": false,
485
+ "stacking": {
486
+ "group": "A",
487
+ "mode": "none"
488
+ },
489
+ "thresholdsStyle": {
490
+ "mode": "off"
491
+ }
492
+ },
493
+ "mappings": [],
494
+ "thresholds": {
495
+ "mode": "absolute",
496
+ "steps": [
497
+ {
498
+ "color": "green",
499
+ "value": null
500
+ },
501
+ {
502
+ "color": "red",
503
+ "value": 80
504
+ }
505
+ ]
506
+ }
507
+ },
508
+ "overrides": []
509
+ },
510
+ "gridPos": {
511
+ "h": 8,
512
+ "w": 12,
513
+ "x": 0,
514
+ "y": 3
515
+ },
516
+ "id": 14,
517
+ "options": {
518
+ "legend": {
519
+ "calcs": [],
520
+ "displayMode": "list",
521
+ "placement": "bottom",
522
+ "showLegend": true
523
+ },
524
+ "tooltip": {
525
+ "mode": "single",
526
+ "sort": "none"
527
+ }
528
+ },
529
+ "targets": [
530
+ {
531
+ "datasource": {
532
+ "type": "prometheus",
533
+ "uid": "ddyfngn31dg5cf"
534
+ },
535
+ "disableTextWrap": false,
536
+ "editorMode": "code",
537
+ "expr": "histogram_quantile(0.99, sum by(le) (rate(sglang:e2e_request_latency_seconds_bucket{instance=\"$instance\", name=\"$name\"}[$__rate_interval])))",
538
+ "fullMetaSearch": false,
539
+ "includeNullMetadata": true,
540
+ "instant": false,
541
+ "legendFormat": "P99",
542
+ "range": true,
543
+ "refId": "A",
544
+ "useBackend": false
545
+ },
546
+ {
547
+ "datasource": {
548
+ "type": "prometheus",
549
+ "uid": "ddyfngn31dg5cf"
550
+ },
551
+ "disableTextWrap": false,
552
+ "editorMode": "code",
553
+ "expr": "histogram_quantile(0.9, sum by(le) (rate(sglang:e2e_request_latency_seconds_bucket{instance=\"$instance\", name=\"$name\"}[$__rate_interval])))",
554
+ "fullMetaSearch": false,
555
+ "hide": false,
556
+ "includeNullMetadata": true,
557
+ "instant": false,
558
+ "legendFormat": "P90",
559
+ "range": true,
560
+ "refId": "B",
561
+ "useBackend": false
562
+ },
563
+ {
564
+ "datasource": {
565
+ "type": "prometheus",
566
+ "uid": "ddyfngn31dg5cf"
567
+ },
568
+ "disableTextWrap": false,
569
+ "editorMode": "builder",
570
+ "expr": "histogram_quantile(0.95, sum by(le) (rate(sglang:e2e_request_latency_seconds_bucket{instance=\"$instance\", name=\"$model_name\"}[$__rate_interval])))",
571
+ "fullMetaSearch": false,
572
+ "hide": false,
573
+ "includeNullMetadata": true,
574
+ "instant": false,
575
+ "legendFormat": "P95",
576
+ "range": true,
577
+ "refId": "C",
578
+ "useBackend": false
579
+ },
580
+ {
581
+ "datasource": {
582
+ "type": "prometheus",
583
+ "uid": "ddyfngn31dg5cf"
584
+ },
585
+ "disableTextWrap": false,
586
+ "editorMode": "builder",
587
+ "expr": "histogram_quantile(0.5, sum by(le) (rate(sglang:e2e_request_latency_seconds_bucket{instance=\"$instance\", name=\"$model_name\"}[$__rate_interval])))",
588
+ "fullMetaSearch": false,
589
+ "hide": false,
590
+ "includeNullMetadata": true,
591
+ "instant": false,
592
+ "legendFormat": "P50",
593
+ "range": true,
594
+ "refId": "D",
595
+ "useBackend": false
596
+ },
597
+ {
598
+ "datasource": {
599
+ "type": "prometheus",
600
+ "uid": "ddyfngn31dg5cf"
601
+ },
602
+ "disableTextWrap": false,
603
+ "editorMode": "builder",
604
+ "expr": "rate(sglang:e2e_request_latency_seconds_sum{instance=\"$instance\", name=\"$model_name\"}[$__rate_interval]) / rate(sglang:e2e_request_latency_seconds_count[$__rate_interval])",
605
+ "fullMetaSearch": false,
606
+ "hide": false,
607
+ "includeNullMetadata": true,
608
+ "instant": false,
609
+ "legendFormat": "Average",
610
+ "range": true,
611
+ "refId": "E",
612
+ "useBackend": false
613
+ }
614
+ ],
615
+ "title": "E2E Request Latency (S)",
616
+ "type": "timeseries"
617
+ },
618
+ {
619
+ "datasource": {
620
+ "default": true,
621
+ "type": "prometheus",
622
+ "uid": "ee2vha8w6f5kwf"
623
+ },
624
+ "fieldConfig": {
625
+ "defaults": {
626
+ "color": {
627
+ "mode": "palette-classic"
628
+ },
629
+ "custom": {
630
+ "axisBorderShow": false,
631
+ "axisCenteredZero": false,
632
+ "axisColorMode": "text",
633
+ "axisLabel": "",
634
+ "axisPlacement": "auto",
635
+ "barAlignment": 0,
636
+ "barWidthFactor": 0.6,
637
+ "drawStyle": "line",
638
+ "fillOpacity": 0,
639
+ "gradientMode": "none",
640
+ "hideFrom": {
641
+ "legend": false,
642
+ "tooltip": false,
643
+ "viz": false
644
+ },
645
+ "insertNulls": false,
646
+ "lineInterpolation": "linear",
647
+ "lineWidth": 1,
648
+ "pointSize": 5,
649
+ "scaleDistribution": {
650
+ "type": "linear"
651
+ },
652
+ "showPoints": "auto",
653
+ "spanNulls": false,
654
+ "stacking": {
655
+ "group": "A",
656
+ "mode": "none"
657
+ },
658
+ "thresholdsStyle": {
659
+ "mode": "off"
660
+ }
661
+ },
662
+ "mappings": [],
663
+ "thresholds": {
664
+ "mode": "absolute",
665
+ "steps": [
666
+ {
667
+ "color": "green",
668
+ "value": null
669
+ },
670
+ {
671
+ "color": "red",
672
+ "value": 80
673
+ }
674
+ ]
675
+ }
676
+ },
677
+ "overrides": []
678
+ },
679
+ "gridPos": {
680
+ "h": 8,
681
+ "w": 12,
682
+ "x": 12,
683
+ "y": 3
684
+ },
685
+ "id": 18,
686
+ "options": {
687
+ "legend": {
688
+ "calcs": [],
689
+ "displayMode": "list",
690
+ "placement": "bottom",
691
+ "showLegend": true
692
+ },
693
+ "tooltip": {
694
+ "mode": "single",
695
+ "sort": "none"
696
+ }
697
+ },
698
+ "targets": [
699
+ {
700
+ "datasource": {
701
+ "type": "prometheus",
702
+ "uid": "ddyfngn31dg5cf"
703
+ },
704
+ "editorMode": "code",
705
+ "expr": "sglang:gen_throughput{instance=\"$instance\", name=\"$name\"}",
706
+ "instant": false,
707
+ "legendFormat": "__auto",
708
+ "range": true,
709
+ "refId": "A"
710
+ }
711
+ ],
712
+ "title": "Generation Throughput (Token / S)",
713
+ "type": "timeseries"
714
+ },
715
+ {
716
+ "datasource": {
717
+ "default": true,
718
+ "type": "prometheus",
719
+ "uid": "ee2vha8w6f5kwf"
720
+ },
721
+ "fieldConfig": {
722
+ "defaults": {
723
+ "color": {
724
+ "mode": "palette-classic"
725
+ },
726
+ "custom": {
727
+ "axisBorderShow": false,
728
+ "axisCenteredZero": false,
729
+ "axisColorMode": "text",
730
+ "axisLabel": "",
731
+ "axisPlacement": "auto",
732
+ "barAlignment": 0,
733
+ "barWidthFactor": 0.6,
734
+ "drawStyle": "line",
735
+ "fillOpacity": 0,
736
+ "gradientMode": "none",
737
+ "hideFrom": {
738
+ "legend": false,
739
+ "tooltip": false,
740
+ "viz": false
741
+ },
742
+ "insertNulls": false,
743
+ "lineInterpolation": "linear",
744
+ "lineWidth": 1,
745
+ "pointSize": 5,
746
+ "scaleDistribution": {
747
+ "type": "linear"
748
+ },
749
+ "showPoints": "auto",
750
+ "spanNulls": false,
751
+ "stacking": {
752
+ "group": "A",
753
+ "mode": "none"
754
+ },
755
+ "thresholdsStyle": {
756
+ "mode": "off"
757
+ }
758
+ },
759
+ "mappings": [],
760
+ "thresholds": {
761
+ "mode": "absolute",
762
+ "steps": [
763
+ {
764
+ "color": "green",
765
+ "value": null
766
+ },
767
+ {
768
+ "color": "red",
769
+ "value": 80
770
+ }
771
+ ]
772
+ }
773
+ },
774
+ "overrides": []
775
+ },
776
+ "gridPos": {
777
+ "h": 8,
778
+ "w": 12,
779
+ "x": 0,
780
+ "y": 11
781
+ },
782
+ "id": 7,
783
+ "options": {
784
+ "legend": {
785
+ "calcs": [],
786
+ "displayMode": "list",
787
+ "placement": "bottom",
788
+ "showLegend": true
789
+ },
790
+ "tooltip": {
791
+ "mode": "single",
792
+ "sort": "none"
793
+ }
794
+ },
795
+ "targets": [
796
+ {
797
+ "datasource": {
798
+ "type": "prometheus",
799
+ "uid": "ddyfngn31dg5cf"
800
+ },
801
+ "disableTextWrap": false,
802
+ "editorMode": "code",
803
+ "expr": "sglang:num_requests_running{instance=\"$instance\", name=\"$name\"}",
804
+ "fullMetaSearch": false,
805
+ "includeNullMetadata": true,
806
+ "instant": false,
807
+ "legendFormat": "{{__name__}}",
808
+ "range": true,
809
+ "refId": "A",
810
+ "useBackend": false
811
+ }
812
+ ],
813
+ "title": "Num Requests Running",
814
+ "type": "timeseries"
815
+ },
816
+ {
817
+ "datasource": {
818
+ "default": true,
819
+ "type": "prometheus",
820
+ "uid": "ee2vha8w6f5kwf"
821
+ },
822
+ "fieldConfig": {
823
+ "defaults": {
824
+ "color": {
825
+ "mode": "palette-classic"
826
+ },
827
+ "custom": {
828
+ "axisBorderShow": false,
829
+ "axisCenteredZero": false,
830
+ "axisColorMode": "text",
831
+ "axisLabel": "",
832
+ "axisPlacement": "auto",
833
+ "barAlignment": 0,
834
+ "barWidthFactor": 0.6,
835
+ "drawStyle": "line",
836
+ "fillOpacity": 0,
837
+ "gradientMode": "none",
838
+ "hideFrom": {
839
+ "legend": false,
840
+ "tooltip": false,
841
+ "viz": false
842
+ },
843
+ "insertNulls": false,
844
+ "lineInterpolation": "linear",
845
+ "lineWidth": 1,
846
+ "pointSize": 5,
847
+ "scaleDistribution": {
848
+ "type": "linear"
849
+ },
850
+ "showPoints": "auto",
851
+ "spanNulls": false,
852
+ "stacking": {
853
+ "group": "A",
854
+ "mode": "none"
855
+ },
856
+ "thresholdsStyle": {
857
+ "mode": "off"
858
+ }
859
+ },
860
+ "mappings": [],
861
+ "thresholds": {
862
+ "mode": "absolute",
863
+ "steps": [
864
+ {
865
+ "color": "green",
866
+ "value": null
867
+ },
868
+ {
869
+ "color": "red",
870
+ "value": 80
871
+ }
872
+ ]
873
+ }
874
+ },
875
+ "overrides": []
876
+ },
877
+ "gridPos": {
878
+ "h": 8,
879
+ "w": 12,
880
+ "x": 12,
881
+ "y": 11
882
+ },
883
+ "id": 8,
884
+ "options": {
885
+ "legend": {
886
+ "calcs": [],
887
+ "displayMode": "list",
888
+ "placement": "bottom",
889
+ "showLegend": true
890
+ },
891
+ "tooltip": {
892
+ "mode": "single",
893
+ "sort": "none"
894
+ }
895
+ },
896
+ "targets": [
897
+ {
898
+ "datasource": {
899
+ "type": "prometheus",
900
+ "uid": "ddyfngn31dg5cf"
901
+ },
902
+ "disableTextWrap": false,
903
+ "editorMode": "code",
904
+ "expr": "sglang:num_requests_waiting{instance=\"$instance\", name=\"$name\"}",
905
+ "fullMetaSearch": false,
906
+ "includeNullMetadata": true,
907
+ "instant": false,
908
+ "legendFormat": "{{__name__}}",
909
+ "range": true,
910
+ "refId": "A",
911
+ "useBackend": false
912
+ }
913
+ ],
914
+ "title": "Number of Requests Waiting",
915
+ "type": "timeseries"
916
+ },
917
+ {
918
+ "datasource": {
919
+ "default": true,
920
+ "type": "prometheus",
921
+ "uid": "ee2vha8w6f5kwf"
922
+ },
923
+ "fieldConfig": {
924
+ "defaults": {
925
+ "color": {
926
+ "mode": "palette-classic"
927
+ },
928
+ "custom": {
929
+ "axisBorderShow": false,
930
+ "axisCenteredZero": false,
931
+ "axisColorMode": "text",
932
+ "axisLabel": "",
933
+ "axisPlacement": "auto",
934
+ "barAlignment": 0,
935
+ "barWidthFactor": 0.6,
936
+ "drawStyle": "line",
937
+ "fillOpacity": 0,
938
+ "gradientMode": "none",
939
+ "hideFrom": {
940
+ "legend": false,
941
+ "tooltip": false,
942
+ "viz": false
943
+ },
944
+ "insertNulls": false,
945
+ "lineInterpolation": "linear",
946
+ "lineWidth": 1,
947
+ "pointSize": 5,
948
+ "scaleDistribution": {
949
+ "type": "linear"
950
+ },
951
+ "showPoints": "auto",
952
+ "spanNulls": false,
953
+ "stacking": {
954
+ "group": "A",
955
+ "mode": "none"
956
+ },
957
+ "thresholdsStyle": {
958
+ "mode": "off"
959
+ }
960
+ },
961
+ "mappings": [],
962
+ "thresholds": {
963
+ "mode": "absolute",
964
+ "steps": [
965
+ {
966
+ "color": "green",
967
+ "value": null
968
+ },
969
+ {
970
+ "color": "red",
971
+ "value": 80
972
+ }
973
+ ]
974
+ }
975
+ },
976
+ "overrides": []
977
+ },
978
+ "gridPos": {
979
+ "h": 8,
980
+ "w": 12,
981
+ "x": 0,
982
+ "y": 19
983
+ },
984
+ "id": 16,
985
+ "options": {
986
+ "legend": {
987
+ "calcs": [],
988
+ "displayMode": "list",
989
+ "placement": "bottom",
990
+ "showLegend": true
991
+ },
992
+ "tooltip": {
993
+ "mode": "single",
994
+ "sort": "none"
995
+ }
996
+ },
997
+ "targets": [
998
+ {
999
+ "datasource": {
1000
+ "type": "prometheus",
1001
+ "uid": "ddyfngn31dg5cf"
1002
+ },
1003
+ "disableTextWrap": false,
1004
+ "editorMode": "code",
1005
+ "expr": "histogram_quantile(0.99, sum by(le) (rate(sglang:e2e_request_latency_seconds_bucket{name=\"$name\"}[$__rate_interval])))",
1006
+ "fullMetaSearch": false,
1007
+ "includeNullMetadata": true,
1008
+ "instant": false,
1009
+ "legendFormat": "P99",
1010
+ "range": true,
1011
+ "refId": "A",
1012
+ "useBackend": false
1013
+ },
1014
+ {
1015
+ "datasource": {
1016
+ "type": "prometheus",
1017
+ "uid": "ddyfngn31dg5cf"
1018
+ },
1019
+ "disableTextWrap": false,
1020
+ "editorMode": "code",
1021
+ "expr": "histogram_quantile(0.9, sum by(le) (rate(sglang:e2e_request_latency_seconds_bucket{name=\"$name\"}[$__rate_interval])))",
1022
+ "fullMetaSearch": false,
1023
+ "hide": false,
1024
+ "includeNullMetadata": true,
1025
+ "instant": false,
1026
+ "legendFormat": "P90",
1027
+ "range": true,
1028
+ "refId": "B",
1029
+ "useBackend": false
1030
+ },
1031
+ {
1032
+ "datasource": {
1033
+ "type": "prometheus",
1034
+ "uid": "ddyfngn31dg5cf"
1035
+ },
1036
+ "disableTextWrap": false,
1037
+ "editorMode": "code",
1038
+ "expr": "histogram_quantile(0.95, sum by(le) (rate(sglang:e2e_request_latency_seconds_bucket{name=\"$name\"}[$__rate_interval])))",
1039
+ "fullMetaSearch": false,
1040
+ "hide": false,
1041
+ "includeNullMetadata": true,
1042
+ "instant": false,
1043
+ "legendFormat": "P95",
1044
+ "range": true,
1045
+ "refId": "C",
1046
+ "useBackend": false
1047
+ },
1048
+ {
1049
+ "datasource": {
1050
+ "type": "prometheus",
1051
+ "uid": "ddyfngn31dg5cf"
1052
+ },
1053
+ "disableTextWrap": false,
1054
+ "editorMode": "code",
1055
+ "expr": "histogram_quantile(0.5, sum by(le) (rate(sglang:e2e_request_latency_seconds_bucket{name=\"$name\"}[$__rate_interval])))",
1056
+ "fullMetaSearch": false,
1057
+ "hide": false,
1058
+ "includeNullMetadata": true,
1059
+ "instant": false,
1060
+ "legendFormat": "P50",
1061
+ "range": true,
1062
+ "refId": "D",
1063
+ "useBackend": false
1064
+ },
1065
+ {
1066
+ "datasource": {
1067
+ "type": "prometheus",
1068
+ "uid": "ddyfngn31dg5cf"
1069
+ },
1070
+ "disableTextWrap": false,
1071
+ "editorMode": "code",
1072
+ "expr": "rate(sglang:e2e_request_latency_seconds_sum{name=\"$name\"}[$__rate_interval]) / rate(sglang:e2e_request_latency_seconds_count{name=\"$name\"}[$__rate_interval])",
1073
+ "fullMetaSearch": false,
1074
+ "hide": false,
1075
+ "includeNullMetadata": true,
1076
+ "instant": false,
1077
+ "legendFormat": "Average",
1078
+ "range": true,
1079
+ "refId": "E",
1080
+ "useBackend": false
1081
+ }
1082
+ ],
1083
+ "title": "Time Request Decoding (S)",
1084
+ "type": "timeseries"
1085
+ },
1086
+ {
1087
+ "datasource": {
1088
+ "default": true,
1089
+ "type": "prometheus",
1090
+ "uid": "ee2vha8w6f5kwf"
1091
+ },
1092
+ "description": "Time requests waiting before added to batch",
1093
+ "fieldConfig": {
1094
+ "defaults": {
1095
+ "color": {
1096
+ "mode": "palette-classic"
1097
+ },
1098
+ "custom": {
1099
+ "axisBorderShow": false,
1100
+ "axisCenteredZero": false,
1101
+ "axisColorMode": "text",
1102
+ "axisLabel": "",
1103
+ "axisPlacement": "auto",
1104
+ "barAlignment": 0,
1105
+ "barWidthFactor": 0.6,
1106
+ "drawStyle": "line",
1107
+ "fillOpacity": 0,
1108
+ "gradientMode": "none",
1109
+ "hideFrom": {
1110
+ "legend": false,
1111
+ "tooltip": false,
1112
+ "viz": false
1113
+ },
1114
+ "insertNulls": false,
1115
+ "lineInterpolation": "linear",
1116
+ "lineWidth": 1,
1117
+ "pointSize": 5,
1118
+ "scaleDistribution": {
1119
+ "type": "linear"
1120
+ },
1121
+ "showPoints": "auto",
1122
+ "spanNulls": false,
1123
+ "stacking": {
1124
+ "group": "A",
1125
+ "mode": "none"
1126
+ },
1127
+ "thresholdsStyle": {
1128
+ "mode": "off"
1129
+ }
1130
+ },
1131
+ "mappings": [],
1132
+ "thresholds": {
1133
+ "mode": "absolute",
1134
+ "steps": [
1135
+ {
1136
+ "color": "green",
1137
+ "value": null
1138
+ },
1139
+ {
1140
+ "color": "red",
1141
+ "value": 80
1142
+ }
1143
+ ]
1144
+ }
1145
+ },
1146
+ "overrides": []
1147
+ },
1148
+ "gridPos": {
1149
+ "h": 8,
1150
+ "w": 12,
1151
+ "x": 12,
1152
+ "y": 19
1153
+ },
1154
+ "id": 15,
1155
+ "options": {
1156
+ "legend": {
1157
+ "calcs": [],
1158
+ "displayMode": "list",
1159
+ "placement": "bottom",
1160
+ "showLegend": true
1161
+ },
1162
+ "tooltip": {
1163
+ "mode": "single",
1164
+ "sort": "none"
1165
+ }
1166
+ },
1167
+ "targets": [
1168
+ {
1169
+ "datasource": {
1170
+ "type": "prometheus",
1171
+ "uid": "ddyfngn31dg5cf"
1172
+ },
1173
+ "editorMode": "code",
1174
+ "expr": "histogram_quantile(0.99, sum by (le) (rate(sglang:waiting_request_latency_seconds_bucket{name=\"$name\"}[$__rate_interval])))",
1175
+ "instant": false,
1176
+ "legendFormat": "P99",
1177
+ "range": true,
1178
+ "refId": "A"
1179
+ },
1180
+ {
1181
+ "datasource": {
1182
+ "type": "prometheus",
1183
+ "uid": "ddyfngn31dg5cf"
1184
+ },
1185
+ "editorMode": "code",
1186
+ "expr": "histogram_quantile(0.95, sum by (le) (rate(sglang:waiting_request_latency_seconds_bucket{name=\"$name\"}[$__rate_interval])))",
1187
+ "hide": false,
1188
+ "instant": false,
1189
+ "legendFormat": "P95",
1190
+ "range": true,
1191
+ "refId": "B"
1192
+ },
1193
+ {
1194
+ "datasource": {
1195
+ "type": "prometheus",
1196
+ "uid": "ddyfngn31dg5cf"
1197
+ },
1198
+ "editorMode": "code",
1199
+ "expr": "histogram_quantile(0.9, sum by (le) (rate(sglang:waiting_request_latency_seconds_bucket{name=\"$name\"}[$__rate_interval])))",
1200
+ "hide": false,
1201
+ "instant": false,
1202
+ "legendFormat": "P90",
1203
+ "range": true,
1204
+ "refId": "C"
1205
+ },
1206
+ {
1207
+ "datasource": {
1208
+ "type": "prometheus",
1209
+ "uid": "ddyfngn31dg5cf"
1210
+ },
1211
+ "editorMode": "code",
1212
+ "expr": "histogram_quantile(0.5, sum by (le) (rate(sglang:waiting_request_latency_seconds_bucket{name=\"$name\"}[$__rate_interval])))",
1213
+ "hide": false,
1214
+ "instant": false,
1215
+ "legendFormat": "P50",
1216
+ "range": true,
1217
+ "refId": "D"
1218
+ },
1219
+ {
1220
+ "datasource": {
1221
+ "type": "prometheus",
1222
+ "uid": "ddyfngn31dg5cf"
1223
+ },
1224
+ "editorMode": "code",
1225
+ "expr": "rate(sglang:waiting_request_latency_seconds_sum{name=\"$name\"}[$__rate_interval])\r\n/\r\nrate(sglang:waiting_request_latency_seconds_count{name=\"$name\"}[$__rate_interval])",
1226
+ "hide": false,
1227
+ "instant": false,
1228
+ "legendFormat": "Average",
1229
+ "range": true,
1230
+ "refId": "E"
1231
+ }
1232
+ ],
1233
+ "title": "Time Request Waiting (S)",
1234
+ "type": "timeseries"
1235
+ },
1236
+ {
1237
+ "datasource": {
1238
+ "default": true,
1239
+ "type": "prometheus",
1240
+ "uid": "ee2vha8w6f5kwf"
1241
+ },
1242
+ "fieldConfig": {
1243
+ "defaults": {
1244
+ "color": {
1245
+ "mode": "palette-classic"
1246
+ },
1247
+ "custom": {
1248
+ "axisBorderShow": false,
1249
+ "axisCenteredZero": false,
1250
+ "axisColorMode": "text",
1251
+ "axisLabel": "",
1252
+ "axisPlacement": "auto",
1253
+ "barAlignment": 0,
1254
+ "barWidthFactor": 0.6,
1255
+ "drawStyle": "line",
1256
+ "fillOpacity": 0,
1257
+ "gradientMode": "none",
1258
+ "hideFrom": {
1259
+ "legend": false,
1260
+ "tooltip": false,
1261
+ "viz": false
1262
+ },
1263
+ "insertNulls": false,
1264
+ "lineInterpolation": "linear",
1265
+ "lineWidth": 1,
1266
+ "pointSize": 5,
1267
+ "scaleDistribution": {
1268
+ "type": "linear"
1269
+ },
1270
+ "showPoints": "auto",
1271
+ "spanNulls": false,
1272
+ "stacking": {
1273
+ "group": "A",
1274
+ "mode": "none"
1275
+ },
1276
+ "thresholdsStyle": {
1277
+ "mode": "off"
1278
+ }
1279
+ },
1280
+ "mappings": [],
1281
+ "thresholds": {
1282
+ "mode": "absolute",
1283
+ "steps": [
1284
+ {
1285
+ "color": "green",
1286
+ "value": null
1287
+ },
1288
+ {
1289
+ "color": "red",
1290
+ "value": 80
1291
+ }
1292
+ ]
1293
+ }
1294
+ },
1295
+ "overrides": []
1296
+ },
1297
+ "gridPos": {
1298
+ "h": 8,
1299
+ "w": 12,
1300
+ "x": 0,
1301
+ "y": 27
1302
+ },
1303
+ "id": 11,
1304
+ "options": {
1305
+ "legend": {
1306
+ "calcs": [],
1307
+ "displayMode": "list",
1308
+ "placement": "bottom",
1309
+ "showLegend": true
1310
+ },
1311
+ "tooltip": {
1312
+ "mode": "single",
1313
+ "sort": "none"
1314
+ }
1315
+ },
1316
+ "targets": [
1317
+ {
1318
+ "datasource": {
1319
+ "type": "prometheus",
1320
+ "uid": "ddyfngn31dg5cf"
1321
+ },
1322
+ "disableTextWrap": false,
1323
+ "editorMode": "code",
1324
+ "expr": "sum(rate(sglang:request_prompt_tokens_sum{instance=\"$instance\", name=\"$name\"}[$__rate_interval])) by (instance, name)",
1325
+ "fullMetaSearch": false,
1326
+ "includeNullMetadata": true,
1327
+ "instant": false,
1328
+ "legendFormat": "{{__name__}}",
1329
+ "range": true,
1330
+ "refId": "A",
1331
+ "useBackend": false
1332
+ },
1333
+ {
1334
+ "datasource": {
1335
+ "type": "prometheus",
1336
+ "uid": "ddyfngn31dg5cf"
1337
+ },
1338
+ "disableTextWrap": false,
1339
+ "editorMode": "code",
1340
+ "expr": "",
1341
+ "fullMetaSearch": false,
1342
+ "hide": false,
1343
+ "includeNullMetadata": true,
1344
+ "instant": false,
1345
+ "legendFormat": "__auto",
1346
+ "range": true,
1347
+ "refId": "B",
1348
+ "useBackend": false
1349
+ }
1350
+ ],
1351
+ "title": "Prompt Tokens",
1352
+ "type": "timeseries"
1353
+ },
1354
+ {
1355
+ "datasource": {
1356
+ "default": true,
1357
+ "type": "prometheus",
1358
+ "uid": "ee2vha8w6f5kwf"
1359
+ },
1360
+ "fieldConfig": {
1361
+ "defaults": {
1362
+ "color": {
1363
+ "mode": "palette-classic"
1364
+ },
1365
+ "custom": {
1366
+ "axisBorderShow": false,
1367
+ "axisCenteredZero": false,
1368
+ "axisColorMode": "text",
1369
+ "axisLabel": "",
1370
+ "axisPlacement": "auto",
1371
+ "barAlignment": 0,
1372
+ "barWidthFactor": 0.6,
1373
+ "drawStyle": "line",
1374
+ "fillOpacity": 0,
1375
+ "gradientMode": "none",
1376
+ "hideFrom": {
1377
+ "legend": false,
1378
+ "tooltip": false,
1379
+ "viz": false
1380
+ },
1381
+ "insertNulls": false,
1382
+ "lineInterpolation": "linear",
1383
+ "lineWidth": 1,
1384
+ "pointSize": 5,
1385
+ "scaleDistribution": {
1386
+ "type": "linear"
1387
+ },
1388
+ "showPoints": "auto",
1389
+ "spanNulls": false,
1390
+ "stacking": {
1391
+ "group": "A",
1392
+ "mode": "none"
1393
+ },
1394
+ "thresholdsStyle": {
1395
+ "mode": "off"
1396
+ }
1397
+ },
1398
+ "mappings": [],
1399
+ "thresholds": {
1400
+ "mode": "absolute",
1401
+ "steps": [
1402
+ {
1403
+ "color": "green",
1404
+ "value": null
1405
+ },
1406
+ {
1407
+ "color": "red",
1408
+ "value": 80
1409
+ }
1410
+ ]
1411
+ }
1412
+ },
1413
+ "overrides": []
1414
+ },
1415
+ "gridPos": {
1416
+ "h": 8,
1417
+ "w": 12,
1418
+ "x": 12,
1419
+ "y": 27
1420
+ },
1421
+ "id": 17,
1422
+ "options": {
1423
+ "legend": {
1424
+ "calcs": [],
1425
+ "displayMode": "list",
1426
+ "placement": "bottom",
1427
+ "showLegend": true
1428
+ },
1429
+ "tooltip": {
1430
+ "mode": "single",
1431
+ "sort": "none"
1432
+ }
1433
+ },
1434
+ "targets": [
1435
+ {
1436
+ "datasource": {
1437
+ "type": "prometheus",
1438
+ "uid": "ddyfngn31dg5cf"
1439
+ },
1440
+ "disableTextWrap": false,
1441
+ "editorMode": "code",
1442
+ "expr": "sum(rate(sglang:request_generation_tokens_sum{instance=\"$instance\", name=\"$name\"}[$__rate_interval])) by (instance, name)",
1443
+ "fullMetaSearch": false,
1444
+ "includeNullMetadata": true,
1445
+ "instant": false,
1446
+ "legendFormat": "{{__name__}}",
1447
+ "range": true,
1448
+ "refId": "A",
1449
+ "useBackend": false
1450
+ }
1451
+ ],
1452
+ "title": "Generated Tokens",
1453
+ "type": "timeseries"
1454
+ },
1455
+ {
1456
+ "datasource": {
1457
+ "default": true,
1458
+ "type": "prometheus",
1459
+ "uid": "ee2vha8w6f5kwf"
1460
+ },
1461
+ "fieldConfig": {
1462
+ "defaults": {
1463
+ "custom": {
1464
+ "hideFrom": {
1465
+ "legend": false,
1466
+ "tooltip": false,
1467
+ "viz": false
1468
+ },
1469
+ "scaleDistribution": {
1470
+ "type": "linear"
1471
+ }
1472
+ }
1473
+ },
1474
+ "overrides": []
1475
+ },
1476
+ "gridPos": {
1477
+ "h": 8,
1478
+ "w": 12,
1479
+ "x": 0,
1480
+ "y": 35
1481
+ },
1482
+ "id": 13,
1483
+ "options": {
1484
+ "calculate": false,
1485
+ "calculation": {
1486
+ "yBuckets": {
1487
+ "scale": {
1488
+ "log": 2,
1489
+ "type": "log"
1490
+ }
1491
+ }
1492
+ },
1493
+ "cellGap": 1,
1494
+ "color": {
1495
+ "exponent": 0.5,
1496
+ "fill": "dark-orange",
1497
+ "mode": "scheme",
1498
+ "reverse": false,
1499
+ "scale": "exponential",
1500
+ "scheme": "Oranges",
1501
+ "steps": 64
1502
+ },
1503
+ "exemplars": {
1504
+ "color": "rgba(255,0,255,0.7)"
1505
+ },
1506
+ "filterValues": {
1507
+ "le": 1e-9
1508
+ },
1509
+ "legend": {
1510
+ "show": true
1511
+ },
1512
+ "rowsFrame": {
1513
+ "layout": "auto"
1514
+ },
1515
+ "tooltip": {
1516
+ "mode": "single",
1517
+ "showColorScale": false,
1518
+ "yHistogram": false
1519
+ },
1520
+ "yAxis": {
1521
+ "axisPlacement": "left",
1522
+ "reverse": false
1523
+ }
1524
+ },
1525
+ "pluginVersion": "11.2.0",
1526
+ "targets": [
1527
+ {
1528
+ "datasource": {
1529
+ "type": "prometheus",
1530
+ "uid": "ddyfngn31dg5cf"
1531
+ },
1532
+ "disableTextWrap": false,
1533
+ "editorMode": "code",
1534
+ "expr": "sum by(le) (increase(sglang:request_prompt_tokens_bucket{name=\"$name\", instance=\"$instance\"}[$__rate_interval]))",
1535
+ "fullMetaSearch": false,
1536
+ "includeNullMetadata": true,
1537
+ "instant": false,
1538
+ "legendFormat": "{{__name__}}",
1539
+ "range": true,
1540
+ "refId": "A",
1541
+ "useBackend": false
1542
+ }
1543
+ ],
1544
+ "title": "Request Prompt Tokens",
1545
+ "type": "heatmap"
1546
+ },
1547
+ {
1548
+ "datasource": {
1549
+ "default": true,
1550
+ "type": "prometheus",
1551
+ "uid": "ee2vha8w6f5kwf"
1552
+ },
1553
+ "description": "",
1554
+ "fieldConfig": {
1555
+ "defaults": {
1556
+ "custom": {
1557
+ "hideFrom": {
1558
+ "legend": false,
1559
+ "tooltip": false,
1560
+ "viz": false
1561
+ },
1562
+ "scaleDistribution": {
1563
+ "type": "linear"
1564
+ }
1565
+ }
1566
+ },
1567
+ "overrides": []
1568
+ },
1569
+ "gridPos": {
1570
+ "h": 8,
1571
+ "w": 12,
1572
+ "x": 12,
1573
+ "y": 35
1574
+ },
1575
+ "id": 12,
1576
+ "options": {
1577
+ "calculate": false,
1578
+ "calculation": {
1579
+ "xBuckets": {
1580
+ "mode": "size",
1581
+ "value": ""
1582
+ },
1583
+ "yBuckets": {
1584
+ "mode": "size",
1585
+ "scale": {
1586
+ "log": 2,
1587
+ "type": "log"
1588
+ },
1589
+ "value": ""
1590
+ }
1591
+ },
1592
+ "cellGap": 1,
1593
+ "color": {
1594
+ "exponent": 0.5,
1595
+ "fill": "dark-orange",
1596
+ "min": 0,
1597
+ "mode": "scheme",
1598
+ "reverse": false,
1599
+ "scale": "exponential",
1600
+ "scheme": "Spectral",
1601
+ "steps": 64
1602
+ },
1603
+ "exemplars": {
1604
+ "color": "rgba(255,0,255,0.7)"
1605
+ },
1606
+ "filterValues": {
1607
+ "le": 1e-9
1608
+ },
1609
+ "legend": {
1610
+ "show": true
1611
+ },
1612
+ "rowsFrame": {
1613
+ "layout": "auto",
1614
+ "value": "Request count"
1615
+ },
1616
+ "tooltip": {
1617
+ "mode": "single",
1618
+ "showColorScale": false,
1619
+ "yHistogram": true
1620
+ },
1621
+ "yAxis": {
1622
+ "axisLabel": "Generation Length",
1623
+ "axisPlacement": "left",
1624
+ "reverse": false,
1625
+ "unit": "none"
1626
+ }
1627
+ },
1628
+ "pluginVersion": "11.2.0",
1629
+ "targets": [
1630
+ {
1631
+ "datasource": {
1632
+ "type": "prometheus",
1633
+ "uid": "ddyfngn31dg5cf"
1634
+ },
1635
+ "disableTextWrap": false,
1636
+ "editorMode": "code",
1637
+ "expr": "sum by(le) (increase(sglang:request_generation_tokens_bucket{name=\"$name\", instance=\"$instance\"}[$__rate_interval]))",
1638
+ "fullMetaSearch": false,
1639
+ "includeNullMetadata": true,
1640
+ "instant": false,
1641
+ "legendFormat": "{{__name__}}",
1642
+ "range": true,
1643
+ "refId": "A",
1644
+ "useBackend": false
1645
+ }
1646
+ ],
1647
+ "title": "Request Generation Tokens",
1648
+ "type": "heatmap"
1649
+ }
1650
+ ],
1651
+ "refresh": "5s",
1652
+ "schemaVersion": 39,
1653
+ "tags": [],
1654
+ "templating": {
1655
+ "list": [
1656
+ {
1657
+ "current": {
1658
+ "selected": false,
1659
+ "text": "127.0.0.1:30000",
1660
+ "value": "127.0.0.1:30000"
1661
+ },
1662
+ "datasource": {
1663
+ "type": "prometheus",
1664
+ "uid": "ddyfngn31dg5cf"
1665
+ },
1666
+ "definition": "label_values(instance)",
1667
+ "hide": 0,
1668
+ "includeAll": false,
1669
+ "label": "instance",
1670
+ "multi": false,
1671
+ "name": "instance",
1672
+ "options": [],
1673
+ "query": {
1674
+ "qryType": 1,
1675
+ "query": "label_values(instance)",
1676
+ "refId": "PrometheusVariableQueryEditor-VariableQuery"
1677
+ },
1678
+ "refresh": 1,
1679
+ "regex": "",
1680
+ "skipUrlSync": false,
1681
+ "sort": 0,
1682
+ "type": "query"
1683
+ },
1684
+ {
1685
+ "current": {
1686
+ "selected": true,
1687
+ "text": "google/gemma-2-9b-it",
1688
+ "value": "google/gemma-2-9b-it"
1689
+ },
1690
+ "definition": "label_values(name)",
1691
+ "hide": 1,
1692
+ "includeAll": false,
1693
+ "label": "name",
1694
+ "multi": false,
1695
+ "name": "name",
1696
+ "options": [],
1697
+ "query": {
1698
+ "qryType": 1,
1699
+ "query": "label_values(name)",
1700
+ "refId": "PrometheusVariableQueryEditor-VariableQuery"
1701
+ },
1702
+ "refresh": 1,
1703
+ "regex": "",
1704
+ "skipUrlSync": false,
1705
+ "sort": 0,
1706
+ "type": "query"
1707
+ }
1708
+ ]
1709
+ },
1710
+ "time": {
1711
+ "from": "now-30m",
1712
+ "to": "now"
1713
+ },
1714
+ "timepicker": {},
1715
+ "timezone": "browser",
1716
+ "title": "SGLang Dashboard",
1717
+ "uid": "ddyp55uq7brpcc",
1718
+ "version": 3,
1719
+ "weekStart": ""
1720
+ }
sglang/examples/monitoring/prometheus.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # prometheus.yaml
2
+ global:
3
+ scrape_interval: 5s
4
+ evaluation_interval: 30s
5
+
6
+ scrape_configs:
7
+ - job_name: sglang
8
+ static_configs:
9
+ - targets:
10
+ - '127.0.0.1:30000'
sglang/examples/runtime/lora.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # launch server
2
+ # python -m sglang.launch_server --model mistralai/Mistral-7B-Instruct-v0.3 --lora-paths /home/ying/test_lora lora1=/home/ying/test_lora_1 lora2=/home/ying/test_lora_2 --disable-radix --disable-cuda-graph --max-loras-per-batch 4
3
+
4
+ # send requests
5
+ # lora_path[i] specifies the LoRA used for text[i], so make sure they have the same length
6
+ # use None to specify base-only prompt, e.x. "lora_path": [None, "/home/ying/test_lora"]
7
+ import json
8
+
9
+ import requests
10
+
11
+ url = "http://127.0.0.1:30000"
12
+ json_data = {
13
+ "text": [
14
+ "prompt 1",
15
+ "prompt 2",
16
+ "prompt 3",
17
+ "prompt 4",
18
+ "prompt 5",
19
+ "prompt 6",
20
+ "prompt 7",
21
+ ],
22
+ "sampling_params": {"max_new_tokens": 32},
23
+ "lora_path": [
24
+ "/home/ying/test_lora",
25
+ "lora1",
26
+ "lora2",
27
+ "lora1",
28
+ "lora2",
29
+ None,
30
+ None,
31
+ ],
32
+ }
33
+ response = requests.post(
34
+ url + "/generate",
35
+ json=json_data,
36
+ )
37
+ print(json.dumps(response.json()))
sglang/examples/runtime/openai_batch_complete.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Usage:
3
+ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
4
+ python openai_batch_complete.py
5
+ Note: Before running this script,
6
+ you should create the input.jsonl file with the following content:
7
+ {"custom_id": "request-1", "method": "POST", "url": "/v1/completions", "body": {"model": "gpt-3.5-turbo-instruct", "prompt": "List 3 names of famous soccer player: ", "max_tokens": 200}}
8
+ {"custom_id": "request-2", "method": "POST", "url": "/v1/completions", "body": {"model": "gpt-3.5-turbo-instruct", "prompt": "List 6 names of famous basketball player: ", "max_tokens": 400}}
9
+ {"custom_id": "request-3", "method": "POST", "url": "/v1/completions", "body": {"model": "gpt-3.5-turbo-instruct", "prompt": "List 6 names of famous basketball player: ", "max_tokens": 400}}
10
+ """
11
+
12
+ import json
13
+ import time
14
+
15
+ import openai
16
+
17
+
18
+ class OpenAIBatchProcessor:
19
+ def __init__(self):
20
+ client = openai.Client(base_url="http://127.0.0.1:30000/v1", api_key="EMPTY")
21
+
22
+ self.client = client
23
+
24
+ def process_batch(self, input_file_path, endpoint, completion_window):
25
+
26
+ # Upload the input file
27
+ with open(input_file_path, "rb") as file:
28
+ uploaded_file = self.client.files.create(file=file, purpose="batch")
29
+
30
+ # Create the batch job
31
+ batch_job = self.client.batches.create(
32
+ input_file_id=uploaded_file.id,
33
+ endpoint=endpoint,
34
+ completion_window=completion_window,
35
+ )
36
+
37
+ # Monitor the batch job status
38
+ while batch_job.status not in ["completed", "failed", "cancelled"]:
39
+ time.sleep(3) # Wait for 3 seconds before checking the status again
40
+ print(
41
+ f"Batch job status: {batch_job.status}...trying again in 3 seconds..."
42
+ )
43
+ batch_job = self.client.batches.retrieve(batch_job.id)
44
+
45
+ # Check the batch job status and errors
46
+ if batch_job.status == "failed":
47
+ print(f"Batch job failed with status: {batch_job.status}")
48
+ print(f"Batch job errors: {batch_job.errors}")
49
+ return None
50
+
51
+ # If the batch job is completed, process the results
52
+ if batch_job.status == "completed":
53
+
54
+ # print result of batch job
55
+ print("batch", batch_job.request_counts)
56
+
57
+ result_file_id = batch_job.output_file_id
58
+ # Retrieve the file content from the server
59
+ file_response = self.client.files.content(result_file_id)
60
+ result_content = file_response.read() # Read the content of the file
61
+
62
+ # Save the content to a local file
63
+ result_file_name = "batch_job_complete_results.jsonl"
64
+ with open(result_file_name, "wb") as file:
65
+ file.write(result_content) # Write the binary content to the file
66
+ # Load data from the saved JSONL file
67
+ results = []
68
+ with open(result_file_name, "r", encoding="utf-8") as file:
69
+ for line in file:
70
+ json_object = json.loads(
71
+ line.strip()
72
+ ) # Parse each line as a JSON object
73
+ results.append(json_object)
74
+
75
+ return results
76
+ else:
77
+ print(f"Batch job failed with status: {batch_job.status}")
78
+ return None
79
+
80
+
81
+ # Initialize the OpenAIBatchProcessor
82
+ processor = OpenAIBatchProcessor()
83
+
84
+ # Process the batch job
85
+ input_file_path = "input.jsonl"
86
+ endpoint = "/v1/completions"
87
+ completion_window = "24h"
88
+
89
+ # Process the batch job
90
+ results = processor.process_batch(input_file_path, endpoint, completion_window)
91
+
92
+ # Print the results
93
+ print(results)
sglang/examples/runtime/openai_chat_with_response_prefill.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Usage:
3
+ python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --port 30000
4
+ python openai_chat.py
5
+ """
6
+
7
+ import openai
8
+ from openai import OpenAI
9
+
10
+ client = openai.Client(base_url="http://127.0.0.1:30000/v1", api_key="EMPTY")
11
+
12
+ response = client.chat.completions.create(
13
+ model="meta-llama/Llama-3.1-8B-Instruct",
14
+ messages=[
15
+ {"role": "system", "content": "You are a helpful AI assistant"},
16
+ {
17
+ "role": "user",
18
+ "content": """
19
+ Extract the name, size, price, and color from this product description as a JSON object:
20
+
21
+ <description>
22
+ The SmartHome Mini is a compact smart home assistant available in black or white for only $49.99. At just 5 inches wide, it lets you control lights, thermostats, and other connected devices via voice or app—no matter where you place it in your home. This affordable little hub brings convenient hands-free control to your smart devices.
23
+ </description>
24
+ """,
25
+ },
26
+ {
27
+ "role": "assistant",
28
+ "content": "{\n",
29
+ },
30
+ ],
31
+ temperature=0,
32
+ )
33
+
34
+ print(response.choices[0].message.content)
sglang/scripts/deprecated/convert_yi_vl.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Convert Yi-VL config into a format useable with SGLang
3
+
4
+ Usage: python3 scripts/convert_yi_vl.py --model-path <path-to-model>
5
+ """
6
+
7
+ import argparse
8
+ import json
9
+ import os
10
+
11
+ from transformers import AutoConfig, AutoTokenizer
12
+
13
+
14
+ def add_image_token(model_path: str):
15
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
16
+ tokenizer.add_tokens(["<image_placeholder>"], special_tokens=True)
17
+
18
+ print(tokenizer)
19
+ tokenizer.save_pretrained(model_path)
20
+
21
+
22
+ def edit_model_config(model_path):
23
+ config = AutoConfig.from_pretrained(model_path)
24
+
25
+ setattr(config, "architectures", ["YiVLForCausalLM"])
26
+ setattr(config, "image_token_index", 64002)
27
+
28
+ print(config)
29
+ config.save_pretrained(model_path)
30
+
31
+
32
+ if __name__ == "__main__":
33
+ parser = argparse.ArgumentParser()
34
+ parser.add_argument("--model-path", type=str)
35
+ args = parser.parse_args()
36
+
37
+ add_image_token(args.model_path)
38
+ edit_model_config(args.model_path)
sglang/scripts/deprecated/test_httpserver_classify.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Usage:
3
+ python3 -m sglang.launch_server --model-path /model/llama-classification --is-embedding --disable-radix-cache
4
+
5
+ python3 test_httpserver_classify.py
6
+ """
7
+
8
+ import argparse
9
+
10
+ import numpy as np
11
+ import requests
12
+
13
+
14
+ def get_logits_deprecated(url: str, prompt: str):
15
+ response = requests.post(
16
+ url + "/generate",
17
+ json={
18
+ "text": prompt,
19
+ "sampling_params": {
20
+ "max_new_tokens": 0,
21
+ },
22
+ "return_logprob": True,
23
+ },
24
+ )
25
+ return response.json()["meta_info"]["normalized_prompt_logprob"]
26
+
27
+
28
+ def get_logits_batch_deprecated(url: str, prompts: list[str]):
29
+ response = requests.post(
30
+ url + "/generate",
31
+ json={
32
+ "text": prompts,
33
+ "sampling_params": {
34
+ "max_new_tokens": 0,
35
+ },
36
+ "return_logprob": True,
37
+ },
38
+ )
39
+ ret = response.json()
40
+ logits = np.array(
41
+ list(
42
+ ret[i]["meta_info"]["normalized_prompt_logprob"]
43
+ for i in range(len(prompts))
44
+ )
45
+ )
46
+ return logits
47
+
48
+
49
+ def get_logits(url: str, prompt: str):
50
+ response = requests.post(
51
+ url + "/classify",
52
+ json={"text": prompt},
53
+ )
54
+ return response.json()["embedding"]
55
+
56
+
57
+ def get_logits_batch(url: str, prompts: list[str]):
58
+ response = requests.post(
59
+ url + "/classify",
60
+ json={"text": prompts},
61
+ )
62
+ return np.array([x["embedding"] for x in response.json()])
63
+
64
+
65
+ if __name__ == "__main__":
66
+ parser = argparse.ArgumentParser()
67
+ parser.add_argument("--host", type=str, default="http://127.0.0.1")
68
+ parser.add_argument("--port", type=int, default=30000)
69
+ args = parser.parse_args()
70
+
71
+ url = f"{args.host}:{args.port}"
72
+
73
+ # A single request
74
+ prompt = "This is a test prompt.<|eot_id|>"
75
+ logits = get_logits(url, prompt)
76
+ print(f"{logits=}")
77
+
78
+ # A batch of requests
79
+ prompts = [
80
+ "This is a test prompt.<|eot_id|>",
81
+ "This is another test prompt.<|eot_id|>",
82
+ "This is a long long long long test prompt.<|eot_id|>",
83
+ ]
84
+ logits = get_logits_batch(url, prompts)
85
+ print(f"{logits=}")
sglang/scripts/deprecated/test_httpserver_decode_stream.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Usage:
3
+ python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --port 30000
4
+ python3 test_httpserver_decode_stream.py
5
+
6
+ Output:
7
+ The capital of France is Paris.\nThe capital of the United States is Washington, D.C.\nThe capital of Canada is Ottawa.\nThe capital of Japan is Tokyo
8
+ """
9
+
10
+ import argparse
11
+ import json
12
+
13
+ import requests
14
+
15
+
16
+ def test_decode_stream(url, return_logprob, top_logprobs_num):
17
+ response = requests.post(
18
+ url + "/generate",
19
+ json={
20
+ "text": "The capital of France is",
21
+ "sampling_params": {
22
+ "temperature": 0,
23
+ "max_new_tokens": 128,
24
+ },
25
+ "stream": True,
26
+ "return_logprob": return_logprob,
27
+ "top_logprobs_num": top_logprobs_num,
28
+ "return_text_in_logprobs": True,
29
+ "logprob_start_len": 0,
30
+ },
31
+ stream=True,
32
+ )
33
+
34
+ prev = 0
35
+ for chunk in response.iter_lines(decode_unicode=False):
36
+ chunk = chunk.decode("utf-8")
37
+ if chunk and chunk.startswith("data:"):
38
+ if chunk == "data: [DONE]":
39
+ break
40
+ data = json.loads(chunk[5:].strip("\n"))
41
+
42
+ if return_logprob:
43
+ assert data["meta_info"]["input_token_logprobs"] is not None
44
+ assert data["meta_info"]["output_token_logprobs"] is not None
45
+ assert data["meta_info"]["normalized_prompt_logprob"] is not None
46
+ for logprob, token_id, token_text in data["meta_info"][
47
+ "output_token_logprobs"
48
+ ][prev:]:
49
+ print(f"{token_text:12s}\t{logprob}\t{token_id}", flush=True)
50
+ prev = len(data["meta_info"]["output_token_logprobs"])
51
+ else:
52
+ output = data["text"].strip()
53
+ print(output[prev:], end="", flush=True)
54
+ prev = len(output)
55
+
56
+ print("=" * 100)
57
+
58
+
59
+ if __name__ == "__main__":
60
+ parser = argparse.ArgumentParser()
61
+ parser.add_argument("--host", type=str, default="http://127.0.0.1")
62
+ parser.add_argument("--port", type=int, default=30000)
63
+ args = parser.parse_args()
64
+
65
+ url = f"{args.host}:{args.port}"
66
+
67
+ test_decode_stream(url, False, 0)
68
+ test_decode_stream(url, True, 0)
69
+ test_decode_stream(url, True, 3)
sglang/scripts/deprecated/test_httpserver_llava.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Usage:
3
+ python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --port 30000
4
+ python3 test_httpserver_llava.py
5
+
6
+ Output:
7
+ The image features a man standing on the back of a yellow taxi cab, holding
8
+ """
9
+
10
+ import argparse
11
+ import asyncio
12
+ import json
13
+
14
+ import aiohttp
15
+ import requests
16
+
17
+
18
+ async def send_request(url, data, delay=0):
19
+ await asyncio.sleep(delay)
20
+ async with aiohttp.ClientSession() as session:
21
+ async with session.post(url, json=data) as resp:
22
+ output = await resp.json()
23
+ return output
24
+
25
+
26
+ async def test_concurrent(args):
27
+ url = f"{args.host}:{args.port}"
28
+
29
+ response = []
30
+ for i in range(8):
31
+ response.append(
32
+ send_request(
33
+ url + "/generate",
34
+ {
35
+ "text": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <image>\nDescribe this picture ASSISTANT:",
36
+ "image_data": "example_image.png",
37
+ "sampling_params": {
38
+ "temperature": 0,
39
+ "max_new_tokens": 64,
40
+ },
41
+ },
42
+ )
43
+ )
44
+
45
+ rets = await asyncio.gather(*response)
46
+ for ret in rets:
47
+ print(ret["text"])
48
+
49
+
50
+ def test_streaming(args):
51
+ url = f"{args.host}:{args.port}"
52
+
53
+ response = requests.post(
54
+ url + "/generate",
55
+ json={
56
+ "text": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <image>\nDescribe this picture ASSISTANT:",
57
+ "image_data": "example_image.png",
58
+ "sampling_params": {
59
+ "temperature": 0,
60
+ "max_new_tokens": 128,
61
+ },
62
+ "stream": True,
63
+ },
64
+ stream=True,
65
+ )
66
+
67
+ prev = 0
68
+ for chunk in response.iter_lines(decode_unicode=False):
69
+ chunk = chunk.decode("utf-8")
70
+ if chunk and chunk.startswith("data:"):
71
+ if chunk == "data: [DONE]":
72
+ break
73
+ data = json.loads(chunk[5:].strip("\n"))
74
+ output = data["text"].strip()
75
+ print(output[prev:], end="", flush=True)
76
+ prev = len(output)
77
+ print("")
78
+
79
+
80
+ if __name__ == "__main__":
81
+ parser = argparse.ArgumentParser()
82
+ parser.add_argument("--host", type=str, default="http://127.0.0.1")
83
+ parser.add_argument("--port", type=int, default=30000)
84
+ args = parser.parse_args()
85
+
86
+ asyncio.run(test_concurrent(args))
87
+
88
+ test_streaming(args)
sglang/scripts/deprecated/test_httpserver_reuse.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --port 30000
3
+
4
+ Output:
5
+ The capital of France is Paris.\nThe capital of the United States is Washington, D.C.\nThe capital of Canada is Ottawa.\nThe capital of Japan is Tokyo
6
+ """
7
+
8
+ import argparse
9
+
10
+ import requests
11
+
12
+ if __name__ == "__main__":
13
+ parser = argparse.ArgumentParser()
14
+ parser.add_argument("--host", type=str, default="http://127.0.0.1")
15
+ parser.add_argument("--port", type=int, default=30000)
16
+ args = parser.parse_args()
17
+
18
+ url = f"{args.host}:{args.port}"
19
+
20
+ response = requests.post(
21
+ url + "/generate",
22
+ json={
23
+ "text": "The capital of France is",
24
+ "sampling_params": {
25
+ "temperature": 0,
26
+ "max_new_tokens": 32,
27
+ },
28
+ },
29
+ )
30
+ print(response.json())
31
+
32
+ response = requests.post(
33
+ url + "/generate",
34
+ json={
35
+ "text": "The capital of France is Paris.\nThe capital of the United States is",
36
+ "sampling_params": {
37
+ "temperature": 0,
38
+ "max_new_tokens": 32,
39
+ },
40
+ },
41
+ )
42
+ print(response.json())
sglang/scripts/deprecated/test_jump_forward.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from enum import Enum
3
+
4
+ from pydantic import BaseModel, constr
5
+
6
+ import sglang as sgl
7
+ from sglang.srt.constrained import build_regex_from_object
8
+ from sglang.test.test_utils import (
9
+ add_common_sglang_args_and_parse,
10
+ select_sglang_backend,
11
+ )
12
+
13
+ IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
14
+
15
+ ip_jump_forward = (
16
+ r"The google's DNS sever address is "
17
+ + IP_REGEX
18
+ + r" and "
19
+ + IP_REGEX
20
+ + r". "
21
+ + r"The google's website domain name is "
22
+ + r"www\.(\w)+\.(\w)+"
23
+ + r"."
24
+ )
25
+
26
+
27
+ # fmt: off
28
+ @sgl.function
29
+ def regex_gen(s):
30
+ s += "Q: What is the IP address of the Google DNS servers?\n"
31
+ s += "A: " + sgl.gen(
32
+ "answer",
33
+ max_tokens=128,
34
+ temperature=0,
35
+ regex=ip_jump_forward,
36
+ )
37
+ # fmt: on
38
+
39
+ json_jump_forward = (
40
+ r"""The information about Hogwarts is in the following JSON format\.\n"""
41
+ + r"""\n\{\n"""
42
+ + r""" "name": "[\w\d\s]*",\n"""
43
+ + r""" "country": "[\w\d\s]*",\n"""
44
+ + r""" "latitude": [-+]?[0-9]*\.?[0-9]+,\n"""
45
+ + r""" "population": [-+]?[0-9]+,\n"""
46
+ + r""" "top 3 landmarks": \["[\w\d\s]*", "[\w\d\s]*", "[\w\d\s]*"\],\n"""
47
+ + r"""\}\n"""
48
+ )
49
+
50
+ # fmt: off
51
+ @sgl.function
52
+ def json_gen(s):
53
+ s += sgl.gen(
54
+ "json",
55
+ max_tokens=128,
56
+ temperature=0,
57
+ regex=json_jump_forward,
58
+ )
59
+ # fmt: on
60
+
61
+
62
+ class Weapon(str, Enum):
63
+ sword = "sword"
64
+ axe = "axe"
65
+ mace = "mace"
66
+ spear = "spear"
67
+ bow = "bow"
68
+ crossbow = "crossbow"
69
+
70
+
71
+ class Armor(str, Enum):
72
+ leather = "leather"
73
+ chainmail = "chainmail"
74
+ plate = "plate"
75
+
76
+
77
+ class Character(BaseModel):
78
+ name: constr(max_length=10)
79
+ age: int
80
+ armor: Armor
81
+ weapon: Weapon
82
+ strength: int
83
+
84
+
85
+ @sgl.function
86
+ def character_gen(s):
87
+ s += "Give me a character description who is a wizard.\n"
88
+ s += sgl.gen(
89
+ "character",
90
+ max_tokens=128,
91
+ temperature=0,
92
+ regex=build_regex_from_object(Character),
93
+ )
94
+
95
+
96
+ def main(args):
97
+ # Select backend
98
+ backend = select_sglang_backend(args)
99
+ sgl.set_default_backend(backend)
100
+
101
+ state = regex_gen.run(temperature=0)
102
+
103
+ print("=" * 20, "IP TEST", "=" * 20)
104
+ print(state.text())
105
+
106
+ state = json_gen.run(temperature=0)
107
+
108
+ print("=" * 20, "JSON TEST", "=" * 20)
109
+ print(state.text())
110
+
111
+ state = character_gen.run(temperature=0)
112
+
113
+ print("=" * 20, "CHARACTER TEST", "=" * 20)
114
+ print(state.text())
115
+
116
+
117
+ if __name__ == "__main__":
118
+ parser = argparse.ArgumentParser()
119
+ args = add_common_sglang_args_and_parse(parser)
120
+ main(args)
121
+
122
+ # ==================== IP TEST ====================
123
+ # Q: What is the IP address of the Google DNS servers?
124
+ # A: The google's DNS sever address is 8.8.8.8 and 8.8.4.4. The google's website domain name is www.google.com.
125
+ # ==================== JSON TEST ====================
126
+ # The information about Hogwarts is in the following JSON format.
127
+
128
+ # {
129
+ # "name": "Hogwarts School of Witchcraft and Wizardry",
130
+ # "country": "Scotland",
131
+ # "latitude": 55.566667,
132
+ # "population": 1000,
133
+ # "top 3 landmarks": ["Hogwarts Castle", "The Great Hall", "The Forbidden Forest"],
134
+ # }
135
+
136
+ # ==================== CHARACTER TEST ====================
137
+ # Give me a character description who is a wizard.
138
+ # { "name" : "Merlin", "age" : 500, "armor" : "chainmail" , "weapon" : "sword" , "strength" : 10 }
sglang/scripts/deprecated/test_robust.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import random
3
+ import string
4
+
5
+ from vllm.transformers_utils.tokenizer import get_tokenizer
6
+
7
+ import sglang as sgl
8
+ from sglang.test.test_utils import (
9
+ add_common_sglang_args_and_parse,
10
+ select_sglang_backend,
11
+ )
12
+
13
+ TOKENIZER = None
14
+ RANDOM_PREFILL_LEN = None
15
+ RANDOM_DECODE_LEN = None
16
+
17
+
18
+ def gen_prompt(token_num):
19
+ if RANDOM_PREFILL_LEN:
20
+ token_num = random.randint(1, token_num)
21
+
22
+ cha_set = string.ascii_letters + string.digits
23
+ ret = "".join(random.choices(cha_set, k=token_num))
24
+ while len(TOKENIZER(ret).input_ids) < token_num:
25
+ ret += random.choice(cha_set)
26
+
27
+ return ret
28
+
29
+
30
+ def robust_test_dfs(s, d, args, leaf_states):
31
+ if d == 0:
32
+ s += "END"
33
+ leaf_states.append(s)
34
+ return
35
+
36
+ s += gen_prompt(args.len_prefill)
37
+ forks = s.fork(args.num_fork)
38
+ for fork_s in forks:
39
+ fork_s += gen_prompt(args.len_prefill)
40
+ new_tokens = (
41
+ args.len_decode
42
+ if not RANDOM_DECODE_LEN
43
+ else random.randint(1, args.len_decode)
44
+ )
45
+ fork_s += sgl.gen(
46
+ max_tokens=new_tokens,
47
+ ignore_eos=True,
48
+ )
49
+
50
+ for fork_s in forks:
51
+ robust_test_dfs(fork_s, d - 1, args, leaf_states)
52
+
53
+
54
+ def robust_test_bfs(s, args, leaf_states):
55
+ old_forks = [s]
56
+ new_forks = []
57
+ for _ in range(args.depth):
58
+ for old_fork in old_forks:
59
+ old_fork += gen_prompt(args.len_prefill)
60
+ forks = old_fork.fork(args.num_fork)
61
+ for fork_s in forks:
62
+ fork_s += gen_prompt(args.len_prefill)
63
+ new_tokens = (
64
+ args.len_decode
65
+ if not RANDOM_DECODE_LEN
66
+ else random.randint(1, args.len_decode)
67
+ )
68
+ fork_s += sgl.gen(
69
+ max_tokens=new_tokens,
70
+ ignore_eos=True,
71
+ )
72
+ new_forks.extend(forks)
73
+
74
+ old_forks = new_forks
75
+ new_forks = []
76
+
77
+ for old_fork in old_forks:
78
+ old_fork += "END"
79
+ leaf_states.append(old_fork)
80
+
81
+
82
+ @sgl.function
83
+ def robust_test(s, args):
84
+ leaf_states = []
85
+ if args.mode == "bfs":
86
+ robust_test_bfs(s, args, leaf_states)
87
+ else:
88
+ robust_test_dfs(s, args.depth, args, leaf_states)
89
+ return leaf_states
90
+
91
+
92
+ def main(args):
93
+ backend = select_sglang_backend(args)
94
+
95
+ arguments = [{"args": args} for _ in range(args.num_req)]
96
+
97
+ states = robust_test.run_batch(
98
+ arguments, temperature=0, backend=backend, num_threads=args.parallel
99
+ )
100
+
101
+ with open(f"tmp_robust_{args.mode}.txt", "w") as f:
102
+ for state in states:
103
+ leaf_states = state.ret_value
104
+ for leaf_state in leaf_states:
105
+ assert leaf_state.text()[-3:] == "END"
106
+ f.write(leaf_state.text()[:-3] + "\n")
107
+
108
+
109
+ if __name__ == "__main__":
110
+ # fmt: off
111
+ parser = argparse.ArgumentParser()
112
+ parser.add_argument("--num-req", type=int, default=2)
113
+ parser.add_argument("--depth", type=int, default=3)
114
+ parser.add_argument("--num-fork", type=int, default=2)
115
+ parser.add_argument("--len-prefill", type=int, default=128)
116
+ parser.add_argument("--len-decode", type=int, default=128)
117
+ parser.add_argument("--random-prefill-len", action="store_true")
118
+ parser.add_argument("--random-decode-len", action="store_true")
119
+ parser.add_argument("--mode", type=str, default="bfs", choices=["dfs", "bfs"])
120
+ parser.add_argument("--tokenizer", type=str, default = "meta-llama/Llama-2-7b-chat-hf")
121
+ parser.add_argument("--trust-remote-code", action="store_true")
122
+ parser.add_argument("--seed", type=int, default=42)
123
+ args = add_common_sglang_args_and_parse(parser)
124
+ # fmt: on
125
+
126
+ RANDOM_PREFILL_LEN = args.random_prefill_len
127
+ RANDOM_DECODE_LEN = args.random_decode_len
128
+ TOKENIZER = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code)
129
+
130
+ random.seed(args.seed)
131
+
132
+ main(args)