|
|
|
import os |
|
import gradio as gr |
|
|
|
from azure.storage.fileshare import ShareServiceClient |
|
|
|
|
|
from climateqa.engine.embeddings import get_embeddings_function |
|
from climateqa.engine.llm import get_llm |
|
from climateqa.engine.vectorstore import get_pinecone_vectorstore |
|
from climateqa.engine.reranker import get_reranker |
|
from climateqa.engine.graph import make_graph_agent,make_graph_agent_poc |
|
from climateqa.engine.chains.retrieve_papers import find_papers |
|
from climateqa.chat import start_chat, chat_stream, finish_chat |
|
from climateqa.engine.talk_to_data.main import ask_vanna |
|
from climateqa.engine.talk_to_data.myVanna import MyVanna |
|
|
|
from front.tabs import (create_config_modal, create_examples_tab, create_papers_tab, create_figures_tab, create_chat_interface, create_about_tab) |
|
from front.utils import process_figures |
|
from gradio_modal import Modal |
|
|
|
|
|
from utils import create_user_id |
|
import logging |
|
|
|
logging.basicConfig(level=logging.WARNING) |
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' |
|
logging.getLogger().setLevel(logging.WARNING) |
|
|
|
|
|
|
|
|
|
try: |
|
from dotenv import load_dotenv |
|
load_dotenv() |
|
except Exception as e: |
|
pass |
|
|
|
|
|
|
|
theme = gr.themes.Base( |
|
primary_hue="blue", |
|
secondary_hue="red", |
|
font=[gr.themes.GoogleFont("Poppins"), "ui-sans-serif", "system-ui", "sans-serif"], |
|
) |
|
|
|
|
|
account_key = os.environ["BLOB_ACCOUNT_KEY"] |
|
if len(account_key) == 86: |
|
account_key += "==" |
|
|
|
credential = { |
|
"account_key": account_key, |
|
"account_name": os.environ["BLOB_ACCOUNT_NAME"], |
|
} |
|
|
|
account_url = os.environ["BLOB_ACCOUNT_URL"] |
|
file_share_name = "climateqa" |
|
service = ShareServiceClient(account_url=account_url, credential=credential) |
|
share_client = service.get_share_client(file_share_name) |
|
|
|
user_id = create_user_id() |
|
|
|
|
|
|
|
|
|
embeddings_function = get_embeddings_function() |
|
vectorstore = get_pinecone_vectorstore(embeddings_function, index_name=os.getenv("PINECONE_API_INDEX")) |
|
vectorstore_graphs = get_pinecone_vectorstore(embeddings_function, index_name=os.getenv("PINECONE_API_INDEX_OWID"), text_key="description") |
|
vectorstore_region = get_pinecone_vectorstore(embeddings_function, index_name=os.getenv("PINECONE_API_INDEX_LOCAL_V2")) |
|
|
|
llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0) |
|
if os.environ["GRADIO_ENV"] == "local": |
|
reranker = get_reranker("nano") |
|
else : |
|
reranker = get_reranker("large") |
|
|
|
agent = make_graph_agent(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, vectorstore_region = vectorstore_region, reranker=reranker, threshold_docs=0.2) |
|
agent_poc = make_graph_agent_poc(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, vectorstore_region = vectorstore_region, reranker=reranker, threshold_docs=0, version="v4") |
|
|
|
|
|
|
|
vn = MyVanna(config = {"temperature": 0, "api_key": os.getenv('THEO_API_KEY'), 'model': os.getenv('VANNA_MODEL'), 'pc_api_key': os.getenv('VANNA_PINECONE_API_KEY'), 'index_name': os.getenv('VANNA_INDEX_NAME'), "top_k" : 4}) |
|
db_vanna_path = os.path.join(os.getcwd(), "data/drias/drias.db") |
|
vn.connect_to_sqlite(db_vanna_path) |
|
|
|
def ask_vanna_query(query): |
|
return ask_vanna(vn, db_vanna_path, query) |
|
|
|
async def chat(query, history, audience, sources, reports, relevant_content_sources_selection, search_only): |
|
print("chat cqa - message received") |
|
async for event in chat_stream(agent, query, history, audience, sources, reports, relevant_content_sources_selection, search_only, share_client, user_id): |
|
yield event |
|
|
|
async def chat_poc(query, history, audience, sources, reports, relevant_content_sources_selection, search_only): |
|
print("chat poc - message received") |
|
async for event in chat_stream(agent_poc, query, history, audience, sources, reports, relevant_content_sources_selection, search_only, share_client, user_id): |
|
yield event |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def update_config_modal_visibility(config_open): |
|
print(config_open) |
|
new_config_visibility_status = not config_open |
|
return Modal(visible=new_config_visibility_status), new_config_visibility_status |
|
|
|
|
|
def update_sources_number_display(sources_textbox, figures_cards, current_graphs, papers_html): |
|
sources_number = sources_textbox.count("<h2>") |
|
figures_number = figures_cards.count("<h2>") |
|
graphs_number = current_graphs.count("<iframe") |
|
papers_number = papers_html.count("<h2>") |
|
sources_notif_label = f"Sources ({sources_number})" |
|
figures_notif_label = f"Figures ({figures_number})" |
|
graphs_notif_label = f"Graphs ({graphs_number})" |
|
papers_notif_label = f"Papers ({papers_number})" |
|
recommended_content_notif_label = f"Recommended content ({figures_number + graphs_number + papers_number})" |
|
|
|
return gr.update(label=recommended_content_notif_label), gr.update(label=sources_notif_label), gr.update(label=figures_notif_label), gr.update(label=graphs_notif_label), gr.update(label=papers_notif_label) |
|
|
|
def create_drias_tab(): |
|
with gr.Tab("Beta - Talk to DRIAS", elem_id="tab-vanna", id=6) as tab_vanna: |
|
vanna_direct_question = gr.Textbox(label="Direct Question", placeholder="You can write direct question here",elem_id="direct-question", interactive=True) |
|
with gr.Accordion("Details",elem_id = 'vanna-details', open=False) as vanna_details : |
|
vanna_sql_query = gr.Textbox(label="SQL Query Used", elem_id="sql-query", interactive=False) |
|
show_vanna_table = gr.Button("Show Table", elem_id="show-table") |
|
with Modal(visible=False) as vanna_table_modal: |
|
vanna_table = gr.DataFrame([], elem_id="vanna-table") |
|
close_vanna_modal = gr.Button("Close", elem_id="close-vanna-modal") |
|
close_vanna_modal.click(lambda: Modal(visible=False),None, [vanna_table_modal]) |
|
show_vanna_table.click(lambda: Modal(visible=True),None ,[vanna_table_modal]) |
|
|
|
vanna_display = gr.Plot() |
|
vanna_direct_question.submit(ask_vanna_query, [vanna_direct_question], [vanna_sql_query ,vanna_table, vanna_display]) |
|
|
|
|
|
def cqa_tab(tab_name): |
|
|
|
current_graphs = gr.State([]) |
|
with gr.Tab(tab_name): |
|
with gr.Row(elem_id="chatbot-row"): |
|
|
|
with gr.Column(scale=2): |
|
chatbot, textbox, config_button = create_chat_interface(tab_name) |
|
|
|
|
|
with gr.Column(scale=2, variant="panel", elem_id="right-panel"): |
|
with gr.Tabs(elem_id="right_panel_tab") as tabs: |
|
|
|
with gr.TabItem("Examples", elem_id="tab-examples", id=0): |
|
examples_hidden = create_examples_tab(tab_name) |
|
|
|
|
|
with gr.Tab("Sources", elem_id="tab-sources", id=1) as tab_sources: |
|
sources_textbox = gr.HTML(show_label=False, elem_id="sources-textbox") |
|
|
|
|
|
|
|
with gr.Tab("Recommended content", elem_id="tab-recommended_content", id=2) as tab_recommended_content: |
|
with gr.Tabs(elem_id="group-subtabs") as tabs_recommended_content: |
|
|
|
with gr.Tab("Figures", elem_id="tab-figures", id=3) as tab_figures: |
|
sources_raw, new_figures, used_figures, gallery_component, figures_cards, figure_modal = create_figures_tab() |
|
|
|
|
|
with gr.Tab("Papers", elem_id="tab-citations", id=4) as tab_papers: |
|
papers_direct_search, papers_summary, papers_html, citations_network, papers_modal = create_papers_tab() |
|
|
|
|
|
with gr.Tab("Graphs", elem_id="tab-graphs", id=5) as tab_graphs: |
|
graphs_container = gr.HTML( |
|
"<h2>There are no graphs to be displayed at the moment. Try asking another question.</h2>", |
|
elem_id="graphs-container" |
|
) |
|
|
|
|
|
return { |
|
"chatbot": chatbot, |
|
"textbox": textbox, |
|
"tabs": tabs, |
|
"sources_raw": sources_raw, |
|
"new_figures": new_figures, |
|
"current_graphs": current_graphs, |
|
"examples_hidden": examples_hidden, |
|
"sources_textbox": sources_textbox, |
|
"figures_cards": figures_cards, |
|
"gallery_component": gallery_component, |
|
"config_button": config_button, |
|
"papers_direct_search" : papers_direct_search, |
|
"papers_html": papers_html, |
|
"citations_network": citations_network, |
|
"papers_summary": papers_summary, |
|
"tab_recommended_content": tab_recommended_content, |
|
"tab_sources": tab_sources, |
|
"tab_figures": tab_figures, |
|
"tab_graphs": tab_graphs, |
|
"tab_papers": tab_papers, |
|
"graph_container": graphs_container, |
|
|
|
|
|
|
|
} |
|
|
|
def config_event_handling(main_tabs_components : list[dict], config_componenets : dict): |
|
config_open = config_componenets["config_open"] |
|
config_modal = config_componenets["config_modal"] |
|
close_config_modal = config_componenets["close_config_modal_button"] |
|
|
|
for button in [close_config_modal] + [main_tab_component["config_button"] for main_tab_component in main_tabs_components]: |
|
button.click( |
|
fn=update_config_modal_visibility, |
|
inputs=[config_open], |
|
outputs=[config_modal, config_open] |
|
) |
|
|
|
def event_handling( |
|
main_tab_components, |
|
config_components, |
|
tab_name="ClimateQ&A" |
|
): |
|
chatbot = main_tab_components["chatbot"] |
|
textbox = main_tab_components["textbox"] |
|
tabs = main_tab_components["tabs"] |
|
sources_raw = main_tab_components["sources_raw"] |
|
new_figures = main_tab_components["new_figures"] |
|
current_graphs = main_tab_components["current_graphs"] |
|
examples_hidden = main_tab_components["examples_hidden"] |
|
sources_textbox = main_tab_components["sources_textbox"] |
|
figures_cards = main_tab_components["figures_cards"] |
|
gallery_component = main_tab_components["gallery_component"] |
|
|
|
papers_direct_search = main_tab_components["papers_direct_search"] |
|
papers_html = main_tab_components["papers_html"] |
|
citations_network = main_tab_components["citations_network"] |
|
papers_summary = main_tab_components["papers_summary"] |
|
tab_recommended_content = main_tab_components["tab_recommended_content"] |
|
tab_sources = main_tab_components["tab_sources"] |
|
tab_figures = main_tab_components["tab_figures"] |
|
tab_graphs = main_tab_components["tab_graphs"] |
|
tab_papers = main_tab_components["tab_papers"] |
|
graphs_container = main_tab_components["graph_container"] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dropdown_sources = config_components["dropdown_sources"] |
|
dropdown_reports = config_components["dropdown_reports"] |
|
dropdown_external_sources = config_components["dropdown_external_sources"] |
|
search_only = config_components["search_only"] |
|
dropdown_audience = config_components["dropdown_audience"] |
|
after = config_components["after"] |
|
output_query = config_components["output_query"] |
|
output_language = config_components["output_language"] |
|
|
|
|
|
new_sources_hmtl = gr.State([]) |
|
ttd_data = gr.State([]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if tab_name == "ClimateQ&A": |
|
print("chat cqa - message sent") |
|
|
|
|
|
(textbox |
|
.submit(start_chat, [textbox, chatbot, search_only], [textbox, tabs, chatbot, sources_raw], queue=False, api_name=f"start_chat_{textbox.elem_id}") |
|
.then(chat, [textbox, chatbot, dropdown_audience, dropdown_sources, dropdown_reports, dropdown_external_sources, search_only], [chatbot, new_sources_hmtl, output_query, output_language, new_figures, current_graphs], concurrency_limit=8, api_name=f"chat_{textbox.elem_id}") |
|
.then(finish_chat, None, [textbox], api_name=f"finish_chat_{textbox.elem_id}") |
|
) |
|
|
|
(examples_hidden |
|
.change(start_chat, [examples_hidden, chatbot, search_only], [examples_hidden, tabs, chatbot, sources_raw], queue=False, api_name=f"start_chat_{examples_hidden.elem_id}") |
|
.then(chat, [examples_hidden, chatbot, dropdown_audience, dropdown_sources, dropdown_reports, dropdown_external_sources, search_only], [chatbot, new_sources_hmtl, output_query, output_language, new_figures, current_graphs], concurrency_limit=8, api_name=f"chat_{examples_hidden.elem_id}") |
|
.then(finish_chat, None, [textbox], api_name=f"finish_chat_{examples_hidden.elem_id}") |
|
) |
|
|
|
elif tab_name == "Beta - POC Adapt'Action": |
|
print("chat poc - message sent") |
|
|
|
(textbox |
|
.submit(start_chat, [textbox, chatbot, search_only], [textbox, tabs, chatbot, sources_raw], queue=False, api_name=f"start_chat_{textbox.elem_id}") |
|
.then(chat_poc, [textbox, chatbot, dropdown_audience, dropdown_sources, dropdown_reports, dropdown_external_sources, search_only], [chatbot, new_sources_hmtl, output_query, output_language, new_figures, current_graphs], concurrency_limit=8, api_name=f"chat_{textbox.elem_id}") |
|
.then(finish_chat, None, [textbox], api_name=f"finish_chat_{textbox.elem_id}") |
|
) |
|
|
|
(examples_hidden |
|
.change(start_chat, [examples_hidden, chatbot, search_only], [examples_hidden, tabs, chatbot, sources_raw], queue=False, api_name=f"start_chat_{examples_hidden.elem_id}") |
|
.then(chat_poc, [examples_hidden, chatbot, dropdown_audience, dropdown_sources, dropdown_reports, dropdown_external_sources, search_only], [chatbot, new_sources_hmtl, output_query, output_language, new_figures, current_graphs], concurrency_limit=8, api_name=f"chat_{examples_hidden.elem_id}") |
|
.then(finish_chat, None, [textbox], api_name=f"finish_chat_{examples_hidden.elem_id}") |
|
) |
|
|
|
|
|
new_sources_hmtl.change(lambda x : x, inputs = [new_sources_hmtl], outputs = [sources_textbox]) |
|
current_graphs.change(lambda x: x, inputs=[current_graphs], outputs=[graphs_container]) |
|
new_figures.change(process_figures, inputs=[sources_raw, new_figures], outputs=[sources_raw, figures_cards, gallery_component]) |
|
|
|
|
|
for component in [sources_textbox, figures_cards, current_graphs, papers_html]: |
|
component.change(update_sources_number_display, [sources_textbox, figures_cards, current_graphs, papers_html], [tab_recommended_content, tab_sources, tab_figures, tab_graphs, tab_papers]) |
|
|
|
|
|
for component in [textbox, examples_hidden, papers_direct_search]: |
|
component.submit(find_papers, [component, after, dropdown_external_sources], [papers_html, citations_network, papers_summary]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
def main_ui(): |
|
|
|
with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=theme, elem_id="main-component") as demo: |
|
config_components = create_config_modal() |
|
|
|
with gr.Tabs(): |
|
cqa_components = cqa_tab(tab_name = "ClimateQ&A") |
|
local_cqa_components = cqa_tab(tab_name = "Beta - POC Adapt'Action") |
|
create_drias_tab() |
|
|
|
create_about_tab() |
|
|
|
event_handling(cqa_components, config_components, tab_name = 'ClimateQ&A') |
|
event_handling(local_cqa_components, config_components, tab_name = "Beta - POC Adapt'Action") |
|
|
|
config_event_handling([cqa_components,local_cqa_components] ,config_components) |
|
|
|
demo.queue() |
|
|
|
return demo |
|
|
|
|
|
demo = main_ui() |
|
demo.launch(ssr_mode=False) |
|
|