Update README.md
Browse files
README.md
CHANGED
@@ -83,3 +83,115 @@ the output:
|
|
83 |
- SyntheticWhichIsGreater5k
|
84 |
- 機械的に合成した、二つの小数のどちらが大きいかを回答する問題 5,000問
|
85 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
- SyntheticWhichIsGreater5k
|
84 |
- 機械的に合成した、二つの小数のどちらが大きいかを回答する問題 5,000問
|
85 |
|
86 |
+
下の二つのデータは、[math_problem.py](https://huggingface.co/p1atdev/llm-jp-3-3.7b-instruct2-R27/blob/main/math_problem.py) の関数を使って以下のように合成しました。
|
87 |
+
|
88 |
+
```py
|
89 |
+
def generate_int_problem(
|
90 |
+
num_generation: int,
|
91 |
+
max_int: int,
|
92 |
+
min_int: int,
|
93 |
+
max_terms: int,
|
94 |
+
):
|
95 |
+
for i in range(num_generation):
|
96 |
+
text, tex, result = create_integer_arithmetic_problem(
|
97 |
+
max_val=max_int,
|
98 |
+
min_val=min_int,
|
99 |
+
max_terms=max_terms,
|
100 |
+
)
|
101 |
+
formula = random.choice([text, tex])
|
102 |
+
templates = [
|
103 |
+
"{formula} = ?",
|
104 |
+
"{formula} を計算してください。",
|
105 |
+
"次の式を計算し、計算結果を解答してください。\n{formula}",
|
106 |
+
"計算して\n{formula}",
|
107 |
+
"次の式を計算してください。\n{formula}",
|
108 |
+
"次の式の答えは何ですか?\n{formula}",
|
109 |
+
"? に当てはまる数字を答えてください。\n{formula}",
|
110 |
+
"{formula}\n計算して",
|
111 |
+
"{formula}\n↑の答えを求めてください。",
|
112 |
+
]
|
113 |
+
instruction = random.choice(templates).format(formula=formula)
|
114 |
+
|
115 |
+
yield {
|
116 |
+
"ground_truth": str(result),
|
117 |
+
"instruction": instruction,
|
118 |
+
"source": "synthetic_int_problem",
|
119 |
+
"answer_dtype": "int",
|
120 |
+
"skip_check": False,
|
121 |
+
}
|
122 |
+
|
123 |
+
def generate_wig_problem(
|
124 |
+
num_generation: int,
|
125 |
+
max_num: float,
|
126 |
+
min_num: float,
|
127 |
+
precision: int,
|
128 |
+
):
|
129 |
+
for i in range(num_generation):
|
130 |
+
num_1, num_2, greater = create_two_decimals(
|
131 |
+
min_val=min_num,
|
132 |
+
max_val=max_num,
|
133 |
+
precision=precision,
|
134 |
+
)
|
135 |
+
templates = [
|
136 |
+
"次の数字のうち、どちらが大きいですか?\n{num_1}\n{num_2}",
|
137 |
+
"{num_1} と {num_2} のうちどちらが大きいですか?",
|
138 |
+
"{num_1} と {num_2} はどっちが大きい?",
|
139 |
+
"大きいほうを選んで: {num_1} {num_2}",
|
140 |
+
"次の数値を比較し、大きい方を選んでください。\n{num_1} {num_2}",
|
141 |
+
]
|
142 |
+
instruction = random.choice(templates).format(num_1=num_1, num_2=num_2)
|
143 |
+
|
144 |
+
yield {
|
145 |
+
"ground_truth": str(greater),
|
146 |
+
"instruction": instruction,
|
147 |
+
"source": "synthetic_which_is_greater",
|
148 |
+
"answer_dtype": "float",
|
149 |
+
"skip_check": False,
|
150 |
+
}
|
151 |
+
|
152 |
+
|
153 |
+
# generate dataset
|
154 |
+
ds_easy_int = Dataset.from_generator(
|
155 |
+
generate_int_problem,
|
156 |
+
gen_kwargs={
|
157 |
+
"num_generation": 5000,
|
158 |
+
"max_int": 10,
|
159 |
+
"min_int": -10,
|
160 |
+
"max_terms": 5,
|
161 |
+
},
|
162 |
+
)
|
163 |
+
assert isinstance(ds_easy_int, Dataset)
|
164 |
+
print("easy_int:", ds_easy_int)
|
165 |
+
|
166 |
+
ds_wig = Dataset.from_generator(
|
167 |
+
generate_wig_problem,
|
168 |
+
gen_kwargs={
|
169 |
+
"num_generation": 5000,
|
170 |
+
"max_num": 100.0,
|
171 |
+
"min_num": -100.0,
|
172 |
+
"precision": 3,
|
173 |
+
},
|
174 |
+
)
|
175 |
+
assert isinstance(ds_wig, Dataset)
|
176 |
+
print("wig:", ds_wig)
|
177 |
+
|
178 |
+
# japanese gsm8k
|
179 |
+
ds_gsm8k = load_dataset("p1atdev/gsm8k-ja-slim", split="train")
|
180 |
+
assert isinstance(ds_gsm8k, Dataset)
|
181 |
+
ds_gsm8k = ds_gsm8k.filter(filter_gsm8k_ja, batched=True)
|
182 |
+
ds_gsm8k = ds_gsm8k.map(
|
183 |
+
map_gsm8k_ja_instruction, remove_columns=ds_gsm8k.column_names
|
184 |
+
)
|
185 |
+
print("gsm8k:", ds_gsm8k)
|
186 |
+
|
187 |
+
# concat
|
188 |
+
ds = concatenate_datasets( # from datasets import concatenate_datasets
|
189 |
+
[
|
190 |
+
ds_gsm8k,
|
191 |
+
ds_easy_int,
|
192 |
+
ds_wig,
|
193 |
+
]
|
194 |
+
)
|
195 |
+
print("total:", ds)
|
196 |
+
```
|
197 |
+
|