armanddemasson commited on
Commit
765a122
·
1 Parent(s): 2133475

refactor: merged find month and detect month in one function and added docstring

Browse files
climateqa/engine/talk_to_data/input_processing.py CHANGED
@@ -120,36 +120,6 @@ async def detect_year_with_openai(sentence: str) -> str:
120
  else:
121
  return ""
122
 
123
- async def detect_month_with_openai(sentence: str) -> dict[str, str]:
124
- """
125
- Detects month in a sentence using OpenAI's API via LangChain.
126
- Returns the month as an integer string (e.g., "7" for July), or "" if not found.
127
- """
128
- llm = get_llm()
129
- prompt = """
130
- Extract the month (as a number from 1 to 12) mentioned in the following sentence.
131
- Return the result as a Python list of integers. If no month is mentioned, return an empty list.
132
-
133
- Sentence: "{sentence}"
134
- """
135
- prompt = ChatPromptTemplate.from_template(prompt)
136
- structured_llm = llm.with_structured_output(ArrayOutput)
137
- chain = prompt | structured_llm
138
- response: ArrayOutput = await chain.ainvoke({"sentence": sentence})
139
- months_list = eval(response['array'])
140
- if len(months_list) > 0:
141
- month_number = int(months_list[0])
142
- month_name = calendar.month_name[month_number]
143
- return {
144
- "month_number": str(month_number),
145
- "month_name": month_name
146
- }
147
- else:
148
- return {
149
- "month_number" : "",
150
- "month_name" : ""
151
- }
152
-
153
 
154
  async def detect_relevant_tables(user_question: str, plot: Plot, llm, table_names_list: list[str]) -> list[str]:
155
  """Identifies relevant tables for a plot based on user input.
@@ -259,11 +229,53 @@ async def find_year(user_input: str) -> str| None:
259
  return year
260
 
261
  async def find_month(user_input: str) -> dict[str, str|None]:
262
- """Extracts month information from user input using LLM."""
263
- print(f"---- Find month ---")
264
- month_info = await detect_month_with_openai(user_input)
265
- month_info = {key: None if value == "" else value for key, value in month_info.items()}
266
- return month_info
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
 
268
  async def find_relevant_plots(state: State, llm, plots: list[Plot]) -> list[str]:
269
  print("---- Find relevant plots ----")
@@ -277,7 +289,26 @@ async def find_relevant_tables_per_plot(state: State, plot: Plot, llm, tables: l
277
 
278
  async def find_param(state: State, param_name: str, mode: Literal['DRIAS', 'IPCC'] = 'DRIAS') -> dict[str, Optional[str]] | Location | None:
279
  """
280
- Perform the good method to retrieve the desired parameter.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
281
  """
282
  if param_name == 'location':
283
  location = await find_location(state['user_input'], mode)
 
120
  else:
121
  return ""
122
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
  async def detect_relevant_tables(user_question: str, plot: Plot, llm, table_names_list: list[str]) -> list[str]:
125
  """Identifies relevant tables for a plot based on user input.
 
229
  return year
230
 
231
  async def find_month(user_input: str) -> dict[str, str|None]:
232
+ """
233
+ Extracts month information from user input using an LLM.
234
+
235
+ This function analyzes the user's query to detect if a month is mentioned.
236
+ It returns both the month number (as a string, e.g. '7' for July) and the full English month name (e.g. 'July').
237
+ If no month is found, both values will be None.
238
+
239
+ Args:
240
+ user_input (str): The user's query text.
241
+
242
+ Returns:
243
+ dict[str, str|None]: A dictionary with keys:
244
+ - "month_number": the month number as a string (e.g. '7'), or None if not found
245
+ - "month_name": the full English month name (e.g. 'July'), or None if not found
246
+
247
+ Example:
248
+ >>> await find_month("Show me the temperature in Paris in July")
249
+ {'month_number': '7', 'month_name': 'July'}
250
+ >>> await find_month("Show me the temperature in Paris")
251
+ {'month_number': None, 'month_name': None}
252
+ """
253
+
254
+ llm = get_llm()
255
+ prompt = """
256
+ Extract the month (as a number from 1 to 12) mentioned in the following sentence.
257
+ Return the result as a Python list of integers. If no month is mentioned, return an empty list.
258
+
259
+ Sentence: "{sentence}"
260
+ """
261
+ prompt = ChatPromptTemplate.from_template(prompt)
262
+ structured_llm = llm.with_structured_output(ArrayOutput)
263
+ chain = prompt | structured_llm
264
+ response: ArrayOutput = await chain.ainvoke({"sentence": user_input})
265
+ months_list = ast.literal_eval(response['array'])
266
+ if len(months_list) > 0:
267
+ month_number = int(months_list[0])
268
+ month_name = calendar.month_name[month_number]
269
+ return {
270
+ "month_number": str(month_number),
271
+ "month_name": month_name
272
+ }
273
+ else:
274
+ return {
275
+ "month_number" : None,
276
+ "month_name" : None
277
+ }
278
+
279
 
280
  async def find_relevant_plots(state: State, llm, plots: list[Plot]) -> list[str]:
281
  print("---- Find relevant plots ----")
 
289
 
290
  async def find_param(state: State, param_name: str, mode: Literal['DRIAS', 'IPCC'] = 'DRIAS') -> dict[str, Optional[str]] | Location | None:
291
  """
292
+ Retrieves a specific parameter (location, year, month, etc.) from the user's input using the appropriate extraction method.
293
+
294
+ Args:
295
+ state (State): The current state containing at least the user's input under 'user_input'.
296
+ param_name (str): The name of the parameter to extract. Supported: 'location', 'year', 'month'.
297
+ mode (Literal['DRIAS', 'IPCC']): The data mode to use for location extraction.
298
+
299
+ Returns:
300
+ - For 'location': a Location object (dict with keys like 'location', 'latitude', etc.), or None if not found.
301
+ - For 'year': a dict {'year': year or None}.
302
+ - For 'month': a dict {'month_number': str or None, 'month_name': str or None}.
303
+ - None if the parameter is not recognized or not found.
304
+
305
+ Example:
306
+ >>> await find_param(state, 'location')
307
+ {'location': 'Paris', 'latitude': ..., ...}
308
+ >>> await find_param(state, 'year')
309
+ {'year': '2050'}
310
+ >>> await find_param(state, 'month')
311
+ {'month_number': '7', 'month_name': 'July'}
312
  """
313
  if param_name == 'location':
314
  location = await find_location(state['user_input'], mode)