WUBIAO commited on
Commit
4b515a5
·
verified ·
1 Parent(s): 7b50e56

Upload cogagent_infer_batch.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. cogagent_infer_batch.py +243 -0
cogagent_infer_batch.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import re
4
+ import torch
5
+ from PIL import Image, ImageDraw
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
7
+ from typing import List
8
+ import json
9
+ from tqdm import tqdm
10
+ from torch.utils.data import Dataset, DataLoader
11
+
12
+ class AITM_Dataset(Dataset):
13
+ def __init__(self, json_path):
14
+ #self.data = []
15
+ with open(json_path, 'r') as f:
16
+ self.data = json.load(f)
17
+
18
+ def __len__(self):
19
+ return len(self.data)
20
+
21
+ def __getitem__(self, idx):
22
+ x = self.data[idx]
23
+ img_path = x['image']
24
+ task = x['conversations'][0]['value']
25
+ return img_path, task
26
+ def draw_boxes_on_image(image: Image.Image, boxes: List[List[float]], save_path: str):
27
+ """
28
+ Draws red bounding boxes on the given image and saves it.
29
+
30
+ Parameters:
31
+ - image (PIL.Image.Image): The image on which to draw the bounding boxes.
32
+ - boxes (List[List[float]]): A list of bounding boxes, each defined as [x_min, y_min, x_max, y_max].
33
+ Coordinates are expected to be normalized (0 to 1).
34
+ - save_path (str): The path to save the updated image.
35
+
36
+ Description:
37
+ Each box coordinate is a fraction of the image dimension. This function converts them to actual pixel
38
+ coordinates and draws a red rectangle to mark the area. The annotated image is then saved to the specified path.
39
+ """
40
+ draw = ImageDraw.Draw(image)
41
+ for box in boxes:
42
+ x_min = int(box[0] * image.width)
43
+ y_min = int(box[1] * image.height)
44
+ x_max = int(box[2] * image.width)
45
+ y_max = int(box[3] * image.height)
46
+ draw.rectangle([x_min, y_min, x_max, y_max], outline="red", width=3)
47
+ image.save(save_path)
48
+
49
+
50
+ def main():
51
+ """
52
+ A continuous interactive demo using the CogAgent1.5 model with selectable format prompts.
53
+ The output_image_path is interpreted as a directory. For each round of interaction,
54
+ the annotated image will be saved in the directory with the filename:
55
+ {original_image_name_without_extension}_{round_number}.png
56
+
57
+ Example:
58
+ python cli_demo.py --model_dir THUDM/cogagent-9b-20241220 --platform "Mac" --max_length 4096 --top_k 1 \
59
+ --output_image_path ./results --format_key status_action_op_sensitive
60
+ """
61
+
62
+ parser = argparse.ArgumentParser(
63
+ description="Continuous interactive demo with CogAgent model and selectable format."
64
+ )
65
+ parser.add_argument(
66
+ "--model_dir", required=True, help="Path or identifier of the model."
67
+ )
68
+ parser.add_argument(
69
+ "--platform",
70
+ default="Mac",
71
+ help="Platform information string (e.g., 'Mac', 'WIN').",
72
+ )
73
+ parser.add_argument(
74
+ "--max_length", type=int, default=4096, help="Maximum generation length."
75
+ )
76
+ parser.add_argument(
77
+ "--top_k", type=int, default=1, help="Top-k sampling parameter."
78
+ )
79
+ parser.add_argument(
80
+ "--output_image_path",
81
+ default="results",
82
+ help="Directory to save the annotated images.",
83
+ )
84
+ parser.add_argument(
85
+ "--input_json",
86
+ default="/Users/baixuehai/Downloads/2025_2/AITM_Test_General_BBox_v0.json",
87
+ help="Directory to save the annotated images.",
88
+ )
89
+ parser.add_argument(
90
+ "--output_json",
91
+ default="/Users/baixuehai/Downloads/2025_2/AITM_Test_General_BBox_v0.json",
92
+ help="Directory to save the annotated images.",
93
+ )
94
+ parser.add_argument(
95
+ "--format_key",
96
+ default="action_op_sensitive",
97
+ help="Key to select the prompt format.",
98
+ )
99
+ args = parser.parse_args()
100
+
101
+ # Dictionary mapping format keys to format strings
102
+ format_dict = {
103
+ "action_op_sensitive": "(Answer in Action-Operation-Sensitive format.)",
104
+ "status_plan_action_op": "(Answer in Status-Plan-Action-Operation format.)",
105
+ "status_action_op_sensitive": "(Answer in Status-Action-Operation-Sensitive format.)",
106
+ "status_action_op": "(Answer in Status-Action-Operation format.)",
107
+ "action_op": "(Answer in Action-Operation format.)",
108
+ }
109
+
110
+ # Ensure the provided format_key is valid
111
+ if args.format_key not in format_dict:
112
+ raise ValueError(
113
+ f"Invalid format_key. Available keys are: {list(format_dict.keys())}"
114
+ )
115
+
116
+ # Ensure the output directory exists
117
+ os.makedirs(args.output_image_path, exist_ok=True)
118
+
119
+ # Load the tokenizer and model
120
+ tokenizer = AutoTokenizer.from_pretrained(args.model_dir, trust_remote_code=True)
121
+ model = AutoModelForCausalLM.from_pretrained(
122
+ args.model_dir,
123
+ torch_dtype=torch.bfloat16,
124
+ trust_remote_code=True,
125
+ device_map="auto",
126
+ # quantization_config=BitsAndBytesConfig(load_in_8bit=True), # For INT8 quantization
127
+ # quantization_config=BitsAndBytesConfig(load_in_4bit=True), # For INT4 quantization
128
+ ).eval()
129
+ # Initialize platform and selected format strings
130
+ platform_str = f"(Platform: {args.platform})\n"
131
+ format_str = format_dict[args.format_key]
132
+
133
+ # Initialize history lists
134
+ history_step = []
135
+ history_action = []
136
+
137
+ round_num = 1
138
+ # with open(args.input_json, "r") as f:
139
+ # data = json.load(f)
140
+ dataset = AITM_Dataset(args.input_json)
141
+ data_loader = DataLoader(dataset, batch_size=16, shuffle=False)
142
+
143
+ res = []
144
+ for x in tqdm(data_loader,desc="Processing items"):
145
+ #x = data[i]
146
+ img_path,task = x
147
+ image = []
148
+ for path in img_path:
149
+ image.append(Image.open(path).convert("RGB"))
150
+ #image = Image.open(img_path).convert("RGB")
151
+ #task = x['conversations'][0]['value']
152
+ # Verify history lengths match
153
+ if len(history_step) != len(history_action):
154
+ raise ValueError("Mismatch in lengths of history_step and history_action.")
155
+
156
+ # Format history steps for output
157
+ history_str = "\nHistory steps: "
158
+ for index, (step, action) in enumerate(zip(history_step, history_action)):
159
+ history_str += f"\n{index}. {step}\t{action}"
160
+
161
+ # Compose the query with task, platform, and selected format instructions
162
+ query = []
163
+ for x in task:
164
+ query.append(f"Task: {x}{history_str}\n{platform_str}{format_str}")
165
+ #query = f"Task: {task}{history_str}\n{platform_str}{format_str}"
166
+
167
+ #print(f"Round {round_num} query:\n{query}")
168
+
169
+ inputs = tokenizer.apply_chat_template(
170
+ [{"role": "user", "image": **image, "content": **query}],
171
+ add_generation_prompt=True,
172
+ tokenize=True,
173
+ return_tensors="pt",
174
+ return_dict=True,
175
+ ).to(model.device)
176
+ # Generation parameters
177
+ gen_kwargs = {
178
+ "max_length": args.max_length,
179
+ "do_sample": True,
180
+ "top_k": args.top_k,
181
+ }
182
+
183
+ # Generate response
184
+ with torch.no_grad():
185
+ outputs = model.generate(**inputs, **gen_kwargs)
186
+ outputs = outputs[:, inputs["input_ids"].shape[1]:]
187
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
188
+ #print(f"Model response:\n{response}")
189
+
190
+ # Extract grounded operation and action
191
+ grounded_pattern = r"Grounded Operation:\s*(.*)"
192
+ action_pattern = r"Action:\s*(.*)"
193
+ matches_history = re.search(grounded_pattern, response)
194
+ matches_actions = re.search(action_pattern, response)
195
+
196
+ if matches_history:
197
+ grounded_operation = matches_history.group(1)
198
+ history_step.append(grounded_operation)
199
+ if matches_actions:
200
+ action_operation = matches_actions.group(1)
201
+ history_action.append(action_operation)
202
+
203
+ # Extract bounding boxes from the response
204
+ box_pattern = r"box=\[\[?(\d+),(\d+),(\d+),(\d+)\]?\]"
205
+ matches = re.findall(box_pattern, response)
206
+ if matches:
207
+ boxes = [[int(x) / 1000 for x in match] for match in matches]
208
+
209
+ # Extract base name of the user's input image (without extension)
210
+ base_name = []
211
+ for path in args.img_path:
212
+ base_name.append(os.path.splitext(os.path.basename(path))[0])
213
+ #base_name = os.path.splitext(os.path.basename(img_path))[0]
214
+ # Construct the output file name with round number
215
+ output_file_name = []
216
+ for i in range(len(base_name)):
217
+ output_file_name.append(f"{base_name[i]}_{round_num}_{i}.png")
218
+ #output_file_name = f"{base_name}_{round_num}.png"
219
+ output_path = []
220
+ for x in output_file_name:
221
+ output_path.append(os.path.join(args.output_image_path, x))
222
+ #output_path = os.path.join(args.output_image_path, output_file_name)
223
+
224
+ draw_boxes_on_image(image, boxes, output_path)
225
+ #print(f"Annotated image saved at: {output_path}")
226
+ ans = {
227
+ 'query': f"Round {round_num} query:\n{query}",
228
+ 'response': response,
229
+ 'output_path': output_path
230
+ }
231
+ res.append(ans)
232
+ round_num += 1
233
+ #print(res)
234
+ print("Writing to json file")
235
+ with open(args.output_json, "w") as file:
236
+ print("Writing to json file")
237
+ json.dump(res, file, indent=4)
238
+ print("Done")
239
+
240
+
241
+
242
+ if __name__ == "__main__":
243
+ main()