WUBIAO commited on
Commit
edd4e08
·
verified ·
1 Parent(s): 5cddfe8

Upload utils_data_.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. utils_data_.py +278 -0
utils_data_.py CHANGED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Dataset
2
+ import torch
3
+ import pickle
4
+ from tqdm import tqdm
5
+ import action_matching, action_type
6
+ import numpy as np
7
+ import jax.numpy as jnp
8
+ import random
9
+ import re
10
+ img_shape = {
11
+ "resnet": (512, 2048),
12
+ "clip": (49, 2048),
13
+ "detr": (100, 256),
14
+ "vit": (577, 768),
15
+ "vit-large": (145, 1024),
16
+ "vit-global": (1, 768),
17
+ "vit-merge": (578, 768),
18
+ }
19
+
20
+
21
+ def load_data(args, split):
22
+ target_text = []
23
+ source_text = []
24
+ source_image = []
25
+ anno_positions = []
26
+
27
+ if args.all_data:
28
+ if split == "train":
29
+ data = []
30
+ for subdir in ["general", "google_apps", "install", "single", "web_shopping"]:
31
+ print(f"loading {subdir}", len(data))
32
+ with open(f"dataset/blip/{subdir}_{args.data_root}_{split}.obj", "rb") as rp:
33
+ sub_data = pickle.load(rp)
34
+ if subdir == "google_apps":
35
+ sub_data = random.sample(sub_data, int(len(sub_data) * args.all_data))
36
+ data.extend(sub_data)
37
+ else:
38
+ # we use general subset for dev/test
39
+ with open(f"{args.eval_subset}_{split}.obj", "rb") as rp:
40
+ data = pickle.load(rp)
41
+ else:
42
+ with open(f"{args.data_root}_{split}.obj", "rb") as rp:
43
+ data = pickle.load(rp)
44
+ if args.data_ratio:
45
+ data = random.sample(data, int(len(data) * args.data_ratio))
46
+
47
+ for qid, episode in enumerate(tqdm(data)):
48
+ episode_id = episode["episode_id"]
49
+ episode_data = episode["data"]
50
+ if args.use_history:
51
+ history_action = []
52
+ if args.use_img_history:
53
+ history_image = [torch.zeros(args.img_dim)] * args.use_history
54
+
55
+ for step_idx, step_data in enumerate(episode_data):
56
+ question = step_data["goal"]
57
+ question = f"Goal: {question}"
58
+
59
+ image = step_data["image"]
60
+
61
+ ui_positions = step_data["ui_positions"]
62
+ ui_text = step_data["ui_text"]
63
+ ui_type = step_data["ui_type"]
64
+
65
+ if args.use_layout:
66
+ icon_string = ""
67
+ for ui_idx, ui_type_i in enumerate(ui_type):
68
+ ui_axis = ui_positions[ui_idx]
69
+ top, left, height, width = ui_axis
70
+ # The y-axis is inverted for AndroidEnv, so bottom = top + height.
71
+ bottom, right = top + height, left + width
72
+ ui_axis = [top, left, bottom, right]
73
+ ui_axis = ["{:.4f}".format(axis) for axis in ui_axis]
74
+ ui_axis = f"({ui_axis[0]}, {ui_axis[1]}, {ui_axis[2]}, {ui_axis[3]})"
75
+ if ui_type_i == "TEXT":
76
+ icon_string += f'<p id={ui_idx} class="text" alt="{ui_axis}">{ui_text[ui_idx]}</p>\n'
77
+ elif "ICON" in ui_type_i:
78
+ icon_string += f'<img id={ui_idx} class={ui_type_i} alt="{ui_axis}">{ui_text[ui_idx]}</p>\n'
79
+ else:
80
+ print(icon_string)
81
+ assert "parsing ui failed!!!"
82
+
83
+ question = f"{question}\nScreen: {icon_string}"
84
+ # print(question)
85
+ result_touch_yx = step_data["result_touch_yx"]
86
+ result_lift_yx = step_data["result_lift_yx"]
87
+ result_action = step_data["result_action"][0]
88
+ result_text = step_data["result_action"][1]
89
+
90
+ result_text = result_text.replace("\\", "").replace('"','').replace("'","")
91
+
92
+ if args.transform_axis:
93
+ scroll_map = {
94
+ "up": [[0.8000, 0.5000], [0.2000, 0.5000]],
95
+ "down": [[0.2000, 0.5000], [0.8000, 0.5000]],
96
+ "left": [[0.5000, 0.8000], [0.5000, 0.2000]],
97
+ "right": [[0.5000, 0.2000], [0.5000, 0.8000]]
98
+ }
99
+ action_touch_yx = jnp.asarray(result_touch_yx)
100
+ action_lift_yx = jnp.asarray(result_lift_yx)
101
+ if result_action == "DUAL_POINT":
102
+ if is_tap_action(action_touch_yx, action_lift_yx):
103
+ result_touch_yx = [round(axis, 4) for axis in result_touch_yx]
104
+ # if touching, the lift can be the same as touch
105
+ result_lift_yx = result_touch_yx
106
+ else:
107
+ drags_match = _check_drag_actions_match(
108
+ action_touch_yx, action_lift_yx
109
+ )
110
+ result_touch_yx, result_lift_yx = scroll_map[drags_match]
111
+
112
+ target_action = f'"action_type": "{result_action}", "touch_point": "{result_touch_yx}", "lift_point": "{result_lift_yx}", "typed_text": "{result_text}"'
113
+
114
+ if args.use_history:
115
+ prev_actions = "\n".join(history_action)
116
+ question = f"Previous Actions: {prev_actions}\n{question}"
117
+ if args.use_img_history:
118
+ image = history_image + [image]
119
+ image = torch.stack(image)
120
+
121
+ if args.use_future:
122
+ future_actions = episode_data[step_idx:]
123
+ if len(future_actions) > args.use_future:
124
+ future_actions = future_actions[:args.use_future]
125
+ future_actions = "[" + ",".join([action_t["result_action"][0] for action_t in future_actions]) + "]\n"
126
+ target_action_label = "Action Plan: " + future_actions + "; Action Decision: " + target_action
127
+
128
+ source_text.append(question)
129
+ source_image.append(image)
130
+ target_text.append(target_action_label)
131
+ anno_positions.append(ui_positions)
132
+
133
+ if args.use_history:
134
+ history_action.append(target_action)
135
+ if args.use_img_history:
136
+ history_image.append(image[-1])
137
+ history_image.pop(0)
138
+ if len(history_action) > args.use_history:
139
+ history_action.pop(0)
140
+
141
+
142
+ if args.debug_num:
143
+ if int(qid) > args.debug_num:
144
+ break
145
+ block = 2000
146
+ return source_text[:block], source_image[:block], target_text[:block], anno_positions[:block]
147
+
148
+ _SWIPE_DISTANCE_THRESHOLD = 0.04
149
+ def is_tap_action(normalized_start_yx, normalized_end_yx):
150
+ distance = jnp.linalg.norm(
151
+ jnp.array(normalized_start_yx) - jnp.array(normalized_end_yx))
152
+ return distance <= _SWIPE_DISTANCE_THRESHOLD
153
+
154
+ def _check_drag_actions_match(
155
+ drag_touch_yx,
156
+ drag_lift_yx,
157
+ ):
158
+ """Determines if two drag actions are the same."""
159
+ # Store drag deltas (the change in the y and x coordinates from touch to
160
+ # lift), magnitudes, and the index of the main axis, which is the axis with
161
+ # the greatest change in coordinate value (e.g. a drag starting at (0, 0) and
162
+ # ending at (0.3, 0.5) has a main axis index of 1).
163
+ drag_1_deltas = drag_lift_yx - drag_touch_yx
164
+ drag_1_magnitudes = jnp.abs(drag_1_deltas)
165
+ drag_1_main_axis = np.argmax(drag_1_magnitudes)
166
+
167
+ # y axis
168
+ if drag_1_main_axis == 0:
169
+ if drag_1_deltas[0] < 0:
170
+ scroll = "up"
171
+ else:
172
+ scroll = "down"
173
+ elif drag_1_main_axis == 1:
174
+ if drag_1_deltas[1] < 0:
175
+ scroll = "left"
176
+ else:
177
+ scroll = "right"
178
+
179
+ return scroll
180
+
181
+ class AITWDatasetImg(Dataset):
182
+ """
183
+ Creating a custom dataset for reading the dataset and
184
+ loading it into the dataloader to pass it to the
185
+ neural network for finetuning the model
186
+
187
+ """
188
+
189
+ def __init__(
190
+ self, data, tokenizer, source_len, target_len
191
+ ):
192
+ """
193
+ Initializes a Dataset class
194
+
195
+ Args:
196
+ dataframe (pandas.DataFrame): Input dataframe
197
+ tokenizer (transformers.tokenizer): Transformers tokenizer
198
+ source_len (int): Max length of source text
199
+ target_len (int): Max length of target text
200
+ source_text (str): column name of source text
201
+ target_text (str): column name of target text
202
+ """
203
+ self.tokenizer = tokenizer
204
+ self.source_len = source_len
205
+ self.summ_len = target_len
206
+ self.source_text = data[0]
207
+ self.source_image = data[1]
208
+ self.target_text = data[2]
209
+ self.anno_positions = data[3]
210
+
211
+ def __len__(self):
212
+ """returns the length of dataframe"""
213
+ return len(self.target_text)
214
+
215
+ def __getitem__(self, index):
216
+ """return the input ids, attention masks and target ids"""
217
+
218
+ source_text = str(self.source_text[index])
219
+ source_image = self.source_image[index]
220
+ target_text_org = str(self.target_text[index])
221
+
222
+
223
+ # abc = self.tokenizer.tokenize(target_text)
224
+ # print(len(abc))
225
+
226
+ pattern = r'(?<=Action Decision:\s).*'
227
+ result = re.search(pattern, target_text_org)
228
+ target_text = result.group(0)
229
+ target_text = target_text.strip()
230
+
231
+ target_dict = eval("{" + target_text + "}")
232
+ action = action_type.ActionType[target_dict["action_type"]].value
233
+
234
+ touch_point = eval(target_dict["touch_point"])
235
+ lift_point = eval(target_dict["lift_point"])
236
+
237
+ # cleaning data so as to ensure data is in string type
238
+ source_text = " ".join(source_text.split())
239
+ target_text_org = " ".join(target_text_org.split())
240
+
241
+ source = self.tokenizer.batch_encode_plus(
242
+ [source_text],
243
+ max_length=self.source_len,
244
+ pad_to_max_length=True,
245
+ truncation=True,
246
+ padding="max_length",
247
+ return_tensors="pt",
248
+ )
249
+ target = self.tokenizer.batch_encode_plus(
250
+ [target_text_org],
251
+ max_length=self.summ_len,
252
+ pad_to_max_length=True,
253
+ truncation=True,
254
+ padding="max_length",
255
+ return_tensors="pt",
256
+ )
257
+
258
+ source_ids = source["input_ids"].squeeze()
259
+ source_mask = source["attention_mask"].squeeze()
260
+ target_ids = target["input_ids"].squeeze()
261
+
262
+ image_ids = torch.tensor(source_image).squeeze()
263
+ vis_attention_mask = torch.tensor([1]).squeeze()
264
+
265
+ act_ids = torch.tensor(action).squeeze()
266
+ touch_point = torch.tensor(touch_point).squeeze()
267
+ lift_point = torch.tensor(lift_point).squeeze()
268
+
269
+
270
+ return {
271
+ "input_ids": source_ids,
272
+ "attention_mask": source_mask,
273
+ "image_ids": image_ids,
274
+ "labels": target_ids,
275
+ "target_act": act_ids,
276
+ "target_touch": touch_point,
277
+ "target_lift": lift_point
278
+ }