Tao11 commited on
Commit
9c73e47
·
unverified ·
1 Parent(s): a8a06b7

Add baize dataset (#25)

Browse files

* add baize dataset

* add baize dataset

.gitignore CHANGED
@@ -3,6 +3,7 @@
3
  wandb/
4
 
5
  checkpoints/
 
6
 
7
  # Byte-compiled / optimized / DLL files
8
  __pycache__/
 
3
  wandb/
4
 
5
  checkpoints/
6
+ tests/
7
 
8
  # Byte-compiled / optimized / DLL files
9
  __pycache__/
README.md CHANGED
@@ -145,6 +145,10 @@ conda env create -f environment.yml
145
 
146
  You can also customize the data path in the [configs/dataset_config.py](configs/dataset_config.py).
147
 
 
 
 
 
148
 
149
  ## Start training
150
 
 
145
 
146
  You can also customize the data path in the [configs/dataset_config.py](configs/dataset_config.py).
147
 
148
+ 8. [Baize](https://github.com/project-baize/baize-chatbot)
149
+
150
+ Download it from [this link](https://github.com/project-baize/baize-chatbot/blob/main/data/quora_chat_data.json) and place it in `data/baize/quora_chat_data.json`.
151
+
152
 
153
  ## Start training
154
 
configs/dataset_config.py CHANGED
@@ -57,4 +57,8 @@ language_datasets = [
57
  type="alpaca_gpt4",
58
  ann_path="data/alpaca_gpt4/alpaca_gpt4_data.json",
59
  ),
 
 
 
 
60
  ]
 
57
  type="alpaca_gpt4",
58
  ann_path="data/alpaca_gpt4/alpaca_gpt4_data.json",
59
  ),
60
+ dict(
61
+ type="baize",
62
+ ann_path="data/baize/quora_chat_data.json",
63
+ ),
64
  ]
mmgpt/datasets/baize_dataset.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ from mmgpt.datasets.dolly_dataset import DollyDataset
4
+
5
+
6
+ TEMPLATE = {
7
+ "description": "Template used by Alpaca-LoRA.",
8
+ "prompt_choice": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{question}\n\n### Input:\n{options}\n\n### Response:\n",
9
+ "prompt_qa": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{question}\n\n### Response:\n",
10
+ "prompt_dial": "\n\n### Instruction:\n{question}\n\n### Response:\n",
11
+ "response_split": "### Response:",
12
+ }
13
+
14
+ class LangDialPrompter:
15
+ def __call__(self, question, options=None):
16
+ if options:
17
+ options = ", ".join(options)
18
+ res = TEMPLATE["prompt_choice"].format(image="<image>", question=question, options=options)
19
+ else:
20
+ res = TEMPLATE["prompt_dial"].format(question=question)
21
+ return res
22
+
23
+ def get_response(self, output: str) -> str:
24
+ return output.split(TEMPLATE["response_split"])[-1].strip()
25
+
26
+ class BaiZeDataset(DollyDataset):
27
+ """
28
+ ```json
29
+ [
30
+ {
31
+ "instruction": "Identify the odd one out.",
32
+ "input": "Twitter, Instagram, Telegram",
33
+ "output": "The odd one out is Telegram. Twitter and Instagram are social media platforms mainly for sharing information, images and videos while Telegram is a cloud-based instant messaging and voice-over-IP service."
34
+ },
35
+ ]
36
+ """
37
+ def __init__(self, *args, **kwargs):
38
+ super(BaiZeDataset, self).__init__(*args, **kwargs)
39
+ self.prompter = LangDialPrompter()
40
+
41
+ def load_annotation(self, ann_path):
42
+ self.annotation = json.load(open(ann_path, "r"))
43
+
44
+ def process_text(self, anns):
45
+ # TODO remove this
46
+ begin_string = "Below is an instruction that describes a task. Write a response that appropriately completes the request."
47
+ convs = anns['input'].split("[|Human|] ")
48
+ conv_list = []
49
+ for conv_id, one_conv in enumerate(convs[1:-1]):
50
+ question, answer = one_conv.split("[|AI|] ")
51
+ question = question.replace("\n", "")
52
+ answer = answer.replace("\n", "")
53
+ instruction = self.prompter(question)
54
+ if conv_id == 0:
55
+ single_conv = dict(instruction=begin_string + instruction, answer=answer)
56
+ else:
57
+ single_conv = dict(instruction=instruction, answer=answer)
58
+ conv_list.append(single_conv)
59
+ return conv_list
60
+
61
+ def __getitem__(self, index):
62
+ ann = self.annotation[index]
63
+ text_list = self.process_text(ann)
64
+ res_list = []
65
+ for text in text_list:
66
+ single_res = self.tokenize(text)
67
+ single_res["instruction"] = text["instruction"]
68
+ single_res["answer"] = text["answer"]
69
+ res_list.append(single_res)
70
+
71
+ input_ids = []
72
+ attention_mask = []
73
+ labels = []
74
+ instruction = []
75
+ answer = []
76
+ for res in res_list:
77
+ input_ids.extend(res["input_ids"])
78
+ attention_mask.extend(res["attention_mask"])
79
+ labels.extend(res["labels"])
80
+ instruction.append(res["instruction"])
81
+ answer.append(res["answer"])
82
+
83
+ res = dict(
84
+ input_ids=input_ids, attention_mask=attention_mask, labels=labels, instruction=instruction, answer=answer
85
+ )
86
+ return res
mmgpt/datasets/builder.py CHANGED
@@ -15,6 +15,7 @@ from .ocr_vqa_dataset import OCRVQADataset # noqa: F401
15
  from .snli_ve_datasets import SNLIVEDataset # noqa: F401
16
  from .text_ocr_dataset import TextOCRDataset # noqa: F401
17
  from .vqa_dataset import ConcatDataset, VQADataset # noqa: F401
 
18
 
19
 
20
  def build_dataset(dataset_config, **kwargs):
@@ -108,6 +109,11 @@ def build_dataset(dataset_config, **kwargs):
108
  **dataset_config,
109
  **kwargs,
110
  )
 
 
 
 
 
111
  else:
112
  raise NotImplementedError
113
 
 
15
  from .snli_ve_datasets import SNLIVEDataset # noqa: F401
16
  from .text_ocr_dataset import TextOCRDataset # noqa: F401
17
  from .vqa_dataset import ConcatDataset, VQADataset # noqa: F401
18
+ from .baize_dataset import BaiZeDataset # noqa: F401
19
 
20
 
21
  def build_dataset(dataset_config, **kwargs):
 
109
  **dataset_config,
110
  **kwargs,
111
  )
112
+ elif dataset_type == "baize":
113
+ dataset = BaiZeDataset(
114
+ **dataset_config,
115
+ **kwargs,
116
+ )
117
  else:
118
  raise NotImplementedError
119