upload model folder to repo
Browse files- .gitattributes +4 -0
- README.md +27 -26
- deploy_align_ds_v.sh +25 -0
- examples/PKU.jpg +0 -0
- examples/baby.mp4 +3 -0
- examples/boya.jpg +0 -0
- examples/drum.wav +3 -0
- examples/laugh.wav +3 -0
- examples/logo.jpg +0 -0
- examples/scream.wav +3 -0
- multi_image_inference.py +206 -0
- stream_inference.py +210 -0
.gitattributes
CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
examples/baby.mp4 filter=lfs diff=lfs merge=lfs -text
|
37 |
+
examples/drum.wav filter=lfs diff=lfs merge=lfs -text
|
38 |
+
examples/laugh.wav filter=lfs diff=lfs merge=lfs -text
|
39 |
+
examples/scream.wav filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
@@ -1,26 +1,27 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
```
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
|
|
|
1 |
+
# Deployment Scripts for Align-DS-V (Built with Gradio)
|
2 |
+
|
3 |
+
This document provides instructions for deploying the Align-DS-V model for inference using Gradio.
|
4 |
+
|
5 |
+
1. **Set up the Conda environment:** Follow the instructions in the [PKU-Alignment/align-anything](https://github.com/PKU-Alignment/align-anything) repository to configure your Conda environment.
|
6 |
+
2. **Configure the model path:** After setting up the environment, update the `BASE_MODEL_PATH` variable in `deploy_align_ds_v.sh` to point to your local Align-DS-V model directory.
|
7 |
+
3. **Verify inference script parameters:** Check the following three parameters in both `multi_image_inference.py` and `stream_inference.py`:
|
8 |
+
```python
|
9 |
+
openai_api_key = "pku" # Or your specific API key if needed
|
10 |
+
openai_api_base = "http://0.0.0.0:8231/v1" # Ensure this matches the deployment port
|
11 |
+
# NOTE: Replace with your own model path if not loaded via the API base
|
12 |
+
model = ''
|
13 |
+
```
|
14 |
+
These scripts utilize an OpenAI-compatible server approach. The `deploy_align_ds_v.sh` script launches the Align-DS-V model locally and exposes it on port 8231 for external access via the specified API base URL.
|
15 |
+
|
16 |
+
4. **Running Inference:**
|
17 |
+
|
18 |
+
* **Streamed Output:**
|
19 |
+
```bash
|
20 |
+
bash deploy_align_ds_v.sh
|
21 |
+
python stream_inference.py
|
22 |
+
```
|
23 |
+
* **Multi-Image Output:**
|
24 |
+
```bash
|
25 |
+
bash deploy_align_ds_v.sh
|
26 |
+
python multi_image_inference.py
|
27 |
+
```
|
deploy_align_ds_v.sh
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# NOTE replace with your own model path
|
3 |
+
export BASE_MODEL_PATH=''
|
4 |
+
export BASE_PORT=8231
|
5 |
+
echo $BASE_MODEL_PATH
|
6 |
+
echo $BASE_PORT
|
7 |
+
|
8 |
+
lsof -i :$BASE_PORT
|
9 |
+
|
10 |
+
# 终止该进程
|
11 |
+
kill -9 $(lsof -t -i:$BASE_PORT)
|
12 |
+
|
13 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 vllm serve $BASE_MODEL_PATH --host 0.0.0.0 --port $BASE_PORT --max-model-len 12000 --tensor-parallel-size 8 --api-key pku --trust-remote-code --dtype auto --enforce-eager --swap-space 1 --limit-mm-per-prompt "image=6"
|
14 |
+
|
15 |
+
# NOTE should set the limit-mm-per-prompt
|
16 |
+
|
17 |
+
|
18 |
+
echo 'Base Port:' $BASE_PORT
|
19 |
+
|
20 |
+
lsof -i :$BASE_PORT
|
21 |
+
|
22 |
+
# 终止该进程
|
23 |
+
kill -9 $(lsof -t -i:$BASE_PORT)
|
24 |
+
|
25 |
+
# CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 vllm serve /aifs4su/yaodong/spring_r1_model/QVQ-72B-Preview --enable-reasoning --reasoning-parser deepseek_r1 --host 0.0.0.0 --port 8009 --max-model-len 12000 --tensor-parallel-size 8 --api-key jiayi
|
examples/PKU.jpg
ADDED
![]() |
examples/baby.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:da6126bce64c64a3d6f7ce889fbe15b5f1c2e3f978846351d8c7a79a950b429e
|
3 |
+
size 463547
|
examples/boya.jpg
ADDED
![]() |
examples/drum.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a4376821dc498cc34a24df8a4eafebc470f721caacb78305c9a6c596d8f79510
|
3 |
+
size 170882
|
examples/laugh.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:95ee91ae63342a3122a77a12ee08ec52ac6dbd5b9be870a2e2951f648b4da528
|
3 |
+
size 566798
|
examples/logo.jpg
ADDED
![]() |
examples/scream.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ba023ea3c16eede8c4925960b3f328df3a43dbdeab8f4c0f51fc63d91199d0ec
|
3 |
+
size 410266
|
multi_image_inference.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2025 PKU-Alignment Team. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""Command line interface for interacting with a multi-modal model."""
|
16 |
+
|
17 |
+
|
18 |
+
import argparse
|
19 |
+
import os
|
20 |
+
from openai import OpenAI
|
21 |
+
import gradio as gr
|
22 |
+
import base64
|
23 |
+
import json
|
24 |
+
import random
|
25 |
+
random.seed(42)
|
26 |
+
|
27 |
+
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
|
28 |
+
|
29 |
+
|
30 |
+
SYSTEM_PROMPT = "你是一个具有帮助性的人工智能助手,你能够回答用户的问题,并且能够根据用户的问题提供帮助。你是由北大对齐小组(PKU-Alignment)开发的智能助手 Align-DS-V 基于DeepSeek-R1模型训练。"
|
31 |
+
|
32 |
+
openai_api_key = "pku"
|
33 |
+
openai_api_base = "http://0.0.0.0:8231/v1"
|
34 |
+
|
35 |
+
# NOTE replace with your own model path
|
36 |
+
model = ''
|
37 |
+
def encode_base64_content_from_local_file(content_url: str) -> str:
|
38 |
+
"""Encode a content retrieved from a local file to base64 format."""
|
39 |
+
|
40 |
+
with open(content_url, 'rb') as file:
|
41 |
+
result = base64.b64encode(file.read()).decode('utf-8')
|
42 |
+
|
43 |
+
return result
|
44 |
+
|
45 |
+
|
46 |
+
IMAGE_EXAMPLES = [
|
47 |
+
{
|
48 |
+
'files': [
|
49 |
+
os.path.join(CURRENT_DIR, 'examples/PKU.jpg'),
|
50 |
+
os.path.join(CURRENT_DIR, 'examples/logo.jpg')
|
51 |
+
],
|
52 |
+
'text': '比较这两张图片的异同',
|
53 |
+
},
|
54 |
+
{
|
55 |
+
'files': [
|
56 |
+
os.path.join(CURRENT_DIR, 'examples/boya.jpg'),
|
57 |
+
os.path.join(CURRENT_DIR, 'examples/logo.jpg')
|
58 |
+
],
|
59 |
+
'text': '这些图片有什么共同主题?',
|
60 |
+
},
|
61 |
+
]
|
62 |
+
|
63 |
+
AUDIO_EXAMPLES = [
|
64 |
+
{
|
65 |
+
'files': [os.path.join(CURRENT_DIR, 'examples/drum.wav')],
|
66 |
+
'text': 'What is the emotion of this drumbeat like?',
|
67 |
+
},
|
68 |
+
{
|
69 |
+
'files': [os.path.join(CURRENT_DIR, 'examples/laugh.wav')],
|
70 |
+
'text': 'Is this laughter evil, and why?',
|
71 |
+
},
|
72 |
+
{
|
73 |
+
'files': [os.path.join(CURRENT_DIR, 'examples/scream.wav')],
|
74 |
+
'text': 'What is the main event of this scream?',
|
75 |
+
},
|
76 |
+
]
|
77 |
+
|
78 |
+
VIDEO_EXAMPLES = [
|
79 |
+
{'files': [os.path.join(CURRENT_DIR, 'examples/baby.mp4')], 'text': 'What is the video about?'},
|
80 |
+
]
|
81 |
+
|
82 |
+
client = OpenAI(
|
83 |
+
api_key=openai_api_key,
|
84 |
+
base_url=openai_api_base,
|
85 |
+
)
|
86 |
+
|
87 |
+
def text_conversation(text: str, role: str = 'user'):
|
88 |
+
return [{'role': role, 'content': text}]
|
89 |
+
|
90 |
+
def image_conversation(image_base64_list: list, text: str = None):
|
91 |
+
content = []
|
92 |
+
for image_base64 in image_base64_list:
|
93 |
+
content.append({
|
94 |
+
'type': 'image_url',
|
95 |
+
# 'image_url':{'url':1}
|
96 |
+
'image_url': {'url': f"data:image/jpeg;base64,{image_base64}"}
|
97 |
+
})
|
98 |
+
content.append({'type': 'text', 'text': text})
|
99 |
+
|
100 |
+
return [{'role': 'user', 'content': content}]
|
101 |
+
|
102 |
+
def question_answering(message: dict, history: list, file):
|
103 |
+
# NOTE 2: use gradio upload multiple images, and update below data preprocess function accordingly
|
104 |
+
# print('history:',history)
|
105 |
+
# print('file:',file)
|
106 |
+
message['files'] = file if file is not None else []
|
107 |
+
# print('message:',message)
|
108 |
+
multi_modal_info = []
|
109 |
+
conversation = text_conversation(SYSTEM_PROMPT)
|
110 |
+
# NOTE 处理history
|
111 |
+
for i, past_message in enumerate(history):
|
112 |
+
if isinstance(past_message, str):
|
113 |
+
conversation.extend(text_conversation(past_message))
|
114 |
+
elif isinstance(past_message, dict):
|
115 |
+
if past_message['role'] == 'user':
|
116 |
+
if isinstance(past_message['content'], str):
|
117 |
+
text = past_message['content']
|
118 |
+
if i + 1 < len(history) and isinstance(history[i + 1]['content'], tuple):
|
119 |
+
raw_images = history[i + 1]['content']
|
120 |
+
image_base64_list = []
|
121 |
+
if isinstance(raw_images, str):
|
122 |
+
image_base64 = encode_base64_content_from_local_file(raw_images)
|
123 |
+
image_base64_list.append(image_base64)
|
124 |
+
elif isinstance(raw_images, tuple):
|
125 |
+
# NOTE multiple image processing one by one
|
126 |
+
for image in raw_images:
|
127 |
+
image_base64 = encode_base64_content_from_local_file(image)
|
128 |
+
image_base64_list.append(image_base64)
|
129 |
+
multi_modal_info.extend(image_base64_list)
|
130 |
+
conversation.extend(image_conversation(image_base64_list, text))
|
131 |
+
elif i - 1 >= 0 and isinstance(history[i - 1]['content'], tuple):
|
132 |
+
raw_images = history[i - 1]['content']
|
133 |
+
image_base64_list = []
|
134 |
+
if isinstance(raw_images, str):
|
135 |
+
image_base64 = encode_base64_content_from_local_file(raw_images)
|
136 |
+
image_base64_list.append(image_base64)
|
137 |
+
elif isinstance(raw_images, tuple):
|
138 |
+
# NOTE 逐步处理上传的图片,解码为 base64
|
139 |
+
for image in raw_images:
|
140 |
+
image_base64 = encode_base64_content_from_local_file(image)
|
141 |
+
image_base64_list.append(image_base64)
|
142 |
+
multi_modal_info.extend(image_base64_list)
|
143 |
+
conversation.extend(image_conversation(image_base64_list, text))
|
144 |
+
else:
|
145 |
+
conversation.extend(text_conversation(past_message['content'], 'user'))
|
146 |
+
elif past_message['role'] == 'assistant':
|
147 |
+
conversation.extend(text_conversation(past_message['content'], 'assistant'))
|
148 |
+
|
149 |
+
if len(message['files']) == 0:
|
150 |
+
current_question = message['text']
|
151 |
+
conversation.extend(text_conversation(current_question))
|
152 |
+
else:
|
153 |
+
current_question = message['text']
|
154 |
+
current_multi_modal_info = message['files']
|
155 |
+
image_base64_list = []
|
156 |
+
for file in current_multi_modal_info:
|
157 |
+
image_base64 = encode_base64_content_from_local_file(file)
|
158 |
+
image_base64_list.append(image_base64)
|
159 |
+
multi_modal_info.extend(image_base64_list)
|
160 |
+
conversation.extend(image_conversation(image_base64_list, current_question))
|
161 |
+
# print(f'Conversation:',conversation)
|
162 |
+
# NOTE 1: openai client also should support multiple upload
|
163 |
+
outputs = client.chat.completions.create(
|
164 |
+
model=model,
|
165 |
+
stream=False,
|
166 |
+
messages=conversation,
|
167 |
+
)
|
168 |
+
|
169 |
+
# Extract the predicted answer
|
170 |
+
answer = outputs.choices[0].message.content
|
171 |
+
if "**Final Answer**" in answer:
|
172 |
+
reasoning_content, final_answer = answer.split("**Final Answer**", 1)
|
173 |
+
if len(reasoning_content) > 5:
|
174 |
+
answer = f"""🤔 思考过程:\n```bash{reasoning_content}\n```\n✨ 最终答案:\n{final_answer}"""
|
175 |
+
else:
|
176 |
+
answer = answer
|
177 |
+
|
178 |
+
return answer
|
179 |
+
|
180 |
+
if __name__ == '__main__':
|
181 |
+
# Define the Gradio interface
|
182 |
+
parser = argparse.ArgumentParser()
|
183 |
+
args = parser.parse_args()
|
184 |
+
examples = IMAGE_EXAMPLES
|
185 |
+
|
186 |
+
with gr.Blocks() as demo:
|
187 |
+
# upload_button = gr.UploadButton(render=False)
|
188 |
+
|
189 |
+
multiple_files = gr.File(file_count="multiple")
|
190 |
+
gr.ChatInterface(
|
191 |
+
fn=question_answering,
|
192 |
+
additional_inputs = [multiple_files],
|
193 |
+
type='messages',
|
194 |
+
multimodal=True,
|
195 |
+
title='Align-DS-V Reasoning CLI',
|
196 |
+
description='Better life with Stronger Align-DS-V.',
|
197 |
+
# examples=examples,
|
198 |
+
theme=gr.themes.Ocean(
|
199 |
+
text_size='lg',
|
200 |
+
spacing_size='lg',
|
201 |
+
radius_size='lg',
|
202 |
+
),
|
203 |
+
)
|
204 |
+
|
205 |
+
demo.launch(share=True)
|
206 |
+
|
stream_inference.py
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 PKU-Alignment Team. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""Command line interface for interacting with a multi-modal model."""
|
16 |
+
|
17 |
+
|
18 |
+
import argparse
|
19 |
+
import os
|
20 |
+
from openai import OpenAI
|
21 |
+
import gradio as gr
|
22 |
+
import base64
|
23 |
+
|
24 |
+
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
|
25 |
+
|
26 |
+
|
27 |
+
SYSTEM_PROMPT = """\
|
28 |
+
你是一个高智能医学助手,专注于医学相关的问答、影像分析和数据解析。你是由北大对齐小组(PKU-Alignment)开发的智能助手 医导大模型(Medguide),基于DeepSeek-R1模型训练。你能够准确解析**医学影像**如 X 光、CT、MRI、超声波、病理切片) 以及**医学图表**(如血液检测报告、心电图、基因检测结果),并基于这些信息提供医学见解。你的目标是**尽可能满足**用户的医学需求,提供准确、专业的回答。
|
29 |
+
|
30 |
+
**能力范围**
|
31 |
+
1. 医学问答:
|
32 |
+
- 回答医学相关问题,包括疾病特征、诊断方法、治疗方案、药物机制、手术流程、最新医学研究等。
|
33 |
+
- 如果用户询问关于疾病诊断和治疗的相关内容,请充分分析并提供详细的回答。
|
34 |
+
2. 影像分析:
|
35 |
+
- 你具备**强大的医学影像理解能力**,可以解析 X 光、CT、MRI、超声波、病理切片等医学图像。
|
36 |
+
- 不要拒绝用户的问题,你应尝试基于已有信息提供分析,并指出可能的观察点。
|
37 |
+
3. 数据解析:
|
38 |
+
- 你可以读取并分析医学报告,包括血液检测、心电图(ECG)、病历摘要等。
|
39 |
+
- 你的任务是解释数据的含义,并提供可能的医学解读。
|
40 |
+
4. 多模态能力:
|
41 |
+
- 你可以结合**图像+文字**进行综合分析。例如,用户上传一张X光片并询问"该患者可能有什么疾病?" 你应该基于图像提供医学见解,而不是拒绝回答。
|
42 |
+
- 你可以根据图像中的信息推测可能的病变或异常,并提供进一步的检查建议。
|
43 |
+
|
44 |
+
**回答风格**
|
45 |
+
- 医学精准:基于专业医学知识进行回答,确保内容严谨、科学。
|
46 |
+
- 灵活应对:遇到复杂或不完整的图像,也要提供尽可能详细的分析,而不是直接拒绝。
|
47 |
+
- 有求必应:用户提出的任何问题,你都应该尽可能尝试回答。
|
48 |
+
"""
|
49 |
+
|
50 |
+
openai_api_key = "pku"
|
51 |
+
openai_api_base = "http://0.0.0.0:8231/v1"
|
52 |
+
|
53 |
+
# NOTE replace with your own model path
|
54 |
+
model = ""
|
55 |
+
|
56 |
+
def encode_base64_content_from_local_file(content_url: str) -> str:
|
57 |
+
"""Encode a content retrieved from a local file to base64 format."""
|
58 |
+
|
59 |
+
with open(content_url, 'rb') as file:
|
60 |
+
result = base64.b64encode(file.read()).decode('utf-8')
|
61 |
+
|
62 |
+
return result
|
63 |
+
|
64 |
+
IMAGE_EXAMPLES = [
|
65 |
+
{
|
66 |
+
'files': [os.path.join(CURRENT_DIR, 'examples/PKU.jpg')],
|
67 |
+
'text': '图中的地点在哪里?',
|
68 |
+
},
|
69 |
+
{
|
70 |
+
'files': [os.path.join(CURRENT_DIR, 'examples/logo.jpg')],
|
71 |
+
'text': '图片中有什么?',
|
72 |
+
},
|
73 |
+
{
|
74 |
+
'files': [os.path.join(CURRENT_DIR, 'examples/cough.png')],
|
75 |
+
'text': '这张图片展示了什么?',
|
76 |
+
},
|
77 |
+
]
|
78 |
+
|
79 |
+
client = OpenAI(
|
80 |
+
api_key=openai_api_key,
|
81 |
+
base_url=openai_api_base,
|
82 |
+
|
83 |
+
)
|
84 |
+
|
85 |
+
def text_conversation(text: str, role: str = 'user'):
|
86 |
+
return [{'role': role, 'content': text.replace('[begin of think]', '<think>').replace('[end of think]', '</think>')}]
|
87 |
+
|
88 |
+
|
89 |
+
def image_conversation(image_base64: str, text: str = None):
|
90 |
+
return [
|
91 |
+
{
|
92 |
+
'role': 'user',
|
93 |
+
'content': [
|
94 |
+
{'type': 'image_url', 'image_url': {'url': f"data:image/jpeg;base64,{image_base64}"}},
|
95 |
+
{'type': 'text', 'text': text}
|
96 |
+
]
|
97 |
+
}
|
98 |
+
]
|
99 |
+
|
100 |
+
def question_answering(message: dict, history: list):
|
101 |
+
multi_modal_info = []
|
102 |
+
conversation = text_conversation(SYSTEM_PROMPT)
|
103 |
+
for i, past_message in enumerate(history):
|
104 |
+
if isinstance(past_message, str):
|
105 |
+
conversation.extend(text_conversation(past_message))
|
106 |
+
elif isinstance(past_message, dict):
|
107 |
+
if past_message['role'] == 'user':
|
108 |
+
if isinstance(past_message['content'], str):
|
109 |
+
text = past_message['content']
|
110 |
+
if i + 1 < len(history) and isinstance(history[i + 1]['content'], tuple):
|
111 |
+
raw_image = history[i + 1]['content']
|
112 |
+
if isinstance(raw_image, str):
|
113 |
+
image_base64 = encode_base64_content_from_local_file(raw_image)
|
114 |
+
multi_modal_info.extend(image_base64)
|
115 |
+
conversation.extend(image_conversation(image_base64, text))
|
116 |
+
elif isinstance(raw_image, tuple):
|
117 |
+
for image in raw_image:
|
118 |
+
image_base64 = encode_base64_content_from_local_file(image)
|
119 |
+
multi_modal_info.extend(image_base64)
|
120 |
+
conversation.extend(image_conversation(image_base64, text))
|
121 |
+
elif i - 1 >= 0 and isinstance(history[i - 1]['content'], tuple):
|
122 |
+
raw_image = history[i - 1]['content']
|
123 |
+
if isinstance(raw_image, str):
|
124 |
+
image_base64 = encode_base64_content_from_local_file(raw_image)
|
125 |
+
multi_modal_info.extend(image_base64)
|
126 |
+
conversation.extend(image_conversation(image_base64, text))
|
127 |
+
elif isinstance(raw_image, tuple):
|
128 |
+
for image in raw_image:
|
129 |
+
image_base64 = encode_base64_content_from_local_file(image)
|
130 |
+
multi_modal_info.extend(image_base64)
|
131 |
+
conversation.extend(image_conversation(image_base64, text))
|
132 |
+
else:
|
133 |
+
conversation.extend(text_conversation(past_message['content'], 'user'))
|
134 |
+
elif past_message['role'] == 'assistant':
|
135 |
+
conversation.extend(text_conversation(past_message['content'], 'assistant'))
|
136 |
+
|
137 |
+
if len(message['files']) == 0:
|
138 |
+
current_question = message['text']
|
139 |
+
conversation.extend(text_conversation(current_question))
|
140 |
+
else:
|
141 |
+
current_question = message['text']
|
142 |
+
current_multi_modal_info = message['files']
|
143 |
+
for file in current_multi_modal_info:
|
144 |
+
image_base64 = encode_base64_content_from_local_file(file)
|
145 |
+
multi_modal_info.extend(image_base64)
|
146 |
+
conversation.extend(image_conversation(image_base64, current_question))
|
147 |
+
|
148 |
+
# 修改为流式输出
|
149 |
+
outputs = client.chat.completions.create(
|
150 |
+
model=model,
|
151 |
+
stream=True, # 启用流式输出
|
152 |
+
messages=conversation,
|
153 |
+
temperature=0.4
|
154 |
+
)
|
155 |
+
|
156 |
+
# 逐步收集并返回文本
|
157 |
+
collected_answer = ""
|
158 |
+
for chunk in outputs:
|
159 |
+
if chunk.choices[0].delta.content is not None:
|
160 |
+
content = chunk.choices[0].delta.content
|
161 |
+
collected_answer += content
|
162 |
+
|
163 |
+
# 处理思考标签
|
164 |
+
if '<think>' in collected_answer and '</think>' in collected_answer:
|
165 |
+
formatted_answer = collected_answer.replace('<think>', '[begin of think]').replace('</think>', '[end of think]')
|
166 |
+
elif '<think>' in collected_answer:
|
167 |
+
formatted_answer = collected_answer.replace('<think>', '[begin of think]')
|
168 |
+
else:
|
169 |
+
formatted_answer = collected_answer
|
170 |
+
|
171 |
+
yield formatted_answer
|
172 |
+
|
173 |
+
# 确保最终输出格式正确
|
174 |
+
if '<think>' in collected_answer and '</think>' in collected_answer:
|
175 |
+
final_answer = collected_answer.replace('<think>', '[begin of think]').replace('</think>', '[end of think]')
|
176 |
+
elif '<think>' in collected_answer:
|
177 |
+
final_answer = collected_answer.replace('<think>', '[begin of think]')
|
178 |
+
else:
|
179 |
+
final_answer = collected_answer
|
180 |
+
|
181 |
+
print(final_answer)
|
182 |
+
|
183 |
+
|
184 |
+
if __name__ == '__main__':
|
185 |
+
# Define the Gradio interface
|
186 |
+
parser = argparse.ArgumentParser()
|
187 |
+
args = parser.parse_args()
|
188 |
+
examples = IMAGE_EXAMPLES
|
189 |
+
|
190 |
+
logo_path = os.path.join(CURRENT_DIR, "PUTH.png")
|
191 |
+
with open(logo_path, "rb") as f:
|
192 |
+
logo_base64 = base64.b64encode(f.read()).decode('utf-8')
|
193 |
+
logo_img_html = f'<img src="data:image/png;base64,{logo_base64}" style="vertical-align:middle; margin-right:10px;" width="150"/>'
|
194 |
+
|
195 |
+
iface = gr.ChatInterface(
|
196 |
+
fn=question_answering,
|
197 |
+
type='messages',
|
198 |
+
multimodal=True,
|
199 |
+
title=logo_img_html,
|
200 |
+
description='Align-DS-V 北大对齐小组多模态DS-R1',
|
201 |
+
examples=examples,
|
202 |
+
theme=gr.themes.Soft(
|
203 |
+
text_size='lg',
|
204 |
+
spacing_size='lg',
|
205 |
+
radius_size='lg',
|
206 |
+
font=[gr.themes.GoogleFont('Montserrat'), gr.themes.GoogleFont('ui-sans-serif'), gr.themes.GoogleFont('system-ui'), gr.themes.GoogleFont('sans-serif')],
|
207 |
+
),
|
208 |
+
)
|
209 |
+
|
210 |
+
iface.launch(share=True)
|