WUBIAO commited on
Commit
5388ba3
·
verified ·
1 Parent(s): 1be9398

Upload test_max_token.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. test_max_token.py +164 -0
test_max_token.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import re
4
+ import torch
5
+ from PIL import Image, ImageDraw
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
7
+ from typing import List
8
+ import json
9
+ from tqdm import tqdm
10
+
11
+ #class
12
+ def draw_boxes_on_image(image: Image.Image, boxes: List[List[float]], save_path: str):
13
+ """
14
+ Draws red bounding boxes on the given image and saves it.
15
+
16
+ Parameters:
17
+ - image (PIL.Image.Image): The image on which to draw the bounding boxes.
18
+ - boxes (List[List[float]]): A list of bounding boxes, each defined as [x_min, y_min, x_max, y_max].
19
+ Coordinates are expected to be normalized (0 to 1).
20
+ - save_path (str): The path to save the updated image.
21
+
22
+ Description:
23
+ Each box coordinate is a fraction of the image dimension. This function converts them to actual pixel
24
+ coordinates and draws a red rectangle to mark the area. The annotated image is then saved to the specified path.
25
+ """
26
+ draw = ImageDraw.Draw(image)
27
+ for box in boxes:
28
+ x_min = int(box[0] * image.width)
29
+ y_min = int(box[1] * image.height)
30
+ x_max = int(box[2] * image.width)
31
+ y_max = int(box[3] * image.height)
32
+ draw.rectangle([x_min, y_min, x_max, y_max], outline="red", width=3)
33
+ image.save(save_path)
34
+
35
+
36
+ def main():
37
+ """
38
+ A continuous interactive demo using the CogAgent1.5 model with selectable format prompts.
39
+ The output_image_path is interpreted as a directory. For each round of interaction,
40
+ the annotated image will be saved in the directory with the filename:
41
+ {original_image_name_without_extension}_{round_number}.png
42
+
43
+ Example:
44
+ python cli_demo.py --model_dir THUDM/cogagent-9b-20241220 --platform "Mac" --max_length 4096 --top_k 1 \
45
+ --output_image_path ./results --format_key status_action_op_sensitive
46
+ """
47
+
48
+ parser = argparse.ArgumentParser(
49
+ description="Continuous interactive demo with CogAgent model and selectable format."
50
+ )
51
+ parser.add_argument(
52
+ "--model_dir", required=True, help="Path or identifier of the model."
53
+ )
54
+ parser.add_argument(
55
+ "--platform",
56
+ default="Mac",
57
+ help="Platform information string (e.g., 'Mac', 'WIN').",
58
+ )
59
+ parser.add_argument(
60
+ "--max_length", type=int, default=4096, help="Maximum generation length."
61
+ )
62
+ parser.add_argument(
63
+ "--top_k", type=int, default=1, help="Top-k sampling parameter."
64
+ )
65
+ parser.add_argument(
66
+ "--output_image_path",
67
+ default="results",
68
+ help="Directory to save the annotated images.",
69
+ )
70
+ parser.add_argument(
71
+ "--input_json",
72
+ default="/Users/baixuehai/Downloads/2025_2/AITM_Test_General_BBox_v0.json",
73
+ help="Directory to save the annotated images.",
74
+ )
75
+ parser.add_argument(
76
+ "--output_json",
77
+ default="/Users/baixuehai/Downloads/2025_2/AITM_Test_General_BBox_v0.json",
78
+ help="Directory to save the annotated images.",
79
+ )
80
+ parser.add_argument(
81
+ "--format_key",
82
+ default="action_op_sensitive",
83
+ help="Key to select the prompt format.",
84
+ )
85
+ args = parser.parse_args()
86
+
87
+ # Dictionary mapping format keys to format strings
88
+ format_dict = {
89
+ "action_op_sensitive": "(Answer in Action-Operation-Sensitive format.)",
90
+ "status_plan_action_op": "(Answer in Status-Plan-Action-Operation format.)",
91
+ "status_action_op_sensitive": "(Answer in Status-Action-Operation-Sensitive format.)",
92
+ "status_action_op": "(Answer in Status-Action-Operation format.)",
93
+ "action_op": "(Answer in Action-Operation format.)",
94
+ }
95
+
96
+ # Ensure the provided format_key is valid
97
+ if args.format_key not in format_dict:
98
+ raise ValueError(
99
+ f"Invalid format_key. Available keys are: {list(format_dict.keys())}"
100
+ )
101
+
102
+ # Ensure the output directory exists
103
+ os.makedirs(args.output_image_path, exist_ok=True)
104
+
105
+ # Load the tokenizer and model
106
+ tokenizer = AutoTokenizer.from_pretrained(args.model_dir, trust_remote_code=True)
107
+ model = AutoModelForCausalLM.from_pretrained(
108
+ args.model_dir,
109
+ torch_dtype=torch.bfloat16,
110
+ trust_remote_code=True,
111
+ device_map="auto",
112
+ # quantization_config=BitsAndBytesConfig(load_in_8bit=True), # For INT8 quantization
113
+ # quantization_config=BitsAndBytesConfig(load_in_4bit=True), # For INT4 quantization
114
+ ).eval()
115
+ # Initialize platform and selected format strings
116
+ platform_str = f"(Platform: {args.platform})\n"
117
+ format_str = format_dict[args.format_key]
118
+
119
+ # Initialize history lists
120
+ history_step = []
121
+ history_action = []
122
+
123
+ round_num = 1
124
+ with open(args.input_json, "r") as f:
125
+ data = json.load(f)
126
+
127
+ max_len_val = []
128
+ for i in tqdm(range(len(data))):
129
+ x = data[i]
130
+ img_path = x['image']
131
+ image = Image.open(img_path).convert("RGB")
132
+ task = x['conversations'][0]['value']
133
+ # Verify history lengths match
134
+ try:
135
+ if len(history_step) != len(history_action):
136
+ raise ValueError("Mismatch in lengths of history_step and history_action.")
137
+ except ValueError as e:
138
+ print(f"警告: {e} - 跳过当前案例")
139
+
140
+ # Format history steps for output
141
+ history_str = "\nHistory steps: "
142
+ for index, (step, action) in enumerate(zip(history_step, history_action)):
143
+ history_str += f"\n{index}. {step}\t{action}"
144
+
145
+ # Compose the query with task, platform, and selected format instructions
146
+ query = f"Task: {task}{history_str}\n{platform_str}{format_str}"
147
+
148
+ #print(f"Round {round_num} query:\n{query}")
149
+
150
+ inputs = tokenizer.apply_chat_template(
151
+ [{"role": "user", "image": image, "content": query}],
152
+ add_generation_prompt=True,
153
+ tokenize=True,
154
+ return_tensors="pt",
155
+ return_dict=True,
156
+ ).to(model.device)
157
+
158
+ now_token_nums = inputs['input_ids'].shape[1]
159
+ max_len_val.append(now_token_nums)
160
+ with open('max_token_nums.json','w') as f:
161
+ json.dump(max_len_val,f,ensure_ascii=False, indent=4)
162
+
163
+ if __name__ == "__main__":
164
+ main()