File size: 1,613 Bytes
883e6ec
 
 
 
 
 
a87e26e
 
3317fdf
 
f2ec66d
 
ca7038a
f2ec66d
ca7038a
 
 
f2ec66d
ca7038a
f2ec66d
ca7038a
f2ec66d
ca7038a
 
 
f2ec66d
ca7038a
f2ec66d
ca7038a
 
 
 
 
 
f2ec66d
ca7038a
f2ec66d
ca7038a
 
 
 
 
 
 
 
 
 
f2ec66d
ca7038a
 
 
f2ec66d
ca7038a
 
 
f2ec66d
ca7038a
 
f2ec66d
ca7038a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
---
license: apache-2.0
base_model:
- microsoft/llava-med-v1.5-mistral-7b
new_version: microsoft/llava-med-v1.5-mistral-7b
pipeline_tag: visual-question-answering
language:
- en
tags:
- medical
- biology
---
# llava-med-v1.5-mistral-7b-hf

This repository contains a **drop-in, Hugging Face–compatible** checkpoint converted from  
[https://huggingface.co/microsoft/llava-med-v1.5-mistral-7b](https://huggingface.co/microsoft/llava-med-v1.5-mistral-7b).  
You can load it with the **exact same code** you use for the original model—no extra conversion steps required.

---

## Quick Start

```python
from transformers import LlavaForConditionalGeneration, AutoProcessor
import torch

model_path = "chaoyinshe/llava-med-v1.5-mistral-7b-hf"

model = LlavaForConditionalGeneration.from_pretrained(
    model_path,
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",   # requires FA2
    device_map="auto"                          # multi-GPU ready
)

processor = AutoProcessor.from_pretrained(model_path)

# Example inference
messages = [
    {
        "role": "user",
        "content": [
            {"type": "image"},
            {"type": "text", "text": "What is the main finding in this chest X-ray?"}
        ]
    }
]

prompt = processor.tokenizer.apply_chat_template(
    messages, tokenize=False, add_generation_prompt=True
)

inputs = processor(
    images=[image], text=prompt, return_tensors="pt"
).to(model.device, torch.bfloat16)

with torch.inference_mode():
    out = model.generate(**inputs, max_new_tokens=256)

print(processor.decode(out[0], skip_special_tokens=True))