RACRO-demo / app.py
KaiChen1998's picture
use flash attention
694f7e2 verified
import os
import argparse
import traceback
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
logging.getLogger("http").setLevel(logging.WARNING)
logging.getLogger("httpx").setLevel(logging.WARNING)
import spaces
import gradio as gr
from conversation_public import default_conversation
auth_token = os.environ.get("TOKEN_FROM_SECRET")
##########################################
# LLM part
##########################################
import torch
from transformers import AutoProcessor, AutoTokenizer, AutoModelForCausalLM
from transformers import Qwen2_5_VLForConditionalGeneration, TextIteratorStreamer
from qwen_vl_utils import process_vision_info
from threading import Thread
# === Prompts ===
SYSTEM_PROMPT_LLM = "You are a helpful assistant."
SYSTEM_PROMPT_CAP = "You are given an image and a relevant question. Based on the query, please describe the image in details. Do not try to answer the question."
CAPTION_PROMPT = "Question: {}\nPlease describe the image. DO NOT try to answer the question!"
LLM_PROMPT = """In the following text, you will receive a detailed caption of an image and a relevant question. In addition, you will be provided with a tentative model response. You goal is to answer the question using these information.\n\n### The detailed caption of the provided image: {}\n\n### Note that the caption might contain incorrect solutions, do not be misguided by them.\n\n### A problem to be solved: {}\n\n### A tentative model response: {}\n\n### Note that the above tentative response might be inaccurate (due to calculation errors, incorrect logic/reasoning and so on), under such a case, please ignore it and give your own solutions. However, if you do not have enough evidence to show it is wrong, please output the tentative response."""
# === Initialize Models ===
MLLM_MODEL_PATH = "KaiChen1998/RACRO-7B-CRO-GRPO"
LLM_MODEL_PATH = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"
processor = AutoProcessor.from_pretrained(MLLM_MODEL_PATH)
tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_PATH)
mllm = Qwen2_5_VLForConditionalGeneration.from_pretrained(MLLM_MODEL_PATH, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2", device_map="auto")
llm = AutoModelForCausalLM.from_pretrained(LLM_MODEL_PATH, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2", device_map="auto")
mllm_sampling = dict(do_sample=False, temperature=0, max_new_tokens=8192)
llm_sampling = dict(temperature=0.6, top_p=0.95, max_new_tokens=8192)
# === Build Prompts ===
def build_messages(image_path, question):
cap_msgs = [
{"role": "system", "content": SYSTEM_PROMPT_CAP},
{"role": "user", "content": [{"type": "image", "image": image_path}, {"type": "text", "text": CAPTION_PROMPT.format(question)}]}
]
qa_msgs = [
{"role": "user", "content": [{"type": "image", "image": image_path}, {"type": "text", "text": question + " Please think step by step. The final answer MUST BE put in \\boxed{}."}]}
]
return cap_msgs, qa_msgs
##########################################
# Streaming
##########################################
mllm_streamer = TextIteratorStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
llm_streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
def stream_response(model, inputs, streamer, prompt, gen_kwargs):
thread = Thread(target=model.generate, kwargs=dict(
streamer=streamer,
**inputs,
**gen_kwargs
)
)
thread.start()
generated_text = prompt
for new_text in streamer:
generated_text += new_text
yield generated_text
##########################################
# Gradio part
##########################################
no_change_btn = gr.Button()
enable_btn = gr.Button(interactive=True)
disable_btn = gr.Button(interactive=False)
server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
server_oom_msg = "**OUT OF GPU MEMORY DETECTED. PLEASE DECREASE THE MAX OUTPUT TOKENS AND REGENERATE.**"
def load_demo_refresh_model_list():
logging.info(f"load_demo.")
state = default_conversation.copy()
return state
def regenerate(state, image_process_mode):
logging.info(f"regenerate.")
state.messages = state.messages[:-3]
prev_human_msg = state.messages[-1]
if type(prev_human_msg[1]) in (tuple, list):
prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode, *prev_human_msg[1][3:])
state.skip_next = False
return (state, state.to_gradio_chatbot_public(), "", None) + (disable_btn,) * 2
def clear_history():
logging.info(f"clear_history.")
state = default_conversation.copy()
return (state, state.to_gradio_chatbot_public(), "", None) + (disable_btn,) * 2
############
# Show prompt in the chatbot
# Input: [state, textbox, imagebox, image_process_mode]
# Return: [state, chatbot, textbox, imagebox] + btn_list
############
def add_text(state, text, image, image_process_mode):
# Input legality checking
logging.info(f"add_text. len: {len(text)}")
if len(text) <= 0 or image is None:
state.skip_next = True
return (state, state.to_gradio_chatbot_public(), "", None) + (no_change_btn,) * 2
# Deal with image inputs
if image is not None:
text = (text, image, image_process_mode, None)
# Single round only
state = default_conversation.copy()
state.append_message(state.roles[0], text)
state.skip_next = False
logging.info(str(state.messages))
return (state, state.to_gradio_chatbot_public(), "", None) + (disable_btn,) * 2
############
# Get response
# Input: [state]
# Return: [state, chatbot] + btn_list
############
@spaces.GPU
def http_bot(state):
logging.info(f"http_bot.")
if state.skip_next:
yield (state, state.to_gradio_chatbot_public()) + (no_change_btn,) * 2
return
# Retrive prompt
prompt = state.messages[-1][-1][0]
all_images = state.get_images(return_pil=True)[0]
pload = {"prompt": prompt, "images": f'List of {len(state.get_images())} images: {all_images}'}
logging.info(f"==== request ====\n{pload}")
# Construct prompt
cap_msgs, qa_msgs = build_messages(all_images, prompt)
cap_prompt = processor.apply_chat_template(cap_msgs, tokenize=False, add_generation_prompt=True)
qa_prompt = processor.apply_chat_template(qa_msgs, tokenize=False, add_generation_prompt=True)
image_tensor, _ = process_vision_info(cap_msgs)
cap_inputs = processor(text=cap_prompt, images=image_tensor, return_tensors="pt").to(mllm.device)
qa_inputs = processor(text=qa_prompt, images=image_tensor, return_tensors="pt").to(mllm.device)
# Step 1: Tentative Response
state.append_message(state.roles[1], "# Tentative Response\n\n▌")
try:
for generated_text in stream_response(mllm, qa_inputs, mllm_streamer, qa_prompt, mllm_sampling):
output = generated_text[len(qa_prompt):].strip()
state.messages[-1][-1] = "# Tentative Response\n\n" + output + "▌"
yield (state, state.to_gradio_chatbot_public()) + (disable_btn,) * 2
except Exception as e:
os.system("nvidia-smi")
logging.info(traceback.print_exc())
state.messages[-1][-1] = server_error_msg
yield (state, state.to_gradio_chatbot_public()) + (enable_btn,) * 2
return
tentative_answer = output
logging.info(f"Tentative Response: {tentative_answer}")
state.messages[-1][-1] = state.messages[-1][-1][:-1]
yield (state, state.to_gradio_chatbot_public()) + (disable_btn,) * 2
# Step 2: Query-conditioned Caption
state.append_message(state.roles[1], "# Query-conditioned Caption\n\n▌")
try:
for generated_text in stream_response(mllm, cap_inputs, mllm_streamer, cap_prompt, mllm_sampling):
output = generated_text[len(cap_prompt):].strip()
state.messages[-1][-1] = "# Query-conditioned Caption\n\n" + output + "▌"
yield (state, state.to_gradio_chatbot_public()) + (disable_btn,) * 2
except Exception as e:
os.system("nvidia-smi")
logging.info(traceback.print_exc())
state.messages[-1][-1] = server_error_msg
yield (state, state.to_gradio_chatbot_public()) + (enable_btn,) * 2
return
caption_text = output
logging.info(f"Query-conditioned Caption: {caption_text}")
state.messages[-1][-1] = state.messages[-1][-1][:-1]
yield (state, state.to_gradio_chatbot_public()) + (disable_btn,) * 2
# Step 3: Text-only Reasoning
reason_msgs = [
{"role": "system", "content": SYSTEM_PROMPT_LLM},
{"role": "user", "content": LLM_PROMPT.format(caption_text, prompt, tentative_answer)}
]
reason_prompt = tokenizer.apply_chat_template(reason_msgs, tokenize=False, add_generation_prompt=True)
reason_inputs = tokenizer(reason_prompt, return_tensors="pt").to(llm.device)
state.append_message(state.roles[1], "# Text-only Reasoning\n\n▌")
try:
for generated_text in stream_response(llm, reason_inputs, llm_streamer, reason_prompt, llm_sampling):
output = generated_text[len(reason_prompt):].strip()
state.messages[-1][-1] = "# Text-only Reasoning\n\n" + output + "▌"
yield (state, state.to_gradio_chatbot_public()) + (disable_btn,) * 2
except Exception as e:
os.system("nvidia-smi")
logging.info(traceback.print_exc())
state.messages[-1][-1] = server_error_msg
yield (state, state.to_gradio_chatbot_public()) + (enable_btn,) * 2
return
final_response = output
logging.info(f"Text-only Reasoning: {final_response}")
state.messages[-1][-1] = state.messages[-1][-1][:-1]
yield (state, state.to_gradio_chatbot_public()) + (enable_btn,) * 2
############
# Layout Markdown
############
title_markdown = ("""
<div style="display: flex; align-items: center; padding: 20px; border-radius: 10px; background-color: #f0f0f0;">
<div>
<h1 style="margin: 0;">RACRO: Perceptual Decoupling for Scalable Multi-modal Reasoning via Reward-Optimized Captioning</h1>
<h2 style="margin: 10px 0;">📃 <a href="https://www.arxiv.org/abs/2506.04559" style="font-weight: 400;">Paper</a> | 💻 <a href="https://github.com/gyhdog99/RACRO2" style="font-weight: 400;">Code</a> | 🤗 <a href="https://huggingface.co/collections/KaiChen1998/racro-6848ec8c65b3a0bf33d0fbdb" style="font-weight: 400;">HuggingFace</a></h2>
<p style="margin: 20px 0;">
<strong>1. RACRO is designed for multi-modal reasoning, and thus, image inputs are <mark>ALWAYS</mark> necessary!</strong>
</p>
</div>
</div>
""")
learn_more_markdown = ("""
## Citation
<pre><code>@article{gou2025perceptual,
author = {Gou, Yunhao and Chen, Kai and Liu, Zhili and Hong, Lanqing and Jin, Xin and Li, Zhenguo and Kwok, James T. and Zhang, Yu},
title = {Perceptual Decoupling for Scalable Multi-modal Reasoning via Reward-Optimized Captioning},
journal = {arXiv preprint arXiv:2506.04559},
year = {2025},
}</code></pre>
""")
block_css = """
#buttons button {
min-width: min(120px,100%);
}
.message-row img {
margin: 0px !important;
}
.avatar-container img {
padding: 0px !important;
}
"""
############
# Layout Demo
############
def build_demo(embed_mode):
textbox = gr.Textbox(label="Text", show_label=False, placeholder="Enter text and then click 💬 Chat to talk with me ^v^", container=False)
with gr.Blocks(title="RACRO", theme=gr.themes.Default(), css=block_css) as demo:
state = gr.State()
if not embed_mode:
gr.HTML(title_markdown)
##############
# Chatbot
##############
with gr.Row(equal_height=True):
with gr.Column(scale=1):
imagebox = gr.Image(type="pil", label="Image")
image_process_mode = gr.Radio(
["Crop", "Resize", "Pad", "Default"],
value="Default",
label="Preprocess for non-square image", visible=False)
gr.Examples(examples=[
["./examples/image-text/demo_example.jpg", "When the canister is momentarily stopped by the spring, by what distance $d$ is the spring compressed?"],
], inputs=[imagebox, textbox], label='Examples')
with gr.Column(scale=8):
chatbot = gr.Chatbot(
type="messages",
elem_id="chatbot",
label="RACRO Chatbot",
layout="bubble",
avatar_images=["examples/user_avator.png", "examples/icon_256.png"]
)
textbox.render()
with gr.Row(elem_id="buttons") as button_row:
submit_btn = gr.Button(value="💬 Chat", variant="primary")
# stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
if not embed_mode:
gr.Markdown(learn_more_markdown)
# Register listeners
btn_list = [regenerate_btn, clear_btn]
regenerate_btn.click(
regenerate,
[state, image_process_mode],
[state, chatbot, textbox, imagebox] + btn_list
).then(
http_bot,
[state],
[state, chatbot] + btn_list,
)
clear_btn.click(
clear_history,
None,
[state, chatbot, textbox, imagebox] + btn_list,
queue=False
)
# probably mean press enter
textbox.submit(
add_text,
[state, textbox, imagebox, image_process_mode],
[state, chatbot, textbox, imagebox] + btn_list,
queue=False
).then(
http_bot,
[state],
[state, chatbot] + btn_list,
)
submit_btn.click(
add_text,
[state, textbox, imagebox, image_process_mode],
[state, chatbot, textbox, imagebox] + btn_list
).then(
http_bot,
[state],
[state, chatbot] + btn_list,
)
##############
# Demo loading
##############
demo.load(
load_demo_refresh_model_list,
None,
[state],
queue=False
)
return demo
parser = argparse.ArgumentParser()
parser.add_argument("--share", action="store_true")
parser.add_argument("--embed", action="store_true")
args = parser.parse_args()
demo = build_demo(args.embed)
demo.queue(
max_size=10,
api_open=False
).launch(
favicon_path="./examples/icon_256.png",
allowed_paths=["/"],
share=args.share
)