armanddemasson commited on
Commit
80062dd
·
1 Parent(s): b21471a

feat: added evolution for a specific month plot

Browse files
climateqa/engine/talk_to_data/input_processing.py CHANGED
@@ -118,7 +118,28 @@ async def detect_year_with_openai(sentence: str) -> str:
118
  return years_list[0]
119
  else:
120
  return ""
121
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
  async def detect_relevant_tables(user_question: str, plot: Plot, llm, table_names_list: list[str]) -> list[str]:
124
  """Identifies relevant tables for a plot based on user input.
@@ -227,6 +248,14 @@ async def find_year(user_input: str) -> str| None:
227
  return None
228
  return year
229
 
 
 
 
 
 
 
 
 
230
  async def find_relevant_plots(state: State, llm, plots: list[Plot]) -> list[str]:
231
  print("---- Find relevant plots ----")
232
  relevant_plots = await detect_relevant_plots(state['user_input'], llm, plots)
@@ -237,16 +266,9 @@ async def find_relevant_tables_per_plot(state: State, plot: Plot, llm, tables: l
237
  relevant_tables = await detect_relevant_tables(state['user_input'], plot, llm, tables)
238
  return relevant_tables
239
 
240
- async def find_param(state: State, param_name:str, mode: Literal['DRIAS', 'IPCC'] = 'DRIAS') -> dict[str, Optional[str]] | Location | None:
241
- """Perform the good method to retrieve the desired parameter
242
-
243
- Args:
244
- state (State): state of the workflow
245
- param_name (str): name of the desired parameter
246
- table (str): name of the table
247
-
248
- Returns:
249
- dict[str, Any] | None:
250
  """
251
  if param_name == 'location':
252
  location = await find_location(state['user_input'], mode)
@@ -254,4 +276,8 @@ async def find_param(state: State, param_name:str, mode: Literal['DRIAS', 'IPCC'
254
  if param_name == 'year':
255
  year = await find_year(state['user_input'])
256
  return {'year': year}
257
- return None
 
 
 
 
 
118
  return years_list[0]
119
  else:
120
  return ""
121
+
122
+ async def detect_month_with_openai(sentence: str) -> str:
123
+ """
124
+ Detects month in a sentence using OpenAI's API via LangChain.
125
+ Returns the month as an integer string (e.g., "7" for July), or "" if not found.
126
+ """
127
+ llm = get_llm()
128
+ prompt = """
129
+ Extract the month (as a number from 1 to 12) mentioned in the following sentence.
130
+ Return the result as a Python list of integers. If no month is mentioned, return an empty list.
131
+
132
+ Sentence: "{sentence}"
133
+ """
134
+ prompt = ChatPromptTemplate.from_template(prompt)
135
+ structured_llm = llm.with_structured_output(ArrayOutput)
136
+ chain = prompt | structured_llm
137
+ response: ArrayOutput = await chain.ainvoke({"sentence": sentence})
138
+ months_list = eval(response['array'])
139
+ if len(months_list) > 0:
140
+ return str(months_list[0])
141
+ else:
142
+ return ""
143
 
144
  async def detect_relevant_tables(user_question: str, plot: Plot, llm, table_names_list: list[str]) -> list[str]:
145
  """Identifies relevant tables for a plot based on user input.
 
248
  return None
249
  return year
250
 
251
+ async def find_month(user_input: str) -> str | None:
252
+ """Extracts month information from user input using LLM."""
253
+ print(f"---- Find month ---")
254
+ month = await detect_month_with_openai(user_input)
255
+ if month == "":
256
+ return None
257
+ return month
258
+
259
  async def find_relevant_plots(state: State, llm, plots: list[Plot]) -> list[str]:
260
  print("---- Find relevant plots ----")
261
  relevant_plots = await detect_relevant_plots(state['user_input'], llm, plots)
 
266
  relevant_tables = await detect_relevant_tables(state['user_input'], plot, llm, tables)
267
  return relevant_tables
268
 
269
+ async def find_param(state: State, param_name: str, mode: Literal['DRIAS', 'IPCC'] = 'DRIAS') -> dict[str, Optional[str]] | Location | None:
270
+ """
271
+ Perform the good method to retrieve the desired parameter.
 
 
 
 
 
 
 
272
  """
273
  if param_name == 'location':
274
  location = await find_location(state['user_input'], mode)
 
276
  if param_name == 'year':
277
  year = await find_year(state['user_input'])
278
  return {'year': year}
279
+ if param_name == 'month':
280
+ month = await find_month(state['user_input'])
281
+ print(month)
282
+ return {'month': month}
283
+ return None
climateqa/engine/talk_to_data/ipcc/config.py CHANGED
@@ -30,7 +30,8 @@ IPCC_MODELS = []
30
 
31
  IPCC_PLOT_PARAMETERS = [
32
  'year',
33
- 'location'
 
34
  ]
35
 
36
  MACRO_COUNTRIES = ['JP',
 
30
 
31
  IPCC_PLOT_PARAMETERS = [
32
  'year',
33
+ 'location',
34
+ 'month'
35
  ]
36
 
37
  MACRO_COUNTRIES = ['JP',
climateqa/engine/talk_to_data/ipcc/plot_informations.py CHANGED
@@ -47,4 +47,27 @@ Each grid point is colored according to the value of the indicator ({unit}), all
47
  - For each grid point of {location} country ({country_name}), the value of {indicator} in {year} and for the selected scenario is extracted and mapped to its geographic coordinates.
48
  - The grid points correspond to 1-degree squares centered on the grid points of the IPCC dataset. Each grid point has been mapped to a country using [**reverse_geocoder**](https://github.com/thampiman/reverse-geocoder).
49
  - The coordinates used for each region are those of the closest available grid point in the IPCC database, which uses a regular grid with a spatial resolution of 1 degree.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  """
 
47
  - For each grid point of {location} country ({country_name}), the value of {indicator} in {year} and for the selected scenario is extracted and mapped to its geographic coordinates.
48
  - The grid points correspond to 1-degree squares centered on the grid points of the IPCC dataset. Each grid point has been mapped to a country using [**reverse_geocoder**](https://github.com/thampiman/reverse-geocoder).
49
  - The coordinates used for each region are those of the closest available grid point in the IPCC database, which uses a regular grid with a spatial resolution of 1 degree.
50
+ """
51
+
52
+ def indicator_specific_month_evolution_informations(
53
+ indicator: str,
54
+ params: dict[str, str]
55
+ ) -> str:
56
+ if "location" not in params:
57
+ raise ValueError('"location" must be provided in params')
58
+ location = params["location"]
59
+ if "month" not in params:
60
+ raise ValueError('"month" must be provided in params')
61
+ month = params["month"]
62
+ unit = IPCC_INDICATOR_TO_UNIT[indicator]
63
+ return f"""
64
+ This plot shows how the climate indicator **{indicator}** evolves over time in **{location}** for the month of **{month}**.
65
+ It combines both historical (from 1950 to 2015) observations and future (from 2016 to 2100) projections for the different SSP climate scenarios (SSP126, SSP245, SSP370 and SSP585).
66
+ The x-axis represents the years (from 1950 to 2100), and the y-axis shows the value of the {indicator} ({unit}) for the selected month.
67
+ Each line corresponds to a different scenario, allowing you to compare how {indicator} for month {month} might change under various future conditions.
68
+
69
+ **Data source:**
70
+ - The data comes from the IPCC climate datasets (Parquet files) for the relevant indicator, location, and month.
71
+ - For each year and scenario, the value of {indicator} for month {month} is extracted for the selected location.
72
+ - The coordinates used for {location} correspond to the closest available point in the IPCC database, which uses a regular grid with a spatial resolution of 1 degree.
73
  """
climateqa/engine/talk_to_data/ipcc/plots.py CHANGED
@@ -5,8 +5,8 @@ import pandas as pd
5
  import geojson
6
 
7
  from climateqa.engine.talk_to_data.ipcc.config import IPCC_INDICATOR_TO_COLORSCALE, IPCC_INDICATOR_TO_UNIT, IPCC_SCENARIO
8
- from climateqa.engine.talk_to_data.ipcc.plot_informations import choropleth_map_informations, indicator_evolution_informations
9
- from climateqa.engine.talk_to_data.ipcc.queries import indicator_for_given_year_query, indicator_per_year_at_location_query
10
  from climateqa.engine.talk_to_data.objects.plot import Plot
11
 
12
  def generate_geojson_polygons(latitudes: list[float], longitudes: list[float], indicators: list[float]) -> geojson.FeatureCollection:
@@ -102,6 +102,82 @@ indicator_evolution_at_location_historical_and_projections: Plot = {
102
  "short_name": "Evolution"
103
  }
104
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  def plot_choropleth_map_of_country_indicator_for_specific_year(
106
  params: dict,
107
  ) -> Callable[[pd.DataFrame], Figure]:
@@ -167,6 +243,7 @@ def plot_choropleth_map_of_country_indicator_for_specific_year(
167
 
168
  return plot_data
169
 
 
170
  choropleth_map_of_country_indicator_for_specific_year: Plot = {
171
  "name": "Choropleth Map of a Country's Indicator Distribution for a Specific Year",
172
  "description": (
@@ -185,5 +262,6 @@ choropleth_map_of_country_indicator_for_specific_year: Plot = {
185
 
186
  IPCC_PLOTS = [
187
  indicator_evolution_at_location_historical_and_projections,
188
- choropleth_map_of_country_indicator_for_specific_year
 
189
  ]
 
5
  import geojson
6
 
7
  from climateqa.engine.talk_to_data.ipcc.config import IPCC_INDICATOR_TO_COLORSCALE, IPCC_INDICATOR_TO_UNIT, IPCC_SCENARIO
8
+ from climateqa.engine.talk_to_data.ipcc.plot_informations import choropleth_map_informations, indicator_evolution_informations, indicator_specific_month_evolution_informations
9
+ from climateqa.engine.talk_to_data.ipcc.queries import indicator_for_given_year_query, indicator_per_year_and_specific_month_at_location_query, indicator_per_year_at_location_query
10
  from climateqa.engine.talk_to_data.objects.plot import Plot
11
 
12
  def generate_geojson_polygons(latitudes: list[float], longitudes: list[float], indicators: list[float]) -> geojson.FeatureCollection:
 
102
  "short_name": "Evolution"
103
  }
104
 
105
+ def plot_indicator_monthly_evolution_at_location(
106
+ params: dict,
107
+ ) -> Callable[[pd.DataFrame], Figure]:
108
+ """
109
+ Returns a function that generates a line plot showing the evolution of a climate indicator
110
+ for a specific month over time at a specific location, including both historical data
111
+ and future projections for different climate scenarios.
112
+
113
+ Args:
114
+ params (dict): Dictionary with:
115
+ - indicator_column (str): Name of the climate indicator column to plot.
116
+ - location (str): Location (e.g., country, city) for which to plot the indicator.
117
+ - month (int): Month number (1-12) to plot.
118
+
119
+ Returns:
120
+ Callable[[pd.DataFrame], Figure]: Function that takes a DataFrame and returns a Plotly Figure.
121
+ """
122
+ indicator = params["indicator_column"]
123
+ location = params["location"]
124
+ month = params["month"]
125
+ indicator_label = " ".join(word.capitalize() for word in indicator.split("_"))
126
+ unit = IPCC_INDICATOR_TO_UNIT.get(indicator, "")
127
+
128
+ def plot_data(df: pd.DataFrame) -> Figure:
129
+ df = df.sort_values(by='year')
130
+ years = df['year'].astype(int).tolist()
131
+ indicators = df[indicator].astype(float).tolist()
132
+ scenarios = df['scenario'].astype(str).tolist()
133
+
134
+ # Find last historical value for continuity
135
+ last_historical = [(y, v) for y, v, s in zip(years, indicators, scenarios) if s == 'historical']
136
+ last_historical_year, last_historical_indicator = last_historical[-1] if last_historical else (None, None)
137
+
138
+ fig = go.Figure()
139
+ for scenario in IPCC_SCENARIO:
140
+ x = [y for y, s in zip(years, scenarios) if s == scenario]
141
+ y = [v for v, s in zip(indicators, scenarios) if s == scenario]
142
+ # Connect historical to scenario
143
+ if scenario != 'historical' and last_historical_indicator is not None:
144
+ x = [last_historical_year] + x
145
+ y = [last_historical_indicator] + y
146
+ fig.add_trace(go.Scatter(
147
+ x=x,
148
+ y=y,
149
+ mode='lines',
150
+ name=scenario
151
+ ))
152
+
153
+ fig.update_layout(
154
+ title=f'Evolution of {indicator_label} in {location} for Month {month} (Historical + SSP Scenarios)',
155
+ xaxis_title='Year',
156
+ yaxis_title=f'{indicator_label} ({unit})',
157
+ legend_title='Scenario',
158
+ height=800,
159
+ )
160
+ return fig
161
+
162
+ return plot_data
163
+
164
+
165
+ indicator_specific_month_evolution_at_location: Plot = {
166
+ "name": "Indicator specific month Evolution at Location (Historical + Projections)",
167
+ "description": (
168
+ "Shows how a climate indicator (e.g., rainfall, temperature) for a specific month changes over time at a specific location, "
169
+ "including historical data and future projections. "
170
+ "Useful for questions about the value or trend of an indicator for a given month at a location, "
171
+ "such as 'How does July temperature evolve in Paris over time?'. "
172
+ "Parameters: indicator_column (the climate variable), location (e.g., country, city), month (1-12)."
173
+ ),
174
+ "params": ["indicator_column", "location", "month"],
175
+ "plot_function": plot_indicator_monthly_evolution_at_location,
176
+ "sql_query": indicator_per_year_and_specific_month_at_location_query,
177
+ "plot_information": indicator_specific_month_evolution_informations,
178
+ "short_name": "Evolution for a specific month"
179
+ }
180
+
181
  def plot_choropleth_map_of_country_indicator_for_specific_year(
182
  params: dict,
183
  ) -> Callable[[pd.DataFrame], Figure]:
 
243
 
244
  return plot_data
245
 
246
+
247
  choropleth_map_of_country_indicator_for_specific_year: Plot = {
248
  "name": "Choropleth Map of a Country's Indicator Distribution for a Specific Year",
249
  "description": (
 
262
 
263
  IPCC_PLOTS = [
264
  indicator_evolution_at_location_historical_and_projections,
265
+ choropleth_map_of_country_indicator_for_specific_year,
266
+ indicator_specific_month_evolution_at_location
267
  ]
climateqa/engine/talk_to_data/ipcc/queries.py CHANGED
@@ -74,6 +74,74 @@ def indicator_per_year_at_location_query(
74
  """
75
  return sql_query.strip()
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  class IndicatorForGivenYearQueryParams(TypedDict, total=False):
78
  """
79
  Parameters for querying an indicator's values across locations for a specific year.
@@ -140,4 +208,4 @@ def indicator_for_given_year_query(
140
  ORDER BY latitude, longitude, scenario
141
  """
142
 
143
- return sql_query.strip()
 
74
  """
75
  return sql_query.strip()
76
 
77
+ class IndicatorPerYearAndSpecificMonthAtLocationQueryParams(TypedDict, total=False):
78
+ """
79
+ Parameters for querying the evolution of an indicator per year for a specific month at a specific location.
80
+
81
+ Attributes:
82
+ indicator_column (str): Name of the climate indicator column.
83
+ latitude (str): Latitude of the location.
84
+ longitude (str): Longitude of the location.
85
+ country_code (str): Country code.
86
+ month (str): Month targeted
87
+ """
88
+ indicator_column: str
89
+ latitude: str
90
+ longitude: str
91
+ country_code: str
92
+ month: str
93
+
94
+ def indicator_per_year_and_specific_month_at_location_query(
95
+ table: str, params: IndicatorPerYearAndSpecificMonthAtLocationQueryParams
96
+ ) -> str:
97
+ """
98
+ Builds an SQL query to get the evolution of an indicator per year for a specific month at a specific location.
99
+
100
+ Args:
101
+ table (str): SQL table of the indicator.
102
+ params (dict): Dictionary with required params:
103
+ - indicator_column (str)
104
+ - latitude (str or float)
105
+ - longitude (str or float)
106
+ - month (int)
107
+
108
+ Returns:
109
+ str: The SQL query string.
110
+ """
111
+ indicator_column = params.get("indicator_column")
112
+ latitude = params.get("latitude")
113
+ longitude = params.get("longitude")
114
+ country_code = params.get("country_code")
115
+ month = params.get('month')
116
+
117
+ if not all([indicator_column, latitude, longitude, country_code, month]):
118
+ return ""
119
+
120
+ if country_code in MACRO_COUNTRIES:
121
+ table_path = f"'{IPCC_DATASET_URL}/{table.lower()}/{country_code}_macro.parquet'"
122
+ sql_query = f"""
123
+ SELECT year, scenario, {indicator_column}
124
+ FROM {table_path}
125
+ WHERE latitude = {latitude} AND longitude = {longitude} AND year >= 1950 AND month={month}
126
+ ORDER BY year, scenario
127
+ """
128
+ elif country_code in HUGE_MACRO_COUNTRIES:
129
+ table_path = f"'{IPCC_DATASET_URL}/{table.lower()}/{country_code}_macro.parquet'"
130
+ sql_query = f"""
131
+ SELECT year, scenario, {indicator_column}
132
+ FROM {table_path}
133
+ WHERE latitude = {latitude} AND longitude = {longitude} AND year >= 1950
134
+ ORDER year, scenario
135
+ """
136
+ else:
137
+ table_path = f"'{IPCC_DATASET_URL}/{table.lower()}/{country_code}.parquet'"
138
+ sql_query = f"""
139
+ SELECT year, scenario, MEDIAN({indicator_column}) AS {indicator_column}
140
+ FROM {table_path}
141
+ WHERE latitude = {latitude} AND longitude = {longitude} AND year >= 1950 AND month={month}
142
+ GROUP BY scenario, year
143
+ """
144
+ return sql_query.strip()
145
  class IndicatorForGivenYearQueryParams(TypedDict, total=False):
146
  """
147
  Parameters for querying an indicator's values across locations for a specific year.
 
208
  ORDER BY latitude, longitude, scenario
209
  """
210
 
211
+ return sql_query.strip()