|
import re |
|
from fastapi.responses import FileResponse |
|
import numpy as np |
|
import pandas as pd |
|
from langchain_core.prompts import ChatPromptTemplate |
|
|
|
|
|
|
|
def generate_csv_data(csv_url): |
|
|
|
try: |
|
|
|
data = pd.read_csv(csv_url) |
|
data = data.where(pd.notnull(data), '') |
|
data_list = data.to_dict(orient='records') |
|
return data_list |
|
except Exception as e: |
|
print(f"Error occurred while reading CSV: {e}") |
|
return {"error": str(e)} |
|
|
|
|
|
|
|
|
|
def clean_data(csv_url): |
|
data = pd.read_csv(csv_url) |
|
if not isinstance(data, pd.DataFrame): |
|
raise ValueError("Input must be a pandas DataFrame.") |
|
|
|
try: |
|
|
|
data = data.drop_duplicates() |
|
|
|
|
|
for column in data.select_dtypes(include=['object']).columns: |
|
data[column] = data[column].str.strip() |
|
|
|
|
|
data.replace([np.inf, -np.inf], np.nan, inplace=True) |
|
|
|
|
|
for column in data.columns: |
|
if data[column].dtype == 'object': |
|
data[column] = data[column].fillna('') |
|
elif data[column].dtype == 'float64': |
|
data[column] = data[column].fillna(0.0) |
|
elif data[column].dtype == 'int64': |
|
data[column] = data[column].fillna(0) |
|
elif data[column].dtype == 'bool': |
|
data[column] = data[column].fillna(False) |
|
elif data[column].dtype == 'datetime64[ns]': |
|
data[column] = data[column].fillna(pd.NaT) |
|
elif data[column].dtype == 'timedelta64[ns]': |
|
data[column] = data[column].fillna(pd.Timedelta(0)) |
|
elif data[column].dtype.name == 'category': |
|
data[column] = data[column].fillna(data[column].cat.categories[0] if len(data[column].cat.categories) > 0 else None) |
|
elif data[column].dtype == 'complex128': |
|
data[column] = data[column].fillna(complex(0, 0)) |
|
else: |
|
data[column] = data[column].fillna(None) |
|
|
|
|
|
constant_columns = [col for col in data.columns if data[col].nunique() <= 1] |
|
data = data.drop(columns=constant_columns) |
|
|
|
return data |
|
|
|
except Exception as e: |
|
|
|
raise e |
|
|
|
|
|
|
|
def get_csv_basic_info(csv_path): |
|
""" |
|
Get basic information about a CSV file including: |
|
- Row count |
|
- Column count |
|
- Column names |
|
- First two rows |
|
|
|
Parameters: |
|
csv_path (str): Path to the CSV file |
|
|
|
Returns: |
|
dict: Dictionary containing basic file information or error message |
|
""" |
|
|
|
|
|
try: |
|
|
|
df = clean_data(csv_path) |
|
|
|
print(f"CSV file read successfully: {csv_path}") |
|
|
|
|
|
info = { |
|
'row_count': df.shape[0], |
|
'col_count': df.shape[1], |
|
'col_names': df.columns.tolist(), |
|
'first_two_rows': df.head(2).to_dict('records'), |
|
'error': None |
|
} |
|
return info |
|
except Exception as e: |
|
error_info = { |
|
'error': f"Error reading CSV file: {str(e)}", |
|
} |
|
return error_info |
|
|
|
|
|
|
|
def get_image_by_file_name(file_name): |
|
return FileResponse(file_name) |
|
|
|
|
|
def extract_chart_filenames(response: str) -> list: |
|
|
|
pattern = r'chart_[a-f0-9-]+\.png' |
|
|
|
|
|
matches = re.findall(pattern, response) |
|
|
|
return matches |