jadechoghari commited on
Commit
3d9619d
1 Parent(s): 0fb0588

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +134 -75
README.md CHANGED
@@ -8,97 +8,156 @@ library_name: transformers
8
 
9
  Please download and save `builder.py`, `conversation.py` locally.
10
 
11
- ### Basic Text Generation
12
- ```python
13
- import torch
14
- from transformers import AutoTokenizer, AutoModelForCausalLM
15
-
16
- # load the model and tokenizer
17
- model_name = "jadechoghari/ferret-gemma"
18
- tokenizer = AutoTokenizer.from_pretrained(model_name)
19
- model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
20
-
21
- # give input text
22
- input_text = "The United States of America is a country situated on earth"
23
-
24
- # tokenize the input text
25
- inputs = tokenizer(input_text, return_tensors="pt", padding=True).to("cuda" if torch.cuda.is_available() else "cpu")
26
-
27
- model = model.to("cuda" if torch.cuda.is_available() else "cpu")
28
-
29
- output = model.generate(inputs['input_ids'], max_length=50, num_return_sequences=1)
30
-
31
- # decode and print the output
32
- generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
33
- print(generated_text)
34
- ```
35
-
36
- ### Image and Text Generation
37
  ```python
38
  import torch
39
  from PIL import Image
40
  from conversation import conv_templates
41
- from builder import load_pretrained_model # custom model loader
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
- # load model and tokenizer, then preprocess an image
44
- def infer_single_prompt(image_path, prompt, model_path):
 
 
 
45
  img = Image.open(image_path).convert('RGB')
46
- tokenizer, model, image_processor, _ = load_pretrained_model(model_path, None, "ferret_gemma")
47
- image_tensor = image_processor.preprocess(img, return_tensors='pt', size=(336, 336))['pixel_values'][0].unsqueeze(0).half()
48
 
49
- # prepare prompt
50
- conv = conv_templates["ferret_gemma_instruct"].copy()
51
- conv.append_message(conv.roles[0], f"Image and prompt: {prompt}")
52
- input_ids = tokenizer(conv.get_prompt(), return_tensors='pt')['input_ids'].cuda()
53
-
54
- image_tensor = image_tensor.cuda()
55
-
56
- # generate text output
57
- with torch.inference_mode():
58
- output_ids = model.generate(input_ids, images=image_tensor, max_new_tokens=1024)
59
-
60
- # decode the output
61
- output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
62
- return output_text.strip()
63
 
64
- # Usage
65
- result = infer_single_prompt("image.jpg", "Describe the contents of the image.", "jadechoghari/ferret-gemma")
66
- print(result)
67
- ```
68
 
69
- ### Text, Image, and Bounding Box
70
- ```python
71
- import torch
72
- from PIL import Image
73
- from functools import partial
74
- from builder import load_pretrained_model
 
 
75
 
76
- # generates a bounding box mask
77
- def generate_mask_for_feature(coor, img_w, img_h):
78
- coor_mask = torch.zeros((img_w, img_h))
79
- coor_mask[coor[0]:coor[2]+1, coor[1]:coor[3]+1] = 1
80
- return coor_mask
81
 
82
- def infer_with_bounding_box(image_path, prompt, model_path, region):
83
- img = Image.open(image_path).convert('RGB')
84
- tokenizer, model, image_processor, _ = load_pretrained_model(model_path, None, "ferret_gemma")
85
- image_tensor = image_processor.preprocess(img, return_tensors='pt', size=(336, 336))['pixel_values'][0].unsqueeze(0).half().cuda()
 
86
 
87
- input_ids = tokenizer(f"Image and prompt: {prompt}", return_tensors='pt')['input_ids'].cuda()
 
88
 
89
- # create region mask
90
- mask = generate_mask_for_feature(region, *img.size).unsqueeze(0).half().cuda()
 
 
 
 
91
 
92
- # generate output with region mask
93
  with torch.inference_mode():
 
94
  model.orig_forward = model.forward
95
- model.forward = partial(model.orig_forward, region_masks=[[mask]])
96
- output_ids = model.generate(input_ids, images=image_tensor, max_new_tokens=1024)
97
-
98
- output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
99
  return output_text.strip()
 
100
 
101
- # Usage
102
- result = infer_with_bounding_box("image.jpg", "Describe the contents of the box.", "jadechoghari/ferret-gemma", (50, 50, 200, 200))
103
- print(result)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  ```
 
8
 
9
  Please download and save `builder.py`, `conversation.py` locally.
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  ```python
12
  import torch
13
  from PIL import Image
14
  from conversation import conv_templates
15
+ from builder import load_pretrained_model # Assuming this is your custom model loader
16
+ from functools import partial
17
+ import numpy as np
18
+
19
+ # define the task categories
20
+ box_in_tasks = ['widgetcaptions', 'taperception', 'ocr', 'icon_recognition', 'widget_classification', 'example_0']
21
+ box_out_tasks = ['widget_listing', 'find_text', 'find_icons', 'find_widget', 'conversation_interaction']
22
+ no_box_tasks = ['screen2words', 'detailed_description', 'conversation_perception', 'gpt4']
23
+
24
+ # function to generate the mask
25
+ def generate_mask_for_feature(coor, raw_w, raw_h, mask=None):
26
+ """
27
+ Generates a region mask based on provided coordinates.
28
+ Handles both point and box input.
29
+ """
30
+ if mask is not None:
31
+ assert mask.shape[0] == raw_w and mask.shape[1] == raw_h
32
+ coor_mask = np.zeros((raw_w, raw_h))
33
+
34
+ # if it's a point (2 coordinates)
35
+ if len(coor) == 2:
36
+ span = 5 # Define the span for the point
37
+ x_min = max(0, coor[0] - span)
38
+ x_max = min(raw_w, coor[0] + span + 1)
39
+ y_min = max(0, coor[1] - span)
40
+ y_max = min(raw_h, coor[1] + span + 1)
41
+ coor_mask[int(x_min):int(x_max), int(y_min):int(y_max)] = 1
42
+ assert (coor_mask == 1).any(), f"coor: {coor}, raw_w: {raw_w}, raw_h: {raw_h}"
43
+
44
+ # if it's a box (4 coordinates)
45
+ elif len(coor) == 4:
46
+ coor_mask[coor[0]:coor[2]+1, coor[1]:coor[3]+1] = 1
47
+ if mask is not None:
48
+ coor_mask = coor_mask * mask
49
+
50
+ # Convert to torch tensor and ensure it contains non-zero values
51
+ coor_mask = torch.from_numpy(coor_mask)
52
+ assert len(coor_mask.nonzero()) != 0, "Generated mask is empty :("
53
 
54
+ return coor_mask
55
+ ```
56
+ ### Now, define the infer function
57
+ ```python
58
+ def infer_single_prompt(image_path, prompt, model_path, region=None, model_name="ferret_gemma", conv_mode="ferret_gemma_instruct"):
59
  img = Image.open(image_path).convert('RGB')
 
 
60
 
61
+ # this loads the model, image processor and tokenizer
62
+ tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name)
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
+ # define the image size (e.g., 224x224 or 336x336)
65
+ image_size = {"height": 336, "width": 336}
 
 
66
 
67
+ # process the image
68
+ image_tensor = image_processor.preprocess(
69
+ img,
70
+ return_tensors='pt',
71
+ do_resize=True,
72
+ do_center_crop=False,
73
+ size=(image_size['height'], image_size['width'])
74
+ )['pixel_values'][0].unsqueeze(0)
75
 
76
+ image_tensor = image_tensor.half().cuda()
 
 
 
 
77
 
78
+ # generate the prompt per template requirement
79
+ conv = conv_templates[conv_mode].copy()
80
+ conv.append_message(conv.roles[0], prompt)
81
+ conv.append_message(conv.roles[1], None)
82
+ prompt_input = conv.get_prompt()
83
 
84
+ # tokenize prompt
85
+ input_ids = tokenizer(prompt_input, return_tensors='pt')['input_ids'].cuda()
86
 
87
+ # region mask logic (if region is provided)
88
+ region_masks = None
89
+ if region is not None:
90
+ raw_w, raw_h = img.size
91
+ region_masks = generate_mask_for_feature(region, raw_w, raw_h).unsqueeze(0).cuda().half()
92
+ region_masks = [[region_masks]] # Wrap the mask in lists as expected by the model
93
 
94
+ # generate model output
95
  with torch.inference_mode():
96
+ # Use region_masks in model's forward call
97
  model.orig_forward = model.forward
98
+ model.forward = partial(
99
+ model.orig_forward,
100
+ region_masks=region_masks
101
+ )
102
+ output_ids = model.generate(
103
+ input_ids,
104
+ images=image_tensor,
105
+ max_new_tokens=1024,
106
+ num_beams=1,
107
+ region_masks=region_masks, # pass the region mask to the model
108
+ image_sizes=[img.size]
109
+ )
110
+ model.forward = model.orig_forward
111
+
112
+ # we decode the output
113
+ output_text = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
114
  return output_text.strip()
115
+ ```
116
 
117
+ # We also define a task-specific inference function
118
+ ```python
119
+ def infer_ui_task(image_path, prompt, model_path, task, region=None):
120
+ """
121
+ Handles task types: box_in_tasks, box_out_tasks, no_box_tasks.
122
+ """
123
+ if task in box_in_tasks and region is None:
124
+ raise ValueError(f"Task {task} requires a bounding box region.")
125
+
126
+ if task in box_in_tasks:
127
+ print(f"Processing {task} with bounding box region.")
128
+ return infer_single_prompt(image_path, prompt, model_path, region)
129
+
130
+ elif task in box_out_tasks:
131
+ print(f"Processing {task} without bounding box region.")
132
+ return infer_single_prompt(image_path, prompt, model_path)
133
+
134
+ elif task in no_box_tasks:
135
+ print(f"Processing {task} without image or bounding box.")
136
+ return infer_single_prompt(image_path, prompt, model_path)
137
+
138
+ else:
139
+ raise ValueError(f"Unknown task type: {task}")
140
+ ```
141
+
142
+ ### Usage:
143
+ ```python
144
+ # Example image and model paths
145
+ image_path = 'image.jpg'
146
+ model_path = 'jadechoghari/ferret-gemma'
147
+
148
+ # Task requiring bounding box
149
+ task = 'widgetcaptions'
150
+ region = (50, 50, 200, 200)
151
+ result = infer_ui_task(image_path, "Describe the contents of the box.", model_path, task, region=region)
152
+ print("Result:", result)
153
+
154
+ # Task not requiring bounding box
155
+ task = 'conversation_interaction'
156
+ result = infer_ui_task(image_path, "How do I navigate to the Games tab?", model_path, task)
157
+ print("Result:", result)
158
+
159
+ # Task with no image processing
160
+ task = 'detailed_description'
161
+ result = infer_ui_task(image_path, "Please describe the screen in detail.", model_path, task)
162
+ print("Result:", result)
163
  ```