infinitymatter's picture
Create app.py
40dad27 verified
raw
history blame
3.79 kB
import streamlit as st
import pandas as pd
import os
import json
import requests
from ctgan import CTGAN
from sklearn.preprocessing import LabelEncoder
def generate_schema(prompt):
"""Fetches schema from Hugging Face Spaces API."""
API_URL = "https://infinitymatter-Synthetic_Data_Generator_SRIJAN.hf.space/run/predict"
# Fetch API token securely
hf_token = st.secrets["hf_token"]
headers = {"Authorization": f"Bearer {hf_token}"}
payload = {"data": [prompt]}
try:
response = requests.post(API_URL, headers=headers, json=payload)
response.raise_for_status()
schema = response.json()
if 'columns' not in schema or 'types' not in schema or 'size' not in schema:
raise ValueError("Invalid schema format!")
return schema
except requests.exceptions.RequestException as e:
st.error(f"❌ API request failed: {e}")
return None
except json.JSONDecodeError:
st.error("❌ Failed to parse JSON response.")
return None
def train_and_generate_synthetic(real_data, schema, output_path):
"""Trains a CTGAN model and generates synthetic data."""
categorical_cols = [col for col, dtype in zip(schema['columns'], schema['types']) if dtype == 'string']
# Store label encoders
label_encoders = {}
for col in categorical_cols:
le = LabelEncoder()
real_data[col] = le.fit_transform(real_data[col])
label_encoders[col] = le
# Train CTGAN
gan = CTGAN(epochs=300)
gan.fit(real_data, categorical_cols)
# Generate synthetic data
synthetic_data = gan.sample(schema['size'])
# Decode categorical columns
for col in categorical_cols:
synthetic_data[col] = label_encoders[col].inverse_transform(synthetic_data[col])
# Save to CSV
os.makedirs('outputs', exist_ok=True)
synthetic_data.to_csv(output_path, index=False)
st.success(f"βœ… Synthetic data saved to {output_path}")
def fetch_data(domain):
"""Fetches real data for the given domain and ensures it's a valid DataFrame."""
data_path = f"datasets/{domain}.csv"
if os.path.exists(data_path):
df = pd.read_csv(data_path)
if not isinstance(df, pd.DataFrame) or df.empty:
raise ValueError("❌ Loaded data is invalid!")
return df
else:
st.error(f"❌ Dataset for {domain} not found.")
return None
st.title("✨ AI-Powered Synthetic Dataset Generator")
st.write("Give a short description of the dataset you need, and AI will generate it for you using real data + GANs!")
# User input
user_prompt = st.text_input("Describe the dataset (e.g., 'Create dataset for hospital patients')", "")
domain = st.selectbox("Select Domain for Real Data", ["healthcare", "finance", "retail", "other"])
data = None
if st.button("Generate Schema"):
if user_prompt.strip():
with st.spinner("Generating schema..."):
schema = generate_schema(user_prompt)
if schema is None:
st.error("❌ Schema generation failed. Please check API response.")
else:
st.success("βœ… Schema generated successfully!")
st.json(schema)
data = fetch_data(domain)
else:
st.warning("⚠️ Please enter a dataset description before generating the schema.")
if data is not None and schema is not None:
output_path = "outputs/synthetic_data.csv"
if st.button("Generate Synthetic Data"):
with st.spinner("Training GAN and generating synthetic data..."):
train_and_generate_synthetic(data, schema, output_path)
with open(output_path, "rb") as file:
st.download_button("Download Synthetic Data", file, file_name="synthetic_data.csv", mime="text/csv")