Soumik555 commited on
Commit
95fd2fd
·
1 Parent(s): 4dee00f

Fixed langchain gemini agent underscore issue

Browse files
Files changed (2) hide show
  1. gemini_langchain_agent.py +127 -0
  2. orchestrator_functions.py +64 -56
gemini_langchain_agent.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import uuid
4
+ from langchain_google_genai import ChatGoogleGenerativeAI
5
+ import pandas as pd
6
+ from langchain_core.prompts import ChatPromptTemplate
7
+ from langchain_experimental.tools import PythonAstREPLTool
8
+ from langchain_experimental.agents import create_pandas_dataframe_agent
9
+ from dotenv import load_dotenv
10
+ import numpy as np
11
+ import matplotlib.pyplot as plt
12
+ import matplotlib
13
+ import seaborn as sns
14
+ import datetime as dt
15
+
16
+ # Set the backend for matplotlib to 'Agg' to avoid GUI issues
17
+ matplotlib.use('Agg')
18
+
19
+ load_dotenv()
20
+ model_name = 'gemini-2.0-flash' # Model name for Google Generative AI
21
+ google_api_keys = os.getenv("GEMINI_API_KEYS").split(",")
22
+ current_key_index = 0 # Global index for API keys
23
+
24
+ def create_agent(llm, data, tools):
25
+ """Create agent with tool names"""
26
+
27
+ return create_pandas_dataframe_agent(
28
+ llm,
29
+ data,
30
+ agent_type="tool-calling",
31
+ verbose=True,
32
+ allow_dangerous_code=True,
33
+ extra_tools=tools,
34
+ return_intermediate_steps=True
35
+ )
36
+
37
+
38
+ def _prompt_generator(question: str, chart_required: bool):
39
+
40
+ chat_prompt = f"""You are a senior data analyst working with CSV data. Adhere strictly to the following guidelines:
41
+
42
+ 1. **Data Verification:** Always inspect the data with `.sample(5).to_dict()` before performing any analysis.
43
+ 2. **Data Integrity:** Ensure proper handling of null values to maintain accuracy and reliability.
44
+ 3. **Communication:** Provide concise, professional, and well-structured responses.
45
+ 4. Avoid including any internal processing details or references to the methods used to generate your response (ex: based on the tool call, using the function -> These types of phrases.)
46
+
47
+ **Query:** {question}
48
+
49
+ """
50
+
51
+ chart_prompt = f"""You are a senior data analyst working with CSV data. Follow these rules STRICTLY:
52
+
53
+ 1. Generate ONE unique identifier FIRST using: unique_id = uuid.uuid4().hex
54
+ 2. Visualization requirements:
55
+ - Adjust font sizes, rotate labels (45° if needed), truncate for readability
56
+ - Figure size: (12, 6)
57
+ - Descriptive titles (fontsize=14)
58
+ - Colorblind-friendly palettes
59
+ 3. File handling rules:
60
+ - Create MAXIMUM 2 charts if absolutely necessary
61
+ - For multiple charts:
62
+ * Arrange in grid format (2x1 vertical layout preferred)
63
+ * Use SAME unique_id with suffixes:
64
+ - f"{{unique_id}}_1.png"
65
+ - f"{{unique_id}}_2.png"
66
+ - Save EXCLUSIVELY to "generated_charts" folder
67
+ - File naming: f"chart_{{unique_id}}.png" (for single chart)
68
+ 4. FINAL OUTPUT MUST BE:
69
+ - For single chart: f"generated_charts/chart_{{unique_id}}.png"
70
+ - For multiple charts: f"generated_charts/chart_{{unique_id}}.png" (combined grid image)
71
+ - **ONLY return this full path string, nothing else**
72
+
73
+ **Query:** {question}
74
+
75
+ IMPORTANT:
76
+ - Generate the unique_id FIRST before any operations
77
+ - Use THE SAME unique_id throughout entire process
78
+ - NEVER generate new UUIDs after initial creation
79
+ - Return EXACT filepath string of the final saved chart
80
+ """
81
+
82
+
83
+ if chart_required:
84
+ return ChatPromptTemplate.from_template(chart_prompt)
85
+ else:
86
+ return ChatPromptTemplate.from_template(chat_prompt)
87
+
88
+ def langchain_gemini_csv_handler(csv_url: str, question: str, chart_required: bool):
89
+ global current_key_index
90
+ data = pd.read_csv(csv_url)
91
+
92
+ attempts = 0
93
+ total_keys = len(google_api_keys)
94
+ while attempts < total_keys:
95
+ try:
96
+ api_key = google_api_keys[current_key_index]
97
+ print(f"Using API key index {current_key_index}")
98
+
99
+ llm = ChatGoogleGenerativeAI(model=model_name, api_key=api_key)
100
+
101
+ # Create tool with validated name
102
+ tool = PythonAstREPLTool(
103
+ locals={
104
+ "df": data,
105
+ "pd": pd,
106
+ "np": np,
107
+ "plt": plt,
108
+ "sns": sns,
109
+ "matplotlib": matplotlib,
110
+ "uuid": uuid,
111
+ "dt": dt
112
+ },
113
+ )
114
+
115
+ agent = create_agent(llm, data, [tool])
116
+
117
+ prompt = _prompt_generator(question, chart_required)
118
+ result = agent.invoke({"input": prompt})
119
+ return result.get("output")
120
+
121
+ except Exception as e:
122
+ print(f"Error using API key index {current_key_index}: {e}")
123
+ current_key_index = (current_key_index + 1) % total_keys
124
+ attempts += 1
125
+
126
+ print("All API keys have been exhausted.")
127
+ return None
orchestrator_functions.py CHANGED
@@ -19,6 +19,7 @@ import numpy as np
19
  import matplotlib.pyplot as plt
20
  import matplotlib
21
  import seaborn as sns
 
22
  from supabase_service import upload_image_to_supabase
23
  from util_service import _prompt_generator, process_answer
24
  import matplotlib
@@ -385,78 +386,85 @@ async def csv_chart(csv_url: str, query: str):
385
 
386
 
387
 
388
- async def csv_chat(csv_url: str, query: str):
389
- try:
390
- updated_query = f"{query} and Do not show any charts or graphs."
391
 
392
- # Process with langchain_chat first
393
- try:
394
- lang_answer = await asyncio.to_thread(
395
- langchain_csv_chat, csv_url, query, False
396
- )
397
- if lang_answer is not None:
398
- return {"answer": jsonable_encoder(lang_answer)}
399
- raise Exception("Langchain failed to process")
400
- except Exception as langchain_error:
401
- print(f"Langchain error, falling back to Groq: {str(langchain_error)}")
402
 
403
- # Process with groq_chat if langchain fails
404
- try:
405
- groq_answer = await asyncio.to_thread(groq_chat, csv_url, updated_query)
406
- print("groq_answer:", groq_answer)
407
 
408
- if process_answer(groq_answer) == "Empty response received.":
409
- return {"answer": "Sorry, I couldn't find relevant data..."}
410
 
411
- if process_answer(groq_answer) or groq_answer is None:
412
- return {"answer": "error"}
413
 
414
- return {"answer": jsonable_encoder(groq_answer)}
415
- except Exception as groq_error:
416
- print(f"Groq processing error: {str(groq_error)}")
417
- return {"answer": "error"}
418
 
419
- except Exception as e:
420
- print(f"Error processing request: {str(e)}")
421
- return {"answer": "error"}
422
 
423
 
424
 
425
 
426
 
427
 
428
- # async def csv_chat(csv_url: str, query: str):
429
- # try:
430
- # updated_query = f"{query} and Do not show any charts or graphs."
431
 
432
- # # Process with Groq first
433
- # try:
434
- # groq_answer = await asyncio.to_thread(groq_chat, csv_url, updated_query)
435
- # print("groq_answer:", groq_answer)
436
 
437
- # if process_answer(groq_answer) == "Empty response received." or groq_answer == None:
438
- # return {"answer": "Sorry, I couldn't find relevant data..."}
439
 
440
- # if process_answer(groq_answer) or groq_answer == None:
441
- # raise Exception("Groq response not usable, falling back to LangChain")
442
 
443
- # return {"answer": jsonable_encoder(groq_answer)}
444
 
445
- # except Exception as groq_error:
446
- # print(f"Groq error, falling back to LangChain: {str(groq_error)}")
 
447
 
448
- # # Process with LangChain if Groq fails
449
- # try:
450
- # lang_answer = await asyncio.to_thread(
451
- # langchain_csv_chat, csv_url, query, False
452
- # )
453
- # if not process_answer(lang_answer):
454
- # return {"answer": jsonable_encoder(lang_answer)}
455
- # return {"answer": "Sorry, I couldn't find relevant data..."}
456
- # except Exception as langchain_error:
457
- # print(f"LangChain processing error: {str(langchain_error)}")
458
- # return {"answer": "error"}
 
 
 
 
 
 
459
 
460
- # except Exception as e:
461
- # print(f"Error processing request: {str(e)}")
462
- # return {"answer": "error"}
 
19
  import matplotlib.pyplot as plt
20
  import matplotlib
21
  import seaborn as sns
22
+ from gemini_langchain_agent import langchain_gemini_csv_handler
23
  from supabase_service import upload_image_to_supabase
24
  from util_service import _prompt_generator, process_answer
25
  import matplotlib
 
386
 
387
 
388
 
389
+ # async def csv_chat(csv_url: str, query: str):
390
+ # try:
391
+ # updated_query = f"{query} and Do not show any charts or graphs."
392
 
393
+ # # Process with langchain_chat first
394
+ # try:
395
+ # lang_answer = await asyncio.to_thread(
396
+ # langchain_csv_chat, csv_url, query, False
397
+ # )
398
+ # if lang_answer is not None:
399
+ # return {"answer": jsonable_encoder(lang_answer)}
400
+ # raise Exception("Langchain failed to process")
401
+ # except Exception as langchain_error:
402
+ # print(f"Langchain error, falling back to Groq: {str(langchain_error)}")
403
 
404
+ # # Process with groq_chat if langchain fails
405
+ # try:
406
+ # groq_answer = await asyncio.to_thread(groq_chat, csv_url, updated_query)
407
+ # print("groq_answer:", groq_answer)
408
 
409
+ # if process_answer(groq_answer) == "Empty response received.":
410
+ # return {"answer": "Sorry, I couldn't find relevant data..."}
411
 
412
+ # if process_answer(groq_answer) or groq_answer is None:
413
+ # return {"answer": "error"}
414
 
415
+ # return {"answer": jsonable_encoder(groq_answer)}
416
+ # except Exception as groq_error:
417
+ # print(f"Groq processing error: {str(groq_error)}")
418
+ # return {"answer": "error"}
419
 
420
+ # except Exception as e:
421
+ # print(f"Error processing request: {str(e)}")
422
+ # return {"answer": "error"}
423
 
424
 
425
 
426
 
427
 
428
 
429
+ async def csv_chat(csv_url: str, query: str):
430
+ try:
431
+ updated_query = f"{query}"
432
 
433
+ # Process with gemini langchain chat first
434
+ try:
435
+ # groq_answer = await asyncio.to_thread(groq_chat, csv_url, updated_query)
436
+ # print("groq_answer:", groq_answer)
437
 
438
+ # if process_answer(groq_answer) == "Empty response received." or groq_answer == None:
439
+ # return {"answer": "Sorry, I couldn't find relevant data..."}
440
 
441
+ # if process_answer(groq_answer) or groq_answer == None:
442
+ # raise Exception("Groq response not usable, falling back to LangChain")
443
 
444
+ # return {"answer": jsonable_encoder(groq_answer)}
445
 
446
+ lc_gemini = await asyncio.to_thread(
447
+ langchain_gemini_csv_handler, csv_url, updated_query, False
448
+ )
449
 
450
+ if lc_gemini is not None:
451
+ return {"answer": jsonable_encoder(lc_gemini)}
452
+
453
+ except Exception as gemini_error:
454
+ print(f"Gemini error, falling back to Groq-LangChain: {str(gemini_error)}")
455
+
456
+ # Process with LangChain if Groq fails
457
+ try:
458
+ lang_answer = await asyncio.to_thread(
459
+ langchain_csv_chat, csv_url, query, False
460
+ )
461
+ if not process_answer(lang_answer):
462
+ return {"answer": jsonable_encoder(lang_answer)}
463
+ return {"answer": "Sorry, I couldn't find relevant data..."}
464
+ except Exception as langchain_error:
465
+ print(f"LangChain processing error: {str(langchain_error)}")
466
+ return {"answer": "error"}
467
 
468
+ except Exception as e:
469
+ print(f"Error processing request: {str(e)}")
470
+ return {"answer": "error"}