ryanramos commited on
Commit
991b6ea
·
1 Parent(s): 3f22931

Delete data/vqa_dataset.py

Browse files
Files changed (1) hide show
  1. data/vqa_dataset.py +0 -115
data/vqa_dataset.py DELETED
@@ -1,115 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the BSD-style license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- import json
8
- import os
9
- from typing import Callable, List, Tuple, Union
10
-
11
- import torch
12
-
13
- from PIL import Image
14
- from torch import Tensor
15
- from torch.utils.data import Dataset
16
-
17
-
18
- class VQADataset(Dataset):
19
- """
20
- Create the dataset for VQA task.
21
-
22
- Args:
23
- ann_file (List[str]): The paths to annotation json files.
24
- vqa_root (str): The path to vqa data directory.
25
- vg_root (str): The path to vg data directory.
26
- image_transform (Callable[[Image.Image], Tensor]): image data transform.
27
- question_transform (Callable[[Union[List[str], str]], Tensor]): text data transform for questions.
28
- answer_transform (Callable[[Union[List[str], str]], Tensor]): text data transform for answers.
29
- split (str): Indicates train or test. Default is train.
30
- answer_list (str): The path to the answers list. Required for test split.
31
-
32
- Dataset Outputs:
33
- if split is train:
34
- image (Tensor): Transformed image input tensor of shape (C, W, H).
35
- question (Tensor): Transformed question token input ids.
36
- answers (List[Tensor]): List of transformed answers token input ids.
37
- answer_weights (List[float]): List of answer weights.
38
- answer_weights[i] is proportional to the number of occurences of answers[i]
39
- if split is test:
40
- image (Tensor): Transformed image input tensor of shape (C, W, H).
41
- question (Tensor): Transformed text token input ids.
42
- question_id (int): The question sample id.
43
- """
44
-
45
- def __init__(
46
- self,
47
- ann_file: List[str],
48
- vqa_root: str,
49
- vg_root: str,
50
- image_transform: Callable[[Image.Image], Tensor],
51
- question_transform: Callable[[Union[List[str], str]], Tensor],
52
- answer_transform: Callable[[Union[List[str], str]], Tensor],
53
- split: str = "train",
54
- answer_list: str = None,
55
- ) -> None:
56
- self.ann = []
57
- for f in ann_file:
58
- self.ann += json.load(open(f, "r"))
59
-
60
- self.vqa_root = vqa_root
61
- self.vg_root = vg_root
62
- self.image_transform = image_transform
63
- self.question_transform = question_transform
64
- self.answer_transform = answer_transform
65
- self.split = split
66
-
67
- if split == "test":
68
- self.answer_list = json.load(open(answer_list, "r"))
69
- self.answer_input_ids = self.answer_transform(self.answer_list)
70
- self.answer_attention_mask = (self.answer_input_ids != 0).type(torch.long)
71
-
72
- def __len__(self) -> int:
73
- return len(self.ann)
74
-
75
- def __getitem__(
76
- self, index: int
77
- ) -> Union[
78
- Tuple[Tensor, Tensor, int], Tuple[Tensor, Tensor, List[Tensor], List[float]]
79
- ]:
80
- ann = self.ann[index]
81
-
82
- image_root = self.vqa_root if ann["dataset"] == "vqa" else self.vg_root
83
- image_path = os.path.join(image_root, ann["image"])
84
- image = Image.open(image_path).convert("RGB")
85
- image = self.image_transform(image)
86
- question = self.question_transform(ann["question"])
87
-
88
- if self.split == "test":
89
- return image, question, ann["question_id"]
90
-
91
- elif self.split == "train":
92
- if ann["dataset"] == "vqa":
93
- # Each VQA sample question has a list of answers (with potential repeats)
94
- # answer_weight[answer] = count(answer) / len(answers for the question)
95
- answer_weights = {}
96
- for answer in ann["answer"]:
97
- if answer in answer_weights.keys():
98
- answer_weights[answer] += 1 / len(ann["answer"])
99
- else:
100
- answer_weights[answer] = 1 / len(ann["answer"])
101
-
102
- answers = list(answer_weights.keys())
103
- answer_weights = list(answer_weights.values())
104
-
105
- elif ann["dataset"] == "vg":
106
- # A VG sample question has one answer so assign it a constant weight (0.5)
107
- answers = [ann["answer"]]
108
- answer_weights = [0.5]
109
-
110
- answers = list(self.answer_transform(answers))
111
-
112
- return image, question, answers, answer_weights
113
-
114
- else:
115
- raise ValueError("dataset split should be train or test")