bstraehle commited on
Commit
6a5cc80
·
verified ·
1 Parent(s): c58f54a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -16
app.py CHANGED
@@ -3,7 +3,14 @@
3
  import gradio as gr
4
  import logging, os, sys, threading
5
 
6
- from custom_utils import connect_to_database, rag_ingestion, rag_retrieval, rag_inference
 
 
 
 
 
 
 
7
 
8
  lock = threading.Lock()
9
 
@@ -31,22 +38,42 @@ def invoke(openai_api_key,
31
  with lock:
32
  db, collection = connect_to_database()
33
 
 
 
34
  if (RAG_INGESTION):
35
  return rag_ingestion(collection)
36
- else:
37
- retrieval_result = rag_retrieval(openai_api_key,
38
- prompt,
39
- accomodates,
40
- bedrooms,
41
- db,
42
- collection)
43
- inference_result = rag_inference(openai_api_key,
44
- prompt,
45
- retrieval_result)
46
- print("###")
47
- print(inference_result)
48
- print("###")
49
- return inference_result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  gr.close_all()
52
 
@@ -56,7 +83,7 @@ demo = gr.Interface(
56
  gr.Textbox(label = "Prompt", value = os.environ["PROMPT"], lines = 1),
57
  gr.Number(label = "Accomodates", value = 2),
58
  gr.Number(label = "Bedrooms", value = 1),
59
- gr.Radio([RAG_OFF, RAG_NAIVE, RAG_ADVANCED], label = "Retrieval-Augmented Generation", value = RAG_NAIVE)],
60
  outputs = [gr.Markdown(label = "Completion", value = os.environ["COMPLETION"], line_breaks = True, sanitize_html = False)],
61
  title = "Context-Aware Reasoning Application",
62
  description = os.environ["DESCRIPTION"]
 
3
  import gradio as gr
4
  import logging, os, sys, threading
5
 
6
+ from custom_utils import (
7
+ connect_to_database,
8
+ inference,
9
+ rag_ingestion,
10
+ rag_retrieval_naive,
11
+ rag_retrieval_advanced,
12
+ rag_inference
13
+ )
14
 
15
  lock = threading.Lock()
16
 
 
38
  with lock:
39
  db, collection = connect_to_database()
40
 
41
+ inference_result = ""
42
+
43
  if (RAG_INGESTION):
44
  return rag_ingestion(collection)
45
+ elif rag_option == RAGOFF:
46
+ inference_result = inference(
47
+ openai_api_key,
48
+ prompt)
49
+ elif rag_option == RAG_NAIVE:
50
+ retrieval_result = rag_retrieval_naive(
51
+ openai_api_key,
52
+ prompt,
53
+ db,
54
+ collection)
55
+ inference_result = rag_inference(
56
+ openai_api_key,
57
+ prompt,
58
+ retrieval_result)
59
+ elif rag_option == RAG_ADVANCED:
60
+ retrieval_result = rag_retrieval_advanced(
61
+ openai_api_key,
62
+ prompt,
63
+ accomodates,
64
+ bedrooms,
65
+ db,
66
+ collection)
67
+ inference_result = rag_inference(
68
+ openai_api_key,
69
+ prompt,
70
+ retrieval_result)
71
+
72
+ #print("###")
73
+ #print(inference_result)
74
+ #print("###")
75
+
76
+ return inference_result
77
 
78
  gr.close_all()
79
 
 
83
  gr.Textbox(label = "Prompt", value = os.environ["PROMPT"], lines = 1),
84
  gr.Number(label = "Accomodates", value = 2),
85
  gr.Number(label = "Bedrooms", value = 1),
86
+ gr.Radio([RAG_OFF, RAG_NAIVE, RAG_ADVANCED], label = "Retrieval-Augmented Generation", value = RAG_ADVANCED)],
87
  outputs = [gr.Markdown(label = "Completion", value = os.environ["COMPLETION"], line_breaks = True, sanitize_html = False)],
88
  title = "Context-Aware Reasoning Application",
89
  description = os.environ["DESCRIPTION"]