0407_align_ds_v / stream_inference.py
alignmentforever's picture
upload model folder to repo
ddfa14d verified
# Copyright 2024 PKU-Alignment Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Command line interface for interacting with a multi-modal model."""
import argparse
import os
from openai import OpenAI
import gradio as gr
import base64
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
SYSTEM_PROMPT = """\
你是一个高智能医学助手,专注于医学相关的问答、影像分析和数据解析。你是由北大对齐小组(PKU-Alignment)开发的智能助手 医导大模型(Medguide),基于DeepSeek-R1模型训练。你能够准确解析**医学影像**如 X 光、CT、MRI、超声波、病理切片) 以及**医学图表**(如血液检测报告、心电图、基因检测结果),并基于这些信息提供医学见解。你的目标是**尽可能满足**用户的医学需求,提供准确、专业的回答。
**能力范围**
1. 医学问答:
- 回答医学相关问题,包括疾病特征、诊断方法、治疗方案、药物机制、手术流程、最新医学研究等。
- 如果用户询问关于疾病诊断和治疗的相关内容,请充分分析并提供详细的回答。
2. 影像分析:
- 你具备**强大的医学影像理解能力**,可以解析 X 光、CT、MRI、超声波、病理切片等医学图像。
- 不要拒绝用户的问题,你应尝试基于已有信息提供分析,并指出可能的观察点。
3. 数据解析:
- 你可以读取并分析医学报告,包括血液检测、心电图(ECG)、病历摘要等。
- 你的任务是解释数据的含义,并提供可能的医学解读。
4. 多模态能力:
- 你可以结合**图像+文字**进行综合分析。例如,用户上传一张X光片并询问"该患者可能有什么疾病?" 你应该基于图像提供医学见解,而不是拒绝回答。
- 你可以根据图像中的信息推测可能的病变或异常,并提供进一步的检查建议。
**回答风格**
- 医学精准:基于专业医学知识进行回答,确保内容严谨、科学。
- 灵活应对:遇到复杂或不完整的图像,也要提供尽可能详细的分析,而不是直接拒绝。
- 有求必应:用户提出的任何问题,你都应该尽可能尝试回答。
"""
openai_api_key = "pku"
openai_api_base = "http://0.0.0.0:8231/v1"
# NOTE replace with your own model path
model = ""
def encode_base64_content_from_local_file(content_url: str) -> str:
"""Encode a content retrieved from a local file to base64 format."""
with open(content_url, 'rb') as file:
result = base64.b64encode(file.read()).decode('utf-8')
return result
IMAGE_EXAMPLES = [
{
'files': [os.path.join(CURRENT_DIR, 'examples/PKU.jpg')],
'text': '图中的地点在哪里?',
},
{
'files': [os.path.join(CURRENT_DIR, 'examples/logo.jpg')],
'text': '图片中有什么?',
},
{
'files': [os.path.join(CURRENT_DIR, 'examples/cough.png')],
'text': '这张图片展示了什么?',
},
]
client = OpenAI(
api_key=openai_api_key,
base_url=openai_api_base,
)
def text_conversation(text: str, role: str = 'user'):
return [{'role': role, 'content': text.replace('[begin of think]', '<think>').replace('[end of think]', '</think>')}]
def image_conversation(image_base64: str, text: str = None):
return [
{
'role': 'user',
'content': [
{'type': 'image_url', 'image_url': {'url': f"data:image/jpeg;base64,{image_base64}"}},
{'type': 'text', 'text': text}
]
}
]
def question_answering(message: dict, history: list):
multi_modal_info = []
conversation = text_conversation(SYSTEM_PROMPT)
for i, past_message in enumerate(history):
if isinstance(past_message, str):
conversation.extend(text_conversation(past_message))
elif isinstance(past_message, dict):
if past_message['role'] == 'user':
if isinstance(past_message['content'], str):
text = past_message['content']
if i + 1 < len(history) and isinstance(history[i + 1]['content'], tuple):
raw_image = history[i + 1]['content']
if isinstance(raw_image, str):
image_base64 = encode_base64_content_from_local_file(raw_image)
multi_modal_info.extend(image_base64)
conversation.extend(image_conversation(image_base64, text))
elif isinstance(raw_image, tuple):
for image in raw_image:
image_base64 = encode_base64_content_from_local_file(image)
multi_modal_info.extend(image_base64)
conversation.extend(image_conversation(image_base64, text))
elif i - 1 >= 0 and isinstance(history[i - 1]['content'], tuple):
raw_image = history[i - 1]['content']
if isinstance(raw_image, str):
image_base64 = encode_base64_content_from_local_file(raw_image)
multi_modal_info.extend(image_base64)
conversation.extend(image_conversation(image_base64, text))
elif isinstance(raw_image, tuple):
for image in raw_image:
image_base64 = encode_base64_content_from_local_file(image)
multi_modal_info.extend(image_base64)
conversation.extend(image_conversation(image_base64, text))
else:
conversation.extend(text_conversation(past_message['content'], 'user'))
elif past_message['role'] == 'assistant':
conversation.extend(text_conversation(past_message['content'], 'assistant'))
if len(message['files']) == 0:
current_question = message['text']
conversation.extend(text_conversation(current_question))
else:
current_question = message['text']
current_multi_modal_info = message['files']
for file in current_multi_modal_info:
image_base64 = encode_base64_content_from_local_file(file)
multi_modal_info.extend(image_base64)
conversation.extend(image_conversation(image_base64, current_question))
# 修改为流式输出
outputs = client.chat.completions.create(
model=model,
stream=True, # 启用流式输出
messages=conversation,
temperature=0.4
)
# 逐步收集并返回文本
collected_answer = ""
for chunk in outputs:
if chunk.choices[0].delta.content is not None:
content = chunk.choices[0].delta.content
collected_answer += content
# 处理思考标签
if '<think>' in collected_answer and '</think>' in collected_answer:
formatted_answer = collected_answer.replace('<think>', '[begin of think]').replace('</think>', '[end of think]')
elif '<think>' in collected_answer:
formatted_answer = collected_answer.replace('<think>', '[begin of think]')
else:
formatted_answer = collected_answer
yield formatted_answer
# 确保最终输出格式正确
if '<think>' in collected_answer and '</think>' in collected_answer:
final_answer = collected_answer.replace('<think>', '[begin of think]').replace('</think>', '[end of think]')
elif '<think>' in collected_answer:
final_answer = collected_answer.replace('<think>', '[begin of think]')
else:
final_answer = collected_answer
print(final_answer)
if __name__ == '__main__':
# Define the Gradio interface
parser = argparse.ArgumentParser()
args = parser.parse_args()
examples = IMAGE_EXAMPLES
logo_path = os.path.join(CURRENT_DIR, "PUTH.png")
with open(logo_path, "rb") as f:
logo_base64 = base64.b64encode(f.read()).decode('utf-8')
logo_img_html = f'<img src="data:image/png;base64,{logo_base64}" style="vertical-align:middle; margin-right:10px;" width="150"/>'
iface = gr.ChatInterface(
fn=question_answering,
type='messages',
multimodal=True,
title=logo_img_html,
description='Align-DS-V 北大对齐小组多模态DS-R1',
examples=examples,
theme=gr.themes.Soft(
text_size='lg',
spacing_size='lg',
radius_size='lg',
font=[gr.themes.GoogleFont('Montserrat'), gr.themes.GoogleFont('ui-sans-serif'), gr.themes.GoogleFont('system-ui'), gr.themes.GoogleFont('sans-serif')],
),
)
iface.launch(share=True)