vlrm-blip2-opt-2.7b / README.md
sashakunitsyn's picture
Update README.md
67c80a5 verified
|
raw
history blame
No virus
3.3 kB
---
language:
- en
license: mit
library_name: transformers
tags:
- vision
- image-to-text
- image-captioning
pipeline_tag: image-to-text
base_model: Salesforce/blip2-opt-2.7b
---
# VLRM
This repository contains the weights of BLIP-2 OPT-2.7B model fine-tuned by reinforcement learning method introduced in the paper [VLRM: Vision-Language Models act as
Reward Models for Image Captioning](https://arxiv.org/abs/2404.01911).
The RL-tuned model is able to generate longer and more comprehensive descriptions with zero computational overhead compared to the original model.
You can find other details in the [GitHub Repository (to be done)](https://github.com/papermsucode).
# Running the model
## Option 1
<details>
<summary> Load the whole model from this repo </summary>
```python
import torch
import requests
from PIL import Image
from transformers import Blip2Processor, Blip2ForConditionalGeneration
processor = Blip2Processor.from_pretrained("sashakunitsyn/vlrm-blip2-opt-2.7b")
model = Blip2ForConditionalGeneration.from_pretrained("sashakunitsyn/vlrm-blip2-opt-2.7b", torch_dtype=torch.float16, device_map="auto")
img_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg'
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB')
inputs = processor(raw_image, return_tensors="pt").to("cuda", torch.float16)
out = model.generate(**inputs, max_new_tokens=60)
processor.decode(out[0], skip_special_tokens=True).strip()
>>> 'a woman in a plaid shirt shaking hands with a yellow labrador retriever sitting on the ground at sunset on a beach in florida'
```
</details>
## Option 2
Since the fine-tuned layers take small part of the whole model, you can first load the original model, then load the RL-tuned weights.
<details>
<summary> Step 1. Load the original model </summary>
```python
import torch
import requests
from PIL import Image
from transformers import Blip2Processor, Blip2ForConditionalGeneration
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16, device_map="auto")
img_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg'
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB')
inputs = processor(raw_image, return_tensors="pt").to("cuda", torch.float16)
out = model.generate(**inputs, max_new_tokens=60)
processor.decode(out[0], skip_special_tokens=True).strip()
>>> 'a woman sitting on the beach with a dog'
```
</details>
<details>
<summary> Step 2. Load the RL-tuned weights </summary>
Available checkpoints:
- `vlrm-blip2-opt-2.7b.pt` (VLRM in the paper)
- `vlrm-rs-blip2-opt-2.7b.pt` (VLRM-RS in the paper)
```python
from huggingface_hub import hf_hub_download
finetuned_weights_state_dict = torch.load(hf_hub_download(repo_id="sashakunitsyn/vlrm-blip2-opt-2.7b", filename="vlrm-blip2-opt-2.7b.pt"))
model.load_state_dict(finetuned_weights_state_dict, strict=False)
out = model.generate(**inputs, max_new_tokens=60)
processor.decode(out[0], skip_special_tokens=True).strip()
>>> 'a woman in a plaid shirt shaking hands with a yellow labrador retriever sitting on the ground at sunset on a beach in florida'
```
</details>