KaiChen1998 commited on
Commit
5f870ca
·
1 Parent(s): 6f2ad11

initial commit

Browse files
app.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import traceback
4
+ import logging
5
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
6
+
7
+ import spaces
8
+ import gradio as gr
9
+ from conversation_public import default_conversation
10
+
11
+ auth_token = os.environ.get("TOKEN_FROM_SECRET")
12
+
13
+ ##########################################
14
+ # LLM part
15
+ ##########################################
16
+ from transformers import AutoProcessor, AutoTokenizer, TextIteratorStreamer
17
+ from vllm import LLM, SamplingParams
18
+ from qwen_vl_utils import process_vision_info
19
+ from threading import Thread
20
+
21
+ # === Prompts ===
22
+ SYSTEM_PROMPT_LLM = "You are a helpful assistant."
23
+ SYSTEM_PROMPT_CAP = "You are given an image and a relevant question. Based on the query, please describe the image in details. Do not try to answer the question."
24
+
25
+ CAPTION_PROMPT = "Question: {}\nPlease describe the image. DO NOT try to answer the question!"
26
+ LLM_PROMPT = """In the following text, you will receive a detailed caption of an image and a relevant question. In addition, you will be provided with a tentative model response. You goal is to answer the question using these information.\n\n### The detailed caption of the provided image: {}\n\n### Note that the caption might contain incorrect solutions, do not be misguided by them.\n\n### A problem to be solved: {}\n\n### A tentative model response: {}\n\n### Note that the above tentative response might be inaccurate (due to calculation errors, incorrect logic/reasoning and so on), under such a case, please ignore it and give your own solutions. However, if you do not have enough evidence to show it is wrong, please output the tentative response."""
27
+
28
+ # === Initialize Models ===
29
+ MLLM_MODEL_PATH = "KaiChen1998/RACRO-7B-CRO-GRPO"
30
+ LLM_MODEL_PATH = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"
31
+
32
+ processor = AutoProcessor.from_pretrained(MLLM_MODEL_PATH)
33
+ tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_PATH)
34
+
35
+ mllm = LLM(model=MLLM_MODEL_PATH, tensor_parallel_size=1, gpu_memory_utilization=0.8,
36
+ device='cuda:0', dtype="bfloat16", limit_mm_per_prompt={"image": 1})
37
+
38
+ llm = LLM(model=LLM_MODEL_PATH, tensor_parallel_size=1, gpu_memory_utilization=0.8,
39
+ device='cuda:0', dtype="bfloat16")
40
+
41
+ mllm_sampling = SamplingParams(temperature=0, max_tokens=8192)
42
+ llm_sampling = SamplingParams(temperature=0.6, top_p=0.95, max_tokens=8192)
43
+
44
+ # === Build Prompts ===
45
+ def build_messages(image_path, question):
46
+ cap_msgs = [
47
+ {"role": "system", "content": SYSTEM_PROMPT_CAP},
48
+ {"role": "user", "content": [{"type": "image", "image": image_path}, {"type": "text", "text": CAPTION_PROMPT.format(question)}]}
49
+ ]
50
+ qa_msgs = [
51
+ {"role": "user", "content": [{"type": "image", "image": image_path}, {"type": "text", "text": question + " Please think step by step. The final answer MUST BE put in \\boxed{}."}]}
52
+ ]
53
+ return cap_msgs, qa_msgs
54
+
55
+ # === Run Captioning and QA ===
56
+ def run_mllm_tentative(image_tensor, cap_prompt, qa_prompt):
57
+ qa_output = mllm.generate([{"multi_modal_data": {"image": image_tensor}, "prompt": qa_prompt[0]}], sampling_params=mllm_sampling)
58
+ return qa_output[0].outputs[0].text
59
+
60
+ def run_mllm_caption(image_tensor, cap_prompt, qa_prompt):
61
+ cap_output = mllm.generate([{"multi_modal_data": {"image": image_tensor}, "prompt": cap_prompt[0]}], sampling_params=mllm_sampling)
62
+ return cap_output[0].outputs[0].text
63
+
64
+ # === Final Reasoning Step ===
65
+ def run_llm_reasoning(caption, question, answer):
66
+ messages = [
67
+ {"role": "system", "content": SYSTEM_PROMPT_LLM},
68
+ {"role": "user", "content": LLM_PROMPT.format(caption, question, answer)}
69
+ ]
70
+ prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
71
+ output = llm.generate([{"prompt": prompt}], sampling_params=llm_sampling)
72
+ return output[0].outputs[0].text
73
+
74
+ ##########################################
75
+ # Gradio part
76
+ ##########################################
77
+ no_change_btn = gr.Button()
78
+ enable_btn = gr.Button(interactive=True)
79
+ disable_btn = gr.Button(interactive=False)
80
+ server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
81
+ server_oom_msg = "**OUT OF GPU MEMORY DETECTED. PLEASE DECREASE THE MAX OUTPUT TOKENS AND REGENERATE.**"
82
+
83
+ def load_demo_refresh_model_list():
84
+ logging.info(f"load_demo.")
85
+ state = default_conversation.copy()
86
+ return state
87
+
88
+ def regenerate(state, image_process_mode):
89
+ logging.info(f"regenerate.")
90
+ state.messages[-1][-1] = None
91
+ prev_human_msg = state.messages[-2]
92
+ if type(prev_human_msg[1]) in (tuple, list):
93
+ prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode, *prev_human_msg[1][3:])
94
+ state.skip_next = False
95
+ return (state, state.to_gradio_chatbot_public(), "", None) + (disable_btn,) * 2
96
+
97
+ def clear_history():
98
+ logging.info(f"clear_history.")
99
+ state = default_conversation.copy()
100
+ return (state, state.to_gradio_chatbot_public(), "", None) + (disable_btn,) * 2
101
+
102
+ ############
103
+ # Show prompt in the chatbot
104
+ # Input: [state, textbox, imagebox, image_process_mode]
105
+ # Return: [state, chatbot, textbox, imagebox] + btn_list
106
+ ############
107
+ def add_text(state, text, image, image_process_mode):
108
+ # Input legality checking
109
+ logging.info(f"add_text. len: {len(text)}")
110
+ if len(text) <= 0 or image is None:
111
+ state.skip_next = True
112
+ return (state, state.to_gradio_chatbot_public(), "", None) + (no_change_btn,) * 2
113
+
114
+ # Deal with image inputs
115
+ if image is not None:
116
+ text = (text, image, image_process_mode, None)
117
+
118
+ # Single round only
119
+ state = default_conversation.copy()
120
+ state.append_message(state.roles[0], text)
121
+ state.skip_next = False
122
+ logging.info(str(state.messages))
123
+ return (state, state.to_gradio_chatbot_public(), "") + (disable_btn,) * 2
124
+
125
+ ############
126
+ # Get response
127
+ # Input: [state]
128
+ # Return: [state, chatbot] + btn_list
129
+ ############
130
+ @spaces.GPU
131
+ def http_bot(state):
132
+ logging.info(f"http_bot.")
133
+
134
+ if state.skip_next:
135
+ yield (state, state.to_gradio_chatbot_public()) + (no_change_btn,) * 2
136
+ return
137
+
138
+ # Retrive prompt
139
+ prompt = state.messages[-1][0][0]
140
+ all_images = state.get_images(return_pil=True)[0]
141
+ pload = {"prompt": prompt, "images": f'List of {len(state.get_images())} images: {all_images}'}
142
+ logging.info(f"==== request ====\n{pload}")
143
+
144
+ # Construct prompt
145
+ cap_msgs, qa_msgs = build_messages(all_images, prompt)
146
+ cap_prompt = processor.apply_chat_template([cap_msgs], tokenize=False, add_generation_prompt=True)
147
+ qa_prompt = processor.apply_chat_template([qa_msgs], tokenize=False, add_generation_prompt=True)
148
+
149
+ image_tensor, _ = process_vision_info(cap_msgs)
150
+ tentative_answer = run_mllm_tentative(image_tensor, cap_prompt, qa_prompt)
151
+ state.append_message(state.roles[1], "# Tentative Response\n\n" + tentative_answer)
152
+ logging.info("# Tentative Response\n\n" + tentative_answer)
153
+ yield (state, state.to_gradio_chatbot_public()) + (disable_btn,) * 2
154
+
155
+ caption_text = run_mllm_caption(image_tensor, cap_prompt, qa_prompt)
156
+ state.append_message(state.roles[1], "# Caption\n\n" + caption_text)
157
+ logging.info("# Caption\n\n" + caption_text)
158
+ yield (state, state.to_gradio_chatbot_public()) + (disable_btn,) * 2
159
+
160
+ final_answer = run_llm_reasoning(caption_text, QUESTION, tentative_answer)
161
+ state.append_message(state.roles[1], "# Final Response\n\n" + final_answer)
162
+ logging.info("# Final Response\n\n" + final_answer)
163
+ yield (state, state.to_gradio_chatbot_public()) + (enable_btn,) * 2
164
+
165
+ ############
166
+ # Layout Markdown
167
+ ############
168
+ title_markdown = ("""
169
+ <div style="display: flex; align-items: center; padding: 20px; border-radius: 10px; background-color: #f0f0f0;">
170
+ <div>
171
+ <h1 style="margin: 0;">RACRO: Perceptual Decoupling for Scalable Multi-modal Reasoning via Reward-Optimized Captioning</h1>
172
+ <h2 style="margin: 10px 0;">📃 <a href="https://www.arxiv.org/abs/2506.04559" style="font-weight: 400;">Paper</a> | 💻 <a href="https://github.com/gyhdog99/RACRO2" style="font-weight: 400;">Code</a> | 🤗 <a href="https://huggingface.co/collections/KaiChen1998/racro-6848ec8c65b3a0bf33d0fbdb" style="font-weight: 400;">HuggingFace</a></h2>
173
+ <p style="margin: 20px 0;">
174
+ <strong>1. RACRO is designed for multi-modal reasoning, and thus, image inputs are <mark>ALWAYS</mark> necessary!</strong><br/>
175
+ <strong>2. Models are deployed with vLLM, which unfortunately, still does not support streaming outputs for MLLMs.</strong>
176
+ </p>
177
+ </div>
178
+ </div>
179
+ """)
180
+
181
+ learn_more_markdown = ("""
182
+ ## Citation
183
+ <pre><code>@article{gou2025perceptual,
184
+ author = {Gou, Yunhao and Chen, Kai and Liu, Zhili and Hong, Lanqing and Jin, Xin and Li, Zhenguo and Kwok, James T. and Zhang, Yu},
185
+ title = {Perceptual Decoupling for Scalable Multi-modal Reasoning via Reward-Optimized Captioning},
186
+ journal = {arXiv preprint arXiv:2506.04559},
187
+ year = {2025},
188
+ }</code></pre>
189
+ """)
190
+
191
+ block_css = """
192
+ #buttons button {
193
+ min-width: min(120px,100%);
194
+ }
195
+ .message-row img {
196
+ margin: 0px !important;
197
+ }
198
+ .avatar-container img {
199
+ padding: 0px !important;
200
+ }
201
+ """
202
+
203
+ ############
204
+ # Layout Demo
205
+ ############
206
+ def build_demo(embed_mode):
207
+ textbox = gr.Textbox(label="Text", show_label=False, placeholder="Enter text and then click 💬 Chat to talk with me ^v^", container=False)
208
+ with gr.Blocks(title="RACRO", theme=gr.themes.Default(), css=block_css) as demo:
209
+ state = gr.State()
210
+ if not embed_mode:
211
+ gr.HTML(title_markdown)
212
+
213
+ ##############
214
+ # Chatbot
215
+ ##############
216
+ with gr.Row(equal_height=True):
217
+ with gr.Column(scale=1):
218
+ imagebox = gr.Image(type="pil", label="Image")
219
+ image_process_mode = gr.Radio(
220
+ ["Crop", "Resize", "Pad", "Default"],
221
+ value="Default",
222
+ label="Preprocess for non-square image", visible=False)
223
+
224
+ gr.Examples(examples=[
225
+ ["./examples/demo_example.png", "When the canister is momentarily stopped by the spring, by what distance $d$ is the spring compressed?"],
226
+ ], inputs=[imagebox, textbox], label='Examples')
227
+
228
+ with gr.Column(scale=8):
229
+ chatbot = gr.Chatbot(
230
+ elem_id="chatbot",
231
+ label="RACRO Chatbot",
232
+ layout="bubble",
233
+ avatar_images=["examples/user_avator.png", "examples/icon_256.png"]
234
+ )
235
+ textbox.render()
236
+ with gr.Row(elem_id="buttons") as button_row:
237
+ submit_btn = gr.Button(value="💬 Chat", variant="primary")
238
+ # stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
239
+ regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
240
+ clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
241
+
242
+ if not embed_mode:
243
+ gr.Markdown(learn_more_markdown)
244
+
245
+ # Register listeners
246
+ btn_list = [regenerate_btn, clear_btn]
247
+ regenerate_btn.click(
248
+ regenerate,
249
+ [state, image_process_mode],
250
+ [state, chatbot, textbox, imagebox] + btn_list
251
+ ).then(
252
+ http_bot,
253
+ [state],
254
+ [state, chatbot] + btn_list,
255
+ )
256
+
257
+ clear_btn.click(
258
+ clear_history,
259
+ None,
260
+ [state, chatbot, textbox, imagebox] + btn_list,
261
+ queue=False
262
+ )
263
+
264
+ # probably mean press enter
265
+ textbox.submit(
266
+ add_text,
267
+ [state, textbox, imagebox, image_process_mode],
268
+ [state, chatbot, textbox, imagebox] + btn_list,
269
+ queue=False
270
+ ).then(
271
+ http_bot,
272
+ [state],
273
+ [state, chatbot] + btn_list,
274
+ )
275
+
276
+ submit_btn.click(
277
+ add_text,
278
+ [state, textbox, imagebox, image_process_mode],
279
+ [state, chatbot, textbox, imagebox] + btn_list
280
+ ).then(
281
+ http_bot,
282
+ [state],
283
+ [state, chatbot] + btn_list,
284
+ )
285
+
286
+ ##############
287
+ # Demo loading
288
+ ##############
289
+ demo.load(
290
+ load_demo_refresh_model_list,
291
+ None,
292
+ [state],
293
+ queue=False
294
+ )
295
+ return demo
296
+
297
+
298
+ parser = argparse.ArgumentParser()
299
+ parser.add_argument("--share", action="store_true")
300
+ parser.add_argument("--embed", action="store_true")
301
+ args = parser.parse_args()
302
+
303
+ demo = build_demo(args.embed)
304
+ demo.queue(
305
+ max_size=10,
306
+ api_open=False
307
+ ).launch(
308
+ favicon_path="./examples/icon_256.png",
309
+ allowed_paths=["/"],
310
+ share=args.share
311
+ )
conversation_public.py ADDED
@@ -0,0 +1,517 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ from enum import auto, Enum
3
+ from typing import List, Tuple
4
+ import base64
5
+ from io import BytesIO
6
+ from PIL import Image
7
+
8
+ import base64
9
+ tts_format = "Please synthesize the speech corresponding to the follwing text.\n"
10
+
11
+ class SeparatorStyle(Enum):
12
+ """Different separator style."""
13
+ SINGLE = auto()
14
+ TWO = auto()
15
+ MPT = auto()
16
+ PLAIN = auto()
17
+ LLAMA_2 = auto()
18
+ GLM4 = auto()
19
+
20
+
21
+ @dataclasses.dataclass
22
+ class Conversation:
23
+ """A class that keeps all conversation history."""
24
+ system: str
25
+ roles: List[str]
26
+ messages: List[List[str]]
27
+ offset: int
28
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
29
+ sep: str = "###"
30
+ sep2: str = None
31
+ version: str = "Unknown"
32
+
33
+ skip_next: bool = False
34
+
35
+ def get_prompt(self):
36
+ messages = self.messages
37
+ if len(messages) > 0 and type(messages[0][1]) is tuple and messages[0][1][1] is not None:
38
+ messages = self.messages.copy()
39
+ init_role, init_msg = messages[0].copy()
40
+ init_msg = init_msg[0].replace("<image>", "").strip()
41
+ if 'mmtag' in self.version:
42
+ messages[0] = (init_role, init_msg)
43
+ messages.insert(0, (self.roles[0], "<Image><image></Image>"))
44
+ messages.insert(1, (self.roles[1], "Received."))
45
+ else:
46
+ messages[0] = (init_role, "<image>\n" + init_msg)
47
+
48
+ if self.sep_style == SeparatorStyle.SINGLE:
49
+ ret = self.system + self.sep
50
+ for role, message in messages:
51
+ if message:
52
+ if type(message) is tuple:
53
+ message, _, _ = message[:3]
54
+ ret += role + ": " + message + self.sep
55
+ else:
56
+ ret += role + ":"
57
+ elif self.sep_style == SeparatorStyle.TWO:
58
+ seps = [self.sep, self.sep2]
59
+ ret = self.system + seps[0]
60
+ for i, (role, message) in enumerate(messages):
61
+ if message:
62
+ if type(message) is tuple:
63
+ message, _, _ = message[:3]
64
+ ret += role + ": " + message + seps[i % 2]
65
+ else:
66
+ ret += role + ":"
67
+ elif self.sep_style == SeparatorStyle.MPT:
68
+ ret = self.system + self.sep
69
+ for role, message in messages:
70
+ if message:
71
+ if type(message) is tuple:
72
+ message, _, _ = message[:3]
73
+ ret += role + message + self.sep
74
+ else:
75
+ ret += role
76
+ elif self.sep_style == SeparatorStyle.LLAMA_2:
77
+ wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n" if len(msg) > 0 else msg
78
+ wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
79
+ ret = ""
80
+
81
+ for i, (role, message) in enumerate(messages):
82
+ if i == 0:
83
+ assert message, "first message should not be none"
84
+ assert role == self.roles[0], "first message should come from user"
85
+ if message:
86
+ if type(message) is tuple:
87
+ message, _, _ = message[:3]
88
+ if i == 0: message = wrap_sys(self.system) + message
89
+ if i % 2 == 0:
90
+ message = wrap_inst(message)
91
+ ret += self.sep + message
92
+ else:
93
+ ret += " " + message + " " + self.sep2
94
+ else:
95
+ ret += ""
96
+ ret = ret.lstrip(self.sep)
97
+ elif self.sep_style == SeparatorStyle.PLAIN:
98
+ seps = [self.sep, self.sep2]
99
+ ret = self.system
100
+ for i, (role, message) in enumerate(messages):
101
+ if message:
102
+ if type(message) is tuple:
103
+ message, _, _ = message[:3]
104
+ ret += message + seps[i % 2]
105
+ else:
106
+ ret += ""
107
+ elif self.sep_style == SeparatorStyle.GLM4:
108
+ role = ("<|user|>", "<|assistant|>")
109
+ ret = self.system + role[0]
110
+ for i, (role, message) in enumerate(messages):
111
+ if message:
112
+ if type(message) is tuple:
113
+ message, _, _ = message[:3]
114
+ ret += self.sep + message + role[(i+1) % 2]
115
+ else:
116
+ ret += ""
117
+ else:
118
+ raise ValueError(f"Invalid style: {self.sep_style}")
119
+
120
+ return ret
121
+
122
+ def append_message(self, role, message):
123
+ if isinstance(self.messages, tuple):
124
+ self.messages += ([role, message],)
125
+ else:
126
+ self.messages.append([role, message])
127
+
128
+ def process_image(self, image, image_process_mode, return_pil=False, image_format='PNG', max_len=1344, min_len=672):
129
+ if image_process_mode == "Pad":
130
+ def expand2square(pil_img, background_color=(122, 116, 104)):
131
+ width, height = pil_img.size
132
+ if width == height:
133
+ return pil_img
134
+ elif width > height:
135
+ result = Image.new(pil_img.mode, (width, width), background_color)
136
+ result.paste(pil_img, (0, (width - height) // 2))
137
+ return result
138
+ else:
139
+ result = Image.new(pil_img.mode, (height, height), background_color)
140
+ result.paste(pil_img, ((height - width) // 2, 0))
141
+ return result
142
+ image = expand2square(image)
143
+ elif image_process_mode in ["Default", "Crop"]:
144
+ pass
145
+ elif image_process_mode == "Resize":
146
+ image = image.resize((336, 336))
147
+ else:
148
+ raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
149
+ if max(image.size) > max_len:
150
+ max_hw, min_hw = max(image.size), min(image.size)
151
+ aspect_ratio = max_hw / min_hw
152
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
153
+ longest_edge = int(shortest_edge * aspect_ratio)
154
+ W, H = image.size
155
+ if H > W:
156
+ H, W = longest_edge, shortest_edge
157
+ else:
158
+ H, W = shortest_edge, longest_edge
159
+ image = image.resize((W, H))
160
+ if return_pil:
161
+ return image
162
+ else:
163
+ buffered = BytesIO()
164
+ image.save(buffered, format=image_format)
165
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
166
+ return img_b64_str
167
+
168
+ def get_images(self, return_pil=False):
169
+ images = []
170
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
171
+ if i % 2 == 0:
172
+ if type(msg) is tuple and msg[1] is not None:
173
+ msg, image, image_process_mode = msg[:3]
174
+ image = self.process_image(image, image_process_mode, return_pil=return_pil)
175
+ images.append(image)
176
+ return images
177
+
178
+ def to_gradio_chatbot(self):
179
+ ret = []
180
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
181
+ if i % 2 == 0:
182
+ if type(msg) is tuple:
183
+ msg, image, image_process_mode = msg
184
+ img_b64_str = self.process_image(
185
+ image, "Default", return_pil=False,
186
+ image_format='JPEG')
187
+ img_str = f'<img src="data:image/jpeg;base64,{img_b64_str}" alt="user upload image" />'
188
+ msg = img_str + msg.replace('<image>', '').strip()
189
+ ret.append([msg, None])
190
+ else:
191
+ ret.append([msg, None])
192
+ else:
193
+ ret[-1][-1] = msg
194
+ return ret
195
+
196
+ def to_gradio_chatbot_public(self):
197
+ ret = []
198
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
199
+ if i % 2 == 0:
200
+ if type(msg) is tuple:
201
+ msg, image, image_process_mode, audio_input = msg
202
+ ret_msg = ""
203
+ if image is not None:
204
+ img_b64_str = self.process_image(
205
+ image, "Default", return_pil=False,
206
+ image_format='JPEG')
207
+ img_str = f'<img src="data:image/jpeg;base64,{img_b64_str}" alt="user upload image" />'
208
+ ret_msg += img_str
209
+ if audio_input is not None:
210
+ audio_b64_str = base64.b64encode(open(audio_input, "rb").read()).decode("utf-8")
211
+ audio_str = f'<audio src="data:audio/wav;base64,{audio_b64_str}" controls ></audio>'
212
+ ret_msg += audio_str
213
+ else:
214
+ ret_msg += msg.replace('<image>', '').replace(tts_format, '').strip()
215
+ ret.append([ret_msg, None])
216
+ else:
217
+ ret.append([msg, None])
218
+ else:
219
+ if type(msg) is tuple:
220
+ audio_b64_str = base64.b64encode(open(msg[1], "rb").read()).decode("utf-8")
221
+ msg = f'<audio src="data:audio/wav;base64,{audio_b64_str}" controls autoplay></audio>'
222
+ ret[-1][-1] = msg
223
+ return ret
224
+
225
+ def copy(self):
226
+ return Conversation(
227
+ system=self.system,
228
+ roles=self.roles,
229
+ messages=[[x, y] for x, y in self.messages],
230
+ offset=self.offset,
231
+ sep_style=self.sep_style,
232
+ sep=self.sep,
233
+ sep2=self.sep2,
234
+ version=self.version)
235
+
236
+ def dict(self):
237
+ if len(self.get_images()) > 0:
238
+ return {
239
+ "system": self.system,
240
+ "roles": self.roles,
241
+ "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
242
+ "offset": self.offset,
243
+ "sep": self.sep,
244
+ "sep2": self.sep2,
245
+ }
246
+ return {
247
+ "system": self.system,
248
+ "roles": self.roles,
249
+ "messages": self.messages,
250
+ "offset": self.offset,
251
+ "sep": self.sep,
252
+ "sep2": self.sep2,
253
+ }
254
+
255
+
256
+ conv_vicuna_v0 = Conversation(
257
+ system="A chat between a curious human and an artificial intelligence assistant. "
258
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
259
+ roles=("Human", "Assistant"),
260
+ messages=(
261
+ ("Human", "What are the key differences between renewable and non-renewable energy sources?"),
262
+ ("Assistant",
263
+ "Renewable energy sources are those that can be replenished naturally in a relatively "
264
+ "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
265
+ "Non-renewable energy sources, on the other hand, are finite and will eventually be "
266
+ "depleted, such as coal, oil, and natural gas. Here are some key differences between "
267
+ "renewable and non-renewable energy sources:\n"
268
+ "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
269
+ "energy sources are finite and will eventually run out.\n"
270
+ "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
271
+ "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
272
+ "and other negative effects.\n"
273
+ "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
274
+ "have lower operational costs than non-renewable sources.\n"
275
+ "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
276
+ "locations than non-renewable sources.\n"
277
+ "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
278
+ "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
279
+ "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
280
+ "non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
281
+ ),
282
+ offset=2,
283
+ sep_style=SeparatorStyle.SINGLE,
284
+ sep="###",
285
+ )
286
+
287
+ conv_vicuna_v1 = Conversation(
288
+ system="A chat between a curious user and an artificial intelligence assistant. "
289
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
290
+ roles=("USER", "ASSISTANT"),
291
+ version="v1",
292
+ messages=(),
293
+ offset=0,
294
+ sep_style=SeparatorStyle.TWO,
295
+ sep=" ",
296
+ sep2="</s>",
297
+ )
298
+
299
+ conv_llama_2 = Conversation(
300
+ system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
301
+
302
+ If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
303
+ roles=("USER", "ASSISTANT"),
304
+ version="llama_v2",
305
+ messages=(),
306
+ offset=0,
307
+ sep_style=SeparatorStyle.LLAMA_2,
308
+ sep="<s>",
309
+ sep2="</s>",
310
+ )
311
+
312
+ conv_llava_llama_2 = Conversation(
313
+ system="You are a helpful language and vision assistant. "
314
+ "You are able to understand the visual content that the user provides, "
315
+ "and assist the user with a variety of tasks using natural language.",
316
+ roles=("USER", "ASSISTANT"),
317
+ version="llama_v2",
318
+ messages=(),
319
+ offset=0,
320
+ sep_style=SeparatorStyle.LLAMA_2,
321
+ sep="<s>",
322
+ sep2="</s>",
323
+ )
324
+
325
+ conv_mpt = Conversation(
326
+ system="""<|im_start|>system
327
+ A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
328
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
329
+ version="mpt",
330
+ messages=(),
331
+ offset=0,
332
+ sep_style=SeparatorStyle.MPT,
333
+ sep="<|im_end|>",
334
+ )
335
+
336
+ conv_llava_plain = Conversation(
337
+ system="",
338
+ roles=("", ""),
339
+ messages=(),
340
+ offset=0,
341
+ sep_style=SeparatorStyle.PLAIN,
342
+ sep="\n",
343
+ )
344
+
345
+ conv_llava_v0 = Conversation(
346
+ system="A chat between a curious human and an artificial intelligence assistant. "
347
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
348
+ roles=("Human", "Assistant"),
349
+ messages=(),
350
+ offset=0,
351
+ sep_style=SeparatorStyle.SINGLE,
352
+ sep="###",
353
+ )
354
+
355
+ conv_llava_v0_mmtag = Conversation(
356
+ system="A chat between a curious user and an artificial intelligence assistant. "
357
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
358
+ "The visual content will be provided with the following format: <Image>visual content</Image>.",
359
+ roles=("Human", "Assistant"),
360
+ messages=(
361
+ ),
362
+ offset=0,
363
+ sep_style=SeparatorStyle.SINGLE,
364
+ sep="###",
365
+ version="v0_mmtag",
366
+ )
367
+
368
+ conv_llava_v1 = Conversation(
369
+ system="A chat between a curious human and an artificial intelligence assistant. "
370
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
371
+ roles=("USER", "ASSISTANT"),
372
+ version="v1",
373
+ messages=(),
374
+ offset=0,
375
+ sep_style=SeparatorStyle.TWO,
376
+ sep=" ",
377
+ sep2="</s>",
378
+ )
379
+
380
+ conv_llava_v1_mmtag = Conversation(
381
+ system="A chat between a curious user and an artificial intelligence assistant. "
382
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
383
+ "The visual content will be provided with the following format: <Image>visual content</Image>.",
384
+ roles=("USER", "ASSISTANT"),
385
+ messages=(),
386
+ offset=0,
387
+ sep_style=SeparatorStyle.TWO,
388
+ sep=" ",
389
+ sep2="</s>",
390
+ version="v1_mmtag",
391
+ )
392
+
393
+ conv_mistral_instruct = Conversation(
394
+ system="",
395
+ roles=("USER", "ASSISTANT"),
396
+ version="llama_v2",
397
+ messages=(),
398
+ offset=0,
399
+ sep_style=SeparatorStyle.LLAMA_2,
400
+ sep="",
401
+ sep2="</s>",
402
+ )
403
+
404
+ conv_chatml_direct = Conversation(
405
+ system="""<|im_start|>system
406
+ Answer the questions.""",
407
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
408
+ version="mpt",
409
+ messages=(),
410
+ offset=0,
411
+ sep_style=SeparatorStyle.MPT,
412
+ sep="<|im_end|>",
413
+ )
414
+
415
+ conv_llama3 = Conversation(
416
+ system="""<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nA chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.""",
417
+ roles=("<|start_header_id|>user<|end_header_id|>\n\n", "<|start_header_id|>assistant<|end_header_id|>\n\n"),
418
+ version="llama3",
419
+ messages=(),
420
+ offset=0,
421
+ sep_style=SeparatorStyle.MPT,
422
+ sep="<|eot_id|>",
423
+ )
424
+
425
+ conv_llama3_demo = Conversation(
426
+ system="""<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nA chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. Your name is emova, and you are purely developed by the emova Team.""",
427
+ roles=("<|start_header_id|>user<|end_header_id|>\n\n", "<|start_header_id|>assistant<|end_header_id|>\n\n"),
428
+ version="llama3_demo",
429
+ messages=(),
430
+ offset=0,
431
+ sep_style=SeparatorStyle.MPT,
432
+ sep="<|eot_id|>",
433
+ )
434
+
435
+ conv_llama3_without_system = Conversation(
436
+ system="",
437
+ roles=("<|start_header_id|>user<|end_header_id|>\n\n", "<|start_header_id|>assistant<|end_header_id|>\n\n"),
438
+ version="llama3_without_system",
439
+ messages=(),
440
+ offset=0,
441
+ sep_style=SeparatorStyle.MPT,
442
+ sep="<|eot_id|>",
443
+ )
444
+
445
+ conv_llama3_without_systemV2 = Conversation(
446
+ system="",
447
+ roles=("user:", "assistant:"),
448
+ version="llama3_without_systemv2",
449
+ messages=(),
450
+ offset=0,
451
+ sep_style=SeparatorStyle.MPT,
452
+ sep="\n\n",
453
+ )
454
+
455
+ conv_qwen2 = Conversation(
456
+ system='<|im_start|>system\nYou are a helpful assistant.',
457
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
458
+ version="qwen2",
459
+ messages=(),
460
+ offset=0,
461
+ sep_style=SeparatorStyle.MPT,
462
+ sep="<|im_end|>\n",
463
+ )
464
+
465
+ conv_qwen2_demo = Conversation(
466
+ system='<|im_start|>system\nYou are a helpful assistant. Your name is emova, and you are purely developed by the emova Team.',
467
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
468
+ version="qwen2_demo",
469
+ messages=(),
470
+ offset=0,
471
+ sep_style=SeparatorStyle.MPT,
472
+ sep="<|im_end|>\n",
473
+ )
474
+
475
+ conv_glm4 = Conversation(
476
+ system='[gMASK]<sop>',
477
+ roles=("<|user|>", "<|assistant|>"),
478
+ version="glm4",
479
+ messages=(),
480
+ offset=0,
481
+ sep_style=SeparatorStyle.GLM4,
482
+ sep="\n",
483
+ )
484
+
485
+
486
+ default_conversation = conv_vicuna_v1
487
+ conv_templates = {
488
+ "default": conv_vicuna_v0,
489
+ "v0": conv_vicuna_v0,
490
+ "v1": conv_vicuna_v1,
491
+ "vicuna_v1": conv_vicuna_v1,
492
+ "llama_2": conv_llama_2,
493
+ "mistral_instruct": conv_mistral_instruct,
494
+ "chatml_direct": conv_chatml_direct,
495
+ "mistral_direct": conv_chatml_direct,
496
+
497
+ "plain": conv_llava_plain,
498
+ "v0_plain": conv_llava_plain,
499
+ "llava_v0": conv_llava_v0,
500
+ "v0_mmtag": conv_llava_v0_mmtag,
501
+ "llava_v1": conv_llava_v1,
502
+ "v1_mmtag": conv_llava_v1_mmtag,
503
+ "llava_llama_2": conv_llava_llama_2,
504
+ "llama3": conv_llama3,
505
+ "llama3_demo": conv_llama3_demo,
506
+ "llama3_without_system": conv_llama3_without_system,
507
+ "conv_llama3_without_systemV2": conv_llama3_without_systemV2,
508
+
509
+ "mpt": conv_mpt,
510
+ "qwen2": conv_qwen2,
511
+ "qwen2_demo": conv_qwen2_demo,
512
+ "glm4": conv_glm4,
513
+ }
514
+
515
+
516
+ if __name__ == "__main__":
517
+ print(default_conversation.get_prompt())
examples/icon_256.png ADDED
examples/image-text/demo_example.jpg ADDED
examples/user_avator.png ADDED
gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__/
requirements.txt ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # requirements.txt records the full set of dependencies for development
2
+ torch==2.6.0
3
+ accelerate
4
+ codetiming
5
+ datasets
6
+ dill
7
+ # flash-attn
8
+ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
9
+ hydra-core
10
+ liger-kernel
11
+ numpy
12
+ pandas
13
+ datasets
14
+ peft
15
+ pyarrow>=15.0.0
16
+ pybind11
17
+ pylatexenc
18
+ pylint==3.3.6
19
+ qwen_vl_utils
20
+ ray[default]
21
+ tensordict<=0.6.2
22
+ torchdata
23
+ transformers
24
+ vllm==0.8.2
25
+ wandb
26
+ word2number
27
+ math_verify
28
+ mathruler
29
+ tensorboard
30
+ transformers==4.51.0