Spaces:
Runtime error
Runtime error
Delete data/vqa_dataset.py
Browse files- data/vqa_dataset.py +0 -115
data/vqa_dataset.py
DELETED
@@ -1,115 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the BSD-style license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
|
7 |
-
import json
|
8 |
-
import os
|
9 |
-
from typing import Callable, List, Tuple, Union
|
10 |
-
|
11 |
-
import torch
|
12 |
-
|
13 |
-
from PIL import Image
|
14 |
-
from torch import Tensor
|
15 |
-
from torch.utils.data import Dataset
|
16 |
-
|
17 |
-
|
18 |
-
class VQADataset(Dataset):
|
19 |
-
"""
|
20 |
-
Create the dataset for VQA task.
|
21 |
-
|
22 |
-
Args:
|
23 |
-
ann_file (List[str]): The paths to annotation json files.
|
24 |
-
vqa_root (str): The path to vqa data directory.
|
25 |
-
vg_root (str): The path to vg data directory.
|
26 |
-
image_transform (Callable[[Image.Image], Tensor]): image data transform.
|
27 |
-
question_transform (Callable[[Union[List[str], str]], Tensor]): text data transform for questions.
|
28 |
-
answer_transform (Callable[[Union[List[str], str]], Tensor]): text data transform for answers.
|
29 |
-
split (str): Indicates train or test. Default is train.
|
30 |
-
answer_list (str): The path to the answers list. Required for test split.
|
31 |
-
|
32 |
-
Dataset Outputs:
|
33 |
-
if split is train:
|
34 |
-
image (Tensor): Transformed image input tensor of shape (C, W, H).
|
35 |
-
question (Tensor): Transformed question token input ids.
|
36 |
-
answers (List[Tensor]): List of transformed answers token input ids.
|
37 |
-
answer_weights (List[float]): List of answer weights.
|
38 |
-
answer_weights[i] is proportional to the number of occurences of answers[i]
|
39 |
-
if split is test:
|
40 |
-
image (Tensor): Transformed image input tensor of shape (C, W, H).
|
41 |
-
question (Tensor): Transformed text token input ids.
|
42 |
-
question_id (int): The question sample id.
|
43 |
-
"""
|
44 |
-
|
45 |
-
def __init__(
|
46 |
-
self,
|
47 |
-
ann_file: List[str],
|
48 |
-
vqa_root: str,
|
49 |
-
vg_root: str,
|
50 |
-
image_transform: Callable[[Image.Image], Tensor],
|
51 |
-
question_transform: Callable[[Union[List[str], str]], Tensor],
|
52 |
-
answer_transform: Callable[[Union[List[str], str]], Tensor],
|
53 |
-
split: str = "train",
|
54 |
-
answer_list: str = None,
|
55 |
-
) -> None:
|
56 |
-
self.ann = []
|
57 |
-
for f in ann_file:
|
58 |
-
self.ann += json.load(open(f, "r"))
|
59 |
-
|
60 |
-
self.vqa_root = vqa_root
|
61 |
-
self.vg_root = vg_root
|
62 |
-
self.image_transform = image_transform
|
63 |
-
self.question_transform = question_transform
|
64 |
-
self.answer_transform = answer_transform
|
65 |
-
self.split = split
|
66 |
-
|
67 |
-
if split == "test":
|
68 |
-
self.answer_list = json.load(open(answer_list, "r"))
|
69 |
-
self.answer_input_ids = self.answer_transform(self.answer_list)
|
70 |
-
self.answer_attention_mask = (self.answer_input_ids != 0).type(torch.long)
|
71 |
-
|
72 |
-
def __len__(self) -> int:
|
73 |
-
return len(self.ann)
|
74 |
-
|
75 |
-
def __getitem__(
|
76 |
-
self, index: int
|
77 |
-
) -> Union[
|
78 |
-
Tuple[Tensor, Tensor, int], Tuple[Tensor, Tensor, List[Tensor], List[float]]
|
79 |
-
]:
|
80 |
-
ann = self.ann[index]
|
81 |
-
|
82 |
-
image_root = self.vqa_root if ann["dataset"] == "vqa" else self.vg_root
|
83 |
-
image_path = os.path.join(image_root, ann["image"])
|
84 |
-
image = Image.open(image_path).convert("RGB")
|
85 |
-
image = self.image_transform(image)
|
86 |
-
question = self.question_transform(ann["question"])
|
87 |
-
|
88 |
-
if self.split == "test":
|
89 |
-
return image, question, ann["question_id"]
|
90 |
-
|
91 |
-
elif self.split == "train":
|
92 |
-
if ann["dataset"] == "vqa":
|
93 |
-
# Each VQA sample question has a list of answers (with potential repeats)
|
94 |
-
# answer_weight[answer] = count(answer) / len(answers for the question)
|
95 |
-
answer_weights = {}
|
96 |
-
for answer in ann["answer"]:
|
97 |
-
if answer in answer_weights.keys():
|
98 |
-
answer_weights[answer] += 1 / len(ann["answer"])
|
99 |
-
else:
|
100 |
-
answer_weights[answer] = 1 / len(ann["answer"])
|
101 |
-
|
102 |
-
answers = list(answer_weights.keys())
|
103 |
-
answer_weights = list(answer_weights.values())
|
104 |
-
|
105 |
-
elif ann["dataset"] == "vg":
|
106 |
-
# A VG sample question has one answer so assign it a constant weight (0.5)
|
107 |
-
answers = [ann["answer"]]
|
108 |
-
answer_weights = [0.5]
|
109 |
-
|
110 |
-
answers = list(self.answer_transform(answers))
|
111 |
-
|
112 |
-
return image, question, answers, answer_weights
|
113 |
-
|
114 |
-
else:
|
115 |
-
raise ValueError("dataset split should be train or test")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|