|
import pandas as pd |
|
import json |
|
from typing import List, Literal, Optional |
|
from pydantic import BaseModel |
|
from dotenv import load_dotenv |
|
from pydantic_ai import Agent |
|
from csv_service import clean_data |
|
from python_code_executor_service import PythonExecutor |
|
from together_ai_instance_provider import InstanceProvider |
|
|
|
load_dotenv() |
|
|
|
instance_provider = InstanceProvider() |
|
|
|
class CodeResponse(BaseModel): |
|
"""Container for code-related responses""" |
|
language: str = "python" |
|
code: str |
|
|
|
class ChartSpecification(BaseModel): |
|
"""Details about requested charts""" |
|
image_description: str |
|
code: Optional[str] = None |
|
|
|
class AnalysisOperation(BaseModel): |
|
"""Container for a single analysis operation with its code and result""" |
|
code: CodeResponse |
|
description: str |
|
|
|
class CsvChatResult(BaseModel): |
|
"""Structured response for CSV-related AI interactions""" |
|
response_type: Literal["casual", "data_analysis", "visualization", "mixed"] |
|
|
|
|
|
casual_response: str |
|
|
|
|
|
analysis_operations: List[AnalysisOperation] |
|
|
|
|
|
charts: Optional[List[ChartSpecification]] = None |
|
|
|
|
|
def get_csv_info(df: pd.DataFrame) -> dict: |
|
"""Get metadata/info about the CSV""" |
|
info = { |
|
'num_rows': len(df), |
|
'num_cols': len(df.columns), |
|
'example_rows': df.head(2).to_dict('records'), |
|
'dtypes': {col: str(df[col].dtype) for col in df.columns}, |
|
'columns': list(df.columns), |
|
'numeric_columns': [col for col in df.columns if pd.api.types.is_numeric_dtype(df[col])], |
|
'categorical_columns': [col for col in df.columns if pd.api.types.is_string_dtype(df[col])] |
|
} |
|
return info |
|
|
|
|
|
def get_csv_system_prompt(df: pd.DataFrame) -> str: |
|
"""Generate system prompt for CSV analysis""" |
|
csv_info = get_csv_info(df) |
|
|
|
prompt = f""" |
|
You're a CSV analysis assistant. The pandas DataFrame is loaded as 'df' - use this variable. |
|
|
|
CSV Info: |
|
- Rows: {csv_info['num_rows']}, Cols: {csv_info['num_cols']} |
|
- Columns: {csv_info['columns']} |
|
- Sample: {csv_info['example_rows']} |
|
- Dtypes: {csv_info['dtypes']} |
|
|
|
Strict Rules: |
|
1. Never recreate 'df' - use the existing variable |
|
2. For analysis: |
|
- Include necessary imports (except pandas) and include complete code |
|
- Use df directly (e.g., print(df[...].mean())) |
|
3. For visualizations (Only support libraries are Matplotlib and Seaborn): |
|
- Complete code with plt.show() |
|
- Example: plt.bar(df['x'], df['y']) \n plt.show() |
|
4. For Lists and Dictionaries, return them as JSON |
|
|
|
Example: |
|
import json |
|
print(json.dumps(df[df['col'] == 'val'].to_dict('records'), indent=2)) |
|
""" |
|
return prompt |
|
|
|
|
|
def create_csv_agent(df: pd.DataFrame, max_retries: int = 1) -> Agent: |
|
"""Create and return a CSV analysis agent with API key rotation""" |
|
csv_system_prompt = get_csv_system_prompt(df) |
|
|
|
for attempt in range(max_retries): |
|
try: |
|
model = instance_provider.get_instance() |
|
if model is None: |
|
raise RuntimeError("No available API instances") |
|
|
|
csv_agent = Agent( |
|
model=model, |
|
output_type=CsvChatResult, |
|
system_prompt=csv_system_prompt, |
|
) |
|
|
|
return csv_agent |
|
|
|
except Exception as e: |
|
api_key = instance_provider.get_api_key_for_model(model) |
|
if api_key: |
|
print(f"Error with API key (attempt {attempt + 1}): {str(e)}") |
|
instance_provider.report_error(api_key) |
|
continue |
|
|
|
raise RuntimeError(f"Failed to create agent after {max_retries} attempts") |
|
|
|
|
|
async def query_csv_agent(csv_url: str, question: str, chat_id: str) -> str: |
|
"""Query the CSV agent with a DataFrame and question and return formatted output""" |
|
|
|
|
|
df = clean_data(csv_url) |
|
|
|
|
|
agent = create_csv_agent(df) |
|
result = await agent.run(question) |
|
|
|
|
|
executor = PythonExecutor(df) |
|
|
|
|
|
if not isinstance(result.output, CsvChatResult): |
|
|
|
try: |
|
response_data = result.output if isinstance(result.output, dict) else json.loads(result.output) |
|
chat_result = CsvChatResult(**response_data) |
|
except Exception as e: |
|
raise ValueError(f"Could not parse agent response: {str(e)}") |
|
else: |
|
chat_result = result.output |
|
|
|
print("Chat Result Original Object:", chat_result) |
|
|
|
|
|
formatted_output = executor.process_response(chat_result, chat_id) |
|
|
|
return formatted_output |
|
|