File size: 2,041 Bytes
5df033a 3147f67 5df033a 646928f 5df033a 646928f 5df033a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
# LLaVA-LoRA Adapter
This is a LoRA adapter for the LLaVA model, fine-tuned for spatial description tasks.
## Base Model
This adapter is trained on top of [llava-hf/llava-1.5-7b-hf](https://huggingface.co/llava-hf/llava-1.5-7b-hf).
## Training
The model was fine-tuned using LoRA with the following configuration:
- Rank: 8
- Alpha: 32
- Target modules: q_proj, v_proj, k_proj
- Dataset: PersReFex validation set
## Usage
```python
from peft import PeftModel
from transformers import AutoProcessor, LlavaForConditionalGeneration
import torch
# Load base model
base_model = LlavaForConditionalGeneration.from_pretrained(
"llava-hf/llava-1.5-7b-hf",
torch_dtype=torch.bfloat16
).to('cuda')
processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
# Load LoRA adapter
model = PeftModel.from_pretrained(
base_model,
"ZinengTang/llava-lora-spatial"
)
from PIL import Image
init_prompt_instruct = "Describe the location of the blue sphere relative to the environment features."
conversation = [
{
"role": "user",
"content": [
{"type": "text", "text": init_prompt_instruct},
{"type": "image"}, # This will be replaced with the actual image
],
},
]
speaker_image = Image.open('your_image_path')
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
# print(prompt)
# Process the input image and prompt
inputs = processor(
images=speaker_image,
text=prompt,
return_tensors="pt",
max_length=256,
).to('cuda')
with torch.no_grad():
generated = model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
pixel_values=inputs["pixel_values"],
max_length=512,
num_beams=1,
do_sample=True,
temperature=0.7
)
generated_message = processor.batch_decode(
generated,
skip_special_tokens=True
)
print(generated_message)
generated_message = generated_message[0].split('ASSISTANT: ')[-1][:100]
```
|