timeki armanddemasson commited on
Commit
711bc31
·
verified ·
1 Parent(s): c6723cc

talk_to_ipcc (#29)

Browse files

- change local examples (3fa35bfc2ec4e3397b571eecc44a68b1a1d1152c)
- fix (c5c15575b0a549578448fa4ae916af294cd4bdc4)
- Upload drias.db (7ea3b8dc6dcc209c74ac86140c23c3f35356b271)
- feature/add_talk_to_data (#23) (dac62987820b2d821cda7ee86bca058bcf66a66b)
- Change openai_key (680ed1c172d20c1280a46077a72c2eb7be02bd10)
- update OpenAI usage from Vanna (3e75ed8cd2d2f444a43cc93c814d4392f991d8dc)
- Small clean POC Local (0327516552f1b9af184361e8594b7377a4e0bcb2)
- Clean configs (ecab0c839ce7fc8100121a0174409451dbb5dd50)
- fix : fix gradio component (1e74d1402fb3c834dc9c0163aeb5028dbd6764bd)
- take the last question as history to understand the question (676d17bd25bb4cbd8cf9396ee353aac7a7028af0)
- Add follow up questions (aaa4bbedbbdee9951a33c8d7bbfd7583fa441bc5)
- Fix : Dynamic follow up examples (a984d58b3e8cf8458eec39ed29d0267590f8e75a)
- Merged in feature/dynamic_conversation (pull request #1) (f684aff68309e2f14164cef7fd02a6da804dcef9)
- Update style.css (e8258aea3313f558bb0afa650464d740fa4b2da4)
- Merge branch 'dev' of https://bitbucket.org/ekimetrics/climate_qa into dev (03a8baf608fd8a639c96a661725c181660122aa1)
- feat: implemented talk to drias v1 (4df74e4b75096daf90548ae7eefe6c32df002bcd)
- feat: added 2 new talk to data plots (170018666950c1b7433f6baba17ad36042282a2e)
- feat: added drias model choice and changed TTD UI (1eae86b617a8336ee3e4b7b11e39dc0769731812)
- fix: fixed bugs and errors (e6e652c91a221b56eb508cc98b13d27b5785e803)
- Merge branch 'dev' into feat/talk_to_data_graph (561caf15f2fd2954b43a49064ac11a085d343151)
- update css for follow up examples (54e2358ad0444473be0ac8d6c436ef73cfcceb3b)
- ensure correct output language (723be3263bfd1a048f671b7ef11219de2c10b736)
- feat: model filtering and UI upgrade for TTD (6155a631718e808fb6622083848a5fc3bf231937)
- feat: added list of tables and reduced execution time of TTD (0bdf2f6f90901ed5cd65fc99d4de1da8eab1788a)
- add documentation (161aa8c18cb4be2d0c5df0dc64f3c202e55c582d)
- Move hardcoded configuration in a config file (45a93206d1c3507a622ba4f40df17a339d5a8bad)
- make ask drias asynchronous (5fe15430ebbb47868615f715c00d6274fad0e7a7)
- Add drias indicators (989d387afc62cb38ae239e5cc97664f702e3fa54)
- Add examples (d4fa76b2cb0a8ed1ae0f482254acfbd0242fdc73)
- UI improvment (1d710d6428b17b25b42e14fa494c2823b7e74e81)
- split front element and event listening (ca2c42949b3de84aeaaabf5879f107cd829523a0)
- Merged in feat/talk_to_data_graph (pull request #3) (f2baf8741a4c6b9712f2cd9cec2f86ddb4ca4274)
- Merge branch 'dev' of https://bitbucket.org/ekimetrics/climate_qa into dev (5c04812e26e5faa6ee2a85867f0155c1e56b96a0)
- add logs of drias interactions (bc43b45669dfd4740dc90d6daa1128278857d5d3)
- Merged in dev (pull request #4) (6af9e984b3a852db3414a96307040d2bbc031871)
- rename tabs for prod (b09184847c18b284fd0d08532a1572722f9159d5)
- V1.7 - Dynamic conversation & France local QA (#24) (6f13540b2041360bcd5d6eab4d699df28d96abce)
- Fix for python 3.10 (b142280be035c6ee3a6496f98cac37b23b38e18d)
- Fix requirements (f1b3e893c4b744376eec5ae30e3ed2e6c11fd8cf)
- update requirements (f46a24ac65bd569cad69d2281c5b5943a9a0830e)
- fixed requiremets (9ab8c871a2c9a8d6bf54d51d32666e973e46b4ee)
- Merge branch 'main' of https://huggingface.co/spaces/Ekimetrics/climate-question-answering (584bae4c2796688e9ecfdd393c693ce0e0d18fc5)
- fix tab name (0a2f32b117ce6289b5c49ab68a7aaedc502f5eca)
- Merge remote-tracking branch 'hf-origin/main' (0afa3dbe6042da779ac4911e23ce5cf7a4d40065)
- TTD : remove first unrelevant points (535d47236a99ebdc1e297eeb35b8833498214af2)
- Fixed follow up examples for local qa (c56c4345fd9484d8d43fd4ee36d8a4a0c3f6bd67)
- change init prompt to english (8c7a7fed7bddc7e672a257ab80f0840a6977de17)
- log to huggingface (f9c4c84a71d320c6db05ee099b9e35492ba7b184)
- Merged in feat/logs_on_huggingface (pull request #5) (261632833e7d939c321f208edfd6576e68947b4b)
- feat: added multithreading to run sql queries in talk to drias (705ccece7775c65a9c7b73b091cdddbc4246f2e7)
- chore: remove prints in talk to drias workflow (a967134f90c70d87bb3d786f2feb81f2e56fdb9f)
- Merged in feat/improve_drias_exeuction_time (pull request #6) (7c38528636ac19c1efa240a122e523ed0c34706a)
- fix import (05b8df9c9b74926459da70797b6852ff07a4d838)
- Merge branch 'main' into dev (8fb231c8beabf8a6406f05cf4cac564c5d81c7ce)
- Merged in dev (pull request #7) (6b9f71b1cf216eef0fd7973412f675d8633a5f4a)
- fix import (b35df2a8160723e43f74a040aa94983069066213)
- Merge branch 'main' of https://bitbucket.org/ekimetrics/climate_qa (f96cfd0715ec2b1ed7a78775ea7f8722f5793d8f)
- feature/drias_parallelization (#25) (47a68b0f6974850c8ec2465bdeda32c731e94568)
- feat: changed talk to drias UI (14d7085602602feef3ce51430f34221dadaa21df)
- temp fix : reranker switch to nano (99302fc182b3b15e3c15a03a633b46adf31e088f)
- Merged in dev (pull request #9) (bbc7806f08a3754842b0c228035c5fe985d45cc3)
- Merge remote-tracking branch 'hf-origin/main' (1b9b1eed6d3a82a9b23791df13f180181e6baa31)
- refactor: modularized talk to data (46c1e346e695ba31e91a6495a57a9063cf4dad54)
- Merge remote-tracking branch 'origin' into refactor/ttd_modularization (c25f6b1dd1d359574c77810933bcd01119f8a5b3)
- Standardize loggin (5b1b83bccbc4b9ae9021e29b43d822cf5b013e7d)
- Merge branch 'dev' into feat/logs_on_huggingface (32d594b10afd7196a7a83f9d9cb6a169a1c5ef51)
- Update constants.py (7a891a71dd288d985b940a61e3df26b34bb496e3)
- correct logs formating for dataviewer (2de91ee28e1168621896119cf65e44608f03000f)
- change logs repo (5f2bda40adc7025b6828daa3c8e65604f10cffaa)
- Merged in feat/logs_on_huggingface (pull request #11) (f52ad179f835a04bc5f6b9880699c7889cf0fc6c)
- Update logging.py (3ff380ff39b5417cda20a8f953301b6f1b833800)
- Merged in feat/logs_on_huggingface (pull request #12) (7c8dc1752e212554fba0a9bbf38d2552af5f72e1)
- Merged in main (pull request #13) (9a9ab642ce3a6361946f3d6ef5f8282d893d6a24)
- Merged in dev (pull request #14) (97d80b660352d1e09bfd697e86d4a7b4f6386571)
- add logs fallback on azure (58f35b70200236190bc8f57f5ec75a72573b31a1)
- Merged in feat/logs_on_huggingface (pull request #15) (e92e8dc1b629f6e9c61657683419e6585fd54261)
- Merged in dev (pull request #16) (eae39a6274f762fdfdfc49e86680d27ce7c6c8fc)
- feat: updated common talk to data for talk to ipcc and drias (c0fd277fca5091c51086af4aeaaa36f625531cef)
- feat: created queries for talk to ipcc (e8d5bc9410f6777c62dcacb77160ea08ac172e31)
- feat: created plots for talk to ipcc (3f85f963aef03ebf32f77836652ce1926778958b)
- feat: created config constants for talk to ipcc (c3398f4493e79a39702edcfa436a860312edf3e5)
- implemented talk to ipcc workflow and updated talk to data state object (45e1dba85d72fed8767b1e369934cf13639a3242)
- Merged in feature/talk_to_ipcc (pull request #17) (ac63459d9126fed1121796aaee4cb3adfbc955d9)
- feat: updated talk to drias based on talk to ipcc (c3024c3df719b25b6aee66dda3aed3927bb83077)
- Merge remote-tracking branch 'origin' into feature/talk_to_data (4af247227e85f9b72b159b213d3a1eccee7b6733)
- chore: added geojson to requirements (1a4f8680944e738da3237b7d6d03b0fd430f186c)
- fix: fix markdown bug in IPCC_UI_TEXT (d3040d364345ed73e9ac1657c5613e1062e728b7)
- chore: added geojson polygon function and custom colorscale per indicator (30f401b002885a71bb62069ec17bd2619f98961e)
- feat: added plot informations for each plot (DRIAS & IPCC) (bdeb1e525e4c80ce990329052a2c172ad9847113)
- feat: updated TTD UI (DRIAS & IPCC) (fdad3be2df6ab94da4b7847a081c3c0c1b67eaf2)
- fix: fixed graphs display bugs (a4206dc31d5e238ee26876faf91473a202eafec1)
- feat: updated sql query for not macro countries in talk to ipcc (ff653468b4f2bfe2e166ad572ca3c91cb5fb986f)
- refactor: changed choropleth map into map (f9508fce0c7db4822b33a8ba95fb048818fd0840)
- feat: changed marker opacity on map plots (7374d87d5c8521f7ff418eb04636558ee6dcfc14)
- chore: added information about grid point mapping to country inside plot information for map (29cf97a04f9def87a065f05139cdb2374b47d07d)
- feat: changed queries for macro countries in talk to ipcc (d0f045617aefd9a28811c567ecda0d1ff70ce548)
- fix: fixed submit bug in talk to drias (01a4939d78358d57f94232827cfe197c22678258)
- feat: added queries on huge countries in talk to ipcc (19905de878ff990f1beff50d76a8eb11541e46b9)
- chore: changed talk to ipcc how to use text (5bd3f8cdd51ef4bbd47b4524ab5406434801a048)
- replace eval by litteral eval (2d6b3b9f22416c157a60ecc2616346dc975d6861)
- fix retrieve documents error for search without reranker (0536821aa63e34759510b5d55505da8df32bae8f)
- Move vanna files in dedicated folder (fa2765ac9a08e5d00439ff752159cf2d22b10d1f)
- move path into config (0695fe4d36575aa5881f981e16897a7bc3a4d45d)
- remove unnecessary imports (819e3c05730ca9e4b1a6526f00d961f1509e3b0d)
- Merged in feature/talk_to_data (pull request #19) (11ab5fbc82c47ac8524b62d4c693eaa376dd149f)
- rename hf_token name (d78ea762bb16187b8755d5c07d90827c10ccf127)
- add lfs images (a9b4fcc0bd5fafd39104eae0074cad2b51bf1341)
- Add PNG assets (11910c77dff6e0795b7cbecb2e692645ef4e14f0)
- Merge hf-origin/main into main (1ed4e25589d82fe195eab13e7249e3889cecac4b)


Co-authored-by: Armand Demasson <[email protected]>

Files changed (35) hide show
  1. .gitattributes +1 -0
  2. app.py +9 -0
  3. climateqa/engine/chains/retrieve_documents.py +6 -4
  4. climateqa/engine/talk_to_data/config.py +8 -96
  5. climateqa/engine/talk_to_data/drias/config.py +124 -0
  6. climateqa/engine/talk_to_data/drias/plot_informations.py +88 -0
  7. climateqa/engine/talk_to_data/drias/plots.py +434 -0
  8. climateqa/engine/talk_to_data/drias/queries.py +83 -0
  9. climateqa/engine/talk_to_data/input_processing.py +257 -0
  10. climateqa/engine/talk_to_data/ipcc/config.py +98 -0
  11. climateqa/engine/talk_to_data/ipcc/plot_informations.py +50 -0
  12. climateqa/engine/talk_to_data/ipcc/plots.py +189 -0
  13. climateqa/engine/talk_to_data/ipcc/queries.py +143 -0
  14. climateqa/engine/talk_to_data/main.py +77 -71
  15. climateqa/engine/talk_to_data/objects/llm_outputs.py +13 -0
  16. climateqa/engine/talk_to_data/objects/location.py +12 -0
  17. climateqa/engine/talk_to_data/objects/plot.py +23 -0
  18. climateqa/engine/talk_to_data/objects/states.py +19 -0
  19. climateqa/engine/talk_to_data/prompt.py +44 -0
  20. climateqa/engine/talk_to_data/query.py +57 -0
  21. climateqa/engine/talk_to_data/ui_config.py +27 -0
  22. climateqa/engine/talk_to_data/vanna/myVanna.py +13 -0
  23. climateqa/engine/talk_to_data/vanna/vanna_class.py +325 -0
  24. climateqa/engine/talk_to_data/workflow/drias.py +163 -0
  25. climateqa/engine/talk_to_data/workflow/ipcc.py +161 -0
  26. front/assets/talk_to_drias_annual_temperature_france_example.png +3 -0
  27. front/assets/talk_to_drias_frequency_remarkable_precipitation_lyon_example.png +3 -0
  28. front/assets/talk_to_drias_winter_temp_paris_example.png +3 -0
  29. front/assets/talk_to_ipcc_china_example.png +3 -0
  30. front/assets/talk_to_ipcc_france_example.png +3 -0
  31. front/assets/talk_to_ipcc_new_york_example.png +3 -0
  32. front/tabs/tab_drias.py +60 -149
  33. front/tabs/tab_ipcc.py +300 -0
  34. requirements.txt +2 -1
  35. style.css +45 -4
.gitattributes CHANGED
@@ -45,3 +45,4 @@ documents/climate_gpt_v2.faiss filter=lfs diff=lfs merge=lfs -text
45
  climateqa_v3.db filter=lfs diff=lfs merge=lfs -text
46
  climateqa_v3.faiss filter=lfs diff=lfs merge=lfs -text
47
  data/drias/drias.db filter=lfs diff=lfs merge=lfs -text
 
 
45
  climateqa_v3.db filter=lfs diff=lfs merge=lfs -text
46
  climateqa_v3.faiss filter=lfs diff=lfs merge=lfs -text
47
  data/drias/drias.db filter=lfs diff=lfs merge=lfs -text
48
+ front/assets/*.png filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -16,6 +16,10 @@ from climateqa.chat import start_chat, chat_stream, finish_chat
16
  from front.tabs import create_config_modal, cqa_tab, create_about_tab
17
  from front.tabs import MainTabPanel, ConfigPanel
18
  from front.tabs.tab_drias import create_drias_tab
 
 
 
 
19
  from front.utils import process_figures
20
  from gradio_modal import Modal
21
 
@@ -532,8 +536,13 @@ def main_ui():
532
  with gr.Tabs():
533
  cqa_components = cqa_tab(tab_name="ClimateQ&A")
534
  local_cqa_components = cqa_tab(tab_name="France - Local Q&A")
 
 
 
 
535
  create_drias_tab(share_client=share_client, user_id=user_id)
536
 
 
537
  create_about_tab()
538
 
539
  event_handling(cqa_components, config_components, tab_name="ClimateQ&A")
 
16
  from front.tabs import create_config_modal, cqa_tab, create_about_tab
17
  from front.tabs import MainTabPanel, ConfigPanel
18
  from front.tabs.tab_drias import create_drias_tab
19
+ <<<<<<< HEAD
20
+ from front.tabs.tab_ipcc import create_ipcc_tab
21
+ =======
22
+ >>>>>>> hf-origin/main
23
  from front.utils import process_figures
24
  from gradio_modal import Modal
25
 
 
536
  with gr.Tabs():
537
  cqa_components = cqa_tab(tab_name="ClimateQ&A")
538
  local_cqa_components = cqa_tab(tab_name="France - Local Q&A")
539
+ <<<<<<< HEAD
540
+ drias_components = create_drias_tab(share_client=share_client, user_id=user_id)
541
+ ipcc_components = create_ipcc_tab(share_client=share_client, user_id=user_id)
542
+ =======
543
  create_drias_tab(share_client=share_client, user_id=user_id)
544
 
545
+ >>>>>>> hf-origin/main
546
  create_about_tab()
547
 
548
  event_handling(cqa_components, config_components, tab_name="ClimateQ&A")
climateqa/engine/chains/retrieve_documents.py CHANGED
@@ -21,7 +21,7 @@ from langchain_core.prompts import ChatPromptTemplate
21
  from langchain_core.output_parsers import StrOutputParser
22
  from ..vectorstore import get_pinecone_vectorstore
23
  from ..embeddings import get_embeddings_function
24
-
25
 
26
  import asyncio
27
 
@@ -477,8 +477,10 @@ async def retrieve_documents(
477
  docs_question_dict[key] = rerank_and_sort_docs(reranker,docs_question_dict[key],question)
478
  else:
479
  # Add a default reranking score
480
- for doc in docs_question:
481
- doc.metadata["reranking_score"] = doc.metadata["similarity_score"]
 
 
482
 
483
  # Keep the right number of documents
484
  docs_question, images_question = concatenate_documents(index, source_type, docs_question_dict, k_by_question, k_summary_by_question, k_images_by_question)
@@ -580,7 +582,7 @@ async def get_relevant_toc_level_for_query(
580
  response = chain.invoke({"query": query, "doc_list": doc_list})
581
 
582
  try:
583
- relevant_tocs = eval(response)
584
  except Exception as e:
585
  print(f" Failed to parse the result because of : {e}")
586
 
 
21
  from langchain_core.output_parsers import StrOutputParser
22
  from ..vectorstore import get_pinecone_vectorstore
23
  from ..embeddings import get_embeddings_function
24
+ import ast
25
 
26
  import asyncio
27
 
 
477
  docs_question_dict[key] = rerank_and_sort_docs(reranker,docs_question_dict[key],question)
478
  else:
479
  # Add a default reranking score
480
+ for key in docs_question_dict.keys():
481
+ if isinstance(docs_question_dict[key], list) and len(docs_question_dict[key]) > 0:
482
+ for doc in docs_question_dict[key]:
483
+ doc.metadata["reranking_score"] = doc.metadata["similarity_score"]
484
 
485
  # Keep the right number of documents
486
  docs_question, images_question = concatenate_documents(index, source_type, docs_question_dict, k_by_question, k_summary_by_question, k_images_by_question)
 
582
  response = chain.invoke({"query": query, "doc_list": doc_list})
583
 
584
  try:
585
+ relevant_tocs = ast.literal_eval(response)
586
  except Exception as e:
587
  print(f" Failed to parse the result because of : {e}")
588
 
climateqa/engine/talk_to_data/config.py CHANGED
@@ -1,99 +1,11 @@
1
- DRIAS_TABLES = [
2
- "total_winter_precipitation",
3
- "total_summer_precipiation",
4
- "total_annual_precipitation",
5
- "total_remarkable_daily_precipitation",
6
- "frequency_of_remarkable_daily_precipitation",
7
- "extreme_precipitation_intensity",
8
- "mean_winter_temperature",
9
- "mean_summer_temperature",
10
- "mean_annual_temperature",
11
- "number_of_tropical_nights",
12
- "maximum_summer_temperature",
13
- "number_of_days_with_tx_above_30",
14
- "number_of_days_with_tx_above_35",
15
- "number_of_days_with_a_dry_ground",
16
- ]
17
 
18
- INDICATOR_COLUMNS_PER_TABLE = {
19
- "total_winter_precipitation": "total_winter_precipitation",
20
- "total_summer_precipiation": "total_summer_precipitation",
21
- "total_annual_precipitation": "total_annual_precipitation",
22
- "total_remarkable_daily_precipitation": "total_remarkable_daily_precipitation",
23
- "frequency_of_remarkable_daily_precipitation": "frequency_of_remarkable_daily_precipitation",
24
- "extreme_precipitation_intensity": "extreme_precipitation_intensity",
25
- "mean_winter_temperature": "mean_winter_temperature",
26
- "mean_summer_temperature": "mean_summer_temperature",
27
- "mean_annual_temperature": "mean_annual_temperature",
28
- "number_of_tropical_nights": "number_tropical_nights",
29
- "maximum_summer_temperature": "maximum_summer_temperature",
30
- "number_of_days_with_tx_above_30": "number_of_days_with_tx_above_30",
31
- "number_of_days_with_tx_above_35": "number_of_days_with_tx_above_35",
32
- "number_of_days_with_a_dry_ground": "number_of_days_with_dry_ground"
33
- }
34
 
35
- DRIAS_MODELS = [
36
- 'ALL',
37
- 'RegCM4-6_MPI-ESM-LR',
38
- 'RACMO22E_EC-EARTH',
39
- 'RegCM4-6_HadGEM2-ES',
40
- 'HadREM3-GA7_EC-EARTH',
41
- 'HadREM3-GA7_CNRM-CM5',
42
- 'REMO2015_NorESM1-M',
43
- 'SMHI-RCA4_EC-EARTH',
44
- 'WRF381P_NorESM1-M',
45
- 'ALADIN63_CNRM-CM5',
46
- 'CCLM4-8-17_MPI-ESM-LR',
47
- 'HIRHAM5_IPSL-CM5A-MR',
48
- 'HadREM3-GA7_HadGEM2-ES',
49
- 'SMHI-RCA4_IPSL-CM5A-MR',
50
- 'HIRHAM5_NorESM1-M',
51
- 'REMO2009_MPI-ESM-LR',
52
- 'CCLM4-8-17_HadGEM2-ES'
53
- ]
54
- # Mapping between indicator columns and their units
55
- INDICATOR_TO_UNIT = {
56
- "total_winter_precipitation": "mm",
57
- "total_summer_precipitation": "mm",
58
- "total_annual_precipitation": "mm",
59
- "total_remarkable_daily_precipitation": "mm",
60
- "frequency_of_remarkable_daily_precipitation": "days",
61
- "extreme_precipitation_intensity": "mm",
62
- "mean_winter_temperature": "°C",
63
- "mean_summer_temperature": "°C",
64
- "mean_annual_temperature": "°C",
65
- "number_tropical_nights": "days",
66
- "maximum_summer_temperature": "°C",
67
- "number_of_days_with_tx_above_30": "days",
68
- "number_of_days_with_tx_above_35": "days",
69
- "number_of_days_with_dry_ground": "days"
70
- }
71
 
72
- DRIAS_UI_TEXT = """
73
- Hi, I'm **Talk to Drias**, designed to answer your questions using [**DRIAS - TRACC 2023**](https://www.drias-climat.fr/accompagnement/sections/401) data.
74
- I'll answer by displaying a list of SQL queries, graphs and data most relevant to your question.
75
-
76
- ❓ **How to use?**
77
- You can ask me anything about these climate indicators: **temperature**, **precipitation** or **drought**.
78
- You can specify **location** and/or **year**.
79
- You can choose from a list of climate models. By default, we take the **average of each model**.
80
-
81
- For example, you can ask:
82
- - What will the temperature be like in Paris?
83
- - What will be the total rainfall in France in 2030?
84
- - How frequent will extreme events be in Lyon?
85
-
86
- **Example of indicators in the data**:
87
- - Mean temperature (annual, winter, summer)
88
- - Total precipitation (annual, winter, summer)
89
- - Number of days with remarkable precipitations, with dry ground, with temperature above 30°C
90
-
91
- ⚠️ **Limitations**:
92
- - You can't ask anything that isn't related to **DRIAS - TRACC 2023** data.
93
- - You can only ask about **locations in France**.
94
- - If you specify a year, there may be **no data for that year for some models**.
95
- - You **cannot compare two models**.
96
-
97
- 🛈 **Information**
98
- Please note that we **log your questions for meta-analysis purposes**, so avoid sharing any sensitive or personal information.
99
- """
 
1
+ # Path configuration for climateqa project
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
+ # IPCC dataset path
4
+ IPCC_DATASET_URL = "hf://datasets/ekimetrics/ipcc-atlas"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
+ # DRIAS dataset paths
7
+ DRIAS_DATASET_URL = "hf://datasets/timeki/drias_db"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
+ # Table paths
10
+ DRIAS_MEAN_ANNUAL_TEMPERATURE_PATH = f"{DRIAS_DATASET_URL}/mean_annual_temperature.parquet"
11
+ IPCC_COORDINATES_PATH = f"{IPCC_DATASET_URL}/coordinates.parquet"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
climateqa/engine/talk_to_data/drias/config.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from climateqa.engine.talk_to_data.ui_config import PRECIPITATION_COLORSCALE, TEMPERATURE_COLORSCALE
3
+
4
+
5
+ DRIAS_TABLES = [
6
+ "total_winter_precipitation",
7
+ "total_summer_precipitation",
8
+ "total_annual_precipitation",
9
+ "total_remarkable_daily_precipitation",
10
+ "frequency_of_remarkable_daily_precipitation",
11
+ "extreme_precipitation_intensity",
12
+ "mean_winter_temperature",
13
+ "mean_summer_temperature",
14
+ "mean_annual_temperature",
15
+ "number_of_tropical_nights",
16
+ "maximum_summer_temperature",
17
+ "number_of_days_with_tx_above_30",
18
+ "number_of_days_with_tx_above_35",
19
+ "number_of_days_with_a_dry_ground",
20
+ ]
21
+
22
+ DRIAS_INDICATOR_COLUMNS_PER_TABLE = {
23
+ "total_winter_precipitation": "total_winter_precipitation",
24
+ "total_summer_precipitation": "total_summer_precipitation",
25
+ "total_annual_precipitation": "total_annual_precipitation",
26
+ "total_remarkable_daily_precipitation": "total_remarkable_daily_precipitation",
27
+ "frequency_of_remarkable_daily_precipitation": "frequency_of_remarkable_daily_precipitation",
28
+ "extreme_precipitation_intensity": "extreme_precipitation_intensity",
29
+ "mean_winter_temperature": "mean_winter_temperature",
30
+ "mean_summer_temperature": "mean_summer_temperature",
31
+ "mean_annual_temperature": "mean_annual_temperature",
32
+ "number_of_tropical_nights": "number_tropical_nights",
33
+ "maximum_summer_temperature": "maximum_summer_temperature",
34
+ "number_of_days_with_tx_above_30": "number_of_days_with_tx_above_30",
35
+ "number_of_days_with_tx_above_35": "number_of_days_with_tx_above_35",
36
+ "number_of_days_with_a_dry_ground": "number_of_days_with_dry_ground"
37
+ }
38
+
39
+ DRIAS_MODELS = [
40
+ 'ALL',
41
+ 'RegCM4-6_MPI-ESM-LR',
42
+ 'RACMO22E_EC-EARTH',
43
+ 'RegCM4-6_HadGEM2-ES',
44
+ 'HadREM3-GA7_EC-EARTH',
45
+ 'HadREM3-GA7_CNRM-CM5',
46
+ 'REMO2015_NorESM1-M',
47
+ 'SMHI-RCA4_EC-EARTH',
48
+ 'WRF381P_NorESM1-M',
49
+ 'ALADIN63_CNRM-CM5',
50
+ 'CCLM4-8-17_MPI-ESM-LR',
51
+ 'HIRHAM5_IPSL-CM5A-MR',
52
+ 'HadREM3-GA7_HadGEM2-ES',
53
+ 'SMHI-RCA4_IPSL-CM5A-MR',
54
+ 'HIRHAM5_NorESM1-M',
55
+ 'REMO2009_MPI-ESM-LR',
56
+ 'CCLM4-8-17_HadGEM2-ES'
57
+ ]
58
+ # Mapping between indicator columns and their units
59
+ DRIAS_INDICATOR_TO_UNIT = {
60
+ "total_winter_precipitation": "mm",
61
+ "total_summer_precipitation": "mm",
62
+ "total_annual_precipitation": "mm",
63
+ "total_remarkable_daily_precipitation": "mm",
64
+ "frequency_of_remarkable_daily_precipitation": "days",
65
+ "extreme_precipitation_intensity": "mm",
66
+ "mean_winter_temperature": "°C",
67
+ "mean_summer_temperature": "°C",
68
+ "mean_annual_temperature": "°C",
69
+ "number_tropical_nights": "days",
70
+ "maximum_summer_temperature": "°C",
71
+ "number_of_days_with_tx_above_30": "days",
72
+ "number_of_days_with_tx_above_35": "days",
73
+ "number_of_days_with_dry_ground": "days"
74
+ }
75
+
76
+ DRIAS_PLOT_PARAMETERS = [
77
+ 'year',
78
+ 'location'
79
+ ]
80
+
81
+ DRIAS_INDICATOR_TO_COLORSCALE = {
82
+ "total_winter_precipitation": PRECIPITATION_COLORSCALE,
83
+ "total_summer_precipitation": PRECIPITATION_COLORSCALE,
84
+ "total_annual_precipitation": PRECIPITATION_COLORSCALE,
85
+ "total_remarkable_daily_precipitation": PRECIPITATION_COLORSCALE,
86
+ "frequency_of_remarkable_daily_precipitation": PRECIPITATION_COLORSCALE,
87
+ "extreme_precipitation_intensity": PRECIPITATION_COLORSCALE,
88
+ "mean_winter_temperature":TEMPERATURE_COLORSCALE,
89
+ "mean_summer_temperature":TEMPERATURE_COLORSCALE,
90
+ "mean_annual_temperature":TEMPERATURE_COLORSCALE,
91
+ "number_tropical_nights": TEMPERATURE_COLORSCALE,
92
+ "maximum_summer_temperature":TEMPERATURE_COLORSCALE,
93
+ "number_of_days_with_tx_above_30": TEMPERATURE_COLORSCALE,
94
+ "number_of_days_with_tx_above_35": TEMPERATURE_COLORSCALE,
95
+ "number_of_days_with_dry_ground": TEMPERATURE_COLORSCALE
96
+ }
97
+
98
+ DRIAS_UI_TEXT = """
99
+ Hi, I'm **Talk to Drias**, designed to answer your questions using [**DRIAS - TRACC 2023**](https://www.drias-climat.fr/accompagnement/sections/401) data.
100
+ I'll answer by displaying a list of SQL queries, graphs and data most relevant to your question.
101
+
102
+ You can ask me anything about these climate indicators: **temperature**, **precipitation** or **drought**.
103
+ You can specify **location** and/or **year**.
104
+ You can choose from a list of climate models. By default, we take the **average of each model**.
105
+
106
+ For example, you can ask:
107
+ - What will the temperature be like in Paris?
108
+ - What will be the total rainfall in France in 2030?
109
+ - How frequent will extreme events be in Lyon?
110
+
111
+ **Example of indicators in the data**:
112
+ - Mean temperature (annual, winter, summer)
113
+ - Total precipitation (annual, winter, summer)
114
+ - Number of days with remarkable precipitations, with dry ground, with temperature above 30°C
115
+
116
+ ⚠️ **Limitations**:
117
+ - You can't ask anything that isn't related to **DRIAS - TRACC 2023** data.
118
+ - You can only ask about **locations in France**.
119
+ - If you specify a year, there may be **no data for that year for some models**.
120
+ - You **cannot compare two models**.
121
+
122
+ 🛈 **Information**
123
+ Please note that we **log your questions for meta-analysis purposes**, so avoid sharing any sensitive or personal information.
124
+ """
climateqa/engine/talk_to_data/drias/plot_informations.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from climateqa.engine.talk_to_data.drias.config import DRIAS_INDICATOR_TO_UNIT
2
+
3
+ def indicator_evolution_informations(
4
+ indicator: str,
5
+ params: dict[str, str]
6
+ ) -> str:
7
+ unit = DRIAS_INDICATOR_TO_UNIT[indicator]
8
+ if "location" not in params:
9
+ raise ValueError('"location" must be provided in params')
10
+ location = params["location"]
11
+ return f"""
12
+ This plot shows how the climate indicator **{indicator}** evolves over time in **{location}**.
13
+
14
+ It combines both historical observations and future projections according to the climate scenario RCP8.5.
15
+
16
+ The x-axis represents the years, and the y-axis shows the value of the indicator ({unit}).
17
+
18
+ A 10-year rolling average curve is displayed to give a better idea of the overall trend.
19
+
20
+ **Data source:**
21
+ - The data come from the DRIAS TRACC data. The data were initially extracted from [the DRIAS website](https://www.drias-climat.fr/drias_prod/accueil/okapiWebDrias/index.jsp?iddrias=climat) and then preprocessed to a tabular format and uploaded as parquet in this [Hugging Face dataset](https://huggingface.co/datasets/timeki/drias_db).
22
+ - For each year and climate model, the value of {indicator} in {location} is collected, to build the time series.
23
+ - The coordinates used for {location} correspond to the closest available point in the DRIAS database, which uses a regular grid with a spatial resolution of 8 km.
24
+ - The indicator values shown are those for the selected climate model.
25
+ - If ALL climate model is selected, the average value of the indicator between all the climate models is used.
26
+ """
27
+
28
+ def indicator_number_of_days_per_year_informations(
29
+ indicator: str,
30
+ params: dict[str, str]
31
+ ) -> str:
32
+ unit = DRIAS_INDICATOR_TO_UNIT[indicator]
33
+ if "location" not in params:
34
+ raise ValueError('"location" must be provided in params')
35
+ location = params["location"]
36
+ return f"""
37
+ This plot displays a bar chart showing the yearly frequency of the climate indicator **{indicator}** in **{location}**.
38
+
39
+ The x-axis represents the years, and the y-axis shows the frequency of {indicator} ({unit}) per year.
40
+
41
+ **Data source:**
42
+ - The data come from the DRIAS TRACC data. The data were initially extracted from [the DRIAS website](https://www.drias-climat.fr/drias_prod/accueil/okapiWebDrias/index.jsp?iddrias=climat) and then preprocessed to a tabular format and uploaded as parquet in this [Hugging Face dataset](https://huggingface.co/datasets/timeki/drias_db).
43
+ - For each year and climate model, the value of {indicator} in {location} is collected, to build the time series.
44
+ - The coordinates used for {location} correspond to the closest available point in the DRIAS database, which uses a regular grid with a spatial resolution of 8 km.
45
+ - The indicator values shown are those for the selected climate model.
46
+ - If ALL climate model is selected, the average value of the indicator between all the climate models is used.
47
+ """
48
+
49
+ def distribution_of_indicator_for_given_year_informations(
50
+ indicator: str,
51
+ params: dict[str, str]
52
+ ) -> str:
53
+ unit = DRIAS_INDICATOR_TO_UNIT[indicator]
54
+ year = params["year"]
55
+ if year is None:
56
+ year = 2030
57
+ return f"""
58
+ This plot shows a histogram of the distribution of the climate indicator **{indicator}** across all locations for the year **{year}**.
59
+
60
+ It allows you to visualize how the values of {indicator} ({unit}) are spread for a given year.
61
+
62
+ **Data source:**
63
+ - The data come from the DRIAS TRACC data. The data were initially extracted from [the DRIAS website](https://www.drias-climat.fr/drias_prod/accueil/okapiWebDrias/index.jsp?iddrias=climat) and then preprocessed to a tabular format and uploaded as parquet in this [Hugging Face dataset](https://huggingface.co/datasets/timeki/drias_db).
64
+ - For each grid point in the dataset and climate model, the value of {indicator} for the year {year} is extracted.
65
+ - The indicator values shown are those for the selected climate model.
66
+ - If ALL climate model is selected, the average value of the indicator between all the climate models is used.
67
+ """
68
+
69
+ def map_of_france_of_indicator_for_given_year_informations(
70
+ indicator: str,
71
+ params: dict[str, str]
72
+ ) -> str:
73
+ unit = DRIAS_INDICATOR_TO_UNIT[indicator]
74
+ year = params["year"]
75
+ if year is None:
76
+ year = 2030
77
+ return f"""
78
+ This plot displays a choropleth map showing the spatial distribution of **{indicator}** across all regions of France for the year **{year}**.
79
+
80
+ Each region is colored according to the value of the indicator ({unit}), allowing you to visually compare how {indicator} varies geographically within France for the selected year and climate model.
81
+
82
+ **Data source:**
83
+ - The data come from the DRIAS TRACC data. The data were initially extracted from [the DRIAS website](https://www.drias-climat.fr/drias_prod/accueil/okapiWebDrias/index.jsp?iddrias=climat) and then preprocessed to a tabular format and uploaded as parquet in this [Hugging Face dataset](https://huggingface.co/datasets/timeki/drias_db).
84
+ - For each region of France, the value of {indicator} in {year} and for the selected climate model is extracted and mapped to its geographic coordinates.
85
+ - The regions correspond to 8 km squares centered on the grid points of the DRIAS dataset.
86
+ - The indicator values shown are those for the selected climate model.
87
+ - If ALL climate model is selected, the average value of the indicator between all the climate models is used.
88
+ """
climateqa/engine/talk_to_data/drias/plots.py ADDED
@@ -0,0 +1,434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import geojson
3
+ from math import cos, radians
4
+ from typing import Callable
5
+ import pandas as pd
6
+ from plotly.graph_objects import Figure
7
+ import plotly.graph_objects as go
8
+ from climateqa.engine.talk_to_data.drias.plot_informations import distribution_of_indicator_for_given_year_informations, indicator_evolution_informations, indicator_number_of_days_per_year_informations, map_of_france_of_indicator_for_given_year_informations
9
+ from climateqa.engine.talk_to_data.objects.plot import Plot
10
+ from climateqa.engine.talk_to_data.drias.queries import (
11
+ indicator_for_given_year_query,
12
+ indicator_per_year_at_location_query,
13
+ )
14
+ from climateqa.engine.talk_to_data.drias.config import DRIAS_INDICATOR_TO_COLORSCALE, DRIAS_INDICATOR_TO_UNIT
15
+
16
+ def generate_geojson_polygons(latitudes: list[float], longitudes: list[float], indicators: list[float]) -> geojson.FeatureCollection:
17
+ side_km = 8
18
+ delta_lat = side_km / 111
19
+ features = []
20
+ for idx, (lat, lon, val) in enumerate(zip(latitudes, longitudes, indicators)):
21
+ delta_lon = side_km / (111 * cos(radians(lat)))
22
+ half_lat = delta_lat / 2
23
+ half_lon = delta_lon / 2
24
+ features.append(geojson.Feature(
25
+ geometry=geojson.Polygon([[
26
+ [lon - half_lon, lat - half_lat],
27
+ [lon + half_lon, lat - half_lat],
28
+ [lon + half_lon, lat + half_lat],
29
+ [lon - half_lon, lat + half_lat],
30
+ [lon - half_lon, lat - half_lat]
31
+ ]]),
32
+ properties={"value": val},
33
+ id=str(idx)
34
+ ))
35
+
36
+ return geojson.FeatureCollection(features)
37
+
38
+ def plot_indicator_evolution_at_location(params: dict) -> Callable[..., Figure]:
39
+ """Generates a function to plot indicator evolution over time at a location.
40
+
41
+ This function creates a line plot showing how a climate indicator changes
42
+ over time at a specific location. It handles temperature, precipitation,
43
+ and other climate indicators.
44
+
45
+ Args:
46
+ params (dict): Dictionary containing:
47
+ - indicator_column (str): The column name for the indicator
48
+ - location (str): The location to plot
49
+ - model (str): The climate model to use
50
+
51
+ Returns:
52
+ Callable[..., Figure]: A function that takes a DataFrame and returns a plotly Figure
53
+
54
+ Example:
55
+ >>> plot_func = plot_indicator_evolution_at_location({
56
+ ... 'indicator_column': 'mean_temperature',
57
+ ... 'location': 'Paris',
58
+ ... 'model': 'ALL'
59
+ ... })
60
+ >>> fig = plot_func(df)
61
+ """
62
+ indicator = params["indicator_column"]
63
+ location = params["location"]
64
+ indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
65
+ unit = DRIAS_INDICATOR_TO_UNIT.get(indicator, "")
66
+
67
+ def plot_data(df: pd.DataFrame) -> Figure:
68
+ """Generates the actual plot from the data.
69
+
70
+ Args:
71
+ df (pd.DataFrame): DataFrame containing the data to plot
72
+
73
+ Returns:
74
+ Figure: A plotly Figure object showing the indicator evolution
75
+ """
76
+ fig = go.Figure()
77
+ if df['model'].nunique() != 1:
78
+ df_avg = df.groupby("year", as_index=False)[indicator].mean()
79
+
80
+ # Transform to list to avoid pandas encoding
81
+ indicators = df_avg[indicator].astype(float).tolist()
82
+ years = df_avg["year"].astype(int).tolist()
83
+
84
+ # Compute the 10-year rolling average
85
+ rolling_window = 10
86
+ sliding_averages = (
87
+ df_avg[indicator]
88
+ .rolling(window=rolling_window, min_periods=rolling_window)
89
+ .mean()
90
+ .astype(float)
91
+ .tolist()
92
+ )
93
+ model_label = "Model Average"
94
+
95
+ # Only add rolling average if we have enough data points
96
+ if len([x for x in sliding_averages if pd.notna(x)]) > 0:
97
+ # Sliding average dashed line
98
+ fig.add_scatter(
99
+ x=years,
100
+ y=sliding_averages,
101
+ mode="lines",
102
+ name="10 years rolling average",
103
+ line=dict(dash="dash"),
104
+ marker=dict(color="#d62728"),
105
+ hovertemplate=f"10-year average: %{{y:.2f}} {unit}<br>Year: %{{x}}<extra></extra>"
106
+ )
107
+
108
+ else:
109
+ df_model = df
110
+
111
+ # Transform to list to avoid pandas encoding
112
+ indicators = df_model[indicator].astype(float).tolist()
113
+ years = df_model["year"].astype(int).tolist()
114
+
115
+ # Compute the 10-year rolling average
116
+ rolling_window = 10
117
+ sliding_averages = (
118
+ df_model[indicator]
119
+ .rolling(window=rolling_window, min_periods=rolling_window)
120
+ .mean()
121
+ .astype(float)
122
+ .tolist()
123
+ )
124
+ model_label = f"Model : {df['model'].unique()[0]}"
125
+
126
+ # Only add rolling average if we have enough data points
127
+ if len([x for x in sliding_averages if pd.notna(x)]) > 0:
128
+ # Sliding average dashed line
129
+ fig.add_scatter(
130
+ x=years,
131
+ y=sliding_averages,
132
+ mode="lines",
133
+ name="10 years rolling average",
134
+ line=dict(dash="dash"),
135
+ marker=dict(color="#d62728"),
136
+ hovertemplate=f"10-year average: %{{y:.2f}} {unit}<br>Year: %{{x}}<extra></extra>"
137
+ )
138
+
139
+ # Indicator per year plot
140
+ fig.add_scatter(
141
+ x=years,
142
+ y=indicators,
143
+ name=f"Yearly {indicator_label}",
144
+ mode="lines",
145
+ marker=dict(color="#1f77b4"),
146
+ hovertemplate=f"{indicator_label}: %{{y:.2f}} {unit}<br>Year: %{{x}}<extra></extra>"
147
+ )
148
+ fig.update_layout(
149
+ title=f"Evolution of {indicator_label} in {location} ({model_label})",
150
+ xaxis_title="Year",
151
+ yaxis_title=f"{indicator_label} ({unit})",
152
+ template="plotly_white",
153
+ height=900,
154
+ )
155
+ return fig
156
+
157
+ return plot_data
158
+
159
+
160
+ indicator_evolution_at_location: Plot = {
161
+ "name": "Indicator evolution at location",
162
+ "description": "Plot an evolution of the indicator at a certain location",
163
+ "params": ["indicator_column", "location", "model"],
164
+ "plot_function": plot_indicator_evolution_at_location,
165
+ "sql_query": indicator_per_year_at_location_query,
166
+ "plot_information": indicator_evolution_informations,
167
+ 'short_name': 'Evolution'
168
+ }
169
+
170
+
171
+ def plot_indicator_number_of_days_per_year_at_location(
172
+ params: dict,
173
+ ) -> Callable[..., Figure]:
174
+ """Generates a function to plot the number of days per year for an indicator.
175
+
176
+ This function creates a bar chart showing the frequency of certain climate
177
+ events (like days above a temperature threshold) per year at a specific location.
178
+
179
+ Args:
180
+ params (dict): Dictionary containing:
181
+ - indicator_column (str): The column name for the indicator
182
+ - location (str): The location to plot
183
+ - model (str): The climate model to use
184
+
185
+ Returns:
186
+ Callable[..., Figure]: A function that takes a DataFrame and returns a plotly Figure
187
+ """
188
+ indicator = params["indicator_column"]
189
+ location = params["location"]
190
+ indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
191
+ unit = DRIAS_INDICATOR_TO_UNIT.get(indicator, "")
192
+
193
+ def plot_data(df: pd.DataFrame) -> Figure:
194
+ """Generate the figure thanks to the dataframe
195
+
196
+ Args:
197
+ df (pd.DataFrame): pandas dataframe with the required data
198
+
199
+ Returns:
200
+ Figure: Plotly figure
201
+ """
202
+ fig = go.Figure()
203
+ if df['model'].nunique() != 1:
204
+ df_avg = df.groupby("year", as_index=False)[indicator].mean()
205
+
206
+ # Transform to list to avoid pandas encoding
207
+ indicators = df_avg[indicator].astype(float).tolist()
208
+ years = df_avg["year"].astype(int).tolist()
209
+ model_label = "Model Average"
210
+
211
+ else:
212
+ df_model = df
213
+ # Transform to list to avoid pandas encoding
214
+ indicators = df_model[indicator].astype(float).tolist()
215
+ years = df_model["year"].astype(int).tolist()
216
+ model_label = f"Model : {df['model'].unique()[0]}"
217
+
218
+
219
+ # Bar plot
220
+ fig.add_trace(
221
+ go.Bar(
222
+ x=years,
223
+ y=indicators,
224
+ width=0.5,
225
+ marker=dict(color="#1f77b4"),
226
+ hovertemplate=f"{indicator_label}: %{{y:.2f}} {unit}<br>Year: %{{x}}<extra></extra>"
227
+ )
228
+ )
229
+
230
+ fig.update_layout(
231
+ title=f"{indicator_label} in {location} ({model_label})",
232
+ xaxis_title="Year",
233
+ yaxis_title=f"{indicator_label} ({unit})",
234
+ yaxis=dict(range=[0, max(indicators)]),
235
+ bargap=0.5,
236
+ height=900,
237
+ template="plotly_white",
238
+ )
239
+
240
+ return fig
241
+
242
+ return plot_data
243
+
244
+
245
+ indicator_number_of_days_per_year_at_location: Plot = {
246
+ "name": "Indicator number of days per year at location",
247
+ "description": "Plot a barchart of the number of days per year of a certain indicator at a certain location. It is appropriate for frequency indicator.",
248
+ "params": ["indicator_column", "location", "model"],
249
+ "plot_function": plot_indicator_number_of_days_per_year_at_location,
250
+ "sql_query": indicator_per_year_at_location_query,
251
+ "plot_information": indicator_number_of_days_per_year_informations,
252
+ "short_name": "Yearly Frequency",
253
+ }
254
+
255
+
256
+ def plot_distribution_of_indicator_for_given_year(
257
+ params: dict,
258
+ ) -> Callable[..., Figure]:
259
+ """Generates a function to plot the distribution of an indicator for a year.
260
+
261
+ This function creates a histogram showing the distribution of a climate
262
+ indicator across different locations for a specific year.
263
+
264
+ Args:
265
+ params (dict): Dictionary containing:
266
+ - indicator_column (str): The column name for the indicator
267
+ - year (str): The year to plot
268
+ - model (str): The climate model to use
269
+
270
+ Returns:
271
+ Callable[..., Figure]: A function that takes a DataFrame and returns a plotly Figure
272
+ """
273
+ indicator = params["indicator_column"]
274
+ year = params["year"]
275
+ if year is None:
276
+ year = 2030
277
+ indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
278
+ unit = DRIAS_INDICATOR_TO_UNIT.get(indicator, "")
279
+
280
+ def plot_data(df: pd.DataFrame) -> Figure:
281
+ """Generate the figure thanks to the dataframe
282
+
283
+ Args:
284
+ df (pd.DataFrame): pandas dataframe with the required data
285
+
286
+ Returns:
287
+ Figure: Plotly figure
288
+ """
289
+ fig = go.Figure()
290
+ if df['model'].nunique() != 1:
291
+ df_avg = df.groupby(["latitude", "longitude"], as_index=False)[
292
+ indicator
293
+ ].mean()
294
+
295
+ # Transform to list to avoid pandas encoding
296
+ indicators = df_avg[indicator].astype(float).tolist()
297
+ model_label = "Model Average"
298
+
299
+ else:
300
+ df_model = df
301
+
302
+ # Transform to list to avoid pandas encoding
303
+ indicators = df_model[indicator].astype(float).tolist()
304
+ model_label = f"Model : {df['model'].unique()[0]}"
305
+
306
+
307
+ fig.add_trace(
308
+ go.Histogram(
309
+ x=indicators,
310
+ opacity=0.8,
311
+ histnorm="percent",
312
+ marker=dict(color="#1f77b4"),
313
+ hovertemplate=f"{indicator_label}: %{{x:.2f}} {unit}<br>Frequency: %{{y:.2f}}%<extra></extra>"
314
+ )
315
+ )
316
+
317
+ fig.update_layout(
318
+ title=f"Distribution of {indicator_label} in {year} ({model_label})",
319
+ xaxis_title=f"{indicator_label} ({unit})",
320
+ yaxis_title="Frequency (%)",
321
+ plot_bgcolor="rgba(0, 0, 0, 0)",
322
+ showlegend=False,
323
+ height=900,
324
+ )
325
+
326
+ return fig
327
+
328
+ return plot_data
329
+
330
+
331
+ distribution_of_indicator_for_given_year: Plot = {
332
+ "name": "Distribution of an indicator for a given year",
333
+ "description": "Plot an histogram of the distribution for a given year of the values of an indicator",
334
+ "params": ["indicator_column", "model", "year"],
335
+ "plot_function": plot_distribution_of_indicator_for_given_year,
336
+ "sql_query": indicator_for_given_year_query,
337
+ "plot_information": distribution_of_indicator_for_given_year_informations,
338
+ 'short_name': 'Distribution'
339
+ }
340
+
341
+
342
+ def plot_map_of_france_of_indicator_for_given_year(
343
+ params: dict,
344
+ ) -> Callable[..., Figure]:
345
+ """Generates a function to plot a map of France for an indicator.
346
+
347
+ This function creates a choropleth map of France showing the spatial
348
+ distribution of a climate indicator for a specific year.
349
+
350
+ Args:
351
+ params (dict): Dictionary containing:
352
+ - indicator_column (str): The column name for the indicator
353
+ - year (str): The year to plot
354
+ - model (str): The climate model to use
355
+
356
+ Returns:
357
+ Callable[..., Figure]: A function that takes a DataFrame and returns a plotly Figure
358
+ """
359
+ indicator = params["indicator_column"]
360
+ year = params["year"]
361
+ if year is None:
362
+ year = 2030
363
+ indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
364
+ unit = DRIAS_INDICATOR_TO_UNIT.get(indicator, "")
365
+
366
+ def plot_data(df: pd.DataFrame) -> Figure:
367
+ fig = go.Figure()
368
+ if df['model'].nunique() != 1:
369
+ df_avg = df.groupby(["latitude", "longitude"], as_index=False)[
370
+ indicator
371
+ ].mean()
372
+
373
+ indicators = df_avg[indicator].astype(float).tolist()
374
+ latitudes = df_avg["latitude"].astype(float).tolist()
375
+ longitudes = df_avg["longitude"].astype(float).tolist()
376
+ model_label = "Model Average"
377
+
378
+ else:
379
+ df_model = df
380
+
381
+ # Transform to list to avoid pandas encoding
382
+ indicators = df_model[indicator].astype(float).tolist()
383
+ latitudes = df_model["latitude"].astype(float).tolist()
384
+ longitudes = df_model["longitude"].astype(float).tolist()
385
+ model_label = f"Model : {df['model'].unique()[0]}"
386
+
387
+
388
+
389
+ geojson_data = generate_geojson_polygons(latitudes, longitudes, indicators)
390
+
391
+ fig = go.Figure(go.Choroplethmapbox(
392
+ geojson=geojson_data,
393
+ locations=[str(i) for i in range(len(indicators))],
394
+ featureidkey="id",
395
+ z=indicators,
396
+ colorscale=DRIAS_INDICATOR_TO_COLORSCALE[indicator],
397
+ zmin=min(indicators),
398
+ zmax=max(indicators),
399
+ marker_opacity=0.7,
400
+ marker_line_width=0,
401
+ colorbar_title=f"{indicator_label} ({unit})",
402
+ text=[f"{indicator_label}: {value:.2f} {unit}" for value in indicators], # Add hover text showing the indicator value
403
+ hoverinfo="text"
404
+ ))
405
+
406
+ fig.update_layout(
407
+ mapbox_style="open-street-map", # Use OpenStreetMap
408
+ mapbox_zoom=5,
409
+ height=900,
410
+ mapbox_center={"lat": 46.6, "lon": 2.0},
411
+ coloraxis_colorbar=dict(title=f"{indicator_label} ({unit})"), # Add legend
412
+ title=f"{indicator_label} in {year} in France ({model_label}) " # Title
413
+ )
414
+ return fig
415
+
416
+ return plot_data
417
+
418
+
419
+ map_of_france_of_indicator_for_given_year: Plot = {
420
+ "name": "Map of France of an indicator for a given year",
421
+ "description": "Heatmap on the map of France of the values of an indicator for a given year",
422
+ "params": ["indicator_column", "year", "model"],
423
+ "plot_function": plot_map_of_france_of_indicator_for_given_year,
424
+ "sql_query": indicator_for_given_year_query,
425
+ "plot_information": map_of_france_of_indicator_for_given_year_informations,
426
+ 'short_name': 'Map of France'
427
+ }
428
+
429
+ DRIAS_PLOTS = [
430
+ indicator_evolution_at_location,
431
+ indicator_number_of_days_per_year_at_location,
432
+ distribution_of_indicator_for_given_year,
433
+ map_of_france_of_indicator_for_given_year,
434
+ ]
climateqa/engine/talk_to_data/drias/queries.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TypedDict
2
+ from climateqa.engine.talk_to_data.config import DRIAS_DATASET_URL
3
+
4
+ class IndicatorPerYearAtLocationQueryParams(TypedDict, total=False):
5
+ """Parameters for querying an indicator's values over time at a location.
6
+
7
+ This class defines the parameters needed to query climate indicator data
8
+ for a specific location over multiple years.
9
+
10
+ Attributes:
11
+ indicator_column (str): The column name for the climate indicator
12
+ latitude (str): The latitude coordinate of the location
13
+ longitude (str): The longitude coordinate of the location
14
+ model (str): The climate model to use (optional)
15
+ """
16
+ indicator_column: str
17
+ latitude: str
18
+ longitude: str
19
+ model: str
20
+
21
+ def indicator_per_year_at_location_query(
22
+ table: str, params: IndicatorPerYearAtLocationQueryParams
23
+ ) -> str:
24
+ """SQL Query to get the evolution of an indicator per year at a certain location
25
+
26
+ Args:
27
+ table (str): sql table of the indicator
28
+ params (IndicatorPerYearAtLocationQueryParams) : dictionary with the required params for the query
29
+
30
+ Returns:
31
+ str: the sql query
32
+ """
33
+ indicator_column = params.get("indicator_column")
34
+ latitude = params.get("latitude")
35
+ longitude = params.get("longitude")
36
+
37
+ if indicator_column is None or latitude is None or longitude is None: # If one parameter is missing, returns an empty query
38
+ return ""
39
+
40
+ table = f"'{DRIAS_DATASET_URL}/{table.lower()}.parquet'"
41
+
42
+ sql_query = f"SELECT year, {indicator_column}, model\nFROM {table}\nWHERE latitude = {latitude} \nAnd longitude = {longitude} \nOrder by Year"
43
+
44
+ return sql_query
45
+
46
+ class IndicatorForGivenYearQueryParams(TypedDict, total=False):
47
+ """Parameters for querying an indicator's values across locations for a year.
48
+
49
+ This class defines the parameters needed to query climate indicator data
50
+ across different locations for a specific year.
51
+
52
+ Attributes:
53
+ indicator_column (str): The column name for the climate indicator
54
+ year (str): The year to query
55
+ model (str): The climate model to use (optional)
56
+ """
57
+ indicator_column: str
58
+ year: str
59
+ model: str
60
+
61
+ def indicator_for_given_year_query(
62
+ table:str, params: IndicatorForGivenYearQueryParams
63
+ ) -> str:
64
+ """SQL Query to get the values of an indicator with their latitudes, longitudes and models for a given year
65
+
66
+ Args:
67
+ table (str): sql table of the indicator
68
+ params (IndicatorForGivenYearQueryParams): dictionarry with the required params for the query
69
+
70
+ Returns:
71
+ str: the sql query
72
+ """
73
+ indicator_column = params.get("indicator_column")
74
+ year = params.get('year')
75
+ if year is None:
76
+ year = 2050
77
+ if year is None or indicator_column is None: # If one parameter is missing, returns an empty query
78
+ return ""
79
+
80
+ table = f"'{DRIAS_DATASET_URL}/{table.lower()}.parquet'"
81
+
82
+ sql_query = f"Select {indicator_column}, latitude, longitude, model\nFrom {table}\nWhere year = {year}"
83
+ return sql_query
climateqa/engine/talk_to_data/input_processing.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Literal, Optional, cast
2
+ import ast
3
+ from langchain_core.prompts import ChatPromptTemplate
4
+ from geopy.geocoders import Nominatim
5
+ from climateqa.engine.llm import get_llm
6
+ import duckdb
7
+ import os
8
+ from climateqa.engine.talk_to_data.config import DRIAS_MEAN_ANNUAL_TEMPERATURE_PATH, IPCC_COORDINATES_PATH
9
+ from climateqa.engine.talk_to_data.objects.llm_outputs import ArrayOutput
10
+ from climateqa.engine.talk_to_data.objects.location import Location
11
+ from climateqa.engine.talk_to_data.objects.plot import Plot
12
+ from climateqa.engine.talk_to_data.objects.states import State
13
+
14
+ async def detect_location_with_openai(sentence: str) -> str:
15
+ """
16
+ Detects locations in a sentence using OpenAI's API via LangChain.
17
+ """
18
+ llm = get_llm()
19
+
20
+ prompt = f"""
21
+ Extract all locations (cities, countries, states, or geographical areas) mentioned in the following sentence.
22
+ Return the result as a Python list. If no locations are mentioned, return an empty list.
23
+
24
+ Sentence: "{sentence}"
25
+ """
26
+
27
+ response = await llm.ainvoke(prompt)
28
+ location_list = ast.literal_eval(response.content.strip("```python\n").strip())
29
+ if location_list:
30
+ return location_list[0]
31
+ else:
32
+ return ""
33
+
34
+ def loc_to_coords(location: str) -> tuple[float, float]:
35
+ """Converts a location name to geographic coordinates.
36
+
37
+ This function uses the Nominatim geocoding service to convert
38
+ a location name (e.g., city name) to its latitude and longitude.
39
+
40
+ Args:
41
+ location (str): The name of the location to geocode
42
+
43
+ Returns:
44
+ tuple[float, float]: A tuple containing (latitude, longitude)
45
+
46
+ Raises:
47
+ AttributeError: If the location cannot be found
48
+ """
49
+ geolocator = Nominatim(user_agent="city_to_latlong")
50
+ coords = geolocator.geocode(location)
51
+ return (coords.latitude, coords.longitude)
52
+
53
+ def coords_to_country(coords: tuple[float, float]) -> tuple[str,str]:
54
+ """Converts geographic coordinates to a country name.
55
+
56
+ This function uses the Nominatim reverse geocoding service to convert
57
+ latitude and longitude coordinates to a country name.
58
+
59
+ Args:
60
+ coords (tuple[float, float]): A tuple containing (latitude, longitude)
61
+
62
+ Returns:
63
+ tuple[str,str]: A tuple containg (country_code, country_name, admin1)
64
+
65
+ Raises:
66
+ AttributeError: If the coordinates cannot be found
67
+ """
68
+ geolocator = Nominatim(user_agent="latlong_to_country")
69
+ location = geolocator.reverse(coords)
70
+ address = location.raw['address']
71
+ return address['country_code'].upper(), address['country']
72
+
73
+ def nearest_neighbour_sql(location: tuple, mode: Literal['DRIAS', 'IPCC']) -> tuple[str, str, Optional[str]]:
74
+ long = round(location[1], 3)
75
+ lat = round(location[0], 3)
76
+ conn = duckdb.connect()
77
+
78
+ if mode == 'DRIAS':
79
+ table_path = f"'{DRIAS_MEAN_ANNUAL_TEMPERATURE_PATH}'"
80
+ results = conn.sql(
81
+ f"SELECT latitude, longitude FROM {table_path} WHERE latitude BETWEEN {lat - 0.3} AND {lat + 0.3} AND longitude BETWEEN {long - 0.3} AND {long + 0.3}"
82
+ ).fetchdf()
83
+ else:
84
+ table_path = f"'{IPCC_COORDINATES_PATH}'"
85
+ results = conn.sql(
86
+ f"SELECT latitude, longitude, admin1 FROM {table_path} WHERE latitude BETWEEN {lat - 0.5} AND {lat + 0.5} AND longitude BETWEEN {long - 0.5} AND {long + 0.5}"
87
+ ).fetchdf()
88
+
89
+
90
+ if len(results) == 0:
91
+ return "", "", ""
92
+
93
+ if 'admin1' in results.columns:
94
+ admin1 = results['admin1'].iloc[0]
95
+ else:
96
+ admin1 = None
97
+ return results['latitude'].iloc[0], results['longitude'].iloc[0], admin1
98
+
99
+ async def detect_year_with_openai(sentence: str) -> str:
100
+ """
101
+ Detects years in a sentence using OpenAI's API via LangChain.
102
+ """
103
+ llm = get_llm()
104
+
105
+ prompt = """
106
+ Extract all years mentioned in the following sentence.
107
+ Return the result as a Python list. If no year are mentioned, return an empty list.
108
+
109
+ Sentence: "{sentence}"
110
+ """
111
+
112
+ prompt = ChatPromptTemplate.from_template(prompt)
113
+ structured_llm = llm.with_structured_output(ArrayOutput)
114
+ chain = prompt | structured_llm
115
+ response: ArrayOutput = await chain.ainvoke({"sentence": sentence})
116
+ years_list = ast.literal_eval(response['array'])
117
+ if len(years_list) > 0:
118
+ return years_list[0]
119
+ else:
120
+ return ""
121
+
122
+
123
+ async def detect_relevant_tables(user_question: str, plot: Plot, llm, table_names_list: list[str]) -> list[str]:
124
+ """Identifies relevant tables for a plot based on user input.
125
+
126
+ This function uses an LLM to analyze the user's question and the plot
127
+ description to determine which tables in the DRIAS database would be
128
+ most relevant for generating the requested visualization.
129
+
130
+ Args:
131
+ user_question (str): The user's question about climate data
132
+ plot (Plot): The plot configuration object
133
+ llm: The language model instance to use for analysis
134
+
135
+ Returns:
136
+ list[str]: A list of table names that are relevant for the plot
137
+
138
+ Example:
139
+ >>> detect_relevant_tables(
140
+ ... "What will the temperature be like in Paris?",
141
+ ... indicator_evolution_at_location,
142
+ ... llm
143
+ ... )
144
+ ['mean_annual_temperature', 'mean_summer_temperature']
145
+ """
146
+ # Get all table names
147
+
148
+ prompt = (
149
+ f"You are helping to build a plot following this description : {plot['description']}."
150
+ f"You are given a list of tables and a user question."
151
+ f"Based on the description of the plot, which table are appropriate for that kind of plot."
152
+ f"Write the 3 most relevant tables to use. Answer only a python list of table name."
153
+ f"### List of tables : {table_names_list}"
154
+ f"### User question : {user_question}"
155
+ f"### List of table name : "
156
+ )
157
+
158
+ table_names = ast.literal_eval(
159
+ (await llm.ainvoke(prompt)).content.strip("```python\n").strip()
160
+ )
161
+ return table_names
162
+
163
+ async def detect_relevant_plots(user_question: str, llm, plot_list: list[Plot]) -> list[str]:
164
+ plots_description = ""
165
+ for plot in plot_list:
166
+ plots_description += "Name: " + plot["name"]
167
+ plots_description += " - Description: " + plot["description"] + "\n"
168
+
169
+ prompt = (
170
+ "You are helping to answer a question with insightful visualizations.\n"
171
+ "You are given a user question and a list of plots with their name and description.\n"
172
+ "Based on the descriptions of the plots, select ALL plots that could provide a useful answer to this question. "
173
+ "Include any plot that could show relevant information, even if their perspectives (such as time series or spatial distribution) are different.\n"
174
+ "For example, for a question like 'What will be the total rainfall in China in 2050?', both a time series plot and a spatial map plot could be relevant.\n"
175
+ "Return only a Python list of plot names sorted from the most relevant one to the less relevant one.\n"
176
+ f"### Descriptions of the plots : {plots_description}"
177
+ f"### User question : {user_question}\n"
178
+ f"### Names of the plots : "
179
+ )
180
+
181
+ plot_names = ast.literal_eval(
182
+ (await llm.ainvoke(prompt)).content.strip("```python\n").strip()
183
+ )
184
+ return plot_names
185
+
186
+ async def find_location(user_input: str, mode: Literal['DRIAS', 'IPCC'] = 'DRIAS') -> Location:
187
+ print(f"---- Find location in user input ----")
188
+ location = await detect_location_with_openai(user_input)
189
+ output: Location = {
190
+ 'location' : location,
191
+ 'longitude' : None,
192
+ 'latitude' : None,
193
+ 'country_code' : None,
194
+ 'country_name' : None,
195
+ 'admin1' : None
196
+ }
197
+
198
+ if location:
199
+ coords = loc_to_coords(location)
200
+ country_code, country_name = coords_to_country(coords)
201
+ neighbour = nearest_neighbour_sql(coords, mode)
202
+ output.update({
203
+ "latitude": neighbour[0],
204
+ "longitude": neighbour[1],
205
+ "country_code": country_code,
206
+ "country_name": country_name,
207
+ "admin1": neighbour[2]
208
+ })
209
+ output = cast(Location, output)
210
+ return output
211
+
212
+ async def find_year(user_input: str) -> str| None:
213
+ """Extracts year information from user input using LLM.
214
+
215
+ This function uses an LLM to identify and extract year information from the
216
+ user's query, which is used to filter data in subsequent queries.
217
+
218
+ Args:
219
+ user_input (str): The user's query text
220
+
221
+ Returns:
222
+ str: The extracted year, or empty string if no year found
223
+ """
224
+ print(f"---- Find year ---")
225
+ year = await detect_year_with_openai(user_input)
226
+ if year == "":
227
+ return None
228
+ return year
229
+
230
+ async def find_relevant_plots(state: State, llm, plots: list[Plot]) -> list[str]:
231
+ print("---- Find relevant plots ----")
232
+ relevant_plots = await detect_relevant_plots(state['user_input'], llm, plots)
233
+ return relevant_plots
234
+
235
+ async def find_relevant_tables_per_plot(state: State, plot: Plot, llm, tables: list[str]) -> list[str]:
236
+ print(f"---- Find relevant tables for {plot['name']} ----")
237
+ relevant_tables = await detect_relevant_tables(state['user_input'], plot, llm, tables)
238
+ return relevant_tables
239
+
240
+ async def find_param(state: State, param_name:str, mode: Literal['DRIAS', 'IPCC'] = 'DRIAS') -> dict[str, Optional[str]] | Location | None:
241
+ """Perform the good method to retrieve the desired parameter
242
+
243
+ Args:
244
+ state (State): state of the workflow
245
+ param_name (str): name of the desired parameter
246
+ table (str): name of the table
247
+
248
+ Returns:
249
+ dict[str, Any] | None:
250
+ """
251
+ if param_name == 'location':
252
+ location = await find_location(state['user_input'], mode)
253
+ return location
254
+ if param_name == 'year':
255
+ year = await find_year(state['user_input'])
256
+ return {'year': year}
257
+ return None
climateqa/engine/talk_to_data/ipcc/config.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from climateqa.engine.talk_to_data.ui_config import PRECIPITATION_COLORSCALE, TEMPERATURE_COLORSCALE
2
+ from climateqa.engine.talk_to_data.config import IPCC_DATASET_URL
3
+
4
+
5
+ # IPCC_DATASET_URL = "hf://datasets/ekimetrics/ipcc-atlas"
6
+ IPCC_TABLES = [
7
+ "mean_temperature",
8
+ "total_precipitation",
9
+ ]
10
+
11
+ IPCC_INDICATOR_COLUMNS_PER_TABLE = {
12
+ "mean_temperature": "mean_temperature",
13
+ "total_precipitation": "total_precipitation"
14
+ }
15
+
16
+ IPCC_INDICATOR_TO_UNIT = {
17
+ "mean_temperature": "°C",
18
+ "total_precipitation": "mm/day"
19
+ }
20
+
21
+ IPCC_SCENARIO = [
22
+ "historical",
23
+ "ssp126",
24
+ "ssp245",
25
+ "ssp370",
26
+ "ssp585",
27
+ ]
28
+
29
+ IPCC_MODELS = []
30
+
31
+ IPCC_PLOT_PARAMETERS = [
32
+ 'year',
33
+ 'location'
34
+ ]
35
+
36
+ MACRO_COUNTRIES = ['JP',
37
+ 'IN',
38
+ 'MH',
39
+ 'PT',
40
+ 'ID',
41
+ 'SJ',
42
+ 'MX',
43
+ 'CN',
44
+ 'GL',
45
+ 'PN',
46
+ 'AR',
47
+ 'AQ',
48
+ 'PF',
49
+ 'BR',
50
+ 'SH',
51
+ 'GS',
52
+ 'ZA',
53
+ 'NZ',
54
+ 'TF',
55
+ ]
56
+
57
+ HUGE_MACRO_COUNTRIES = ['CL',
58
+ 'CA',
59
+ 'AU',
60
+ 'US',
61
+ 'RU'
62
+ ]
63
+
64
+ IPCC_INDICATOR_TO_COLORSCALE = {
65
+ "mean_temperature": TEMPERATURE_COLORSCALE,
66
+ "total_precipitation": PRECIPITATION_COLORSCALE
67
+ }
68
+
69
+ IPCC_UI_TEXT = """
70
+ Hi, I'm **Talk to IPCC**, designed to answer your questions using [**IPCC - ATLAS**](https://interactive-atlas.ipcc.ch/regional-information#eyJ0eXBlIjoiQVRMQVMiLCJjb21tb25zIjp7ImxhdCI6OTc3MiwibG5nIjo0MDA2OTIsInpvb20iOjQsInByb2oiOiJFUFNHOjU0MDMwIiwibW9kZSI6ImNvbXBsZXRlX2F0bGFzIn0sInByaW1hcnkiOnsic2NlbmFyaW8iOiJzc3A1ODUiLCJwZXJpb2QiOiIyIiwic2Vhc29uIjoieWVhciIsImRhdGFzZXQiOiJDTUlQNiIsInZhcmlhYmxlIjoidGFzIiwidmFsdWVUeXBlIjoiQU5PTUFMWSIsImhhdGNoaW5nIjoiU0lNUExFIiwicmVnaW9uU2V0IjoiYXI2IiwiYmFzZWxpbmUiOiJwcmVJbmR1c3RyaWFsIiwicmVnaW9uc1NlbGVjdGVkIjpbXX0sInBsb3QiOnsiYWN0aXZlVGFiIjoicGx1bWUiLCJtYXNrIjoibm9uZSIsInNjYXR0ZXJZTWFnIjpudWxsLCJzY2F0dGVyWVZhciI6bnVsbCwic2hvd2luZyI6ZmFsc2V9fQ==) data.
71
+ I'll answer by displaying a list of SQL queries, graphs and data most relevant to your question.
72
+
73
+ You can ask me anything about these climate indicators: **temperature** or **precipitation**.
74
+ You can specify **location** and/or **year**.
75
+ By default, we take the **mediane of each climate model**.
76
+
77
+ Current available charts :
78
+ - Yearly evolution of an indicator at a specific location (historical + SSP Projections)
79
+ - Yearly spatial distribution of an indicator in a specific country
80
+
81
+ Current available indicators :
82
+ - Mean temperature
83
+ - Total precipitation
84
+
85
+ For example, you can ask:
86
+ - What will the temperature be like in Paris?
87
+ - What will be the total rainfall in the USA in 2030?
88
+ - How will the average temperature evolve in China ?
89
+
90
+ ⚠️ **Limitations**:
91
+ - You can't ask anything that isn't related to **IPCC - ATLAS** data.
92
+ - You can not ask about **several locations at the same time**.
93
+ - If you specify a year **before 1850 or over 2100**, there will be **no data**.
94
+ - You **cannot compare two models**.
95
+
96
+ 🛈 **Information**
97
+ Please note that we **log your questions for meta-analysis purposes**, so avoid sharing any sensitive or personal information.
98
+ """
climateqa/engine/talk_to_data/ipcc/plot_informations.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from climateqa.engine.talk_to_data.ipcc.config import IPCC_INDICATOR_TO_UNIT
2
+
3
+ def indicator_evolution_informations(
4
+ indicator: str,
5
+ params: dict[str,str],
6
+ ) -> str:
7
+ if "location" not in params:
8
+ raise ValueError('"location" must be provided in params')
9
+ location = params["location"]
10
+
11
+ unit = IPCC_INDICATOR_TO_UNIT[indicator]
12
+ return f"""
13
+ This plot shows how the climate indicator **{indicator}** evolves over time in **{location}**.
14
+
15
+ It combines both historical (from 1950 to 2015) observations and future (from 2016 to 2100) projections for the different SSP climate scenarios (SSP126, SSP245, SSP370 and SSP585).
16
+
17
+ The x-axis represents the years (from 1950 to 2100), and the y-axis shows the value of the {indicator} ({unit}).
18
+
19
+ Each line corresponds to a different scenario, allowing you to compare how {indicator} might change under various future conditions.
20
+
21
+ **Data source:**
22
+ - The data comes from the CMIP6 IPCC ATLAS data. The data were initially extracted from [this referenced website](https://digital.csic.es/handle/10261/332744) and then preprocessed to a tabular format and uploaded as parquet in this [Hugging Face dataset](https://huggingface.co/datasets/Ekimetrics/ipcc-atlas).
23
+ - The underlying data is retrieved by aggregating yearly values of {indicator} for the selected location, across all available scenarios. This means the system collects, for each year, the value of {indicator} in {location}, both for the historical period and for each scenario, to build the time series.
24
+ - The coordinates used for {location} correspond to the closest available point in the IPCC database, which uses a regular grid with a spatial resolution of 1 degree.
25
+ """
26
+
27
+ def choropleth_map_informations(
28
+ indicator: str,
29
+ params: dict[str, str],
30
+ ) -> str:
31
+ unit = IPCC_INDICATOR_TO_UNIT[indicator]
32
+ if "location" not in params:
33
+ raise ValueError('"location" must be provided in params')
34
+ location = params["location"]
35
+ country_name = params["country_name"]
36
+ year = params["year"]
37
+ if year is None:
38
+ year = 2050
39
+
40
+ return f"""
41
+ This plot displays a choropleth map showing the spatial distribution of **{indicator}** across all regions of **{location}** country ({country_name}) for the year **{year}** and the chosen scenario.
42
+
43
+ Each grid point is colored according to the value of the indicator ({unit}), allowing you to visually compare how {indicator} varies geographically within the country for the selected year and scenario.
44
+
45
+ **Data source:**
46
+ - The data come from the CMIP6 IPCC ATLAS data. The data were initially extracted from [this referenced website](https://digital.csic.es/handle/10261/332744) and then preprocessed to a tabular format and uploaded as parquet in this [Hugging Face dataset](https://huggingface.co/datasets/Ekimetrics/ipcc-atlas).
47
+ - For each grid point of {location} country ({country_name}), the value of {indicator} in {year} and for the selected scenario is extracted and mapped to its geographic coordinates.
48
+ - The grid points correspond to 1-degree squares centered on the grid points of the IPCC dataset. Each grid point has been mapped to a country using [**reverse_geocoder**](https://github.com/thampiman/reverse-geocoder).
49
+ - The coordinates used for each region are those of the closest available grid point in the IPCC database, which uses a regular grid with a spatial resolution of 1 degree.
50
+ """
climateqa/engine/talk_to_data/ipcc/plots.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable
2
+ from plotly.graph_objects import Figure
3
+ import plotly.graph_objects as go
4
+ import pandas as pd
5
+ import geojson
6
+
7
+ from climateqa.engine.talk_to_data.ipcc.config import IPCC_INDICATOR_TO_COLORSCALE, IPCC_INDICATOR_TO_UNIT, IPCC_SCENARIO
8
+ from climateqa.engine.talk_to_data.ipcc.plot_informations import choropleth_map_informations, indicator_evolution_informations
9
+ from climateqa.engine.talk_to_data.ipcc.queries import indicator_for_given_year_query, indicator_per_year_at_location_query
10
+ from climateqa.engine.talk_to_data.objects.plot import Plot
11
+
12
+ def generate_geojson_polygons(latitudes: list[float], longitudes: list[float], indicators: list[float]) -> geojson.FeatureCollection:
13
+ features = [
14
+ geojson.Feature(
15
+ geometry=geojson.Polygon([[
16
+ [lon - 0.5, lat - 0.5],
17
+ [lon + 0.5, lat - 0.5],
18
+ [lon + 0.5, lat + 0.5],
19
+ [lon - 0.5, lat + 0.5],
20
+ [lon - 0.5, lat - 0.5]
21
+ ]]),
22
+ properties={"value": val},
23
+ id=str(idx)
24
+ )
25
+ for idx, (lat, lon, val) in enumerate(zip(latitudes, longitudes, indicators))
26
+ ]
27
+
28
+ geojson_data = geojson.FeatureCollection(features)
29
+ return geojson_data
30
+
31
+ def plot_indicator_evolution_at_location_historical_and_projections(
32
+ params: dict,
33
+ ) -> Callable[[pd.DataFrame], Figure]:
34
+ """
35
+ Returns a function that generates a line plot showing the evolution of a climate indicator
36
+ (e.g., temperature, rainfall) over time at a specific location, including both historical data
37
+ and future projections for different climate scenarios.
38
+
39
+ Args:
40
+ params (dict): Dictionary with:
41
+ - indicator_column (str): Name of the climate indicator column to plot.
42
+ - location (str): Location (e.g., country, city) for which to plot the indicator.
43
+
44
+ Returns:
45
+ Callable[[pd.DataFrame], Figure]: Function that takes a DataFrame and returns a Plotly Figure
46
+ showing the indicator's evolution over time, with scenario lines and historical data.
47
+ """
48
+ indicator = params["indicator_column"]
49
+ location = params["location"]
50
+ indicator_label = " ".join(word.capitalize() for word in indicator.split("_"))
51
+ unit = IPCC_INDICATOR_TO_UNIT.get(indicator, "")
52
+
53
+ def plot_data(df: pd.DataFrame) -> Figure:
54
+ df = df.sort_values(by='year')
55
+ years = df['year'].astype(int).tolist()
56
+ indicators = df[indicator].astype(float).tolist()
57
+ scenarios = df['scenario'].astype(str).tolist()
58
+
59
+ # Find last historical value for continuity
60
+ last_historical = [(y, v) for y, v, s in zip(years, indicators, scenarios) if s == 'historical']
61
+ last_historical_year, last_historical_indicator = last_historical[-1] if last_historical else (None, None)
62
+
63
+ fig = go.Figure()
64
+ for scenario in IPCC_SCENARIO:
65
+ x = [y for y, s in zip(years, scenarios) if s == scenario]
66
+ y = [v for v, s in zip(indicators, scenarios) if s == scenario]
67
+ # Connect historical to scenario
68
+ if scenario != 'historical' and last_historical_indicator is not None:
69
+ x = [last_historical_year] + x
70
+ y = [last_historical_indicator] + y
71
+ fig.add_trace(go.Scatter(
72
+ x=x,
73
+ y=y,
74
+ mode='lines',
75
+ name=scenario
76
+ ))
77
+
78
+ fig.update_layout(
79
+ title=f'Yearly Evolution of {indicator_label} in {location} (Historical + SSP Scenarios)',
80
+ xaxis_title='Year',
81
+ yaxis_title=f'{indicator_label} ({unit})',
82
+ legend_title='Scenario',
83
+ height=800,
84
+ )
85
+ return fig
86
+
87
+ return plot_data
88
+
89
+ indicator_evolution_at_location_historical_and_projections: Plot = {
90
+ "name": "Indicator Evolution at Location (Historical + Projections)",
91
+ "description": (
92
+ "Shows how a climate indicator (e.g., rainfall, temperature) changes over time at a specific location, "
93
+ "including historical data and future projections. "
94
+ "Useful for questions about the value or trend of an indicator at a location for any year, "
95
+ "such as 'What will be the total rainfall in China in 2050?' or 'How does rainfall evolve in China over time?'. "
96
+ "Parameters: indicator_column (the climate variable), location (e.g., country, city)."
97
+ ),
98
+ "params": ["indicator_column", "location"],
99
+ "plot_function": plot_indicator_evolution_at_location_historical_and_projections,
100
+ "sql_query": indicator_per_year_at_location_query,
101
+ "plot_information": indicator_evolution_informations,
102
+ "short_name": "Evolution"
103
+ }
104
+
105
+ def plot_choropleth_map_of_country_indicator_for_specific_year(
106
+ params: dict,
107
+ ) -> Callable[[pd.DataFrame], Figure]:
108
+ """
109
+ Returns a function that generates a choropleth map (heatmap) showing the spatial distribution
110
+ of a climate indicator (e.g., temperature, rainfall) across all regions of a country for a specific year.
111
+
112
+ Args:
113
+ params (dict): Dictionary with:
114
+ - indicator_column (str): Name of the climate indicator column to plot.
115
+ - year (str or int, optional): Year for which to plot the indicator (default: 2050).
116
+ - country_name (str): Name of the country.
117
+ - location (str): Location (country or region) for the map.
118
+
119
+ Returns:
120
+ Callable[[pd.DataFrame], Figure]: Function that takes a DataFrame and returns a Plotly Figure
121
+ showing the indicator's spatial distribution as a choropleth map for the specified year.
122
+ """
123
+ indicator = params["indicator_column"]
124
+ year = params.get('year')
125
+ if year is None:
126
+ year = 2050
127
+ country_name = params['country_name']
128
+ location = params['location']
129
+ indicator_label = " ".join(word.capitalize() for word in indicator.split("_"))
130
+ unit = IPCC_INDICATOR_TO_UNIT.get(indicator, "")
131
+
132
+ def plot_data(df: pd.DataFrame) -> Figure:
133
+
134
+ indicators = df[indicator].astype(float).tolist()
135
+ latitudes = df["latitude"].astype(float).tolist()
136
+ longitudes = df["longitude"].astype(float).tolist()
137
+
138
+ geojson_data = generate_geojson_polygons(latitudes, longitudes, indicators)
139
+
140
+ fig = go.Figure(go.Choroplethmapbox(
141
+ geojson=geojson_data,
142
+ locations=[str(i) for i in range(len(indicators))],
143
+ featureidkey="id",
144
+ z=indicators,
145
+ colorscale=IPCC_INDICATOR_TO_COLORSCALE[indicator],
146
+ zmin=min(indicators),
147
+ zmax=max(indicators),
148
+ marker_opacity=0.7,
149
+ marker_line_width=0,
150
+ colorbar_title=f"{indicator_label} ({unit})",
151
+ text=[f"{indicator_label}: {value:.2f} {unit}" for value in indicators], # Add hover text showing the indicator value
152
+ hoverinfo="text"
153
+ ))
154
+
155
+ fig.update_layout(
156
+ mapbox_style="open-street-map",
157
+ mapbox_zoom=2,
158
+ height=800,
159
+ mapbox_center={
160
+ "lat": latitudes[len(latitudes)//2] if latitudes else 0,
161
+ "lon": longitudes[len(longitudes)//2] if longitudes else 0
162
+ },
163
+ coloraxis_colorbar=dict(title=f"{indicator_label} ({unit})"),
164
+ title=f"{indicator_label} in {year} in {location} ({country_name})"
165
+ )
166
+ return fig
167
+
168
+ return plot_data
169
+
170
+ choropleth_map_of_country_indicator_for_specific_year: Plot = {
171
+ "name": "Choropleth Map of a Country's Indicator Distribution for a Specific Year",
172
+ "description": (
173
+ "Displays a map showing the spatial distribution of a climate indicator (e.g., rainfall, temperature) "
174
+ "across all regions of a country for a specific year. "
175
+ "Can answer questions about the value of an indicator in a country or region for a given year, "
176
+ "such as 'What will be the total rainfall in China in 2050?' or 'How is rainfall distributed across China in 2050?'. "
177
+ "Parameters: indicator_column (the climate variable), year, location (country name)."
178
+ ),
179
+ "params": ["indicator_column", "year", "location"],
180
+ "plot_function": plot_choropleth_map_of_country_indicator_for_specific_year,
181
+ "sql_query": indicator_for_given_year_query,
182
+ "plot_information": choropleth_map_informations,
183
+ "short_name": "Map",
184
+ }
185
+
186
+ IPCC_PLOTS = [
187
+ indicator_evolution_at_location_historical_and_projections,
188
+ choropleth_map_of_country_indicator_for_specific_year
189
+ ]
climateqa/engine/talk_to_data/ipcc/queries.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TypedDict, Optional
2
+
3
+ from climateqa.engine.talk_to_data.ipcc.config import HUGE_MACRO_COUNTRIES, MACRO_COUNTRIES
4
+ from climateqa.engine.talk_to_data.config import IPCC_DATASET_URL
5
+ class IndicatorPerYearAtLocationQueryParams(TypedDict, total=False):
6
+ """
7
+ Parameters for querying the evolution of an indicator per year at a specific location.
8
+
9
+ Attributes:
10
+ indicator_column (str): Name of the climate indicator column.
11
+ latitude (str): Latitude of the location.
12
+ longitude (str): Longitude of the location.
13
+ country_code (str): Country code.
14
+ admin1 (str): Administrative region (optional).
15
+ """
16
+ indicator_column: str
17
+ latitude: str
18
+ longitude: str
19
+ country_code: str
20
+ admin1: Optional[str]
21
+
22
+ def indicator_per_year_at_location_query(
23
+ table: str, params: IndicatorPerYearAtLocationQueryParams
24
+ ) -> str:
25
+ """
26
+ Builds an SQL query to get the evolution of an indicator per year at a specific location.
27
+
28
+ Args:
29
+ table (str): SQL table of the indicator.
30
+ params (IndicatorPerYearAtLocationQueryParams): Dictionary with the required params for the query.
31
+
32
+ Returns:
33
+ str: The SQL query string, or an empty string if required parameters are missing.
34
+ """
35
+ indicator_column = params.get("indicator_column")
36
+ latitude = params.get("latitude")
37
+ longitude = params.get("longitude")
38
+ country_code = params.get("country_code")
39
+ admin1 = params.get("admin1")
40
+
41
+ if not all([indicator_column, latitude, longitude, country_code]):
42
+ return ""
43
+
44
+ if country_code in MACRO_COUNTRIES:
45
+ table_path = f"'{IPCC_DATASET_URL}/{table.lower()}/{country_code}_macro.parquet'"
46
+ sql_query = f"""
47
+ SELECT year, scenario, AVG({indicator_column}) as {indicator_column}
48
+ FROM {table_path}
49
+ WHERE latitude = {latitude} AND longitude = {longitude} AND year >= 1950
50
+ GROUP BY scenario, year
51
+ ORDER BY year, scenario
52
+ """
53
+ elif country_code in HUGE_MACRO_COUNTRIES:
54
+ table_path = f"'{IPCC_DATASET_URL}/{table.lower()}/{country_code}_macro.parquet'"
55
+ sql_query = f"""
56
+ SELECT year, scenario, {indicator_column},
57
+ FROM {table_path}
58
+ WHERE latitude = {latitude} AND longitude = {longitude} AND year >= 1950
59
+ ORDER year, scenario
60
+ """
61
+ else:
62
+ table_path = f"'{IPCC_DATASET_URL}/{table.lower()}/{country_code}.parquet'"
63
+ sql_query = f"""
64
+ WITH medians_per_month AS (
65
+ SELECT year, scenario, month, MEDIAN({indicator_column}) AS median_value
66
+ FROM {table_path}
67
+ WHERE latitude = {latitude} AND longitude = {longitude} AND year >= 1950
68
+ GROUP BY scenario, year, month
69
+ )
70
+ SELECT year, scenario, AVG(median_value) AS {indicator_column}
71
+ FROM medians_per_month
72
+ GROUP BY scenario, year
73
+ ORDER BY year, scenario
74
+ """
75
+ return sql_query.strip()
76
+
77
+ class IndicatorForGivenYearQueryParams(TypedDict, total=False):
78
+ """
79
+ Parameters for querying an indicator's values across locations for a specific year.
80
+
81
+ Attributes:
82
+ indicator_column (str): The column name for the climate indicator.
83
+ year (str): The year to query.
84
+ country_code (str): The country code.
85
+ """
86
+ indicator_column: str
87
+ year: str
88
+ country_code: str
89
+
90
+ def indicator_for_given_year_query(
91
+ table: str, params: IndicatorForGivenYearQueryParams
92
+ ) -> str:
93
+ """
94
+ Builds an SQL query to get the values of an indicator with their latitudes, longitudes,
95
+ and scenarios for a given year.
96
+
97
+ Args:
98
+ table (str): SQL table of the indicator.
99
+ params (IndicatorForGivenYearQueryParams): Dictionary with the required params for the query.
100
+
101
+ Returns:
102
+ str: The SQL query string, or an empty string if required parameters are missing.
103
+ """
104
+ indicator_column = params.get("indicator_column")
105
+ year = params.get("year") or 2050
106
+ country_code = params.get("country_code")
107
+
108
+ if not all([indicator_column, year, country_code]):
109
+ return ""
110
+
111
+ if country_code in MACRO_COUNTRIES:
112
+ table_path = f"'{IPCC_DATASET_URL}/{table.lower()}/{country_code}_macro.parquet'"
113
+ sql_query = f"""
114
+ SELECT latitude, longitude, scenario, AVG({indicator_column}) as {indicator_column}
115
+ FROM {table_path}
116
+ WHERE year = {year}
117
+ GROUP BY latitude, longitude, scenario
118
+ ORDER BY latitude, longitude, scenario
119
+ """
120
+ elif country_code in HUGE_MACRO_COUNTRIES:
121
+ table_path = f"'{IPCC_DATASET_URL}/{table.lower()}/{country_code}_macro.parquet'"
122
+ sql_query = f"""
123
+ SELECT latitude, longitude, scenario, {indicator_column},
124
+ FROM {table_path}
125
+ WHERE year = {year}
126
+ ORDER BY latitude, longitude, scenario
127
+ """
128
+ else:
129
+ table_path = f"'{IPCC_DATASET_URL}/{table.lower()}/{country_code}.parquet'"
130
+ sql_query = f"""
131
+ WITH medians_per_month AS (
132
+ SELECT latitude, longitude, scenario, month, MEDIAN({indicator_column}) AS median_value
133
+ FROM {table_path}
134
+ WHERE year = {year}
135
+ GROUP BY latitude, longitude, scenario, month
136
+ )
137
+ SELECT latitude, longitude, scenario, AVG(median_value) AS {indicator_column}
138
+ FROM medians_per_month
139
+ GROUP BY latitude, longitude, scenario
140
+ ORDER BY latitude, longitude, scenario
141
+ """
142
+
143
+ return sql_query.strip()
climateqa/engine/talk_to_data/main.py CHANGED
@@ -1,44 +1,70 @@
1
- from climateqa.engine.talk_to_data.talk_to_drias import drias_workflow
2
- from climateqa.engine.llm import get_llm
3
  from climateqa.logging import log_drias_interaction_to_huggingface
4
- import ast
5
 
6
- llm = get_llm(provider="openai")
7
-
8
- def ask_llm_to_add_table_names(sql_query: str, llm) -> str:
9
- """Adds table names to the SQL query result rows using LLM.
10
 
11
- This function modifies the SQL query to include the source table name in each row
12
- of the result set, making it easier to track which data comes from which table.
 
13
 
14
  Args:
15
- sql_query (str): The original SQL query to modify
16
- llm: The language model instance to use for generating the modified query
17
 
18
  Returns:
19
- str: The modified SQL query with table names included in the result rows
 
 
 
 
 
 
 
 
 
20
  """
21
- sql_with_table_names = llm.invoke(f"Make the following sql query display the source table in the rows {sql_query}. Just answer the query. The answer should not include ```sql\n").content
22
- return sql_with_table_names
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
- def ask_llm_column_names(sql_query: str, llm) -> list[str]:
25
- """Extracts column names from a SQL query using LLM.
 
 
26
 
27
- This function analyzes a SQL query to identify which columns are being selected
28
- in the result set.
29
 
30
- Args:
31
- sql_query (str): The SQL query to analyze
32
- llm: The language model instance to use for column extraction
33
-
34
- Returns:
35
- list[str]: A list of column names being selected in the query
36
- """
37
- columns = llm.invoke(f"From the given sql query, list the columns that are being selected. The answer should only be a python list. Just answer the list. The SQL query : {sql_query}").content
38
- columns_list = ast.literal_eval(columns.strip("```python\n").strip())
39
- return columns_list
40
 
41
- async def ask_drias(query: str, index_state: int = 0, user_id: str = None) -> tuple:
 
 
42
  """Main function to process a DRIAS query and return results.
43
 
44
  This function orchestrates the DRIAS workflow, processing a user query to generate
@@ -61,58 +87,38 @@ async def ask_drias(query: str, index_state: int = 0, user_id: str = None) -> tu
61
  - table_list (list): List of table names used
62
  - error (str): Error message if any
63
  """
64
- final_state = await drias_workflow(query)
65
  sql_queries = []
66
  result_dataframes = []
67
  figures = []
68
- table_list = []
 
69
 
70
- for plot_state in final_state['plot_states'].values():
71
- for table_state in plot_state['table_states'].values():
72
- if table_state['status'] == 'OK':
73
- if 'table_name' in table_state:
74
- table_list.append(' '.join(table_state['table_name'].capitalize().split('_')))
75
- if 'sql_query' in table_state and table_state['sql_query'] is not None:
76
- sql_queries.append(table_state['sql_query'])
77
-
78
- if 'dataframe' in table_state and table_state['dataframe'] is not None:
79
- result_dataframes.append(table_state['dataframe'])
80
- if 'figure' in table_state and table_state['figure'] is not None:
81
- figures.append(table_state['figure'])
 
 
 
82
 
83
  if "error" in final_state and final_state["error"] != "":
84
- return None, None, None, [], [], [], 0, final_state["error"]
 
85
 
86
  sql_query = sql_queries[index_state]
87
  dataframe = result_dataframes[index_state]
88
  figure = figures[index_state](dataframe)
 
89
 
90
  log_drias_interaction_to_huggingface(query, sql_query, user_id)
91
 
92
- return sql_query, dataframe, figure, sql_queries, result_dataframes, figures, index_state, table_list, ""
93
-
94
- # def ask_vanna(vn,db_vanna_path, query):
95
-
96
- # try :
97
- # location = detect_location_with_openai(query)
98
- # if location:
99
-
100
- # coords = loc2coords(location)
101
- # user_input = query.lower().replace(location.lower(), f"lat, long : {coords}")
102
-
103
- # relevant_tables = detect_relevant_tables(db_vanna_path, user_input, llm)
104
- # coords_tables = [nearestNeighbourSQL(db_vanna_path, coords, relevant_tables[i]) for i in range(len(relevant_tables))]
105
- # user_input_with_coords = replace_coordonates(coords, user_input, coords_tables)
106
-
107
- # sql_query, result_dataframe, figure = vn.ask(user_input_with_coords, print_results=False, allow_llm_to_see_data=True, auto_train=False)
108
-
109
- # return sql_query, result_dataframe, figure
110
- # else :
111
- # empty_df = pd.DataFrame()
112
- # empty_fig = None
113
- # return "", empty_df, empty_fig
114
- # except Exception as e:
115
- # print(f"Error: {e}")
116
- # empty_df = pd.DataFrame()
117
- # empty_fig = None
118
- # return "", empty_df, empty_fig
 
1
+ from climateqa.engine.talk_to_data.workflow.drias import drias_workflow
2
+ from climateqa.engine.talk_to_data.workflow.ipcc import ipcc_workflow
3
  from climateqa.logging import log_drias_interaction_to_huggingface
 
4
 
5
+ async def ask_drias(query: str, index_state: int = 0, user_id: str | None = None) -> tuple:
6
+ """Main function to process a DRIAS query and return results.
 
 
7
 
8
+ This function orchestrates the DRIAS workflow, processing a user query to generate
9
+ SQL queries, dataframes, and visualizations. It handles multiple results and allows
10
+ pagination through them.
11
 
12
  Args:
13
+ query (str): The user's question about climate data
14
+ index_state (int, optional): The index of the result to return. Defaults to 0.
15
 
16
  Returns:
17
+ tuple: A tuple containing:
18
+ - sql_query (str): The SQL query used
19
+ - dataframe (pd.DataFrame): The resulting data
20
+ - figure (Callable): Function to generate the visualization
21
+ - sql_queries (list): All generated SQL queries
22
+ - result_dataframes (list): All resulting dataframes
23
+ - figures (list): All figure generation functions
24
+ - index_state (int): Current result index
25
+ - table_list (list): List of table names used
26
+ - error (str): Error message if any
27
  """
28
+ final_state = await drias_workflow(query)
29
+ sql_queries = []
30
+ result_dataframes = []
31
+ figures = []
32
+ plot_title_list = []
33
+ plot_informations = []
34
+
35
+ for output_title, output in final_state['outputs'].items():
36
+ if output['status'] == 'OK':
37
+ if output['table'] is not None:
38
+ plot_title_list.append(output_title)
39
+
40
+ if output['plot_information'] is not None:
41
+ plot_informations.append(output['plot_information'])
42
+
43
+ if output['sql_query'] is not None:
44
+ sql_queries.append(output['sql_query'])
45
+
46
+ if output['dataframe'] is not None:
47
+ result_dataframes.append(output['dataframe'])
48
+ if output['figure'] is not None:
49
+ figures.append(output['figure'])
50
+
51
+ if "error" in final_state and final_state["error"] != "":
52
+ # No Sql query, no dataframe, no figure, no plot information, empty sql queries list, empty result dataframes list, empty figures list, empty plot information list, index state = 0, empty table list, error message
53
+ return None, None, None, None, [], [], [], 0, [], final_state["error"]
54
 
55
+ sql_query = sql_queries[index_state]
56
+ dataframe = result_dataframes[index_state]
57
+ figure = figures[index_state](dataframe)
58
+ plot_information = plot_informations[index_state]
59
 
 
 
60
 
61
+ log_drias_interaction_to_huggingface(query, sql_query, user_id)
62
+
63
+ return sql_query, dataframe, figure, plot_information, sql_queries, result_dataframes, figures, plot_informations, index_state, plot_title_list, ""
 
 
 
 
 
 
 
64
 
65
+
66
+
67
+ async def ask_ipcc(query: str, index_state: int = 0, user_id: str | None = None) -> tuple:
68
  """Main function to process a DRIAS query and return results.
69
 
70
  This function orchestrates the DRIAS workflow, processing a user query to generate
 
87
  - table_list (list): List of table names used
88
  - error (str): Error message if any
89
  """
90
+ final_state = await ipcc_workflow(query)
91
  sql_queries = []
92
  result_dataframes = []
93
  figures = []
94
+ plot_title_list = []
95
+ plot_informations = []
96
 
97
+ for output_title, output in final_state['outputs'].items():
98
+ if output['status'] == 'OK':
99
+ if output['table'] is not None:
100
+ plot_title_list.append(output_title)
101
+
102
+ if output['plot_information'] is not None:
103
+ plot_informations.append(output['plot_information'])
104
+
105
+ if output['sql_query'] is not None:
106
+ sql_queries.append(output['sql_query'])
107
+
108
+ if output['dataframe'] is not None:
109
+ result_dataframes.append(output['dataframe'])
110
+ if output['figure'] is not None:
111
+ figures.append(output['figure'])
112
 
113
  if "error" in final_state and final_state["error"] != "":
114
+ # No Sql query, no dataframe, no figure, no plot information, empty sql queries list, empty result dataframes list, empty figures list, empty plot information list, index state = 0, empty table list, error message
115
+ return None, None, None, None, [], [], [], 0, [], final_state["error"]
116
 
117
  sql_query = sql_queries[index_state]
118
  dataframe = result_dataframes[index_state]
119
  figure = figures[index_state](dataframe)
120
+ plot_information = plot_informations[index_state]
121
 
122
  log_drias_interaction_to_huggingface(query, sql_query, user_id)
123
 
124
+ return sql_query, dataframe, figure, plot_information, sql_queries, result_dataframes, figures, plot_informations, index_state, plot_title_list, ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
climateqa/engine/talk_to_data/objects/llm_outputs.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Annotated, TypedDict
2
+
3
+
4
+ class ArrayOutput(TypedDict):
5
+ """Represents the output of a function that returns an array.
6
+
7
+ This class is used to type-hint functions that return arrays,
8
+ ensuring consistent return types across the codebase.
9
+
10
+ Attributes:
11
+ array (str): A syntactically valid Python array string
12
+ """
13
+ array: Annotated[str, "Syntactically valid python array."]
climateqa/engine/talk_to_data/objects/location.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from token import OP
2
+ from typing import Optional, TypedDict
3
+
4
+
5
+
6
+ class Location(TypedDict):
7
+ location: str
8
+ latitude: Optional[str]
9
+ longitude: Optional[str]
10
+ country_code: Optional[str]
11
+ country_name: Optional[str]
12
+ admin1: Optional[str]
climateqa/engine/talk_to_data/objects/plot.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, TypedDict, Optional
2
+ from plotly.graph_objects import Figure
3
+
4
+ class Plot(TypedDict):
5
+ """Represents a plot configuration in the DRIAS system.
6
+
7
+ This class defines the structure for configuring different types of plots
8
+ that can be generated from climate data.
9
+
10
+ Attributes:
11
+ name (str): The name of the plot type
12
+ description (str): A description of what the plot shows
13
+ params (list[str]): List of required parameters for the plot
14
+ plot_function (Callable[..., Callable[..., Figure]]): Function to generate the plot
15
+ sql_query (Callable[..., str]): Function to generate the SQL query for the plot
16
+ """
17
+ name: str
18
+ description: str
19
+ params: list[str]
20
+ plot_function: Callable[..., Callable[..., Figure]]
21
+ sql_query: Callable[..., str]
22
+ plot_information: Callable[..., str]
23
+ short_name: str
climateqa/engine/talk_to_data/objects/states.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Callable, Optional, TypedDict
2
+ from plotly.graph_objects import Figure
3
+ import pandas as pd
4
+ from climateqa.engine.talk_to_data.objects.plot import Plot
5
+
6
+ class TTDOutput(TypedDict):
7
+ status: str
8
+ plot: Plot
9
+ table: str
10
+ sql_query: Optional[str]
11
+ dataframe: Optional[pd.DataFrame]
12
+ figure: Optional[Callable[..., Figure]]
13
+ plot_information: Optional[str]
14
+ class State(TypedDict):
15
+ user_input: str
16
+ plots: list[str]
17
+ outputs: dict[str, TTDOutput]
18
+ error: Optional[str]
19
+
climateqa/engine/talk_to_data/prompt.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ query_prompt_template = """You are an expert SQL query generator. Given an input question, database schema, SQL dialect and relevant tables to answer the question, generate an optimized and syntactically correct SQL query which can provide useful insights to the question.
2
+
3
+ ### Instructions:
4
+ 1. **Use only relevant tables**: The following tables are relevant to answering the question: {relevant_tables}. Do not use any other tables.
5
+ 2. **Relevant columns only**: Never select `*`. Only include necessary columns based on the input question.
6
+ 3. **Schema Awareness**:
7
+ - Use only columns present in the given schema.
8
+ - **If a column name appears in multiple tables, always use the format `table_name.column_name` to avoid ambiguity.**
9
+ - Select only the column which are insightful for the question.
10
+ 4. **Dialect Compliance**: Follow `{dialect}` syntax rules.
11
+ 5. **Ordering**: Order the results by a relevant column if applicable (e.g., timestamp for recent records).
12
+ 6. **Valid query**: Make sure the query is syntactically and functionally correct.
13
+ 7. **Conditions** : For the common columns, the same condition should be applied to all the tables (e.g. latitude, longitude, model, year...)
14
+ 9. **Join tables** : If you need to join table, you should join them with year feature.
15
+ 10. **Model** : For each table, you need to add a condition on the model to be equal to {model}
16
+
17
+ ### Provided Database Schema:
18
+ {table_info}
19
+
20
+ ### Relevant Tables:
21
+ {relevant_tables}
22
+
23
+ **Question:** {input}
24
+
25
+ **SQL Query:**"""
26
+
27
+ plot_prompt_template = """You are a data visualization expert. Given an input question and an SQL Query, generate an insightful plot according to the question.
28
+
29
+ ### Instructions
30
+ 1. **Use only the column names provided**. The data will be provided as a Pandas DataFrame `df` with the columns present in the SELECT.
31
+ 2. Generate the Python Plotly code to chart the results using `df` and the column names.
32
+ 3. Make as complete a graph as possible to answer the question, and make it as easy to understand as possible.
33
+ 4. **Response with only Python code**. Do not answer with any explanations -- just the code.
34
+ 5. **Specific cases** :
35
+ - For a question about the evolution of something, it is also relevant to plot the data with also the sliding average for a period of 20 years for example.
36
+
37
+ ### SQL Query:
38
+ {sql_query}
39
+
40
+ **Question:** {input}
41
+
42
+ **Python code:**
43
+ """
44
+
climateqa/engine/talk_to_data/query.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ from concurrent.futures import ThreadPoolExecutor
3
+ import duckdb
4
+ import pandas as pd
5
+ import os
6
+
7
+ def find_indicator_column(table: str, indicator_columns_per_table: dict[str,str]) -> str:
8
+ """Retrieves the name of the indicator column within a table.
9
+
10
+ This function maps table names to their corresponding indicator columns
11
+ using the predefined mapping in INDICATOR_COLUMNS_PER_TABLE.
12
+
13
+ Args:
14
+ table (str): Name of the table in the database
15
+
16
+ Returns:
17
+ str: Name of the indicator column for the specified table
18
+
19
+ Raises:
20
+ KeyError: If the table name is not found in the mapping
21
+ """
22
+ print(f"---- Find indicator column in table {table} ----")
23
+ return indicator_columns_per_table[table]
24
+
25
+ async def execute_sql_query(sql_query: str) -> pd.DataFrame:
26
+ """Executes a SQL query on the DRIAS database and returns the results.
27
+
28
+ This function connects to the DuckDB database containing DRIAS climate data
29
+ and executes the provided SQL query. It handles the database connection and
30
+ returns the results as a pandas DataFrame.
31
+
32
+ Args:
33
+ sql_query (str): The SQL query to execute
34
+
35
+ Returns:
36
+ pd.DataFrame: A DataFrame containing the query results
37
+
38
+ Raises:
39
+ duckdb.Error: If there is an error executing the SQL query
40
+ """
41
+ def _execute_query():
42
+ # Execute the query
43
+ con = duckdb.connect()
44
+ HF_TTD_TOKEN = os.getenv("HF_TTD_TOKEN")
45
+ con.execute(f"""CREATE SECRET hf_token (
46
+ TYPE huggingface,
47
+ TOKEN '{HF_TTD_TOKEN}'
48
+ );""")
49
+ results = con.execute(sql_query).fetchdf()
50
+ # return fetched data
51
+ return results
52
+
53
+ # Run the query in a thread pool to avoid blocking
54
+ loop = asyncio.get_event_loop()
55
+ with ThreadPoolExecutor() as executor:
56
+ return await loop.run_in_executor(executor, _execute_query)
57
+
climateqa/engine/talk_to_data/ui_config.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TEMPERATURE_COLORSCALE = [
2
+ [0.0, "rgb(5, 48, 97)"],
3
+ [0.10, "rgb(33, 102, 172)"],
4
+ [0.20, "rgb(67, 147, 195)"],
5
+ [0.30, "rgb(146, 197, 222)"],
6
+ [0.40, "rgb(209, 229, 240)"],
7
+ [0.50, "rgb(247, 247, 247)"],
8
+ [0.60, "rgb(253, 219, 199)"],
9
+ [0.75, "rgb(244, 165, 130)"],
10
+ [0.85, "rgb(214, 96, 77)"],
11
+ [0.90, "rgb(178, 24, 43)"],
12
+ [1.0, "rgb(103, 0, 31)"]
13
+ ]
14
+
15
+ PRECIPITATION_COLORSCALE = [
16
+ [0.0, "rgb(84, 48, 5)"],
17
+ [0.10, "rgb(140, 81, 10)"],
18
+ [0.20, "rgb(191, 129, 45)"],
19
+ [0.30, "rgb(223, 194, 125)"],
20
+ [0.40, "rgb(246, 232, 195)"],
21
+ [0.50, "rgb(245, 245, 245)"],
22
+ [0.60, "rgb(199, 234, 229)"],
23
+ [0.75, "rgb(128, 205, 193)"],
24
+ [0.85, "rgb(53, 151, 143)"],
25
+ [0.90, "rgb(1, 102, 94)"],
26
+ [1.0, "rgb(0, 60, 48)"]
27
+ ]
climateqa/engine/talk_to_data/vanna/myVanna.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dotenv import load_dotenv
2
+ from climateqa.engine.talk_to_data.vanna_class import MyCustomVectorDB
3
+ from vanna.openai import OpenAI_Chat
4
+ import os
5
+
6
+ load_dotenv()
7
+
8
+ OPENAI_API_KEY = os.getenv('THEO_API_KEY')
9
+
10
+ class MyVanna(MyCustomVectorDB, OpenAI_Chat):
11
+ def __init__(self, config=None):
12
+ MyCustomVectorDB.__init__(self, config=config)
13
+ OpenAI_Chat.__init__(self, config=config)
climateqa/engine/talk_to_data/vanna/vanna_class.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from vanna.base import VannaBase
2
+ from pinecone import Pinecone
3
+ from climateqa.engine.embeddings import get_embeddings_function
4
+ import pandas as pd
5
+ import hashlib
6
+
7
+ class MyCustomVectorDB(VannaBase):
8
+
9
+ """
10
+ VectorDB class for storing and retrieving vectors from Pinecone.
11
+
12
+ args :
13
+ config (dict) : Configuration dictionary containing the Pinecone API key and the index name :
14
+ - pc_api_key (str) : Pinecone API key
15
+ - index_name (str) : Pinecone index name
16
+ - top_k (int) : Number of top results to return (default = 2)
17
+
18
+ """
19
+
20
+ def __init__(self,config):
21
+ super().__init__(config = config)
22
+ try :
23
+ self.api_key = config.get('pc_api_key')
24
+ self.index_name = config.get('index_name')
25
+ except :
26
+ raise Exception("Please provide the Pinecone API key and the index name")
27
+
28
+ self.pc = Pinecone(api_key = self.api_key)
29
+ self.index = self.pc.Index(self.index_name)
30
+ self.top_k = config.get('top_k', 2)
31
+ self.embeddings = get_embeddings_function()
32
+
33
+
34
+ def check_embedding(self, id, namespace):
35
+ fetched = self.index.fetch(ids = [id], namespace = namespace)
36
+ if fetched['vectors'] == {}:
37
+ return False
38
+ return True
39
+
40
+ def generate_hash_id(self, data: str) -> str:
41
+ """
42
+ Generate a unique hash ID for the given data.
43
+
44
+ Args:
45
+ data (str): The input data to hash (e.g., a concatenated string of user attributes).
46
+
47
+ Returns:
48
+ str: A unique hash ID as a hexadecimal string.
49
+ """
50
+
51
+ data_bytes = data.encode('utf-8')
52
+ hash_object = hashlib.sha256(data_bytes)
53
+ hash_id = hash_object.hexdigest()
54
+
55
+ return hash_id
56
+
57
+ def add_ddl(self, ddl: str, **kwargs) -> str:
58
+ id = self.generate_hash_id(ddl) + '_ddl'
59
+
60
+ if self.check_embedding(id, 'ddl'):
61
+ print(f"DDL having id {id} already exists")
62
+ return id
63
+
64
+ self.index.upsert(
65
+ vectors = [(id, self.embeddings.embed_query(ddl), {'ddl': ddl})],
66
+ namespace = 'ddl'
67
+ )
68
+
69
+ return id
70
+
71
+ def add_documentation(self, doc: str, **kwargs) -> str:
72
+ id = self.generate_hash_id(doc) + '_doc'
73
+
74
+ if self.check_embedding(id, 'documentation'):
75
+ print(f"Documentation having id {id} already exists")
76
+ return id
77
+
78
+ self.index.upsert(
79
+ vectors = [(id, self.embeddings.embed_query(doc), {'doc': doc})],
80
+ namespace = 'documentation'
81
+ )
82
+
83
+ return id
84
+
85
+ def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
86
+ id = self.generate_hash_id(question) + '_sql'
87
+
88
+ if self.check_embedding(id, 'question_sql'):
89
+ print(f"Question-SQL pair having id {id} already exists")
90
+ return id
91
+
92
+ self.index.upsert(
93
+ vectors = [(id, self.embeddings.embed_query(question + sql), {'question': question, 'sql': sql})],
94
+ namespace = 'question_sql'
95
+ )
96
+
97
+ return id
98
+
99
+ def get_related_ddl(self, question: str, **kwargs) -> list:
100
+ res = self.index.query(
101
+ vector=self.embeddings.embed_query(question),
102
+ top_k=self.top_k,
103
+ namespace='ddl',
104
+ include_metadata=True
105
+ )
106
+
107
+ return [match['metadata']['ddl'] for match in res['matches']]
108
+
109
+ def get_related_documentation(self, question: str, **kwargs) -> list:
110
+ res = self.index.query(
111
+ vector=self.embeddings.embed_query(question),
112
+ top_k=self.top_k,
113
+ namespace='documentation',
114
+ include_metadata=True
115
+ )
116
+
117
+ return [match['metadata']['doc'] for match in res['matches']]
118
+
119
+ def get_similar_question_sql(self, question: str, **kwargs) -> list:
120
+ res = self.index.query(
121
+ vector=self.embeddings.embed_query(question),
122
+ top_k=self.top_k,
123
+ namespace='question_sql',
124
+ include_metadata=True
125
+ )
126
+
127
+ return [(match['metadata']['question'], match['metadata']['sql']) for match in res['matches']]
128
+
129
+ def get_training_data(self, **kwargs) -> pd.DataFrame:
130
+
131
+ list_of_data = []
132
+
133
+ namespaces = ['ddl', 'documentation', 'question_sql']
134
+
135
+ for namespace in namespaces:
136
+
137
+ data = self.index.query(
138
+ top_k=10000,
139
+ namespace=namespace,
140
+ include_metadata=True,
141
+ include_values=False
142
+ )
143
+
144
+ for match in data['matches']:
145
+ list_of_data.append(match['metadata'])
146
+
147
+ return pd.DataFrame(list_of_data)
148
+
149
+
150
+
151
+ def remove_training_data(self, id: str, **kwargs) -> bool:
152
+ if id.endswith("_ddl"):
153
+ self.Index.delete(ids=[id], namespace="_ddl")
154
+ return True
155
+ if id.endswith("_sql"):
156
+ self.index.delete(ids=[id], namespace="_sql")
157
+ return True
158
+
159
+ if id.endswith("_doc"):
160
+ self.Index.delete(ids=[id], namespace="_doc")
161
+ return True
162
+
163
+ return False
164
+
165
+ def generate_embedding(self, text, **kwargs):
166
+ # Implement the method here
167
+ pass
168
+
169
+
170
+ def get_sql_prompt(
171
+ self,
172
+ initial_prompt : str,
173
+ question: str,
174
+ question_sql_list: list,
175
+ ddl_list: list,
176
+ doc_list: list,
177
+ **kwargs,
178
+ ):
179
+ """
180
+ Example:
181
+ ```python
182
+ vn.get_sql_prompt(
183
+ question="What are the top 10 customers by sales?",
184
+ question_sql_list=[{"question": "What are the top 10 customers by sales?", "sql": "SELECT * FROM customers ORDER BY sales DESC LIMIT 10"}],
185
+ ddl_list=["CREATE TABLE customers (id INT, name TEXT, sales DECIMAL)"],
186
+ doc_list=["The customers table contains information about customers and their sales."],
187
+ )
188
+
189
+ ```
190
+
191
+ This method is used to generate a prompt for the LLM to generate SQL.
192
+
193
+ Args:
194
+ question (str): The question to generate SQL for.
195
+ question_sql_list (list): A list of questions and their corresponding SQL statements.
196
+ ddl_list (list): A list of DDL statements.
197
+ doc_list (list): A list of documentation.
198
+
199
+ Returns:
200
+ any: The prompt for the LLM to generate SQL.
201
+ """
202
+
203
+ if initial_prompt is None:
204
+ initial_prompt = f"You are a {self.dialect} expert. " + \
205
+ "Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions. "
206
+
207
+ initial_prompt = self.add_ddl_to_prompt(
208
+ initial_prompt, ddl_list, max_tokens=self.max_tokens
209
+ )
210
+
211
+ if self.static_documentation != "":
212
+ doc_list.append(self.static_documentation)
213
+
214
+ initial_prompt = self.add_documentation_to_prompt(
215
+ initial_prompt, doc_list, max_tokens=self.max_tokens
216
+ )
217
+
218
+ # initial_prompt = self.add_sql_to_prompt(
219
+ # initial_prompt, question_sql_list, max_tokens=self.max_tokens
220
+ # )
221
+
222
+
223
+ initial_prompt += (
224
+ "===Response Guidelines \n"
225
+ "1. If the provided context is sufficient, please generate a valid SQL query without any explanations for the question. \n"
226
+ "2. If the provided context is almost sufficient but requires knowledge of a specific string in a particular column, please generate an intermediate SQL query to find the distinct strings in that column. Prepend the query with a comment saying intermediate_sql \n"
227
+ "3. If the provided context is insufficient, please give a sql query based on your knowledge and the context provided. \n"
228
+ "4. Please use the most relevant table(s). \n"
229
+ "5. If the question has been asked and answered before, please repeat the answer exactly as it was given before. \n"
230
+ f"6. Ensure that the output SQL is {self.dialect}-compliant and executable, and free of syntax errors. \n"
231
+ f"7. Add a description of the table in the result of the sql query, if relevant. \n"
232
+ "8 Make sure to include the relevant KPI in the SQL query. The query should return impactfull data \n"
233
+ # f"8. If a set of latitude,longitude is provided, make a intermediate query to find the nearest value in the table and replace the coordinates in the sql query. \n"
234
+ # "7. Add a description of the table in the result of the sql query."
235
+ # "7. If the question is about a specific latitude, longitude, query an interval of 0.3 and keep only the first set of coordinate. \n"
236
+ # "7. Table names should be included in the result of the sql query. Use for example Mean_winter_temperature AS table_name in the query \n"
237
+ )
238
+
239
+
240
+ message_log = [self.system_message(initial_prompt)]
241
+
242
+ for example in question_sql_list:
243
+ if example is None:
244
+ print("example is None")
245
+ else:
246
+ if example is not None and "question" in example and "sql" in example:
247
+ message_log.append(self.user_message(example["question"]))
248
+ message_log.append(self.assistant_message(example["sql"]))
249
+
250
+ message_log.append(self.user_message(question))
251
+
252
+ return message_log
253
+
254
+
255
+ # def get_sql_prompt(
256
+ # self,
257
+ # initial_prompt : str,
258
+ # question: str,
259
+ # question_sql_list: list,
260
+ # ddl_list: list,
261
+ # doc_list: list,
262
+ # **kwargs,
263
+ # ):
264
+ # """
265
+ # Example:
266
+ # ```python
267
+ # vn.get_sql_prompt(
268
+ # question="What are the top 10 customers by sales?",
269
+ # question_sql_list=[{"question": "What are the top 10 customers by sales?", "sql": "SELECT * FROM customers ORDER BY sales DESC LIMIT 10"}],
270
+ # ddl_list=["CREATE TABLE customers (id INT, name TEXT, sales DECIMAL)"],
271
+ # doc_list=["The customers table contains information about customers and their sales."],
272
+ # )
273
+
274
+ # ```
275
+
276
+ # This method is used to generate a prompt for the LLM to generate SQL.
277
+
278
+ # Args:
279
+ # question (str): The question to generate SQL for.
280
+ # question_sql_list (list): A list of questions and their corresponding SQL statements.
281
+ # ddl_list (list): A list of DDL statements.
282
+ # doc_list (list): A list of documentation.
283
+
284
+ # Returns:
285
+ # any: The prompt for the LLM to generate SQL.
286
+ # """
287
+
288
+ # if initial_prompt is None:
289
+ # initial_prompt = f"You are a {self.dialect} expert. " + \
290
+ # "Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions. "
291
+
292
+ # initial_prompt = self.add_ddl_to_prompt(
293
+ # initial_prompt, ddl_list, max_tokens=self.max_tokens
294
+ # )
295
+
296
+ # if self.static_documentation != "":
297
+ # doc_list.append(self.static_documentation)
298
+
299
+ # initial_prompt = self.add_documentation_to_prompt(
300
+ # initial_prompt, doc_list, max_tokens=self.max_tokens
301
+ # )
302
+
303
+ # initial_prompt += (
304
+ # "===Response Guidelines \n"
305
+ # "1. If the provided context is sufficient, please generate a valid SQL query without any explanations for the question. \n"
306
+ # "2. If the provided context is almost sufficient but requires knowledge of a specific string in a particular column, please generate an intermediate SQL query to find the distinct strings in that column. Prepend the query with a comment saying intermediate_sql \n"
307
+ # "3. If the provided context is insufficient, please explain why it can't be generated. \n"
308
+ # "4. Please use the most relevant table(s). \n"
309
+ # "5. If the question has been asked and answered before, please repeat the answer exactly as it was given before. \n"
310
+ # f"6. Ensure that the output SQL is {self.dialect}-compliant and executable, and free of syntax errors. \n"
311
+ # )
312
+
313
+ # message_log = [self.system_message(initial_prompt)]
314
+
315
+ # for example in question_sql_list:
316
+ # if example is None:
317
+ # print("example is None")
318
+ # else:
319
+ # if example is not None and "question" in example and "sql" in example:
320
+ # message_log.append(self.user_message(example["question"]))
321
+ # message_log.append(self.assistant_message(example["sql"]))
322
+
323
+ # message_log.append(self.user_message(question))
324
+
325
+ # return message_log
climateqa/engine/talk_to_data/workflow/drias.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from typing import Any
4
+ import asyncio
5
+ from climateqa.engine.llm import get_llm
6
+ from climateqa.engine.talk_to_data.input_processing import find_param, find_relevant_plots, find_relevant_tables_per_plot
7
+ from climateqa.engine.talk_to_data.query import execute_sql_query, find_indicator_column
8
+ from climateqa.engine.talk_to_data.objects.plot import Plot
9
+ from climateqa.engine.talk_to_data.objects.states import State, TTDOutput
10
+ from climateqa.engine.talk_to_data.drias.config import DRIAS_TABLES, DRIAS_INDICATOR_COLUMNS_PER_TABLE, DRIAS_PLOT_PARAMETERS
11
+ from climateqa.engine.talk_to_data.drias.plots import DRIAS_PLOTS
12
+
13
+ ROOT_PATH = os.path.dirname(os.path.dirname(os.getcwd()))
14
+
15
+ async def process_output(
16
+ output_title: str,
17
+ table: str,
18
+ plot: Plot,
19
+ params: dict[str, Any]
20
+ ) -> tuple[str, TTDOutput, dict[str, bool]]:
21
+ """
22
+ Processes a table for a given plot and parameters: builds the SQL query, executes it,
23
+ and generates the corresponding figure.
24
+
25
+ Args:
26
+ output_title (str): Title for the output (used as key in outputs dict).
27
+ table (str): The name of the table to process.
28
+ plot (Plot): The plot object containing SQL query and visualization function.
29
+ params (dict[str, Any]): Parameters used for querying the table.
30
+
31
+ Returns:
32
+ tuple: (output_title, results dict, errors dict)
33
+ """
34
+ results: TTDOutput = {
35
+ 'status': 'OK',
36
+ 'plot': plot,
37
+ 'table': table,
38
+ 'sql_query': None,
39
+ 'dataframe': None,
40
+ 'figure': None,
41
+ 'plot_information': None
42
+ }
43
+ errors = {
44
+ 'have_sql_query': False,
45
+ 'have_dataframe': False
46
+ }
47
+
48
+ # Find the indicator column for this table
49
+ indicator_column = find_indicator_column(table, DRIAS_INDICATOR_COLUMNS_PER_TABLE)
50
+ if indicator_column:
51
+ params['indicator_column'] = indicator_column
52
+
53
+ # Build the SQL query
54
+ sql_query = plot['sql_query'](table, params)
55
+ if not sql_query:
56
+ results['status'] = 'ERROR'
57
+ return output_title, results, errors
58
+
59
+ results['plot_information'] = plot['plot_information'](table, params)
60
+
61
+ results['sql_query'] = sql_query
62
+ errors['have_sql_query'] = True
63
+
64
+ # Execute the SQL query
65
+ df = await execute_sql_query(sql_query)
66
+ if df is not None and len(df) > 0:
67
+ results['dataframe'] = df
68
+ errors['have_dataframe'] = True
69
+ else:
70
+ results['status'] = 'NO_DATA'
71
+
72
+ # Generate the figure (always, even if df is empty, for consistency)
73
+ results['figure'] = plot['plot_function'](params)
74
+
75
+ return output_title, results, errors
76
+
77
+ async def drias_workflow(user_input: str) -> State:
78
+ """
79
+ Orchestrates the DRIAS workflow: from user input to SQL queries, dataframes, and figures.
80
+
81
+ Args:
82
+ user_input (str): The user's question.
83
+
84
+ Returns:
85
+ State: Final state with all results and error messages if any.
86
+ """
87
+ state: State = {
88
+ 'user_input': user_input,
89
+ 'plots': [],
90
+ 'outputs': {},
91
+ 'error': ''
92
+ }
93
+
94
+ llm = get_llm(provider="openai")
95
+ plots = await find_relevant_plots(state, llm, DRIAS_PLOTS)
96
+
97
+ if not plots:
98
+ state['error'] = 'There is no plot to answer to the question'
99
+ return state
100
+
101
+ plots = plots[:2] # limit to 2 types of plots
102
+ state['plots'] = plots
103
+
104
+ errors = {
105
+ 'have_relevant_table': False,
106
+ 'have_sql_query': False,
107
+ 'have_dataframe': False
108
+ }
109
+ outputs = {}
110
+
111
+ # Find relevant tables for each plot and prepare outputs
112
+ for plot_name in plots:
113
+ plot = next((p for p in DRIAS_PLOTS if p['name'] == plot_name), None)
114
+ if plot is None:
115
+ continue
116
+
117
+ relevant_tables = await find_relevant_tables_per_plot(state, plot, llm, DRIAS_TABLES)
118
+ if relevant_tables:
119
+ errors['have_relevant_table'] = True
120
+
121
+ for table in relevant_tables:
122
+ output_title = f"{plot['short_name']} - {' '.join(table.capitalize().split('_'))}"
123
+ outputs[output_title] = {
124
+ 'table': table,
125
+ 'plot': plot,
126
+ 'status': 'OK'
127
+ }
128
+
129
+ # Gather all required parameters
130
+ params = {}
131
+ for param_name in DRIAS_PLOT_PARAMETERS:
132
+ param = await find_param(state, param_name, mode='DRIAS')
133
+ if param:
134
+ params.update(param)
135
+
136
+ # Process all outputs in parallel using process_output
137
+ tasks = [
138
+ process_output(output_title, output['table'], output['plot'], params.copy())
139
+ for output_title, output in outputs.items()
140
+ ]
141
+ results = await asyncio.gather(*tasks)
142
+
143
+ # Update outputs with results and error flags
144
+ for output_title, task_results, task_errors in results:
145
+ outputs[output_title]['sql_query'] = task_results['sql_query']
146
+ outputs[output_title]['dataframe'] = task_results['dataframe']
147
+ outputs[output_title]['figure'] = task_results['figure']
148
+ outputs[output_title]['plot_information'] = task_results['plot_information']
149
+ outputs[output_title]['status'] = task_results['status']
150
+ errors['have_sql_query'] |= task_errors['have_sql_query']
151
+ errors['have_dataframe'] |= task_errors['have_dataframe']
152
+
153
+ state['outputs'] = outputs
154
+
155
+ # Set error messages if needed
156
+ if not errors['have_relevant_table']:
157
+ state['error'] = "There is no relevant table in our database to answer your question"
158
+ elif not errors['have_sql_query']:
159
+ state['error'] = "There is no relevant sql query on our database that can help to answer your question"
160
+ elif not errors['have_dataframe']:
161
+ state['error'] = "There is no data in our table that can answer to your question"
162
+
163
+ return state
climateqa/engine/talk_to_data/workflow/ipcc.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from typing import Any
4
+ import asyncio
5
+ from climateqa.engine.llm import get_llm
6
+ from climateqa.engine.talk_to_data.input_processing import find_param, find_relevant_plots, find_relevant_tables_per_plot
7
+ from climateqa.engine.talk_to_data.query import execute_sql_query, find_indicator_column
8
+ from climateqa.engine.talk_to_data.objects.plot import Plot
9
+ from climateqa.engine.talk_to_data.objects.states import State, TTDOutput
10
+ from climateqa.engine.talk_to_data.ipcc.config import IPCC_TABLES, IPCC_INDICATOR_COLUMNS_PER_TABLE, IPCC_PLOT_PARAMETERS
11
+ from climateqa.engine.talk_to_data.ipcc.plots import IPCC_PLOTS
12
+
13
+ ROOT_PATH = os.path.dirname(os.path.dirname(os.getcwd()))
14
+
15
+ async def process_output(
16
+ output_title: str,
17
+ table: str,
18
+ plot: Plot,
19
+ params: dict[str, Any]
20
+ ) -> tuple[str, TTDOutput, dict[str, bool]]:
21
+ """
22
+ Process a table for a given plot and parameters: builds the SQL query, executes it,
23
+ and generates the corresponding figure.
24
+
25
+ Args:
26
+ output_title (str): Title for the output (used as key in outputs dict).
27
+ table (str): The name of the table to process.
28
+ plot (Plot): The plot object containing SQL query and visualization function.
29
+ params (dict[str, Any]): Parameters used for querying the table.
30
+
31
+ Returns:
32
+ tuple: (output_title, results dict, errors dict)
33
+ """
34
+ results: TTDOutput = {
35
+ 'status': 'OK',
36
+ 'plot': plot,
37
+ 'table': table,
38
+ 'sql_query': None,
39
+ 'dataframe': None,
40
+ 'figure': None,
41
+ 'plot_information': None,
42
+ }
43
+ errors = {
44
+ 'have_sql_query': False,
45
+ 'have_dataframe': False
46
+ }
47
+
48
+ # Find the indicator column for this table
49
+ indicator_column = find_indicator_column(table, IPCC_INDICATOR_COLUMNS_PER_TABLE)
50
+ if indicator_column:
51
+ params['indicator_column'] = indicator_column
52
+
53
+ # Build the SQL query
54
+ sql_query = plot['sql_query'](table, params)
55
+ if not sql_query:
56
+ results['status'] = 'ERROR'
57
+ return output_title, results, errors
58
+
59
+ results['plot_information'] = plot['plot_information'](table, params)
60
+
61
+ results['sql_query'] = sql_query
62
+ errors['have_sql_query'] = True
63
+
64
+ # Execute the SQL query
65
+ df = await execute_sql_query(sql_query)
66
+ if df is not None and not df.empty:
67
+ results['dataframe'] = df
68
+ errors['have_dataframe'] = True
69
+ else:
70
+ results['status'] = 'NO_DATA'
71
+
72
+ # Generate the figure (always, even if df is empty, for consistency)
73
+ results['figure'] = plot['plot_function'](params)
74
+
75
+ return output_title, results, errors
76
+
77
+ async def ipcc_workflow(user_input: str) -> State:
78
+ """
79
+ Performs the complete workflow of Talk To IPCC: from user input to SQL queries, dataframes, and figures.
80
+
81
+ Args:
82
+ user_input (str): The user's question.
83
+
84
+ Returns:
85
+ State: Final state with all the results and error messages if any.
86
+ """
87
+ state: State = {
88
+ 'user_input': user_input,
89
+ 'plots': [],
90
+ 'outputs': {},
91
+ 'error': ''
92
+ }
93
+
94
+ llm = get_llm(provider="openai")
95
+ plots = await find_relevant_plots(state, llm, IPCC_PLOTS)
96
+ state['plots'] = plots
97
+
98
+ if not plots:
99
+ state['error'] = 'There is no plot to answer to the question'
100
+ return state
101
+
102
+ errors = {
103
+ 'have_relevant_table': False,
104
+ 'have_sql_query': False,
105
+ 'have_dataframe': False
106
+ }
107
+ outputs = {}
108
+
109
+ # Find relevant tables for each plot and prepare outputs
110
+ for plot_name in plots:
111
+ plot = next((p for p in IPCC_PLOTS if p['name'] == plot_name), None)
112
+ if plot is None:
113
+ continue
114
+
115
+ relevant_tables = await find_relevant_tables_per_plot(state, plot, llm, IPCC_TABLES)
116
+ if relevant_tables:
117
+ errors['have_relevant_table'] = True
118
+
119
+ for table in relevant_tables:
120
+ output_title = f"{plot['short_name']} - {' '.join(table.capitalize().split('_'))}"
121
+ outputs[output_title] = {
122
+ 'table': table,
123
+ 'plot': plot,
124
+ 'status': 'OK'
125
+ }
126
+
127
+ # Gather all required parameters
128
+ params = {}
129
+ for param_name in IPCC_PLOT_PARAMETERS:
130
+ param = await find_param(state, param_name, mode='IPCC')
131
+ if param:
132
+ params.update(param)
133
+
134
+ # Process all outputs in parallel using process_output
135
+ tasks = [
136
+ process_output(output_title, output['table'], output['plot'], params.copy())
137
+ for output_title, output in outputs.items()
138
+ ]
139
+ results = await asyncio.gather(*tasks)
140
+
141
+ # Update outputs with results and error flags
142
+ for output_title, task_results, task_errors in results:
143
+ outputs[output_title]['sql_query'] = task_results['sql_query']
144
+ outputs[output_title]['dataframe'] = task_results['dataframe']
145
+ outputs[output_title]['figure'] = task_results['figure']
146
+ outputs[output_title]['plot_information'] = task_results['plot_information']
147
+ outputs[output_title]['status'] = task_results['status']
148
+ errors['have_sql_query'] |= task_errors['have_sql_query']
149
+ errors['have_dataframe'] |= task_errors['have_dataframe']
150
+
151
+ state['outputs'] = outputs
152
+
153
+ # Set error messages if needed
154
+ if not errors['have_relevant_table']:
155
+ state['error'] = "There is no relevant table in our database to answer your question"
156
+ elif not errors['have_sql_query']:
157
+ state['error'] = "There is no relevant sql query on our database that can help to answer your question"
158
+ elif not errors['have_dataframe']:
159
+ state['error'] = "There is no data in our table that can answer to your question"
160
+
161
+ return state
front/assets/talk_to_drias_annual_temperature_france_example.png ADDED

Git LFS Details

  • SHA256: 7b257f8240d8d09c2843f9930194fdc21d6a7e41c86c990eac2e547d7dda5fd2
  • Pointer size: 132 Bytes
  • Size of remote file: 1.95 MB
front/assets/talk_to_drias_frequency_remarkable_precipitation_lyon_example.png ADDED

Git LFS Details

  • SHA256: 49b0efa2b1b071ec2e052b21b522be1ed0be959a8f113c362a69f156c01cf112
  • Pointer size: 131 Bytes
  • Size of remote file: 118 kB
front/assets/talk_to_drias_winter_temp_paris_example.png ADDED

Git LFS Details

  • SHA256: 11ef57c94d0920f7c25c6624ecfa793648c383a15e3aa85cf838b3baada056a7
  • Pointer size: 131 Bytes
  • Size of remote file: 194 kB
front/assets/talk_to_ipcc_china_example.png ADDED

Git LFS Details

  • SHA256: 3f8cfe2e6942352e893e76303409e5179dcc9bb398dcef3d76d0f6d5b2ac1a0b
  • Pointer size: 131 Bytes
  • Size of remote file: 918 kB
front/assets/talk_to_ipcc_france_example.png ADDED

Git LFS Details

  • SHA256: 3b50641a94f15cac929a480126fb62f0df03fa9a7db9c6fe1f9f23a9d0bb14b3
  • Pointer size: 132 Bytes
  • Size of remote file: 1.24 MB
front/assets/talk_to_ipcc_new_york_example.png ADDED

Git LFS Details

  • SHA256: 7dab3af17210c5d138db9825acffc1279ff2ad6f3649b31dbadf18ff44b3ea14
  • Pointer size: 131 Bytes
  • Size of remote file: 190 kB
front/tabs/tab_drias.py CHANGED
@@ -4,26 +4,25 @@ import os
4
  import pandas as pd
5
 
6
  from climateqa.engine.talk_to_data.main import ask_drias
7
- from climateqa.engine.talk_to_data.config import DRIAS_MODELS, DRIAS_UI_TEXT
8
 
9
  class DriasUIElements(TypedDict):
10
  tab: gr.Tab
11
  details_accordion: gr.Accordion
12
  examples_hidden: gr.Textbox
13
  examples: gr.Examples
 
14
  drias_direct_question: gr.Textbox
15
  result_text: gr.Textbox
16
- table_names_display: gr.DataFrame
17
  query_accordion: gr.Accordion
18
  drias_sql_query: gr.Textbox
19
  chart_accordion: gr.Accordion
 
20
  model_selection: gr.Dropdown
21
  drias_display: gr.Plot
22
  table_accordion: gr.Accordion
23
  drias_table: gr.DataFrame
24
- pagination_display: gr.Markdown
25
- prev_button: gr.Button
26
- next_button: gr.Button
27
 
28
 
29
  async def ask_drias_query(query: str, index_state: int, user_id: str):
@@ -31,7 +30,7 @@ async def ask_drias_query(query: str, index_state: int, user_id: str):
31
  return result
32
 
33
 
34
- def show_results(sql_queries_state, dataframes_state, plots_state):
35
  if not sql_queries_state or not dataframes_state or not plots_state:
36
  # If all results are empty, show "No result"
37
  return (
@@ -40,9 +39,6 @@ def show_results(sql_queries_state, dataframes_state, plots_state):
40
  gr.update(visible=False),
41
  gr.update(visible=False),
42
  gr.update(visible=False),
43
- gr.update(visible=False),
44
- gr.update(visible=False),
45
- gr.update(visible=False),
46
  )
47
  else:
48
  # Show the appropriate components with their data
@@ -51,10 +47,7 @@ def show_results(sql_queries_state, dataframes_state, plots_state):
51
  gr.update(visible=True),
52
  gr.update(visible=True),
53
  gr.update(visible=True),
54
- gr.update(visible=True),
55
- gr.update(visible=True),
56
- gr.update(visible=True),
57
- gr.update(visible=True),
58
  )
59
 
60
 
@@ -72,44 +65,14 @@ def filter_by_model(dataframes, figures, index_state, model_selection):
72
  return df, figure
73
 
74
 
75
- def update_pagination(index, sql_queries):
76
- pagination = f"{index + 1}/{len(sql_queries)}" if sql_queries else ""
77
- return pagination
78
-
79
-
80
- def show_previous(index, sql_queries, dataframes, plots):
81
- if index > 0:
82
- index -= 1
83
- return (
84
- sql_queries[index],
85
- dataframes[index],
86
- plots[index](dataframes[index]),
87
- index,
88
- )
89
-
90
-
91
- def show_next(index, sql_queries, dataframes, plots):
92
- if index < len(sql_queries) - 1:
93
- index += 1
94
- return (
95
- sql_queries[index],
96
- dataframes[index],
97
- plots[index](dataframes[index]),
98
- index,
99
- )
100
-
101
-
102
- def display_table_names(table_names):
103
- return [table_names]
104
-
105
-
106
- def on_table_click(evt: gr.SelectData, table_names, sql_queries, dataframes, plots):
107
- index = evt.index[1]
108
  figure = plots[index](dataframes[index])
109
  return (
110
  sql_queries[index],
111
  dataframes[index],
112
  figure,
 
113
  index,
114
  )
115
 
@@ -117,7 +80,7 @@ def on_table_click(evt: gr.SelectData, table_names, sql_queries, dataframes, plo
117
  def create_drias_ui() -> DriasUIElements:
118
  """Create and return all UI elements for the DRIAS tab."""
119
  with gr.Tab("France - Talk to DRIAS", elem_id="tab-vanna", id=6) as tab:
120
- with gr.Accordion(label="Details") as details_accordion:
121
  gr.Markdown(DRIAS_UI_TEXT)
122
 
123
  # Add examples for common questions
@@ -141,24 +104,43 @@ def create_drias_ui() -> DriasUIElements:
141
  elem_id="direct-question",
142
  interactive=True,
143
  )
 
 
 
 
 
 
 
 
 
144
 
145
  result_text = gr.Textbox(
146
  label="", elem_id="no-result-label", interactive=False, visible=True
147
  )
148
-
149
- table_names_display = gr.DataFrame(
150
- [], label="List of relevant indicators", headers=None, interactive=False, elem_id="table-names", visible=False
151
- )
152
-
153
- with gr.Accordion(label="SQL Query Used", visible=False) as query_accordion:
154
- drias_sql_query = gr.Textbox(
155
- label="", elem_id="sql-query", interactive=False
156
  )
157
 
 
 
 
 
 
 
158
  with gr.Accordion(label="Chart", visible=False) as chart_accordion:
159
- model_selection = gr.Dropdown(
160
- label="Model", choices=DRIAS_MODELS, value="ALL", interactive=True
161
- )
 
 
 
 
162
  drias_display = gr.Plot(elem_id="vanna-plot")
163
 
164
  with gr.Accordion(
@@ -166,32 +148,23 @@ def create_drias_ui() -> DriasUIElements:
166
  ) as table_accordion:
167
  drias_table = gr.DataFrame([], elem_id="vanna-table")
168
 
169
- pagination_display = gr.Markdown(
170
- value="", visible=False, elem_id="pagination-display"
171
- )
172
-
173
- with gr.Row():
174
- prev_button = gr.Button("Previous", visible=False)
175
- next_button = gr.Button("Next", visible=False)
176
-
177
  return DriasUIElements(
178
  tab=tab,
179
  details_accordion=details_accordion,
180
  examples_hidden=examples_hidden,
181
  examples=examples,
 
182
  drias_direct_question=drias_direct_question,
183
  result_text=result_text,
184
  table_names_display=table_names_display,
185
  query_accordion=query_accordion,
186
  drias_sql_query=drias_sql_query,
187
  chart_accordion=chart_accordion,
 
188
  model_selection=model_selection,
189
  drias_display=drias_display,
190
  table_accordion=table_accordion,
191
  drias_table=drias_table,
192
- pagination_display=pagination_display,
193
- prev_button=prev_button,
194
- next_button=next_button
195
  )
196
 
197
 
@@ -202,94 +175,56 @@ def setup_drias_events(ui_elements: DriasUIElements, share_client=None, user_id=
202
  sql_queries_state = gr.State([])
203
  dataframes_state = gr.State([])
204
  plots_state = gr.State([])
 
205
  index_state = gr.State(0)
206
  table_names_list = gr.State([])
207
  user_id = gr.State(user_id)
208
 
 
 
 
 
 
 
 
209
  # Handle example selection
210
  ui_elements["examples_hidden"].change(
211
  lambda x: (gr.Accordion(open=False), gr.Textbox(value=x)),
212
  inputs=[ui_elements["examples_hidden"]],
213
  outputs=[ui_elements["details_accordion"], ui_elements["drias_direct_question"]]
214
  ).then(
215
- ask_drias_query,
216
- inputs=[ui_elements["examples_hidden"], index_state, user_id],
217
- outputs=[
218
- ui_elements["drias_sql_query"],
219
- ui_elements["drias_table"],
220
- ui_elements["drias_display"],
221
- sql_queries_state,
222
- dataframes_state,
223
- plots_state,
224
- index_state,
225
- table_names_list,
226
- ui_elements["result_text"],
227
- ],
228
- ).then(
229
- show_results,
230
- inputs=[sql_queries_state, dataframes_state, plots_state],
231
- outputs=[
232
- ui_elements["result_text"],
233
- ui_elements["query_accordion"],
234
- ui_elements["table_accordion"],
235
- ui_elements["chart_accordion"],
236
- ui_elements["prev_button"],
237
- ui_elements["next_button"],
238
- ui_elements["pagination_display"],
239
- ui_elements["table_names_display"],
240
- ],
241
- ).then(
242
- update_pagination,
243
- inputs=[index_state, sql_queries_state],
244
- outputs=[ui_elements["pagination_display"]],
245
- ).then(
246
- display_table_names,
247
- inputs=[table_names_list],
248
- outputs=[ui_elements["table_names_display"]],
249
- )
250
-
251
- # Handle direct question submission
252
- ui_elements["drias_direct_question"].submit(
253
- lambda: gr.Accordion(open=False),
254
  inputs=None,
255
- outputs=[ui_elements["details_accordion"]]
256
  ).then(
257
  ask_drias_query,
258
- inputs=[ui_elements["drias_direct_question"], index_state, user_id],
259
  outputs=[
260
  ui_elements["drias_sql_query"],
261
  ui_elements["drias_table"],
262
  ui_elements["drias_display"],
 
263
  sql_queries_state,
264
  dataframes_state,
265
  plots_state,
 
266
  index_state,
267
  table_names_list,
268
  ui_elements["result_text"],
269
  ],
270
  ).then(
271
  show_results,
272
- inputs=[sql_queries_state, dataframes_state, plots_state],
273
  outputs=[
274
  ui_elements["result_text"],
275
  ui_elements["query_accordion"],
276
  ui_elements["table_accordion"],
277
  ui_elements["chart_accordion"],
278
- ui_elements["prev_button"],
279
- ui_elements["next_button"],
280
- ui_elements["pagination_display"],
281
  ui_elements["table_names_display"],
282
  ],
283
- ).then(
284
- update_pagination,
285
- inputs=[index_state, sql_queries_state],
286
- outputs=[ui_elements["pagination_display"]],
287
- ).then(
288
- display_table_names,
289
- inputs=[table_names_list],
290
- outputs=[ui_elements["table_names_display"]],
291
  )
292
 
 
293
  # Handle model selection change
294
  ui_elements["model_selection"].change(
295
  filter_by_model,
@@ -297,36 +232,12 @@ def setup_drias_events(ui_elements: DriasUIElements, share_client=None, user_id=
297
  outputs=[ui_elements["drias_table"], ui_elements["drias_display"]],
298
  )
299
 
300
- # Handle pagination buttons
301
- ui_elements["prev_button"].click(
302
- show_previous,
303
- inputs=[index_state, sql_queries_state, dataframes_state, plots_state],
304
- outputs=[ui_elements["drias_sql_query"], ui_elements["drias_table"], ui_elements["drias_display"], index_state],
305
- ).then(
306
- update_pagination,
307
- inputs=[index_state, sql_queries_state],
308
- outputs=[ui_elements["pagination_display"]],
309
- )
310
-
311
- ui_elements["next_button"].click(
312
- show_next,
313
- inputs=[index_state, sql_queries_state, dataframes_state, plots_state],
314
- outputs=[ui_elements["drias_sql_query"], ui_elements["drias_table"], ui_elements["drias_display"], index_state],
315
- ).then(
316
- update_pagination,
317
- inputs=[index_state, sql_queries_state],
318
- outputs=[ui_elements["pagination_display"]],
319
- )
320
 
321
  # Handle table selection
322
- ui_elements["table_names_display"].select(
323
  fn=on_table_click,
324
- inputs=[table_names_list, sql_queries_state, dataframes_state, plots_state],
325
- outputs=[ui_elements["drias_sql_query"], ui_elements["drias_table"], ui_elements["drias_display"], index_state],
326
- ).then(
327
- update_pagination,
328
- inputs=[index_state, sql_queries_state],
329
- outputs=[ui_elements["pagination_display"]],
330
  )
331
 
332
  def create_drias_tab(share_client=None, user_id=None):
 
4
  import pandas as pd
5
 
6
  from climateqa.engine.talk_to_data.main import ask_drias
7
+ from climateqa.engine.talk_to_data.drias.config import DRIAS_MODELS, DRIAS_UI_TEXT
8
 
9
  class DriasUIElements(TypedDict):
10
  tab: gr.Tab
11
  details_accordion: gr.Accordion
12
  examples_hidden: gr.Textbox
13
  examples: gr.Examples
14
+ image_examples: gr.Row
15
  drias_direct_question: gr.Textbox
16
  result_text: gr.Textbox
17
+ table_names_display: gr.Radio
18
  query_accordion: gr.Accordion
19
  drias_sql_query: gr.Textbox
20
  chart_accordion: gr.Accordion
21
+ plot_information: gr.Markdown
22
  model_selection: gr.Dropdown
23
  drias_display: gr.Plot
24
  table_accordion: gr.Accordion
25
  drias_table: gr.DataFrame
 
 
 
26
 
27
 
28
  async def ask_drias_query(query: str, index_state: int, user_id: str):
 
30
  return result
31
 
32
 
33
+ def show_results(sql_queries_state, dataframes_state, plots_state, table_names):
34
  if not sql_queries_state or not dataframes_state or not plots_state:
35
  # If all results are empty, show "No result"
36
  return (
 
39
  gr.update(visible=False),
40
  gr.update(visible=False),
41
  gr.update(visible=False),
 
 
 
42
  )
43
  else:
44
  # Show the appropriate components with their data
 
47
  gr.update(visible=True),
48
  gr.update(visible=True),
49
  gr.update(visible=True),
50
+ gr.update(choices=table_names, value=table_names[0], visible=True),
 
 
 
51
  )
52
 
53
 
 
65
  return df, figure
66
 
67
 
68
+ def on_table_click(selected_label, table_names, sql_queries, dataframes, plot_informations, plots):
69
+ index = table_names.index(selected_label)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  figure = plots[index](dataframes[index])
71
  return (
72
  sql_queries[index],
73
  dataframes[index],
74
  figure,
75
+ plot_informations[index],
76
  index,
77
  )
78
 
 
80
  def create_drias_ui() -> DriasUIElements:
81
  """Create and return all UI elements for the DRIAS tab."""
82
  with gr.Tab("France - Talk to DRIAS", elem_id="tab-vanna", id=6) as tab:
83
+ with gr.Accordion(label="❓ How to use?", elem_id="details") as details_accordion:
84
  gr.Markdown(DRIAS_UI_TEXT)
85
 
86
  # Add examples for common questions
 
104
  elem_id="direct-question",
105
  interactive=True,
106
  )
107
+
108
+
109
+ with gr.Row(visible=True, elem_id="example-img-container") as image_examples:
110
+ gr.Markdown("### Examples of possible visualizations")
111
+
112
+ with gr.Row():
113
+ gr.Image("./front/assets/talk_to_drias_winter_temp_paris_example.png", label="Evolution of Mean Winter Temperature in Paris", elem_classes=["example-img"])
114
+ gr.Image("./front/assets/talk_to_drias_annual_temperature_france_example.png", label="Mean Annual Temperature in 2030 in France", elem_classes=["example-img"])
115
+ gr.Image("./front/assets/talk_to_drias_frequency_remarkable_precipitation_lyon_example.png", label="Frequency of Remarkable Daily Precipitation in Lyon", elem_classes=["example-img"])
116
 
117
  result_text = gr.Textbox(
118
  label="", elem_id="no-result-label", interactive=False, visible=True
119
  )
120
+
121
+ with gr.Row():
122
+ table_names_display = gr.Radio(
123
+ choices=[],
124
+ label="Relevant figures created",
125
+ interactive=True,
126
+ elem_id="table-names",
127
+ visible=False
128
  )
129
 
130
+ with gr.Accordion(label="SQL Query Used", visible=False) as query_accordion:
131
+ drias_sql_query = gr.Textbox(
132
+ label="", elem_id="sql-query", interactive=False
133
+ )
134
+
135
+
136
  with gr.Accordion(label="Chart", visible=False) as chart_accordion:
137
+ with gr.Row():
138
+ model_selection = gr.Dropdown(
139
+ label="Model", choices=DRIAS_MODELS, value="ALL", interactive=True
140
+ )
141
+ with gr.Accordion(label="Informations about the plot", open=False):
142
+ plot_information = gr.Markdown(value = "")
143
+
144
  drias_display = gr.Plot(elem_id="vanna-plot")
145
 
146
  with gr.Accordion(
 
148
  ) as table_accordion:
149
  drias_table = gr.DataFrame([], elem_id="vanna-table")
150
 
 
 
 
 
 
 
 
 
151
  return DriasUIElements(
152
  tab=tab,
153
  details_accordion=details_accordion,
154
  examples_hidden=examples_hidden,
155
  examples=examples,
156
+ image_examples=image_examples,
157
  drias_direct_question=drias_direct_question,
158
  result_text=result_text,
159
  table_names_display=table_names_display,
160
  query_accordion=query_accordion,
161
  drias_sql_query=drias_sql_query,
162
  chart_accordion=chart_accordion,
163
+ plot_information=plot_information,
164
  model_selection=model_selection,
165
  drias_display=drias_display,
166
  table_accordion=table_accordion,
167
  drias_table=drias_table,
 
 
 
168
  )
169
 
170
 
 
175
  sql_queries_state = gr.State([])
176
  dataframes_state = gr.State([])
177
  plots_state = gr.State([])
178
+ plot_informations_state = gr.State([])
179
  index_state = gr.State(0)
180
  table_names_list = gr.State([])
181
  user_id = gr.State(user_id)
182
 
183
+ # Handle direct question submission - trigger the same workflow by setting examples_hidden
184
+ ui_elements["drias_direct_question"].submit(
185
+ lambda x: gr.update(value=x),
186
+ inputs=[ui_elements["drias_direct_question"]],
187
+ outputs=[ui_elements["examples_hidden"]],
188
+ )
189
+
190
  # Handle example selection
191
  ui_elements["examples_hidden"].change(
192
  lambda x: (gr.Accordion(open=False), gr.Textbox(value=x)),
193
  inputs=[ui_elements["examples_hidden"]],
194
  outputs=[ui_elements["details_accordion"], ui_elements["drias_direct_question"]]
195
  ).then(
196
+ lambda : gr.update(visible=False),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
  inputs=None,
198
+ outputs=ui_elements["image_examples"]
199
  ).then(
200
  ask_drias_query,
201
+ inputs=[ui_elements["examples_hidden"], index_state, user_id],
202
  outputs=[
203
  ui_elements["drias_sql_query"],
204
  ui_elements["drias_table"],
205
  ui_elements["drias_display"],
206
+ ui_elements["plot_information"],
207
  sql_queries_state,
208
  dataframes_state,
209
  plots_state,
210
+ plot_informations_state,
211
  index_state,
212
  table_names_list,
213
  ui_elements["result_text"],
214
  ],
215
  ).then(
216
  show_results,
217
+ inputs=[sql_queries_state, dataframes_state, plots_state, table_names_list],
218
  outputs=[
219
  ui_elements["result_text"],
220
  ui_elements["query_accordion"],
221
  ui_elements["table_accordion"],
222
  ui_elements["chart_accordion"],
 
 
 
223
  ui_elements["table_names_display"],
224
  ],
 
 
 
 
 
 
 
 
225
  )
226
 
227
+
228
  # Handle model selection change
229
  ui_elements["model_selection"].change(
230
  filter_by_model,
 
232
  outputs=[ui_elements["drias_table"], ui_elements["drias_display"]],
233
  )
234
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
 
236
  # Handle table selection
237
+ ui_elements["table_names_display"].change(
238
  fn=on_table_click,
239
+ inputs=[ui_elements["table_names_display"], table_names_list, sql_queries_state, dataframes_state, plot_informations_state, plots_state],
240
+ outputs=[ui_elements["drias_sql_query"], ui_elements["drias_table"], ui_elements["drias_display"], ui_elements["plot_information"], index_state],
 
 
 
 
241
  )
242
 
243
  def create_drias_tab(share_client=None, user_id=None):
front/tabs/tab_ipcc.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from random import choices
2
+ import gradio as gr
3
+ from typing import TypedDict
4
+ from climateqa.engine.talk_to_data.main import ask_ipcc
5
+ from climateqa.engine.talk_to_data.ipcc.config import IPCC_MODELS, IPCC_SCENARIO, IPCC_UI_TEXT
6
+ import uuid
7
+
8
+ class ipccUIElements(TypedDict):
9
+ tab: gr.Tab
10
+ details_accordion: gr.Accordion
11
+ examples_hidden: gr.Textbox
12
+ examples: gr.Examples
13
+ image_examples: gr.Row
14
+ ipcc_direct_question: gr.Textbox
15
+ result_text: gr.Textbox
16
+ table_names_display: gr.Radio
17
+ query_accordion: gr.Accordion
18
+ ipcc_sql_query: gr.Textbox
19
+ chart_accordion: gr.Accordion
20
+ plot_information: gr.Markdown
21
+ scenario_selection: gr.Dropdown
22
+ ipcc_display: gr.Plot
23
+ table_accordion: gr.Accordion
24
+ ipcc_table: gr.DataFrame
25
+
26
+
27
+ async def ask_ipcc_query(query: str, index_state: int, user_id: str):
28
+ result = await ask_ipcc(query, index_state, user_id)
29
+ return result
30
+
31
+ def hide_outputs():
32
+ """Hide all outputs initially."""
33
+ return (
34
+ gr.update(visible=True), # Show the result text
35
+ gr.update(visible=False), # Hide the query accordion
36
+ gr.update(visible=False), # Hide the table accordion
37
+ gr.update(visible=False), # Hide the chart accordion
38
+ gr.update(visible=False), # Hide table names
39
+ )
40
+
41
+ def show_results(sql_queries_state, dataframes_state, plots_state, table_names):
42
+ if not sql_queries_state or not dataframes_state or not plots_state:
43
+ # If all results are empty, show "No result"
44
+ return (
45
+ gr.update(visible=True),
46
+ gr.update(visible=False),
47
+ gr.update(visible=False),
48
+ gr.update(visible=False),
49
+ gr.update(visible=False),
50
+ )
51
+ else:
52
+ # Show the appropriate components with their data
53
+ return (
54
+ gr.update(visible=False),
55
+ gr.update(visible=True),
56
+ gr.update(visible=True),
57
+ gr.update(visible=True),
58
+ gr.update(choices=table_names, value=table_names[0], visible=True),
59
+ )
60
+
61
+
62
+ def show_filter_by_scenario(table_names, index_state, dataframes):
63
+ if len(table_names) > 0 and table_names[index_state].startswith("Map"):
64
+ df = dataframes[index_state]
65
+ scenarios = sorted(df["scenario"].unique())
66
+ return gr.update(visible=True, choices=scenarios, value=scenarios[0])
67
+ else:
68
+ return gr.update(visible=False)
69
+
70
+ def filter_by_scenario(dataframes, figures, table_names, index_state, scenario):
71
+ df = dataframes[index_state]
72
+ if not table_names[index_state].startswith("Map"):
73
+ return df, figures[index_state](df)
74
+ if df.empty:
75
+ return df, None
76
+ if "scenario" not in df.columns:
77
+ return df, figures[index_state](df)
78
+ else:
79
+ df = df[df["scenario"] == scenario]
80
+ if df.empty:
81
+ return df, None
82
+ figure = figures[index_state](df)
83
+ return df, figure
84
+
85
+
86
+ def display_table_names(table_names, index_state):
87
+ return [
88
+ [name]
89
+ for name in table_names
90
+ ]
91
+
92
+ def on_table_click(selected_label, table_names, sql_queries, dataframes, plot_informations, plots):
93
+ index = table_names.index(selected_label)
94
+ figure = plots[index](dataframes[index])
95
+
96
+ return (
97
+ sql_queries[index],
98
+ dataframes[index],
99
+ figure,
100
+ plot_informations[index],
101
+ index,
102
+ )
103
+
104
+
105
+ def create_ipcc_ui() -> ipccUIElements:
106
+
107
+ """Create and return all UI elements for the ipcc tab."""
108
+ with gr.Tab("(Beta) Talk to IPCC", elem_id="tab-vanna", id=7) as tab:
109
+ with gr.Accordion(label="❓ How to use?", elem_id="details") as details_accordion:
110
+ gr.Markdown(IPCC_UI_TEXT)
111
+
112
+ # Add examples for common questions
113
+ examples_hidden = gr.Textbox(visible=False, elem_id="ipcc-examples-hidden")
114
+ examples = gr.Examples(
115
+ examples=[
116
+ ["What will the temperature be like in Paris?"],
117
+ ["What will be the total rainfall in the USA in 2030?"],
118
+ ["How will the average temperature evolve in China?"],
119
+ ["What will be the average total precipitation in London ?"]
120
+ ],
121
+ label="Example Questions",
122
+ inputs=[examples_hidden],
123
+ outputs=[examples_hidden],
124
+ )
125
+
126
+ with gr.Row():
127
+ ipcc_direct_question = gr.Textbox(
128
+ label="Direct Question",
129
+ placeholder="You can write direct question here",
130
+ elem_id="direct-question",
131
+ interactive=True,
132
+ )
133
+
134
+ with gr.Row(visible=True, elem_id="example-img-container") as image_examples:
135
+ gr.Markdown("### Examples of possible visualizations")
136
+
137
+ with gr.Row():
138
+ gr.Image("./front/assets/talk_to_ipcc_france_example.png", label="Total Precipitation in 2030 in France", elem_classes=["example-img"])
139
+ gr.Image("./front/assets/talk_to_ipcc_new_york_example.png", label="Yearly Evolution of Mean Temperature in New York (Historical + SSP Scenarios)", elem_classes=["example-img"])
140
+ gr.Image("./front/assets/talk_to_ipcc_china_example.png", label="Mean Temperature in 2050 in China", elem_classes=["example-img"])
141
+
142
+ result_text = gr.Textbox(
143
+ label="", elem_id="no-result-label", interactive=False, visible=True
144
+ )
145
+ with gr.Row():
146
+ table_names_display = gr.Radio(
147
+ choices=[],
148
+ label="Relevant figures created",
149
+ interactive=True,
150
+ elem_id="table-names",
151
+ visible=False
152
+ )
153
+
154
+ with gr.Accordion(label="SQL Query Used", visible=False) as query_accordion:
155
+ ipcc_sql_query = gr.Textbox(
156
+ label="", elem_id="sql-query", interactive=False
157
+ )
158
+
159
+ with gr.Accordion(label="Chart", visible=False) as chart_accordion:
160
+
161
+ with gr.Row():
162
+ scenario_selection = gr.Dropdown(
163
+ label="Scenario", choices=IPCC_SCENARIO, value=IPCC_SCENARIO[0], interactive=True, visible=False
164
+ )
165
+
166
+ with gr.Accordion(label="Informations about the plot", open=False):
167
+ plot_information = gr.Markdown(value = "")
168
+
169
+ ipcc_display = gr.Plot(elem_id="vanna-plot")
170
+
171
+ with gr.Accordion(
172
+ label="Data used", open=False, visible=False
173
+ ) as table_accordion:
174
+ ipcc_table = gr.DataFrame([], elem_id="vanna-table")
175
+
176
+
177
+ return ipccUIElements(
178
+ tab=tab,
179
+ details_accordion=details_accordion,
180
+ examples_hidden=examples_hidden,
181
+ examples=examples,
182
+ image_examples=image_examples,
183
+ ipcc_direct_question=ipcc_direct_question,
184
+ result_text=result_text,
185
+ table_names_display=table_names_display,
186
+ query_accordion=query_accordion,
187
+ ipcc_sql_query=ipcc_sql_query,
188
+ chart_accordion=chart_accordion,
189
+ plot_information=plot_information,
190
+ scenario_selection=scenario_selection,
191
+ ipcc_display=ipcc_display,
192
+ table_accordion=table_accordion,
193
+ ipcc_table=ipcc_table,
194
+ )
195
+
196
+
197
+
198
+ def setup_ipcc_events(ui_elements: ipccUIElements, share_client=None, user_id=None) -> None:
199
+ """Set up all event handlers for the ipcc tab."""
200
+ # Create state variables
201
+ sql_queries_state = gr.State([])
202
+ dataframes_state = gr.State([])
203
+ plots_state = gr.State([])
204
+ plot_informations_state = gr.State([])
205
+ index_state = gr.State(0)
206
+ table_names_list = gr.State([])
207
+ user_id = gr.State(user_id)
208
+
209
+ # Handle direct question submission - trigger the same workflow by setting examples_hidden
210
+ ui_elements["ipcc_direct_question"].submit(
211
+ lambda x: gr.update(value=x),
212
+ inputs=[ui_elements["ipcc_direct_question"]],
213
+ outputs=[ui_elements["examples_hidden"]],
214
+ )
215
+
216
+ # Handle example selection
217
+ ui_elements["examples_hidden"].change(
218
+ lambda x: (gr.Accordion(open=False), gr.Textbox(value=x)),
219
+ inputs=[ui_elements["examples_hidden"]],
220
+ outputs=[ui_elements["details_accordion"], ui_elements["ipcc_direct_question"]]
221
+ ).then(
222
+ lambda : gr.update(visible=False),
223
+ inputs=None,
224
+ outputs=ui_elements["image_examples"]
225
+ ).then(
226
+ hide_outputs,
227
+ inputs=None,
228
+ outputs=[
229
+ ui_elements["result_text"],
230
+ ui_elements["query_accordion"],
231
+ ui_elements["table_accordion"],
232
+ ui_elements["chart_accordion"],
233
+ ui_elements["table_names_display"],
234
+ ]
235
+ ).then(
236
+ ask_ipcc_query,
237
+ inputs=[ui_elements["examples_hidden"], index_state, user_id],
238
+ outputs=[
239
+ ui_elements["ipcc_sql_query"],
240
+ ui_elements["ipcc_table"],
241
+ ui_elements["ipcc_display"],
242
+ ui_elements["plot_information"],
243
+ sql_queries_state,
244
+ dataframes_state,
245
+ plots_state,
246
+ plot_informations_state,
247
+ index_state,
248
+ table_names_list,
249
+ ui_elements["result_text"],
250
+ ],
251
+ ).then(
252
+ show_results,
253
+ inputs=[sql_queries_state, dataframes_state, plots_state, table_names_list],
254
+ outputs=[
255
+ ui_elements["result_text"],
256
+ ui_elements["query_accordion"],
257
+ ui_elements["table_accordion"],
258
+ ui_elements["chart_accordion"],
259
+ ui_elements["table_names_display"],
260
+ ],
261
+ ).then(
262
+ show_filter_by_scenario,
263
+ inputs=[table_names_list, index_state, dataframes_state],
264
+ outputs=[ui_elements["scenario_selection"]],
265
+ ).then(
266
+ filter_by_scenario,
267
+ inputs=[dataframes_state, plots_state, table_names_list, index_state, ui_elements["scenario_selection"]],
268
+ outputs=[ui_elements["ipcc_table"], ui_elements["ipcc_display"]],
269
+ )
270
+
271
+
272
+ # Handle model selection change
273
+ ui_elements["scenario_selection"].change(
274
+ filter_by_scenario,
275
+ inputs=[dataframes_state, plots_state, table_names_list, index_state, ui_elements["scenario_selection"]],
276
+ outputs=[ui_elements["ipcc_table"], ui_elements["ipcc_display"]],
277
+ )
278
+
279
+ # Handle table selection
280
+ ui_elements["table_names_display"].change(
281
+ fn=on_table_click,
282
+ inputs=[ui_elements["table_names_display"], table_names_list, sql_queries_state, dataframes_state, plot_informations_state, plots_state],
283
+ outputs=[ui_elements["ipcc_sql_query"], ui_elements["ipcc_table"], ui_elements["ipcc_display"], ui_elements["plot_information"], index_state],
284
+ ).then(
285
+ show_filter_by_scenario,
286
+ inputs=[table_names_list, index_state, dataframes_state],
287
+ outputs=[ui_elements["scenario_selection"]],
288
+ ).then(
289
+ filter_by_scenario,
290
+ inputs=[dataframes_state, plots_state, table_names_list, index_state, ui_elements["scenario_selection"]],
291
+ outputs=[ui_elements["ipcc_table"], ui_elements["ipcc_display"]],
292
+ )
293
+
294
+
295
+ def create_ipcc_tab(share_client=None, user_id=None):
296
+ """Create the ipcc tab with all its components and event handlers."""
297
+ ui_elements = create_ipcc_ui()
298
+ setup_ipcc_events(ui_elements, share_client=share_client, user_id=user_id)
299
+
300
+
requirements.txt CHANGED
@@ -25,4 +25,5 @@ geopy==2.4.1
25
  duckdb==1.2.1
26
  openai==1.61.1
27
  pydantic==2.9.2
28
- pydantic-settings==2.2.1
 
 
25
  duckdb==1.2.1
26
  openai==1.61.1
27
  pydantic==2.9.2
28
+ pydantic-settings==2.2.1
29
+ geojson==3.2.0
style.css CHANGED
@@ -656,12 +656,20 @@ a {
656
  /* overflow-y: scroll; */
657
  }
658
  #sql-query{
 
 
 
 
 
 
 
659
  max-height: 300px;
660
  overflow-y:scroll;
661
  }
662
 
663
  #sql-query textarea{
664
  min-height: 100px !important;
 
665
  }
666
 
667
  #sql-query span{
@@ -671,8 +679,11 @@ div#tab-vanna{
671
  max-height: 100¨vh;
672
  overflow-y: hidden;
673
  }
 
 
 
674
  #vanna-plot{
675
- max-height:500px
676
  }
677
 
678
  #pagination-display{
@@ -681,13 +692,33 @@ div#tab-vanna{
681
  font-size: 16px;
682
  }
683
 
684
- #table-names table{
685
- overflow: hidden;
 
 
 
 
 
 
 
 
 
 
686
  }
687
- #table-names thead{
 
 
 
 
 
688
  display: none;
689
  }
690
 
 
 
 
 
 
691
  /* DRIAS Data Table Styles */
692
  #vanna-table {
693
  height: 400px !important;
@@ -710,3 +741,13 @@ div#tab-vanna{
710
  background: white;
711
  z-index: 1;
712
  }
 
 
 
 
 
 
 
 
 
 
 
656
  /* overflow-y: scroll; */
657
  }
658
  #sql-query{
659
+ <<<<<<< HEAD
660
+ max-height: 100%;
661
+ }
662
+
663
+ #sql-query textarea{
664
+ min-height: 200px !important;
665
+ =======
666
  max-height: 300px;
667
  overflow-y:scroll;
668
  }
669
 
670
  #sql-query textarea{
671
  min-height: 100px !important;
672
+ >>>>>>> hf-origin/main
673
  }
674
 
675
  #sql-query span{
 
679
  max-height: 100¨vh;
680
  overflow-y: hidden;
681
  }
682
+ #details button span{
683
+ font-weight: bold;
684
+ }
685
  #vanna-plot{
686
+ max-height:1000px
687
  }
688
 
689
  #pagination-display{
 
692
  font-size: 16px;
693
  }
694
 
695
+
696
+ #table-names label {
697
+ display: block;
698
+ width: 100%;
699
+ box-sizing: border-box;
700
+ padding: 8px 12px;
701
+ margin-bottom: 4px;
702
+ border: 1px solid #ccc;
703
+ border-radius: 6px;
704
+ background-color: white;
705
+ cursor: pointer;
706
+ text-align: center;
707
  }
708
+
709
+ #table-names label:hover {
710
+ background-color: #f0f8ff;
711
+ }
712
+
713
+ #table-names input[type="radio"] {
714
  display: none;
715
  }
716
 
717
+ #table-names input[type="radio"]:checked + label {
718
+ background-color: #d0eaff;
719
+ border-color: #2196f3;
720
+ }
721
+
722
  /* DRIAS Data Table Styles */
723
  #vanna-table {
724
  height: 400px !important;
 
741
  background: white;
742
  z-index: 1;
743
  }
744
+
745
+ .example-img{
746
+ height: 250px;
747
+ object-fit: contain;
748
+ }
749
+
750
+ #example-img-container {
751
+ flex-direction: column;
752
+ align-items: left;
753
+ }