|
import pandas as pd |
|
import json |
|
from typing import List, Literal, Optional |
|
from pydantic import BaseModel |
|
from dotenv import load_dotenv |
|
from pydantic_ai import Agent |
|
from csv_service import clean_data |
|
from python_code_executor_service import PythonExecutor |
|
from cerebras_instance_provider import InstanceProvider |
|
import logging |
|
|
|
load_dotenv() |
|
|
|
instance_provider = InstanceProvider() |
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
class CodeResponse(BaseModel): |
|
"""Container for code-related responses""" |
|
language: str = "python" |
|
code: str |
|
|
|
class ChartSpecification(BaseModel): |
|
"""Details about requested charts""" |
|
image_description: str |
|
code: Optional[str] = None |
|
|
|
class AnalysisOperation(BaseModel): |
|
"""Container for a single analysis operation with its code and result""" |
|
code: CodeResponse |
|
result_var: str |
|
|
|
class CsvChatResult(BaseModel): |
|
"""Structured response for CSV-related AI interactions""" |
|
|
|
|
|
casual_response: str |
|
|
|
|
|
|
|
analysis_operations: AnalysisOperation |
|
|
|
|
|
|
|
charts: Optional[ChartSpecification] = None |
|
|
|
|
|
def get_csv_info(df: pd.DataFrame) -> dict: |
|
"""Get metadata/info about the CSV""" |
|
info = { |
|
'num_rows': len(df), |
|
'num_cols': len(df.columns), |
|
'example_rows': df.head(2).to_dict('records'), |
|
'dtypes': {col: str(df[col].dtype) for col in df.columns}, |
|
'columns': list(df.columns), |
|
'numeric_columns': [col for col in df.columns if pd.api.types.is_numeric_dtype(df[col])], |
|
'categorical_columns': [col for col in df.columns if pd.api.types.is_string_dtype(df[col])] |
|
} |
|
return info |
|
|
|
|
|
def get_csv_system_prompt(df: pd.DataFrame) -> str: |
|
"""Generate system prompt for CSV analysis""" |
|
csv_info = get_csv_info(df) |
|
prompt = f""" |
|
You're a CSV analysis assistant. The pandas DataFrame is loaded as 'df' - use this variable. |
|
|
|
CSV Info: |
|
- Shape: {csv_info['num_rows']} rows × {csv_info['num_cols']} cols |
|
- Columns: {csv_info['columns']} |
|
- Sample: {csv_info['example_rows']} |
|
- Dtypes: {csv_info['dtypes']} |
|
|
|
STRICT REQUIREMENTS: |
|
1. NEVER calculate or predict values yourself - ALWAYS return executable code that would produce the result |
|
2. Use existing 'df' - never recreate it |
|
3. For any data structures (Lists, Records, Tables, Dictionaries, etc.), always return them as JSON with correct indentation |
|
4. For charts: |
|
- Use matplotlib/seaborn only |
|
- Professional quality: proper sizing, labels, titles |
|
- Figure size: (14, 8) for complex, (12, 6) for simple |
|
- Clear titles (fontsize=16), labels (fontsize=14) |
|
- Rotate x-labels if needed (45°, fontsize=12) |
|
- Add annotations/gridlines where helpful |
|
- Use colorblind-friendly palettes |
|
- Always include plt.tight_layout() |
|
|
|
Example professional chart: |
|
plt.figure(figsize=(14, 8)) |
|
sns.barplot(x='category', y='value', data=df, palette='muted') |
|
plt.title('Value by Category', fontsize=16) |
|
plt.xlabel('Category', fontsize=14) |
|
plt.ylabel('Value', fontsize=14) |
|
plt.xticks(rotation=45) |
|
plt.grid(alpha=0.3) |
|
plt.tight_layout() |
|
plt.show() |
|
|
|
Example professional response for a dataframe: |
|
num_rows = len(df) |
|
|
|
|
|
Return complete, executable code. |
|
""" |
|
return prompt |
|
|
|
|
|
def create_csv_agent(df: pd.DataFrame, max_retries: int = 1) -> Agent: |
|
"""Create and return a CSV analysis agent with API key rotation""" |
|
csv_system_prompt = get_csv_system_prompt(df) |
|
|
|
for attempt in range(max_retries): |
|
try: |
|
model = instance_provider.get_instance() |
|
if model is None: |
|
raise RuntimeError("No available API instances") |
|
|
|
csv_agent = Agent( |
|
model=model, |
|
output_type=CsvChatResult, |
|
system_prompt=csv_system_prompt, |
|
) |
|
|
|
return csv_agent |
|
|
|
except Exception as e: |
|
api_key = instance_provider.get_api_key_for_model(model) |
|
if api_key: |
|
logger.info(f"Error with API key (attempt {attempt + 1}): {str(e)}") |
|
instance_provider.report_error(api_key) |
|
continue |
|
|
|
raise RuntimeError(f"Failed to create agent after {max_retries} attempts") |
|
|
|
|
|
async def query_csv_agent(csv_url: str, question: str, chat_id: str) -> str: |
|
"""Query the CSV agent with a DataFrame and question and return formatted output""" |
|
|
|
|
|
df = clean_data(csv_url) |
|
|
|
|
|
agent = create_csv_agent(df) |
|
result = await agent.run(question) |
|
|
|
|
|
executor = PythonExecutor(df) |
|
|
|
|
|
if not isinstance(result.output, CsvChatResult): |
|
|
|
try: |
|
response_data = result.output if isinstance(result.output, dict) else json.loads(result.output) |
|
chat_result = CsvChatResult(**response_data) |
|
except Exception as e: |
|
raise ValueError(f"Could not parse agent response: {str(e)}") |
|
else: |
|
chat_result = result.output |
|
|
|
logger.info("Chat Result Original Object:", chat_result) |
|
|
|
|
|
formatted_output = await executor.process_response(chat_result, chat_id) |
|
|
|
return formatted_output |
|
|