FastApi / together_ai_llama_agent.py
Soumik555's picture
changed prompt
69a0b7f
raw
history blame
6.5 kB
import pandas as pd
import json
from typing import List, Literal, Optional
from pydantic import BaseModel
from dotenv import load_dotenv
from pydantic_ai import Agent
from csv_service import clean_data
from python_code_executor_service import PythonExecutor
from together_ai_instance_provider import InstanceProvider
load_dotenv()
instance_provider = InstanceProvider()
class CodeResponse(BaseModel):
"""Container for code-related responses"""
language: str = "python"
code: str
class ChartSpecification(BaseModel):
"""Details about requested charts"""
image_description: str
code: Optional[str] = None
class AnalysisOperation(BaseModel):
"""Container for a single analysis operation with its code and result"""
code: CodeResponse
result_var: str
class CsvChatResult(BaseModel):
"""Structured response for CSV-related AI interactions"""
response_type: Literal["casual", "data_analysis", "visualization", "mixed"]
# Casual chat response
casual_response: str
# Data analysis components
analysis_operations: List[AnalysisOperation]
# Visualization components
charts: Optional[List[ChartSpecification]] = None
def get_csv_info(df: pd.DataFrame) -> dict:
"""Get metadata/info about the CSV"""
info = {
'num_rows': len(df),
'num_cols': len(df.columns),
'example_rows': df.head(2).to_dict('records'),
'dtypes': {col: str(df[col].dtype) for col in df.columns},
'columns': list(df.columns),
'numeric_columns': [col for col in df.columns if pd.api.types.is_numeric_dtype(df[col])],
'categorical_columns': [col for col in df.columns if pd.api.types.is_string_dtype(df[col])]
}
return info
def get_csv_system_prompt(df: pd.DataFrame) -> str:
"""Generate system prompt for CSV analysis"""
csv_info = get_csv_info(df)
prompt = f"""
You're a CSV analysis assistant. The pandas DataFrame is loaded as 'df' - use this variable.
CSV Info:
- Rows: {csv_info['num_rows']}, Cols: {csv_info['num_cols']}
- Columns: {csv_info['columns']}
- Sample: {csv_info['example_rows']}
- Dtypes: {csv_info['dtypes']}
Strict Rules:
1. Never recreate 'df' - use the existing variable
2. For analysis:
- Complete code without imports
- Use df directly (e.g., print(df[...].mean()))
3. For visualizations:
- Create the most professional, publication-quality charts possible
- Maximize descriptive elements and detail while maintaining clarity
- Figure size: (14, 8) for complex charts, (12, 6) for simpler ones
- Use comprehensive titles (fontsize=16) and axis labels (fontsize=14)
- Include informative legends (fontsize=12) when appropriate
- Add annotations for important data points where valuable
- Rotate x-labels (45° if needed) with fontsize=12 for readability
- Use colorblind-friendly palettes (seaborn 'deep', 'muted', or 'colorblind')
- Add gridlines (alpha=0.3) when they improve readability
- Include proper margins and padding to prevent label cutoff
- For distributions, include kernel density estimates when appropriate
- For time series, use appropriate date formatting and markers
- Do not use any visualization library other than matplotlib or seaborn
- Complete code with plt.tight_layout() before plt.show()
- Example professional chart:
plt.figure(figsize=(14, 8))
ax = sns.barplot(x='category', y='value', data=df, palette='muted', ci=None)
plt.title('Detailed Analysis of Values by Category', fontsize=16, pad=20)
plt.xlabel('Category', fontsize=14)
plt.ylabel('Average Value', fontsize=14)
plt.xticks(rotation=45, ha='right', fontsize=12)
plt.yticks(fontsize=12)
ax.grid(True, linestyle='--', alpha=0.3)
for p in ax.patches:
ax.annotate(f'{{p.get_height():.1f}}',
(p.get_x() + p.get_width() / 2., p.get_height()),
ha='center', va='center',
xytext=(0, 10),
textcoords='offset points',
fontsize=12)
plt.tight_layout()
plt.show()
4. For Lists, Records, Tables, Dictionaries...etc for any data structure, always return them as JSON with correct indentation.
IMPORTANT: Code must be syntactically perfect and executable as-is.
"""
return prompt
def create_csv_agent(df: pd.DataFrame, max_retries: int = 1) -> Agent:
"""Create and return a CSV analysis agent with API key rotation"""
csv_system_prompt = get_csv_system_prompt(df)
for attempt in range(max_retries):
try:
model = instance_provider.get_instance()
if model is None:
raise RuntimeError("No available API instances")
csv_agent = Agent(
model=model,
output_type=CsvChatResult,
system_prompt=csv_system_prompt,
)
return csv_agent
except Exception as e:
api_key = instance_provider.get_api_key_for_model(model)
if api_key:
print(f"Error with API key (attempt {attempt + 1}): {str(e)}")
instance_provider.report_error(api_key)
continue
raise RuntimeError(f"Failed to create agent after {max_retries} attempts")
async def query_csv_agent(csv_url: str, question: str, chat_id: str) -> str:
"""Query the CSV agent with a DataFrame and question and return formatted output"""
# Get the DataFrame from the CSV URL
df = clean_data(csv_url)
# Create agent and get response
agent = create_csv_agent(df)
result = await agent.run(question)
# Process the response through PythonExecutor
executor = PythonExecutor(df)
# Convert the raw output to CsvChatResult if needed
if not isinstance(result.output, CsvChatResult):
# Handle case where output needs conversion
try:
response_data = result.output if isinstance(result.output, dict) else json.loads(result.output)
chat_result = CsvChatResult(**response_data)
except Exception as e:
raise ValueError(f"Could not parse agent response: {str(e)}")
else:
chat_result = result.output
print("Chat Result Original Object:", chat_result)
# Process and format the response
formatted_output = await executor.process_response(chat_result, chat_id)
return formatted_output