Peter Szemraj commited on
Commit
f28496d
1 Parent(s): 31f09ae

:sparkles: upgrade to v2 code

Browse files
ai_single_response.py DELETED
@@ -1,324 +0,0 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- ai_single_response.py
5
-
6
- An executable way to call the model. example:
7
- *\gpt2_chatbot> python ai_single_response.py --model "GPT2_conversational_355M_WoW10k" --prompt "hey, what's up?" --time
8
-
9
- query_gpt_model is used throughout the code, and is the "fundamental" building block of the bot and how everything works. Test this function with a few different models.
10
-
11
- """
12
- from aitextgen import aitextgen
13
- import argparse
14
- import pprint as pp
15
- import sys
16
- import time
17
- import warnings
18
- from datetime import datetime
19
- from pathlib import Path
20
- from grammar_improve import remove_trailing_punctuation
21
- from utils import print_spacer, cleantxt_wrap
22
-
23
- warnings.filterwarnings(action="ignore", message=".*gradient_checkpointing*")
24
-
25
-
26
- def extract_response(full_resp: list, plist: list, verbose: bool = False):
27
- """
28
- extract_response - helper fn for ai_single_response.py. By default aitextgen returns the prompt and the response, we just want the response
29
-
30
- Args:
31
- full_resp (list): a list of strings, each string is a response
32
- plist (list): a list of strings, each string is a prompt
33
-
34
- verbose (bool, optional): 4 debug. Defaults to False.
35
- """
36
- full_resp = [cleantxt_wrap(ele) for ele in full_resp]
37
- plist = [cleantxt_wrap(pr) for pr in plist]
38
- p_len = len(plist)
39
- assert (
40
- len(full_resp) >= p_len
41
- ), "model output should have as many lines or longer as the input."
42
-
43
- if set(plist).issubset(full_resp):
44
-
45
- del full_resp[:p_len] # remove the prompts from the responses
46
- else:
47
- print("the isolated responses are:\n")
48
- pp.pprint(full_resp)
49
- print_spacer()
50
- print("the input prompt was:\n")
51
- pp.pprint(plist)
52
- print_spacer()
53
- sys.exit("Exiting: some prompts not found in the responses")
54
- if verbose:
55
- print("the isolated responses are:\n")
56
- pp.pprint(full_resp)
57
- print_spacer()
58
- print("the input prompt was:\n")
59
- pp.pprint(plist)
60
- print_spacer()
61
- return full_resp # list of only the model generated responses
62
-
63
-
64
- def get_bot_response(
65
- name_resp: str, model_resp: str, name_spk: str, verbose: bool = False
66
- ):
67
-
68
- """
69
-
70
- get_bot_response - from the model response, extract the bot response. This is needed because depending on the generation length the model may return more than one response.
71
-
72
- Args: name_resp (str): the name of the responder
73
- model_resp (str): the model response
74
- verbose (bool, optional): 4 debug. Defaults to False.
75
-
76
- returns: fn_resp (list of str)
77
- """
78
-
79
- fn_resp = []
80
-
81
- name_counter = 0
82
- break_safe = False
83
- for resline in model_resp:
84
- if resline.startswith(name_resp):
85
- name_counter += 1
86
- break_safe = True # know the line is from bot as this line starts with the name of the bot
87
- continue
88
- if name_spk is not None and name_spk.lower() in resline.lower():
89
- break
90
- if ":" in resline and name_counter > 0:
91
- if break_safe:
92
- # we know this is a response from the bot even tho ':' is in the line
93
- fn_resp.append(resline)
94
- break_safe = False
95
- else:
96
- # we do not know this is a response from the bot. could be name of another person.. bot is "finished" response
97
- break
98
- else:
99
- fn_resp.append(resline)
100
- break_safe = False
101
- if verbose:
102
- print("the full response is:\n")
103
- print("\n".join(fn_resp))
104
-
105
- return fn_resp
106
-
107
-
108
- def query_gpt_model(
109
- prompt_msg: str,
110
- speaker=None,
111
- responder=None,
112
- resp_length=64,
113
- resp_min=10,
114
- kparam=50,
115
- temp=0.75,
116
- top_p=0.90,
117
- batch_size=1,
118
- verbose=False,
119
- use_gpu=False,
120
- nbeams=1,
121
- ):
122
- """
123
- query_gpt_model - the main function that calls the model.
124
-
125
- Parameters:
126
- -----------
127
- prompt_msg (str): the prompt to be sent to the model
128
- speaker (str, optional): the name of the speaker. Defaults to None.
129
- responder (str, optional): the name of the responder. Defaults to None.
130
- resp_length (int, optional): the length of the response. Defaults to 64.
131
- resp_min (int, optional): the minimum length of the response. Defaults to 4.
132
- kparam (int, optional): the k parameter for the top_p. Defaults to 150.
133
- temp (float, optional): the temperature for the top_p. Defaults to 0.75.
134
- top_p (float, optional): the top_p parameter for the top_p. Defaults to 0.65.
135
- verbose (bool, optional): 4 debug. Defaults to False.
136
- use_gpu (bool, optional): use gpu. Defaults to False.
137
- nbeams (int, optional): the number of beams to search and return best value. Defaults to 1.
138
- """
139
- from aitextgen.utils import GPT2ConfigCPU
140
-
141
- ai = aitextgen(
142
- model="pszemraj/Ballpark-Trivia-L", # THIS WORKS. XL is not working
143
- to_gpu=use_gpu,
144
- )
145
-
146
- p_list = [] # track conversation
147
- p_list.append(speaker.lower() + ":" + "\n")
148
- p_list.append(prompt_msg.lower() + "\n")
149
- p_list.append("\n")
150
- p_list.append(responder.lower() + ":" + "\n")
151
- this_prompt = "".join(p_list)
152
- pr_len = len(this_prompt)
153
- if verbose:
154
- print("overall prompt:\n")
155
- pp.pprint(this_prompt, indent=4)
156
- # call the model
157
- print("\n... generating...")
158
- this_result = ai.generate(
159
- n=1,
160
- batch_size=batch_size,
161
- # the prompt input counts for text length constraints
162
- max_length=resp_length + pr_len,
163
- min_length=resp_min + pr_len,
164
- prompt=this_prompt,
165
- top_k=kparam,
166
- top_p=top_p,
167
- do_sample=True,
168
- return_as_list=True,
169
- n_beams=nbeams,
170
- temperature=temp,
171
- verbose=True, # in this case verbose is just to enable huggingface logging
172
- use_cache=True,
173
- )
174
- if verbose:
175
- print("\n... generated:\n")
176
- pp.pprint(this_result) # for debugging
177
- # process the full result to get the ~bot response~ piece
178
- this_result = str(this_result[0]).split(
179
- "\n"
180
- ) # TODO: adjust hardcoded value for index to dynamic (if n>1)
181
- og_res = this_result.copy()
182
- og_prompt = p_list.copy()
183
- diff_list = extract_response(
184
- this_result, p_list, verbose=verbose
185
- ) # isolate the responses from the prompts
186
- # extract the bot response from the model generated text
187
- bot_dialogue = get_bot_response(
188
- name_resp=responder, model_resp=diff_list, name_spk=speaker, verbose=verbose
189
- )
190
- print(f"DEBUG: {bot_dialogue} was original response pre-SC")
191
- bot_resp = ", ".join(bot_dialogue)
192
- bot_resp = bot_resp.strip()
193
- # remove the last ',' '.' chars
194
- bot_resp = remove_trailing_punctuation(bot_resp)
195
- if verbose:
196
- print("\n... bot response:\n")
197
- pp.pprint(bot_resp)
198
- og_prompt.append(bot_resp + "\n")
199
- og_prompt.append("\n")
200
-
201
- print("\nfinished!")
202
- # return the bot response and the full conversation
203
-
204
- return {"out_text": bot_resp, "full_conv": og_prompt} # model responses
205
-
206
-
207
- # Set up the parsing of command-line arguments
208
- def get_parser():
209
- """
210
- get_parser a helper function for the argparse module, relevant if this is run as a script.
211
- """
212
-
213
- parser = argparse.ArgumentParser(
214
- description="submit a message and have a pretrained GPT model respond"
215
- )
216
- parser.add_argument(
217
- "--prompt",
218
- required=True, # MUST HAVE A PROMPT
219
- type=str,
220
- help="the message the bot is supposed to respond to. Prompt is said by speaker, answered by responder.",
221
- )
222
- parser.add_argument(
223
- "--model",
224
- required=False,
225
- type=str,
226
- default="GPT2_trivNatQAdailydia_774M_175Ksteps",
227
- help="folder - with respect to git directory of your repo that has the model files in it (pytorch.bin + "
228
- "config.json). No models? Run the script download_models.py",
229
- )
230
-
231
- parser.add_argument(
232
- "--speaker",
233
- required=False,
234
- default=None,
235
- help="Who the prompt is from (to the bot). Primarily relevant to bots trained on multi-individual chat data",
236
- )
237
- parser.add_argument(
238
- "--responder",
239
- required=False,
240
- default="person beta",
241
- help="who the responder is. Primarily relevant to bots trained on multi-individual chat data",
242
- )
243
-
244
- parser.add_argument(
245
- "--topk",
246
- required=False,
247
- type=int,
248
- default=150,
249
- help="how many responses to sample (positive integer). lower = more random responses",
250
- )
251
-
252
- parser.add_argument(
253
- "--temp",
254
- required=False,
255
- type=float,
256
- default=0.75,
257
- help="specify temperature hyperparam (0-1). roughly considered as 'model creativity'",
258
- )
259
-
260
- parser.add_argument(
261
- "--topp",
262
- required=False,
263
- type=float,
264
- default=0.65,
265
- help="nucleus sampling frac (0-1). aka: what fraction of possible options are considered?",
266
- )
267
-
268
- parser.add_argument(
269
- "--verbose",
270
- default=False,
271
- action="store_true",
272
- help="pass this argument if you want all the printouts",
273
- )
274
- parser.add_argument(
275
- "--time",
276
- default=False,
277
- action="store_true",
278
- help="pass this argument if you want to know runtime",
279
- )
280
- return parser
281
-
282
-
283
- if __name__ == "__main__":
284
- # parse the command line arguments
285
- args = get_parser().parse_args()
286
- query = args.prompt
287
- model_dir = str(args.model)
288
- model_loc = Path.cwd() / model_dir
289
- spkr = args.speaker
290
- rspndr = args.responder
291
- k_results = args.topk
292
- my_temp = args.temp
293
- my_top_p = args.topp
294
- want_verbose = args.verbose
295
- want_rt = args.time
296
-
297
- st = time.perf_counter()
298
-
299
- resp = query_gpt_model(
300
- folder_path=model_loc,
301
- prompt_msg=query,
302
- speaker=spkr,
303
- responder=rspndr,
304
- kparam=k_results,
305
- temp=my_temp,
306
- top_p=my_top_p,
307
- verbose=want_verbose,
308
- use_gpu=False,
309
- )
310
-
311
- output = resp["out_text"]
312
- pp.pprint(output, indent=4)
313
-
314
- rt = round(time.perf_counter() - st, 1)
315
-
316
- if want_rt:
317
- print("took {runtime} seconds to generate. \n".format(runtime=rt))
318
-
319
- if want_verbose:
320
- print("finished - ", datetime.now())
321
- p_list = resp["full_conv"]
322
- print("A transcript of your chat is as follows: \n")
323
- p_list = [item.strip() for item in p_list]
324
- pp.pprint(p_list)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -1,19 +1,9 @@
1
  """
2
- deploy-as-bot\gradio_chatbot.py
3
- A system, method for deploying to Gradio. Gradio is a basic "deploy" interface which allows for other users to test your model from a web URL. It also enables some basic functionality like user flagging for weird responses.
4
  """
5
- from flask import (
6
- Flask,
7
- request,
8
- session,
9
- jsonify,
10
- abort,
11
- send_file,
12
- render_template,
13
- redirect,
14
- )
15
- from ai_single_response import query_gpt_model
16
- from datetime import datetime
17
  from transformers import pipeline
18
  from cleantext import clean
19
  from pathlib import Path
@@ -26,7 +16,7 @@ import os
26
  import sys
27
  from os.path import dirname
28
  import nltk
29
- from aitextgen import aitextgen
30
  from grammar_improve import (
31
  detect_propers,
32
  load_ns_checker,
@@ -35,6 +25,7 @@ from grammar_improve import (
35
  remove_trailing_punctuation,
36
  build_symspell_obj,
37
  symspeller,
 
38
  )
39
 
40
  from utils import (
@@ -46,46 +37,78 @@ nltk.download("stopwords") # TODO: find where this requirement originates from
46
 
47
  sys.path.append(dirname(dirname(os.path.abspath(__file__))))
48
  warnings.filterwarnings(action="ignore", message=".*gradient_checkpointing*")
 
 
 
49
  logging.basicConfig()
50
  cwd = Path.cwd()
51
  my_cwd = str(cwd.resolve()) # string so it can be passed to os.path() objects
52
 
53
 
54
- def load_model(model_name=None, use_gpu=False):
55
- _name = "pszemraj/Ballpark-Trivia-L" if model_name is None else model_name
56
- print(f"\nloading model: {_name}\n")
57
- ai = aitextgen(
58
- model=_name,
59
- to_gpu=use_gpu,
60
- )
 
 
 
 
61
 
 
62
 
63
- def ask_gpt(message: str):
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  """
65
- ask_gpt - queries the relevant model with a prompt message and returns the response.
66
- NOTE: because this is for models trained with person alpha and person beta,
67
- there is no need for customizing / changing the name settings and so on
68
- Args:
69
- message (str): prompt message to respond to, usually a question
70
- Returns:
71
- [str]: [model response as a string]
 
 
 
 
 
72
  """
 
73
  st = time.perf_counter()
74
  prompt = clean(message) # clean user input
75
  prompt = prompt.strip() # get rid of any extra whitespace
76
- if len(prompt) > 200:
77
- prompt = prompt[-200:] # truncateblack
78
-
79
- resp = query_gpt_model(
80
- prompt_msg=prompt,
81
- speaker="person alpha",
82
- responder="person beta",
83
- kparam=30,
84
- top_p=0.9,
85
- batch_size=1,
86
- nbeams=1,
87
- # TODO - allow users to adjust these 4 da memes
88
- ) # using top_P and top_k to avoid the "too many hypotheses" error, not using temp
 
 
 
 
 
89
  rawtxt = resp["out_text"]
90
  # check for proper nouns
91
  if basic_sc and not detect_propers(rawtxt):
@@ -95,26 +118,16 @@ def ask_gpt(message: str):
95
  else:
96
  # no correction needed
97
  cln_resp = rawtxt.strip()
98
- bot_resp = corr(remove_repeated_words(cln_resp))
99
- print(f"the prompt was:\n {message} and the response was:\n {bot_resp}\n")
100
- rt = round(time.perf_counter() - st, 2)
101
- print(f"took {rt} sec to respond\n")
 
 
 
102
  return remove_trailing_punctuation(bot_resp)
103
 
104
 
105
- def chat(trivia_query):
106
- history = []
107
- response = ask_gpt(trivia_query)
108
- history = [trivia_query, response]
109
- html = ""
110
- for item in history:
111
- html += f"<b>{item}</b> <br><br>"
112
-
113
- html += ""
114
-
115
- return html
116
-
117
-
118
  def get_parser():
119
  """
120
  get_parser - a helper function for the argparse module
@@ -127,9 +140,8 @@ def get_parser():
127
  "--model",
128
  required=False,
129
  type=str,
130
- default="pszemraj/Ballpark-Trivia-L",
131
- help="folder - with respect to git directory of your repo that has the model files in it (pytorch.bin + "
132
- "config.json)",
133
  )
134
  parser.add_argument(
135
  "--basic-sc",
@@ -139,21 +151,35 @@ def get_parser():
139
  help="turn on symspell (baseline) correction instead of the more advanced neural net models",
140
  )
141
 
 
 
 
 
 
 
142
  return parser
143
 
144
 
145
  if __name__ == "__main__":
146
  args = get_parser().parse_args()
147
  default_model = str(args.model)
148
- load_model(default_model)
149
- model_loc = cwd.parent / default_model
150
- model_loc = str(model_loc.resolve())
151
- basic_sc = args.basic_sc
 
 
 
 
 
 
 
 
152
  if basic_sc:
153
- print("defaulting to symspell for spell checking")
154
  schnellspell = build_symspell_obj()
155
  else:
156
- print("using advanced spell checker (Neuspell)")
157
  ns_checker = load_ns_checker(fast=False)
158
 
159
  print(f"using model stored here: \n {model_loc} \n")
@@ -188,7 +214,7 @@ if __name__ == "__main__":
188
  ],
189
  title=f"Ballpark Trivia: {default_model} Model",
190
  description=f"Are you frequently asked google-able Trivia questions and annoyed by it? Well, this is the app for you! Ballpark Trivia Bot answers any trivia question with something that sounds plausible but is probably not 100% correct. \n\n One might say.. the answers are in the right ballpark.",
191
- article="Further details can be found in the [model card](https://huggingface.co/pszemraj/Ballpark-Trivia-L). If you are interested in a more deceptively incorrect model, there is also [an XL version](https://huggingface.co/pszemraj/Ballpark-Trivia-XL) on my page.\n\n"
192
  "**Important Notes & About:**\n\n"
193
  "1. the model can take up to 60 seconds to respond sometimes, patience is a virtue.\n"
194
  "2. the model started from a pretrained checkpoint, and was trained on several different datasets. Anything it says should be fact-checked before being regarded as a true statement.\n"
@@ -209,4 +235,4 @@ if __name__ == "__main__":
209
  # prevent_thread_lock=True,
210
  # share=True,
211
  enable_queue=True, # also allows for dealing with multiple users simultaneously (per newer gradio version)
212
- )
 
1
  """
2
+ app.py - the main file for the app. This creates the flask app and handles the routes.
3
+
4
  """
5
+
6
+ import torch
 
 
 
 
 
 
 
 
 
 
7
  from transformers import pipeline
8
  from cleantext import clean
9
  from pathlib import Path
 
16
  import sys
17
  from os.path import dirname
18
  import nltk
19
+ from converse import discussion
20
  from grammar_improve import (
21
  detect_propers,
22
  load_ns_checker,
 
25
  remove_trailing_punctuation,
26
  build_symspell_obj,
27
  symspeller,
28
+ fix_punct_spacing,
29
  )
30
 
31
  from utils import (
 
37
 
38
  sys.path.append(dirname(dirname(os.path.abspath(__file__))))
39
  warnings.filterwarnings(action="ignore", message=".*gradient_checkpointing*")
40
+ import transformers
41
+
42
+ transformers.logging.set_verbosity_error()
43
  logging.basicConfig()
44
  cwd = Path.cwd()
45
  my_cwd = str(cwd.resolve()) # string so it can be passed to os.path() objects
46
 
47
 
48
+ def chat(trivia_query):
49
+ """
50
+ chat - a function that takes in a trivia query and returns a response
51
+
52
+ """
53
+ history = []
54
+ response = ask_gpt(trivia_query)
55
+ history = [trivia_query, response]
56
+ html = ""
57
+ for item in history:
58
+ html += f"<b>{item}</b> <br><br>"
59
 
60
+ html += ""
61
 
62
+ return html
63
+
64
+
65
+
66
+ def ask_gpt(
67
+ message: str,
68
+ chat_pipe,
69
+ speaker="person alpha",
70
+ responder="person beta",
71
+ max_len=196,
72
+ top_p=0.95,
73
+ top_k=50,
74
+ temperature=0.6,
75
+ ):
76
  """
77
+
78
+ ask_gpt - a function that takes in a prompt and generates a response using the pipeline. This interacts the discussion function.
79
+
80
+ Parameters:
81
+ message (str): the question to ask the bot
82
+ chat_pipe (str): the chat_pipe to use for the bot (default: "pszemraj/Ballpark-Trivia-XL")
83
+ speaker (str): the name of the speaker (default: "person alpha")
84
+ responder (str): the name of the responder (default: "person beta")
85
+ max_len (int): the maximum length of the response (default: 128)
86
+ top_p (float): the top probability threshold (default: 0.95)
87
+ top_k (int): the top k threshold (default: 50)
88
+ temperature (float): the temperature of the response (default: 0.7)
89
  """
90
+
91
  st = time.perf_counter()
92
  prompt = clean(message) # clean user input
93
  prompt = prompt.strip() # get rid of any extra whitespace
94
+ in_len = len(prompt)
95
+ if in_len > 512:
96
+ prompt = prompt[-512:] # truncate to 512 chars
97
+ print(f"Truncated prompt to last 512 chars: started with {in_len} chars")
98
+ max_len = min(max_len, 512)
99
+
100
+ resp = discussion(
101
+ prompt_text=prompt,
102
+ pipeline=chat_pipe,
103
+ speaker=speaker,
104
+ responder=responder,
105
+ top_p=top_p,
106
+ top_k=top_k,
107
+ temperature=temperature,
108
+ max_length=max_len,
109
+ )
110
+ gpt_et = time.perf_counter()
111
+ gpt_rt = round(gpt_et - st, 2)
112
  rawtxt = resp["out_text"]
113
  # check for proper nouns
114
  if basic_sc and not detect_propers(rawtxt):
 
118
  else:
119
  # no correction needed
120
  cln_resp = rawtxt.strip()
121
+ bot_resp_a = corr(remove_repeated_words(cln_resp))
122
+ bot_resp = fix_punct_spacing(bot_resp_a)
123
+ print(f"the prompt was:\n\t{message}\nand the response was:\n\t{bot_resp}\n")
124
+ corr_rt = round(time.perf_counter() - gpt_et, 4)
125
+ print(
126
+ f"took {gpt_rt + corr_rt} sec to respond, {gpt_rt} for GPT, {corr_rt} for correction\n"
127
+ )
128
  return remove_trailing_punctuation(bot_resp)
129
 
130
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  def get_parser():
132
  """
133
  get_parser - a helper function for the argparse module
 
140
  "--model",
141
  required=False,
142
  type=str,
143
+ default="pszemraj/Ballpark-Trivia-XL", # default model
144
+ help="the model to use for the chatbot on https://huggingface.co/models OR a path to a local model",
 
145
  )
146
  parser.add_argument(
147
  "--basic-sc",
 
151
  help="turn on symspell (baseline) correction instead of the more advanced neural net models",
152
  )
153
 
154
+ parser.add_argument(
155
+ "--verbose",
156
+ action="store_true",
157
+ default=False,
158
+ help="turn on verbose logging",
159
+ )
160
  return parser
161
 
162
 
163
  if __name__ == "__main__":
164
  args = get_parser().parse_args()
165
  default_model = str(args.model)
166
+ model_loc = Path(default_model) # if the model is a path, use it
167
+ basic_sc = args.basic_sc # whether to use the baseline spellchecker
168
+ device = 0 if torch.cuda.is_available() else -1
169
+ print(f"CUDA avail is {torch.cuda.is_available()}")
170
+
171
+ my_chatbot = (
172
+ pipeline("text-generation", model=model_loc.resolve(), device=device)
173
+ if model_loc.exists() and model_loc.is_dir()
174
+ else pipeline("text-generation", model=default_model, device=device)
175
+ ) # if the model is a name, use it. stays on CPU if no GPU available
176
+ print(f"using model {my_chatbot.model}")
177
+
178
  if basic_sc:
179
+ print("Using the baseline spellchecker")
180
  schnellspell = build_symspell_obj()
181
  else:
182
+ print("using Neuspell spell checker")
183
  ns_checker = load_ns_checker(fast=False)
184
 
185
  print(f"using model stored here: \n {model_loc} \n")
 
214
  ],
215
  title=f"Ballpark Trivia: {default_model} Model",
216
  description=f"Are you frequently asked google-able Trivia questions and annoyed by it? Well, this is the app for you! Ballpark Trivia Bot answers any trivia question with something that sounds plausible but is probably not 100% correct. \n\n One might say.. the answers are in the right ballpark.",
217
+ article="Further details can be found in the [model card](https://huggingface.co/pszemraj/Ballpark-Trivia-XL).\n\n"
218
  "**Important Notes & About:**\n\n"
219
  "1. the model can take up to 60 seconds to respond sometimes, patience is a virtue.\n"
220
  "2. the model started from a pretrained checkpoint, and was trained on several different datasets. Anything it says should be fact-checked before being regarded as a true statement.\n"
 
235
  # prevent_thread_lock=True,
236
  # share=True,
237
  enable_queue=True, # also allows for dealing with multiple users simultaneously (per newer gradio version)
238
+ )
grammar_improve.py CHANGED
@@ -4,6 +4,7 @@ grammar_improve.py - this .py script contains functions to improve the grammar o
4
  """
5
 
6
  from datetime import datetime
 
7
  import pprint as pp
8
  from neuspell import BertChecker, SclstmChecker
9
  import neuspell
@@ -11,9 +12,11 @@ import math
11
  from cleantext import clean
12
  import time
13
  import re
14
-
15
  from symspellpy.symspellpy import SymSpell
16
 
 
 
17
 
18
  def detect_propers(text: str):
19
  """
@@ -98,6 +101,14 @@ def remove_trailing_punctuation(text: str, fuLL_strip=False):
98
  return text.strip(".,;:")
99
 
100
 
 
 
 
 
 
 
 
 
101
  """
102
  start of SymSpell code
103
  """
@@ -126,6 +137,11 @@ def symspeller(
126
  dictionary_path : str, optional, default=None, the path to the dictionary file
127
  bigram_path : str, optional, default=None, the path to the bigram dictionary file
128
  verbose : bool, optional, default=False, whether to print the results
 
 
 
 
 
129
  """
130
 
131
  assert len(my_string) > 0, "entered string for correction is empty"
@@ -202,7 +218,8 @@ def build_symspell_obj(
202
 
203
 
204
  """
205
- NEEDED FOR T5
 
206
  import torch
207
  from transformers import T5Tokenizer, T5ForConditionalGeneration
208
 
@@ -282,19 +299,21 @@ def load_ns_checker(customckr=None, fast=False):
282
  [neuspell.NeuSpell]: [neuspell checker object]
283
  """
284
  st = time.perf_counter()
285
-
286
- if customckr is None and not fast:
287
-
288
- checker = BertChecker(
289
- pretrained=True
290
- ) # load the default checker, has the best balance
291
- elif customckr is None and fast:
292
- checker = SclstmChecker(
293
- pretrained=True
294
- ) # this one is faster but not as accurate
295
- else:
296
- checker = customckr(pretrained=True)
 
297
  rt_min = (time.perf_counter() - st) / 60
 
298
  print(f"\n\nloaded checker in {rt_min} minutes")
299
 
300
  return checker
@@ -320,7 +339,7 @@ def neuspell_correct(input_text: str, checker=None, verbose=False):
320
  return input_text
321
 
322
  if checker is None:
323
- print("NOTE - no checker provided, using default checker")
324
  checker = SclstmChecker(pretrained=True)
325
 
326
  corrected = checker.correct(input_text)
 
4
  """
5
 
6
  from datetime import datetime
7
+ import os
8
  import pprint as pp
9
  from neuspell import BertChecker, SclstmChecker
10
  import neuspell
 
12
  from cleantext import clean
13
  import time
14
  import re
15
+ import sys
16
  from symspellpy.symspellpy import SymSpell
17
 
18
+ from utils import suppress_stdout
19
+
20
 
21
  def detect_propers(text: str):
22
  """
 
101
  return text.strip(".,;:")
102
 
103
 
104
+ def fix_punct_spacing(text: str):
105
+ fix_spaces = re.compile(r"\s*([?!.,]+(?:\s+[?!.,]+)*)\s*")
106
+ spc_text = fix_spaces.sub(lambda x: "{} ".format(x.group(1).replace(" ", "")), text)
107
+ cln_text = re.sub(r"(\W)(?=\1)", "", spc_text)
108
+
109
+ return cln_text
110
+
111
+
112
  """
113
  start of SymSpell code
114
  """
 
137
  dictionary_path : str, optional, default=None, the path to the dictionary file
138
  bigram_path : str, optional, default=None, the path to the bigram dictionary file
139
  verbose : bool, optional, default=False, whether to print the results
140
+
141
+ Returns
142
+ -------
143
+ list,
144
+
145
  """
146
 
147
  assert len(my_string) > 0, "entered string for correction is empty"
 
218
 
219
 
220
  """
221
+ # if using t5b_correction to check for spelling errors, use this code to initialize the objects
222
+
223
  import torch
224
  from transformers import T5Tokenizer, T5ForConditionalGeneration
225
 
 
299
  [neuspell.NeuSpell]: [neuspell checker object]
300
  """
301
  st = time.perf_counter()
302
+ # stop all printing to the console
303
+ with suppress_stdout():
304
+ if customckr is None and not fast:
305
+
306
+ checker = BertChecker(
307
+ pretrained=True
308
+ ) # load the default checker, has the best balance
309
+ elif customckr is None and fast:
310
+ checker = SclstmChecker(
311
+ pretrained=True
312
+ ) # this one is faster but not as accurate
313
+ else:
314
+ checker = customckr(pretrained=True)
315
  rt_min = (time.perf_counter() - st) / 60
316
+ # return to standard logging level
317
  print(f"\n\nloaded checker in {rt_min} minutes")
318
 
319
  return checker
 
339
  return input_text
340
 
341
  if checker is None:
342
+ print("NOTE - no checker provided, loading default checker")
343
  checker = SclstmChecker(pretrained=True)
344
 
345
  corrected = checker.correct(input_text)
requirements.txt CHANGED
@@ -3,7 +3,7 @@ sentencepiece>=0.1.96
3
  tqdm>=4.43.0
4
  symspellpy>=6.7.0
5
  requests>=2.24.0
6
- gradio>=2.5.0
7
  natsort>=7.1.1
8
  pandas>=1.3.0
9
  aitextgen>=0.5.2
 
3
  tqdm>=4.43.0
4
  symspellpy>=6.7.0
5
  requests>=2.24.0
6
+ gradio>=2.4.6
7
  natsort>=7.1.1
8
  pandas>=1.3.0
9
  aitextgen>=0.5.2
symspell_rsc/frequency_bigramdictionary_en_243_342.txt DELETED
The diff for this file is too large to render. See raw diff
 
symspell_rsc/frequency_dictionary_en_82_765.txt DELETED
The diff for this file is too large to render. See raw diff
 
utils.py CHANGED
@@ -1,5 +1,5 @@
1
  """
2
- general utility functions for loading, saving, and manipulating data
3
  """
4
 
5
  import os
@@ -19,6 +19,25 @@ import pandas as pd
19
  from tqdm.auto import tqdm
20
 
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  def remove_string_extras(mytext):
23
  # removes everything from a string except A-Za-z0-9 .,;
24
  return re.sub(r"[^A-Za-z0-9 .,;]+", "", mytext)
 
1
  """
2
+ utils - general utility functions for loading, saving, and manipulating data
3
  """
4
 
5
  import os
 
19
  from tqdm.auto import tqdm
20
 
21
 
22
+ from contextlib import contextmanager
23
+ import sys
24
+ import os
25
+
26
+
27
+ @contextmanager
28
+ def suppress_stdout():
29
+ """
30
+ suppress_stdout - suppress stdout for a given block of code. credit to https://newbedev.com/how-to-suppress-console-output-in-python
31
+ """
32
+ with open(os.devnull, "w") as devnull:
33
+ old_stdout = sys.stdout
34
+ sys.stdout = devnull
35
+ try:
36
+ yield
37
+ finally:
38
+ sys.stdout = old_stdout
39
+
40
+
41
  def remove_string_extras(mytext):
42
  # removes everything from a string except A-Za-z0-9 .,;
43
  return re.sub(r"[^A-Za-z0-9 .,;]+", "", mytext)