|
import pandas as pd |
|
import json |
|
from typing import List, Literal, Optional |
|
from pydantic import BaseModel |
|
from dotenv import load_dotenv |
|
from pydantic_ai import Agent |
|
from csv_service import clean_data |
|
from python_code_executor_service import PythonExecutor |
|
from together_ai_instance_provider import InstanceProvider |
|
|
|
load_dotenv() |
|
|
|
instance_provider = InstanceProvider() |
|
|
|
class CodeResponse(BaseModel): |
|
"""Container for code-related responses""" |
|
language: str = "python" |
|
code: str |
|
|
|
class ChartSpecification(BaseModel): |
|
"""Details about requested charts""" |
|
image_description: str |
|
code: Optional[str] = None |
|
|
|
class AnalysisOperation(BaseModel): |
|
"""Container for a single analysis operation with its code and result""" |
|
code: CodeResponse |
|
description: str |
|
|
|
class CsvChatResult(BaseModel): |
|
"""Structured response for CSV-related AI interactions""" |
|
response_type: Literal["casual", "data_analysis", "visualization", "mixed"] |
|
|
|
|
|
casual_response: str |
|
|
|
|
|
analysis_operations: List[AnalysisOperation] |
|
|
|
|
|
charts: Optional[List[ChartSpecification]] = None |
|
|
|
|
|
def get_csv_info(df: pd.DataFrame) -> dict: |
|
"""Get metadata/info about the CSV""" |
|
info = { |
|
'num_rows': len(df), |
|
'num_cols': len(df.columns), |
|
'example_rows': df.head(2).to_dict('records'), |
|
'dtypes': {col: str(df[col].dtype) for col in df.columns}, |
|
'columns': list(df.columns), |
|
'numeric_columns': [col for col in df.columns if pd.api.types.is_numeric_dtype(df[col])], |
|
'categorical_columns': [col for col in df.columns if pd.api.types.is_string_dtype(df[col])] |
|
} |
|
return info |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_csv_system_prompt(df: pd.DataFrame) -> str: |
|
"""Generate system prompt for CSV analysis""" |
|
csv_info = get_csv_info(df) |
|
|
|
prompt = f""" |
|
You're a CSV analysis assistant. The pandas DataFrame is loaded as 'df' - use this variable. |
|
|
|
CSV Info: |
|
- Rows: {csv_info['num_rows']}, Cols: {csv_info['num_cols']} |
|
- Columns: {csv_info['columns']} |
|
- Sample: {csv_info['example_rows']} |
|
- Dtypes: {csv_info['dtypes']} |
|
|
|
Strict Rules: |
|
1. Never recreate 'df' - use the existing variable |
|
2. For analysis: |
|
- Include necessary imports (except pandas) and include complete code |
|
- Use df directly (e.g., print(df[...].mean())) |
|
3. For visualizations: |
|
- Create the most professional, publication-quality charts possible |
|
- Maximize descriptive elements and detail while maintaining clarity |
|
- Figure size: (14, 8) for complex charts, (12, 6) for simpler ones |
|
- Use comprehensive titles (fontsize=16) and axis labels (fontsize=14) |
|
- Include informative legends (fontsize=12) when appropriate |
|
- Add annotations for important data points where valuable |
|
- Rotate x-labels (45° if needed) with fontsize=12 for readability |
|
- Use colorblind-friendly palettes (seaborn 'deep', 'muted', or 'colorblind') |
|
- Add gridlines (alpha=0.3) when they improve readability |
|
- Include proper margins and padding to prevent label cutoff |
|
- For distributions, include kernel density estimates when appropriate |
|
- For time series, use appropriate date formatting and markers |
|
- Do not use any visualization library other than matplotlib or seaborn |
|
- Complete code with plt.tight_layout() before plt.show() |
|
- Example professional chart: |
|
plt.figure(figsize=(14, 8)) |
|
ax = sns.barplot(x='category', y='value', data=df, palette='muted', ci=None) |
|
plt.title('Detailed Analysis of Values by Category', fontsize=16, pad=20) |
|
plt.xlabel('Category', fontsize=14) |
|
plt.ylabel('Average Value', fontsize=14) |
|
plt.xticks(rotation=45, ha='right', fontsize=12) |
|
plt.yticks(fontsize=12) |
|
ax.grid(True, linestyle='--', alpha=0.3) |
|
for p in ax.patches: |
|
ax.annotate(f'{{p.get_height():.1f}}', |
|
(p.get_x() + p.get_width() / 2., p.get_height()), |
|
ha='center', va='center', |
|
xytext=(0, 10), |
|
textcoords='offset points', |
|
fontsize=12) |
|
plt.tight_layout() |
|
plt.show() |
|
4. For Lists, Tables and Dictionaries, always return them as JSON |
|
|
|
Example: |
|
import json |
|
print(json.dumps(df[df['col'] == 'val'].to_dict('records'), indent=2)) |
|
""" |
|
return prompt |
|
|
|
|
|
def create_csv_agent(df: pd.DataFrame, max_retries: int = 1) -> Agent: |
|
"""Create and return a CSV analysis agent with API key rotation""" |
|
csv_system_prompt = get_csv_system_prompt(df) |
|
|
|
for attempt in range(max_retries): |
|
try: |
|
model = instance_provider.get_instance() |
|
if model is None: |
|
raise RuntimeError("No available API instances") |
|
|
|
csv_agent = Agent( |
|
model=model, |
|
output_type=CsvChatResult, |
|
system_prompt=csv_system_prompt, |
|
) |
|
|
|
return csv_agent |
|
|
|
except Exception as e: |
|
api_key = instance_provider.get_api_key_for_model(model) |
|
if api_key: |
|
print(f"Error with API key (attempt {attempt + 1}): {str(e)}") |
|
instance_provider.report_error(api_key) |
|
continue |
|
|
|
raise RuntimeError(f"Failed to create agent after {max_retries} attempts") |
|
|
|
|
|
async def query_csv_agent(csv_url: str, question: str, chat_id: str) -> str: |
|
"""Query the CSV agent with a DataFrame and question and return formatted output""" |
|
|
|
|
|
df = clean_data(csv_url) |
|
|
|
|
|
agent = create_csv_agent(df) |
|
result = await agent.run(question) |
|
|
|
|
|
executor = PythonExecutor(df) |
|
|
|
|
|
if not isinstance(result.output, CsvChatResult): |
|
|
|
try: |
|
response_data = result.output if isinstance(result.output, dict) else json.loads(result.output) |
|
chat_result = CsvChatResult(**response_data) |
|
except Exception as e: |
|
raise ValueError(f"Could not parse agent response: {str(e)}") |
|
else: |
|
chat_result = result.output |
|
|
|
print("Chat Result Original Object:", chat_result) |
|
|
|
|
|
formatted_output = await executor.process_response(chat_result, chat_id) |
|
|
|
return formatted_output |
|
|