0407_align_ds_v / multi_image_inference.py
alignmentforever's picture
upload model folder to repo
ddfa14d verified
# Copyright 2025 PKU-Alignment Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Command line interface for interacting with a multi-modal model."""
import argparse
import os
from openai import OpenAI
import gradio as gr
import base64
import json
import random
random.seed(42)
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
SYSTEM_PROMPT = "你是一个具有帮助性的人工智能助手,你能够回答用户的问题,并且能够根据用户的问题提供帮助。你是由北大对齐小组(PKU-Alignment)开发的智能助手 Align-DS-V 基于DeepSeek-R1模型训练。"
openai_api_key = "pku"
openai_api_base = "http://0.0.0.0:8231/v1"
# NOTE replace with your own model path
model = ''
def encode_base64_content_from_local_file(content_url: str) -> str:
"""Encode a content retrieved from a local file to base64 format."""
with open(content_url, 'rb') as file:
result = base64.b64encode(file.read()).decode('utf-8')
return result
IMAGE_EXAMPLES = [
{
'files': [
os.path.join(CURRENT_DIR, 'examples/PKU.jpg'),
os.path.join(CURRENT_DIR, 'examples/logo.jpg')
],
'text': '比较这两张图片的异同',
},
{
'files': [
os.path.join(CURRENT_DIR, 'examples/boya.jpg'),
os.path.join(CURRENT_DIR, 'examples/logo.jpg')
],
'text': '这些图片有什么共同主题?',
},
]
AUDIO_EXAMPLES = [
{
'files': [os.path.join(CURRENT_DIR, 'examples/drum.wav')],
'text': 'What is the emotion of this drumbeat like?',
},
{
'files': [os.path.join(CURRENT_DIR, 'examples/laugh.wav')],
'text': 'Is this laughter evil, and why?',
},
{
'files': [os.path.join(CURRENT_DIR, 'examples/scream.wav')],
'text': 'What is the main event of this scream?',
},
]
VIDEO_EXAMPLES = [
{'files': [os.path.join(CURRENT_DIR, 'examples/baby.mp4')], 'text': 'What is the video about?'},
]
client = OpenAI(
api_key=openai_api_key,
base_url=openai_api_base,
)
def text_conversation(text: str, role: str = 'user'):
return [{'role': role, 'content': text}]
def image_conversation(image_base64_list: list, text: str = None):
content = []
for image_base64 in image_base64_list:
content.append({
'type': 'image_url',
# 'image_url':{'url':1}
'image_url': {'url': f"data:image/jpeg;base64,{image_base64}"}
})
content.append({'type': 'text', 'text': text})
return [{'role': 'user', 'content': content}]
def question_answering(message: dict, history: list, file):
# NOTE 2: use gradio upload multiple images, and update below data preprocess function accordingly
# print('history:',history)
# print('file:',file)
message['files'] = file if file is not None else []
# print('message:',message)
multi_modal_info = []
conversation = text_conversation(SYSTEM_PROMPT)
# NOTE 处理history
for i, past_message in enumerate(history):
if isinstance(past_message, str):
conversation.extend(text_conversation(past_message))
elif isinstance(past_message, dict):
if past_message['role'] == 'user':
if isinstance(past_message['content'], str):
text = past_message['content']
if i + 1 < len(history) and isinstance(history[i + 1]['content'], tuple):
raw_images = history[i + 1]['content']
image_base64_list = []
if isinstance(raw_images, str):
image_base64 = encode_base64_content_from_local_file(raw_images)
image_base64_list.append(image_base64)
elif isinstance(raw_images, tuple):
# NOTE multiple image processing one by one
for image in raw_images:
image_base64 = encode_base64_content_from_local_file(image)
image_base64_list.append(image_base64)
multi_modal_info.extend(image_base64_list)
conversation.extend(image_conversation(image_base64_list, text))
elif i - 1 >= 0 and isinstance(history[i - 1]['content'], tuple):
raw_images = history[i - 1]['content']
image_base64_list = []
if isinstance(raw_images, str):
image_base64 = encode_base64_content_from_local_file(raw_images)
image_base64_list.append(image_base64)
elif isinstance(raw_images, tuple):
# NOTE 逐步处理上传的图片,解码为 base64
for image in raw_images:
image_base64 = encode_base64_content_from_local_file(image)
image_base64_list.append(image_base64)
multi_modal_info.extend(image_base64_list)
conversation.extend(image_conversation(image_base64_list, text))
else:
conversation.extend(text_conversation(past_message['content'], 'user'))
elif past_message['role'] == 'assistant':
conversation.extend(text_conversation(past_message['content'], 'assistant'))
if len(message['files']) == 0:
current_question = message['text']
conversation.extend(text_conversation(current_question))
else:
current_question = message['text']
current_multi_modal_info = message['files']
image_base64_list = []
for file in current_multi_modal_info:
image_base64 = encode_base64_content_from_local_file(file)
image_base64_list.append(image_base64)
multi_modal_info.extend(image_base64_list)
conversation.extend(image_conversation(image_base64_list, current_question))
# print(f'Conversation:',conversation)
# NOTE 1: openai client also should support multiple upload
outputs = client.chat.completions.create(
model=model,
stream=False,
messages=conversation,
)
# Extract the predicted answer
answer = outputs.choices[0].message.content
if "**Final Answer**" in answer:
reasoning_content, final_answer = answer.split("**Final Answer**", 1)
if len(reasoning_content) > 5:
answer = f"""🤔 思考过程:\n```bash{reasoning_content}\n```\n✨ 最终答案:\n{final_answer}"""
else:
answer = answer
return answer
if __name__ == '__main__':
# Define the Gradio interface
parser = argparse.ArgumentParser()
args = parser.parse_args()
examples = IMAGE_EXAMPLES
with gr.Blocks() as demo:
# upload_button = gr.UploadButton(render=False)
multiple_files = gr.File(file_count="multiple")
gr.ChatInterface(
fn=question_answering,
additional_inputs = [multiple_files],
type='messages',
multimodal=True,
title='Align-DS-V Reasoning CLI',
description='Better life with Stronger Align-DS-V.',
# examples=examples,
theme=gr.themes.Ocean(
text_size='lg',
spacing_size='lg',
radius_size='lg',
),
)
demo.launch(share=True)