talk_to_ipcc (#29)
Browse files- change local examples (3fa35bfc2ec4e3397b571eecc44a68b1a1d1152c)
- fix (c5c15575b0a549578448fa4ae916af294cd4bdc4)
- Upload drias.db (7ea3b8dc6dcc209c74ac86140c23c3f35356b271)
- feature/add_talk_to_data (#23) (dac62987820b2d821cda7ee86bca058bcf66a66b)
- Change openai_key (680ed1c172d20c1280a46077a72c2eb7be02bd10)
- update OpenAI usage from Vanna (3e75ed8cd2d2f444a43cc93c814d4392f991d8dc)
- Small clean POC Local (0327516552f1b9af184361e8594b7377a4e0bcb2)
- Clean configs (ecab0c839ce7fc8100121a0174409451dbb5dd50)
- fix : fix gradio component (1e74d1402fb3c834dc9c0163aeb5028dbd6764bd)
- take the last question as history to understand the question (676d17bd25bb4cbd8cf9396ee353aac7a7028af0)
- Add follow up questions (aaa4bbedbbdee9951a33c8d7bbfd7583fa441bc5)
- Fix : Dynamic follow up examples (a984d58b3e8cf8458eec39ed29d0267590f8e75a)
- Merged in feature/dynamic_conversation (pull request #1) (f684aff68309e2f14164cef7fd02a6da804dcef9)
- Update style.css (e8258aea3313f558bb0afa650464d740fa4b2da4)
- Merge branch 'dev' of https://bitbucket.org/ekimetrics/climate_qa into dev (03a8baf608fd8a639c96a661725c181660122aa1)
- feat: implemented talk to drias v1 (4df74e4b75096daf90548ae7eefe6c32df002bcd)
- feat: added 2 new talk to data plots (170018666950c1b7433f6baba17ad36042282a2e)
- feat: added drias model choice and changed TTD UI (1eae86b617a8336ee3e4b7b11e39dc0769731812)
- fix: fixed bugs and errors (e6e652c91a221b56eb508cc98b13d27b5785e803)
- Merge branch 'dev' into feat/talk_to_data_graph (561caf15f2fd2954b43a49064ac11a085d343151)
- update css for follow up examples (54e2358ad0444473be0ac8d6c436ef73cfcceb3b)
- ensure correct output language (723be3263bfd1a048f671b7ef11219de2c10b736)
- feat: model filtering and UI upgrade for TTD (6155a631718e808fb6622083848a5fc3bf231937)
- feat: added list of tables and reduced execution time of TTD (0bdf2f6f90901ed5cd65fc99d4de1da8eab1788a)
- add documentation (161aa8c18cb4be2d0c5df0dc64f3c202e55c582d)
- Move hardcoded configuration in a config file (45a93206d1c3507a622ba4f40df17a339d5a8bad)
- make ask drias asynchronous (5fe15430ebbb47868615f715c00d6274fad0e7a7)
- Add drias indicators (989d387afc62cb38ae239e5cc97664f702e3fa54)
- Add examples (d4fa76b2cb0a8ed1ae0f482254acfbd0242fdc73)
- UI improvment (1d710d6428b17b25b42e14fa494c2823b7e74e81)
- split front element and event listening (ca2c42949b3de84aeaaabf5879f107cd829523a0)
- Merged in feat/talk_to_data_graph (pull request #3) (f2baf8741a4c6b9712f2cd9cec2f86ddb4ca4274)
- Merge branch 'dev' of https://bitbucket.org/ekimetrics/climate_qa into dev (5c04812e26e5faa6ee2a85867f0155c1e56b96a0)
- add logs of drias interactions (bc43b45669dfd4740dc90d6daa1128278857d5d3)
- Merged in dev (pull request #4) (6af9e984b3a852db3414a96307040d2bbc031871)
- rename tabs for prod (b09184847c18b284fd0d08532a1572722f9159d5)
- V1.7 - Dynamic conversation & France local QA (#24) (6f13540b2041360bcd5d6eab4d699df28d96abce)
- Fix for python 3.10 (b142280be035c6ee3a6496f98cac37b23b38e18d)
- Fix requirements (f1b3e893c4b744376eec5ae30e3ed2e6c11fd8cf)
- update requirements (f46a24ac65bd569cad69d2281c5b5943a9a0830e)
- fixed requiremets (9ab8c871a2c9a8d6bf54d51d32666e973e46b4ee)
- Merge branch 'main' of https://huggingface.co/spaces/Ekimetrics/climate-question-answering (584bae4c2796688e9ecfdd393c693ce0e0d18fc5)
- fix tab name (0a2f32b117ce6289b5c49ab68a7aaedc502f5eca)
- Merge remote-tracking branch 'hf-origin/main' (0afa3dbe6042da779ac4911e23ce5cf7a4d40065)
- TTD : remove first unrelevant points (535d47236a99ebdc1e297eeb35b8833498214af2)
- Fixed follow up examples for local qa (c56c4345fd9484d8d43fd4ee36d8a4a0c3f6bd67)
- change init prompt to english (8c7a7fed7bddc7e672a257ab80f0840a6977de17)
- log to huggingface (f9c4c84a71d320c6db05ee099b9e35492ba7b184)
- Merged in feat/logs_on_huggingface (pull request #5) (261632833e7d939c321f208edfd6576e68947b4b)
- feat: added multithreading to run sql queries in talk to drias (705ccece7775c65a9c7b73b091cdddbc4246f2e7)
- chore: remove prints in talk to drias workflow (a967134f90c70d87bb3d786f2feb81f2e56fdb9f)
- Merged in feat/improve_drias_exeuction_time (pull request #6) (7c38528636ac19c1efa240a122e523ed0c34706a)
- fix import (05b8df9c9b74926459da70797b6852ff07a4d838)
- Merge branch 'main' into dev (8fb231c8beabf8a6406f05cf4cac564c5d81c7ce)
- Merged in dev (pull request #7) (6b9f71b1cf216eef0fd7973412f675d8633a5f4a)
- fix import (b35df2a8160723e43f74a040aa94983069066213)
- Merge branch 'main' of https://bitbucket.org/ekimetrics/climate_qa (f96cfd0715ec2b1ed7a78775ea7f8722f5793d8f)
- feature/drias_parallelization (#25) (47a68b0f6974850c8ec2465bdeda32c731e94568)
- feat: changed talk to drias UI (14d7085602602feef3ce51430f34221dadaa21df)
- temp fix : reranker switch to nano (99302fc182b3b15e3c15a03a633b46adf31e088f)
- Merged in dev (pull request #9) (bbc7806f08a3754842b0c228035c5fe985d45cc3)
- Merge remote-tracking branch 'hf-origin/main' (1b9b1eed6d3a82a9b23791df13f180181e6baa31)
- refactor: modularized talk to data (46c1e346e695ba31e91a6495a57a9063cf4dad54)
- Merge remote-tracking branch 'origin' into refactor/ttd_modularization (c25f6b1dd1d359574c77810933bcd01119f8a5b3)
- Standardize loggin (5b1b83bccbc4b9ae9021e29b43d822cf5b013e7d)
- Merge branch 'dev' into feat/logs_on_huggingface (32d594b10afd7196a7a83f9d9cb6a169a1c5ef51)
- Update constants.py (7a891a71dd288d985b940a61e3df26b34bb496e3)
- correct logs formating for dataviewer (2de91ee28e1168621896119cf65e44608f03000f)
- change logs repo (5f2bda40adc7025b6828daa3c8e65604f10cffaa)
- Merged in feat/logs_on_huggingface (pull request #11) (f52ad179f835a04bc5f6b9880699c7889cf0fc6c)
- Update logging.py (3ff380ff39b5417cda20a8f953301b6f1b833800)
- Merged in feat/logs_on_huggingface (pull request #12) (7c8dc1752e212554fba0a9bbf38d2552af5f72e1)
- Merged in main (pull request #13) (9a9ab642ce3a6361946f3d6ef5f8282d893d6a24)
- Merged in dev (pull request #14) (97d80b660352d1e09bfd697e86d4a7b4f6386571)
- add logs fallback on azure (58f35b70200236190bc8f57f5ec75a72573b31a1)
- Merged in feat/logs_on_huggingface (pull request #15) (e92e8dc1b629f6e9c61657683419e6585fd54261)
- Merged in dev (pull request #16) (eae39a6274f762fdfdfc49e86680d27ce7c6c8fc)
- feat: updated common talk to data for talk to ipcc and drias (c0fd277fca5091c51086af4aeaaa36f625531cef)
- feat: created queries for talk to ipcc (e8d5bc9410f6777c62dcacb77160ea08ac172e31)
- feat: created plots for talk to ipcc (3f85f963aef03ebf32f77836652ce1926778958b)
- feat: created config constants for talk to ipcc (c3398f4493e79a39702edcfa436a860312edf3e5)
- implemented talk to ipcc workflow and updated talk to data state object (45e1dba85d72fed8767b1e369934cf13639a3242)
- Merged in feature/talk_to_ipcc (pull request #17) (ac63459d9126fed1121796aaee4cb3adfbc955d9)
- feat: updated talk to drias based on talk to ipcc (c3024c3df719b25b6aee66dda3aed3927bb83077)
- Merge remote-tracking branch 'origin' into feature/talk_to_data (4af247227e85f9b72b159b213d3a1eccee7b6733)
- chore: added geojson to requirements (1a4f8680944e738da3237b7d6d03b0fd430f186c)
- fix: fix markdown bug in IPCC_UI_TEXT (d3040d364345ed73e9ac1657c5613e1062e728b7)
- chore: added geojson polygon function and custom colorscale per indicator (30f401b002885a71bb62069ec17bd2619f98961e)
- feat: added plot informations for each plot (DRIAS & IPCC) (bdeb1e525e4c80ce990329052a2c172ad9847113)
- feat: updated TTD UI (DRIAS & IPCC) (fdad3be2df6ab94da4b7847a081c3c0c1b67eaf2)
- fix: fixed graphs display bugs (a4206dc31d5e238ee26876faf91473a202eafec1)
- feat: updated sql query for not macro countries in talk to ipcc (ff653468b4f2bfe2e166ad572ca3c91cb5fb986f)
- refactor: changed choropleth map into map (f9508fce0c7db4822b33a8ba95fb048818fd0840)
- feat: changed marker opacity on map plots (7374d87d5c8521f7ff418eb04636558ee6dcfc14)
- chore: added information about grid point mapping to country inside plot information for map (29cf97a04f9def87a065f05139cdb2374b47d07d)
- feat: changed queries for macro countries in talk to ipcc (d0f045617aefd9a28811c567ecda0d1ff70ce548)
- fix: fixed submit bug in talk to drias (01a4939d78358d57f94232827cfe197c22678258)
- feat: added queries on huge countries in talk to ipcc (19905de878ff990f1beff50d76a8eb11541e46b9)
- chore: changed talk to ipcc how to use text (5bd3f8cdd51ef4bbd47b4524ab5406434801a048)
- replace eval by litteral eval (2d6b3b9f22416c157a60ecc2616346dc975d6861)
- fix retrieve documents error for search without reranker (0536821aa63e34759510b5d55505da8df32bae8f)
- Move vanna files in dedicated folder (fa2765ac9a08e5d00439ff752159cf2d22b10d1f)
- move path into config (0695fe4d36575aa5881f981e16897a7bc3a4d45d)
- remove unnecessary imports (819e3c05730ca9e4b1a6526f00d961f1509e3b0d)
- Merged in feature/talk_to_data (pull request #19) (11ab5fbc82c47ac8524b62d4c693eaa376dd149f)
- rename hf_token name (d78ea762bb16187b8755d5c07d90827c10ccf127)
- add lfs images (a9b4fcc0bd5fafd39104eae0074cad2b51bf1341)
- Add PNG assets (11910c77dff6e0795b7cbecb2e692645ef4e14f0)
- Merge hf-origin/main into main (1ed4e25589d82fe195eab13e7249e3889cecac4b)
Co-authored-by: Armand Demasson <[email protected]>
- .gitattributes +1 -0
- app.py +9 -0
- climateqa/engine/chains/retrieve_documents.py +6 -4
- climateqa/engine/talk_to_data/config.py +8 -96
- climateqa/engine/talk_to_data/drias/config.py +124 -0
- climateqa/engine/talk_to_data/drias/plot_informations.py +88 -0
- climateqa/engine/talk_to_data/drias/plots.py +434 -0
- climateqa/engine/talk_to_data/drias/queries.py +83 -0
- climateqa/engine/talk_to_data/input_processing.py +257 -0
- climateqa/engine/talk_to_data/ipcc/config.py +98 -0
- climateqa/engine/talk_to_data/ipcc/plot_informations.py +50 -0
- climateqa/engine/talk_to_data/ipcc/plots.py +189 -0
- climateqa/engine/talk_to_data/ipcc/queries.py +143 -0
- climateqa/engine/talk_to_data/main.py +77 -71
- climateqa/engine/talk_to_data/objects/llm_outputs.py +13 -0
- climateqa/engine/talk_to_data/objects/location.py +12 -0
- climateqa/engine/talk_to_data/objects/plot.py +23 -0
- climateqa/engine/talk_to_data/objects/states.py +19 -0
- climateqa/engine/talk_to_data/prompt.py +44 -0
- climateqa/engine/talk_to_data/query.py +57 -0
- climateqa/engine/talk_to_data/ui_config.py +27 -0
- climateqa/engine/talk_to_data/vanna/myVanna.py +13 -0
- climateqa/engine/talk_to_data/vanna/vanna_class.py +325 -0
- climateqa/engine/talk_to_data/workflow/drias.py +163 -0
- climateqa/engine/talk_to_data/workflow/ipcc.py +161 -0
- front/assets/talk_to_drias_annual_temperature_france_example.png +3 -0
- front/assets/talk_to_drias_frequency_remarkable_precipitation_lyon_example.png +3 -0
- front/assets/talk_to_drias_winter_temp_paris_example.png +3 -0
- front/assets/talk_to_ipcc_china_example.png +3 -0
- front/assets/talk_to_ipcc_france_example.png +3 -0
- front/assets/talk_to_ipcc_new_york_example.png +3 -0
- front/tabs/tab_drias.py +60 -149
- front/tabs/tab_ipcc.py +300 -0
- requirements.txt +2 -1
- style.css +45 -4
@@ -45,3 +45,4 @@ documents/climate_gpt_v2.faiss filter=lfs diff=lfs merge=lfs -text
|
|
45 |
climateqa_v3.db filter=lfs diff=lfs merge=lfs -text
|
46 |
climateqa_v3.faiss filter=lfs diff=lfs merge=lfs -text
|
47 |
data/drias/drias.db filter=lfs diff=lfs merge=lfs -text
|
|
|
|
45 |
climateqa_v3.db filter=lfs diff=lfs merge=lfs -text
|
46 |
climateqa_v3.faiss filter=lfs diff=lfs merge=lfs -text
|
47 |
data/drias/drias.db filter=lfs diff=lfs merge=lfs -text
|
48 |
+
front/assets/*.png filter=lfs diff=lfs merge=lfs -text
|
@@ -16,6 +16,10 @@ from climateqa.chat import start_chat, chat_stream, finish_chat
|
|
16 |
from front.tabs import create_config_modal, cqa_tab, create_about_tab
|
17 |
from front.tabs import MainTabPanel, ConfigPanel
|
18 |
from front.tabs.tab_drias import create_drias_tab
|
|
|
|
|
|
|
|
|
19 |
from front.utils import process_figures
|
20 |
from gradio_modal import Modal
|
21 |
|
@@ -532,8 +536,13 @@ def main_ui():
|
|
532 |
with gr.Tabs():
|
533 |
cqa_components = cqa_tab(tab_name="ClimateQ&A")
|
534 |
local_cqa_components = cqa_tab(tab_name="France - Local Q&A")
|
|
|
|
|
|
|
|
|
535 |
create_drias_tab(share_client=share_client, user_id=user_id)
|
536 |
|
|
|
537 |
create_about_tab()
|
538 |
|
539 |
event_handling(cqa_components, config_components, tab_name="ClimateQ&A")
|
|
|
16 |
from front.tabs import create_config_modal, cqa_tab, create_about_tab
|
17 |
from front.tabs import MainTabPanel, ConfigPanel
|
18 |
from front.tabs.tab_drias import create_drias_tab
|
19 |
+
<<<<<<< HEAD
|
20 |
+
from front.tabs.tab_ipcc import create_ipcc_tab
|
21 |
+
=======
|
22 |
+
>>>>>>> hf-origin/main
|
23 |
from front.utils import process_figures
|
24 |
from gradio_modal import Modal
|
25 |
|
|
|
536 |
with gr.Tabs():
|
537 |
cqa_components = cqa_tab(tab_name="ClimateQ&A")
|
538 |
local_cqa_components = cqa_tab(tab_name="France - Local Q&A")
|
539 |
+
<<<<<<< HEAD
|
540 |
+
drias_components = create_drias_tab(share_client=share_client, user_id=user_id)
|
541 |
+
ipcc_components = create_ipcc_tab(share_client=share_client, user_id=user_id)
|
542 |
+
=======
|
543 |
create_drias_tab(share_client=share_client, user_id=user_id)
|
544 |
|
545 |
+
>>>>>>> hf-origin/main
|
546 |
create_about_tab()
|
547 |
|
548 |
event_handling(cqa_components, config_components, tab_name="ClimateQ&A")
|
@@ -21,7 +21,7 @@ from langchain_core.prompts import ChatPromptTemplate
|
|
21 |
from langchain_core.output_parsers import StrOutputParser
|
22 |
from ..vectorstore import get_pinecone_vectorstore
|
23 |
from ..embeddings import get_embeddings_function
|
24 |
-
|
25 |
|
26 |
import asyncio
|
27 |
|
@@ -477,8 +477,10 @@ async def retrieve_documents(
|
|
477 |
docs_question_dict[key] = rerank_and_sort_docs(reranker,docs_question_dict[key],question)
|
478 |
else:
|
479 |
# Add a default reranking score
|
480 |
-
for
|
481 |
-
|
|
|
|
|
482 |
|
483 |
# Keep the right number of documents
|
484 |
docs_question, images_question = concatenate_documents(index, source_type, docs_question_dict, k_by_question, k_summary_by_question, k_images_by_question)
|
@@ -580,7 +582,7 @@ async def get_relevant_toc_level_for_query(
|
|
580 |
response = chain.invoke({"query": query, "doc_list": doc_list})
|
581 |
|
582 |
try:
|
583 |
-
relevant_tocs =
|
584 |
except Exception as e:
|
585 |
print(f" Failed to parse the result because of : {e}")
|
586 |
|
|
|
21 |
from langchain_core.output_parsers import StrOutputParser
|
22 |
from ..vectorstore import get_pinecone_vectorstore
|
23 |
from ..embeddings import get_embeddings_function
|
24 |
+
import ast
|
25 |
|
26 |
import asyncio
|
27 |
|
|
|
477 |
docs_question_dict[key] = rerank_and_sort_docs(reranker,docs_question_dict[key],question)
|
478 |
else:
|
479 |
# Add a default reranking score
|
480 |
+
for key in docs_question_dict.keys():
|
481 |
+
if isinstance(docs_question_dict[key], list) and len(docs_question_dict[key]) > 0:
|
482 |
+
for doc in docs_question_dict[key]:
|
483 |
+
doc.metadata["reranking_score"] = doc.metadata["similarity_score"]
|
484 |
|
485 |
# Keep the right number of documents
|
486 |
docs_question, images_question = concatenate_documents(index, source_type, docs_question_dict, k_by_question, k_summary_by_question, k_images_by_question)
|
|
|
582 |
response = chain.invoke({"query": query, "doc_list": doc_list})
|
583 |
|
584 |
try:
|
585 |
+
relevant_tocs = ast.literal_eval(response)
|
586 |
except Exception as e:
|
587 |
print(f" Failed to parse the result because of : {e}")
|
588 |
|
@@ -1,99 +1,11 @@
|
|
1 |
-
|
2 |
-
"total_winter_precipitation",
|
3 |
-
"total_summer_precipiation",
|
4 |
-
"total_annual_precipitation",
|
5 |
-
"total_remarkable_daily_precipitation",
|
6 |
-
"frequency_of_remarkable_daily_precipitation",
|
7 |
-
"extreme_precipitation_intensity",
|
8 |
-
"mean_winter_temperature",
|
9 |
-
"mean_summer_temperature",
|
10 |
-
"mean_annual_temperature",
|
11 |
-
"number_of_tropical_nights",
|
12 |
-
"maximum_summer_temperature",
|
13 |
-
"number_of_days_with_tx_above_30",
|
14 |
-
"number_of_days_with_tx_above_35",
|
15 |
-
"number_of_days_with_a_dry_ground",
|
16 |
-
]
|
17 |
|
18 |
-
|
19 |
-
|
20 |
-
"total_summer_precipiation": "total_summer_precipitation",
|
21 |
-
"total_annual_precipitation": "total_annual_precipitation",
|
22 |
-
"total_remarkable_daily_precipitation": "total_remarkable_daily_precipitation",
|
23 |
-
"frequency_of_remarkable_daily_precipitation": "frequency_of_remarkable_daily_precipitation",
|
24 |
-
"extreme_precipitation_intensity": "extreme_precipitation_intensity",
|
25 |
-
"mean_winter_temperature": "mean_winter_temperature",
|
26 |
-
"mean_summer_temperature": "mean_summer_temperature",
|
27 |
-
"mean_annual_temperature": "mean_annual_temperature",
|
28 |
-
"number_of_tropical_nights": "number_tropical_nights",
|
29 |
-
"maximum_summer_temperature": "maximum_summer_temperature",
|
30 |
-
"number_of_days_with_tx_above_30": "number_of_days_with_tx_above_30",
|
31 |
-
"number_of_days_with_tx_above_35": "number_of_days_with_tx_above_35",
|
32 |
-
"number_of_days_with_a_dry_ground": "number_of_days_with_dry_ground"
|
33 |
-
}
|
34 |
|
35 |
-
|
36 |
-
|
37 |
-
'RegCM4-6_MPI-ESM-LR',
|
38 |
-
'RACMO22E_EC-EARTH',
|
39 |
-
'RegCM4-6_HadGEM2-ES',
|
40 |
-
'HadREM3-GA7_EC-EARTH',
|
41 |
-
'HadREM3-GA7_CNRM-CM5',
|
42 |
-
'REMO2015_NorESM1-M',
|
43 |
-
'SMHI-RCA4_EC-EARTH',
|
44 |
-
'WRF381P_NorESM1-M',
|
45 |
-
'ALADIN63_CNRM-CM5',
|
46 |
-
'CCLM4-8-17_MPI-ESM-LR',
|
47 |
-
'HIRHAM5_IPSL-CM5A-MR',
|
48 |
-
'HadREM3-GA7_HadGEM2-ES',
|
49 |
-
'SMHI-RCA4_IPSL-CM5A-MR',
|
50 |
-
'HIRHAM5_NorESM1-M',
|
51 |
-
'REMO2009_MPI-ESM-LR',
|
52 |
-
'CCLM4-8-17_HadGEM2-ES'
|
53 |
-
]
|
54 |
-
# Mapping between indicator columns and their units
|
55 |
-
INDICATOR_TO_UNIT = {
|
56 |
-
"total_winter_precipitation": "mm",
|
57 |
-
"total_summer_precipitation": "mm",
|
58 |
-
"total_annual_precipitation": "mm",
|
59 |
-
"total_remarkable_daily_precipitation": "mm",
|
60 |
-
"frequency_of_remarkable_daily_precipitation": "days",
|
61 |
-
"extreme_precipitation_intensity": "mm",
|
62 |
-
"mean_winter_temperature": "°C",
|
63 |
-
"mean_summer_temperature": "°C",
|
64 |
-
"mean_annual_temperature": "°C",
|
65 |
-
"number_tropical_nights": "days",
|
66 |
-
"maximum_summer_temperature": "°C",
|
67 |
-
"number_of_days_with_tx_above_30": "days",
|
68 |
-
"number_of_days_with_tx_above_35": "days",
|
69 |
-
"number_of_days_with_dry_ground": "days"
|
70 |
-
}
|
71 |
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
❓ **How to use?**
|
77 |
-
You can ask me anything about these climate indicators: **temperature**, **precipitation** or **drought**.
|
78 |
-
You can specify **location** and/or **year**.
|
79 |
-
You can choose from a list of climate models. By default, we take the **average of each model**.
|
80 |
-
|
81 |
-
For example, you can ask:
|
82 |
-
- What will the temperature be like in Paris?
|
83 |
-
- What will be the total rainfall in France in 2030?
|
84 |
-
- How frequent will extreme events be in Lyon?
|
85 |
-
|
86 |
-
**Example of indicators in the data**:
|
87 |
-
- Mean temperature (annual, winter, summer)
|
88 |
-
- Total precipitation (annual, winter, summer)
|
89 |
-
- Number of days with remarkable precipitations, with dry ground, with temperature above 30°C
|
90 |
-
|
91 |
-
⚠️ **Limitations**:
|
92 |
-
- You can't ask anything that isn't related to **DRIAS - TRACC 2023** data.
|
93 |
-
- You can only ask about **locations in France**.
|
94 |
-
- If you specify a year, there may be **no data for that year for some models**.
|
95 |
-
- You **cannot compare two models**.
|
96 |
-
|
97 |
-
🛈 **Information**
|
98 |
-
Please note that we **log your questions for meta-analysis purposes**, so avoid sharing any sensitive or personal information.
|
99 |
-
"""
|
|
|
1 |
+
# Path configuration for climateqa project
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
+
# IPCC dataset path
|
4 |
+
IPCC_DATASET_URL = "hf://datasets/ekimetrics/ipcc-atlas"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
+
# DRIAS dataset paths
|
7 |
+
DRIAS_DATASET_URL = "hf://datasets/timeki/drias_db"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
+
# Table paths
|
10 |
+
DRIAS_MEAN_ANNUAL_TEMPERATURE_PATH = f"{DRIAS_DATASET_URL}/mean_annual_temperature.parquet"
|
11 |
+
IPCC_COORDINATES_PATH = f"{IPCC_DATASET_URL}/coordinates.parquet"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from climateqa.engine.talk_to_data.ui_config import PRECIPITATION_COLORSCALE, TEMPERATURE_COLORSCALE
|
3 |
+
|
4 |
+
|
5 |
+
DRIAS_TABLES = [
|
6 |
+
"total_winter_precipitation",
|
7 |
+
"total_summer_precipitation",
|
8 |
+
"total_annual_precipitation",
|
9 |
+
"total_remarkable_daily_precipitation",
|
10 |
+
"frequency_of_remarkable_daily_precipitation",
|
11 |
+
"extreme_precipitation_intensity",
|
12 |
+
"mean_winter_temperature",
|
13 |
+
"mean_summer_temperature",
|
14 |
+
"mean_annual_temperature",
|
15 |
+
"number_of_tropical_nights",
|
16 |
+
"maximum_summer_temperature",
|
17 |
+
"number_of_days_with_tx_above_30",
|
18 |
+
"number_of_days_with_tx_above_35",
|
19 |
+
"number_of_days_with_a_dry_ground",
|
20 |
+
]
|
21 |
+
|
22 |
+
DRIAS_INDICATOR_COLUMNS_PER_TABLE = {
|
23 |
+
"total_winter_precipitation": "total_winter_precipitation",
|
24 |
+
"total_summer_precipitation": "total_summer_precipitation",
|
25 |
+
"total_annual_precipitation": "total_annual_precipitation",
|
26 |
+
"total_remarkable_daily_precipitation": "total_remarkable_daily_precipitation",
|
27 |
+
"frequency_of_remarkable_daily_precipitation": "frequency_of_remarkable_daily_precipitation",
|
28 |
+
"extreme_precipitation_intensity": "extreme_precipitation_intensity",
|
29 |
+
"mean_winter_temperature": "mean_winter_temperature",
|
30 |
+
"mean_summer_temperature": "mean_summer_temperature",
|
31 |
+
"mean_annual_temperature": "mean_annual_temperature",
|
32 |
+
"number_of_tropical_nights": "number_tropical_nights",
|
33 |
+
"maximum_summer_temperature": "maximum_summer_temperature",
|
34 |
+
"number_of_days_with_tx_above_30": "number_of_days_with_tx_above_30",
|
35 |
+
"number_of_days_with_tx_above_35": "number_of_days_with_tx_above_35",
|
36 |
+
"number_of_days_with_a_dry_ground": "number_of_days_with_dry_ground"
|
37 |
+
}
|
38 |
+
|
39 |
+
DRIAS_MODELS = [
|
40 |
+
'ALL',
|
41 |
+
'RegCM4-6_MPI-ESM-LR',
|
42 |
+
'RACMO22E_EC-EARTH',
|
43 |
+
'RegCM4-6_HadGEM2-ES',
|
44 |
+
'HadREM3-GA7_EC-EARTH',
|
45 |
+
'HadREM3-GA7_CNRM-CM5',
|
46 |
+
'REMO2015_NorESM1-M',
|
47 |
+
'SMHI-RCA4_EC-EARTH',
|
48 |
+
'WRF381P_NorESM1-M',
|
49 |
+
'ALADIN63_CNRM-CM5',
|
50 |
+
'CCLM4-8-17_MPI-ESM-LR',
|
51 |
+
'HIRHAM5_IPSL-CM5A-MR',
|
52 |
+
'HadREM3-GA7_HadGEM2-ES',
|
53 |
+
'SMHI-RCA4_IPSL-CM5A-MR',
|
54 |
+
'HIRHAM5_NorESM1-M',
|
55 |
+
'REMO2009_MPI-ESM-LR',
|
56 |
+
'CCLM4-8-17_HadGEM2-ES'
|
57 |
+
]
|
58 |
+
# Mapping between indicator columns and their units
|
59 |
+
DRIAS_INDICATOR_TO_UNIT = {
|
60 |
+
"total_winter_precipitation": "mm",
|
61 |
+
"total_summer_precipitation": "mm",
|
62 |
+
"total_annual_precipitation": "mm",
|
63 |
+
"total_remarkable_daily_precipitation": "mm",
|
64 |
+
"frequency_of_remarkable_daily_precipitation": "days",
|
65 |
+
"extreme_precipitation_intensity": "mm",
|
66 |
+
"mean_winter_temperature": "°C",
|
67 |
+
"mean_summer_temperature": "°C",
|
68 |
+
"mean_annual_temperature": "°C",
|
69 |
+
"number_tropical_nights": "days",
|
70 |
+
"maximum_summer_temperature": "°C",
|
71 |
+
"number_of_days_with_tx_above_30": "days",
|
72 |
+
"number_of_days_with_tx_above_35": "days",
|
73 |
+
"number_of_days_with_dry_ground": "days"
|
74 |
+
}
|
75 |
+
|
76 |
+
DRIAS_PLOT_PARAMETERS = [
|
77 |
+
'year',
|
78 |
+
'location'
|
79 |
+
]
|
80 |
+
|
81 |
+
DRIAS_INDICATOR_TO_COLORSCALE = {
|
82 |
+
"total_winter_precipitation": PRECIPITATION_COLORSCALE,
|
83 |
+
"total_summer_precipitation": PRECIPITATION_COLORSCALE,
|
84 |
+
"total_annual_precipitation": PRECIPITATION_COLORSCALE,
|
85 |
+
"total_remarkable_daily_precipitation": PRECIPITATION_COLORSCALE,
|
86 |
+
"frequency_of_remarkable_daily_precipitation": PRECIPITATION_COLORSCALE,
|
87 |
+
"extreme_precipitation_intensity": PRECIPITATION_COLORSCALE,
|
88 |
+
"mean_winter_temperature":TEMPERATURE_COLORSCALE,
|
89 |
+
"mean_summer_temperature":TEMPERATURE_COLORSCALE,
|
90 |
+
"mean_annual_temperature":TEMPERATURE_COLORSCALE,
|
91 |
+
"number_tropical_nights": TEMPERATURE_COLORSCALE,
|
92 |
+
"maximum_summer_temperature":TEMPERATURE_COLORSCALE,
|
93 |
+
"number_of_days_with_tx_above_30": TEMPERATURE_COLORSCALE,
|
94 |
+
"number_of_days_with_tx_above_35": TEMPERATURE_COLORSCALE,
|
95 |
+
"number_of_days_with_dry_ground": TEMPERATURE_COLORSCALE
|
96 |
+
}
|
97 |
+
|
98 |
+
DRIAS_UI_TEXT = """
|
99 |
+
Hi, I'm **Talk to Drias**, designed to answer your questions using [**DRIAS - TRACC 2023**](https://www.drias-climat.fr/accompagnement/sections/401) data.
|
100 |
+
I'll answer by displaying a list of SQL queries, graphs and data most relevant to your question.
|
101 |
+
|
102 |
+
You can ask me anything about these climate indicators: **temperature**, **precipitation** or **drought**.
|
103 |
+
You can specify **location** and/or **year**.
|
104 |
+
You can choose from a list of climate models. By default, we take the **average of each model**.
|
105 |
+
|
106 |
+
For example, you can ask:
|
107 |
+
- What will the temperature be like in Paris?
|
108 |
+
- What will be the total rainfall in France in 2030?
|
109 |
+
- How frequent will extreme events be in Lyon?
|
110 |
+
|
111 |
+
**Example of indicators in the data**:
|
112 |
+
- Mean temperature (annual, winter, summer)
|
113 |
+
- Total precipitation (annual, winter, summer)
|
114 |
+
- Number of days with remarkable precipitations, with dry ground, with temperature above 30°C
|
115 |
+
|
116 |
+
⚠️ **Limitations**:
|
117 |
+
- You can't ask anything that isn't related to **DRIAS - TRACC 2023** data.
|
118 |
+
- You can only ask about **locations in France**.
|
119 |
+
- If you specify a year, there may be **no data for that year for some models**.
|
120 |
+
- You **cannot compare two models**.
|
121 |
+
|
122 |
+
🛈 **Information**
|
123 |
+
Please note that we **log your questions for meta-analysis purposes**, so avoid sharing any sensitive or personal information.
|
124 |
+
"""
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from climateqa.engine.talk_to_data.drias.config import DRIAS_INDICATOR_TO_UNIT
|
2 |
+
|
3 |
+
def indicator_evolution_informations(
|
4 |
+
indicator: str,
|
5 |
+
params: dict[str, str]
|
6 |
+
) -> str:
|
7 |
+
unit = DRIAS_INDICATOR_TO_UNIT[indicator]
|
8 |
+
if "location" not in params:
|
9 |
+
raise ValueError('"location" must be provided in params')
|
10 |
+
location = params["location"]
|
11 |
+
return f"""
|
12 |
+
This plot shows how the climate indicator **{indicator}** evolves over time in **{location}**.
|
13 |
+
|
14 |
+
It combines both historical observations and future projections according to the climate scenario RCP8.5.
|
15 |
+
|
16 |
+
The x-axis represents the years, and the y-axis shows the value of the indicator ({unit}).
|
17 |
+
|
18 |
+
A 10-year rolling average curve is displayed to give a better idea of the overall trend.
|
19 |
+
|
20 |
+
**Data source:**
|
21 |
+
- The data come from the DRIAS TRACC data. The data were initially extracted from [the DRIAS website](https://www.drias-climat.fr/drias_prod/accueil/okapiWebDrias/index.jsp?iddrias=climat) and then preprocessed to a tabular format and uploaded as parquet in this [Hugging Face dataset](https://huggingface.co/datasets/timeki/drias_db).
|
22 |
+
- For each year and climate model, the value of {indicator} in {location} is collected, to build the time series.
|
23 |
+
- The coordinates used for {location} correspond to the closest available point in the DRIAS database, which uses a regular grid with a spatial resolution of 8 km.
|
24 |
+
- The indicator values shown are those for the selected climate model.
|
25 |
+
- If ALL climate model is selected, the average value of the indicator between all the climate models is used.
|
26 |
+
"""
|
27 |
+
|
28 |
+
def indicator_number_of_days_per_year_informations(
|
29 |
+
indicator: str,
|
30 |
+
params: dict[str, str]
|
31 |
+
) -> str:
|
32 |
+
unit = DRIAS_INDICATOR_TO_UNIT[indicator]
|
33 |
+
if "location" not in params:
|
34 |
+
raise ValueError('"location" must be provided in params')
|
35 |
+
location = params["location"]
|
36 |
+
return f"""
|
37 |
+
This plot displays a bar chart showing the yearly frequency of the climate indicator **{indicator}** in **{location}**.
|
38 |
+
|
39 |
+
The x-axis represents the years, and the y-axis shows the frequency of {indicator} ({unit}) per year.
|
40 |
+
|
41 |
+
**Data source:**
|
42 |
+
- The data come from the DRIAS TRACC data. The data were initially extracted from [the DRIAS website](https://www.drias-climat.fr/drias_prod/accueil/okapiWebDrias/index.jsp?iddrias=climat) and then preprocessed to a tabular format and uploaded as parquet in this [Hugging Face dataset](https://huggingface.co/datasets/timeki/drias_db).
|
43 |
+
- For each year and climate model, the value of {indicator} in {location} is collected, to build the time series.
|
44 |
+
- The coordinates used for {location} correspond to the closest available point in the DRIAS database, which uses a regular grid with a spatial resolution of 8 km.
|
45 |
+
- The indicator values shown are those for the selected climate model.
|
46 |
+
- If ALL climate model is selected, the average value of the indicator between all the climate models is used.
|
47 |
+
"""
|
48 |
+
|
49 |
+
def distribution_of_indicator_for_given_year_informations(
|
50 |
+
indicator: str,
|
51 |
+
params: dict[str, str]
|
52 |
+
) -> str:
|
53 |
+
unit = DRIAS_INDICATOR_TO_UNIT[indicator]
|
54 |
+
year = params["year"]
|
55 |
+
if year is None:
|
56 |
+
year = 2030
|
57 |
+
return f"""
|
58 |
+
This plot shows a histogram of the distribution of the climate indicator **{indicator}** across all locations for the year **{year}**.
|
59 |
+
|
60 |
+
It allows you to visualize how the values of {indicator} ({unit}) are spread for a given year.
|
61 |
+
|
62 |
+
**Data source:**
|
63 |
+
- The data come from the DRIAS TRACC data. The data were initially extracted from [the DRIAS website](https://www.drias-climat.fr/drias_prod/accueil/okapiWebDrias/index.jsp?iddrias=climat) and then preprocessed to a tabular format and uploaded as parquet in this [Hugging Face dataset](https://huggingface.co/datasets/timeki/drias_db).
|
64 |
+
- For each grid point in the dataset and climate model, the value of {indicator} for the year {year} is extracted.
|
65 |
+
- The indicator values shown are those for the selected climate model.
|
66 |
+
- If ALL climate model is selected, the average value of the indicator between all the climate models is used.
|
67 |
+
"""
|
68 |
+
|
69 |
+
def map_of_france_of_indicator_for_given_year_informations(
|
70 |
+
indicator: str,
|
71 |
+
params: dict[str, str]
|
72 |
+
) -> str:
|
73 |
+
unit = DRIAS_INDICATOR_TO_UNIT[indicator]
|
74 |
+
year = params["year"]
|
75 |
+
if year is None:
|
76 |
+
year = 2030
|
77 |
+
return f"""
|
78 |
+
This plot displays a choropleth map showing the spatial distribution of **{indicator}** across all regions of France for the year **{year}**.
|
79 |
+
|
80 |
+
Each region is colored according to the value of the indicator ({unit}), allowing you to visually compare how {indicator} varies geographically within France for the selected year and climate model.
|
81 |
+
|
82 |
+
**Data source:**
|
83 |
+
- The data come from the DRIAS TRACC data. The data were initially extracted from [the DRIAS website](https://www.drias-climat.fr/drias_prod/accueil/okapiWebDrias/index.jsp?iddrias=climat) and then preprocessed to a tabular format and uploaded as parquet in this [Hugging Face dataset](https://huggingface.co/datasets/timeki/drias_db).
|
84 |
+
- For each region of France, the value of {indicator} in {year} and for the selected climate model is extracted and mapped to its geographic coordinates.
|
85 |
+
- The regions correspond to 8 km squares centered on the grid points of the DRIAS dataset.
|
86 |
+
- The indicator values shown are those for the selected climate model.
|
87 |
+
- If ALL climate model is selected, the average value of the indicator between all the climate models is used.
|
88 |
+
"""
|
@@ -0,0 +1,434 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import geojson
|
3 |
+
from math import cos, radians
|
4 |
+
from typing import Callable
|
5 |
+
import pandas as pd
|
6 |
+
from plotly.graph_objects import Figure
|
7 |
+
import plotly.graph_objects as go
|
8 |
+
from climateqa.engine.talk_to_data.drias.plot_informations import distribution_of_indicator_for_given_year_informations, indicator_evolution_informations, indicator_number_of_days_per_year_informations, map_of_france_of_indicator_for_given_year_informations
|
9 |
+
from climateqa.engine.talk_to_data.objects.plot import Plot
|
10 |
+
from climateqa.engine.talk_to_data.drias.queries import (
|
11 |
+
indicator_for_given_year_query,
|
12 |
+
indicator_per_year_at_location_query,
|
13 |
+
)
|
14 |
+
from climateqa.engine.talk_to_data.drias.config import DRIAS_INDICATOR_TO_COLORSCALE, DRIAS_INDICATOR_TO_UNIT
|
15 |
+
|
16 |
+
def generate_geojson_polygons(latitudes: list[float], longitudes: list[float], indicators: list[float]) -> geojson.FeatureCollection:
|
17 |
+
side_km = 8
|
18 |
+
delta_lat = side_km / 111
|
19 |
+
features = []
|
20 |
+
for idx, (lat, lon, val) in enumerate(zip(latitudes, longitudes, indicators)):
|
21 |
+
delta_lon = side_km / (111 * cos(radians(lat)))
|
22 |
+
half_lat = delta_lat / 2
|
23 |
+
half_lon = delta_lon / 2
|
24 |
+
features.append(geojson.Feature(
|
25 |
+
geometry=geojson.Polygon([[
|
26 |
+
[lon - half_lon, lat - half_lat],
|
27 |
+
[lon + half_lon, lat - half_lat],
|
28 |
+
[lon + half_lon, lat + half_lat],
|
29 |
+
[lon - half_lon, lat + half_lat],
|
30 |
+
[lon - half_lon, lat - half_lat]
|
31 |
+
]]),
|
32 |
+
properties={"value": val},
|
33 |
+
id=str(idx)
|
34 |
+
))
|
35 |
+
|
36 |
+
return geojson.FeatureCollection(features)
|
37 |
+
|
38 |
+
def plot_indicator_evolution_at_location(params: dict) -> Callable[..., Figure]:
|
39 |
+
"""Generates a function to plot indicator evolution over time at a location.
|
40 |
+
|
41 |
+
This function creates a line plot showing how a climate indicator changes
|
42 |
+
over time at a specific location. It handles temperature, precipitation,
|
43 |
+
and other climate indicators.
|
44 |
+
|
45 |
+
Args:
|
46 |
+
params (dict): Dictionary containing:
|
47 |
+
- indicator_column (str): The column name for the indicator
|
48 |
+
- location (str): The location to plot
|
49 |
+
- model (str): The climate model to use
|
50 |
+
|
51 |
+
Returns:
|
52 |
+
Callable[..., Figure]: A function that takes a DataFrame and returns a plotly Figure
|
53 |
+
|
54 |
+
Example:
|
55 |
+
>>> plot_func = plot_indicator_evolution_at_location({
|
56 |
+
... 'indicator_column': 'mean_temperature',
|
57 |
+
... 'location': 'Paris',
|
58 |
+
... 'model': 'ALL'
|
59 |
+
... })
|
60 |
+
>>> fig = plot_func(df)
|
61 |
+
"""
|
62 |
+
indicator = params["indicator_column"]
|
63 |
+
location = params["location"]
|
64 |
+
indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
|
65 |
+
unit = DRIAS_INDICATOR_TO_UNIT.get(indicator, "")
|
66 |
+
|
67 |
+
def plot_data(df: pd.DataFrame) -> Figure:
|
68 |
+
"""Generates the actual plot from the data.
|
69 |
+
|
70 |
+
Args:
|
71 |
+
df (pd.DataFrame): DataFrame containing the data to plot
|
72 |
+
|
73 |
+
Returns:
|
74 |
+
Figure: A plotly Figure object showing the indicator evolution
|
75 |
+
"""
|
76 |
+
fig = go.Figure()
|
77 |
+
if df['model'].nunique() != 1:
|
78 |
+
df_avg = df.groupby("year", as_index=False)[indicator].mean()
|
79 |
+
|
80 |
+
# Transform to list to avoid pandas encoding
|
81 |
+
indicators = df_avg[indicator].astype(float).tolist()
|
82 |
+
years = df_avg["year"].astype(int).tolist()
|
83 |
+
|
84 |
+
# Compute the 10-year rolling average
|
85 |
+
rolling_window = 10
|
86 |
+
sliding_averages = (
|
87 |
+
df_avg[indicator]
|
88 |
+
.rolling(window=rolling_window, min_periods=rolling_window)
|
89 |
+
.mean()
|
90 |
+
.astype(float)
|
91 |
+
.tolist()
|
92 |
+
)
|
93 |
+
model_label = "Model Average"
|
94 |
+
|
95 |
+
# Only add rolling average if we have enough data points
|
96 |
+
if len([x for x in sliding_averages if pd.notna(x)]) > 0:
|
97 |
+
# Sliding average dashed line
|
98 |
+
fig.add_scatter(
|
99 |
+
x=years,
|
100 |
+
y=sliding_averages,
|
101 |
+
mode="lines",
|
102 |
+
name="10 years rolling average",
|
103 |
+
line=dict(dash="dash"),
|
104 |
+
marker=dict(color="#d62728"),
|
105 |
+
hovertemplate=f"10-year average: %{{y:.2f}} {unit}<br>Year: %{{x}}<extra></extra>"
|
106 |
+
)
|
107 |
+
|
108 |
+
else:
|
109 |
+
df_model = df
|
110 |
+
|
111 |
+
# Transform to list to avoid pandas encoding
|
112 |
+
indicators = df_model[indicator].astype(float).tolist()
|
113 |
+
years = df_model["year"].astype(int).tolist()
|
114 |
+
|
115 |
+
# Compute the 10-year rolling average
|
116 |
+
rolling_window = 10
|
117 |
+
sliding_averages = (
|
118 |
+
df_model[indicator]
|
119 |
+
.rolling(window=rolling_window, min_periods=rolling_window)
|
120 |
+
.mean()
|
121 |
+
.astype(float)
|
122 |
+
.tolist()
|
123 |
+
)
|
124 |
+
model_label = f"Model : {df['model'].unique()[0]}"
|
125 |
+
|
126 |
+
# Only add rolling average if we have enough data points
|
127 |
+
if len([x for x in sliding_averages if pd.notna(x)]) > 0:
|
128 |
+
# Sliding average dashed line
|
129 |
+
fig.add_scatter(
|
130 |
+
x=years,
|
131 |
+
y=sliding_averages,
|
132 |
+
mode="lines",
|
133 |
+
name="10 years rolling average",
|
134 |
+
line=dict(dash="dash"),
|
135 |
+
marker=dict(color="#d62728"),
|
136 |
+
hovertemplate=f"10-year average: %{{y:.2f}} {unit}<br>Year: %{{x}}<extra></extra>"
|
137 |
+
)
|
138 |
+
|
139 |
+
# Indicator per year plot
|
140 |
+
fig.add_scatter(
|
141 |
+
x=years,
|
142 |
+
y=indicators,
|
143 |
+
name=f"Yearly {indicator_label}",
|
144 |
+
mode="lines",
|
145 |
+
marker=dict(color="#1f77b4"),
|
146 |
+
hovertemplate=f"{indicator_label}: %{{y:.2f}} {unit}<br>Year: %{{x}}<extra></extra>"
|
147 |
+
)
|
148 |
+
fig.update_layout(
|
149 |
+
title=f"Evolution of {indicator_label} in {location} ({model_label})",
|
150 |
+
xaxis_title="Year",
|
151 |
+
yaxis_title=f"{indicator_label} ({unit})",
|
152 |
+
template="plotly_white",
|
153 |
+
height=900,
|
154 |
+
)
|
155 |
+
return fig
|
156 |
+
|
157 |
+
return plot_data
|
158 |
+
|
159 |
+
|
160 |
+
indicator_evolution_at_location: Plot = {
|
161 |
+
"name": "Indicator evolution at location",
|
162 |
+
"description": "Plot an evolution of the indicator at a certain location",
|
163 |
+
"params": ["indicator_column", "location", "model"],
|
164 |
+
"plot_function": plot_indicator_evolution_at_location,
|
165 |
+
"sql_query": indicator_per_year_at_location_query,
|
166 |
+
"plot_information": indicator_evolution_informations,
|
167 |
+
'short_name': 'Evolution'
|
168 |
+
}
|
169 |
+
|
170 |
+
|
171 |
+
def plot_indicator_number_of_days_per_year_at_location(
|
172 |
+
params: dict,
|
173 |
+
) -> Callable[..., Figure]:
|
174 |
+
"""Generates a function to plot the number of days per year for an indicator.
|
175 |
+
|
176 |
+
This function creates a bar chart showing the frequency of certain climate
|
177 |
+
events (like days above a temperature threshold) per year at a specific location.
|
178 |
+
|
179 |
+
Args:
|
180 |
+
params (dict): Dictionary containing:
|
181 |
+
- indicator_column (str): The column name for the indicator
|
182 |
+
- location (str): The location to plot
|
183 |
+
- model (str): The climate model to use
|
184 |
+
|
185 |
+
Returns:
|
186 |
+
Callable[..., Figure]: A function that takes a DataFrame and returns a plotly Figure
|
187 |
+
"""
|
188 |
+
indicator = params["indicator_column"]
|
189 |
+
location = params["location"]
|
190 |
+
indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
|
191 |
+
unit = DRIAS_INDICATOR_TO_UNIT.get(indicator, "")
|
192 |
+
|
193 |
+
def plot_data(df: pd.DataFrame) -> Figure:
|
194 |
+
"""Generate the figure thanks to the dataframe
|
195 |
+
|
196 |
+
Args:
|
197 |
+
df (pd.DataFrame): pandas dataframe with the required data
|
198 |
+
|
199 |
+
Returns:
|
200 |
+
Figure: Plotly figure
|
201 |
+
"""
|
202 |
+
fig = go.Figure()
|
203 |
+
if df['model'].nunique() != 1:
|
204 |
+
df_avg = df.groupby("year", as_index=False)[indicator].mean()
|
205 |
+
|
206 |
+
# Transform to list to avoid pandas encoding
|
207 |
+
indicators = df_avg[indicator].astype(float).tolist()
|
208 |
+
years = df_avg["year"].astype(int).tolist()
|
209 |
+
model_label = "Model Average"
|
210 |
+
|
211 |
+
else:
|
212 |
+
df_model = df
|
213 |
+
# Transform to list to avoid pandas encoding
|
214 |
+
indicators = df_model[indicator].astype(float).tolist()
|
215 |
+
years = df_model["year"].astype(int).tolist()
|
216 |
+
model_label = f"Model : {df['model'].unique()[0]}"
|
217 |
+
|
218 |
+
|
219 |
+
# Bar plot
|
220 |
+
fig.add_trace(
|
221 |
+
go.Bar(
|
222 |
+
x=years,
|
223 |
+
y=indicators,
|
224 |
+
width=0.5,
|
225 |
+
marker=dict(color="#1f77b4"),
|
226 |
+
hovertemplate=f"{indicator_label}: %{{y:.2f}} {unit}<br>Year: %{{x}}<extra></extra>"
|
227 |
+
)
|
228 |
+
)
|
229 |
+
|
230 |
+
fig.update_layout(
|
231 |
+
title=f"{indicator_label} in {location} ({model_label})",
|
232 |
+
xaxis_title="Year",
|
233 |
+
yaxis_title=f"{indicator_label} ({unit})",
|
234 |
+
yaxis=dict(range=[0, max(indicators)]),
|
235 |
+
bargap=0.5,
|
236 |
+
height=900,
|
237 |
+
template="plotly_white",
|
238 |
+
)
|
239 |
+
|
240 |
+
return fig
|
241 |
+
|
242 |
+
return plot_data
|
243 |
+
|
244 |
+
|
245 |
+
indicator_number_of_days_per_year_at_location: Plot = {
|
246 |
+
"name": "Indicator number of days per year at location",
|
247 |
+
"description": "Plot a barchart of the number of days per year of a certain indicator at a certain location. It is appropriate for frequency indicator.",
|
248 |
+
"params": ["indicator_column", "location", "model"],
|
249 |
+
"plot_function": plot_indicator_number_of_days_per_year_at_location,
|
250 |
+
"sql_query": indicator_per_year_at_location_query,
|
251 |
+
"plot_information": indicator_number_of_days_per_year_informations,
|
252 |
+
"short_name": "Yearly Frequency",
|
253 |
+
}
|
254 |
+
|
255 |
+
|
256 |
+
def plot_distribution_of_indicator_for_given_year(
|
257 |
+
params: dict,
|
258 |
+
) -> Callable[..., Figure]:
|
259 |
+
"""Generates a function to plot the distribution of an indicator for a year.
|
260 |
+
|
261 |
+
This function creates a histogram showing the distribution of a climate
|
262 |
+
indicator across different locations for a specific year.
|
263 |
+
|
264 |
+
Args:
|
265 |
+
params (dict): Dictionary containing:
|
266 |
+
- indicator_column (str): The column name for the indicator
|
267 |
+
- year (str): The year to plot
|
268 |
+
- model (str): The climate model to use
|
269 |
+
|
270 |
+
Returns:
|
271 |
+
Callable[..., Figure]: A function that takes a DataFrame and returns a plotly Figure
|
272 |
+
"""
|
273 |
+
indicator = params["indicator_column"]
|
274 |
+
year = params["year"]
|
275 |
+
if year is None:
|
276 |
+
year = 2030
|
277 |
+
indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
|
278 |
+
unit = DRIAS_INDICATOR_TO_UNIT.get(indicator, "")
|
279 |
+
|
280 |
+
def plot_data(df: pd.DataFrame) -> Figure:
|
281 |
+
"""Generate the figure thanks to the dataframe
|
282 |
+
|
283 |
+
Args:
|
284 |
+
df (pd.DataFrame): pandas dataframe with the required data
|
285 |
+
|
286 |
+
Returns:
|
287 |
+
Figure: Plotly figure
|
288 |
+
"""
|
289 |
+
fig = go.Figure()
|
290 |
+
if df['model'].nunique() != 1:
|
291 |
+
df_avg = df.groupby(["latitude", "longitude"], as_index=False)[
|
292 |
+
indicator
|
293 |
+
].mean()
|
294 |
+
|
295 |
+
# Transform to list to avoid pandas encoding
|
296 |
+
indicators = df_avg[indicator].astype(float).tolist()
|
297 |
+
model_label = "Model Average"
|
298 |
+
|
299 |
+
else:
|
300 |
+
df_model = df
|
301 |
+
|
302 |
+
# Transform to list to avoid pandas encoding
|
303 |
+
indicators = df_model[indicator].astype(float).tolist()
|
304 |
+
model_label = f"Model : {df['model'].unique()[0]}"
|
305 |
+
|
306 |
+
|
307 |
+
fig.add_trace(
|
308 |
+
go.Histogram(
|
309 |
+
x=indicators,
|
310 |
+
opacity=0.8,
|
311 |
+
histnorm="percent",
|
312 |
+
marker=dict(color="#1f77b4"),
|
313 |
+
hovertemplate=f"{indicator_label}: %{{x:.2f}} {unit}<br>Frequency: %{{y:.2f}}%<extra></extra>"
|
314 |
+
)
|
315 |
+
)
|
316 |
+
|
317 |
+
fig.update_layout(
|
318 |
+
title=f"Distribution of {indicator_label} in {year} ({model_label})",
|
319 |
+
xaxis_title=f"{indicator_label} ({unit})",
|
320 |
+
yaxis_title="Frequency (%)",
|
321 |
+
plot_bgcolor="rgba(0, 0, 0, 0)",
|
322 |
+
showlegend=False,
|
323 |
+
height=900,
|
324 |
+
)
|
325 |
+
|
326 |
+
return fig
|
327 |
+
|
328 |
+
return plot_data
|
329 |
+
|
330 |
+
|
331 |
+
distribution_of_indicator_for_given_year: Plot = {
|
332 |
+
"name": "Distribution of an indicator for a given year",
|
333 |
+
"description": "Plot an histogram of the distribution for a given year of the values of an indicator",
|
334 |
+
"params": ["indicator_column", "model", "year"],
|
335 |
+
"plot_function": plot_distribution_of_indicator_for_given_year,
|
336 |
+
"sql_query": indicator_for_given_year_query,
|
337 |
+
"plot_information": distribution_of_indicator_for_given_year_informations,
|
338 |
+
'short_name': 'Distribution'
|
339 |
+
}
|
340 |
+
|
341 |
+
|
342 |
+
def plot_map_of_france_of_indicator_for_given_year(
|
343 |
+
params: dict,
|
344 |
+
) -> Callable[..., Figure]:
|
345 |
+
"""Generates a function to plot a map of France for an indicator.
|
346 |
+
|
347 |
+
This function creates a choropleth map of France showing the spatial
|
348 |
+
distribution of a climate indicator for a specific year.
|
349 |
+
|
350 |
+
Args:
|
351 |
+
params (dict): Dictionary containing:
|
352 |
+
- indicator_column (str): The column name for the indicator
|
353 |
+
- year (str): The year to plot
|
354 |
+
- model (str): The climate model to use
|
355 |
+
|
356 |
+
Returns:
|
357 |
+
Callable[..., Figure]: A function that takes a DataFrame and returns a plotly Figure
|
358 |
+
"""
|
359 |
+
indicator = params["indicator_column"]
|
360 |
+
year = params["year"]
|
361 |
+
if year is None:
|
362 |
+
year = 2030
|
363 |
+
indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
|
364 |
+
unit = DRIAS_INDICATOR_TO_UNIT.get(indicator, "")
|
365 |
+
|
366 |
+
def plot_data(df: pd.DataFrame) -> Figure:
|
367 |
+
fig = go.Figure()
|
368 |
+
if df['model'].nunique() != 1:
|
369 |
+
df_avg = df.groupby(["latitude", "longitude"], as_index=False)[
|
370 |
+
indicator
|
371 |
+
].mean()
|
372 |
+
|
373 |
+
indicators = df_avg[indicator].astype(float).tolist()
|
374 |
+
latitudes = df_avg["latitude"].astype(float).tolist()
|
375 |
+
longitudes = df_avg["longitude"].astype(float).tolist()
|
376 |
+
model_label = "Model Average"
|
377 |
+
|
378 |
+
else:
|
379 |
+
df_model = df
|
380 |
+
|
381 |
+
# Transform to list to avoid pandas encoding
|
382 |
+
indicators = df_model[indicator].astype(float).tolist()
|
383 |
+
latitudes = df_model["latitude"].astype(float).tolist()
|
384 |
+
longitudes = df_model["longitude"].astype(float).tolist()
|
385 |
+
model_label = f"Model : {df['model'].unique()[0]}"
|
386 |
+
|
387 |
+
|
388 |
+
|
389 |
+
geojson_data = generate_geojson_polygons(latitudes, longitudes, indicators)
|
390 |
+
|
391 |
+
fig = go.Figure(go.Choroplethmapbox(
|
392 |
+
geojson=geojson_data,
|
393 |
+
locations=[str(i) for i in range(len(indicators))],
|
394 |
+
featureidkey="id",
|
395 |
+
z=indicators,
|
396 |
+
colorscale=DRIAS_INDICATOR_TO_COLORSCALE[indicator],
|
397 |
+
zmin=min(indicators),
|
398 |
+
zmax=max(indicators),
|
399 |
+
marker_opacity=0.7,
|
400 |
+
marker_line_width=0,
|
401 |
+
colorbar_title=f"{indicator_label} ({unit})",
|
402 |
+
text=[f"{indicator_label}: {value:.2f} {unit}" for value in indicators], # Add hover text showing the indicator value
|
403 |
+
hoverinfo="text"
|
404 |
+
))
|
405 |
+
|
406 |
+
fig.update_layout(
|
407 |
+
mapbox_style="open-street-map", # Use OpenStreetMap
|
408 |
+
mapbox_zoom=5,
|
409 |
+
height=900,
|
410 |
+
mapbox_center={"lat": 46.6, "lon": 2.0},
|
411 |
+
coloraxis_colorbar=dict(title=f"{indicator_label} ({unit})"), # Add legend
|
412 |
+
title=f"{indicator_label} in {year} in France ({model_label}) " # Title
|
413 |
+
)
|
414 |
+
return fig
|
415 |
+
|
416 |
+
return plot_data
|
417 |
+
|
418 |
+
|
419 |
+
map_of_france_of_indicator_for_given_year: Plot = {
|
420 |
+
"name": "Map of France of an indicator for a given year",
|
421 |
+
"description": "Heatmap on the map of France of the values of an indicator for a given year",
|
422 |
+
"params": ["indicator_column", "year", "model"],
|
423 |
+
"plot_function": plot_map_of_france_of_indicator_for_given_year,
|
424 |
+
"sql_query": indicator_for_given_year_query,
|
425 |
+
"plot_information": map_of_france_of_indicator_for_given_year_informations,
|
426 |
+
'short_name': 'Map of France'
|
427 |
+
}
|
428 |
+
|
429 |
+
DRIAS_PLOTS = [
|
430 |
+
indicator_evolution_at_location,
|
431 |
+
indicator_number_of_days_per_year_at_location,
|
432 |
+
distribution_of_indicator_for_given_year,
|
433 |
+
map_of_france_of_indicator_for_given_year,
|
434 |
+
]
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import TypedDict
|
2 |
+
from climateqa.engine.talk_to_data.config import DRIAS_DATASET_URL
|
3 |
+
|
4 |
+
class IndicatorPerYearAtLocationQueryParams(TypedDict, total=False):
|
5 |
+
"""Parameters for querying an indicator's values over time at a location.
|
6 |
+
|
7 |
+
This class defines the parameters needed to query climate indicator data
|
8 |
+
for a specific location over multiple years.
|
9 |
+
|
10 |
+
Attributes:
|
11 |
+
indicator_column (str): The column name for the climate indicator
|
12 |
+
latitude (str): The latitude coordinate of the location
|
13 |
+
longitude (str): The longitude coordinate of the location
|
14 |
+
model (str): The climate model to use (optional)
|
15 |
+
"""
|
16 |
+
indicator_column: str
|
17 |
+
latitude: str
|
18 |
+
longitude: str
|
19 |
+
model: str
|
20 |
+
|
21 |
+
def indicator_per_year_at_location_query(
|
22 |
+
table: str, params: IndicatorPerYearAtLocationQueryParams
|
23 |
+
) -> str:
|
24 |
+
"""SQL Query to get the evolution of an indicator per year at a certain location
|
25 |
+
|
26 |
+
Args:
|
27 |
+
table (str): sql table of the indicator
|
28 |
+
params (IndicatorPerYearAtLocationQueryParams) : dictionary with the required params for the query
|
29 |
+
|
30 |
+
Returns:
|
31 |
+
str: the sql query
|
32 |
+
"""
|
33 |
+
indicator_column = params.get("indicator_column")
|
34 |
+
latitude = params.get("latitude")
|
35 |
+
longitude = params.get("longitude")
|
36 |
+
|
37 |
+
if indicator_column is None or latitude is None or longitude is None: # If one parameter is missing, returns an empty query
|
38 |
+
return ""
|
39 |
+
|
40 |
+
table = f"'{DRIAS_DATASET_URL}/{table.lower()}.parquet'"
|
41 |
+
|
42 |
+
sql_query = f"SELECT year, {indicator_column}, model\nFROM {table}\nWHERE latitude = {latitude} \nAnd longitude = {longitude} \nOrder by Year"
|
43 |
+
|
44 |
+
return sql_query
|
45 |
+
|
46 |
+
class IndicatorForGivenYearQueryParams(TypedDict, total=False):
|
47 |
+
"""Parameters for querying an indicator's values across locations for a year.
|
48 |
+
|
49 |
+
This class defines the parameters needed to query climate indicator data
|
50 |
+
across different locations for a specific year.
|
51 |
+
|
52 |
+
Attributes:
|
53 |
+
indicator_column (str): The column name for the climate indicator
|
54 |
+
year (str): The year to query
|
55 |
+
model (str): The climate model to use (optional)
|
56 |
+
"""
|
57 |
+
indicator_column: str
|
58 |
+
year: str
|
59 |
+
model: str
|
60 |
+
|
61 |
+
def indicator_for_given_year_query(
|
62 |
+
table:str, params: IndicatorForGivenYearQueryParams
|
63 |
+
) -> str:
|
64 |
+
"""SQL Query to get the values of an indicator with their latitudes, longitudes and models for a given year
|
65 |
+
|
66 |
+
Args:
|
67 |
+
table (str): sql table of the indicator
|
68 |
+
params (IndicatorForGivenYearQueryParams): dictionarry with the required params for the query
|
69 |
+
|
70 |
+
Returns:
|
71 |
+
str: the sql query
|
72 |
+
"""
|
73 |
+
indicator_column = params.get("indicator_column")
|
74 |
+
year = params.get('year')
|
75 |
+
if year is None:
|
76 |
+
year = 2050
|
77 |
+
if year is None or indicator_column is None: # If one parameter is missing, returns an empty query
|
78 |
+
return ""
|
79 |
+
|
80 |
+
table = f"'{DRIAS_DATASET_URL}/{table.lower()}.parquet'"
|
81 |
+
|
82 |
+
sql_query = f"Select {indicator_column}, latitude, longitude, model\nFrom {table}\nWhere year = {year}"
|
83 |
+
return sql_query
|
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Literal, Optional, cast
|
2 |
+
import ast
|
3 |
+
from langchain_core.prompts import ChatPromptTemplate
|
4 |
+
from geopy.geocoders import Nominatim
|
5 |
+
from climateqa.engine.llm import get_llm
|
6 |
+
import duckdb
|
7 |
+
import os
|
8 |
+
from climateqa.engine.talk_to_data.config import DRIAS_MEAN_ANNUAL_TEMPERATURE_PATH, IPCC_COORDINATES_PATH
|
9 |
+
from climateqa.engine.talk_to_data.objects.llm_outputs import ArrayOutput
|
10 |
+
from climateqa.engine.talk_to_data.objects.location import Location
|
11 |
+
from climateqa.engine.talk_to_data.objects.plot import Plot
|
12 |
+
from climateqa.engine.talk_to_data.objects.states import State
|
13 |
+
|
14 |
+
async def detect_location_with_openai(sentence: str) -> str:
|
15 |
+
"""
|
16 |
+
Detects locations in a sentence using OpenAI's API via LangChain.
|
17 |
+
"""
|
18 |
+
llm = get_llm()
|
19 |
+
|
20 |
+
prompt = f"""
|
21 |
+
Extract all locations (cities, countries, states, or geographical areas) mentioned in the following sentence.
|
22 |
+
Return the result as a Python list. If no locations are mentioned, return an empty list.
|
23 |
+
|
24 |
+
Sentence: "{sentence}"
|
25 |
+
"""
|
26 |
+
|
27 |
+
response = await llm.ainvoke(prompt)
|
28 |
+
location_list = ast.literal_eval(response.content.strip("```python\n").strip())
|
29 |
+
if location_list:
|
30 |
+
return location_list[0]
|
31 |
+
else:
|
32 |
+
return ""
|
33 |
+
|
34 |
+
def loc_to_coords(location: str) -> tuple[float, float]:
|
35 |
+
"""Converts a location name to geographic coordinates.
|
36 |
+
|
37 |
+
This function uses the Nominatim geocoding service to convert
|
38 |
+
a location name (e.g., city name) to its latitude and longitude.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
location (str): The name of the location to geocode
|
42 |
+
|
43 |
+
Returns:
|
44 |
+
tuple[float, float]: A tuple containing (latitude, longitude)
|
45 |
+
|
46 |
+
Raises:
|
47 |
+
AttributeError: If the location cannot be found
|
48 |
+
"""
|
49 |
+
geolocator = Nominatim(user_agent="city_to_latlong")
|
50 |
+
coords = geolocator.geocode(location)
|
51 |
+
return (coords.latitude, coords.longitude)
|
52 |
+
|
53 |
+
def coords_to_country(coords: tuple[float, float]) -> tuple[str,str]:
|
54 |
+
"""Converts geographic coordinates to a country name.
|
55 |
+
|
56 |
+
This function uses the Nominatim reverse geocoding service to convert
|
57 |
+
latitude and longitude coordinates to a country name.
|
58 |
+
|
59 |
+
Args:
|
60 |
+
coords (tuple[float, float]): A tuple containing (latitude, longitude)
|
61 |
+
|
62 |
+
Returns:
|
63 |
+
tuple[str,str]: A tuple containg (country_code, country_name, admin1)
|
64 |
+
|
65 |
+
Raises:
|
66 |
+
AttributeError: If the coordinates cannot be found
|
67 |
+
"""
|
68 |
+
geolocator = Nominatim(user_agent="latlong_to_country")
|
69 |
+
location = geolocator.reverse(coords)
|
70 |
+
address = location.raw['address']
|
71 |
+
return address['country_code'].upper(), address['country']
|
72 |
+
|
73 |
+
def nearest_neighbour_sql(location: tuple, mode: Literal['DRIAS', 'IPCC']) -> tuple[str, str, Optional[str]]:
|
74 |
+
long = round(location[1], 3)
|
75 |
+
lat = round(location[0], 3)
|
76 |
+
conn = duckdb.connect()
|
77 |
+
|
78 |
+
if mode == 'DRIAS':
|
79 |
+
table_path = f"'{DRIAS_MEAN_ANNUAL_TEMPERATURE_PATH}'"
|
80 |
+
results = conn.sql(
|
81 |
+
f"SELECT latitude, longitude FROM {table_path} WHERE latitude BETWEEN {lat - 0.3} AND {lat + 0.3} AND longitude BETWEEN {long - 0.3} AND {long + 0.3}"
|
82 |
+
).fetchdf()
|
83 |
+
else:
|
84 |
+
table_path = f"'{IPCC_COORDINATES_PATH}'"
|
85 |
+
results = conn.sql(
|
86 |
+
f"SELECT latitude, longitude, admin1 FROM {table_path} WHERE latitude BETWEEN {lat - 0.5} AND {lat + 0.5} AND longitude BETWEEN {long - 0.5} AND {long + 0.5}"
|
87 |
+
).fetchdf()
|
88 |
+
|
89 |
+
|
90 |
+
if len(results) == 0:
|
91 |
+
return "", "", ""
|
92 |
+
|
93 |
+
if 'admin1' in results.columns:
|
94 |
+
admin1 = results['admin1'].iloc[0]
|
95 |
+
else:
|
96 |
+
admin1 = None
|
97 |
+
return results['latitude'].iloc[0], results['longitude'].iloc[0], admin1
|
98 |
+
|
99 |
+
async def detect_year_with_openai(sentence: str) -> str:
|
100 |
+
"""
|
101 |
+
Detects years in a sentence using OpenAI's API via LangChain.
|
102 |
+
"""
|
103 |
+
llm = get_llm()
|
104 |
+
|
105 |
+
prompt = """
|
106 |
+
Extract all years mentioned in the following sentence.
|
107 |
+
Return the result as a Python list. If no year are mentioned, return an empty list.
|
108 |
+
|
109 |
+
Sentence: "{sentence}"
|
110 |
+
"""
|
111 |
+
|
112 |
+
prompt = ChatPromptTemplate.from_template(prompt)
|
113 |
+
structured_llm = llm.with_structured_output(ArrayOutput)
|
114 |
+
chain = prompt | structured_llm
|
115 |
+
response: ArrayOutput = await chain.ainvoke({"sentence": sentence})
|
116 |
+
years_list = ast.literal_eval(response['array'])
|
117 |
+
if len(years_list) > 0:
|
118 |
+
return years_list[0]
|
119 |
+
else:
|
120 |
+
return ""
|
121 |
+
|
122 |
+
|
123 |
+
async def detect_relevant_tables(user_question: str, plot: Plot, llm, table_names_list: list[str]) -> list[str]:
|
124 |
+
"""Identifies relevant tables for a plot based on user input.
|
125 |
+
|
126 |
+
This function uses an LLM to analyze the user's question and the plot
|
127 |
+
description to determine which tables in the DRIAS database would be
|
128 |
+
most relevant for generating the requested visualization.
|
129 |
+
|
130 |
+
Args:
|
131 |
+
user_question (str): The user's question about climate data
|
132 |
+
plot (Plot): The plot configuration object
|
133 |
+
llm: The language model instance to use for analysis
|
134 |
+
|
135 |
+
Returns:
|
136 |
+
list[str]: A list of table names that are relevant for the plot
|
137 |
+
|
138 |
+
Example:
|
139 |
+
>>> detect_relevant_tables(
|
140 |
+
... "What will the temperature be like in Paris?",
|
141 |
+
... indicator_evolution_at_location,
|
142 |
+
... llm
|
143 |
+
... )
|
144 |
+
['mean_annual_temperature', 'mean_summer_temperature']
|
145 |
+
"""
|
146 |
+
# Get all table names
|
147 |
+
|
148 |
+
prompt = (
|
149 |
+
f"You are helping to build a plot following this description : {plot['description']}."
|
150 |
+
f"You are given a list of tables and a user question."
|
151 |
+
f"Based on the description of the plot, which table are appropriate for that kind of plot."
|
152 |
+
f"Write the 3 most relevant tables to use. Answer only a python list of table name."
|
153 |
+
f"### List of tables : {table_names_list}"
|
154 |
+
f"### User question : {user_question}"
|
155 |
+
f"### List of table name : "
|
156 |
+
)
|
157 |
+
|
158 |
+
table_names = ast.literal_eval(
|
159 |
+
(await llm.ainvoke(prompt)).content.strip("```python\n").strip()
|
160 |
+
)
|
161 |
+
return table_names
|
162 |
+
|
163 |
+
async def detect_relevant_plots(user_question: str, llm, plot_list: list[Plot]) -> list[str]:
|
164 |
+
plots_description = ""
|
165 |
+
for plot in plot_list:
|
166 |
+
plots_description += "Name: " + plot["name"]
|
167 |
+
plots_description += " - Description: " + plot["description"] + "\n"
|
168 |
+
|
169 |
+
prompt = (
|
170 |
+
"You are helping to answer a question with insightful visualizations.\n"
|
171 |
+
"You are given a user question and a list of plots with their name and description.\n"
|
172 |
+
"Based on the descriptions of the plots, select ALL plots that could provide a useful answer to this question. "
|
173 |
+
"Include any plot that could show relevant information, even if their perspectives (such as time series or spatial distribution) are different.\n"
|
174 |
+
"For example, for a question like 'What will be the total rainfall in China in 2050?', both a time series plot and a spatial map plot could be relevant.\n"
|
175 |
+
"Return only a Python list of plot names sorted from the most relevant one to the less relevant one.\n"
|
176 |
+
f"### Descriptions of the plots : {plots_description}"
|
177 |
+
f"### User question : {user_question}\n"
|
178 |
+
f"### Names of the plots : "
|
179 |
+
)
|
180 |
+
|
181 |
+
plot_names = ast.literal_eval(
|
182 |
+
(await llm.ainvoke(prompt)).content.strip("```python\n").strip()
|
183 |
+
)
|
184 |
+
return plot_names
|
185 |
+
|
186 |
+
async def find_location(user_input: str, mode: Literal['DRIAS', 'IPCC'] = 'DRIAS') -> Location:
|
187 |
+
print(f"---- Find location in user input ----")
|
188 |
+
location = await detect_location_with_openai(user_input)
|
189 |
+
output: Location = {
|
190 |
+
'location' : location,
|
191 |
+
'longitude' : None,
|
192 |
+
'latitude' : None,
|
193 |
+
'country_code' : None,
|
194 |
+
'country_name' : None,
|
195 |
+
'admin1' : None
|
196 |
+
}
|
197 |
+
|
198 |
+
if location:
|
199 |
+
coords = loc_to_coords(location)
|
200 |
+
country_code, country_name = coords_to_country(coords)
|
201 |
+
neighbour = nearest_neighbour_sql(coords, mode)
|
202 |
+
output.update({
|
203 |
+
"latitude": neighbour[0],
|
204 |
+
"longitude": neighbour[1],
|
205 |
+
"country_code": country_code,
|
206 |
+
"country_name": country_name,
|
207 |
+
"admin1": neighbour[2]
|
208 |
+
})
|
209 |
+
output = cast(Location, output)
|
210 |
+
return output
|
211 |
+
|
212 |
+
async def find_year(user_input: str) -> str| None:
|
213 |
+
"""Extracts year information from user input using LLM.
|
214 |
+
|
215 |
+
This function uses an LLM to identify and extract year information from the
|
216 |
+
user's query, which is used to filter data in subsequent queries.
|
217 |
+
|
218 |
+
Args:
|
219 |
+
user_input (str): The user's query text
|
220 |
+
|
221 |
+
Returns:
|
222 |
+
str: The extracted year, or empty string if no year found
|
223 |
+
"""
|
224 |
+
print(f"---- Find year ---")
|
225 |
+
year = await detect_year_with_openai(user_input)
|
226 |
+
if year == "":
|
227 |
+
return None
|
228 |
+
return year
|
229 |
+
|
230 |
+
async def find_relevant_plots(state: State, llm, plots: list[Plot]) -> list[str]:
|
231 |
+
print("---- Find relevant plots ----")
|
232 |
+
relevant_plots = await detect_relevant_plots(state['user_input'], llm, plots)
|
233 |
+
return relevant_plots
|
234 |
+
|
235 |
+
async def find_relevant_tables_per_plot(state: State, plot: Plot, llm, tables: list[str]) -> list[str]:
|
236 |
+
print(f"---- Find relevant tables for {plot['name']} ----")
|
237 |
+
relevant_tables = await detect_relevant_tables(state['user_input'], plot, llm, tables)
|
238 |
+
return relevant_tables
|
239 |
+
|
240 |
+
async def find_param(state: State, param_name:str, mode: Literal['DRIAS', 'IPCC'] = 'DRIAS') -> dict[str, Optional[str]] | Location | None:
|
241 |
+
"""Perform the good method to retrieve the desired parameter
|
242 |
+
|
243 |
+
Args:
|
244 |
+
state (State): state of the workflow
|
245 |
+
param_name (str): name of the desired parameter
|
246 |
+
table (str): name of the table
|
247 |
+
|
248 |
+
Returns:
|
249 |
+
dict[str, Any] | None:
|
250 |
+
"""
|
251 |
+
if param_name == 'location':
|
252 |
+
location = await find_location(state['user_input'], mode)
|
253 |
+
return location
|
254 |
+
if param_name == 'year':
|
255 |
+
year = await find_year(state['user_input'])
|
256 |
+
return {'year': year}
|
257 |
+
return None
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from climateqa.engine.talk_to_data.ui_config import PRECIPITATION_COLORSCALE, TEMPERATURE_COLORSCALE
|
2 |
+
from climateqa.engine.talk_to_data.config import IPCC_DATASET_URL
|
3 |
+
|
4 |
+
|
5 |
+
# IPCC_DATASET_URL = "hf://datasets/ekimetrics/ipcc-atlas"
|
6 |
+
IPCC_TABLES = [
|
7 |
+
"mean_temperature",
|
8 |
+
"total_precipitation",
|
9 |
+
]
|
10 |
+
|
11 |
+
IPCC_INDICATOR_COLUMNS_PER_TABLE = {
|
12 |
+
"mean_temperature": "mean_temperature",
|
13 |
+
"total_precipitation": "total_precipitation"
|
14 |
+
}
|
15 |
+
|
16 |
+
IPCC_INDICATOR_TO_UNIT = {
|
17 |
+
"mean_temperature": "°C",
|
18 |
+
"total_precipitation": "mm/day"
|
19 |
+
}
|
20 |
+
|
21 |
+
IPCC_SCENARIO = [
|
22 |
+
"historical",
|
23 |
+
"ssp126",
|
24 |
+
"ssp245",
|
25 |
+
"ssp370",
|
26 |
+
"ssp585",
|
27 |
+
]
|
28 |
+
|
29 |
+
IPCC_MODELS = []
|
30 |
+
|
31 |
+
IPCC_PLOT_PARAMETERS = [
|
32 |
+
'year',
|
33 |
+
'location'
|
34 |
+
]
|
35 |
+
|
36 |
+
MACRO_COUNTRIES = ['JP',
|
37 |
+
'IN',
|
38 |
+
'MH',
|
39 |
+
'PT',
|
40 |
+
'ID',
|
41 |
+
'SJ',
|
42 |
+
'MX',
|
43 |
+
'CN',
|
44 |
+
'GL',
|
45 |
+
'PN',
|
46 |
+
'AR',
|
47 |
+
'AQ',
|
48 |
+
'PF',
|
49 |
+
'BR',
|
50 |
+
'SH',
|
51 |
+
'GS',
|
52 |
+
'ZA',
|
53 |
+
'NZ',
|
54 |
+
'TF',
|
55 |
+
]
|
56 |
+
|
57 |
+
HUGE_MACRO_COUNTRIES = ['CL',
|
58 |
+
'CA',
|
59 |
+
'AU',
|
60 |
+
'US',
|
61 |
+
'RU'
|
62 |
+
]
|
63 |
+
|
64 |
+
IPCC_INDICATOR_TO_COLORSCALE = {
|
65 |
+
"mean_temperature": TEMPERATURE_COLORSCALE,
|
66 |
+
"total_precipitation": PRECIPITATION_COLORSCALE
|
67 |
+
}
|
68 |
+
|
69 |
+
IPCC_UI_TEXT = """
|
70 |
+
Hi, I'm **Talk to IPCC**, designed to answer your questions using [**IPCC - ATLAS**](https://interactive-atlas.ipcc.ch/regional-information#eyJ0eXBlIjoiQVRMQVMiLCJjb21tb25zIjp7ImxhdCI6OTc3MiwibG5nIjo0MDA2OTIsInpvb20iOjQsInByb2oiOiJFUFNHOjU0MDMwIiwibW9kZSI6ImNvbXBsZXRlX2F0bGFzIn0sInByaW1hcnkiOnsic2NlbmFyaW8iOiJzc3A1ODUiLCJwZXJpb2QiOiIyIiwic2Vhc29uIjoieWVhciIsImRhdGFzZXQiOiJDTUlQNiIsInZhcmlhYmxlIjoidGFzIiwidmFsdWVUeXBlIjoiQU5PTUFMWSIsImhhdGNoaW5nIjoiU0lNUExFIiwicmVnaW9uU2V0IjoiYXI2IiwiYmFzZWxpbmUiOiJwcmVJbmR1c3RyaWFsIiwicmVnaW9uc1NlbGVjdGVkIjpbXX0sInBsb3QiOnsiYWN0aXZlVGFiIjoicGx1bWUiLCJtYXNrIjoibm9uZSIsInNjYXR0ZXJZTWFnIjpudWxsLCJzY2F0dGVyWVZhciI6bnVsbCwic2hvd2luZyI6ZmFsc2V9fQ==) data.
|
71 |
+
I'll answer by displaying a list of SQL queries, graphs and data most relevant to your question.
|
72 |
+
|
73 |
+
You can ask me anything about these climate indicators: **temperature** or **precipitation**.
|
74 |
+
You can specify **location** and/or **year**.
|
75 |
+
By default, we take the **mediane of each climate model**.
|
76 |
+
|
77 |
+
Current available charts :
|
78 |
+
- Yearly evolution of an indicator at a specific location (historical + SSP Projections)
|
79 |
+
- Yearly spatial distribution of an indicator in a specific country
|
80 |
+
|
81 |
+
Current available indicators :
|
82 |
+
- Mean temperature
|
83 |
+
- Total precipitation
|
84 |
+
|
85 |
+
For example, you can ask:
|
86 |
+
- What will the temperature be like in Paris?
|
87 |
+
- What will be the total rainfall in the USA in 2030?
|
88 |
+
- How will the average temperature evolve in China ?
|
89 |
+
|
90 |
+
⚠️ **Limitations**:
|
91 |
+
- You can't ask anything that isn't related to **IPCC - ATLAS** data.
|
92 |
+
- You can not ask about **several locations at the same time**.
|
93 |
+
- If you specify a year **before 1850 or over 2100**, there will be **no data**.
|
94 |
+
- You **cannot compare two models**.
|
95 |
+
|
96 |
+
🛈 **Information**
|
97 |
+
Please note that we **log your questions for meta-analysis purposes**, so avoid sharing any sensitive or personal information.
|
98 |
+
"""
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from climateqa.engine.talk_to_data.ipcc.config import IPCC_INDICATOR_TO_UNIT
|
2 |
+
|
3 |
+
def indicator_evolution_informations(
|
4 |
+
indicator: str,
|
5 |
+
params: dict[str,str],
|
6 |
+
) -> str:
|
7 |
+
if "location" not in params:
|
8 |
+
raise ValueError('"location" must be provided in params')
|
9 |
+
location = params["location"]
|
10 |
+
|
11 |
+
unit = IPCC_INDICATOR_TO_UNIT[indicator]
|
12 |
+
return f"""
|
13 |
+
This plot shows how the climate indicator **{indicator}** evolves over time in **{location}**.
|
14 |
+
|
15 |
+
It combines both historical (from 1950 to 2015) observations and future (from 2016 to 2100) projections for the different SSP climate scenarios (SSP126, SSP245, SSP370 and SSP585).
|
16 |
+
|
17 |
+
The x-axis represents the years (from 1950 to 2100), and the y-axis shows the value of the {indicator} ({unit}).
|
18 |
+
|
19 |
+
Each line corresponds to a different scenario, allowing you to compare how {indicator} might change under various future conditions.
|
20 |
+
|
21 |
+
**Data source:**
|
22 |
+
- The data comes from the CMIP6 IPCC ATLAS data. The data were initially extracted from [this referenced website](https://digital.csic.es/handle/10261/332744) and then preprocessed to a tabular format and uploaded as parquet in this [Hugging Face dataset](https://huggingface.co/datasets/Ekimetrics/ipcc-atlas).
|
23 |
+
- The underlying data is retrieved by aggregating yearly values of {indicator} for the selected location, across all available scenarios. This means the system collects, for each year, the value of {indicator} in {location}, both for the historical period and for each scenario, to build the time series.
|
24 |
+
- The coordinates used for {location} correspond to the closest available point in the IPCC database, which uses a regular grid with a spatial resolution of 1 degree.
|
25 |
+
"""
|
26 |
+
|
27 |
+
def choropleth_map_informations(
|
28 |
+
indicator: str,
|
29 |
+
params: dict[str, str],
|
30 |
+
) -> str:
|
31 |
+
unit = IPCC_INDICATOR_TO_UNIT[indicator]
|
32 |
+
if "location" not in params:
|
33 |
+
raise ValueError('"location" must be provided in params')
|
34 |
+
location = params["location"]
|
35 |
+
country_name = params["country_name"]
|
36 |
+
year = params["year"]
|
37 |
+
if year is None:
|
38 |
+
year = 2050
|
39 |
+
|
40 |
+
return f"""
|
41 |
+
This plot displays a choropleth map showing the spatial distribution of **{indicator}** across all regions of **{location}** country ({country_name}) for the year **{year}** and the chosen scenario.
|
42 |
+
|
43 |
+
Each grid point is colored according to the value of the indicator ({unit}), allowing you to visually compare how {indicator} varies geographically within the country for the selected year and scenario.
|
44 |
+
|
45 |
+
**Data source:**
|
46 |
+
- The data come from the CMIP6 IPCC ATLAS data. The data were initially extracted from [this referenced website](https://digital.csic.es/handle/10261/332744) and then preprocessed to a tabular format and uploaded as parquet in this [Hugging Face dataset](https://huggingface.co/datasets/Ekimetrics/ipcc-atlas).
|
47 |
+
- For each grid point of {location} country ({country_name}), the value of {indicator} in {year} and for the selected scenario is extracted and mapped to its geographic coordinates.
|
48 |
+
- The grid points correspond to 1-degree squares centered on the grid points of the IPCC dataset. Each grid point has been mapped to a country using [**reverse_geocoder**](https://github.com/thampiman/reverse-geocoder).
|
49 |
+
- The coordinates used for each region are those of the closest available grid point in the IPCC database, which uses a regular grid with a spatial resolution of 1 degree.
|
50 |
+
"""
|
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Callable
|
2 |
+
from plotly.graph_objects import Figure
|
3 |
+
import plotly.graph_objects as go
|
4 |
+
import pandas as pd
|
5 |
+
import geojson
|
6 |
+
|
7 |
+
from climateqa.engine.talk_to_data.ipcc.config import IPCC_INDICATOR_TO_COLORSCALE, IPCC_INDICATOR_TO_UNIT, IPCC_SCENARIO
|
8 |
+
from climateqa.engine.talk_to_data.ipcc.plot_informations import choropleth_map_informations, indicator_evolution_informations
|
9 |
+
from climateqa.engine.talk_to_data.ipcc.queries import indicator_for_given_year_query, indicator_per_year_at_location_query
|
10 |
+
from climateqa.engine.talk_to_data.objects.plot import Plot
|
11 |
+
|
12 |
+
def generate_geojson_polygons(latitudes: list[float], longitudes: list[float], indicators: list[float]) -> geojson.FeatureCollection:
|
13 |
+
features = [
|
14 |
+
geojson.Feature(
|
15 |
+
geometry=geojson.Polygon([[
|
16 |
+
[lon - 0.5, lat - 0.5],
|
17 |
+
[lon + 0.5, lat - 0.5],
|
18 |
+
[lon + 0.5, lat + 0.5],
|
19 |
+
[lon - 0.5, lat + 0.5],
|
20 |
+
[lon - 0.5, lat - 0.5]
|
21 |
+
]]),
|
22 |
+
properties={"value": val},
|
23 |
+
id=str(idx)
|
24 |
+
)
|
25 |
+
for idx, (lat, lon, val) in enumerate(zip(latitudes, longitudes, indicators))
|
26 |
+
]
|
27 |
+
|
28 |
+
geojson_data = geojson.FeatureCollection(features)
|
29 |
+
return geojson_data
|
30 |
+
|
31 |
+
def plot_indicator_evolution_at_location_historical_and_projections(
|
32 |
+
params: dict,
|
33 |
+
) -> Callable[[pd.DataFrame], Figure]:
|
34 |
+
"""
|
35 |
+
Returns a function that generates a line plot showing the evolution of a climate indicator
|
36 |
+
(e.g., temperature, rainfall) over time at a specific location, including both historical data
|
37 |
+
and future projections for different climate scenarios.
|
38 |
+
|
39 |
+
Args:
|
40 |
+
params (dict): Dictionary with:
|
41 |
+
- indicator_column (str): Name of the climate indicator column to plot.
|
42 |
+
- location (str): Location (e.g., country, city) for which to plot the indicator.
|
43 |
+
|
44 |
+
Returns:
|
45 |
+
Callable[[pd.DataFrame], Figure]: Function that takes a DataFrame and returns a Plotly Figure
|
46 |
+
showing the indicator's evolution over time, with scenario lines and historical data.
|
47 |
+
"""
|
48 |
+
indicator = params["indicator_column"]
|
49 |
+
location = params["location"]
|
50 |
+
indicator_label = " ".join(word.capitalize() for word in indicator.split("_"))
|
51 |
+
unit = IPCC_INDICATOR_TO_UNIT.get(indicator, "")
|
52 |
+
|
53 |
+
def plot_data(df: pd.DataFrame) -> Figure:
|
54 |
+
df = df.sort_values(by='year')
|
55 |
+
years = df['year'].astype(int).tolist()
|
56 |
+
indicators = df[indicator].astype(float).tolist()
|
57 |
+
scenarios = df['scenario'].astype(str).tolist()
|
58 |
+
|
59 |
+
# Find last historical value for continuity
|
60 |
+
last_historical = [(y, v) for y, v, s in zip(years, indicators, scenarios) if s == 'historical']
|
61 |
+
last_historical_year, last_historical_indicator = last_historical[-1] if last_historical else (None, None)
|
62 |
+
|
63 |
+
fig = go.Figure()
|
64 |
+
for scenario in IPCC_SCENARIO:
|
65 |
+
x = [y for y, s in zip(years, scenarios) if s == scenario]
|
66 |
+
y = [v for v, s in zip(indicators, scenarios) if s == scenario]
|
67 |
+
# Connect historical to scenario
|
68 |
+
if scenario != 'historical' and last_historical_indicator is not None:
|
69 |
+
x = [last_historical_year] + x
|
70 |
+
y = [last_historical_indicator] + y
|
71 |
+
fig.add_trace(go.Scatter(
|
72 |
+
x=x,
|
73 |
+
y=y,
|
74 |
+
mode='lines',
|
75 |
+
name=scenario
|
76 |
+
))
|
77 |
+
|
78 |
+
fig.update_layout(
|
79 |
+
title=f'Yearly Evolution of {indicator_label} in {location} (Historical + SSP Scenarios)',
|
80 |
+
xaxis_title='Year',
|
81 |
+
yaxis_title=f'{indicator_label} ({unit})',
|
82 |
+
legend_title='Scenario',
|
83 |
+
height=800,
|
84 |
+
)
|
85 |
+
return fig
|
86 |
+
|
87 |
+
return plot_data
|
88 |
+
|
89 |
+
indicator_evolution_at_location_historical_and_projections: Plot = {
|
90 |
+
"name": "Indicator Evolution at Location (Historical + Projections)",
|
91 |
+
"description": (
|
92 |
+
"Shows how a climate indicator (e.g., rainfall, temperature) changes over time at a specific location, "
|
93 |
+
"including historical data and future projections. "
|
94 |
+
"Useful for questions about the value or trend of an indicator at a location for any year, "
|
95 |
+
"such as 'What will be the total rainfall in China in 2050?' or 'How does rainfall evolve in China over time?'. "
|
96 |
+
"Parameters: indicator_column (the climate variable), location (e.g., country, city)."
|
97 |
+
),
|
98 |
+
"params": ["indicator_column", "location"],
|
99 |
+
"plot_function": plot_indicator_evolution_at_location_historical_and_projections,
|
100 |
+
"sql_query": indicator_per_year_at_location_query,
|
101 |
+
"plot_information": indicator_evolution_informations,
|
102 |
+
"short_name": "Evolution"
|
103 |
+
}
|
104 |
+
|
105 |
+
def plot_choropleth_map_of_country_indicator_for_specific_year(
|
106 |
+
params: dict,
|
107 |
+
) -> Callable[[pd.DataFrame], Figure]:
|
108 |
+
"""
|
109 |
+
Returns a function that generates a choropleth map (heatmap) showing the spatial distribution
|
110 |
+
of a climate indicator (e.g., temperature, rainfall) across all regions of a country for a specific year.
|
111 |
+
|
112 |
+
Args:
|
113 |
+
params (dict): Dictionary with:
|
114 |
+
- indicator_column (str): Name of the climate indicator column to plot.
|
115 |
+
- year (str or int, optional): Year for which to plot the indicator (default: 2050).
|
116 |
+
- country_name (str): Name of the country.
|
117 |
+
- location (str): Location (country or region) for the map.
|
118 |
+
|
119 |
+
Returns:
|
120 |
+
Callable[[pd.DataFrame], Figure]: Function that takes a DataFrame and returns a Plotly Figure
|
121 |
+
showing the indicator's spatial distribution as a choropleth map for the specified year.
|
122 |
+
"""
|
123 |
+
indicator = params["indicator_column"]
|
124 |
+
year = params.get('year')
|
125 |
+
if year is None:
|
126 |
+
year = 2050
|
127 |
+
country_name = params['country_name']
|
128 |
+
location = params['location']
|
129 |
+
indicator_label = " ".join(word.capitalize() for word in indicator.split("_"))
|
130 |
+
unit = IPCC_INDICATOR_TO_UNIT.get(indicator, "")
|
131 |
+
|
132 |
+
def plot_data(df: pd.DataFrame) -> Figure:
|
133 |
+
|
134 |
+
indicators = df[indicator].astype(float).tolist()
|
135 |
+
latitudes = df["latitude"].astype(float).tolist()
|
136 |
+
longitudes = df["longitude"].astype(float).tolist()
|
137 |
+
|
138 |
+
geojson_data = generate_geojson_polygons(latitudes, longitudes, indicators)
|
139 |
+
|
140 |
+
fig = go.Figure(go.Choroplethmapbox(
|
141 |
+
geojson=geojson_data,
|
142 |
+
locations=[str(i) for i in range(len(indicators))],
|
143 |
+
featureidkey="id",
|
144 |
+
z=indicators,
|
145 |
+
colorscale=IPCC_INDICATOR_TO_COLORSCALE[indicator],
|
146 |
+
zmin=min(indicators),
|
147 |
+
zmax=max(indicators),
|
148 |
+
marker_opacity=0.7,
|
149 |
+
marker_line_width=0,
|
150 |
+
colorbar_title=f"{indicator_label} ({unit})",
|
151 |
+
text=[f"{indicator_label}: {value:.2f} {unit}" for value in indicators], # Add hover text showing the indicator value
|
152 |
+
hoverinfo="text"
|
153 |
+
))
|
154 |
+
|
155 |
+
fig.update_layout(
|
156 |
+
mapbox_style="open-street-map",
|
157 |
+
mapbox_zoom=2,
|
158 |
+
height=800,
|
159 |
+
mapbox_center={
|
160 |
+
"lat": latitudes[len(latitudes)//2] if latitudes else 0,
|
161 |
+
"lon": longitudes[len(longitudes)//2] if longitudes else 0
|
162 |
+
},
|
163 |
+
coloraxis_colorbar=dict(title=f"{indicator_label} ({unit})"),
|
164 |
+
title=f"{indicator_label} in {year} in {location} ({country_name})"
|
165 |
+
)
|
166 |
+
return fig
|
167 |
+
|
168 |
+
return plot_data
|
169 |
+
|
170 |
+
choropleth_map_of_country_indicator_for_specific_year: Plot = {
|
171 |
+
"name": "Choropleth Map of a Country's Indicator Distribution for a Specific Year",
|
172 |
+
"description": (
|
173 |
+
"Displays a map showing the spatial distribution of a climate indicator (e.g., rainfall, temperature) "
|
174 |
+
"across all regions of a country for a specific year. "
|
175 |
+
"Can answer questions about the value of an indicator in a country or region for a given year, "
|
176 |
+
"such as 'What will be the total rainfall in China in 2050?' or 'How is rainfall distributed across China in 2050?'. "
|
177 |
+
"Parameters: indicator_column (the climate variable), year, location (country name)."
|
178 |
+
),
|
179 |
+
"params": ["indicator_column", "year", "location"],
|
180 |
+
"plot_function": plot_choropleth_map_of_country_indicator_for_specific_year,
|
181 |
+
"sql_query": indicator_for_given_year_query,
|
182 |
+
"plot_information": choropleth_map_informations,
|
183 |
+
"short_name": "Map",
|
184 |
+
}
|
185 |
+
|
186 |
+
IPCC_PLOTS = [
|
187 |
+
indicator_evolution_at_location_historical_and_projections,
|
188 |
+
choropleth_map_of_country_indicator_for_specific_year
|
189 |
+
]
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import TypedDict, Optional
|
2 |
+
|
3 |
+
from climateqa.engine.talk_to_data.ipcc.config import HUGE_MACRO_COUNTRIES, MACRO_COUNTRIES
|
4 |
+
from climateqa.engine.talk_to_data.config import IPCC_DATASET_URL
|
5 |
+
class IndicatorPerYearAtLocationQueryParams(TypedDict, total=False):
|
6 |
+
"""
|
7 |
+
Parameters for querying the evolution of an indicator per year at a specific location.
|
8 |
+
|
9 |
+
Attributes:
|
10 |
+
indicator_column (str): Name of the climate indicator column.
|
11 |
+
latitude (str): Latitude of the location.
|
12 |
+
longitude (str): Longitude of the location.
|
13 |
+
country_code (str): Country code.
|
14 |
+
admin1 (str): Administrative region (optional).
|
15 |
+
"""
|
16 |
+
indicator_column: str
|
17 |
+
latitude: str
|
18 |
+
longitude: str
|
19 |
+
country_code: str
|
20 |
+
admin1: Optional[str]
|
21 |
+
|
22 |
+
def indicator_per_year_at_location_query(
|
23 |
+
table: str, params: IndicatorPerYearAtLocationQueryParams
|
24 |
+
) -> str:
|
25 |
+
"""
|
26 |
+
Builds an SQL query to get the evolution of an indicator per year at a specific location.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
table (str): SQL table of the indicator.
|
30 |
+
params (IndicatorPerYearAtLocationQueryParams): Dictionary with the required params for the query.
|
31 |
+
|
32 |
+
Returns:
|
33 |
+
str: The SQL query string, or an empty string if required parameters are missing.
|
34 |
+
"""
|
35 |
+
indicator_column = params.get("indicator_column")
|
36 |
+
latitude = params.get("latitude")
|
37 |
+
longitude = params.get("longitude")
|
38 |
+
country_code = params.get("country_code")
|
39 |
+
admin1 = params.get("admin1")
|
40 |
+
|
41 |
+
if not all([indicator_column, latitude, longitude, country_code]):
|
42 |
+
return ""
|
43 |
+
|
44 |
+
if country_code in MACRO_COUNTRIES:
|
45 |
+
table_path = f"'{IPCC_DATASET_URL}/{table.lower()}/{country_code}_macro.parquet'"
|
46 |
+
sql_query = f"""
|
47 |
+
SELECT year, scenario, AVG({indicator_column}) as {indicator_column}
|
48 |
+
FROM {table_path}
|
49 |
+
WHERE latitude = {latitude} AND longitude = {longitude} AND year >= 1950
|
50 |
+
GROUP BY scenario, year
|
51 |
+
ORDER BY year, scenario
|
52 |
+
"""
|
53 |
+
elif country_code in HUGE_MACRO_COUNTRIES:
|
54 |
+
table_path = f"'{IPCC_DATASET_URL}/{table.lower()}/{country_code}_macro.parquet'"
|
55 |
+
sql_query = f"""
|
56 |
+
SELECT year, scenario, {indicator_column},
|
57 |
+
FROM {table_path}
|
58 |
+
WHERE latitude = {latitude} AND longitude = {longitude} AND year >= 1950
|
59 |
+
ORDER year, scenario
|
60 |
+
"""
|
61 |
+
else:
|
62 |
+
table_path = f"'{IPCC_DATASET_URL}/{table.lower()}/{country_code}.parquet'"
|
63 |
+
sql_query = f"""
|
64 |
+
WITH medians_per_month AS (
|
65 |
+
SELECT year, scenario, month, MEDIAN({indicator_column}) AS median_value
|
66 |
+
FROM {table_path}
|
67 |
+
WHERE latitude = {latitude} AND longitude = {longitude} AND year >= 1950
|
68 |
+
GROUP BY scenario, year, month
|
69 |
+
)
|
70 |
+
SELECT year, scenario, AVG(median_value) AS {indicator_column}
|
71 |
+
FROM medians_per_month
|
72 |
+
GROUP BY scenario, year
|
73 |
+
ORDER BY year, scenario
|
74 |
+
"""
|
75 |
+
return sql_query.strip()
|
76 |
+
|
77 |
+
class IndicatorForGivenYearQueryParams(TypedDict, total=False):
|
78 |
+
"""
|
79 |
+
Parameters for querying an indicator's values across locations for a specific year.
|
80 |
+
|
81 |
+
Attributes:
|
82 |
+
indicator_column (str): The column name for the climate indicator.
|
83 |
+
year (str): The year to query.
|
84 |
+
country_code (str): The country code.
|
85 |
+
"""
|
86 |
+
indicator_column: str
|
87 |
+
year: str
|
88 |
+
country_code: str
|
89 |
+
|
90 |
+
def indicator_for_given_year_query(
|
91 |
+
table: str, params: IndicatorForGivenYearQueryParams
|
92 |
+
) -> str:
|
93 |
+
"""
|
94 |
+
Builds an SQL query to get the values of an indicator with their latitudes, longitudes,
|
95 |
+
and scenarios for a given year.
|
96 |
+
|
97 |
+
Args:
|
98 |
+
table (str): SQL table of the indicator.
|
99 |
+
params (IndicatorForGivenYearQueryParams): Dictionary with the required params for the query.
|
100 |
+
|
101 |
+
Returns:
|
102 |
+
str: The SQL query string, or an empty string if required parameters are missing.
|
103 |
+
"""
|
104 |
+
indicator_column = params.get("indicator_column")
|
105 |
+
year = params.get("year") or 2050
|
106 |
+
country_code = params.get("country_code")
|
107 |
+
|
108 |
+
if not all([indicator_column, year, country_code]):
|
109 |
+
return ""
|
110 |
+
|
111 |
+
if country_code in MACRO_COUNTRIES:
|
112 |
+
table_path = f"'{IPCC_DATASET_URL}/{table.lower()}/{country_code}_macro.parquet'"
|
113 |
+
sql_query = f"""
|
114 |
+
SELECT latitude, longitude, scenario, AVG({indicator_column}) as {indicator_column}
|
115 |
+
FROM {table_path}
|
116 |
+
WHERE year = {year}
|
117 |
+
GROUP BY latitude, longitude, scenario
|
118 |
+
ORDER BY latitude, longitude, scenario
|
119 |
+
"""
|
120 |
+
elif country_code in HUGE_MACRO_COUNTRIES:
|
121 |
+
table_path = f"'{IPCC_DATASET_URL}/{table.lower()}/{country_code}_macro.parquet'"
|
122 |
+
sql_query = f"""
|
123 |
+
SELECT latitude, longitude, scenario, {indicator_column},
|
124 |
+
FROM {table_path}
|
125 |
+
WHERE year = {year}
|
126 |
+
ORDER BY latitude, longitude, scenario
|
127 |
+
"""
|
128 |
+
else:
|
129 |
+
table_path = f"'{IPCC_DATASET_URL}/{table.lower()}/{country_code}.parquet'"
|
130 |
+
sql_query = f"""
|
131 |
+
WITH medians_per_month AS (
|
132 |
+
SELECT latitude, longitude, scenario, month, MEDIAN({indicator_column}) AS median_value
|
133 |
+
FROM {table_path}
|
134 |
+
WHERE year = {year}
|
135 |
+
GROUP BY latitude, longitude, scenario, month
|
136 |
+
)
|
137 |
+
SELECT latitude, longitude, scenario, AVG(median_value) AS {indicator_column}
|
138 |
+
FROM medians_per_month
|
139 |
+
GROUP BY latitude, longitude, scenario
|
140 |
+
ORDER BY latitude, longitude, scenario
|
141 |
+
"""
|
142 |
+
|
143 |
+
return sql_query.strip()
|
@@ -1,44 +1,70 @@
|
|
1 |
-
from climateqa.engine.talk_to_data.
|
2 |
-
from climateqa.engine.
|
3 |
from climateqa.logging import log_drias_interaction_to_huggingface
|
4 |
-
import ast
|
5 |
|
6 |
-
|
7 |
-
|
8 |
-
def ask_llm_to_add_table_names(sql_query: str, llm) -> str:
|
9 |
-
"""Adds table names to the SQL query result rows using LLM.
|
10 |
|
11 |
-
This function
|
12 |
-
|
|
|
13 |
|
14 |
Args:
|
15 |
-
|
16 |
-
|
17 |
|
18 |
Returns:
|
19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
"""
|
21 |
-
|
22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
-
|
25 |
-
|
|
|
|
|
26 |
|
27 |
-
This function analyzes a SQL query to identify which columns are being selected
|
28 |
-
in the result set.
|
29 |
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
Returns:
|
35 |
-
list[str]: A list of column names being selected in the query
|
36 |
-
"""
|
37 |
-
columns = llm.invoke(f"From the given sql query, list the columns that are being selected. The answer should only be a python list. Just answer the list. The SQL query : {sql_query}").content
|
38 |
-
columns_list = ast.literal_eval(columns.strip("```python\n").strip())
|
39 |
-
return columns_list
|
40 |
|
41 |
-
|
|
|
|
|
42 |
"""Main function to process a DRIAS query and return results.
|
43 |
|
44 |
This function orchestrates the DRIAS workflow, processing a user query to generate
|
@@ -61,58 +87,38 @@ async def ask_drias(query: str, index_state: int = 0, user_id: str = None) -> tu
|
|
61 |
- table_list (list): List of table names used
|
62 |
- error (str): Error message if any
|
63 |
"""
|
64 |
-
final_state = await
|
65 |
sql_queries = []
|
66 |
result_dataframes = []
|
67 |
figures = []
|
68 |
-
|
|
|
69 |
|
70 |
-
for
|
71 |
-
|
72 |
-
if
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
|
|
|
|
|
|
82 |
|
83 |
if "error" in final_state and final_state["error"] != "":
|
84 |
-
|
|
|
85 |
|
86 |
sql_query = sql_queries[index_state]
|
87 |
dataframe = result_dataframes[index_state]
|
88 |
figure = figures[index_state](dataframe)
|
|
|
89 |
|
90 |
log_drias_interaction_to_huggingface(query, sql_query, user_id)
|
91 |
|
92 |
-
return sql_query, dataframe, figure, sql_queries, result_dataframes, figures, index_state,
|
93 |
-
|
94 |
-
# def ask_vanna(vn,db_vanna_path, query):
|
95 |
-
|
96 |
-
# try :
|
97 |
-
# location = detect_location_with_openai(query)
|
98 |
-
# if location:
|
99 |
-
|
100 |
-
# coords = loc2coords(location)
|
101 |
-
# user_input = query.lower().replace(location.lower(), f"lat, long : {coords}")
|
102 |
-
|
103 |
-
# relevant_tables = detect_relevant_tables(db_vanna_path, user_input, llm)
|
104 |
-
# coords_tables = [nearestNeighbourSQL(db_vanna_path, coords, relevant_tables[i]) for i in range(len(relevant_tables))]
|
105 |
-
# user_input_with_coords = replace_coordonates(coords, user_input, coords_tables)
|
106 |
-
|
107 |
-
# sql_query, result_dataframe, figure = vn.ask(user_input_with_coords, print_results=False, allow_llm_to_see_data=True, auto_train=False)
|
108 |
-
|
109 |
-
# return sql_query, result_dataframe, figure
|
110 |
-
# else :
|
111 |
-
# empty_df = pd.DataFrame()
|
112 |
-
# empty_fig = None
|
113 |
-
# return "", empty_df, empty_fig
|
114 |
-
# except Exception as e:
|
115 |
-
# print(f"Error: {e}")
|
116 |
-
# empty_df = pd.DataFrame()
|
117 |
-
# empty_fig = None
|
118 |
-
# return "", empty_df, empty_fig
|
|
|
1 |
+
from climateqa.engine.talk_to_data.workflow.drias import drias_workflow
|
2 |
+
from climateqa.engine.talk_to_data.workflow.ipcc import ipcc_workflow
|
3 |
from climateqa.logging import log_drias_interaction_to_huggingface
|
|
|
4 |
|
5 |
+
async def ask_drias(query: str, index_state: int = 0, user_id: str | None = None) -> tuple:
|
6 |
+
"""Main function to process a DRIAS query and return results.
|
|
|
|
|
7 |
|
8 |
+
This function orchestrates the DRIAS workflow, processing a user query to generate
|
9 |
+
SQL queries, dataframes, and visualizations. It handles multiple results and allows
|
10 |
+
pagination through them.
|
11 |
|
12 |
Args:
|
13 |
+
query (str): The user's question about climate data
|
14 |
+
index_state (int, optional): The index of the result to return. Defaults to 0.
|
15 |
|
16 |
Returns:
|
17 |
+
tuple: A tuple containing:
|
18 |
+
- sql_query (str): The SQL query used
|
19 |
+
- dataframe (pd.DataFrame): The resulting data
|
20 |
+
- figure (Callable): Function to generate the visualization
|
21 |
+
- sql_queries (list): All generated SQL queries
|
22 |
+
- result_dataframes (list): All resulting dataframes
|
23 |
+
- figures (list): All figure generation functions
|
24 |
+
- index_state (int): Current result index
|
25 |
+
- table_list (list): List of table names used
|
26 |
+
- error (str): Error message if any
|
27 |
"""
|
28 |
+
final_state = await drias_workflow(query)
|
29 |
+
sql_queries = []
|
30 |
+
result_dataframes = []
|
31 |
+
figures = []
|
32 |
+
plot_title_list = []
|
33 |
+
plot_informations = []
|
34 |
+
|
35 |
+
for output_title, output in final_state['outputs'].items():
|
36 |
+
if output['status'] == 'OK':
|
37 |
+
if output['table'] is not None:
|
38 |
+
plot_title_list.append(output_title)
|
39 |
+
|
40 |
+
if output['plot_information'] is not None:
|
41 |
+
plot_informations.append(output['plot_information'])
|
42 |
+
|
43 |
+
if output['sql_query'] is not None:
|
44 |
+
sql_queries.append(output['sql_query'])
|
45 |
+
|
46 |
+
if output['dataframe'] is not None:
|
47 |
+
result_dataframes.append(output['dataframe'])
|
48 |
+
if output['figure'] is not None:
|
49 |
+
figures.append(output['figure'])
|
50 |
+
|
51 |
+
if "error" in final_state and final_state["error"] != "":
|
52 |
+
# No Sql query, no dataframe, no figure, no plot information, empty sql queries list, empty result dataframes list, empty figures list, empty plot information list, index state = 0, empty table list, error message
|
53 |
+
return None, None, None, None, [], [], [], 0, [], final_state["error"]
|
54 |
|
55 |
+
sql_query = sql_queries[index_state]
|
56 |
+
dataframe = result_dataframes[index_state]
|
57 |
+
figure = figures[index_state](dataframe)
|
58 |
+
plot_information = plot_informations[index_state]
|
59 |
|
|
|
|
|
60 |
|
61 |
+
log_drias_interaction_to_huggingface(query, sql_query, user_id)
|
62 |
+
|
63 |
+
return sql_query, dataframe, figure, plot_information, sql_queries, result_dataframes, figures, plot_informations, index_state, plot_title_list, ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
|
65 |
+
|
66 |
+
|
67 |
+
async def ask_ipcc(query: str, index_state: int = 0, user_id: str | None = None) -> tuple:
|
68 |
"""Main function to process a DRIAS query and return results.
|
69 |
|
70 |
This function orchestrates the DRIAS workflow, processing a user query to generate
|
|
|
87 |
- table_list (list): List of table names used
|
88 |
- error (str): Error message if any
|
89 |
"""
|
90 |
+
final_state = await ipcc_workflow(query)
|
91 |
sql_queries = []
|
92 |
result_dataframes = []
|
93 |
figures = []
|
94 |
+
plot_title_list = []
|
95 |
+
plot_informations = []
|
96 |
|
97 |
+
for output_title, output in final_state['outputs'].items():
|
98 |
+
if output['status'] == 'OK':
|
99 |
+
if output['table'] is not None:
|
100 |
+
plot_title_list.append(output_title)
|
101 |
+
|
102 |
+
if output['plot_information'] is not None:
|
103 |
+
plot_informations.append(output['plot_information'])
|
104 |
+
|
105 |
+
if output['sql_query'] is not None:
|
106 |
+
sql_queries.append(output['sql_query'])
|
107 |
+
|
108 |
+
if output['dataframe'] is not None:
|
109 |
+
result_dataframes.append(output['dataframe'])
|
110 |
+
if output['figure'] is not None:
|
111 |
+
figures.append(output['figure'])
|
112 |
|
113 |
if "error" in final_state and final_state["error"] != "":
|
114 |
+
# No Sql query, no dataframe, no figure, no plot information, empty sql queries list, empty result dataframes list, empty figures list, empty plot information list, index state = 0, empty table list, error message
|
115 |
+
return None, None, None, None, [], [], [], 0, [], final_state["error"]
|
116 |
|
117 |
sql_query = sql_queries[index_state]
|
118 |
dataframe = result_dataframes[index_state]
|
119 |
figure = figures[index_state](dataframe)
|
120 |
+
plot_information = plot_informations[index_state]
|
121 |
|
122 |
log_drias_interaction_to_huggingface(query, sql_query, user_id)
|
123 |
|
124 |
+
return sql_query, dataframe, figure, plot_information, sql_queries, result_dataframes, figures, plot_informations, index_state, plot_title_list, ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Annotated, TypedDict
|
2 |
+
|
3 |
+
|
4 |
+
class ArrayOutput(TypedDict):
|
5 |
+
"""Represents the output of a function that returns an array.
|
6 |
+
|
7 |
+
This class is used to type-hint functions that return arrays,
|
8 |
+
ensuring consistent return types across the codebase.
|
9 |
+
|
10 |
+
Attributes:
|
11 |
+
array (str): A syntactically valid Python array string
|
12 |
+
"""
|
13 |
+
array: Annotated[str, "Syntactically valid python array."]
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from token import OP
|
2 |
+
from typing import Optional, TypedDict
|
3 |
+
|
4 |
+
|
5 |
+
|
6 |
+
class Location(TypedDict):
|
7 |
+
location: str
|
8 |
+
latitude: Optional[str]
|
9 |
+
longitude: Optional[str]
|
10 |
+
country_code: Optional[str]
|
11 |
+
country_name: Optional[str]
|
12 |
+
admin1: Optional[str]
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Callable, TypedDict, Optional
|
2 |
+
from plotly.graph_objects import Figure
|
3 |
+
|
4 |
+
class Plot(TypedDict):
|
5 |
+
"""Represents a plot configuration in the DRIAS system.
|
6 |
+
|
7 |
+
This class defines the structure for configuring different types of plots
|
8 |
+
that can be generated from climate data.
|
9 |
+
|
10 |
+
Attributes:
|
11 |
+
name (str): The name of the plot type
|
12 |
+
description (str): A description of what the plot shows
|
13 |
+
params (list[str]): List of required parameters for the plot
|
14 |
+
plot_function (Callable[..., Callable[..., Figure]]): Function to generate the plot
|
15 |
+
sql_query (Callable[..., str]): Function to generate the SQL query for the plot
|
16 |
+
"""
|
17 |
+
name: str
|
18 |
+
description: str
|
19 |
+
params: list[str]
|
20 |
+
plot_function: Callable[..., Callable[..., Figure]]
|
21 |
+
sql_query: Callable[..., str]
|
22 |
+
plot_information: Callable[..., str]
|
23 |
+
short_name: str
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Callable, Optional, TypedDict
|
2 |
+
from plotly.graph_objects import Figure
|
3 |
+
import pandas as pd
|
4 |
+
from climateqa.engine.talk_to_data.objects.plot import Plot
|
5 |
+
|
6 |
+
class TTDOutput(TypedDict):
|
7 |
+
status: str
|
8 |
+
plot: Plot
|
9 |
+
table: str
|
10 |
+
sql_query: Optional[str]
|
11 |
+
dataframe: Optional[pd.DataFrame]
|
12 |
+
figure: Optional[Callable[..., Figure]]
|
13 |
+
plot_information: Optional[str]
|
14 |
+
class State(TypedDict):
|
15 |
+
user_input: str
|
16 |
+
plots: list[str]
|
17 |
+
outputs: dict[str, TTDOutput]
|
18 |
+
error: Optional[str]
|
19 |
+
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
query_prompt_template = """You are an expert SQL query generator. Given an input question, database schema, SQL dialect and relevant tables to answer the question, generate an optimized and syntactically correct SQL query which can provide useful insights to the question.
|
2 |
+
|
3 |
+
### Instructions:
|
4 |
+
1. **Use only relevant tables**: The following tables are relevant to answering the question: {relevant_tables}. Do not use any other tables.
|
5 |
+
2. **Relevant columns only**: Never select `*`. Only include necessary columns based on the input question.
|
6 |
+
3. **Schema Awareness**:
|
7 |
+
- Use only columns present in the given schema.
|
8 |
+
- **If a column name appears in multiple tables, always use the format `table_name.column_name` to avoid ambiguity.**
|
9 |
+
- Select only the column which are insightful for the question.
|
10 |
+
4. **Dialect Compliance**: Follow `{dialect}` syntax rules.
|
11 |
+
5. **Ordering**: Order the results by a relevant column if applicable (e.g., timestamp for recent records).
|
12 |
+
6. **Valid query**: Make sure the query is syntactically and functionally correct.
|
13 |
+
7. **Conditions** : For the common columns, the same condition should be applied to all the tables (e.g. latitude, longitude, model, year...)
|
14 |
+
9. **Join tables** : If you need to join table, you should join them with year feature.
|
15 |
+
10. **Model** : For each table, you need to add a condition on the model to be equal to {model}
|
16 |
+
|
17 |
+
### Provided Database Schema:
|
18 |
+
{table_info}
|
19 |
+
|
20 |
+
### Relevant Tables:
|
21 |
+
{relevant_tables}
|
22 |
+
|
23 |
+
**Question:** {input}
|
24 |
+
|
25 |
+
**SQL Query:**"""
|
26 |
+
|
27 |
+
plot_prompt_template = """You are a data visualization expert. Given an input question and an SQL Query, generate an insightful plot according to the question.
|
28 |
+
|
29 |
+
### Instructions
|
30 |
+
1. **Use only the column names provided**. The data will be provided as a Pandas DataFrame `df` with the columns present in the SELECT.
|
31 |
+
2. Generate the Python Plotly code to chart the results using `df` and the column names.
|
32 |
+
3. Make as complete a graph as possible to answer the question, and make it as easy to understand as possible.
|
33 |
+
4. **Response with only Python code**. Do not answer with any explanations -- just the code.
|
34 |
+
5. **Specific cases** :
|
35 |
+
- For a question about the evolution of something, it is also relevant to plot the data with also the sliding average for a period of 20 years for example.
|
36 |
+
|
37 |
+
### SQL Query:
|
38 |
+
{sql_query}
|
39 |
+
|
40 |
+
**Question:** {input}
|
41 |
+
|
42 |
+
**Python code:**
|
43 |
+
"""
|
44 |
+
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
from concurrent.futures import ThreadPoolExecutor
|
3 |
+
import duckdb
|
4 |
+
import pandas as pd
|
5 |
+
import os
|
6 |
+
|
7 |
+
def find_indicator_column(table: str, indicator_columns_per_table: dict[str,str]) -> str:
|
8 |
+
"""Retrieves the name of the indicator column within a table.
|
9 |
+
|
10 |
+
This function maps table names to their corresponding indicator columns
|
11 |
+
using the predefined mapping in INDICATOR_COLUMNS_PER_TABLE.
|
12 |
+
|
13 |
+
Args:
|
14 |
+
table (str): Name of the table in the database
|
15 |
+
|
16 |
+
Returns:
|
17 |
+
str: Name of the indicator column for the specified table
|
18 |
+
|
19 |
+
Raises:
|
20 |
+
KeyError: If the table name is not found in the mapping
|
21 |
+
"""
|
22 |
+
print(f"---- Find indicator column in table {table} ----")
|
23 |
+
return indicator_columns_per_table[table]
|
24 |
+
|
25 |
+
async def execute_sql_query(sql_query: str) -> pd.DataFrame:
|
26 |
+
"""Executes a SQL query on the DRIAS database and returns the results.
|
27 |
+
|
28 |
+
This function connects to the DuckDB database containing DRIAS climate data
|
29 |
+
and executes the provided SQL query. It handles the database connection and
|
30 |
+
returns the results as a pandas DataFrame.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
sql_query (str): The SQL query to execute
|
34 |
+
|
35 |
+
Returns:
|
36 |
+
pd.DataFrame: A DataFrame containing the query results
|
37 |
+
|
38 |
+
Raises:
|
39 |
+
duckdb.Error: If there is an error executing the SQL query
|
40 |
+
"""
|
41 |
+
def _execute_query():
|
42 |
+
# Execute the query
|
43 |
+
con = duckdb.connect()
|
44 |
+
HF_TTD_TOKEN = os.getenv("HF_TTD_TOKEN")
|
45 |
+
con.execute(f"""CREATE SECRET hf_token (
|
46 |
+
TYPE huggingface,
|
47 |
+
TOKEN '{HF_TTD_TOKEN}'
|
48 |
+
);""")
|
49 |
+
results = con.execute(sql_query).fetchdf()
|
50 |
+
# return fetched data
|
51 |
+
return results
|
52 |
+
|
53 |
+
# Run the query in a thread pool to avoid blocking
|
54 |
+
loop = asyncio.get_event_loop()
|
55 |
+
with ThreadPoolExecutor() as executor:
|
56 |
+
return await loop.run_in_executor(executor, _execute_query)
|
57 |
+
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
TEMPERATURE_COLORSCALE = [
|
2 |
+
[0.0, "rgb(5, 48, 97)"],
|
3 |
+
[0.10, "rgb(33, 102, 172)"],
|
4 |
+
[0.20, "rgb(67, 147, 195)"],
|
5 |
+
[0.30, "rgb(146, 197, 222)"],
|
6 |
+
[0.40, "rgb(209, 229, 240)"],
|
7 |
+
[0.50, "rgb(247, 247, 247)"],
|
8 |
+
[0.60, "rgb(253, 219, 199)"],
|
9 |
+
[0.75, "rgb(244, 165, 130)"],
|
10 |
+
[0.85, "rgb(214, 96, 77)"],
|
11 |
+
[0.90, "rgb(178, 24, 43)"],
|
12 |
+
[1.0, "rgb(103, 0, 31)"]
|
13 |
+
]
|
14 |
+
|
15 |
+
PRECIPITATION_COLORSCALE = [
|
16 |
+
[0.0, "rgb(84, 48, 5)"],
|
17 |
+
[0.10, "rgb(140, 81, 10)"],
|
18 |
+
[0.20, "rgb(191, 129, 45)"],
|
19 |
+
[0.30, "rgb(223, 194, 125)"],
|
20 |
+
[0.40, "rgb(246, 232, 195)"],
|
21 |
+
[0.50, "rgb(245, 245, 245)"],
|
22 |
+
[0.60, "rgb(199, 234, 229)"],
|
23 |
+
[0.75, "rgb(128, 205, 193)"],
|
24 |
+
[0.85, "rgb(53, 151, 143)"],
|
25 |
+
[0.90, "rgb(1, 102, 94)"],
|
26 |
+
[1.0, "rgb(0, 60, 48)"]
|
27 |
+
]
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dotenv import load_dotenv
|
2 |
+
from climateqa.engine.talk_to_data.vanna_class import MyCustomVectorDB
|
3 |
+
from vanna.openai import OpenAI_Chat
|
4 |
+
import os
|
5 |
+
|
6 |
+
load_dotenv()
|
7 |
+
|
8 |
+
OPENAI_API_KEY = os.getenv('THEO_API_KEY')
|
9 |
+
|
10 |
+
class MyVanna(MyCustomVectorDB, OpenAI_Chat):
|
11 |
+
def __init__(self, config=None):
|
12 |
+
MyCustomVectorDB.__init__(self, config=config)
|
13 |
+
OpenAI_Chat.__init__(self, config=config)
|
@@ -0,0 +1,325 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from vanna.base import VannaBase
|
2 |
+
from pinecone import Pinecone
|
3 |
+
from climateqa.engine.embeddings import get_embeddings_function
|
4 |
+
import pandas as pd
|
5 |
+
import hashlib
|
6 |
+
|
7 |
+
class MyCustomVectorDB(VannaBase):
|
8 |
+
|
9 |
+
"""
|
10 |
+
VectorDB class for storing and retrieving vectors from Pinecone.
|
11 |
+
|
12 |
+
args :
|
13 |
+
config (dict) : Configuration dictionary containing the Pinecone API key and the index name :
|
14 |
+
- pc_api_key (str) : Pinecone API key
|
15 |
+
- index_name (str) : Pinecone index name
|
16 |
+
- top_k (int) : Number of top results to return (default = 2)
|
17 |
+
|
18 |
+
"""
|
19 |
+
|
20 |
+
def __init__(self,config):
|
21 |
+
super().__init__(config = config)
|
22 |
+
try :
|
23 |
+
self.api_key = config.get('pc_api_key')
|
24 |
+
self.index_name = config.get('index_name')
|
25 |
+
except :
|
26 |
+
raise Exception("Please provide the Pinecone API key and the index name")
|
27 |
+
|
28 |
+
self.pc = Pinecone(api_key = self.api_key)
|
29 |
+
self.index = self.pc.Index(self.index_name)
|
30 |
+
self.top_k = config.get('top_k', 2)
|
31 |
+
self.embeddings = get_embeddings_function()
|
32 |
+
|
33 |
+
|
34 |
+
def check_embedding(self, id, namespace):
|
35 |
+
fetched = self.index.fetch(ids = [id], namespace = namespace)
|
36 |
+
if fetched['vectors'] == {}:
|
37 |
+
return False
|
38 |
+
return True
|
39 |
+
|
40 |
+
def generate_hash_id(self, data: str) -> str:
|
41 |
+
"""
|
42 |
+
Generate a unique hash ID for the given data.
|
43 |
+
|
44 |
+
Args:
|
45 |
+
data (str): The input data to hash (e.g., a concatenated string of user attributes).
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
str: A unique hash ID as a hexadecimal string.
|
49 |
+
"""
|
50 |
+
|
51 |
+
data_bytes = data.encode('utf-8')
|
52 |
+
hash_object = hashlib.sha256(data_bytes)
|
53 |
+
hash_id = hash_object.hexdigest()
|
54 |
+
|
55 |
+
return hash_id
|
56 |
+
|
57 |
+
def add_ddl(self, ddl: str, **kwargs) -> str:
|
58 |
+
id = self.generate_hash_id(ddl) + '_ddl'
|
59 |
+
|
60 |
+
if self.check_embedding(id, 'ddl'):
|
61 |
+
print(f"DDL having id {id} already exists")
|
62 |
+
return id
|
63 |
+
|
64 |
+
self.index.upsert(
|
65 |
+
vectors = [(id, self.embeddings.embed_query(ddl), {'ddl': ddl})],
|
66 |
+
namespace = 'ddl'
|
67 |
+
)
|
68 |
+
|
69 |
+
return id
|
70 |
+
|
71 |
+
def add_documentation(self, doc: str, **kwargs) -> str:
|
72 |
+
id = self.generate_hash_id(doc) + '_doc'
|
73 |
+
|
74 |
+
if self.check_embedding(id, 'documentation'):
|
75 |
+
print(f"Documentation having id {id} already exists")
|
76 |
+
return id
|
77 |
+
|
78 |
+
self.index.upsert(
|
79 |
+
vectors = [(id, self.embeddings.embed_query(doc), {'doc': doc})],
|
80 |
+
namespace = 'documentation'
|
81 |
+
)
|
82 |
+
|
83 |
+
return id
|
84 |
+
|
85 |
+
def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
|
86 |
+
id = self.generate_hash_id(question) + '_sql'
|
87 |
+
|
88 |
+
if self.check_embedding(id, 'question_sql'):
|
89 |
+
print(f"Question-SQL pair having id {id} already exists")
|
90 |
+
return id
|
91 |
+
|
92 |
+
self.index.upsert(
|
93 |
+
vectors = [(id, self.embeddings.embed_query(question + sql), {'question': question, 'sql': sql})],
|
94 |
+
namespace = 'question_sql'
|
95 |
+
)
|
96 |
+
|
97 |
+
return id
|
98 |
+
|
99 |
+
def get_related_ddl(self, question: str, **kwargs) -> list:
|
100 |
+
res = self.index.query(
|
101 |
+
vector=self.embeddings.embed_query(question),
|
102 |
+
top_k=self.top_k,
|
103 |
+
namespace='ddl',
|
104 |
+
include_metadata=True
|
105 |
+
)
|
106 |
+
|
107 |
+
return [match['metadata']['ddl'] for match in res['matches']]
|
108 |
+
|
109 |
+
def get_related_documentation(self, question: str, **kwargs) -> list:
|
110 |
+
res = self.index.query(
|
111 |
+
vector=self.embeddings.embed_query(question),
|
112 |
+
top_k=self.top_k,
|
113 |
+
namespace='documentation',
|
114 |
+
include_metadata=True
|
115 |
+
)
|
116 |
+
|
117 |
+
return [match['metadata']['doc'] for match in res['matches']]
|
118 |
+
|
119 |
+
def get_similar_question_sql(self, question: str, **kwargs) -> list:
|
120 |
+
res = self.index.query(
|
121 |
+
vector=self.embeddings.embed_query(question),
|
122 |
+
top_k=self.top_k,
|
123 |
+
namespace='question_sql',
|
124 |
+
include_metadata=True
|
125 |
+
)
|
126 |
+
|
127 |
+
return [(match['metadata']['question'], match['metadata']['sql']) for match in res['matches']]
|
128 |
+
|
129 |
+
def get_training_data(self, **kwargs) -> pd.DataFrame:
|
130 |
+
|
131 |
+
list_of_data = []
|
132 |
+
|
133 |
+
namespaces = ['ddl', 'documentation', 'question_sql']
|
134 |
+
|
135 |
+
for namespace in namespaces:
|
136 |
+
|
137 |
+
data = self.index.query(
|
138 |
+
top_k=10000,
|
139 |
+
namespace=namespace,
|
140 |
+
include_metadata=True,
|
141 |
+
include_values=False
|
142 |
+
)
|
143 |
+
|
144 |
+
for match in data['matches']:
|
145 |
+
list_of_data.append(match['metadata'])
|
146 |
+
|
147 |
+
return pd.DataFrame(list_of_data)
|
148 |
+
|
149 |
+
|
150 |
+
|
151 |
+
def remove_training_data(self, id: str, **kwargs) -> bool:
|
152 |
+
if id.endswith("_ddl"):
|
153 |
+
self.Index.delete(ids=[id], namespace="_ddl")
|
154 |
+
return True
|
155 |
+
if id.endswith("_sql"):
|
156 |
+
self.index.delete(ids=[id], namespace="_sql")
|
157 |
+
return True
|
158 |
+
|
159 |
+
if id.endswith("_doc"):
|
160 |
+
self.Index.delete(ids=[id], namespace="_doc")
|
161 |
+
return True
|
162 |
+
|
163 |
+
return False
|
164 |
+
|
165 |
+
def generate_embedding(self, text, **kwargs):
|
166 |
+
# Implement the method here
|
167 |
+
pass
|
168 |
+
|
169 |
+
|
170 |
+
def get_sql_prompt(
|
171 |
+
self,
|
172 |
+
initial_prompt : str,
|
173 |
+
question: str,
|
174 |
+
question_sql_list: list,
|
175 |
+
ddl_list: list,
|
176 |
+
doc_list: list,
|
177 |
+
**kwargs,
|
178 |
+
):
|
179 |
+
"""
|
180 |
+
Example:
|
181 |
+
```python
|
182 |
+
vn.get_sql_prompt(
|
183 |
+
question="What are the top 10 customers by sales?",
|
184 |
+
question_sql_list=[{"question": "What are the top 10 customers by sales?", "sql": "SELECT * FROM customers ORDER BY sales DESC LIMIT 10"}],
|
185 |
+
ddl_list=["CREATE TABLE customers (id INT, name TEXT, sales DECIMAL)"],
|
186 |
+
doc_list=["The customers table contains information about customers and their sales."],
|
187 |
+
)
|
188 |
+
|
189 |
+
```
|
190 |
+
|
191 |
+
This method is used to generate a prompt for the LLM to generate SQL.
|
192 |
+
|
193 |
+
Args:
|
194 |
+
question (str): The question to generate SQL for.
|
195 |
+
question_sql_list (list): A list of questions and their corresponding SQL statements.
|
196 |
+
ddl_list (list): A list of DDL statements.
|
197 |
+
doc_list (list): A list of documentation.
|
198 |
+
|
199 |
+
Returns:
|
200 |
+
any: The prompt for the LLM to generate SQL.
|
201 |
+
"""
|
202 |
+
|
203 |
+
if initial_prompt is None:
|
204 |
+
initial_prompt = f"You are a {self.dialect} expert. " + \
|
205 |
+
"Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions. "
|
206 |
+
|
207 |
+
initial_prompt = self.add_ddl_to_prompt(
|
208 |
+
initial_prompt, ddl_list, max_tokens=self.max_tokens
|
209 |
+
)
|
210 |
+
|
211 |
+
if self.static_documentation != "":
|
212 |
+
doc_list.append(self.static_documentation)
|
213 |
+
|
214 |
+
initial_prompt = self.add_documentation_to_prompt(
|
215 |
+
initial_prompt, doc_list, max_tokens=self.max_tokens
|
216 |
+
)
|
217 |
+
|
218 |
+
# initial_prompt = self.add_sql_to_prompt(
|
219 |
+
# initial_prompt, question_sql_list, max_tokens=self.max_tokens
|
220 |
+
# )
|
221 |
+
|
222 |
+
|
223 |
+
initial_prompt += (
|
224 |
+
"===Response Guidelines \n"
|
225 |
+
"1. If the provided context is sufficient, please generate a valid SQL query without any explanations for the question. \n"
|
226 |
+
"2. If the provided context is almost sufficient but requires knowledge of a specific string in a particular column, please generate an intermediate SQL query to find the distinct strings in that column. Prepend the query with a comment saying intermediate_sql \n"
|
227 |
+
"3. If the provided context is insufficient, please give a sql query based on your knowledge and the context provided. \n"
|
228 |
+
"4. Please use the most relevant table(s). \n"
|
229 |
+
"5. If the question has been asked and answered before, please repeat the answer exactly as it was given before. \n"
|
230 |
+
f"6. Ensure that the output SQL is {self.dialect}-compliant and executable, and free of syntax errors. \n"
|
231 |
+
f"7. Add a description of the table in the result of the sql query, if relevant. \n"
|
232 |
+
"8 Make sure to include the relevant KPI in the SQL query. The query should return impactfull data \n"
|
233 |
+
# f"8. If a set of latitude,longitude is provided, make a intermediate query to find the nearest value in the table and replace the coordinates in the sql query. \n"
|
234 |
+
# "7. Add a description of the table in the result of the sql query."
|
235 |
+
# "7. If the question is about a specific latitude, longitude, query an interval of 0.3 and keep only the first set of coordinate. \n"
|
236 |
+
# "7. Table names should be included in the result of the sql query. Use for example Mean_winter_temperature AS table_name in the query \n"
|
237 |
+
)
|
238 |
+
|
239 |
+
|
240 |
+
message_log = [self.system_message(initial_prompt)]
|
241 |
+
|
242 |
+
for example in question_sql_list:
|
243 |
+
if example is None:
|
244 |
+
print("example is None")
|
245 |
+
else:
|
246 |
+
if example is not None and "question" in example and "sql" in example:
|
247 |
+
message_log.append(self.user_message(example["question"]))
|
248 |
+
message_log.append(self.assistant_message(example["sql"]))
|
249 |
+
|
250 |
+
message_log.append(self.user_message(question))
|
251 |
+
|
252 |
+
return message_log
|
253 |
+
|
254 |
+
|
255 |
+
# def get_sql_prompt(
|
256 |
+
# self,
|
257 |
+
# initial_prompt : str,
|
258 |
+
# question: str,
|
259 |
+
# question_sql_list: list,
|
260 |
+
# ddl_list: list,
|
261 |
+
# doc_list: list,
|
262 |
+
# **kwargs,
|
263 |
+
# ):
|
264 |
+
# """
|
265 |
+
# Example:
|
266 |
+
# ```python
|
267 |
+
# vn.get_sql_prompt(
|
268 |
+
# question="What are the top 10 customers by sales?",
|
269 |
+
# question_sql_list=[{"question": "What are the top 10 customers by sales?", "sql": "SELECT * FROM customers ORDER BY sales DESC LIMIT 10"}],
|
270 |
+
# ddl_list=["CREATE TABLE customers (id INT, name TEXT, sales DECIMAL)"],
|
271 |
+
# doc_list=["The customers table contains information about customers and their sales."],
|
272 |
+
# )
|
273 |
+
|
274 |
+
# ```
|
275 |
+
|
276 |
+
# This method is used to generate a prompt for the LLM to generate SQL.
|
277 |
+
|
278 |
+
# Args:
|
279 |
+
# question (str): The question to generate SQL for.
|
280 |
+
# question_sql_list (list): A list of questions and their corresponding SQL statements.
|
281 |
+
# ddl_list (list): A list of DDL statements.
|
282 |
+
# doc_list (list): A list of documentation.
|
283 |
+
|
284 |
+
# Returns:
|
285 |
+
# any: The prompt for the LLM to generate SQL.
|
286 |
+
# """
|
287 |
+
|
288 |
+
# if initial_prompt is None:
|
289 |
+
# initial_prompt = f"You are a {self.dialect} expert. " + \
|
290 |
+
# "Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions. "
|
291 |
+
|
292 |
+
# initial_prompt = self.add_ddl_to_prompt(
|
293 |
+
# initial_prompt, ddl_list, max_tokens=self.max_tokens
|
294 |
+
# )
|
295 |
+
|
296 |
+
# if self.static_documentation != "":
|
297 |
+
# doc_list.append(self.static_documentation)
|
298 |
+
|
299 |
+
# initial_prompt = self.add_documentation_to_prompt(
|
300 |
+
# initial_prompt, doc_list, max_tokens=self.max_tokens
|
301 |
+
# )
|
302 |
+
|
303 |
+
# initial_prompt += (
|
304 |
+
# "===Response Guidelines \n"
|
305 |
+
# "1. If the provided context is sufficient, please generate a valid SQL query without any explanations for the question. \n"
|
306 |
+
# "2. If the provided context is almost sufficient but requires knowledge of a specific string in a particular column, please generate an intermediate SQL query to find the distinct strings in that column. Prepend the query with a comment saying intermediate_sql \n"
|
307 |
+
# "3. If the provided context is insufficient, please explain why it can't be generated. \n"
|
308 |
+
# "4. Please use the most relevant table(s). \n"
|
309 |
+
# "5. If the question has been asked and answered before, please repeat the answer exactly as it was given before. \n"
|
310 |
+
# f"6. Ensure that the output SQL is {self.dialect}-compliant and executable, and free of syntax errors. \n"
|
311 |
+
# )
|
312 |
+
|
313 |
+
# message_log = [self.system_message(initial_prompt)]
|
314 |
+
|
315 |
+
# for example in question_sql_list:
|
316 |
+
# if example is None:
|
317 |
+
# print("example is None")
|
318 |
+
# else:
|
319 |
+
# if example is not None and "question" in example and "sql" in example:
|
320 |
+
# message_log.append(self.user_message(example["question"]))
|
321 |
+
# message_log.append(self.assistant_message(example["sql"]))
|
322 |
+
|
323 |
+
# message_log.append(self.user_message(question))
|
324 |
+
|
325 |
+
# return message_log
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
from typing import Any
|
4 |
+
import asyncio
|
5 |
+
from climateqa.engine.llm import get_llm
|
6 |
+
from climateqa.engine.talk_to_data.input_processing import find_param, find_relevant_plots, find_relevant_tables_per_plot
|
7 |
+
from climateqa.engine.talk_to_data.query import execute_sql_query, find_indicator_column
|
8 |
+
from climateqa.engine.talk_to_data.objects.plot import Plot
|
9 |
+
from climateqa.engine.talk_to_data.objects.states import State, TTDOutput
|
10 |
+
from climateqa.engine.talk_to_data.drias.config import DRIAS_TABLES, DRIAS_INDICATOR_COLUMNS_PER_TABLE, DRIAS_PLOT_PARAMETERS
|
11 |
+
from climateqa.engine.talk_to_data.drias.plots import DRIAS_PLOTS
|
12 |
+
|
13 |
+
ROOT_PATH = os.path.dirname(os.path.dirname(os.getcwd()))
|
14 |
+
|
15 |
+
async def process_output(
|
16 |
+
output_title: str,
|
17 |
+
table: str,
|
18 |
+
plot: Plot,
|
19 |
+
params: dict[str, Any]
|
20 |
+
) -> tuple[str, TTDOutput, dict[str, bool]]:
|
21 |
+
"""
|
22 |
+
Processes a table for a given plot and parameters: builds the SQL query, executes it,
|
23 |
+
and generates the corresponding figure.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
output_title (str): Title for the output (used as key in outputs dict).
|
27 |
+
table (str): The name of the table to process.
|
28 |
+
plot (Plot): The plot object containing SQL query and visualization function.
|
29 |
+
params (dict[str, Any]): Parameters used for querying the table.
|
30 |
+
|
31 |
+
Returns:
|
32 |
+
tuple: (output_title, results dict, errors dict)
|
33 |
+
"""
|
34 |
+
results: TTDOutput = {
|
35 |
+
'status': 'OK',
|
36 |
+
'plot': plot,
|
37 |
+
'table': table,
|
38 |
+
'sql_query': None,
|
39 |
+
'dataframe': None,
|
40 |
+
'figure': None,
|
41 |
+
'plot_information': None
|
42 |
+
}
|
43 |
+
errors = {
|
44 |
+
'have_sql_query': False,
|
45 |
+
'have_dataframe': False
|
46 |
+
}
|
47 |
+
|
48 |
+
# Find the indicator column for this table
|
49 |
+
indicator_column = find_indicator_column(table, DRIAS_INDICATOR_COLUMNS_PER_TABLE)
|
50 |
+
if indicator_column:
|
51 |
+
params['indicator_column'] = indicator_column
|
52 |
+
|
53 |
+
# Build the SQL query
|
54 |
+
sql_query = plot['sql_query'](table, params)
|
55 |
+
if not sql_query:
|
56 |
+
results['status'] = 'ERROR'
|
57 |
+
return output_title, results, errors
|
58 |
+
|
59 |
+
results['plot_information'] = plot['plot_information'](table, params)
|
60 |
+
|
61 |
+
results['sql_query'] = sql_query
|
62 |
+
errors['have_sql_query'] = True
|
63 |
+
|
64 |
+
# Execute the SQL query
|
65 |
+
df = await execute_sql_query(sql_query)
|
66 |
+
if df is not None and len(df) > 0:
|
67 |
+
results['dataframe'] = df
|
68 |
+
errors['have_dataframe'] = True
|
69 |
+
else:
|
70 |
+
results['status'] = 'NO_DATA'
|
71 |
+
|
72 |
+
# Generate the figure (always, even if df is empty, for consistency)
|
73 |
+
results['figure'] = plot['plot_function'](params)
|
74 |
+
|
75 |
+
return output_title, results, errors
|
76 |
+
|
77 |
+
async def drias_workflow(user_input: str) -> State:
|
78 |
+
"""
|
79 |
+
Orchestrates the DRIAS workflow: from user input to SQL queries, dataframes, and figures.
|
80 |
+
|
81 |
+
Args:
|
82 |
+
user_input (str): The user's question.
|
83 |
+
|
84 |
+
Returns:
|
85 |
+
State: Final state with all results and error messages if any.
|
86 |
+
"""
|
87 |
+
state: State = {
|
88 |
+
'user_input': user_input,
|
89 |
+
'plots': [],
|
90 |
+
'outputs': {},
|
91 |
+
'error': ''
|
92 |
+
}
|
93 |
+
|
94 |
+
llm = get_llm(provider="openai")
|
95 |
+
plots = await find_relevant_plots(state, llm, DRIAS_PLOTS)
|
96 |
+
|
97 |
+
if not plots:
|
98 |
+
state['error'] = 'There is no plot to answer to the question'
|
99 |
+
return state
|
100 |
+
|
101 |
+
plots = plots[:2] # limit to 2 types of plots
|
102 |
+
state['plots'] = plots
|
103 |
+
|
104 |
+
errors = {
|
105 |
+
'have_relevant_table': False,
|
106 |
+
'have_sql_query': False,
|
107 |
+
'have_dataframe': False
|
108 |
+
}
|
109 |
+
outputs = {}
|
110 |
+
|
111 |
+
# Find relevant tables for each plot and prepare outputs
|
112 |
+
for plot_name in plots:
|
113 |
+
plot = next((p for p in DRIAS_PLOTS if p['name'] == plot_name), None)
|
114 |
+
if plot is None:
|
115 |
+
continue
|
116 |
+
|
117 |
+
relevant_tables = await find_relevant_tables_per_plot(state, plot, llm, DRIAS_TABLES)
|
118 |
+
if relevant_tables:
|
119 |
+
errors['have_relevant_table'] = True
|
120 |
+
|
121 |
+
for table in relevant_tables:
|
122 |
+
output_title = f"{plot['short_name']} - {' '.join(table.capitalize().split('_'))}"
|
123 |
+
outputs[output_title] = {
|
124 |
+
'table': table,
|
125 |
+
'plot': plot,
|
126 |
+
'status': 'OK'
|
127 |
+
}
|
128 |
+
|
129 |
+
# Gather all required parameters
|
130 |
+
params = {}
|
131 |
+
for param_name in DRIAS_PLOT_PARAMETERS:
|
132 |
+
param = await find_param(state, param_name, mode='DRIAS')
|
133 |
+
if param:
|
134 |
+
params.update(param)
|
135 |
+
|
136 |
+
# Process all outputs in parallel using process_output
|
137 |
+
tasks = [
|
138 |
+
process_output(output_title, output['table'], output['plot'], params.copy())
|
139 |
+
for output_title, output in outputs.items()
|
140 |
+
]
|
141 |
+
results = await asyncio.gather(*tasks)
|
142 |
+
|
143 |
+
# Update outputs with results and error flags
|
144 |
+
for output_title, task_results, task_errors in results:
|
145 |
+
outputs[output_title]['sql_query'] = task_results['sql_query']
|
146 |
+
outputs[output_title]['dataframe'] = task_results['dataframe']
|
147 |
+
outputs[output_title]['figure'] = task_results['figure']
|
148 |
+
outputs[output_title]['plot_information'] = task_results['plot_information']
|
149 |
+
outputs[output_title]['status'] = task_results['status']
|
150 |
+
errors['have_sql_query'] |= task_errors['have_sql_query']
|
151 |
+
errors['have_dataframe'] |= task_errors['have_dataframe']
|
152 |
+
|
153 |
+
state['outputs'] = outputs
|
154 |
+
|
155 |
+
# Set error messages if needed
|
156 |
+
if not errors['have_relevant_table']:
|
157 |
+
state['error'] = "There is no relevant table in our database to answer your question"
|
158 |
+
elif not errors['have_sql_query']:
|
159 |
+
state['error'] = "There is no relevant sql query on our database that can help to answer your question"
|
160 |
+
elif not errors['have_dataframe']:
|
161 |
+
state['error'] = "There is no data in our table that can answer to your question"
|
162 |
+
|
163 |
+
return state
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
from typing import Any
|
4 |
+
import asyncio
|
5 |
+
from climateqa.engine.llm import get_llm
|
6 |
+
from climateqa.engine.talk_to_data.input_processing import find_param, find_relevant_plots, find_relevant_tables_per_plot
|
7 |
+
from climateqa.engine.talk_to_data.query import execute_sql_query, find_indicator_column
|
8 |
+
from climateqa.engine.talk_to_data.objects.plot import Plot
|
9 |
+
from climateqa.engine.talk_to_data.objects.states import State, TTDOutput
|
10 |
+
from climateqa.engine.talk_to_data.ipcc.config import IPCC_TABLES, IPCC_INDICATOR_COLUMNS_PER_TABLE, IPCC_PLOT_PARAMETERS
|
11 |
+
from climateqa.engine.talk_to_data.ipcc.plots import IPCC_PLOTS
|
12 |
+
|
13 |
+
ROOT_PATH = os.path.dirname(os.path.dirname(os.getcwd()))
|
14 |
+
|
15 |
+
async def process_output(
|
16 |
+
output_title: str,
|
17 |
+
table: str,
|
18 |
+
plot: Plot,
|
19 |
+
params: dict[str, Any]
|
20 |
+
) -> tuple[str, TTDOutput, dict[str, bool]]:
|
21 |
+
"""
|
22 |
+
Process a table for a given plot and parameters: builds the SQL query, executes it,
|
23 |
+
and generates the corresponding figure.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
output_title (str): Title for the output (used as key in outputs dict).
|
27 |
+
table (str): The name of the table to process.
|
28 |
+
plot (Plot): The plot object containing SQL query and visualization function.
|
29 |
+
params (dict[str, Any]): Parameters used for querying the table.
|
30 |
+
|
31 |
+
Returns:
|
32 |
+
tuple: (output_title, results dict, errors dict)
|
33 |
+
"""
|
34 |
+
results: TTDOutput = {
|
35 |
+
'status': 'OK',
|
36 |
+
'plot': plot,
|
37 |
+
'table': table,
|
38 |
+
'sql_query': None,
|
39 |
+
'dataframe': None,
|
40 |
+
'figure': None,
|
41 |
+
'plot_information': None,
|
42 |
+
}
|
43 |
+
errors = {
|
44 |
+
'have_sql_query': False,
|
45 |
+
'have_dataframe': False
|
46 |
+
}
|
47 |
+
|
48 |
+
# Find the indicator column for this table
|
49 |
+
indicator_column = find_indicator_column(table, IPCC_INDICATOR_COLUMNS_PER_TABLE)
|
50 |
+
if indicator_column:
|
51 |
+
params['indicator_column'] = indicator_column
|
52 |
+
|
53 |
+
# Build the SQL query
|
54 |
+
sql_query = plot['sql_query'](table, params)
|
55 |
+
if not sql_query:
|
56 |
+
results['status'] = 'ERROR'
|
57 |
+
return output_title, results, errors
|
58 |
+
|
59 |
+
results['plot_information'] = plot['plot_information'](table, params)
|
60 |
+
|
61 |
+
results['sql_query'] = sql_query
|
62 |
+
errors['have_sql_query'] = True
|
63 |
+
|
64 |
+
# Execute the SQL query
|
65 |
+
df = await execute_sql_query(sql_query)
|
66 |
+
if df is not None and not df.empty:
|
67 |
+
results['dataframe'] = df
|
68 |
+
errors['have_dataframe'] = True
|
69 |
+
else:
|
70 |
+
results['status'] = 'NO_DATA'
|
71 |
+
|
72 |
+
# Generate the figure (always, even if df is empty, for consistency)
|
73 |
+
results['figure'] = plot['plot_function'](params)
|
74 |
+
|
75 |
+
return output_title, results, errors
|
76 |
+
|
77 |
+
async def ipcc_workflow(user_input: str) -> State:
|
78 |
+
"""
|
79 |
+
Performs the complete workflow of Talk To IPCC: from user input to SQL queries, dataframes, and figures.
|
80 |
+
|
81 |
+
Args:
|
82 |
+
user_input (str): The user's question.
|
83 |
+
|
84 |
+
Returns:
|
85 |
+
State: Final state with all the results and error messages if any.
|
86 |
+
"""
|
87 |
+
state: State = {
|
88 |
+
'user_input': user_input,
|
89 |
+
'plots': [],
|
90 |
+
'outputs': {},
|
91 |
+
'error': ''
|
92 |
+
}
|
93 |
+
|
94 |
+
llm = get_llm(provider="openai")
|
95 |
+
plots = await find_relevant_plots(state, llm, IPCC_PLOTS)
|
96 |
+
state['plots'] = plots
|
97 |
+
|
98 |
+
if not plots:
|
99 |
+
state['error'] = 'There is no plot to answer to the question'
|
100 |
+
return state
|
101 |
+
|
102 |
+
errors = {
|
103 |
+
'have_relevant_table': False,
|
104 |
+
'have_sql_query': False,
|
105 |
+
'have_dataframe': False
|
106 |
+
}
|
107 |
+
outputs = {}
|
108 |
+
|
109 |
+
# Find relevant tables for each plot and prepare outputs
|
110 |
+
for plot_name in plots:
|
111 |
+
plot = next((p for p in IPCC_PLOTS if p['name'] == plot_name), None)
|
112 |
+
if plot is None:
|
113 |
+
continue
|
114 |
+
|
115 |
+
relevant_tables = await find_relevant_tables_per_plot(state, plot, llm, IPCC_TABLES)
|
116 |
+
if relevant_tables:
|
117 |
+
errors['have_relevant_table'] = True
|
118 |
+
|
119 |
+
for table in relevant_tables:
|
120 |
+
output_title = f"{plot['short_name']} - {' '.join(table.capitalize().split('_'))}"
|
121 |
+
outputs[output_title] = {
|
122 |
+
'table': table,
|
123 |
+
'plot': plot,
|
124 |
+
'status': 'OK'
|
125 |
+
}
|
126 |
+
|
127 |
+
# Gather all required parameters
|
128 |
+
params = {}
|
129 |
+
for param_name in IPCC_PLOT_PARAMETERS:
|
130 |
+
param = await find_param(state, param_name, mode='IPCC')
|
131 |
+
if param:
|
132 |
+
params.update(param)
|
133 |
+
|
134 |
+
# Process all outputs in parallel using process_output
|
135 |
+
tasks = [
|
136 |
+
process_output(output_title, output['table'], output['plot'], params.copy())
|
137 |
+
for output_title, output in outputs.items()
|
138 |
+
]
|
139 |
+
results = await asyncio.gather(*tasks)
|
140 |
+
|
141 |
+
# Update outputs with results and error flags
|
142 |
+
for output_title, task_results, task_errors in results:
|
143 |
+
outputs[output_title]['sql_query'] = task_results['sql_query']
|
144 |
+
outputs[output_title]['dataframe'] = task_results['dataframe']
|
145 |
+
outputs[output_title]['figure'] = task_results['figure']
|
146 |
+
outputs[output_title]['plot_information'] = task_results['plot_information']
|
147 |
+
outputs[output_title]['status'] = task_results['status']
|
148 |
+
errors['have_sql_query'] |= task_errors['have_sql_query']
|
149 |
+
errors['have_dataframe'] |= task_errors['have_dataframe']
|
150 |
+
|
151 |
+
state['outputs'] = outputs
|
152 |
+
|
153 |
+
# Set error messages if needed
|
154 |
+
if not errors['have_relevant_table']:
|
155 |
+
state['error'] = "There is no relevant table in our database to answer your question"
|
156 |
+
elif not errors['have_sql_query']:
|
157 |
+
state['error'] = "There is no relevant sql query on our database that can help to answer your question"
|
158 |
+
elif not errors['have_dataframe']:
|
159 |
+
state['error'] = "There is no data in our table that can answer to your question"
|
160 |
+
|
161 |
+
return state
|
![]() |
Git LFS Details
|
![]() |
Git LFS Details
|
![]() |
Git LFS Details
|
![]() |
Git LFS Details
|
![]() |
Git LFS Details
|
![]() |
Git LFS Details
|
@@ -4,26 +4,25 @@ import os
|
|
4 |
import pandas as pd
|
5 |
|
6 |
from climateqa.engine.talk_to_data.main import ask_drias
|
7 |
-
from climateqa.engine.talk_to_data.config import DRIAS_MODELS, DRIAS_UI_TEXT
|
8 |
|
9 |
class DriasUIElements(TypedDict):
|
10 |
tab: gr.Tab
|
11 |
details_accordion: gr.Accordion
|
12 |
examples_hidden: gr.Textbox
|
13 |
examples: gr.Examples
|
|
|
14 |
drias_direct_question: gr.Textbox
|
15 |
result_text: gr.Textbox
|
16 |
-
table_names_display: gr.
|
17 |
query_accordion: gr.Accordion
|
18 |
drias_sql_query: gr.Textbox
|
19 |
chart_accordion: gr.Accordion
|
|
|
20 |
model_selection: gr.Dropdown
|
21 |
drias_display: gr.Plot
|
22 |
table_accordion: gr.Accordion
|
23 |
drias_table: gr.DataFrame
|
24 |
-
pagination_display: gr.Markdown
|
25 |
-
prev_button: gr.Button
|
26 |
-
next_button: gr.Button
|
27 |
|
28 |
|
29 |
async def ask_drias_query(query: str, index_state: int, user_id: str):
|
@@ -31,7 +30,7 @@ async def ask_drias_query(query: str, index_state: int, user_id: str):
|
|
31 |
return result
|
32 |
|
33 |
|
34 |
-
def show_results(sql_queries_state, dataframes_state, plots_state):
|
35 |
if not sql_queries_state or not dataframes_state or not plots_state:
|
36 |
# If all results are empty, show "No result"
|
37 |
return (
|
@@ -40,9 +39,6 @@ def show_results(sql_queries_state, dataframes_state, plots_state):
|
|
40 |
gr.update(visible=False),
|
41 |
gr.update(visible=False),
|
42 |
gr.update(visible=False),
|
43 |
-
gr.update(visible=False),
|
44 |
-
gr.update(visible=False),
|
45 |
-
gr.update(visible=False),
|
46 |
)
|
47 |
else:
|
48 |
# Show the appropriate components with their data
|
@@ -51,10 +47,7 @@ def show_results(sql_queries_state, dataframes_state, plots_state):
|
|
51 |
gr.update(visible=True),
|
52 |
gr.update(visible=True),
|
53 |
gr.update(visible=True),
|
54 |
-
gr.update(visible=True),
|
55 |
-
gr.update(visible=True),
|
56 |
-
gr.update(visible=True),
|
57 |
-
gr.update(visible=True),
|
58 |
)
|
59 |
|
60 |
|
@@ -72,44 +65,14 @@ def filter_by_model(dataframes, figures, index_state, model_selection):
|
|
72 |
return df, figure
|
73 |
|
74 |
|
75 |
-
def
|
76 |
-
|
77 |
-
return pagination
|
78 |
-
|
79 |
-
|
80 |
-
def show_previous(index, sql_queries, dataframes, plots):
|
81 |
-
if index > 0:
|
82 |
-
index -= 1
|
83 |
-
return (
|
84 |
-
sql_queries[index],
|
85 |
-
dataframes[index],
|
86 |
-
plots[index](dataframes[index]),
|
87 |
-
index,
|
88 |
-
)
|
89 |
-
|
90 |
-
|
91 |
-
def show_next(index, sql_queries, dataframes, plots):
|
92 |
-
if index < len(sql_queries) - 1:
|
93 |
-
index += 1
|
94 |
-
return (
|
95 |
-
sql_queries[index],
|
96 |
-
dataframes[index],
|
97 |
-
plots[index](dataframes[index]),
|
98 |
-
index,
|
99 |
-
)
|
100 |
-
|
101 |
-
|
102 |
-
def display_table_names(table_names):
|
103 |
-
return [table_names]
|
104 |
-
|
105 |
-
|
106 |
-
def on_table_click(evt: gr.SelectData, table_names, sql_queries, dataframes, plots):
|
107 |
-
index = evt.index[1]
|
108 |
figure = plots[index](dataframes[index])
|
109 |
return (
|
110 |
sql_queries[index],
|
111 |
dataframes[index],
|
112 |
figure,
|
|
|
113 |
index,
|
114 |
)
|
115 |
|
@@ -117,7 +80,7 @@ def on_table_click(evt: gr.SelectData, table_names, sql_queries, dataframes, plo
|
|
117 |
def create_drias_ui() -> DriasUIElements:
|
118 |
"""Create and return all UI elements for the DRIAS tab."""
|
119 |
with gr.Tab("France - Talk to DRIAS", elem_id="tab-vanna", id=6) as tab:
|
120 |
-
with gr.Accordion(label="
|
121 |
gr.Markdown(DRIAS_UI_TEXT)
|
122 |
|
123 |
# Add examples for common questions
|
@@ -141,24 +104,43 @@ def create_drias_ui() -> DriasUIElements:
|
|
141 |
elem_id="direct-question",
|
142 |
interactive=True,
|
143 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
144 |
|
145 |
result_text = gr.Textbox(
|
146 |
label="", elem_id="no-result-label", interactive=False, visible=True
|
147 |
)
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
)
|
157 |
|
|
|
|
|
|
|
|
|
|
|
|
|
158 |
with gr.Accordion(label="Chart", visible=False) as chart_accordion:
|
159 |
-
|
160 |
-
|
161 |
-
|
|
|
|
|
|
|
|
|
162 |
drias_display = gr.Plot(elem_id="vanna-plot")
|
163 |
|
164 |
with gr.Accordion(
|
@@ -166,32 +148,23 @@ def create_drias_ui() -> DriasUIElements:
|
|
166 |
) as table_accordion:
|
167 |
drias_table = gr.DataFrame([], elem_id="vanna-table")
|
168 |
|
169 |
-
pagination_display = gr.Markdown(
|
170 |
-
value="", visible=False, elem_id="pagination-display"
|
171 |
-
)
|
172 |
-
|
173 |
-
with gr.Row():
|
174 |
-
prev_button = gr.Button("Previous", visible=False)
|
175 |
-
next_button = gr.Button("Next", visible=False)
|
176 |
-
|
177 |
return DriasUIElements(
|
178 |
tab=tab,
|
179 |
details_accordion=details_accordion,
|
180 |
examples_hidden=examples_hidden,
|
181 |
examples=examples,
|
|
|
182 |
drias_direct_question=drias_direct_question,
|
183 |
result_text=result_text,
|
184 |
table_names_display=table_names_display,
|
185 |
query_accordion=query_accordion,
|
186 |
drias_sql_query=drias_sql_query,
|
187 |
chart_accordion=chart_accordion,
|
|
|
188 |
model_selection=model_selection,
|
189 |
drias_display=drias_display,
|
190 |
table_accordion=table_accordion,
|
191 |
drias_table=drias_table,
|
192 |
-
pagination_display=pagination_display,
|
193 |
-
prev_button=prev_button,
|
194 |
-
next_button=next_button
|
195 |
)
|
196 |
|
197 |
|
@@ -202,94 +175,56 @@ def setup_drias_events(ui_elements: DriasUIElements, share_client=None, user_id=
|
|
202 |
sql_queries_state = gr.State([])
|
203 |
dataframes_state = gr.State([])
|
204 |
plots_state = gr.State([])
|
|
|
205 |
index_state = gr.State(0)
|
206 |
table_names_list = gr.State([])
|
207 |
user_id = gr.State(user_id)
|
208 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
209 |
# Handle example selection
|
210 |
ui_elements["examples_hidden"].change(
|
211 |
lambda x: (gr.Accordion(open=False), gr.Textbox(value=x)),
|
212 |
inputs=[ui_elements["examples_hidden"]],
|
213 |
outputs=[ui_elements["details_accordion"], ui_elements["drias_direct_question"]]
|
214 |
).then(
|
215 |
-
|
216 |
-
inputs=[ui_elements["examples_hidden"], index_state, user_id],
|
217 |
-
outputs=[
|
218 |
-
ui_elements["drias_sql_query"],
|
219 |
-
ui_elements["drias_table"],
|
220 |
-
ui_elements["drias_display"],
|
221 |
-
sql_queries_state,
|
222 |
-
dataframes_state,
|
223 |
-
plots_state,
|
224 |
-
index_state,
|
225 |
-
table_names_list,
|
226 |
-
ui_elements["result_text"],
|
227 |
-
],
|
228 |
-
).then(
|
229 |
-
show_results,
|
230 |
-
inputs=[sql_queries_state, dataframes_state, plots_state],
|
231 |
-
outputs=[
|
232 |
-
ui_elements["result_text"],
|
233 |
-
ui_elements["query_accordion"],
|
234 |
-
ui_elements["table_accordion"],
|
235 |
-
ui_elements["chart_accordion"],
|
236 |
-
ui_elements["prev_button"],
|
237 |
-
ui_elements["next_button"],
|
238 |
-
ui_elements["pagination_display"],
|
239 |
-
ui_elements["table_names_display"],
|
240 |
-
],
|
241 |
-
).then(
|
242 |
-
update_pagination,
|
243 |
-
inputs=[index_state, sql_queries_state],
|
244 |
-
outputs=[ui_elements["pagination_display"]],
|
245 |
-
).then(
|
246 |
-
display_table_names,
|
247 |
-
inputs=[table_names_list],
|
248 |
-
outputs=[ui_elements["table_names_display"]],
|
249 |
-
)
|
250 |
-
|
251 |
-
# Handle direct question submission
|
252 |
-
ui_elements["drias_direct_question"].submit(
|
253 |
-
lambda: gr.Accordion(open=False),
|
254 |
inputs=None,
|
255 |
-
outputs=
|
256 |
).then(
|
257 |
ask_drias_query,
|
258 |
-
inputs=[ui_elements["
|
259 |
outputs=[
|
260 |
ui_elements["drias_sql_query"],
|
261 |
ui_elements["drias_table"],
|
262 |
ui_elements["drias_display"],
|
|
|
263 |
sql_queries_state,
|
264 |
dataframes_state,
|
265 |
plots_state,
|
|
|
266 |
index_state,
|
267 |
table_names_list,
|
268 |
ui_elements["result_text"],
|
269 |
],
|
270 |
).then(
|
271 |
show_results,
|
272 |
-
inputs=[sql_queries_state, dataframes_state, plots_state],
|
273 |
outputs=[
|
274 |
ui_elements["result_text"],
|
275 |
ui_elements["query_accordion"],
|
276 |
ui_elements["table_accordion"],
|
277 |
ui_elements["chart_accordion"],
|
278 |
-
ui_elements["prev_button"],
|
279 |
-
ui_elements["next_button"],
|
280 |
-
ui_elements["pagination_display"],
|
281 |
ui_elements["table_names_display"],
|
282 |
],
|
283 |
-
).then(
|
284 |
-
update_pagination,
|
285 |
-
inputs=[index_state, sql_queries_state],
|
286 |
-
outputs=[ui_elements["pagination_display"]],
|
287 |
-
).then(
|
288 |
-
display_table_names,
|
289 |
-
inputs=[table_names_list],
|
290 |
-
outputs=[ui_elements["table_names_display"]],
|
291 |
)
|
292 |
|
|
|
293 |
# Handle model selection change
|
294 |
ui_elements["model_selection"].change(
|
295 |
filter_by_model,
|
@@ -297,36 +232,12 @@ def setup_drias_events(ui_elements: DriasUIElements, share_client=None, user_id=
|
|
297 |
outputs=[ui_elements["drias_table"], ui_elements["drias_display"]],
|
298 |
)
|
299 |
|
300 |
-
# Handle pagination buttons
|
301 |
-
ui_elements["prev_button"].click(
|
302 |
-
show_previous,
|
303 |
-
inputs=[index_state, sql_queries_state, dataframes_state, plots_state],
|
304 |
-
outputs=[ui_elements["drias_sql_query"], ui_elements["drias_table"], ui_elements["drias_display"], index_state],
|
305 |
-
).then(
|
306 |
-
update_pagination,
|
307 |
-
inputs=[index_state, sql_queries_state],
|
308 |
-
outputs=[ui_elements["pagination_display"]],
|
309 |
-
)
|
310 |
-
|
311 |
-
ui_elements["next_button"].click(
|
312 |
-
show_next,
|
313 |
-
inputs=[index_state, sql_queries_state, dataframes_state, plots_state],
|
314 |
-
outputs=[ui_elements["drias_sql_query"], ui_elements["drias_table"], ui_elements["drias_display"], index_state],
|
315 |
-
).then(
|
316 |
-
update_pagination,
|
317 |
-
inputs=[index_state, sql_queries_state],
|
318 |
-
outputs=[ui_elements["pagination_display"]],
|
319 |
-
)
|
320 |
|
321 |
# Handle table selection
|
322 |
-
ui_elements["table_names_display"].
|
323 |
fn=on_table_click,
|
324 |
-
inputs=[table_names_list, sql_queries_state, dataframes_state, plots_state],
|
325 |
-
outputs=[ui_elements["drias_sql_query"], ui_elements["drias_table"], ui_elements["drias_display"], index_state],
|
326 |
-
).then(
|
327 |
-
update_pagination,
|
328 |
-
inputs=[index_state, sql_queries_state],
|
329 |
-
outputs=[ui_elements["pagination_display"]],
|
330 |
)
|
331 |
|
332 |
def create_drias_tab(share_client=None, user_id=None):
|
|
|
4 |
import pandas as pd
|
5 |
|
6 |
from climateqa.engine.talk_to_data.main import ask_drias
|
7 |
+
from climateqa.engine.talk_to_data.drias.config import DRIAS_MODELS, DRIAS_UI_TEXT
|
8 |
|
9 |
class DriasUIElements(TypedDict):
|
10 |
tab: gr.Tab
|
11 |
details_accordion: gr.Accordion
|
12 |
examples_hidden: gr.Textbox
|
13 |
examples: gr.Examples
|
14 |
+
image_examples: gr.Row
|
15 |
drias_direct_question: gr.Textbox
|
16 |
result_text: gr.Textbox
|
17 |
+
table_names_display: gr.Radio
|
18 |
query_accordion: gr.Accordion
|
19 |
drias_sql_query: gr.Textbox
|
20 |
chart_accordion: gr.Accordion
|
21 |
+
plot_information: gr.Markdown
|
22 |
model_selection: gr.Dropdown
|
23 |
drias_display: gr.Plot
|
24 |
table_accordion: gr.Accordion
|
25 |
drias_table: gr.DataFrame
|
|
|
|
|
|
|
26 |
|
27 |
|
28 |
async def ask_drias_query(query: str, index_state: int, user_id: str):
|
|
|
30 |
return result
|
31 |
|
32 |
|
33 |
+
def show_results(sql_queries_state, dataframes_state, plots_state, table_names):
|
34 |
if not sql_queries_state or not dataframes_state or not plots_state:
|
35 |
# If all results are empty, show "No result"
|
36 |
return (
|
|
|
39 |
gr.update(visible=False),
|
40 |
gr.update(visible=False),
|
41 |
gr.update(visible=False),
|
|
|
|
|
|
|
42 |
)
|
43 |
else:
|
44 |
# Show the appropriate components with their data
|
|
|
47 |
gr.update(visible=True),
|
48 |
gr.update(visible=True),
|
49 |
gr.update(visible=True),
|
50 |
+
gr.update(choices=table_names, value=table_names[0], visible=True),
|
|
|
|
|
|
|
51 |
)
|
52 |
|
53 |
|
|
|
65 |
return df, figure
|
66 |
|
67 |
|
68 |
+
def on_table_click(selected_label, table_names, sql_queries, dataframes, plot_informations, plots):
|
69 |
+
index = table_names.index(selected_label)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
figure = plots[index](dataframes[index])
|
71 |
return (
|
72 |
sql_queries[index],
|
73 |
dataframes[index],
|
74 |
figure,
|
75 |
+
plot_informations[index],
|
76 |
index,
|
77 |
)
|
78 |
|
|
|
80 |
def create_drias_ui() -> DriasUIElements:
|
81 |
"""Create and return all UI elements for the DRIAS tab."""
|
82 |
with gr.Tab("France - Talk to DRIAS", elem_id="tab-vanna", id=6) as tab:
|
83 |
+
with gr.Accordion(label="❓ How to use?", elem_id="details") as details_accordion:
|
84 |
gr.Markdown(DRIAS_UI_TEXT)
|
85 |
|
86 |
# Add examples for common questions
|
|
|
104 |
elem_id="direct-question",
|
105 |
interactive=True,
|
106 |
)
|
107 |
+
|
108 |
+
|
109 |
+
with gr.Row(visible=True, elem_id="example-img-container") as image_examples:
|
110 |
+
gr.Markdown("### Examples of possible visualizations")
|
111 |
+
|
112 |
+
with gr.Row():
|
113 |
+
gr.Image("./front/assets/talk_to_drias_winter_temp_paris_example.png", label="Evolution of Mean Winter Temperature in Paris", elem_classes=["example-img"])
|
114 |
+
gr.Image("./front/assets/talk_to_drias_annual_temperature_france_example.png", label="Mean Annual Temperature in 2030 in France", elem_classes=["example-img"])
|
115 |
+
gr.Image("./front/assets/talk_to_drias_frequency_remarkable_precipitation_lyon_example.png", label="Frequency of Remarkable Daily Precipitation in Lyon", elem_classes=["example-img"])
|
116 |
|
117 |
result_text = gr.Textbox(
|
118 |
label="", elem_id="no-result-label", interactive=False, visible=True
|
119 |
)
|
120 |
+
|
121 |
+
with gr.Row():
|
122 |
+
table_names_display = gr.Radio(
|
123 |
+
choices=[],
|
124 |
+
label="Relevant figures created",
|
125 |
+
interactive=True,
|
126 |
+
elem_id="table-names",
|
127 |
+
visible=False
|
128 |
)
|
129 |
|
130 |
+
with gr.Accordion(label="SQL Query Used", visible=False) as query_accordion:
|
131 |
+
drias_sql_query = gr.Textbox(
|
132 |
+
label="", elem_id="sql-query", interactive=False
|
133 |
+
)
|
134 |
+
|
135 |
+
|
136 |
with gr.Accordion(label="Chart", visible=False) as chart_accordion:
|
137 |
+
with gr.Row():
|
138 |
+
model_selection = gr.Dropdown(
|
139 |
+
label="Model", choices=DRIAS_MODELS, value="ALL", interactive=True
|
140 |
+
)
|
141 |
+
with gr.Accordion(label="Informations about the plot", open=False):
|
142 |
+
plot_information = gr.Markdown(value = "")
|
143 |
+
|
144 |
drias_display = gr.Plot(elem_id="vanna-plot")
|
145 |
|
146 |
with gr.Accordion(
|
|
|
148 |
) as table_accordion:
|
149 |
drias_table = gr.DataFrame([], elem_id="vanna-table")
|
150 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
151 |
return DriasUIElements(
|
152 |
tab=tab,
|
153 |
details_accordion=details_accordion,
|
154 |
examples_hidden=examples_hidden,
|
155 |
examples=examples,
|
156 |
+
image_examples=image_examples,
|
157 |
drias_direct_question=drias_direct_question,
|
158 |
result_text=result_text,
|
159 |
table_names_display=table_names_display,
|
160 |
query_accordion=query_accordion,
|
161 |
drias_sql_query=drias_sql_query,
|
162 |
chart_accordion=chart_accordion,
|
163 |
+
plot_information=plot_information,
|
164 |
model_selection=model_selection,
|
165 |
drias_display=drias_display,
|
166 |
table_accordion=table_accordion,
|
167 |
drias_table=drias_table,
|
|
|
|
|
|
|
168 |
)
|
169 |
|
170 |
|
|
|
175 |
sql_queries_state = gr.State([])
|
176 |
dataframes_state = gr.State([])
|
177 |
plots_state = gr.State([])
|
178 |
+
plot_informations_state = gr.State([])
|
179 |
index_state = gr.State(0)
|
180 |
table_names_list = gr.State([])
|
181 |
user_id = gr.State(user_id)
|
182 |
|
183 |
+
# Handle direct question submission - trigger the same workflow by setting examples_hidden
|
184 |
+
ui_elements["drias_direct_question"].submit(
|
185 |
+
lambda x: gr.update(value=x),
|
186 |
+
inputs=[ui_elements["drias_direct_question"]],
|
187 |
+
outputs=[ui_elements["examples_hidden"]],
|
188 |
+
)
|
189 |
+
|
190 |
# Handle example selection
|
191 |
ui_elements["examples_hidden"].change(
|
192 |
lambda x: (gr.Accordion(open=False), gr.Textbox(value=x)),
|
193 |
inputs=[ui_elements["examples_hidden"]],
|
194 |
outputs=[ui_elements["details_accordion"], ui_elements["drias_direct_question"]]
|
195 |
).then(
|
196 |
+
lambda : gr.update(visible=False),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
197 |
inputs=None,
|
198 |
+
outputs=ui_elements["image_examples"]
|
199 |
).then(
|
200 |
ask_drias_query,
|
201 |
+
inputs=[ui_elements["examples_hidden"], index_state, user_id],
|
202 |
outputs=[
|
203 |
ui_elements["drias_sql_query"],
|
204 |
ui_elements["drias_table"],
|
205 |
ui_elements["drias_display"],
|
206 |
+
ui_elements["plot_information"],
|
207 |
sql_queries_state,
|
208 |
dataframes_state,
|
209 |
plots_state,
|
210 |
+
plot_informations_state,
|
211 |
index_state,
|
212 |
table_names_list,
|
213 |
ui_elements["result_text"],
|
214 |
],
|
215 |
).then(
|
216 |
show_results,
|
217 |
+
inputs=[sql_queries_state, dataframes_state, plots_state, table_names_list],
|
218 |
outputs=[
|
219 |
ui_elements["result_text"],
|
220 |
ui_elements["query_accordion"],
|
221 |
ui_elements["table_accordion"],
|
222 |
ui_elements["chart_accordion"],
|
|
|
|
|
|
|
223 |
ui_elements["table_names_display"],
|
224 |
],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
225 |
)
|
226 |
|
227 |
+
|
228 |
# Handle model selection change
|
229 |
ui_elements["model_selection"].change(
|
230 |
filter_by_model,
|
|
|
232 |
outputs=[ui_elements["drias_table"], ui_elements["drias_display"]],
|
233 |
)
|
234 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
235 |
|
236 |
# Handle table selection
|
237 |
+
ui_elements["table_names_display"].change(
|
238 |
fn=on_table_click,
|
239 |
+
inputs=[ui_elements["table_names_display"], table_names_list, sql_queries_state, dataframes_state, plot_informations_state, plots_state],
|
240 |
+
outputs=[ui_elements["drias_sql_query"], ui_elements["drias_table"], ui_elements["drias_display"], ui_elements["plot_information"], index_state],
|
|
|
|
|
|
|
|
|
241 |
)
|
242 |
|
243 |
def create_drias_tab(share_client=None, user_id=None):
|
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from random import choices
|
2 |
+
import gradio as gr
|
3 |
+
from typing import TypedDict
|
4 |
+
from climateqa.engine.talk_to_data.main import ask_ipcc
|
5 |
+
from climateqa.engine.talk_to_data.ipcc.config import IPCC_MODELS, IPCC_SCENARIO, IPCC_UI_TEXT
|
6 |
+
import uuid
|
7 |
+
|
8 |
+
class ipccUIElements(TypedDict):
|
9 |
+
tab: gr.Tab
|
10 |
+
details_accordion: gr.Accordion
|
11 |
+
examples_hidden: gr.Textbox
|
12 |
+
examples: gr.Examples
|
13 |
+
image_examples: gr.Row
|
14 |
+
ipcc_direct_question: gr.Textbox
|
15 |
+
result_text: gr.Textbox
|
16 |
+
table_names_display: gr.Radio
|
17 |
+
query_accordion: gr.Accordion
|
18 |
+
ipcc_sql_query: gr.Textbox
|
19 |
+
chart_accordion: gr.Accordion
|
20 |
+
plot_information: gr.Markdown
|
21 |
+
scenario_selection: gr.Dropdown
|
22 |
+
ipcc_display: gr.Plot
|
23 |
+
table_accordion: gr.Accordion
|
24 |
+
ipcc_table: gr.DataFrame
|
25 |
+
|
26 |
+
|
27 |
+
async def ask_ipcc_query(query: str, index_state: int, user_id: str):
|
28 |
+
result = await ask_ipcc(query, index_state, user_id)
|
29 |
+
return result
|
30 |
+
|
31 |
+
def hide_outputs():
|
32 |
+
"""Hide all outputs initially."""
|
33 |
+
return (
|
34 |
+
gr.update(visible=True), # Show the result text
|
35 |
+
gr.update(visible=False), # Hide the query accordion
|
36 |
+
gr.update(visible=False), # Hide the table accordion
|
37 |
+
gr.update(visible=False), # Hide the chart accordion
|
38 |
+
gr.update(visible=False), # Hide table names
|
39 |
+
)
|
40 |
+
|
41 |
+
def show_results(sql_queries_state, dataframes_state, plots_state, table_names):
|
42 |
+
if not sql_queries_state or not dataframes_state or not plots_state:
|
43 |
+
# If all results are empty, show "No result"
|
44 |
+
return (
|
45 |
+
gr.update(visible=True),
|
46 |
+
gr.update(visible=False),
|
47 |
+
gr.update(visible=False),
|
48 |
+
gr.update(visible=False),
|
49 |
+
gr.update(visible=False),
|
50 |
+
)
|
51 |
+
else:
|
52 |
+
# Show the appropriate components with their data
|
53 |
+
return (
|
54 |
+
gr.update(visible=False),
|
55 |
+
gr.update(visible=True),
|
56 |
+
gr.update(visible=True),
|
57 |
+
gr.update(visible=True),
|
58 |
+
gr.update(choices=table_names, value=table_names[0], visible=True),
|
59 |
+
)
|
60 |
+
|
61 |
+
|
62 |
+
def show_filter_by_scenario(table_names, index_state, dataframes):
|
63 |
+
if len(table_names) > 0 and table_names[index_state].startswith("Map"):
|
64 |
+
df = dataframes[index_state]
|
65 |
+
scenarios = sorted(df["scenario"].unique())
|
66 |
+
return gr.update(visible=True, choices=scenarios, value=scenarios[0])
|
67 |
+
else:
|
68 |
+
return gr.update(visible=False)
|
69 |
+
|
70 |
+
def filter_by_scenario(dataframes, figures, table_names, index_state, scenario):
|
71 |
+
df = dataframes[index_state]
|
72 |
+
if not table_names[index_state].startswith("Map"):
|
73 |
+
return df, figures[index_state](df)
|
74 |
+
if df.empty:
|
75 |
+
return df, None
|
76 |
+
if "scenario" not in df.columns:
|
77 |
+
return df, figures[index_state](df)
|
78 |
+
else:
|
79 |
+
df = df[df["scenario"] == scenario]
|
80 |
+
if df.empty:
|
81 |
+
return df, None
|
82 |
+
figure = figures[index_state](df)
|
83 |
+
return df, figure
|
84 |
+
|
85 |
+
|
86 |
+
def display_table_names(table_names, index_state):
|
87 |
+
return [
|
88 |
+
[name]
|
89 |
+
for name in table_names
|
90 |
+
]
|
91 |
+
|
92 |
+
def on_table_click(selected_label, table_names, sql_queries, dataframes, plot_informations, plots):
|
93 |
+
index = table_names.index(selected_label)
|
94 |
+
figure = plots[index](dataframes[index])
|
95 |
+
|
96 |
+
return (
|
97 |
+
sql_queries[index],
|
98 |
+
dataframes[index],
|
99 |
+
figure,
|
100 |
+
plot_informations[index],
|
101 |
+
index,
|
102 |
+
)
|
103 |
+
|
104 |
+
|
105 |
+
def create_ipcc_ui() -> ipccUIElements:
|
106 |
+
|
107 |
+
"""Create and return all UI elements for the ipcc tab."""
|
108 |
+
with gr.Tab("(Beta) Talk to IPCC", elem_id="tab-vanna", id=7) as tab:
|
109 |
+
with gr.Accordion(label="❓ How to use?", elem_id="details") as details_accordion:
|
110 |
+
gr.Markdown(IPCC_UI_TEXT)
|
111 |
+
|
112 |
+
# Add examples for common questions
|
113 |
+
examples_hidden = gr.Textbox(visible=False, elem_id="ipcc-examples-hidden")
|
114 |
+
examples = gr.Examples(
|
115 |
+
examples=[
|
116 |
+
["What will the temperature be like in Paris?"],
|
117 |
+
["What will be the total rainfall in the USA in 2030?"],
|
118 |
+
["How will the average temperature evolve in China?"],
|
119 |
+
["What will be the average total precipitation in London ?"]
|
120 |
+
],
|
121 |
+
label="Example Questions",
|
122 |
+
inputs=[examples_hidden],
|
123 |
+
outputs=[examples_hidden],
|
124 |
+
)
|
125 |
+
|
126 |
+
with gr.Row():
|
127 |
+
ipcc_direct_question = gr.Textbox(
|
128 |
+
label="Direct Question",
|
129 |
+
placeholder="You can write direct question here",
|
130 |
+
elem_id="direct-question",
|
131 |
+
interactive=True,
|
132 |
+
)
|
133 |
+
|
134 |
+
with gr.Row(visible=True, elem_id="example-img-container") as image_examples:
|
135 |
+
gr.Markdown("### Examples of possible visualizations")
|
136 |
+
|
137 |
+
with gr.Row():
|
138 |
+
gr.Image("./front/assets/talk_to_ipcc_france_example.png", label="Total Precipitation in 2030 in France", elem_classes=["example-img"])
|
139 |
+
gr.Image("./front/assets/talk_to_ipcc_new_york_example.png", label="Yearly Evolution of Mean Temperature in New York (Historical + SSP Scenarios)", elem_classes=["example-img"])
|
140 |
+
gr.Image("./front/assets/talk_to_ipcc_china_example.png", label="Mean Temperature in 2050 in China", elem_classes=["example-img"])
|
141 |
+
|
142 |
+
result_text = gr.Textbox(
|
143 |
+
label="", elem_id="no-result-label", interactive=False, visible=True
|
144 |
+
)
|
145 |
+
with gr.Row():
|
146 |
+
table_names_display = gr.Radio(
|
147 |
+
choices=[],
|
148 |
+
label="Relevant figures created",
|
149 |
+
interactive=True,
|
150 |
+
elem_id="table-names",
|
151 |
+
visible=False
|
152 |
+
)
|
153 |
+
|
154 |
+
with gr.Accordion(label="SQL Query Used", visible=False) as query_accordion:
|
155 |
+
ipcc_sql_query = gr.Textbox(
|
156 |
+
label="", elem_id="sql-query", interactive=False
|
157 |
+
)
|
158 |
+
|
159 |
+
with gr.Accordion(label="Chart", visible=False) as chart_accordion:
|
160 |
+
|
161 |
+
with gr.Row():
|
162 |
+
scenario_selection = gr.Dropdown(
|
163 |
+
label="Scenario", choices=IPCC_SCENARIO, value=IPCC_SCENARIO[0], interactive=True, visible=False
|
164 |
+
)
|
165 |
+
|
166 |
+
with gr.Accordion(label="Informations about the plot", open=False):
|
167 |
+
plot_information = gr.Markdown(value = "")
|
168 |
+
|
169 |
+
ipcc_display = gr.Plot(elem_id="vanna-plot")
|
170 |
+
|
171 |
+
with gr.Accordion(
|
172 |
+
label="Data used", open=False, visible=False
|
173 |
+
) as table_accordion:
|
174 |
+
ipcc_table = gr.DataFrame([], elem_id="vanna-table")
|
175 |
+
|
176 |
+
|
177 |
+
return ipccUIElements(
|
178 |
+
tab=tab,
|
179 |
+
details_accordion=details_accordion,
|
180 |
+
examples_hidden=examples_hidden,
|
181 |
+
examples=examples,
|
182 |
+
image_examples=image_examples,
|
183 |
+
ipcc_direct_question=ipcc_direct_question,
|
184 |
+
result_text=result_text,
|
185 |
+
table_names_display=table_names_display,
|
186 |
+
query_accordion=query_accordion,
|
187 |
+
ipcc_sql_query=ipcc_sql_query,
|
188 |
+
chart_accordion=chart_accordion,
|
189 |
+
plot_information=plot_information,
|
190 |
+
scenario_selection=scenario_selection,
|
191 |
+
ipcc_display=ipcc_display,
|
192 |
+
table_accordion=table_accordion,
|
193 |
+
ipcc_table=ipcc_table,
|
194 |
+
)
|
195 |
+
|
196 |
+
|
197 |
+
|
198 |
+
def setup_ipcc_events(ui_elements: ipccUIElements, share_client=None, user_id=None) -> None:
|
199 |
+
"""Set up all event handlers for the ipcc tab."""
|
200 |
+
# Create state variables
|
201 |
+
sql_queries_state = gr.State([])
|
202 |
+
dataframes_state = gr.State([])
|
203 |
+
plots_state = gr.State([])
|
204 |
+
plot_informations_state = gr.State([])
|
205 |
+
index_state = gr.State(0)
|
206 |
+
table_names_list = gr.State([])
|
207 |
+
user_id = gr.State(user_id)
|
208 |
+
|
209 |
+
# Handle direct question submission - trigger the same workflow by setting examples_hidden
|
210 |
+
ui_elements["ipcc_direct_question"].submit(
|
211 |
+
lambda x: gr.update(value=x),
|
212 |
+
inputs=[ui_elements["ipcc_direct_question"]],
|
213 |
+
outputs=[ui_elements["examples_hidden"]],
|
214 |
+
)
|
215 |
+
|
216 |
+
# Handle example selection
|
217 |
+
ui_elements["examples_hidden"].change(
|
218 |
+
lambda x: (gr.Accordion(open=False), gr.Textbox(value=x)),
|
219 |
+
inputs=[ui_elements["examples_hidden"]],
|
220 |
+
outputs=[ui_elements["details_accordion"], ui_elements["ipcc_direct_question"]]
|
221 |
+
).then(
|
222 |
+
lambda : gr.update(visible=False),
|
223 |
+
inputs=None,
|
224 |
+
outputs=ui_elements["image_examples"]
|
225 |
+
).then(
|
226 |
+
hide_outputs,
|
227 |
+
inputs=None,
|
228 |
+
outputs=[
|
229 |
+
ui_elements["result_text"],
|
230 |
+
ui_elements["query_accordion"],
|
231 |
+
ui_elements["table_accordion"],
|
232 |
+
ui_elements["chart_accordion"],
|
233 |
+
ui_elements["table_names_display"],
|
234 |
+
]
|
235 |
+
).then(
|
236 |
+
ask_ipcc_query,
|
237 |
+
inputs=[ui_elements["examples_hidden"], index_state, user_id],
|
238 |
+
outputs=[
|
239 |
+
ui_elements["ipcc_sql_query"],
|
240 |
+
ui_elements["ipcc_table"],
|
241 |
+
ui_elements["ipcc_display"],
|
242 |
+
ui_elements["plot_information"],
|
243 |
+
sql_queries_state,
|
244 |
+
dataframes_state,
|
245 |
+
plots_state,
|
246 |
+
plot_informations_state,
|
247 |
+
index_state,
|
248 |
+
table_names_list,
|
249 |
+
ui_elements["result_text"],
|
250 |
+
],
|
251 |
+
).then(
|
252 |
+
show_results,
|
253 |
+
inputs=[sql_queries_state, dataframes_state, plots_state, table_names_list],
|
254 |
+
outputs=[
|
255 |
+
ui_elements["result_text"],
|
256 |
+
ui_elements["query_accordion"],
|
257 |
+
ui_elements["table_accordion"],
|
258 |
+
ui_elements["chart_accordion"],
|
259 |
+
ui_elements["table_names_display"],
|
260 |
+
],
|
261 |
+
).then(
|
262 |
+
show_filter_by_scenario,
|
263 |
+
inputs=[table_names_list, index_state, dataframes_state],
|
264 |
+
outputs=[ui_elements["scenario_selection"]],
|
265 |
+
).then(
|
266 |
+
filter_by_scenario,
|
267 |
+
inputs=[dataframes_state, plots_state, table_names_list, index_state, ui_elements["scenario_selection"]],
|
268 |
+
outputs=[ui_elements["ipcc_table"], ui_elements["ipcc_display"]],
|
269 |
+
)
|
270 |
+
|
271 |
+
|
272 |
+
# Handle model selection change
|
273 |
+
ui_elements["scenario_selection"].change(
|
274 |
+
filter_by_scenario,
|
275 |
+
inputs=[dataframes_state, plots_state, table_names_list, index_state, ui_elements["scenario_selection"]],
|
276 |
+
outputs=[ui_elements["ipcc_table"], ui_elements["ipcc_display"]],
|
277 |
+
)
|
278 |
+
|
279 |
+
# Handle table selection
|
280 |
+
ui_elements["table_names_display"].change(
|
281 |
+
fn=on_table_click,
|
282 |
+
inputs=[ui_elements["table_names_display"], table_names_list, sql_queries_state, dataframes_state, plot_informations_state, plots_state],
|
283 |
+
outputs=[ui_elements["ipcc_sql_query"], ui_elements["ipcc_table"], ui_elements["ipcc_display"], ui_elements["plot_information"], index_state],
|
284 |
+
).then(
|
285 |
+
show_filter_by_scenario,
|
286 |
+
inputs=[table_names_list, index_state, dataframes_state],
|
287 |
+
outputs=[ui_elements["scenario_selection"]],
|
288 |
+
).then(
|
289 |
+
filter_by_scenario,
|
290 |
+
inputs=[dataframes_state, plots_state, table_names_list, index_state, ui_elements["scenario_selection"]],
|
291 |
+
outputs=[ui_elements["ipcc_table"], ui_elements["ipcc_display"]],
|
292 |
+
)
|
293 |
+
|
294 |
+
|
295 |
+
def create_ipcc_tab(share_client=None, user_id=None):
|
296 |
+
"""Create the ipcc tab with all its components and event handlers."""
|
297 |
+
ui_elements = create_ipcc_ui()
|
298 |
+
setup_ipcc_events(ui_elements, share_client=share_client, user_id=user_id)
|
299 |
+
|
300 |
+
|
@@ -25,4 +25,5 @@ geopy==2.4.1
|
|
25 |
duckdb==1.2.1
|
26 |
openai==1.61.1
|
27 |
pydantic==2.9.2
|
28 |
-
pydantic-settings==2.2.1
|
|
|
|
25 |
duckdb==1.2.1
|
26 |
openai==1.61.1
|
27 |
pydantic==2.9.2
|
28 |
+
pydantic-settings==2.2.1
|
29 |
+
geojson==3.2.0
|
@@ -656,12 +656,20 @@ a {
|
|
656 |
/* overflow-y: scroll; */
|
657 |
}
|
658 |
#sql-query{
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
659 |
max-height: 300px;
|
660 |
overflow-y:scroll;
|
661 |
}
|
662 |
|
663 |
#sql-query textarea{
|
664 |
min-height: 100px !important;
|
|
|
665 |
}
|
666 |
|
667 |
#sql-query span{
|
@@ -671,8 +679,11 @@ div#tab-vanna{
|
|
671 |
max-height: 100¨vh;
|
672 |
overflow-y: hidden;
|
673 |
}
|
|
|
|
|
|
|
674 |
#vanna-plot{
|
675 |
-
max-height:
|
676 |
}
|
677 |
|
678 |
#pagination-display{
|
@@ -681,13 +692,33 @@ div#tab-vanna{
|
|
681 |
font-size: 16px;
|
682 |
}
|
683 |
|
684 |
-
|
685 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
686 |
}
|
687 |
-
|
|
|
|
|
|
|
|
|
|
|
688 |
display: none;
|
689 |
}
|
690 |
|
|
|
|
|
|
|
|
|
|
|
691 |
/* DRIAS Data Table Styles */
|
692 |
#vanna-table {
|
693 |
height: 400px !important;
|
@@ -710,3 +741,13 @@ div#tab-vanna{
|
|
710 |
background: white;
|
711 |
z-index: 1;
|
712 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
656 |
/* overflow-y: scroll; */
|
657 |
}
|
658 |
#sql-query{
|
659 |
+
<<<<<<< HEAD
|
660 |
+
max-height: 100%;
|
661 |
+
}
|
662 |
+
|
663 |
+
#sql-query textarea{
|
664 |
+
min-height: 200px !important;
|
665 |
+
=======
|
666 |
max-height: 300px;
|
667 |
overflow-y:scroll;
|
668 |
}
|
669 |
|
670 |
#sql-query textarea{
|
671 |
min-height: 100px !important;
|
672 |
+
>>>>>>> hf-origin/main
|
673 |
}
|
674 |
|
675 |
#sql-query span{
|
|
|
679 |
max-height: 100¨vh;
|
680 |
overflow-y: hidden;
|
681 |
}
|
682 |
+
#details button span{
|
683 |
+
font-weight: bold;
|
684 |
+
}
|
685 |
#vanna-plot{
|
686 |
+
max-height:1000px
|
687 |
}
|
688 |
|
689 |
#pagination-display{
|
|
|
692 |
font-size: 16px;
|
693 |
}
|
694 |
|
695 |
+
|
696 |
+
#table-names label {
|
697 |
+
display: block;
|
698 |
+
width: 100%;
|
699 |
+
box-sizing: border-box;
|
700 |
+
padding: 8px 12px;
|
701 |
+
margin-bottom: 4px;
|
702 |
+
border: 1px solid #ccc;
|
703 |
+
border-radius: 6px;
|
704 |
+
background-color: white;
|
705 |
+
cursor: pointer;
|
706 |
+
text-align: center;
|
707 |
}
|
708 |
+
|
709 |
+
#table-names label:hover {
|
710 |
+
background-color: #f0f8ff;
|
711 |
+
}
|
712 |
+
|
713 |
+
#table-names input[type="radio"] {
|
714 |
display: none;
|
715 |
}
|
716 |
|
717 |
+
#table-names input[type="radio"]:checked + label {
|
718 |
+
background-color: #d0eaff;
|
719 |
+
border-color: #2196f3;
|
720 |
+
}
|
721 |
+
|
722 |
/* DRIAS Data Table Styles */
|
723 |
#vanna-table {
|
724 |
height: 400px !important;
|
|
|
741 |
background: white;
|
742 |
z-index: 1;
|
743 |
}
|
744 |
+
|
745 |
+
.example-img{
|
746 |
+
height: 250px;
|
747 |
+
object-fit: contain;
|
748 |
+
}
|
749 |
+
|
750 |
+
#example-img-container {
|
751 |
+
flex-direction: column;
|
752 |
+
align-items: left;
|
753 |
+
}
|