Soumik555 commited on
Commit
d7d1d4e
·
1 Parent(s): 1211e2b

added together ai agent

Browse files
controller.py CHANGED
@@ -29,6 +29,7 @@ from gemini_report_generator import generate_csv_report
29
  from intitial_q_handler import if_initial_chart_question, if_initial_chat_question
30
  from orchestrator_agent import csv_orchestrator_chat
31
  from supabase_service import upload_file_to_supabase
 
32
  from util_service import _prompt_generator, process_answer
33
  from fastapi.middleware.cors import CORSMiddleware
34
  import matplotlib
@@ -363,11 +364,15 @@ async def csv_chat(request: Dict, authorization: str = Header(None)):
363
  # return {"answer": jsonable_encoder(orchestrator_answer)}
364
 
365
  # Process with groq_chat first
366
- groq_answer = await asyncio.to_thread(groq_chat, decoded_url, query)
367
- logger.info("groq_answer:", groq_answer)
368
 
369
- if process_answer(groq_answer) == "Empty response received.":
370
- return {"answer": "Sorry, I couldn't find relevant data..."}
 
 
 
 
371
 
372
  # if process_answer(groq_answer):
373
  # lang_answer = await asyncio.to_thread(
@@ -377,7 +382,7 @@ async def csv_chat(request: Dict, authorization: str = Header(None)):
377
  # return {"answer": "error"}
378
  # return {"answer": jsonable_encoder(lang_answer)}
379
 
380
- return {"answer": jsonable_encoder(groq_answer)}
381
 
382
  except Exception as e:
383
  logger.error(f"Error processing request: {str(e)}")
 
29
  from intitial_q_handler import if_initial_chart_question, if_initial_chat_question
30
  from orchestrator_agent import csv_orchestrator_chat
31
  from supabase_service import upload_file_to_supabase
32
+ from together_ai_llama_agent import query_csv_agent
33
  from util_service import _prompt_generator, process_answer
34
  from fastapi.middleware.cors import CORSMiddleware
35
  import matplotlib
 
364
  # return {"answer": jsonable_encoder(orchestrator_answer)}
365
 
366
  # Process with groq_chat first
367
+ # groq_answer = await asyncio.to_thread(groq_chat, decoded_url, query)
368
+ # logger.info("groq_answer:", groq_answer)
369
 
370
+ result = await asyncio.to_thread(query_csv_agent, decoded_url, query)
371
+ logger.info("together ai csv answer == >", result)
372
+ return {"answer": result}
373
+
374
+ # if process_answer(groq_answer) == "Empty response received.":
375
+ # return {"answer": "Sorry, I couldn't find relevant data..."}
376
 
377
  # if process_answer(groq_answer):
378
  # lang_answer = await asyncio.to_thread(
 
382
  # return {"answer": "error"}
383
  # return {"answer": jsonable_encoder(lang_answer)}
384
 
385
+ # return {"answer": jsonable_encoder(groq_answer)}
386
 
387
  except Exception as e:
388
  logger.error(f"Error processing request: {str(e)}")
python_code_executor_service.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import uuid
2
+ import matplotlib.pyplot as plt
3
+ from pathlib import Path
4
+ from typing import Dict, Any, List, Optional
5
+ import pandas as pd
6
+ import json
7
+ import io
8
+ import contextlib
9
+ import traceback
10
+ from pydantic import BaseModel
11
+
12
+ class CodeResponse(BaseModel):
13
+ """Container for code-related responses"""
14
+ language: str = "python"
15
+ code: str
16
+
17
+ class ChartSpecification(BaseModel):
18
+ """Details about requested charts"""
19
+ image_description: str
20
+ code: Optional[str] = None
21
+
22
+ class AnalysisOperation(BaseModel):
23
+ """Container for a single analysis operation with its code and result"""
24
+ code: CodeResponse
25
+ description: str
26
+
27
+ class CsvChatResult(BaseModel):
28
+ """Structured response for CSV-related AI interactions"""
29
+ response_type: str # Literal["casual", "data_analysis", "visualization", "mixed"]
30
+ casual_response: str
31
+ analysis_operations: List[AnalysisOperation]
32
+ charts: Optional[List[ChartSpecification]] = None
33
+
34
+ class PythonExecutor:
35
+ """Handles execution of Python code and dummy image generation for CSV analysis"""
36
+
37
+ def __init__(self, df: pd.DataFrame, charts_folder: str = "charts"):
38
+ """
39
+ Initialize the PythonExecutor with a DataFrame
40
+
41
+ Args:
42
+ df (pd.DataFrame): The DataFrame to operate on
43
+ charts_folder (str): Folder to save charts in
44
+ """
45
+ self.df = df
46
+ self.charts_folder = Path(charts_folder)
47
+ self.charts_folder.mkdir(exist_ok=True)
48
+
49
+ def execute_code(self, code: str) -> Dict[str, Any]:
50
+ """
51
+ Execute Python code and return the output and any generated plots
52
+
53
+ Args:
54
+ code (str): Python code to execute
55
+
56
+ Returns:
57
+ dict: Dictionary containing execution results and any generated plots
58
+ """
59
+ output = ""
60
+ error = None
61
+ plots = []
62
+
63
+ # Capture stdout
64
+ stdout = io.StringIO()
65
+
66
+ # Monkey patch plt.show() to save figures
67
+ original_show = plt.show
68
+
69
+ def custom_show():
70
+ """Custom show function that saves plots instead of displaying them"""
71
+ for i, fig in enumerate(plt.get_fignums()):
72
+ figure = plt.figure(fig)
73
+ # Save plot to bytes buffer
74
+ buf = io.BytesIO()
75
+ figure.savefig(buf, format='png', bbox_inches='tight')
76
+ buf.seek(0)
77
+ plots.append(buf.read())
78
+ plt.close('all')
79
+
80
+ try:
81
+ # Create execution context with common libraries and the DataFrame
82
+ exec_globals = {
83
+ 'pd': pd,
84
+ 'plt': plt,
85
+ 'json': json,
86
+ 'df': self.df, # Include the DataFrame in the execution context
87
+ '__builtins__': __builtins__,
88
+ }
89
+
90
+ # Replace plt.show with custom implementation
91
+ plt.show = custom_show
92
+
93
+ # Execute code and capture output
94
+ with contextlib.redirect_stdout(stdout):
95
+ exec(code, exec_globals)
96
+
97
+ output = stdout.getvalue()
98
+
99
+ except Exception as e:
100
+ error = {
101
+ "message": str(e),
102
+ "traceback": traceback.format_exc()
103
+ }
104
+ finally:
105
+ # Restore original plt.show
106
+ plt.show = original_show
107
+
108
+ return {
109
+ 'output': output,
110
+ 'error': error,
111
+ 'plots': plots
112
+ }
113
+
114
+ def save_plot_dummy(self, plot_data: bytes, description: str) -> str:
115
+ """
116
+ Save plot to charts folder and return a dummy URL
117
+
118
+ Args:
119
+ plot_data (bytes): Image data in bytes
120
+ description (str): Description of the plot
121
+
122
+ Returns:
123
+ str: Dummy URL for the chart
124
+ """
125
+ # Generate unique filename
126
+ filename = f"chart_{uuid.uuid4().hex}.png"
127
+ filepath = self.charts_folder / filename
128
+
129
+ # Save the plot (even though we're using dummy URLs, we still save it)
130
+ with open(filepath, 'wb') as f:
131
+ f.write(plot_data)
132
+
133
+ # Return a dummy URL
134
+ return f"https://example.com/charts/{filename}"
135
+
136
+ def process_response(self, response: CsvChatResult) -> str:
137
+ """
138
+ Process the CsvChatResult response and generate formatted output
139
+
140
+ Args:
141
+ response (CsvChatResult): Response from CSV analysis
142
+
143
+ Returns:
144
+ str: Formatted output with results and dummy image URLs
145
+ """
146
+ output_parts = []
147
+
148
+ # Add casual response
149
+ output_parts.append(response.casual_response)
150
+
151
+ # Process analysis operations
152
+ for operation in response.analysis_operations:
153
+ # Execute the code
154
+ result = self.execute_code(operation.code.code)
155
+
156
+ # Add operation description
157
+ output_parts.append(f"\n{operation.description}:")
158
+
159
+ # Add output or error
160
+ if result['error']:
161
+ output_parts.append(f"Error: {result['error']['message']}")
162
+ else:
163
+ output_parts.append(result['output'].strip())
164
+
165
+ # Process charts if they exist
166
+ if response.charts:
167
+ output_parts.append("\nVisualizations:")
168
+
169
+ for chart in response.charts:
170
+ if chart.code:
171
+ # Execute the chart code
172
+ result = self.execute_code(chart.code)
173
+
174
+ if result['plots']:
175
+ # Save each generated plot and get dummy URL
176
+ for plot_data in result['plots']:
177
+ dummy_url = self.save_plot_dummy(plot_data, chart.image_description)
178
+ output_parts.append(f"\n{chart.image_description}")
179
+ output_parts.append(f"![{chart.image_description}]({dummy_url})")
180
+ elif result['error']:
181
+ output_parts.append(f"\nError generating {chart.image_description}: {result['error']['message']}")
182
+
183
+ return "\n".join(output_parts)
together_ai_instance_provider.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # instance_provider.py
2
+ import os
3
+ import time
4
+ from typing import Dict, Optional
5
+ from pydantic_ai.models.openai import OpenAIModel
6
+ from pydantic_ai.providers.openai import OpenAIProvider
7
+
8
+ class InstanceProvider:
9
+ """Manages multiple Together AI API instances with failover support"""
10
+
11
+ def __init__(self):
12
+ self.instances: Dict[str, dict] = {}
13
+ self.locked_keys: Dict[str, float] = {} # key: lock_time
14
+ self.LOCK_DURATION = 1800 # 30 minutes in seconds
15
+ self._initialize_instances()
16
+
17
+ def _initialize_instances(self):
18
+ """Load all API keys from environment and create instances"""
19
+ api_keys = os.getenv("TOGETHER_AI_API_KEYS", "").split(",")
20
+ base_url = os.getenv("TOGETHER_AI_BASE_URL")
21
+ model_name = os.getenv("TOGETHER_AI_LLM_MODEL_NAME")
22
+
23
+ for key in api_keys:
24
+ key = key.strip()
25
+ if key:
26
+ self.instances[key] = {
27
+ 'model': OpenAIModel(
28
+ model_name,
29
+ provider=OpenAIProvider(
30
+ base_url=base_url,
31
+ api_key=key
32
+ )
33
+ ),
34
+ 'error_count': 0
35
+ }
36
+
37
+ def _clean_locked_keys(self):
38
+ """Remove keys that have been locked beyond the duration"""
39
+ current_time = time.time()
40
+ expired_keys = [
41
+ key for key, lock_time in self.locked_keys.items()
42
+ if current_time - lock_time > self.LOCK_DURATION
43
+ ]
44
+ for key in expired_keys:
45
+ del self.locked_keys[key]
46
+
47
+ def get_instance(self) -> Optional[OpenAIModel]:
48
+ """Get an available instance, rotating through keys"""
49
+ self._clean_locked_keys()
50
+
51
+ for key, instance_data in self.instances.items():
52
+ if key not in self.locked_keys:
53
+ return instance_data['model']
54
+
55
+ # If we get here, all keys are locked
56
+ raise RuntimeError("All API keys exhausted or temporarily locked")
57
+
58
+ def report_error(self, api_key: str):
59
+ """Report an error for a specific API key and lock it"""
60
+ if api_key in self.instances:
61
+ self.instances[api_key]['error_count'] += 1
62
+ self.locked_keys[api_key] = time.time()
63
+
64
+ def get_api_key_for_model(self, model: OpenAIModel) -> Optional[str]:
65
+ """Get the API key for a given model instance"""
66
+ for key, instance_data in self.instances.items():
67
+ if instance_data['model'] == model:
68
+ return key
69
+ return None
together_ai_llama_agent.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import json
3
+ from typing import List, Literal, Optional
4
+ from pydantic import BaseModel
5
+ from dotenv import load_dotenv
6
+ from pydantic_ai import Agent
7
+ from csv_service import clean_data
8
+ from python_code_executor_service import PythonExecutor
9
+ from together_ai_instance_provider import InstanceProvider
10
+
11
+ load_dotenv()
12
+
13
+ instance_provider = InstanceProvider()
14
+
15
+ class CodeResponse(BaseModel):
16
+ """Container for code-related responses"""
17
+ language: str = "python"
18
+ code: str
19
+
20
+ class ChartSpecification(BaseModel):
21
+ """Details about requested charts"""
22
+ image_description: str
23
+ code: Optional[str] = None
24
+
25
+ class AnalysisOperation(BaseModel):
26
+ """Container for a single analysis operation with its code and result"""
27
+ code: CodeResponse
28
+ description: str
29
+
30
+ class CsvChatResult(BaseModel):
31
+ """Structured response for CSV-related AI interactions"""
32
+ response_type: Literal["casual", "data_analysis", "visualization", "mixed"]
33
+
34
+ # Casual chat response
35
+ casual_response: str
36
+
37
+ # Data analysis components
38
+ analysis_operations: List[AnalysisOperation]
39
+
40
+ # Visualization components
41
+ charts: Optional[List[ChartSpecification]] = None
42
+
43
+
44
+ def get_csv_info(df: pd.DataFrame) -> dict:
45
+ """Get metadata/info about the CSV"""
46
+ info = {
47
+ 'num_rows': len(df),
48
+ 'num_cols': len(df.columns),
49
+ 'example_rows': df.head(2).to_dict('records'),
50
+ 'dtypes': {col: str(df[col].dtype) for col in df.columns},
51
+ 'columns': list(df.columns),
52
+ 'numeric_columns': [col for col in df.columns if pd.api.types.is_numeric_dtype(df[col])],
53
+ 'categorical_columns': [col for col in df.columns if pd.api.types.is_string_dtype(df[col])]
54
+ }
55
+ return info
56
+
57
+
58
+ def get_csv_system_prompt(df: pd.DataFrame) -> str:
59
+ """Generate system prompt for CSV analysis"""
60
+ csv_info = get_csv_info(df)
61
+
62
+ prompt = f"""
63
+ You're a CSV analysis assistant. The pandas DataFrame is loaded as 'df' - use this variable.
64
+
65
+ CSV Info:
66
+ - Rows: {csv_info['num_rows']}, Cols: {csv_info['num_cols']}
67
+ - Columns: {csv_info['columns']}
68
+ - Sample: {csv_info['example_rows']}
69
+ - Dtypes: {csv_info['dtypes']}
70
+
71
+ Strict Rules:
72
+ 1. Never recreate 'df' - use the existing variable
73
+ 2. For analysis:
74
+ - Include necessary imports (except pandas) and include complete code
75
+ - Use df directly (e.g., print(df[...].mean()))
76
+ 3. For visualizations:
77
+ - Specify chart type and include complete code
78
+ - Example: plt.bar(df['x'], df['y'])
79
+ 4. For Lists and Dictionaries, return them as JSON
80
+
81
+ Example:
82
+ import json
83
+ print(json.dumps(df[df['col'] == 'val'].to_dict('records'), indent=2))
84
+ """
85
+ return prompt
86
+
87
+
88
+ def create_csv_agent(df: pd.DataFrame, max_retries: int = 1) -> Agent:
89
+ """Create and return a CSV analysis agent with API key rotation"""
90
+ csv_system_prompt = get_csv_system_prompt(df)
91
+
92
+ for attempt in range(max_retries):
93
+ try:
94
+ model = instance_provider.get_instance()
95
+ if model is None:
96
+ raise RuntimeError("No available API instances")
97
+
98
+ csv_agent = Agent(
99
+ model=model,
100
+ output_type=CsvChatResult,
101
+ system_prompt=csv_system_prompt,
102
+ )
103
+
104
+ return csv_agent
105
+
106
+ except Exception as e:
107
+ api_key = instance_provider.get_api_key_for_model(model)
108
+ if api_key:
109
+ print(f"Error with API key (attempt {attempt + 1}): {str(e)}")
110
+ instance_provider.report_error(api_key)
111
+ continue
112
+
113
+ raise RuntimeError(f"Failed to create agent after {max_retries} attempts")
114
+
115
+
116
+ async def query_csv_agent(csv_url: str, question: str) -> str:
117
+ """Query the CSV agent with a DataFrame and question and return formatted output"""
118
+
119
+ # Get the DataFrame from the CSV URL
120
+ df = clean_data(csv_url)
121
+
122
+ # Create agent and get response
123
+ agent = create_csv_agent(df)
124
+ result = await agent.run(question)
125
+
126
+ # Process the response through PythonExecutor
127
+ executor = PythonExecutor(df)
128
+
129
+ # Convert the raw output to CsvChatResult if needed
130
+ if not isinstance(result.output, CsvChatResult):
131
+ # Handle case where output needs conversion
132
+ try:
133
+ response_data = result.output if isinstance(result.output, dict) else json.loads(result.output)
134
+ chat_result = CsvChatResult(**response_data)
135
+ except Exception as e:
136
+ raise ValueError(f"Could not parse agent response: {str(e)}")
137
+ else:
138
+ chat_result = result.output
139
+
140
+ # Process and format the response
141
+ formatted_output = executor.process_response(chat_result)
142
+
143
+ return formatted_output