Keiran Paster commited on
Commit
8a6689a
1 Parent(s): 1a33fb4

add example code and readme

Browse files
README.md CHANGED
@@ -1,3 +1,7 @@
1
  ---
2
  license: apache-2.0
3
  ---
 
 
 
 
 
1
  ---
2
  license: apache-2.0
3
  ---
4
+
5
+ This repository stores the MathScore and KenLM models used in the generation of OpenWebMath.
6
+
7
+ To test the models, please `git clone` this repository and run `python perplexity.py` to test the KenLM model and `python math_score.py` to test the MathScore model.
example/__pycache__/text_normalizer.cpython-39.pyc ADDED
Binary file (4.47 kB). View file
 
example/math_score.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import fasttext
2
+ from text_normalizer import normalize
3
+
4
+ def score_text(model, text):
5
+ normalized_text = normalize(text).replace('\n', ' ')
6
+ # Remove any [EQUATION] tokens
7
+ normalized_text = normalized_text.replace('[EQUATION]', '')
8
+ pred = model.predict(normalized_text, k=2)
9
+ if pred[0][0] == '__label__positive':
10
+ prob = pred[1][0]
11
+ else:
12
+ prob = pred[1][1]
13
+
14
+ return prob
15
+
16
+ # Load the fasttext model
17
+
18
+ model = fasttext.load_model('../math_score.bin')
19
+
20
+ # Test the model
21
+ TEXT = """I thought I’d add a little bit of background. The previous discussion started from the result $P(B|AC) = K^{-1}P(B|C)P(A|BC) = K^{-1} P(AB|C)$ where $K=P(A|C).$ Although this is called Bayes’ theorem, the general form of it as stated here was actually first written down, not by Bayes but by Laplace."""
22
+
23
+ print(score_text(model, TEXT)) # Should print out 0.912
example/perplexity.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import kenlm
2
+ from text_normalizer import normalize
3
+
4
+ def document_perplexity(model, text):
5
+ text = normalize(text)
6
+ score = model.score(text)
7
+ return 10 ** (-score / len(text.split()))
8
+
9
+ # Load the language model
10
+ model = kenlm.Model('../lm-v2.binary')
11
+
12
+ # Test the model
13
+ TEXT = """I thought I’d add a little bit of background. The previous discussion started from the result $P(B|AC) = K^{-1}P(B|C)P(A|BC) = K^{-1} P(AB|C)$ where $K=P(A|C).$ Although this is called Bayes’ theorem, the general form of it as stated here was actually first written down, not by Bayes but by Laplace."""
14
+
15
+ print(document_perplexity(model, TEXT)) # Should print out ~239
example/text_normalizer.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ #
6
+ # From https://github.com/facebookresearch/cc_net/blob/main/cc_net/text_normalizer.py
7
+
8
+ import re
9
+ import unicodedata
10
+
11
+ UNICODE_PUNCT = {
12
+ ",": ",",
13
+ "。": ".",
14
+ "、": ",",
15
+ "„": '"',
16
+ "”": '"',
17
+ "“": '"',
18
+ "«": '"',
19
+ "»": '"',
20
+ "1": '"',
21
+ "」": '"',
22
+ "「": '"',
23
+ "《": '"',
24
+ "》": '"',
25
+ "´": "'",
26
+ "∶": ":",
27
+ ":": ":",
28
+ "?": "?",
29
+ "!": "!",
30
+ "(": "(",
31
+ ")": ")",
32
+ ";": ";",
33
+ "–": "-",
34
+ "—": " - ",
35
+ ".": ". ",
36
+ "~": "~",
37
+ "’": "'",
38
+ "…": "...",
39
+ "━": "-",
40
+ "〈": "<",
41
+ "〉": ">",
42
+ "【": "[",
43
+ "】": "]",
44
+ "%": "%",
45
+ "►": "-",
46
+ }
47
+
48
+ UNICODE_PUNCT_RE = re.compile(f"[{''.join(UNICODE_PUNCT.keys())}]")
49
+
50
+ MATH_RE = r"(?<!\\)(\$\$?.+?\$\$?)"
51
+ CODE_RE = r'\`{1,3}.*?\`{1,3}'
52
+
53
+
54
+ def replace_unicode_punct(text: str) -> str:
55
+ return "".join((UNICODE_PUNCT.get(c, c) for c in text))
56
+
57
+
58
+ def remove_unicode_punct(text: str) -> str:
59
+ """More aggressive version of replace_unicode_punct but also faster."""
60
+ return UNICODE_PUNCT_RE.sub("", text)
61
+
62
+
63
+ def strip_accents(line: str) -> str:
64
+ """Strips accents from a piece of text."""
65
+ nfd = unicodedata.normalize("NFD", line)
66
+ output = [c for c in nfd if unicodedata.category(c) != "Mn"]
67
+ if len(output) == line:
68
+ return line
69
+ return "".join(output)
70
+
71
+
72
+ # Build a regex matching all control characters.
73
+ NON_PRINTING_CHARS_RE = re.compile(
74
+ f"[{''.join(map(chr, list(range(0,32)) + list(range(127,160))))}]"
75
+ )
76
+ DIGIT_RE = re.compile(r"\d")
77
+ PUNCT_OR_NON_PRINTING_CHARS_RE = re.compile(
78
+ (UNICODE_PUNCT_RE.pattern + NON_PRINTING_CHARS_RE.pattern).replace("][", "")
79
+ )
80
+
81
+
82
+ def remove_non_printing_char(text: str) -> str:
83
+ return NON_PRINTING_CHARS_RE.sub("", text)
84
+
85
+
86
+ def normalize_spacing_for_tok(text: str, language: str = "en") -> str:
87
+ res = (
88
+ text.replace("\r", "")
89
+ # remove extra spaces
90
+ .replace("(", " (")
91
+ .replace(")", ") ")
92
+ .replace(" +", " ")
93
+ )
94
+ res = re.sub(r"\) ([\.\!\:\?\;\,])", r"\)\1", res)
95
+ res = res.replace("( ", "(").replace(" )", ")")
96
+ res = re.sub(r"(\d) \%", r"\1\%", res)
97
+ res = res.replace(" :", ":").replace(" ;", ";")
98
+ res = res.replace("`", "'").replace("''", ' " ')
99
+
100
+ res = (
101
+ res.replace("„", '"')
102
+ .replace("“", '"')
103
+ .replace("”", '"')
104
+ .replace("–", "-")
105
+ .replace("—", " - ")
106
+ .replace(" +", " ")
107
+ .replace("´", "'")
108
+ .replace("([a-z])‘([a-z])", r"\1'\2/")
109
+ .replace("([a-z])’([a-z])", r"\1'\2/")
110
+ .replace("‘", '"')
111
+ .replace("‚", '"')
112
+ .replace("’", '"')
113
+ .replace("''", '"')
114
+ .replace("´´", '"')
115
+ .replace("…", "...")
116
+ # French quotes
117
+ .replace(" « ", ' "')
118
+ .replace("« ", '"')
119
+ .replace("«", '"')
120
+ .replace(" » ", '" ')
121
+ .replace(" »", '"')
122
+ .replace("»", '"')
123
+ # handle pseudo-spaces
124
+ .replace(" %", "%")
125
+ .replace("nº ", "nº ")
126
+ .replace(" :", ":")
127
+ .replace(" ºC", " ºC")
128
+ .replace(" cm", " cm")
129
+ .replace(" ?", "?")
130
+ .replace(" !", "!")
131
+ .replace(" ;", ";")
132
+ .replace(", ", ", ")
133
+ .replace(" +", " ")
134
+ .replace(".", ". ")
135
+ )
136
+ # English "quotation," followed by comma, style
137
+ if language == "en":
138
+ res = re.sub(r"\"([,\.]+)", r"\1\"", res)
139
+ # Czech is confused
140
+ elif language == "cs" or language == "cz":
141
+ pass
142
+ # German/Spanish/French "quotation", followed by comma, style
143
+ else:
144
+ res = res.replace(',"', '",')
145
+ res = re.sub(
146
+ r"(\.+)\"(\s*[^<])", r"\"\1\2", res
147
+ ) # don't fix period at end of sentence
148
+
149
+ if (
150
+ language == "de"
151
+ or language == "es"
152
+ or language == "cz"
153
+ or language == "cs"
154
+ or language == "fr"
155
+ ):
156
+ res = re.sub(r"(\d) (\d)", r"\1,\2", res)
157
+ else:
158
+ res = re.sub(r"(\d) (\d)", r"\1.\2", res)
159
+ return res
160
+
161
+
162
+ def normalize(line: str, accent=True, case=True, numbers=True, math=True, code=True, punct=1) -> str:
163
+ line = line.strip()
164
+ if not line:
165
+ return line
166
+ if case:
167
+ line = line.lower()
168
+ if accent:
169
+ line = strip_accents(line)
170
+ if numbers:
171
+ line = DIGIT_RE.sub("0", line)
172
+ if punct == 1:
173
+ line = replace_unicode_punct(line)
174
+ elif punct == 2:
175
+ line = remove_unicode_punct(line)
176
+ if math:
177
+ line = re.sub(MATH_RE, "[EQUATION]", line, flags=re.DOTALL)
178
+ if code:
179
+ line = re.sub(CODE_RE, "[CODE]", line, flags=re.DOTALL)
180
+ # Replace any <s> or </s> explicitly written in the text with nothing
181
+ line = line.replace("<s>", "").replace("</s>", "")
182
+ line = remove_non_printing_char(line)
183
+ return line
184
+
185
+
186
+ def slow_normalize_for_dedup(line: str) -> str:
187
+ return normalize(line, accent=False, case=True, numbers=True, punct=2)
188
+
189
+
190
+ def normalize_for_dedup(line: str) -> str:
191
+ line = line.strip()
192
+ if not line:
193
+ return line
194
+ # case
195
+ line = line.lower()
196
+ # numbers
197
+ line = DIGIT_RE.sub("0", line)
198
+ line = PUNCT_OR_NON_PRINTING_CHARS_RE.sub("", line)
199
+ return line