FastApi / csv_service.py
Soumik555's picture
first hf space commit
8cb6e00
raw
history blame
4.04 kB
import re
from fastapi.responses import FileResponse
import numpy as np
import pandas as pd
from langchain_core.prompts import ChatPromptTemplate
def generate_csv_data(csv_url):
try:
# Fetch the data from the URL
data = pd.read_csv(csv_url)
data = data.where(pd.notnull(data), '')
data_list = data.to_dict(orient='records')
return data_list
except Exception as e:
print(f"Error occurred while reading CSV: {e}")
return {"error": str(e)}
def clean_data(csv_url):
data = pd.read_csv(csv_url)
if not isinstance(data, pd.DataFrame):
raise ValueError("Input must be a pandas DataFrame.")
try:
# Remove duplicate rows
data = data.drop_duplicates()
# Strip whitespace from string columns
for column in data.select_dtypes(include=['object']).columns:
data[column] = data[column].str.strip()
# Replace infinite values with NaN
data.replace([np.inf, -np.inf], np.nan, inplace=True)
# Fill NaN values based on column data types
for column in data.columns:
if data[column].dtype == 'object': # String type
data[column] = data[column].fillna('')
elif data[column].dtype == 'float64': # Float type
data[column] = data[column].fillna(0.0)
elif data[column].dtype == 'int64': # Integer type
data[column] = data[column].fillna(0)
elif data[column].dtype == 'bool': # Boolean type
data[column] = data[column].fillna(False)
elif data[column].dtype == 'datetime64[ns]': # Datetime type
data[column] = data[column].fillna(pd.NaT)
elif data[column].dtype == 'timedelta64[ns]': # Timedelta type
data[column] = data[column].fillna(pd.Timedelta(0))
elif data[column].dtype.name == 'category': # Categorical type
data[column] = data[column].fillna(data[column].cat.categories[0] if len(data[column].cat.categories) > 0 else None)
elif data[column].dtype == 'complex128': # Complex number type
data[column] = data[column].fillna(complex(0, 0))
else: # For other types, default to None
data[column] = data[column].fillna(None)
# Remove constant columns (columns with only one unique value)
constant_columns = [col for col in data.columns if data[col].nunique() <= 1]
data = data.drop(columns=constant_columns)
# print(f"Data cleaning complete. Removed {len(constant_columns)} constant columns and duplicates.")
return data
except Exception as e:
# print(f"Error occurred during data cleaning: {e}")
raise e
def get_csv_basic_info(csv_path):
"""
Get basic information about a CSV file including:
- Row count
- Column count
- Column names
- First two rows
Parameters:
csv_path (str): Path to the CSV file
Returns:
dict: Dictionary containing basic file information or error message
"""
try:
# Read and clean the CSV file
df = clean_data(csv_path)
print(f"CSV file read successfully: {csv_path}")
# Prepare the basic info
info = {
'row_count': df.shape[0],
'col_count': df.shape[1],
'col_names': df.columns.tolist(),
'first_two_rows': df.head(2).to_dict('records'),
'error': None
}
return info
except Exception as e:
error_info = {
'error': f"Error reading CSV file: {str(e)}",
}
return error_info
def get_image_by_file_name(file_name):
return FileResponse(file_name)
def extract_chart_filenames(response: str) -> list:
# Regex pattern to match chart filenames
pattern = r'chart_[a-f0-9-]+\.png'
# Find all matches in the response
matches = re.findall(pattern, response)
return matches