import os import argparse import traceback import logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S') logging.getLogger("http").setLevel(logging.WARNING) logging.getLogger("httpx").setLevel(logging.WARNING) import spaces import gradio as gr from conversation_public import default_conversation auth_token = os.environ.get("TOKEN_FROM_SECRET") ########################################## # LLM part ########################################## import torch from transformers import AutoProcessor, AutoTokenizer, AutoModelForCausalLM from transformers import Qwen2_5_VLForConditionalGeneration, TextIteratorStreamer from qwen_vl_utils import process_vision_info from threading import Thread # === Prompts === SYSTEM_PROMPT_LLM = "You are a helpful assistant." SYSTEM_PROMPT_CAP = "You are given an image and a relevant question. Based on the query, please describe the image in details. Do not try to answer the question." CAPTION_PROMPT = "Question: {}\nPlease describe the image. DO NOT try to answer the question!" LLM_PROMPT = """In the following text, you will receive a detailed caption of an image and a relevant question. In addition, you will be provided with a tentative model response. You goal is to answer the question using these information.\n\n### The detailed caption of the provided image: {}\n\n### Note that the caption might contain incorrect solutions, do not be misguided by them.\n\n### A problem to be solved: {}\n\n### A tentative model response: {}\n\n### Note that the above tentative response might be inaccurate (due to calculation errors, incorrect logic/reasoning and so on), under such a case, please ignore it and give your own solutions. However, if you do not have enough evidence to show it is wrong, please output the tentative response.""" # === Initialize Models === MLLM_MODEL_PATH = "KaiChen1998/RACRO-7B-CRO-GRPO" LLM_MODEL_PATH = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B" processor = AutoProcessor.from_pretrained(MLLM_MODEL_PATH) tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_PATH) mllm = Qwen2_5_VLForConditionalGeneration.from_pretrained(MLLM_MODEL_PATH, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2", device_map="auto") llm = AutoModelForCausalLM.from_pretrained(LLM_MODEL_PATH, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2", device_map="auto") mllm_sampling = dict(do_sample=False, temperature=0, max_new_tokens=8192) llm_sampling = dict(temperature=0.6, top_p=0.95, max_new_tokens=8192) # === Build Prompts === def build_messages(image_path, question): cap_msgs = [ {"role": "system", "content": SYSTEM_PROMPT_CAP}, {"role": "user", "content": [{"type": "image", "image": image_path}, {"type": "text", "text": CAPTION_PROMPT.format(question)}]} ] qa_msgs = [ {"role": "user", "content": [{"type": "image", "image": image_path}, {"type": "text", "text": question + " Please think step by step. The final answer MUST BE put in \\boxed{}."}]} ] return cap_msgs, qa_msgs ########################################## # Streaming ########################################## mllm_streamer = TextIteratorStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15) llm_streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15) def stream_response(model, inputs, streamer, prompt, gen_kwargs): thread = Thread(target=model.generate, kwargs=dict( streamer=streamer, **inputs, **gen_kwargs ) ) thread.start() generated_text = prompt for new_text in streamer: generated_text += new_text yield generated_text ########################################## # Gradio part ########################################## no_change_btn = gr.Button() enable_btn = gr.Button(interactive=True) disable_btn = gr.Button(interactive=False) server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" server_oom_msg = "**OUT OF GPU MEMORY DETECTED. PLEASE DECREASE THE MAX OUTPUT TOKENS AND REGENERATE.**" def load_demo_refresh_model_list(): logging.info(f"load_demo.") state = default_conversation.copy() return state def regenerate(state, image_process_mode): logging.info(f"regenerate.") state.messages = state.messages[:-3] prev_human_msg = state.messages[-1] if type(prev_human_msg[1]) in (tuple, list): prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode, *prev_human_msg[1][3:]) state.skip_next = False return (state, state.to_gradio_chatbot_public(), "", None) + (disable_btn,) * 2 def clear_history(): logging.info(f"clear_history.") state = default_conversation.copy() return (state, state.to_gradio_chatbot_public(), "", None) + (disable_btn,) * 2 ############ # Show prompt in the chatbot # Input: [state, textbox, imagebox, image_process_mode] # Return: [state, chatbot, textbox, imagebox] + btn_list ############ def add_text(state, text, image, image_process_mode): # Input legality checking logging.info(f"add_text. len: {len(text)}") if len(text) <= 0 or image is None: state.skip_next = True return (state, state.to_gradio_chatbot_public(), "", None) + (no_change_btn,) * 2 # Deal with image inputs if image is not None: text = (text, image, image_process_mode, None) # Single round only state = default_conversation.copy() state.append_message(state.roles[0], text) state.skip_next = False logging.info(str(state.messages)) return (state, state.to_gradio_chatbot_public(), "", None) + (disable_btn,) * 2 ############ # Get response # Input: [state] # Return: [state, chatbot] + btn_list ############ @spaces.GPU def http_bot(state): logging.info(f"http_bot.") if state.skip_next: yield (state, state.to_gradio_chatbot_public()) + (no_change_btn,) * 2 return # Retrive prompt prompt = state.messages[-1][-1][0] all_images = state.get_images(return_pil=True)[0] pload = {"prompt": prompt, "images": f'List of {len(state.get_images())} images: {all_images}'} logging.info(f"==== request ====\n{pload}") # Construct prompt cap_msgs, qa_msgs = build_messages(all_images, prompt) cap_prompt = processor.apply_chat_template(cap_msgs, tokenize=False, add_generation_prompt=True) qa_prompt = processor.apply_chat_template(qa_msgs, tokenize=False, add_generation_prompt=True) image_tensor, _ = process_vision_info(cap_msgs) cap_inputs = processor(text=cap_prompt, images=image_tensor, return_tensors="pt").to(mllm.device) qa_inputs = processor(text=qa_prompt, images=image_tensor, return_tensors="pt").to(mllm.device) # Step 1: Tentative Response state.append_message(state.roles[1], "# Tentative Response\n\n▌") try: for generated_text in stream_response(mllm, qa_inputs, mllm_streamer, qa_prompt, mllm_sampling): output = generated_text[len(qa_prompt):].strip() state.messages[-1][-1] = "# Tentative Response\n\n" + output + "▌" yield (state, state.to_gradio_chatbot_public()) + (disable_btn,) * 2 except Exception as e: os.system("nvidia-smi") logging.info(traceback.print_exc()) state.messages[-1][-1] = server_error_msg yield (state, state.to_gradio_chatbot_public()) + (enable_btn,) * 2 return tentative_answer = output logging.info(f"Tentative Response: {tentative_answer}") state.messages[-1][-1] = state.messages[-1][-1][:-1] yield (state, state.to_gradio_chatbot_public()) + (disable_btn,) * 2 # Step 2: Query-conditioned Caption state.append_message(state.roles[1], "# Query-conditioned Caption\n\n▌") try: for generated_text in stream_response(mllm, cap_inputs, mllm_streamer, cap_prompt, mllm_sampling): output = generated_text[len(cap_prompt):].strip() state.messages[-1][-1] = "# Query-conditioned Caption\n\n" + output + "▌" yield (state, state.to_gradio_chatbot_public()) + (disable_btn,) * 2 except Exception as e: os.system("nvidia-smi") logging.info(traceback.print_exc()) state.messages[-1][-1] = server_error_msg yield (state, state.to_gradio_chatbot_public()) + (enable_btn,) * 2 return caption_text = output logging.info(f"Query-conditioned Caption: {caption_text}") state.messages[-1][-1] = state.messages[-1][-1][:-1] yield (state, state.to_gradio_chatbot_public()) + (disable_btn,) * 2 # Step 3: Text-only Reasoning reason_msgs = [ {"role": "system", "content": SYSTEM_PROMPT_LLM}, {"role": "user", "content": LLM_PROMPT.format(caption_text, prompt, tentative_answer)} ] reason_prompt = tokenizer.apply_chat_template(reason_msgs, tokenize=False, add_generation_prompt=True) reason_inputs = tokenizer(reason_prompt, return_tensors="pt").to(llm.device) state.append_message(state.roles[1], "# Text-only Reasoning\n\n▌") try: for generated_text in stream_response(llm, reason_inputs, llm_streamer, reason_prompt, llm_sampling): output = generated_text[len(reason_prompt):].strip() state.messages[-1][-1] = "# Text-only Reasoning\n\n" + output + "▌" yield (state, state.to_gradio_chatbot_public()) + (disable_btn,) * 2 except Exception as e: os.system("nvidia-smi") logging.info(traceback.print_exc()) state.messages[-1][-1] = server_error_msg yield (state, state.to_gradio_chatbot_public()) + (enable_btn,) * 2 return final_response = output logging.info(f"Text-only Reasoning: {final_response}") state.messages[-1][-1] = state.messages[-1][-1][:-1] yield (state, state.to_gradio_chatbot_public()) + (enable_btn,) * 2 ############ # Layout Markdown ############ title_markdown = ("""

RACRO: Perceptual Decoupling for Scalable Multi-modal Reasoning via Reward-Optimized Captioning

📃 Paper | 💻 Code | 🤗 HuggingFace

1. RACRO is designed for multi-modal reasoning, and thus, image inputs are ALWAYS necessary!

""") learn_more_markdown = (""" ## Citation
@article{gou2025perceptual,
  author    = {Gou, Yunhao and Chen, Kai and Liu, Zhili and Hong, Lanqing and Jin, Xin and Li, Zhenguo and Kwok, James T. and Zhang, Yu}, 
  title     = {Perceptual Decoupling for Scalable Multi-modal Reasoning via Reward-Optimized Captioning},
  journal   = {arXiv preprint arXiv:2506.04559},
  year      = {2025},
}
""") block_css = """ #buttons button { min-width: min(120px,100%); } .message-row img { margin: 0px !important; } .avatar-container img { padding: 0px !important; } """ ############ # Layout Demo ############ def build_demo(embed_mode): textbox = gr.Textbox(label="Text", show_label=False, placeholder="Enter text and then click 💬 Chat to talk with me ^v^", container=False) with gr.Blocks(title="RACRO", theme=gr.themes.Default(), css=block_css) as demo: state = gr.State() if not embed_mode: gr.HTML(title_markdown) ############## # Chatbot ############## with gr.Row(equal_height=True): with gr.Column(scale=1): imagebox = gr.Image(type="pil", label="Image") image_process_mode = gr.Radio( ["Crop", "Resize", "Pad", "Default"], value="Default", label="Preprocess for non-square image", visible=False) gr.Examples(examples=[ ["./examples/image-text/demo_example.jpg", "When the canister is momentarily stopped by the spring, by what distance $d$ is the spring compressed?"], ], inputs=[imagebox, textbox], label='Examples') with gr.Column(scale=8): chatbot = gr.Chatbot( type="messages", elem_id="chatbot", label="RACRO Chatbot", layout="bubble", avatar_images=["examples/user_avator.png", "examples/icon_256.png"] ) textbox.render() with gr.Row(elem_id="buttons") as button_row: submit_btn = gr.Button(value="💬 Chat", variant="primary") # stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False) regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False) clear_btn = gr.Button(value="🗑️ Clear", interactive=False) if not embed_mode: gr.Markdown(learn_more_markdown) # Register listeners btn_list = [regenerate_btn, clear_btn] regenerate_btn.click( regenerate, [state, image_process_mode], [state, chatbot, textbox, imagebox] + btn_list ).then( http_bot, [state], [state, chatbot] + btn_list, ) clear_btn.click( clear_history, None, [state, chatbot, textbox, imagebox] + btn_list, queue=False ) # probably mean press enter textbox.submit( add_text, [state, textbox, imagebox, image_process_mode], [state, chatbot, textbox, imagebox] + btn_list, queue=False ).then( http_bot, [state], [state, chatbot] + btn_list, ) submit_btn.click( add_text, [state, textbox, imagebox, image_process_mode], [state, chatbot, textbox, imagebox] + btn_list ).then( http_bot, [state], [state, chatbot] + btn_list, ) ############## # Demo loading ############## demo.load( load_demo_refresh_model_list, None, [state], queue=False ) return demo parser = argparse.ArgumentParser() parser.add_argument("--share", action="store_true") parser.add_argument("--embed", action="store_true") args = parser.parse_args() demo = build_demo(args.embed) demo.queue( max_size=10, api_open=False ).launch( favicon_path="./examples/icon_256.png", allowed_paths=["/"], share=args.share )