Med-R1

Med-R1 is a reinforcement learning (RL)-enhanced vision-language model (VLM) designed for medical reasoning across 8 imaging modalities (CT, MRI, Ultrasound, Dermoscopy, Fundus Photography, Optical Coherence Tomography (OCT), Microscopy, and X-ray) and 5 key tasks (modality recognition, anatomy identification, disease diagnosis, lesion grading, and biological attribute analysis). Using Group Relative Policy Optimization (GRPO), Med-R1 improves generalization and trustworthiness, surpassing Qwen2-VL-2B by 29.94% and even outperforming the much larger Qwen2-VL-72B. Our model checkpoints provide researchers with a powerful tool for advancing medical AI with RL-driven enhancements.

Description of Models

  • Cross-Modality: We provide checkpoints trained separately on the following modalities:

    • CT, MRI, X-Ray, Fundus (FP), Dermoscopy (Der), Microscopy (Micro), OCT, and Ultrasound (US).
  • Cross-Task Learning: We provide checkpoints trained separately on the following tasks:

    • Anatomy Identification, Disease Diagnosis, Lesion Grading, Modality Recognition, and Biological Attribute Analysis.

Use of Models

Load Checkpoint

from transformers import Qwen2VLForConditionalGeneration, AutoProcessor

MODEL_PATH = "..."

model = Qwen2VLForConditionalGeneration.from_pretrained(
    MODEL_PATH,
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
    device_map="auto",
)

processor = AutoProcessor.from_pretrained(MODEL_PATH)

Data Organization

[
    {
        "image": "Images/Chest CT Scan/test/adenocarcinoma_left.lower.lobe_T2_N0_M0_Ib/000139 (9).png",
        "problem": "What imaging technique is employed for obtaining this image? A)Mammogram, B)Positron emission tomography (PET), C)CT, D)Fluoroscopy",
        "solution": "<answer> C </answer>"
    },
    {
        "image": "Images/Chest CT Scan/test/squamous.cell.carcinoma_left.hilum_T1_N2_M0_IIIa/000127 (2).png",
        "problem": "What imaging technique was utilized for obtaining this image? A)CT, B)Angiography, C)X-ray, D)Ultrasound",
        "solution": "<answer> A </answer>"
    },
    {
        "image": "Images/Chest CT Scan/test/normal/10 (2).png",
        "problem": "What imaging technique was used for this image acquisition? A)CT, B)Ultrasound, C)Fluoroscopy, D)X-ray",
        "solution": "<answer> A </answer>"
    },
    {
        "image": "Images/Chest CT Scan/test/adenocarcinoma_left.lower.lobe_T2_N0_M0_Ib/000142.png",
        "problem": "What is the specific diagnosis of the cancer shown in the image? A)Neuroendocrine tumor of the left upper lobe, T3 N0 M1, Stage III, B)Mesothelioma of the left lower lobe, T2 N0 M0, Stage Ib, C)Adenocarcinoma of the left lower lobe, T2 N0 M0, Stage Ib, D)Non-Hodgkin lymphoma of the right lower lobe, T2 N1 M0, Stage II",
        "solution": "<answer> C </answer>"
    }
  ...
]

Inference

with open(PROMPT_PATH, "r", encoding="utf-8") as f:
    data = json.load(f)  

QUESTION_TEMPLATE = "{Question} First output the thinking process in <think> </think> and final choice (A, B, C, D ...) in <answer> </answer> tags."

messages = []

for i in data:
    message = [{
        "role": "user",
        "content": [
            {
                "type": "image", 
                "image": f"file://{i['image']}"
            },
            {
                "type": "text",
                "text": QUESTION_TEMPLATE.format(Question=i['problem'])
            }
        ]
    }]
    messages.append(message)


for i in tqdm(range(0, len(messages), BSZ)):
    batch_messages = messages[i:i + BSZ]
    
    # Preparation for inference
    text = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in batch_messages]
    
    image_inputs, video_inputs = process_vision_info(batch_messages)
    inputs = processor(
        text=text,
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    )
    inputs = inputs.to("cuda")

    # Inference: Generation of the output
    generated_ids = model.generate(**inputs, use_cache=True, max_new_tokens=256, do_sample=False)
    
    generated_ids_trimmed = [
        out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]
    batch_output_text = processor.batch_decode(
        generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )
    
    all_outputs.extend(batch_output_text)
    print(f"Processed batch {i//BSZ + 1}/{(len(messages) + BSZ - 1)//BSZ}")

Acknowledgements

We thank the authors of OmniMedVQA and R1-V for their open-source contributions.
๐Ÿ”— R1-V GitHub Repository ๐Ÿ”— OmniMedVQA GitHub Repository

Citation

@misc{lai2025medr1reinforcementlearninggeneralizable,
      title={Med-R1: Reinforcement Learning for Generalizable Medical Reasoning in Vision-Language Models}, 
      author={Yuxiang Lai and Jike Zhong and Ming Li and Shitian Zhao and Xiaofeng Yang},
      year={2025},
      eprint={2503.13939},
      archivePrefix={arXiv},
      primaryClass={cs.CV},
      url={https://arxiv.org/abs/2503.13939}, 
}
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API: The model has no library tag.

Model tree for yuxianglai117/Med-R1

Base model

Qwen/Qwen2-VL-2B
Finetuned
(166)
this model

Dataset used to train yuxianglai117/Med-R1