File size: 5,382 Bytes
585e9cc 6d08dd1 11ef389 6d08dd1 11ef389 6d08dd1 9a087aa 88d5eeb f64ab39 9a087aa 41c56f5 9a087aa 88d5eeb 9a087aa 7b0bae8 f82ae0e f64ab39 e5f9847 f64ab39 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
---
license: apache-2.0
datasets:
- foreverbeliever/OmniMedVQA
language:
- en
metrics:
- accuracy
base_model:
- Qwen/Qwen2-VL-2B-Instruct
pipeline_tag: visual-question-answering
---
# Med-R1
Med-R1 is a reinforcement learning (RL)-enhanced vision-language model (VLM) designed for medical reasoning across 8 imaging modalities (CT, MRI, Ultrasound, Dermoscopy, Fundus Photography, Optical Coherence Tomography (OCT), Microscopy, and X-ray) and 5 key tasks (modality recognition, anatomy identification, disease diagnosis, lesion grading, and biological attribute analysis). Using Group Relative Policy Optimization (GRPO), Med-R1 improves generalization and trustworthiness, surpassing Qwen2-VL-2B by 29.94% and even outperforming the much larger Qwen2-VL-72B. Our model checkpoints provide researchers with a powerful tool for advancing medical AI with RL-driven enhancements.
## Description of Models
- **Cross-Modality**: We provide checkpoints trained separately on the following modalities:
- **CT**, **MRI**, **X-Ray**, **Fundus (FP)**, **Dermoscopy (Der)**, **Microscopy (Micro)**, **OCT**, and **Ultrasound (US)**.
- **Cross-Task Learning**: We provide checkpoints trained separately on the following tasks:
- **Anatomy Identification**, **Disease Diagnosis**, **Lesion Grading**, **Modality Recognition**, and **Biological Attribute Analysis**.
## Use of Models
### Load Checkpoint
```python
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
MODEL_PATH = "..."
model = Qwen2VLForConditionalGeneration.from_pretrained(
MODEL_PATH,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
device_map="auto",
)
processor = AutoProcessor.from_pretrained(MODEL_PATH)
```
### Data Organization
```json
[
{
"image": "Images/Chest CT Scan/test/adenocarcinoma_left.lower.lobe_T2_N0_M0_Ib/000139 (9).png",
"problem": "What imaging technique is employed for obtaining this image? A)Mammogram, B)Positron emission tomography (PET), C)CT, D)Fluoroscopy",
"solution": "<answer> C </answer>"
},
{
"image": "Images/Chest CT Scan/test/squamous.cell.carcinoma_left.hilum_T1_N2_M0_IIIa/000127 (2).png",
"problem": "What imaging technique was utilized for obtaining this image? A)CT, B)Angiography, C)X-ray, D)Ultrasound",
"solution": "<answer> A </answer>"
},
{
"image": "Images/Chest CT Scan/test/normal/10 (2).png",
"problem": "What imaging technique was used for this image acquisition? A)CT, B)Ultrasound, C)Fluoroscopy, D)X-ray",
"solution": "<answer> A </answer>"
},
{
"image": "Images/Chest CT Scan/test/adenocarcinoma_left.lower.lobe_T2_N0_M0_Ib/000142.png",
"problem": "What is the specific diagnosis of the cancer shown in the image? A)Neuroendocrine tumor of the left upper lobe, T3 N0 M1, Stage III, B)Mesothelioma of the left lower lobe, T2 N0 M0, Stage Ib, C)Adenocarcinoma of the left lower lobe, T2 N0 M0, Stage Ib, D)Non-Hodgkin lymphoma of the right lower lobe, T2 N1 M0, Stage II",
"solution": "<answer> C </answer>"
}
...
]
```
### Inference
```python
with open(PROMPT_PATH, "r", encoding="utf-8") as f:
data = json.load(f)
QUESTION_TEMPLATE = "{Question} First output the thinking process in <think> </think> and final choice (A, B, C, D ...) in <answer> </answer> tags."
messages = []
for i in data:
message = [{
"role": "user",
"content": [
{
"type": "image",
"image": f"file://{i['image']}"
},
{
"type": "text",
"text": QUESTION_TEMPLATE.format(Question=i['problem'])
}
]
}]
messages.append(message)
for i in tqdm(range(0, len(messages), BSZ)):
batch_messages = messages[i:i + BSZ]
# Preparation for inference
text = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in batch_messages]
image_inputs, video_inputs = process_vision_info(batch_messages)
inputs = processor(
text=text,
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to("cuda")
# Inference: Generation of the output
generated_ids = model.generate(**inputs, use_cache=True, max_new_tokens=256, do_sample=False)
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
batch_output_text = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
all_outputs.extend(batch_output_text)
print(f"Processed batch {i//BSZ + 1}/{(len(messages) + BSZ - 1)//BSZ}")
```
## Acknowledgements
We thank the authors of **OmniMedVQA** and **R1-V** for their open-source contributions.
๐ [R1-V GitHub Repository](https://github.com/Deep-Agent/R1-V)
๐ [OmniMedVQA GitHub Repository](https://github.com/OpenGVLab/Multi-Modality-Arena)
## Citation
```
@article{lai2025med,
title={Med-R1: Reinforcement Learning for Generalizable Medical Reasoning in Vision-Language Models},
author={Lai, Yuxiang and Zhong, Jike and Li, Ming and Zhao, Shitian and Yang, Xiaofeng},
journal={arXiv preprint arXiv:2503.13939},
year={2025}
}
``` |