ZinengTang commited on
Commit
646928f
·
verified ·
1 Parent(s): 3147f67

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +41 -1
README.md CHANGED
@@ -24,7 +24,7 @@ import torch
24
  base_model = LlavaForConditionalGeneration.from_pretrained(
25
  "llava-hf/llava-1.5-7b-hf",
26
  torch_dtype=torch.bfloat16
27
- )
28
  processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
29
 
30
  # Load LoRA adapter
@@ -32,4 +32,44 @@ model = PeftModel.from_pretrained(
32
  base_model,
33
  "ZinengTang/llava-lora-spatial"
34
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  ```
 
24
  base_model = LlavaForConditionalGeneration.from_pretrained(
25
  "llava-hf/llava-1.5-7b-hf",
26
  torch_dtype=torch.bfloat16
27
+ ).to('cuda')
28
  processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
29
 
30
  # Load LoRA adapter
 
32
  base_model,
33
  "ZinengTang/llava-lora-spatial"
34
  )
35
+
36
+ from PIL import Image
37
+ init_prompt_instruct = "Describe the location of the blue sphere relative to the environment features."
38
+ conversation = [
39
+ {
40
+ "role": "user",
41
+ "content": [
42
+ {"type": "text", "text": init_prompt_instruct},
43
+ {"type": "image"}, # This will be replaced with the actual image
44
+ ],
45
+ },
46
+ ]
47
+ speaker_image = Image.open('your_image_path')
48
+ prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
49
+ # print(prompt)
50
+ # Process the input image and prompt
51
+ inputs = processor(
52
+ images=speaker_image,
53
+ text=prompt,
54
+ return_tensors="pt",
55
+ max_length=256,
56
+ ).to('cuda')
57
+
58
+ with torch.no_grad():
59
+ generated = model.generate(
60
+ input_ids=inputs["input_ids"],
61
+ attention_mask=inputs["attention_mask"],
62
+ pixel_values=inputs["pixel_values"],
63
+ max_length=512,
64
+ num_beams=1,
65
+ do_sample=True,
66
+ temperature=0.7
67
+ )
68
+ generated_message = processor.batch_decode(
69
+ generated,
70
+ skip_special_tokens=True
71
+ )
72
+ print(generated_message)
73
+ generated_message = generated_message[0].split('ASSISTANT: ')[-1][:100]
74
+
75
  ```