|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Command line interface for interacting with a multi-modal model.""" |
|
|
|
|
|
import argparse |
|
import os |
|
from openai import OpenAI |
|
import gradio as gr |
|
import base64 |
|
|
|
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) |
|
|
|
|
|
SYSTEM_PROMPT = """\ |
|
你是一个高智能医学助手,专注于医学相关的问答、影像分析和数据解析。你是由北大对齐小组(PKU-Alignment)开发的智能助手 医导大模型(Medguide),基于DeepSeek-R1模型训练。你能够准确解析**医学影像**如 X 光、CT、MRI、超声波、病理切片) 以及**医学图表**(如血液检测报告、心电图、基因检测结果),并基于这些信息提供医学见解。你的目标是**尽可能满足**用户的医学需求,提供准确、专业的回答。 |
|
|
|
**能力范围** |
|
1. 医学问答: |
|
- 回答医学相关问题,包括疾病特征、诊断方法、治疗方案、药物机制、手术流程、最新医学研究等。 |
|
- 如果用户询问关于疾病诊断和治疗的相关内容,请充分分析并提供详细的回答。 |
|
2. 影像分析: |
|
- 你具备**强大的医学影像理解能力**,可以解析 X 光、CT、MRI、超声波、病理切片等医学图像。 |
|
- 不要拒绝用户的问题,你应尝试基于已有信息提供分析,并指出可能的观察点。 |
|
3. 数据解析: |
|
- 你可以读取并分析医学报告,包括血液检测、心电图(ECG)、病历摘要等。 |
|
- 你的任务是解释数据的含义,并提供可能的医学解读。 |
|
4. 多模态能力: |
|
- 你可以结合**图像+文字**进行综合分析。例如,用户上传一张X光片并询问"该患者可能有什么疾病?" 你应该基于图像提供医学见解,而不是拒绝回答。 |
|
- 你可以根据图像中的信息推测可能的病变或异常,并提供进一步的检查建议。 |
|
|
|
**回答风格** |
|
- 医学精准:基于专业医学知识进行回答,确保内容严谨、科学。 |
|
- 灵活应对:遇到复杂或不完整的图像,也要提供尽可能详细的分析,而不是直接拒绝。 |
|
- 有求必应:用户提出的任何问题,你都应该尽可能尝试回答。 |
|
""" |
|
|
|
openai_api_key = "pku" |
|
openai_api_base = "http://0.0.0.0:8231/v1" |
|
|
|
|
|
model = "" |
|
|
|
def encode_base64_content_from_local_file(content_url: str) -> str: |
|
"""Encode a content retrieved from a local file to base64 format.""" |
|
|
|
with open(content_url, 'rb') as file: |
|
result = base64.b64encode(file.read()).decode('utf-8') |
|
|
|
return result |
|
|
|
IMAGE_EXAMPLES = [ |
|
{ |
|
'files': [os.path.join(CURRENT_DIR, 'examples/PKU.jpg')], |
|
'text': '图中的地点在哪里?', |
|
}, |
|
{ |
|
'files': [os.path.join(CURRENT_DIR, 'examples/logo.jpg')], |
|
'text': '图片中有什么?', |
|
}, |
|
{ |
|
'files': [os.path.join(CURRENT_DIR, 'examples/cough.png')], |
|
'text': '这张图片展示了什么?', |
|
}, |
|
] |
|
|
|
client = OpenAI( |
|
api_key=openai_api_key, |
|
base_url=openai_api_base, |
|
|
|
) |
|
|
|
def text_conversation(text: str, role: str = 'user'): |
|
return [{'role': role, 'content': text.replace('[begin of think]', '<think>').replace('[end of think]', '</think>')}] |
|
|
|
|
|
def image_conversation(image_base64: str, text: str = None): |
|
return [ |
|
{ |
|
'role': 'user', |
|
'content': [ |
|
{'type': 'image_url', 'image_url': {'url': f"data:image/jpeg;base64,{image_base64}"}}, |
|
{'type': 'text', 'text': text} |
|
] |
|
} |
|
] |
|
|
|
def question_answering(message: dict, history: list): |
|
multi_modal_info = [] |
|
conversation = text_conversation(SYSTEM_PROMPT) |
|
for i, past_message in enumerate(history): |
|
if isinstance(past_message, str): |
|
conversation.extend(text_conversation(past_message)) |
|
elif isinstance(past_message, dict): |
|
if past_message['role'] == 'user': |
|
if isinstance(past_message['content'], str): |
|
text = past_message['content'] |
|
if i + 1 < len(history) and isinstance(history[i + 1]['content'], tuple): |
|
raw_image = history[i + 1]['content'] |
|
if isinstance(raw_image, str): |
|
image_base64 = encode_base64_content_from_local_file(raw_image) |
|
multi_modal_info.extend(image_base64) |
|
conversation.extend(image_conversation(image_base64, text)) |
|
elif isinstance(raw_image, tuple): |
|
for image in raw_image: |
|
image_base64 = encode_base64_content_from_local_file(image) |
|
multi_modal_info.extend(image_base64) |
|
conversation.extend(image_conversation(image_base64, text)) |
|
elif i - 1 >= 0 and isinstance(history[i - 1]['content'], tuple): |
|
raw_image = history[i - 1]['content'] |
|
if isinstance(raw_image, str): |
|
image_base64 = encode_base64_content_from_local_file(raw_image) |
|
multi_modal_info.extend(image_base64) |
|
conversation.extend(image_conversation(image_base64, text)) |
|
elif isinstance(raw_image, tuple): |
|
for image in raw_image: |
|
image_base64 = encode_base64_content_from_local_file(image) |
|
multi_modal_info.extend(image_base64) |
|
conversation.extend(image_conversation(image_base64, text)) |
|
else: |
|
conversation.extend(text_conversation(past_message['content'], 'user')) |
|
elif past_message['role'] == 'assistant': |
|
conversation.extend(text_conversation(past_message['content'], 'assistant')) |
|
|
|
if len(message['files']) == 0: |
|
current_question = message['text'] |
|
conversation.extend(text_conversation(current_question)) |
|
else: |
|
current_question = message['text'] |
|
current_multi_modal_info = message['files'] |
|
for file in current_multi_modal_info: |
|
image_base64 = encode_base64_content_from_local_file(file) |
|
multi_modal_info.extend(image_base64) |
|
conversation.extend(image_conversation(image_base64, current_question)) |
|
|
|
|
|
outputs = client.chat.completions.create( |
|
model=model, |
|
stream=True, |
|
messages=conversation, |
|
temperature=0.4 |
|
) |
|
|
|
|
|
collected_answer = "" |
|
for chunk in outputs: |
|
if chunk.choices[0].delta.content is not None: |
|
content = chunk.choices[0].delta.content |
|
collected_answer += content |
|
|
|
|
|
if '<think>' in collected_answer and '</think>' in collected_answer: |
|
formatted_answer = collected_answer.replace('<think>', '[begin of think]').replace('</think>', '[end of think]') |
|
elif '<think>' in collected_answer: |
|
formatted_answer = collected_answer.replace('<think>', '[begin of think]') |
|
else: |
|
formatted_answer = collected_answer |
|
|
|
yield formatted_answer |
|
|
|
|
|
if '<think>' in collected_answer and '</think>' in collected_answer: |
|
final_answer = collected_answer.replace('<think>', '[begin of think]').replace('</think>', '[end of think]') |
|
elif '<think>' in collected_answer: |
|
final_answer = collected_answer.replace('<think>', '[begin of think]') |
|
else: |
|
final_answer = collected_answer |
|
|
|
print(final_answer) |
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
parser = argparse.ArgumentParser() |
|
args = parser.parse_args() |
|
examples = IMAGE_EXAMPLES |
|
|
|
logo_path = os.path.join(CURRENT_DIR, "PUTH.png") |
|
with open(logo_path, "rb") as f: |
|
logo_base64 = base64.b64encode(f.read()).decode('utf-8') |
|
logo_img_html = f'<img src="data:image/png;base64,{logo_base64}" style="vertical-align:middle; margin-right:10px;" width="150"/>' |
|
|
|
iface = gr.ChatInterface( |
|
fn=question_answering, |
|
type='messages', |
|
multimodal=True, |
|
title=logo_img_html, |
|
description='Align-DS-V 北大对齐小组多模态DS-R1', |
|
examples=examples, |
|
theme=gr.themes.Soft( |
|
text_size='lg', |
|
spacing_size='lg', |
|
radius_size='lg', |
|
font=[gr.themes.GoogleFont('Montserrat'), gr.themes.GoogleFont('ui-sans-serif'), gr.themes.GoogleFont('system-ui'), gr.themes.GoogleFont('sans-serif')], |
|
), |
|
) |
|
|
|
iface.launch(share=True) |