Spaces:
Runtime error
Runtime error
Peter Szemraj
commited on
Commit
•
f28496d
1
Parent(s):
31f09ae
:sparkles: upgrade to v2 code
Browse files- ai_single_response.py +0 -324
- app.py +97 -71
- grammar_improve.py +34 -15
- requirements.txt +1 -1
- symspell_rsc/frequency_bigramdictionary_en_243_342.txt +0 -0
- symspell_rsc/frequency_dictionary_en_82_765.txt +0 -0
- utils.py +20 -1
ai_single_response.py
DELETED
@@ -1,324 +0,0 @@
|
|
1 |
-
#!/usr/bin/env python3
|
2 |
-
# -*- coding: utf-8 -*-
|
3 |
-
"""
|
4 |
-
ai_single_response.py
|
5 |
-
|
6 |
-
An executable way to call the model. example:
|
7 |
-
*\gpt2_chatbot> python ai_single_response.py --model "GPT2_conversational_355M_WoW10k" --prompt "hey, what's up?" --time
|
8 |
-
|
9 |
-
query_gpt_model is used throughout the code, and is the "fundamental" building block of the bot and how everything works. Test this function with a few different models.
|
10 |
-
|
11 |
-
"""
|
12 |
-
from aitextgen import aitextgen
|
13 |
-
import argparse
|
14 |
-
import pprint as pp
|
15 |
-
import sys
|
16 |
-
import time
|
17 |
-
import warnings
|
18 |
-
from datetime import datetime
|
19 |
-
from pathlib import Path
|
20 |
-
from grammar_improve import remove_trailing_punctuation
|
21 |
-
from utils import print_spacer, cleantxt_wrap
|
22 |
-
|
23 |
-
warnings.filterwarnings(action="ignore", message=".*gradient_checkpointing*")
|
24 |
-
|
25 |
-
|
26 |
-
def extract_response(full_resp: list, plist: list, verbose: bool = False):
|
27 |
-
"""
|
28 |
-
extract_response - helper fn for ai_single_response.py. By default aitextgen returns the prompt and the response, we just want the response
|
29 |
-
|
30 |
-
Args:
|
31 |
-
full_resp (list): a list of strings, each string is a response
|
32 |
-
plist (list): a list of strings, each string is a prompt
|
33 |
-
|
34 |
-
verbose (bool, optional): 4 debug. Defaults to False.
|
35 |
-
"""
|
36 |
-
full_resp = [cleantxt_wrap(ele) for ele in full_resp]
|
37 |
-
plist = [cleantxt_wrap(pr) for pr in plist]
|
38 |
-
p_len = len(plist)
|
39 |
-
assert (
|
40 |
-
len(full_resp) >= p_len
|
41 |
-
), "model output should have as many lines or longer as the input."
|
42 |
-
|
43 |
-
if set(plist).issubset(full_resp):
|
44 |
-
|
45 |
-
del full_resp[:p_len] # remove the prompts from the responses
|
46 |
-
else:
|
47 |
-
print("the isolated responses are:\n")
|
48 |
-
pp.pprint(full_resp)
|
49 |
-
print_spacer()
|
50 |
-
print("the input prompt was:\n")
|
51 |
-
pp.pprint(plist)
|
52 |
-
print_spacer()
|
53 |
-
sys.exit("Exiting: some prompts not found in the responses")
|
54 |
-
if verbose:
|
55 |
-
print("the isolated responses are:\n")
|
56 |
-
pp.pprint(full_resp)
|
57 |
-
print_spacer()
|
58 |
-
print("the input prompt was:\n")
|
59 |
-
pp.pprint(plist)
|
60 |
-
print_spacer()
|
61 |
-
return full_resp # list of only the model generated responses
|
62 |
-
|
63 |
-
|
64 |
-
def get_bot_response(
|
65 |
-
name_resp: str, model_resp: str, name_spk: str, verbose: bool = False
|
66 |
-
):
|
67 |
-
|
68 |
-
"""
|
69 |
-
|
70 |
-
get_bot_response - from the model response, extract the bot response. This is needed because depending on the generation length the model may return more than one response.
|
71 |
-
|
72 |
-
Args: name_resp (str): the name of the responder
|
73 |
-
model_resp (str): the model response
|
74 |
-
verbose (bool, optional): 4 debug. Defaults to False.
|
75 |
-
|
76 |
-
returns: fn_resp (list of str)
|
77 |
-
"""
|
78 |
-
|
79 |
-
fn_resp = []
|
80 |
-
|
81 |
-
name_counter = 0
|
82 |
-
break_safe = False
|
83 |
-
for resline in model_resp:
|
84 |
-
if resline.startswith(name_resp):
|
85 |
-
name_counter += 1
|
86 |
-
break_safe = True # know the line is from bot as this line starts with the name of the bot
|
87 |
-
continue
|
88 |
-
if name_spk is not None and name_spk.lower() in resline.lower():
|
89 |
-
break
|
90 |
-
if ":" in resline and name_counter > 0:
|
91 |
-
if break_safe:
|
92 |
-
# we know this is a response from the bot even tho ':' is in the line
|
93 |
-
fn_resp.append(resline)
|
94 |
-
break_safe = False
|
95 |
-
else:
|
96 |
-
# we do not know this is a response from the bot. could be name of another person.. bot is "finished" response
|
97 |
-
break
|
98 |
-
else:
|
99 |
-
fn_resp.append(resline)
|
100 |
-
break_safe = False
|
101 |
-
if verbose:
|
102 |
-
print("the full response is:\n")
|
103 |
-
print("\n".join(fn_resp))
|
104 |
-
|
105 |
-
return fn_resp
|
106 |
-
|
107 |
-
|
108 |
-
def query_gpt_model(
|
109 |
-
prompt_msg: str,
|
110 |
-
speaker=None,
|
111 |
-
responder=None,
|
112 |
-
resp_length=64,
|
113 |
-
resp_min=10,
|
114 |
-
kparam=50,
|
115 |
-
temp=0.75,
|
116 |
-
top_p=0.90,
|
117 |
-
batch_size=1,
|
118 |
-
verbose=False,
|
119 |
-
use_gpu=False,
|
120 |
-
nbeams=1,
|
121 |
-
):
|
122 |
-
"""
|
123 |
-
query_gpt_model - the main function that calls the model.
|
124 |
-
|
125 |
-
Parameters:
|
126 |
-
-----------
|
127 |
-
prompt_msg (str): the prompt to be sent to the model
|
128 |
-
speaker (str, optional): the name of the speaker. Defaults to None.
|
129 |
-
responder (str, optional): the name of the responder. Defaults to None.
|
130 |
-
resp_length (int, optional): the length of the response. Defaults to 64.
|
131 |
-
resp_min (int, optional): the minimum length of the response. Defaults to 4.
|
132 |
-
kparam (int, optional): the k parameter for the top_p. Defaults to 150.
|
133 |
-
temp (float, optional): the temperature for the top_p. Defaults to 0.75.
|
134 |
-
top_p (float, optional): the top_p parameter for the top_p. Defaults to 0.65.
|
135 |
-
verbose (bool, optional): 4 debug. Defaults to False.
|
136 |
-
use_gpu (bool, optional): use gpu. Defaults to False.
|
137 |
-
nbeams (int, optional): the number of beams to search and return best value. Defaults to 1.
|
138 |
-
"""
|
139 |
-
from aitextgen.utils import GPT2ConfigCPU
|
140 |
-
|
141 |
-
ai = aitextgen(
|
142 |
-
model="pszemraj/Ballpark-Trivia-L", # THIS WORKS. XL is not working
|
143 |
-
to_gpu=use_gpu,
|
144 |
-
)
|
145 |
-
|
146 |
-
p_list = [] # track conversation
|
147 |
-
p_list.append(speaker.lower() + ":" + "\n")
|
148 |
-
p_list.append(prompt_msg.lower() + "\n")
|
149 |
-
p_list.append("\n")
|
150 |
-
p_list.append(responder.lower() + ":" + "\n")
|
151 |
-
this_prompt = "".join(p_list)
|
152 |
-
pr_len = len(this_prompt)
|
153 |
-
if verbose:
|
154 |
-
print("overall prompt:\n")
|
155 |
-
pp.pprint(this_prompt, indent=4)
|
156 |
-
# call the model
|
157 |
-
print("\n... generating...")
|
158 |
-
this_result = ai.generate(
|
159 |
-
n=1,
|
160 |
-
batch_size=batch_size,
|
161 |
-
# the prompt input counts for text length constraints
|
162 |
-
max_length=resp_length + pr_len,
|
163 |
-
min_length=resp_min + pr_len,
|
164 |
-
prompt=this_prompt,
|
165 |
-
top_k=kparam,
|
166 |
-
top_p=top_p,
|
167 |
-
do_sample=True,
|
168 |
-
return_as_list=True,
|
169 |
-
n_beams=nbeams,
|
170 |
-
temperature=temp,
|
171 |
-
verbose=True, # in this case verbose is just to enable huggingface logging
|
172 |
-
use_cache=True,
|
173 |
-
)
|
174 |
-
if verbose:
|
175 |
-
print("\n... generated:\n")
|
176 |
-
pp.pprint(this_result) # for debugging
|
177 |
-
# process the full result to get the ~bot response~ piece
|
178 |
-
this_result = str(this_result[0]).split(
|
179 |
-
"\n"
|
180 |
-
) # TODO: adjust hardcoded value for index to dynamic (if n>1)
|
181 |
-
og_res = this_result.copy()
|
182 |
-
og_prompt = p_list.copy()
|
183 |
-
diff_list = extract_response(
|
184 |
-
this_result, p_list, verbose=verbose
|
185 |
-
) # isolate the responses from the prompts
|
186 |
-
# extract the bot response from the model generated text
|
187 |
-
bot_dialogue = get_bot_response(
|
188 |
-
name_resp=responder, model_resp=diff_list, name_spk=speaker, verbose=verbose
|
189 |
-
)
|
190 |
-
print(f"DEBUG: {bot_dialogue} was original response pre-SC")
|
191 |
-
bot_resp = ", ".join(bot_dialogue)
|
192 |
-
bot_resp = bot_resp.strip()
|
193 |
-
# remove the last ',' '.' chars
|
194 |
-
bot_resp = remove_trailing_punctuation(bot_resp)
|
195 |
-
if verbose:
|
196 |
-
print("\n... bot response:\n")
|
197 |
-
pp.pprint(bot_resp)
|
198 |
-
og_prompt.append(bot_resp + "\n")
|
199 |
-
og_prompt.append("\n")
|
200 |
-
|
201 |
-
print("\nfinished!")
|
202 |
-
# return the bot response and the full conversation
|
203 |
-
|
204 |
-
return {"out_text": bot_resp, "full_conv": og_prompt} # model responses
|
205 |
-
|
206 |
-
|
207 |
-
# Set up the parsing of command-line arguments
|
208 |
-
def get_parser():
|
209 |
-
"""
|
210 |
-
get_parser a helper function for the argparse module, relevant if this is run as a script.
|
211 |
-
"""
|
212 |
-
|
213 |
-
parser = argparse.ArgumentParser(
|
214 |
-
description="submit a message and have a pretrained GPT model respond"
|
215 |
-
)
|
216 |
-
parser.add_argument(
|
217 |
-
"--prompt",
|
218 |
-
required=True, # MUST HAVE A PROMPT
|
219 |
-
type=str,
|
220 |
-
help="the message the bot is supposed to respond to. Prompt is said by speaker, answered by responder.",
|
221 |
-
)
|
222 |
-
parser.add_argument(
|
223 |
-
"--model",
|
224 |
-
required=False,
|
225 |
-
type=str,
|
226 |
-
default="GPT2_trivNatQAdailydia_774M_175Ksteps",
|
227 |
-
help="folder - with respect to git directory of your repo that has the model files in it (pytorch.bin + "
|
228 |
-
"config.json). No models? Run the script download_models.py",
|
229 |
-
)
|
230 |
-
|
231 |
-
parser.add_argument(
|
232 |
-
"--speaker",
|
233 |
-
required=False,
|
234 |
-
default=None,
|
235 |
-
help="Who the prompt is from (to the bot). Primarily relevant to bots trained on multi-individual chat data",
|
236 |
-
)
|
237 |
-
parser.add_argument(
|
238 |
-
"--responder",
|
239 |
-
required=False,
|
240 |
-
default="person beta",
|
241 |
-
help="who the responder is. Primarily relevant to bots trained on multi-individual chat data",
|
242 |
-
)
|
243 |
-
|
244 |
-
parser.add_argument(
|
245 |
-
"--topk",
|
246 |
-
required=False,
|
247 |
-
type=int,
|
248 |
-
default=150,
|
249 |
-
help="how many responses to sample (positive integer). lower = more random responses",
|
250 |
-
)
|
251 |
-
|
252 |
-
parser.add_argument(
|
253 |
-
"--temp",
|
254 |
-
required=False,
|
255 |
-
type=float,
|
256 |
-
default=0.75,
|
257 |
-
help="specify temperature hyperparam (0-1). roughly considered as 'model creativity'",
|
258 |
-
)
|
259 |
-
|
260 |
-
parser.add_argument(
|
261 |
-
"--topp",
|
262 |
-
required=False,
|
263 |
-
type=float,
|
264 |
-
default=0.65,
|
265 |
-
help="nucleus sampling frac (0-1). aka: what fraction of possible options are considered?",
|
266 |
-
)
|
267 |
-
|
268 |
-
parser.add_argument(
|
269 |
-
"--verbose",
|
270 |
-
default=False,
|
271 |
-
action="store_true",
|
272 |
-
help="pass this argument if you want all the printouts",
|
273 |
-
)
|
274 |
-
parser.add_argument(
|
275 |
-
"--time",
|
276 |
-
default=False,
|
277 |
-
action="store_true",
|
278 |
-
help="pass this argument if you want to know runtime",
|
279 |
-
)
|
280 |
-
return parser
|
281 |
-
|
282 |
-
|
283 |
-
if __name__ == "__main__":
|
284 |
-
# parse the command line arguments
|
285 |
-
args = get_parser().parse_args()
|
286 |
-
query = args.prompt
|
287 |
-
model_dir = str(args.model)
|
288 |
-
model_loc = Path.cwd() / model_dir
|
289 |
-
spkr = args.speaker
|
290 |
-
rspndr = args.responder
|
291 |
-
k_results = args.topk
|
292 |
-
my_temp = args.temp
|
293 |
-
my_top_p = args.topp
|
294 |
-
want_verbose = args.verbose
|
295 |
-
want_rt = args.time
|
296 |
-
|
297 |
-
st = time.perf_counter()
|
298 |
-
|
299 |
-
resp = query_gpt_model(
|
300 |
-
folder_path=model_loc,
|
301 |
-
prompt_msg=query,
|
302 |
-
speaker=spkr,
|
303 |
-
responder=rspndr,
|
304 |
-
kparam=k_results,
|
305 |
-
temp=my_temp,
|
306 |
-
top_p=my_top_p,
|
307 |
-
verbose=want_verbose,
|
308 |
-
use_gpu=False,
|
309 |
-
)
|
310 |
-
|
311 |
-
output = resp["out_text"]
|
312 |
-
pp.pprint(output, indent=4)
|
313 |
-
|
314 |
-
rt = round(time.perf_counter() - st, 1)
|
315 |
-
|
316 |
-
if want_rt:
|
317 |
-
print("took {runtime} seconds to generate. \n".format(runtime=rt))
|
318 |
-
|
319 |
-
if want_verbose:
|
320 |
-
print("finished - ", datetime.now())
|
321 |
-
p_list = resp["full_conv"]
|
322 |
-
print("A transcript of your chat is as follows: \n")
|
323 |
-
p_list = [item.strip() for item in p_list]
|
324 |
-
pp.pprint(p_list)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app.py
CHANGED
@@ -1,19 +1,9 @@
|
|
1 |
"""
|
2 |
-
|
3 |
-
|
4 |
"""
|
5 |
-
|
6 |
-
|
7 |
-
request,
|
8 |
-
session,
|
9 |
-
jsonify,
|
10 |
-
abort,
|
11 |
-
send_file,
|
12 |
-
render_template,
|
13 |
-
redirect,
|
14 |
-
)
|
15 |
-
from ai_single_response import query_gpt_model
|
16 |
-
from datetime import datetime
|
17 |
from transformers import pipeline
|
18 |
from cleantext import clean
|
19 |
from pathlib import Path
|
@@ -26,7 +16,7 @@ import os
|
|
26 |
import sys
|
27 |
from os.path import dirname
|
28 |
import nltk
|
29 |
-
from
|
30 |
from grammar_improve import (
|
31 |
detect_propers,
|
32 |
load_ns_checker,
|
@@ -35,6 +25,7 @@ from grammar_improve import (
|
|
35 |
remove_trailing_punctuation,
|
36 |
build_symspell_obj,
|
37 |
symspeller,
|
|
|
38 |
)
|
39 |
|
40 |
from utils import (
|
@@ -46,46 +37,78 @@ nltk.download("stopwords") # TODO: find where this requirement originates from
|
|
46 |
|
47 |
sys.path.append(dirname(dirname(os.path.abspath(__file__))))
|
48 |
warnings.filterwarnings(action="ignore", message=".*gradient_checkpointing*")
|
|
|
|
|
|
|
49 |
logging.basicConfig()
|
50 |
cwd = Path.cwd()
|
51 |
my_cwd = str(cwd.resolve()) # string so it can be passed to os.path() objects
|
52 |
|
53 |
|
54 |
-
def
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
)
|
|
|
|
|
|
|
|
|
61 |
|
|
|
62 |
|
63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
"""
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
message (str):
|
70 |
-
|
71 |
-
|
|
|
|
|
|
|
|
|
|
|
72 |
"""
|
|
|
73 |
st = time.perf_counter()
|
74 |
prompt = clean(message) # clean user input
|
75 |
prompt = prompt.strip() # get rid of any extra whitespace
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
|
|
|
|
|
|
|
|
|
|
89 |
rawtxt = resp["out_text"]
|
90 |
# check for proper nouns
|
91 |
if basic_sc and not detect_propers(rawtxt):
|
@@ -95,26 +118,16 @@ def ask_gpt(message: str):
|
|
95 |
else:
|
96 |
# no correction needed
|
97 |
cln_resp = rawtxt.strip()
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
|
|
|
|
|
|
102 |
return remove_trailing_punctuation(bot_resp)
|
103 |
|
104 |
|
105 |
-
def chat(trivia_query):
|
106 |
-
history = []
|
107 |
-
response = ask_gpt(trivia_query)
|
108 |
-
history = [trivia_query, response]
|
109 |
-
html = ""
|
110 |
-
for item in history:
|
111 |
-
html += f"<b>{item}</b> <br><br>"
|
112 |
-
|
113 |
-
html += ""
|
114 |
-
|
115 |
-
return html
|
116 |
-
|
117 |
-
|
118 |
def get_parser():
|
119 |
"""
|
120 |
get_parser - a helper function for the argparse module
|
@@ -127,9 +140,8 @@ def get_parser():
|
|
127 |
"--model",
|
128 |
required=False,
|
129 |
type=str,
|
130 |
-
default="pszemraj/Ballpark-Trivia-
|
131 |
-
help="
|
132 |
-
"config.json)",
|
133 |
)
|
134 |
parser.add_argument(
|
135 |
"--basic-sc",
|
@@ -139,21 +151,35 @@ def get_parser():
|
|
139 |
help="turn on symspell (baseline) correction instead of the more advanced neural net models",
|
140 |
)
|
141 |
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
return parser
|
143 |
|
144 |
|
145 |
if __name__ == "__main__":
|
146 |
args = get_parser().parse_args()
|
147 |
default_model = str(args.model)
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
152 |
if basic_sc:
|
153 |
-
print("
|
154 |
schnellspell = build_symspell_obj()
|
155 |
else:
|
156 |
-
print("using
|
157 |
ns_checker = load_ns_checker(fast=False)
|
158 |
|
159 |
print(f"using model stored here: \n {model_loc} \n")
|
@@ -188,7 +214,7 @@ if __name__ == "__main__":
|
|
188 |
],
|
189 |
title=f"Ballpark Trivia: {default_model} Model",
|
190 |
description=f"Are you frequently asked google-able Trivia questions and annoyed by it? Well, this is the app for you! Ballpark Trivia Bot answers any trivia question with something that sounds plausible but is probably not 100% correct. \n\n One might say.. the answers are in the right ballpark.",
|
191 |
-
article="Further details can be found in the [model card](https://huggingface.co/pszemraj/Ballpark-Trivia-
|
192 |
"**Important Notes & About:**\n\n"
|
193 |
"1. the model can take up to 60 seconds to respond sometimes, patience is a virtue.\n"
|
194 |
"2. the model started from a pretrained checkpoint, and was trained on several different datasets. Anything it says should be fact-checked before being regarded as a true statement.\n"
|
@@ -209,4 +235,4 @@ if __name__ == "__main__":
|
|
209 |
# prevent_thread_lock=True,
|
210 |
# share=True,
|
211 |
enable_queue=True, # also allows for dealing with multiple users simultaneously (per newer gradio version)
|
212 |
-
)
|
|
|
1 |
"""
|
2 |
+
app.py - the main file for the app. This creates the flask app and handles the routes.
|
3 |
+
|
4 |
"""
|
5 |
+
|
6 |
+
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
from transformers import pipeline
|
8 |
from cleantext import clean
|
9 |
from pathlib import Path
|
|
|
16 |
import sys
|
17 |
from os.path import dirname
|
18 |
import nltk
|
19 |
+
from converse import discussion
|
20 |
from grammar_improve import (
|
21 |
detect_propers,
|
22 |
load_ns_checker,
|
|
|
25 |
remove_trailing_punctuation,
|
26 |
build_symspell_obj,
|
27 |
symspeller,
|
28 |
+
fix_punct_spacing,
|
29 |
)
|
30 |
|
31 |
from utils import (
|
|
|
37 |
|
38 |
sys.path.append(dirname(dirname(os.path.abspath(__file__))))
|
39 |
warnings.filterwarnings(action="ignore", message=".*gradient_checkpointing*")
|
40 |
+
import transformers
|
41 |
+
|
42 |
+
transformers.logging.set_verbosity_error()
|
43 |
logging.basicConfig()
|
44 |
cwd = Path.cwd()
|
45 |
my_cwd = str(cwd.resolve()) # string so it can be passed to os.path() objects
|
46 |
|
47 |
|
48 |
+
def chat(trivia_query):
|
49 |
+
"""
|
50 |
+
chat - a function that takes in a trivia query and returns a response
|
51 |
+
|
52 |
+
"""
|
53 |
+
history = []
|
54 |
+
response = ask_gpt(trivia_query)
|
55 |
+
history = [trivia_query, response]
|
56 |
+
html = ""
|
57 |
+
for item in history:
|
58 |
+
html += f"<b>{item}</b> <br><br>"
|
59 |
|
60 |
+
html += ""
|
61 |
|
62 |
+
return html
|
63 |
+
|
64 |
+
|
65 |
+
|
66 |
+
def ask_gpt(
|
67 |
+
message: str,
|
68 |
+
chat_pipe,
|
69 |
+
speaker="person alpha",
|
70 |
+
responder="person beta",
|
71 |
+
max_len=196,
|
72 |
+
top_p=0.95,
|
73 |
+
top_k=50,
|
74 |
+
temperature=0.6,
|
75 |
+
):
|
76 |
"""
|
77 |
+
|
78 |
+
ask_gpt - a function that takes in a prompt and generates a response using the pipeline. This interacts the discussion function.
|
79 |
+
|
80 |
+
Parameters:
|
81 |
+
message (str): the question to ask the bot
|
82 |
+
chat_pipe (str): the chat_pipe to use for the bot (default: "pszemraj/Ballpark-Trivia-XL")
|
83 |
+
speaker (str): the name of the speaker (default: "person alpha")
|
84 |
+
responder (str): the name of the responder (default: "person beta")
|
85 |
+
max_len (int): the maximum length of the response (default: 128)
|
86 |
+
top_p (float): the top probability threshold (default: 0.95)
|
87 |
+
top_k (int): the top k threshold (default: 50)
|
88 |
+
temperature (float): the temperature of the response (default: 0.7)
|
89 |
"""
|
90 |
+
|
91 |
st = time.perf_counter()
|
92 |
prompt = clean(message) # clean user input
|
93 |
prompt = prompt.strip() # get rid of any extra whitespace
|
94 |
+
in_len = len(prompt)
|
95 |
+
if in_len > 512:
|
96 |
+
prompt = prompt[-512:] # truncate to 512 chars
|
97 |
+
print(f"Truncated prompt to last 512 chars: started with {in_len} chars")
|
98 |
+
max_len = min(max_len, 512)
|
99 |
+
|
100 |
+
resp = discussion(
|
101 |
+
prompt_text=prompt,
|
102 |
+
pipeline=chat_pipe,
|
103 |
+
speaker=speaker,
|
104 |
+
responder=responder,
|
105 |
+
top_p=top_p,
|
106 |
+
top_k=top_k,
|
107 |
+
temperature=temperature,
|
108 |
+
max_length=max_len,
|
109 |
+
)
|
110 |
+
gpt_et = time.perf_counter()
|
111 |
+
gpt_rt = round(gpt_et - st, 2)
|
112 |
rawtxt = resp["out_text"]
|
113 |
# check for proper nouns
|
114 |
if basic_sc and not detect_propers(rawtxt):
|
|
|
118 |
else:
|
119 |
# no correction needed
|
120 |
cln_resp = rawtxt.strip()
|
121 |
+
bot_resp_a = corr(remove_repeated_words(cln_resp))
|
122 |
+
bot_resp = fix_punct_spacing(bot_resp_a)
|
123 |
+
print(f"the prompt was:\n\t{message}\nand the response was:\n\t{bot_resp}\n")
|
124 |
+
corr_rt = round(time.perf_counter() - gpt_et, 4)
|
125 |
+
print(
|
126 |
+
f"took {gpt_rt + corr_rt} sec to respond, {gpt_rt} for GPT, {corr_rt} for correction\n"
|
127 |
+
)
|
128 |
return remove_trailing_punctuation(bot_resp)
|
129 |
|
130 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
def get_parser():
|
132 |
"""
|
133 |
get_parser - a helper function for the argparse module
|
|
|
140 |
"--model",
|
141 |
required=False,
|
142 |
type=str,
|
143 |
+
default="pszemraj/Ballpark-Trivia-XL", # default model
|
144 |
+
help="the model to use for the chatbot on https://huggingface.co/models OR a path to a local model",
|
|
|
145 |
)
|
146 |
parser.add_argument(
|
147 |
"--basic-sc",
|
|
|
151 |
help="turn on symspell (baseline) correction instead of the more advanced neural net models",
|
152 |
)
|
153 |
|
154 |
+
parser.add_argument(
|
155 |
+
"--verbose",
|
156 |
+
action="store_true",
|
157 |
+
default=False,
|
158 |
+
help="turn on verbose logging",
|
159 |
+
)
|
160 |
return parser
|
161 |
|
162 |
|
163 |
if __name__ == "__main__":
|
164 |
args = get_parser().parse_args()
|
165 |
default_model = str(args.model)
|
166 |
+
model_loc = Path(default_model) # if the model is a path, use it
|
167 |
+
basic_sc = args.basic_sc # whether to use the baseline spellchecker
|
168 |
+
device = 0 if torch.cuda.is_available() else -1
|
169 |
+
print(f"CUDA avail is {torch.cuda.is_available()}")
|
170 |
+
|
171 |
+
my_chatbot = (
|
172 |
+
pipeline("text-generation", model=model_loc.resolve(), device=device)
|
173 |
+
if model_loc.exists() and model_loc.is_dir()
|
174 |
+
else pipeline("text-generation", model=default_model, device=device)
|
175 |
+
) # if the model is a name, use it. stays on CPU if no GPU available
|
176 |
+
print(f"using model {my_chatbot.model}")
|
177 |
+
|
178 |
if basic_sc:
|
179 |
+
print("Using the baseline spellchecker")
|
180 |
schnellspell = build_symspell_obj()
|
181 |
else:
|
182 |
+
print("using Neuspell spell checker")
|
183 |
ns_checker = load_ns_checker(fast=False)
|
184 |
|
185 |
print(f"using model stored here: \n {model_loc} \n")
|
|
|
214 |
],
|
215 |
title=f"Ballpark Trivia: {default_model} Model",
|
216 |
description=f"Are you frequently asked google-able Trivia questions and annoyed by it? Well, this is the app for you! Ballpark Trivia Bot answers any trivia question with something that sounds plausible but is probably not 100% correct. \n\n One might say.. the answers are in the right ballpark.",
|
217 |
+
article="Further details can be found in the [model card](https://huggingface.co/pszemraj/Ballpark-Trivia-XL).\n\n"
|
218 |
"**Important Notes & About:**\n\n"
|
219 |
"1. the model can take up to 60 seconds to respond sometimes, patience is a virtue.\n"
|
220 |
"2. the model started from a pretrained checkpoint, and was trained on several different datasets. Anything it says should be fact-checked before being regarded as a true statement.\n"
|
|
|
235 |
# prevent_thread_lock=True,
|
236 |
# share=True,
|
237 |
enable_queue=True, # also allows for dealing with multiple users simultaneously (per newer gradio version)
|
238 |
+
)
|
grammar_improve.py
CHANGED
@@ -4,6 +4,7 @@ grammar_improve.py - this .py script contains functions to improve the grammar o
|
|
4 |
"""
|
5 |
|
6 |
from datetime import datetime
|
|
|
7 |
import pprint as pp
|
8 |
from neuspell import BertChecker, SclstmChecker
|
9 |
import neuspell
|
@@ -11,9 +12,11 @@ import math
|
|
11 |
from cleantext import clean
|
12 |
import time
|
13 |
import re
|
14 |
-
|
15 |
from symspellpy.symspellpy import SymSpell
|
16 |
|
|
|
|
|
17 |
|
18 |
def detect_propers(text: str):
|
19 |
"""
|
@@ -98,6 +101,14 @@ def remove_trailing_punctuation(text: str, fuLL_strip=False):
|
|
98 |
return text.strip(".,;:")
|
99 |
|
100 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
"""
|
102 |
start of SymSpell code
|
103 |
"""
|
@@ -126,6 +137,11 @@ def symspeller(
|
|
126 |
dictionary_path : str, optional, default=None, the path to the dictionary file
|
127 |
bigram_path : str, optional, default=None, the path to the bigram dictionary file
|
128 |
verbose : bool, optional, default=False, whether to print the results
|
|
|
|
|
|
|
|
|
|
|
129 |
"""
|
130 |
|
131 |
assert len(my_string) > 0, "entered string for correction is empty"
|
@@ -202,7 +218,8 @@ def build_symspell_obj(
|
|
202 |
|
203 |
|
204 |
"""
|
205 |
-
|
|
|
206 |
import torch
|
207 |
from transformers import T5Tokenizer, T5ForConditionalGeneration
|
208 |
|
@@ -282,19 +299,21 @@ def load_ns_checker(customckr=None, fast=False):
|
|
282 |
[neuspell.NeuSpell]: [neuspell checker object]
|
283 |
"""
|
284 |
st = time.perf_counter()
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
|
|
297 |
rt_min = (time.perf_counter() - st) / 60
|
|
|
298 |
print(f"\n\nloaded checker in {rt_min} minutes")
|
299 |
|
300 |
return checker
|
@@ -320,7 +339,7 @@ def neuspell_correct(input_text: str, checker=None, verbose=False):
|
|
320 |
return input_text
|
321 |
|
322 |
if checker is None:
|
323 |
-
print("NOTE - no checker provided,
|
324 |
checker = SclstmChecker(pretrained=True)
|
325 |
|
326 |
corrected = checker.correct(input_text)
|
|
|
4 |
"""
|
5 |
|
6 |
from datetime import datetime
|
7 |
+
import os
|
8 |
import pprint as pp
|
9 |
from neuspell import BertChecker, SclstmChecker
|
10 |
import neuspell
|
|
|
12 |
from cleantext import clean
|
13 |
import time
|
14 |
import re
|
15 |
+
import sys
|
16 |
from symspellpy.symspellpy import SymSpell
|
17 |
|
18 |
+
from utils import suppress_stdout
|
19 |
+
|
20 |
|
21 |
def detect_propers(text: str):
|
22 |
"""
|
|
|
101 |
return text.strip(".,;:")
|
102 |
|
103 |
|
104 |
+
def fix_punct_spacing(text: str):
|
105 |
+
fix_spaces = re.compile(r"\s*([?!.,]+(?:\s+[?!.,]+)*)\s*")
|
106 |
+
spc_text = fix_spaces.sub(lambda x: "{} ".format(x.group(1).replace(" ", "")), text)
|
107 |
+
cln_text = re.sub(r"(\W)(?=\1)", "", spc_text)
|
108 |
+
|
109 |
+
return cln_text
|
110 |
+
|
111 |
+
|
112 |
"""
|
113 |
start of SymSpell code
|
114 |
"""
|
|
|
137 |
dictionary_path : str, optional, default=None, the path to the dictionary file
|
138 |
bigram_path : str, optional, default=None, the path to the bigram dictionary file
|
139 |
verbose : bool, optional, default=False, whether to print the results
|
140 |
+
|
141 |
+
Returns
|
142 |
+
-------
|
143 |
+
list,
|
144 |
+
|
145 |
"""
|
146 |
|
147 |
assert len(my_string) > 0, "entered string for correction is empty"
|
|
|
218 |
|
219 |
|
220 |
"""
|
221 |
+
# if using t5b_correction to check for spelling errors, use this code to initialize the objects
|
222 |
+
|
223 |
import torch
|
224 |
from transformers import T5Tokenizer, T5ForConditionalGeneration
|
225 |
|
|
|
299 |
[neuspell.NeuSpell]: [neuspell checker object]
|
300 |
"""
|
301 |
st = time.perf_counter()
|
302 |
+
# stop all printing to the console
|
303 |
+
with suppress_stdout():
|
304 |
+
if customckr is None and not fast:
|
305 |
+
|
306 |
+
checker = BertChecker(
|
307 |
+
pretrained=True
|
308 |
+
) # load the default checker, has the best balance
|
309 |
+
elif customckr is None and fast:
|
310 |
+
checker = SclstmChecker(
|
311 |
+
pretrained=True
|
312 |
+
) # this one is faster but not as accurate
|
313 |
+
else:
|
314 |
+
checker = customckr(pretrained=True)
|
315 |
rt_min = (time.perf_counter() - st) / 60
|
316 |
+
# return to standard logging level
|
317 |
print(f"\n\nloaded checker in {rt_min} minutes")
|
318 |
|
319 |
return checker
|
|
|
339 |
return input_text
|
340 |
|
341 |
if checker is None:
|
342 |
+
print("NOTE - no checker provided, loading default checker")
|
343 |
checker = SclstmChecker(pretrained=True)
|
344 |
|
345 |
corrected = checker.correct(input_text)
|
requirements.txt
CHANGED
@@ -3,7 +3,7 @@ sentencepiece>=0.1.96
|
|
3 |
tqdm>=4.43.0
|
4 |
symspellpy>=6.7.0
|
5 |
requests>=2.24.0
|
6 |
-
gradio>=2.
|
7 |
natsort>=7.1.1
|
8 |
pandas>=1.3.0
|
9 |
aitextgen>=0.5.2
|
|
|
3 |
tqdm>=4.43.0
|
4 |
symspellpy>=6.7.0
|
5 |
requests>=2.24.0
|
6 |
+
gradio>=2.4.6
|
7 |
natsort>=7.1.1
|
8 |
pandas>=1.3.0
|
9 |
aitextgen>=0.5.2
|
symspell_rsc/frequency_bigramdictionary_en_243_342.txt
DELETED
The diff for this file is too large to render.
See raw diff
|
|
symspell_rsc/frequency_dictionary_en_82_765.txt
DELETED
The diff for this file is too large to render.
See raw diff
|
|
utils.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
"""
|
2 |
-
general utility functions for loading, saving, and manipulating data
|
3 |
"""
|
4 |
|
5 |
import os
|
@@ -19,6 +19,25 @@ import pandas as pd
|
|
19 |
from tqdm.auto import tqdm
|
20 |
|
21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
def remove_string_extras(mytext):
|
23 |
# removes everything from a string except A-Za-z0-9 .,;
|
24 |
return re.sub(r"[^A-Za-z0-9 .,;]+", "", mytext)
|
|
|
1 |
"""
|
2 |
+
utils - general utility functions for loading, saving, and manipulating data
|
3 |
"""
|
4 |
|
5 |
import os
|
|
|
19 |
from tqdm.auto import tqdm
|
20 |
|
21 |
|
22 |
+
from contextlib import contextmanager
|
23 |
+
import sys
|
24 |
+
import os
|
25 |
+
|
26 |
+
|
27 |
+
@contextmanager
|
28 |
+
def suppress_stdout():
|
29 |
+
"""
|
30 |
+
suppress_stdout - suppress stdout for a given block of code. credit to https://newbedev.com/how-to-suppress-console-output-in-python
|
31 |
+
"""
|
32 |
+
with open(os.devnull, "w") as devnull:
|
33 |
+
old_stdout = sys.stdout
|
34 |
+
sys.stdout = devnull
|
35 |
+
try:
|
36 |
+
yield
|
37 |
+
finally:
|
38 |
+
sys.stdout = old_stdout
|
39 |
+
|
40 |
+
|
41 |
def remove_string_extras(mytext):
|
42 |
# removes everything from a string except A-Za-z0-9 .,;
|
43 |
return re.sub(r"[^A-Za-z0-9 .,;]+", "", mytext)
|