Spaces:
Runtime error
Runtime error
Commit
·
df5eeb7
1
Parent(s):
935bf6f
Update with h2oGPT hash 23aaa9c9839867b3f0c86e7722cc7fbdae414fc4
Browse files- src/db_utils.py +54 -0
- src/gpt_langchain.py +2 -51
- src/gradio_runner.py +5 -2
src/db_utils.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import uuid
|
| 2 |
+
|
| 3 |
+
from enums import LangChainMode
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def set_userid(db1s, requests_state1, get_userid_auth):
|
| 7 |
+
db1 = db1s[LangChainMode.MY_DATA.value]
|
| 8 |
+
assert db1 is not None and len(db1) == length_db1()
|
| 9 |
+
if not db1[1]:
|
| 10 |
+
db1[1] = get_userid_auth(requests_state1)
|
| 11 |
+
if not db1[2]:
|
| 12 |
+
username1 = None
|
| 13 |
+
if 'username' in requests_state1:
|
| 14 |
+
username1 = requests_state1['username']
|
| 15 |
+
db1[2] = username1
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def set_userid_direct(db1s, userid, username):
|
| 19 |
+
db1 = db1s[LangChainMode.MY_DATA.value]
|
| 20 |
+
db1[1] = userid
|
| 21 |
+
db1[2] = username
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def get_userid_direct(db1s):
|
| 25 |
+
return db1s[LangChainMode.MY_DATA.value][1] if db1s is not None else ''
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def get_username_direct(db1s):
|
| 29 |
+
return db1s[LangChainMode.MY_DATA.value][2] if db1s is not None else ''
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def get_dbid(db1):
|
| 33 |
+
return db1[1]
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def set_dbid(db1):
|
| 37 |
+
# can only call this after function called so for specific user, not in gr.State() that occurs during app init
|
| 38 |
+
assert db1 is not None and len(db1) == length_db1()
|
| 39 |
+
if db1[1] is None:
|
| 40 |
+
# uuid in db is used as user ID
|
| 41 |
+
db1[1] = str(uuid.uuid4())
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def length_db1():
|
| 45 |
+
# For MyData:
|
| 46 |
+
# 0: db
|
| 47 |
+
# 1: userid and dbid
|
| 48 |
+
# 2: username
|
| 49 |
+
|
| 50 |
+
# For others:
|
| 51 |
+
# 0: db
|
| 52 |
+
# 1: dbid
|
| 53 |
+
# 2: None
|
| 54 |
+
return 3
|
src/gpt_langchain.py
CHANGED
|
@@ -37,6 +37,8 @@ from langchain.tools import PythonREPLTool
|
|
| 37 |
from langchain.tools.json.tool import JsonSpec
|
| 38 |
from tqdm import tqdm
|
| 39 |
|
|
|
|
|
|
|
| 40 |
from utils import wrapped_partial, EThread, import_matplotlib, sanitize_filename, makedirs, get_url, flatten_list, \
|
| 41 |
get_device, ProgressParallel, remove, hash_file, clear_torch_cache, NullContext, get_hf_server, FakeTokenizer, \
|
| 42 |
have_libreoffice, have_arxiv, have_playwright, have_selenium, have_tesseract, have_doctr, have_pymupdf, set_openai, \
|
|
@@ -4655,57 +4657,6 @@ def get_sources_answer(query, docs, answer, scores, show_rank,
|
|
| 4655 |
return ret, extra
|
| 4656 |
|
| 4657 |
|
| 4658 |
-
def set_userid(db1s, requests_state1, get_userid_auth):
|
| 4659 |
-
db1 = db1s[LangChainMode.MY_DATA.value]
|
| 4660 |
-
assert db1 is not None and len(db1) == length_db1()
|
| 4661 |
-
if not db1[1]:
|
| 4662 |
-
db1[1] = get_userid_auth(requests_state1)
|
| 4663 |
-
if not db1[2]:
|
| 4664 |
-
username1 = None
|
| 4665 |
-
if 'username' in requests_state1:
|
| 4666 |
-
username1 = requests_state1['username']
|
| 4667 |
-
db1[2] = username1
|
| 4668 |
-
|
| 4669 |
-
|
| 4670 |
-
def set_userid_direct(db1s, userid, username):
|
| 4671 |
-
db1 = db1s[LangChainMode.MY_DATA.value]
|
| 4672 |
-
db1[1] = userid
|
| 4673 |
-
db1[2] = username
|
| 4674 |
-
|
| 4675 |
-
|
| 4676 |
-
def get_userid_direct(db1s):
|
| 4677 |
-
return db1s[LangChainMode.MY_DATA.value][1] if db1s is not None else ''
|
| 4678 |
-
|
| 4679 |
-
|
| 4680 |
-
def get_username_direct(db1s):
|
| 4681 |
-
return db1s[LangChainMode.MY_DATA.value][2] if db1s is not None else ''
|
| 4682 |
-
|
| 4683 |
-
|
| 4684 |
-
def get_dbid(db1):
|
| 4685 |
-
return db1[1]
|
| 4686 |
-
|
| 4687 |
-
|
| 4688 |
-
def set_dbid(db1):
|
| 4689 |
-
# can only call this after function called so for specific user, not in gr.State() that occurs during app init
|
| 4690 |
-
assert db1 is not None and len(db1) == length_db1()
|
| 4691 |
-
if db1[1] is None:
|
| 4692 |
-
# uuid in db is used as user ID
|
| 4693 |
-
db1[1] = str(uuid.uuid4())
|
| 4694 |
-
|
| 4695 |
-
|
| 4696 |
-
def length_db1():
|
| 4697 |
-
# For MyData:
|
| 4698 |
-
# 0: db
|
| 4699 |
-
# 1: userid and dbid
|
| 4700 |
-
# 2: username
|
| 4701 |
-
|
| 4702 |
-
# For others:
|
| 4703 |
-
# 0: db
|
| 4704 |
-
# 1: dbid
|
| 4705 |
-
# 2: None
|
| 4706 |
-
return 3
|
| 4707 |
-
|
| 4708 |
-
|
| 4709 |
def get_any_db(db1s, langchain_mode, langchain_mode_paths, langchain_mode_types,
|
| 4710 |
dbs=None,
|
| 4711 |
load_db_if_exists=None, db_type=None,
|
|
|
|
| 37 |
from langchain.tools.json.tool import JsonSpec
|
| 38 |
from tqdm import tqdm
|
| 39 |
|
| 40 |
+
from src.db_utils import length_db1, set_dbid, set_userid, get_dbid, get_userid_direct, get_username_direct, \
|
| 41 |
+
set_userid_direct
|
| 42 |
from utils import wrapped_partial, EThread, import_matplotlib, sanitize_filename, makedirs, get_url, flatten_list, \
|
| 43 |
get_device, ProgressParallel, remove, hash_file, clear_torch_cache, NullContext, get_hf_server, FakeTokenizer, \
|
| 44 |
have_libreoffice, have_arxiv, have_playwright, have_selenium, have_tesseract, have_doctr, have_pymupdf, set_openai, \
|
|
|
|
| 4657 |
return ret, extra
|
| 4658 |
|
| 4659 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4660 |
def get_any_db(db1s, langchain_mode, langchain_mode_paths, langchain_mode_types,
|
| 4661 |
dbs=None,
|
| 4662 |
load_db_if_exists=None, db_type=None,
|
src/gradio_runner.py
CHANGED
|
@@ -20,6 +20,7 @@ from iterators import TimeoutIterator
|
|
| 20 |
|
| 21 |
from gradio_utils.css import get_css
|
| 22 |
from gradio_utils.prompt_form import make_chatbots
|
|
|
|
| 23 |
|
| 24 |
# This is a hack to prevent Gradio from phoning home when it gets imported
|
| 25 |
os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
|
|
@@ -459,7 +460,6 @@ def go_gradio(**kwargs):
|
|
| 459 |
if not requests_state1.get('host2', '') and hasattr(request, 'client') and hasattr(request.client, 'host'):
|
| 460 |
requests_state1.update(dict(host2=request.client.host))
|
| 461 |
if not requests_state1.get('username', '') and hasattr(request, 'username'):
|
| 462 |
-
from src.gpt_langchain import get_username_direct
|
| 463 |
# use already-defined username instead of keep changing to new uuid
|
| 464 |
# should be same as in requests_state1
|
| 465 |
db_username = get_username_direct(db1s)
|
|
@@ -469,7 +469,6 @@ def go_gradio(**kwargs):
|
|
| 469 |
|
| 470 |
def user_state_setup(db1s, requests_state1, request: gr.Request, *args):
|
| 471 |
requests_state1 = get_request_state(requests_state1, request, db1s)
|
| 472 |
-
from src.gpt_langchain import set_userid
|
| 473 |
set_userid(db1s, requests_state1, get_userid_auth)
|
| 474 |
args_list = [db1s, requests_state1] + list(args)
|
| 475 |
return tuple(args_list)
|
|
@@ -500,6 +499,8 @@ def go_gradio(**kwargs):
|
|
| 500 |
inference_server=kwargs['inference_server'],
|
| 501 |
prompt_type=kwargs['prompt_type'],
|
| 502 |
prompt_dict=kwargs['prompt_dict'],
|
|
|
|
|
|
|
| 503 |
)
|
| 504 |
)
|
| 505 |
|
|
@@ -3746,6 +3747,8 @@ def go_gradio(**kwargs):
|
|
| 3746 |
base_model=model_name, tokenizer_base_model=tokenizer_base_model,
|
| 3747 |
lora_weights=lora_weights, inference_server=server_name,
|
| 3748 |
prompt_type=prompt_type1, prompt_dict=prompt_dict1,
|
|
|
|
|
|
|
| 3749 |
)
|
| 3750 |
|
| 3751 |
max_max_new_tokens1 = get_max_max_new_tokens(model_state_new, **kwargs)
|
|
|
|
| 20 |
|
| 21 |
from gradio_utils.css import get_css
|
| 22 |
from gradio_utils.prompt_form import make_chatbots
|
| 23 |
+
from src.db_utils import set_userid, get_username_direct
|
| 24 |
|
| 25 |
# This is a hack to prevent Gradio from phoning home when it gets imported
|
| 26 |
os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
|
|
|
|
| 460 |
if not requests_state1.get('host2', '') and hasattr(request, 'client') and hasattr(request.client, 'host'):
|
| 461 |
requests_state1.update(dict(host2=request.client.host))
|
| 462 |
if not requests_state1.get('username', '') and hasattr(request, 'username'):
|
|
|
|
| 463 |
# use already-defined username instead of keep changing to new uuid
|
| 464 |
# should be same as in requests_state1
|
| 465 |
db_username = get_username_direct(db1s)
|
|
|
|
| 469 |
|
| 470 |
def user_state_setup(db1s, requests_state1, request: gr.Request, *args):
|
| 471 |
requests_state1 = get_request_state(requests_state1, request, db1s)
|
|
|
|
| 472 |
set_userid(db1s, requests_state1, get_userid_auth)
|
| 473 |
args_list = [db1s, requests_state1] + list(args)
|
| 474 |
return tuple(args_list)
|
|
|
|
| 499 |
inference_server=kwargs['inference_server'],
|
| 500 |
prompt_type=kwargs['prompt_type'],
|
| 501 |
prompt_dict=kwargs['prompt_dict'],
|
| 502 |
+
visible_models=kwargs['visible_models'],
|
| 503 |
+
h2ogpt_key=kwargs['h2ogpt_key'],
|
| 504 |
)
|
| 505 |
)
|
| 506 |
|
|
|
|
| 3747 |
base_model=model_name, tokenizer_base_model=tokenizer_base_model,
|
| 3748 |
lora_weights=lora_weights, inference_server=server_name,
|
| 3749 |
prompt_type=prompt_type1, prompt_dict=prompt_dict1,
|
| 3750 |
+
# FIXME: not typically required, unless want to expose adding h2ogpt endpoint in UI
|
| 3751 |
+
visible_models=None, h2ogpt_key=None,
|
| 3752 |
)
|
| 3753 |
|
| 3754 |
max_max_new_tokens1 = get_max_max_new_tokens(model_state_new, **kwargs)
|