Med-R1
Med-R1 is a reinforcement learning (RL)-enhanced vision-language model (VLM) designed for medical reasoning across 8 imaging modalities (CT, MRI, Ultrasound, Dermoscopy, Fundus Photography, Optical Coherence Tomography (OCT), Microscopy, and X-ray) and 5 key tasks (modality recognition, anatomy identification, disease diagnosis, lesion grading, and biological attribute analysis). Using Group Relative Policy Optimization (GRPO), Med-R1 improves generalization and trustworthiness, surpassing Qwen2-VL-2B by 29.94% and even outperforming the much larger Qwen2-VL-72B. Our model checkpoints provide researchers with a powerful tool for advancing medical AI with RL-driven enhancements.
Description of Models
Cross-Modality: We provide checkpoints trained separately on the following modalities:
- CT, MRI, X-Ray, Fundus (FP), Dermoscopy (Der), Microscopy (Micro), OCT, and Ultrasound (US).
Cross-Task Learning: We provide checkpoints trained separately on the following tasks:
- Anatomy Identification, Disease Diagnosis, Lesion Grading, Modality Recognition, and Biological Attribute Analysis.
Use of Models
Load Checkpoint
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
MODEL_PATH = "..."
model = Qwen2VLForConditionalGeneration.from_pretrained(
MODEL_PATH,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
device_map="auto",
)
processor = AutoProcessor.from_pretrained(MODEL_PATH)
Data Organization
[
{
"image": "Images/Chest CT Scan/test/adenocarcinoma_left.lower.lobe_T2_N0_M0_Ib/000139 (9).png",
"problem": "What imaging technique is employed for obtaining this image? A)Mammogram, B)Positron emission tomography (PET), C)CT, D)Fluoroscopy",
"solution": "<answer> C </answer>"
},
{
"image": "Images/Chest CT Scan/test/squamous.cell.carcinoma_left.hilum_T1_N2_M0_IIIa/000127 (2).png",
"problem": "What imaging technique was utilized for obtaining this image? A)CT, B)Angiography, C)X-ray, D)Ultrasound",
"solution": "<answer> A </answer>"
},
{
"image": "Images/Chest CT Scan/test/normal/10 (2).png",
"problem": "What imaging technique was used for this image acquisition? A)CT, B)Ultrasound, C)Fluoroscopy, D)X-ray",
"solution": "<answer> A </answer>"
},
{
"image": "Images/Chest CT Scan/test/adenocarcinoma_left.lower.lobe_T2_N0_M0_Ib/000142.png",
"problem": "What is the specific diagnosis of the cancer shown in the image? A)Neuroendocrine tumor of the left upper lobe, T3 N0 M1, Stage III, B)Mesothelioma of the left lower lobe, T2 N0 M0, Stage Ib, C)Adenocarcinoma of the left lower lobe, T2 N0 M0, Stage Ib, D)Non-Hodgkin lymphoma of the right lower lobe, T2 N1 M0, Stage II",
"solution": "<answer> C </answer>"
}
...
]
Inference
with open(PROMPT_PATH, "r", encoding="utf-8") as f:
data = json.load(f)
QUESTION_TEMPLATE = "{Question} First output the thinking process in <think> </think> and final choice (A, B, C, D ...) in <answer> </answer> tags."
messages = []
for i in data:
message = [{
"role": "user",
"content": [
{
"type": "image",
"image": f"file://{i['image']}"
},
{
"type": "text",
"text": QUESTION_TEMPLATE.format(Question=i['problem'])
}
]
}]
messages.append(message)
for i in tqdm(range(0, len(messages), BSZ)):
batch_messages = messages[i:i + BSZ]
# Preparation for inference
text = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in batch_messages]
image_inputs, video_inputs = process_vision_info(batch_messages)
inputs = processor(
text=text,
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to("cuda")
# Inference: Generation of the output
generated_ids = model.generate(**inputs, use_cache=True, max_new_tokens=256, do_sample=False)
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
batch_output_text = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
all_outputs.extend(batch_output_text)
print(f"Processed batch {i//BSZ + 1}/{(len(messages) + BSZ - 1)//BSZ}")
Acknowledgements
We thank the authors of OmniMedVQA and R1-V for their open-source contributions.
๐ R1-V GitHub Repository
๐ OmniMedVQA GitHub Repository
Citation
@misc{lai2025medr1reinforcementlearninggeneralizable,
title={Med-R1: Reinforcement Learning for Generalizable Medical Reasoning in Vision-Language Models},
author={Yuxiang Lai and Jike Zhong and Ming Li and Shitian Zhao and Xiaofeng Yang},
year={2025},
eprint={2503.13939},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2503.13939},
}