File size: 5,382 Bytes
585e9cc
 
 
 
 
 
 
 
 
 
6d08dd1
 
 
 
 
11ef389
6d08dd1
11ef389
 
 
 
 
6d08dd1
9a087aa
88d5eeb
 
 
f64ab39
9a087aa
 
41c56f5
9a087aa
 
 
 
 
 
 
 
 
 
88d5eeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9a087aa
7b0bae8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f82ae0e
 
 
 
 
 
f64ab39
 
e5f9847
 
 
 
 
f64ab39
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
---
license: apache-2.0
datasets:
- foreverbeliever/OmniMedVQA
language:
- en
metrics:
- accuracy
base_model:
- Qwen/Qwen2-VL-2B-Instruct
pipeline_tag: visual-question-answering
---
# Med-R1
Med-R1 is a reinforcement learning (RL)-enhanced vision-language model (VLM) designed for medical reasoning across 8 imaging modalities (CT, MRI, Ultrasound, Dermoscopy, Fundus Photography, Optical Coherence Tomography (OCT), Microscopy, and X-ray) and 5 key tasks (modality recognition, anatomy identification, disease diagnosis, lesion grading, and biological attribute analysis). Using Group Relative Policy Optimization (GRPO), Med-R1 improves generalization and trustworthiness, surpassing Qwen2-VL-2B by 29.94% and even outperforming the much larger Qwen2-VL-72B. Our model checkpoints provide researchers with a powerful tool for advancing medical AI with RL-driven enhancements.

## Description of Models

- **Cross-Modality**: We provide checkpoints trained separately on the following modalities:
  - **CT**, **MRI**, **X-Ray**, **Fundus (FP)**, **Dermoscopy (Der)**, **Microscopy (Micro)**, **OCT**, and **Ultrasound (US)**.

- **Cross-Task Learning**: We provide checkpoints trained separately on the following tasks:
  - **Anatomy Identification**, **Disease Diagnosis**, **Lesion Grading**, **Modality Recognition**, and **Biological Attribute Analysis**.

## Use of Models

### Load Checkpoint

```python
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor

MODEL_PATH = "..."

model = Qwen2VLForConditionalGeneration.from_pretrained(
    MODEL_PATH,
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
    device_map="auto",
)

processor = AutoProcessor.from_pretrained(MODEL_PATH)
```
### Data Organization
```json
[
    {
        "image": "Images/Chest CT Scan/test/adenocarcinoma_left.lower.lobe_T2_N0_M0_Ib/000139 (9).png",
        "problem": "What imaging technique is employed for obtaining this image? A)Mammogram, B)Positron emission tomography (PET), C)CT, D)Fluoroscopy",
        "solution": "<answer> C </answer>"
    },
    {
        "image": "Images/Chest CT Scan/test/squamous.cell.carcinoma_left.hilum_T1_N2_M0_IIIa/000127 (2).png",
        "problem": "What imaging technique was utilized for obtaining this image? A)CT, B)Angiography, C)X-ray, D)Ultrasound",
        "solution": "<answer> A </answer>"
    },
    {
        "image": "Images/Chest CT Scan/test/normal/10 (2).png",
        "problem": "What imaging technique was used for this image acquisition? A)CT, B)Ultrasound, C)Fluoroscopy, D)X-ray",
        "solution": "<answer> A </answer>"
    },
    {
        "image": "Images/Chest CT Scan/test/adenocarcinoma_left.lower.lobe_T2_N0_M0_Ib/000142.png",
        "problem": "What is the specific diagnosis of the cancer shown in the image? A)Neuroendocrine tumor of the left upper lobe, T3 N0 M1, Stage III, B)Mesothelioma of the left lower lobe, T2 N0 M0, Stage Ib, C)Adenocarcinoma of the left lower lobe, T2 N0 M0, Stage Ib, D)Non-Hodgkin lymphoma of the right lower lobe, T2 N1 M0, Stage II",
        "solution": "<answer> C </answer>"
    }
  ...
]
```

### Inference 
```python
with open(PROMPT_PATH, "r", encoding="utf-8") as f:
    data = json.load(f)  

QUESTION_TEMPLATE = "{Question} First output the thinking process in <think> </think> and final choice (A, B, C, D ...) in <answer> </answer> tags."

messages = []

for i in data:
    message = [{
        "role": "user",
        "content": [
            {
                "type": "image", 
                "image": f"file://{i['image']}"
            },
            {
                "type": "text",
                "text": QUESTION_TEMPLATE.format(Question=i['problem'])
            }
        ]
    }]
    messages.append(message)


for i in tqdm(range(0, len(messages), BSZ)):
    batch_messages = messages[i:i + BSZ]
    
    # Preparation for inference
    text = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in batch_messages]
    
    image_inputs, video_inputs = process_vision_info(batch_messages)
    inputs = processor(
        text=text,
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    )
    inputs = inputs.to("cuda")

    # Inference: Generation of the output
    generated_ids = model.generate(**inputs, use_cache=True, max_new_tokens=256, do_sample=False)
    
    generated_ids_trimmed = [
        out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]
    batch_output_text = processor.batch_decode(
        generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )
    
    all_outputs.extend(batch_output_text)
    print(f"Processed batch {i//BSZ + 1}/{(len(messages) + BSZ - 1)//BSZ}")

```

## Acknowledgements

We thank the authors of **OmniMedVQA** and **R1-V** for their open-source contributions.  
๐Ÿ”— [R1-V GitHub Repository](https://github.com/Deep-Agent/R1-V)
๐Ÿ”— [OmniMedVQA GitHub Repository](https://github.com/OpenGVLab/Multi-Modality-Arena)

## Citation
```
@article{lai2025med,
  title={Med-R1: Reinforcement Learning for Generalizable Medical Reasoning in Vision-Language Models},
  author={Lai, Yuxiang and Zhong, Jike and Li, Ming and Zhao, Shitian and Yang, Xiaofeng},
  journal={arXiv preprint arXiv:2503.13939},
  year={2025}
}
```