Upload LLM2Vec4CXR fine-tuned model
Browse files- README.md +180 -46
- config.json +9 -4
- pytorch_model.bin +3 -0
- requirements.txt +6 -0
- special_tokens_map.json +0 -1
- tokenizer.json +14 -11
- tokenizer_config.json +0 -9
- usage_example.py +236 -0
README.md
CHANGED
@@ -1,60 +1,194 @@
|
|
1 |
-
# llm2vec4cxr
|
2 |
-
|
3 |
---
|
4 |
license: mit
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
---
|
6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
## Usage
|
8 |
|
9 |
-
###
|
10 |
-
pip install llm2vec
|
11 |
|
12 |
-
```
|
13 |
-
|
14 |
-
|
|
|
|
|
|
|
15 |
|
|
|
16 |
import torch
|
17 |
-
|
18 |
-
from
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
```
|
55 |
|
56 |
-
##
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
```
|
59 |
-
|
60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
license: mit
|
3 |
+
base_model: microsoft/LLM2CLIP-Llama-3.2-1B-Instruct-CC-Finetuned
|
4 |
+
tags:
|
5 |
+
- text-embeddings
|
6 |
+
- sentence-transformers
|
7 |
+
- llm2vec
|
8 |
+
- medical
|
9 |
+
- chest-xray
|
10 |
+
- radiology
|
11 |
+
- clinical-nlp
|
12 |
+
language:
|
13 |
+
- en
|
14 |
+
pipeline_tag: feature-extraction
|
15 |
+
library_name: transformers
|
16 |
---
|
17 |
|
18 |
+
# LLM2Vec4CXR - Fine-tuned Model for Chest X-ray Report Analysis
|
19 |
+
|
20 |
+
This model is a fine-tuned version of [microsoft/LLM2CLIP-Llama-3.2-1B-Instruct-CC-Finetuned](https://huggingface.co/microsoft/LLM2CLIP-Llama-3.2-1B-Instruct-CC-Finetuned) specifically optimized for chest X-ray report analysis and medical text understanding.
|
21 |
+
|
22 |
+
## Model Description
|
23 |
+
|
24 |
+
LLM2Vec4CXR is a bidirectional language model that converts the base decoder-only LLM into a text encoder optimized for medical text embeddings. The model has been fully fine-tuned with modified pooling strategy (`latent_attention`) to better capture semantic relationships in chest X-ray reports.
|
25 |
+
|
26 |
+
### Key Features
|
27 |
+
|
28 |
+
- **Base Architecture**: LLM2CLIP-Llama-3.2-1B-Instruct
|
29 |
+
- **Pooling Mode**: Latent Attention (modified from original)
|
30 |
+
- **Bidirectional Processing**: Enabled for better context understanding
|
31 |
+
- **Medical Domain**: Specialized for chest X-ray report analysis
|
32 |
+
- **Max Length**: 512 tokens
|
33 |
+
- **Precision**: bfloat16
|
34 |
+
|
35 |
+
## Training Details
|
36 |
+
|
37 |
+
### Training Data
|
38 |
+
- Fully fine-tuned on chest X-ray reports and medical text data
|
39 |
+
- Training focused on understanding pleural effusion status and other chest X-ray findings
|
40 |
+
|
41 |
+
### Training Configuration
|
42 |
+
- **Pooling Mode**: `latent_attention` (modified from base model)
|
43 |
+
- **Enable Bidirectional**: True
|
44 |
+
- **Max Length**: 512
|
45 |
+
- **Torch Dtype**: bfloat16
|
46 |
+
- **Full Fine-tuning**: All model weights were updated during training
|
47 |
+
|
48 |
## Usage
|
49 |
|
50 |
+
### Installation
|
|
|
51 |
|
52 |
+
```bash
|
53 |
+
pip install torch transformers
|
54 |
+
# Also requires the LLM2Vec wrapper - see the original repository for installation
|
55 |
+
```
|
56 |
+
|
57 |
+
### Basic Usage
|
58 |
|
59 |
+
```python
|
60 |
import torch
|
61 |
+
import torch.nn.functional as F
|
62 |
+
from llm2vec_wrapper import LLM2VecWrapper as LLM2Vec
|
63 |
+
|
64 |
+
# Load the model
|
65 |
+
model = LLM2Vec.from_pretrained(
|
66 |
+
base_model_name_or_path='lukeingawesome/llm2vec4cxr',
|
67 |
+
enable_bidirectional=True,
|
68 |
+
pooling_mode="latent_attention",
|
69 |
+
max_length=512,
|
70 |
+
torch_dtype=torch.bfloat16,
|
71 |
+
)
|
72 |
+
|
73 |
+
# Configure tokenizer
|
74 |
+
tokenizer = model.tokenizer
|
75 |
+
tokenizer.padding_side = 'left'
|
76 |
+
|
77 |
+
# Example usage for chest X-ray report analysis
|
78 |
+
def encode_text(text):
|
79 |
+
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
|
80 |
+
with torch.no_grad():
|
81 |
+
embeddings = model(inputs)
|
82 |
+
return embeddings
|
83 |
+
|
84 |
+
# Example with medical text
|
85 |
+
report = "There is a small increase in the left-sided effusion. There continues to be volume loss at both bases."
|
86 |
+
embedding = encode_text(report)
|
87 |
+
```
|
88 |
+
|
89 |
+
### Advanced Usage with Separator-based Processing
|
90 |
+
|
91 |
+
The model supports special separator-based processing for instruction-following tasks:
|
92 |
+
|
93 |
+
```python
|
94 |
+
def tokenize_with_separator(texts, tokenizer, max_length):
|
95 |
+
"""Tokenize texts with special handling for separator-based splitting."""
|
96 |
+
texts_2 = []
|
97 |
+
original_texts = []
|
98 |
+
separator = '!@#$%^&*()'
|
99 |
+
|
100 |
+
for text in texts:
|
101 |
+
parts = text.split(separator)
|
102 |
+
texts_2.append(parts[1] if len(parts) > 1 else "")
|
103 |
+
original_texts.append("".join(parts))
|
104 |
+
|
105 |
+
tokenized = tokenizer(
|
106 |
+
original_texts,
|
107 |
+
return_tensors="pt",
|
108 |
+
padding=True,
|
109 |
+
truncation=True,
|
110 |
+
max_length=max_length,
|
111 |
+
)
|
112 |
|
113 |
+
# Create embedding masks for the separated parts
|
114 |
+
embed_mask = None
|
115 |
+
for t_i, t in enumerate(texts_2):
|
116 |
+
ids = tokenizer([t], return_tensors="pt", padding=True, truncation=True,
|
117 |
+
max_length=max_length, add_special_tokens=False)
|
118 |
+
e_m = torch.zeros_like(tokenized["attention_mask"][t_i])
|
119 |
+
if len(ids["input_ids"][0]) > 0:
|
120 |
+
e_m[-len(ids["input_ids"][0]):] = torch.ones(len(ids["input_ids"][0]))
|
121 |
+
if embed_mask is None:
|
122 |
+
embed_mask = e_m.unsqueeze(0)
|
123 |
+
else:
|
124 |
+
embed_mask = torch.cat((embed_mask, e_m.unsqueeze(0)), dim=0)
|
125 |
+
|
126 |
+
tokenized["embed_mask"] = embed_mask
|
127 |
+
return tokenized
|
128 |
+
|
129 |
+
# Example with instruction and report
|
130 |
+
separator = '!@#$%^&*()'
|
131 |
+
instruction = 'Determine the change or the status of the pleural effusion.'
|
132 |
+
report = 'There is a small increase in the left-sided effusion.'
|
133 |
+
text = instruction + separator + report
|
134 |
+
|
135 |
+
tokenized = tokenize_with_separator([text], tokenizer, 512)
|
136 |
+
embedding = model(tokenized)
|
137 |
```
|
138 |
|
139 |
+
## Evaluation
|
140 |
+
|
141 |
+
The model has been evaluated on chest X-ray report analysis tasks, particularly for:
|
142 |
+
- Pleural effusion status determination
|
143 |
+
- Medical text similarity comparison
|
144 |
+
- Clinical finding extraction
|
145 |
+
|
146 |
+
### Sample Performance
|
147 |
+
|
148 |
+
The model shows improved performance compared to the base model on medical text understanding tasks, particularly in distinguishing between different pleural effusion states and medical abbreviations.
|
149 |
+
|
150 |
+
## Intended Use
|
151 |
|
152 |
+
### Primary Use Cases
|
153 |
+
- **Medical Text Embeddings**: Generate embeddings for chest X-ray reports
|
154 |
+
- **Clinical Text Similarity**: Compare medical texts for semantic similarity
|
155 |
+
- **Medical Information Retrieval**: Find relevant medical reports or findings
|
156 |
+
- **Clinical NLP Research**: Foundation model for medical text analysis
|
157 |
+
|
158 |
+
### Limitations
|
159 |
+
- Specialized for chest X-ray reports - may not generalize to other medical domains
|
160 |
+
- Requires careful preprocessing for optimal performance
|
161 |
+
- Should be used as part of a larger clinical decision support system, not for standalone diagnosis
|
162 |
+
|
163 |
+
## Technical Specifications
|
164 |
+
|
165 |
+
- **Model Type**: Bidirectional Language Model (LLM2Vec)
|
166 |
+
- **Architecture**: LlamaBiModel (modified Llama 3.2)
|
167 |
+
- **Parameters**: ~1B parameters
|
168 |
+
- **Input Length**: Up to 512 tokens
|
169 |
+
- **Output**: Dense embeddings
|
170 |
+
- **Precision**: bfloat16
|
171 |
+
|
172 |
+
## Citation
|
173 |
+
|
174 |
+
If you use this model in your research, please cite:
|
175 |
+
|
176 |
+
```bibtex
|
177 |
+
@misc{llm2vec4cxr,
|
178 |
+
title={LLM2Vec4CXR: Fine-tuned Language Model for Chest X-ray Report Analysis},
|
179 |
+
author={[Your Name]},
|
180 |
+
year={2024},
|
181 |
+
howpublished={\\url{https://huggingface.co/lukeingawesome/llm2vec4cxr}},
|
182 |
+
}
|
183 |
```
|
184 |
+
|
185 |
+
## Acknowledgments
|
186 |
+
|
187 |
+
This model is built upon:
|
188 |
+
- [LLM2Vec](https://github.com/McGill-NLP/llm2vec) - Framework for converting decoder-only LLMs into text encoders
|
189 |
+
- [LLM2CLIP](https://github.com/microsoft/LLM2CLIP) - Microsoft's implementation for connecting LLMs with CLIP models
|
190 |
+
- [microsoft/LLM2CLIP-Llama-3.2-1B-Instruct-CC-Finetuned](https://huggingface.co/microsoft/LLM2CLIP-Llama-3.2-1B-Instruct-CC-Finetuned) - Base model
|
191 |
+
|
192 |
+
## License
|
193 |
+
|
194 |
+
This model is licensed under the MIT License.
|
config.json
CHANGED
@@ -1,13 +1,14 @@
|
|
1 |
{
|
2 |
-
"_name_or_path": "
|
3 |
"architectures": [
|
4 |
-
"
|
5 |
],
|
6 |
"attention_bias": false,
|
7 |
"attention_dropout": 0.0,
|
8 |
"auto_map": {
|
9 |
"AutoModel": "microsoft/LLM2CLIP-Llama-3.2-1B-Instruct-CC-Finetuned--modeling_bidirectional_llama_encoder.LlamaBiModel"
|
10 |
},
|
|
|
11 |
"bos_token_id": 128000,
|
12 |
"eos_token_id": [
|
13 |
128001,
|
@@ -25,6 +26,7 @@
|
|
25 |
"num_attention_heads": 32,
|
26 |
"num_hidden_layers": 16,
|
27 |
"num_key_value_heads": 8,
|
|
|
28 |
"pretraining_tp": 1,
|
29 |
"rms_norm_eps": 1e-05,
|
30 |
"rope_scaling": {
|
@@ -39,5 +41,8 @@
|
|
39 |
"torch_dtype": "bfloat16",
|
40 |
"transformers_version": "4.44.2",
|
41 |
"use_cache": true,
|
42 |
-
"vocab_size": 128256
|
43 |
-
|
|
|
|
|
|
|
|
1 |
{
|
2 |
+
"_name_or_path": "lukeingawesome/llm2vec4cxr",
|
3 |
"architectures": [
|
4 |
+
"LlamaBiModel"
|
5 |
],
|
6 |
"attention_bias": false,
|
7 |
"attention_dropout": 0.0,
|
8 |
"auto_map": {
|
9 |
"AutoModel": "microsoft/LLM2CLIP-Llama-3.2-1B-Instruct-CC-Finetuned--modeling_bidirectional_llama_encoder.LlamaBiModel"
|
10 |
},
|
11 |
+
"base_model": "microsoft/LLM2CLIP-Llama-3.2-1B-Instruct-CC-Finetuned",
|
12 |
"bos_token_id": 128000,
|
13 |
"eos_token_id": [
|
14 |
128001,
|
|
|
26 |
"num_attention_heads": 32,
|
27 |
"num_hidden_layers": 16,
|
28 |
"num_key_value_heads": 8,
|
29 |
+
"pooling_mode": "latent_attention",
|
30 |
"pretraining_tp": 1,
|
31 |
"rms_norm_eps": 1e-05,
|
32 |
"rope_scaling": {
|
|
|
41 |
"torch_dtype": "bfloat16",
|
42 |
"transformers_version": "4.44.2",
|
43 |
"use_cache": true,
|
44 |
+
"vocab_size": 128256,
|
45 |
+
"model_description": "Fine-tuned LLM2Vec model for chest X-ray report analysis with latent attention pooling",
|
46 |
+
"domain": "medical",
|
47 |
+
"task": "text-embeddings"
|
48 |
+
}
|
pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c77cfb88a4601f6ad464c6694801f11e23f1e9f8f496e794323153247fc750a4
|
3 |
+
size 2524132278
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch>=2.0.0
|
2 |
+
transformers>=4.44.0
|
3 |
+
accelerate>=0.20.0
|
4 |
+
flash-attn>=2.5.0
|
5 |
+
numpy>=1.21.0
|
6 |
+
huggingface_hub>=0.16.0
|
special_tokens_map.json
CHANGED
@@ -13,6 +13,5 @@
|
|
13 |
"rstrip": false,
|
14 |
"single_word": false
|
15 |
},
|
16 |
-
"mask_token": "_",
|
17 |
"pad_token": "<|eot_id|>"
|
18 |
}
|
|
|
13 |
"rstrip": false,
|
14 |
"single_word": false
|
15 |
},
|
|
|
16 |
"pad_token": "<|eot_id|>"
|
17 |
}
|
tokenizer.json
CHANGED
@@ -1,17 +1,20 @@
|
|
1 |
{
|
2 |
"version": "1.0",
|
3 |
-
"truncation":
|
4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
"added_tokens": [
|
6 |
-
{
|
7 |
-
"id": 62,
|
8 |
-
"content": "_",
|
9 |
-
"single_word": false,
|
10 |
-
"lstrip": false,
|
11 |
-
"rstrip": false,
|
12 |
-
"normalized": false,
|
13 |
-
"special": true
|
14 |
-
},
|
15 |
{
|
16 |
"id": 128000,
|
17 |
"content": "<|begin_of_text|>",
|
|
|
1 |
{
|
2 |
"version": "1.0",
|
3 |
+
"truncation": {
|
4 |
+
"direction": "Right",
|
5 |
+
"max_length": 512,
|
6 |
+
"strategy": "LongestFirst",
|
7 |
+
"stride": 0
|
8 |
+
},
|
9 |
+
"padding": {
|
10 |
+
"strategy": "BatchLongest",
|
11 |
+
"direction": "Left",
|
12 |
+
"pad_to_multiple_of": null,
|
13 |
+
"pad_id": 128009,
|
14 |
+
"pad_type_id": 0,
|
15 |
+
"pad_token": "<|eot_id|>"
|
16 |
+
},
|
17 |
"added_tokens": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
{
|
19 |
"id": 128000,
|
20 |
"content": "<|begin_of_text|>",
|
tokenizer_config.json
CHANGED
@@ -1,13 +1,5 @@
|
|
1 |
{
|
2 |
"added_tokens_decoder": {
|
3 |
-
"62": {
|
4 |
-
"content": "_",
|
5 |
-
"lstrip": false,
|
6 |
-
"normalized": false,
|
7 |
-
"rstrip": false,
|
8 |
-
"single_word": false,
|
9 |
-
"special": true
|
10 |
-
},
|
11 |
"128000": {
|
12 |
"content": "<|begin_of_text|>",
|
13 |
"lstrip": false,
|
@@ -2061,7 +2053,6 @@
|
|
2061 |
"chat_template": "{{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- if strftime_now is defined %}\n {%- set date_string = strftime_now(\"%d %b %Y\") %}\n {%- else %}\n {%- set date_string = \"26 Jul 2024\" %}\n {%- endif %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = \"\" %}\n{%- endif %}\n\n{#- System message #}\n{{- \"<|start_header_id|>system<|end_header_id|>\\n\\n\" }}\n{%- if tools is not none %}\n {{- \"Environment: ipython\\n\" }}\n{%- endif %}\n{{- \"Cutting Knowledge Date: December 2023\\n\" }}\n{{- \"Today Date: \" + date_string + \"\\n\\n\" }}\n{%- if tools is not none and not tools_in_user_message %}\n {{- \"You have access to the following functions. To call a function, please respond with JSON for a function call.\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n{%- endif %}\n{{- system_message }}\n{{- \"<|eot_id|>\" }}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception(\"Cannot put tools in the first user message when there's no first user message!\") }}\n{%- endif %}\n {{- '<|start_header_id|>user<|end_header_id|>\\n\\n' -}}\n {{- \"Given the following functions, please respond with a JSON for a function call \" }}\n {{- \"with its proper arguments that best answers the given prompt.\\n\\n\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n {{- first_user_message + \"<|eot_id|>\"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}\n {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n'+ message['content'] | trim + '<|eot_id|>' }}\n {%- elif 'tool_calls' in message %}\n {%- if not message.tool_calls|length == 1 %}\n {{- raise_exception(\"This model only supports single tool-calls at once!\") }}\n {%- endif %}\n {%- set tool_call = message.tool_calls[0].function %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- '{\"name\": \"' + tool_call.name + '\", ' }}\n {{- '\"parameters\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- \"}\" }}\n {{- \"<|eot_id|>\" }}\n {%- elif message.role == \"tool\" or message.role == \"ipython\" %}\n {{- \"<|start_header_id|>ipython<|end_header_id|>\\n\\n\" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- \"<|eot_id|>\" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}\n{%- endif %}\n",
|
2062 |
"clean_up_tokenization_spaces": true,
|
2063 |
"eos_token": "<|eot_id|>",
|
2064 |
-
"mask_token": "_",
|
2065 |
"model_input_names": [
|
2066 |
"input_ids",
|
2067 |
"attention_mask"
|
|
|
1 |
{
|
2 |
"added_tokens_decoder": {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
"128000": {
|
4 |
"content": "<|begin_of_text|>",
|
5 |
"lstrip": false,
|
|
|
2053 |
"chat_template": "{{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- if strftime_now is defined %}\n {%- set date_string = strftime_now(\"%d %b %Y\") %}\n {%- else %}\n {%- set date_string = \"26 Jul 2024\" %}\n {%- endif %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = \"\" %}\n{%- endif %}\n\n{#- System message #}\n{{- \"<|start_header_id|>system<|end_header_id|>\\n\\n\" }}\n{%- if tools is not none %}\n {{- \"Environment: ipython\\n\" }}\n{%- endif %}\n{{- \"Cutting Knowledge Date: December 2023\\n\" }}\n{{- \"Today Date: \" + date_string + \"\\n\\n\" }}\n{%- if tools is not none and not tools_in_user_message %}\n {{- \"You have access to the following functions. To call a function, please respond with JSON for a function call.\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n{%- endif %}\n{{- system_message }}\n{{- \"<|eot_id|>\" }}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception(\"Cannot put tools in the first user message when there's no first user message!\") }}\n{%- endif %}\n {{- '<|start_header_id|>user<|end_header_id|>\\n\\n' -}}\n {{- \"Given the following functions, please respond with a JSON for a function call \" }}\n {{- \"with its proper arguments that best answers the given prompt.\\n\\n\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n {{- first_user_message + \"<|eot_id|>\"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}\n {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n'+ message['content'] | trim + '<|eot_id|>' }}\n {%- elif 'tool_calls' in message %}\n {%- if not message.tool_calls|length == 1 %}\n {{- raise_exception(\"This model only supports single tool-calls at once!\") }}\n {%- endif %}\n {%- set tool_call = message.tool_calls[0].function %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- '{\"name\": \"' + tool_call.name + '\", ' }}\n {{- '\"parameters\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- \"}\" }}\n {{- \"<|eot_id|>\" }}\n {%- elif message.role == \"tool\" or message.role == \"ipython\" %}\n {{- \"<|start_header_id|>ipython<|end_header_id|>\\n\\n\" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- \"<|eot_id|>\" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}\n{%- endif %}\n",
|
2054 |
"clean_up_tokenization_spaces": true,
|
2055 |
"eos_token": "<|eot_id|>",
|
|
|
2056 |
"model_input_names": [
|
2057 |
"input_ids",
|
2058 |
"attention_mask"
|
usage_example.py
ADDED
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Example usage script for LLM2Vec4CXR model.
|
3 |
+
This demonstrates how to load and use the model for chest X-ray report analysis.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from llm2vec_wrapper import LLM2VecWrapper as LLM2Vec
|
9 |
+
|
10 |
+
def load_llm2vec4cxr_model(model_name_or_path="lukeingawesome/llm2vec4cxr"):
|
11 |
+
"""
|
12 |
+
Load the LLM2Vec4CXR model with proper configuration.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
model_name_or_path (str): Hugging Face model path or local path
|
16 |
+
|
17 |
+
Returns:
|
18 |
+
tuple: (model, tokenizer)
|
19 |
+
"""
|
20 |
+
# Load model with the specific configuration used for LLM2Vec4CXR
|
21 |
+
model = LLM2Vec.from_pretrained(
|
22 |
+
base_model_name_or_path=model_name_or_path,
|
23 |
+
enable_bidirectional=True,
|
24 |
+
pooling_mode="latent_attention", # This is the key modification
|
25 |
+
max_length=512,
|
26 |
+
torch_dtype=torch.bfloat16,
|
27 |
+
)
|
28 |
+
|
29 |
+
# Configure tokenizer
|
30 |
+
tokenizer = model.tokenizer
|
31 |
+
tokenizer.padding_side = 'left'
|
32 |
+
|
33 |
+
return model, tokenizer
|
34 |
+
|
35 |
+
def tokenize_with_separator(texts, tokenizer, max_length=512):
|
36 |
+
"""
|
37 |
+
Tokenize texts with special handling for separator-based splitting.
|
38 |
+
This is useful for instruction-following tasks.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
texts (list): List of texts to tokenize
|
42 |
+
tokenizer: The tokenizer to use
|
43 |
+
max_length (int): Maximum sequence length
|
44 |
+
|
45 |
+
Returns:
|
46 |
+
dict: Tokenized inputs with attention masks and embed masks
|
47 |
+
"""
|
48 |
+
texts_2 = []
|
49 |
+
original_texts = []
|
50 |
+
separator = '!@#$%^&*()'
|
51 |
+
|
52 |
+
for text in texts:
|
53 |
+
parts = text.split(separator)
|
54 |
+
texts_2.append(parts[1] if len(parts) > 1 else "")
|
55 |
+
original_texts.append("".join(parts))
|
56 |
+
|
57 |
+
# Tokenize original texts
|
58 |
+
tokenized = tokenizer(
|
59 |
+
original_texts,
|
60 |
+
return_tensors="pt",
|
61 |
+
padding=True,
|
62 |
+
truncation=True,
|
63 |
+
max_length=max_length,
|
64 |
+
)
|
65 |
+
|
66 |
+
# Create embedding masks for the separated parts
|
67 |
+
embed_mask = None
|
68 |
+
for t_i, t in enumerate(texts_2):
|
69 |
+
ids = tokenizer(
|
70 |
+
[t],
|
71 |
+
return_tensors="pt",
|
72 |
+
padding=True,
|
73 |
+
truncation=True,
|
74 |
+
max_length=max_length,
|
75 |
+
add_special_tokens=False,
|
76 |
+
)
|
77 |
+
|
78 |
+
e_m = torch.zeros_like(tokenized["attention_mask"][t_i])
|
79 |
+
if len(ids["input_ids"][0]) > 0:
|
80 |
+
e_m[-len(ids["input_ids"][0]):] = torch.ones(len(ids["input_ids"][0]))
|
81 |
+
|
82 |
+
if embed_mask is None:
|
83 |
+
embed_mask = e_m.unsqueeze(0)
|
84 |
+
else:
|
85 |
+
embed_mask = torch.cat((embed_mask, e_m.unsqueeze(0)), dim=0)
|
86 |
+
|
87 |
+
tokenized["embed_mask"] = embed_mask
|
88 |
+
return tokenized
|
89 |
+
|
90 |
+
def compute_similarities(model, tokenizer, texts, device):
|
91 |
+
"""
|
92 |
+
Compute similarity scores between the first text and all other texts.
|
93 |
+
|
94 |
+
Args:
|
95 |
+
model: The LLM2Vec model
|
96 |
+
tokenizer: The tokenizer
|
97 |
+
texts (list): List of texts to compare (first text is the reference)
|
98 |
+
device: The device to run computations on
|
99 |
+
|
100 |
+
Returns:
|
101 |
+
tuple: (embeddings, similarities)
|
102 |
+
"""
|
103 |
+
with torch.no_grad():
|
104 |
+
# Use separator-based tokenization if texts contain the separator
|
105 |
+
if any('!@#$%^&*()' in text for text in texts):
|
106 |
+
tokenized = tokenize_with_separator(texts, tokenizer, 512)
|
107 |
+
else:
|
108 |
+
tokenized = tokenizer(
|
109 |
+
texts,
|
110 |
+
return_tensors="pt",
|
111 |
+
padding=True,
|
112 |
+
truncation=True,
|
113 |
+
max_length=512,
|
114 |
+
)
|
115 |
+
|
116 |
+
tokenized = tokenized.to(device)
|
117 |
+
if hasattr(tokenized, 'to'):
|
118 |
+
tokenized = tokenized.to(torch.bfloat16)
|
119 |
+
else:
|
120 |
+
# Convert each tensor in the dict
|
121 |
+
for key in tokenized:
|
122 |
+
if torch.is_tensor(tokenized[key]):
|
123 |
+
tokenized[key] = tokenized[key].to(torch.bfloat16)
|
124 |
+
|
125 |
+
embeddings = model(tokenized)
|
126 |
+
|
127 |
+
# Compute cosine similarities between first embedding and all others
|
128 |
+
similarities = F.cosine_similarity(embeddings[0], embeddings[1:], dim=1)
|
129 |
+
|
130 |
+
return embeddings, similarities
|
131 |
+
|
132 |
+
def main():
|
133 |
+
"""
|
134 |
+
Example usage of the LLM2Vec4CXR model for chest X-ray report analysis.
|
135 |
+
"""
|
136 |
+
# Set device
|
137 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
138 |
+
print(f"Using device: {device}")
|
139 |
+
|
140 |
+
# Load the model
|
141 |
+
print("Loading LLM2Vec4CXR model...")
|
142 |
+
model, tokenizer = load_llm2vec4cxr_model()
|
143 |
+
model = model.to(device).to(torch.bfloat16)
|
144 |
+
model.eval()
|
145 |
+
|
146 |
+
# Example 1: Basic text embedding
|
147 |
+
print("\n" + "="*60)
|
148 |
+
print("Example 1: Basic Text Embedding")
|
149 |
+
print("="*60)
|
150 |
+
|
151 |
+
report = "There is a small increase in the left-sided effusion. There continues to be volume loss at both bases."
|
152 |
+
|
153 |
+
inputs = tokenizer(report, return_tensors="pt", padding=True, truncation=True, max_length=512)
|
154 |
+
inputs = inputs.to(device)
|
155 |
+
|
156 |
+
with torch.no_grad():
|
157 |
+
embedding = model(inputs)
|
158 |
+
|
159 |
+
print(f"Report: {report}")
|
160 |
+
print(f"Embedding shape: {embedding.shape}")
|
161 |
+
print(f"Embedding norm: {torch.norm(embedding).item():.4f}")
|
162 |
+
|
163 |
+
# Example 2: Instruction-based similarity comparison
|
164 |
+
print("\n" + "="*60)
|
165 |
+
print("Example 2: Instruction-based Similarity Comparison")
|
166 |
+
print("="*60)
|
167 |
+
|
168 |
+
separator = '!@#$%^&*()'
|
169 |
+
instruction = 'Determine the change or the status of the pleural effusion.'
|
170 |
+
report = 'There is a small increase in the left-sided effusion. There continues to be volume loss at both bases.'
|
171 |
+
text = instruction + separator + report
|
172 |
+
|
173 |
+
comparison_options = [
|
174 |
+
'No pleural effusion',
|
175 |
+
'Pleural effusion',
|
176 |
+
'Effusion is seen in the right',
|
177 |
+
'Effusion is seen in the left',
|
178 |
+
'Pleural effusion is improving',
|
179 |
+
'Pleural effusion is stable',
|
180 |
+
'Pleural effusion is worsening'
|
181 |
+
]
|
182 |
+
|
183 |
+
all_texts = [text] + comparison_options
|
184 |
+
|
185 |
+
# Compute similarities
|
186 |
+
_, similarities = compute_similarities(model, tokenizer, all_texts, device)
|
187 |
+
|
188 |
+
print(f"Original text: {report}")
|
189 |
+
print(f"Instruction: {instruction}")
|
190 |
+
print("\nSimilarity Scores:")
|
191 |
+
print("-" * 50)
|
192 |
+
|
193 |
+
for option, score in zip(comparison_options, similarities):
|
194 |
+
print(f"{option:<35} | {score.item():.4f}")
|
195 |
+
|
196 |
+
# Find the most similar option
|
197 |
+
best_match_idx = torch.argmax(similarities).item()
|
198 |
+
print(f"\nBest match: {comparison_options[best_match_idx]} (score: {similarities[best_match_idx].item():.4f})")
|
199 |
+
|
200 |
+
# Example 3: Multiple report comparison
|
201 |
+
print("\n" + "="*60)
|
202 |
+
print("Example 3: Multiple Report Comparison")
|
203 |
+
print("="*60)
|
204 |
+
|
205 |
+
reports = [
|
206 |
+
"No acute cardiopulmonary abnormality.",
|
207 |
+
"Small bilateral pleural effusions.",
|
208 |
+
"Large left pleural effusion with compressive atelectasis.",
|
209 |
+
"Interval improvement in bilateral pleural effusions.",
|
210 |
+
"Worsening bilateral pleural effusions."
|
211 |
+
]
|
212 |
+
|
213 |
+
print("Computing embeddings for multiple reports...")
|
214 |
+
inputs = tokenizer(reports, return_tensors="pt", padding=True, truncation=True, max_length=512)
|
215 |
+
inputs = inputs.to(device)
|
216 |
+
|
217 |
+
with torch.no_grad():
|
218 |
+
embeddings = model(inputs)
|
219 |
+
|
220 |
+
# Compute pairwise similarities
|
221 |
+
similarity_matrix = F.cosine_similarity(
|
222 |
+
embeddings.unsqueeze(1),
|
223 |
+
embeddings.unsqueeze(0),
|
224 |
+
dim=2
|
225 |
+
)
|
226 |
+
|
227 |
+
print("\nPairwise Similarity Matrix:")
|
228 |
+
print("-" * 30)
|
229 |
+
for i, report1 in enumerate(reports):
|
230 |
+
print(f"Report {i+1}: {report1[:30]}...")
|
231 |
+
for j, report2 in enumerate(reports):
|
232 |
+
print(f" vs Report {j+1}: {similarity_matrix[i][j].item():.4f}")
|
233 |
+
print()
|
234 |
+
|
235 |
+
if __name__ == "__main__":
|
236 |
+
main()
|