|
import streamlit as st |
|
import pandas as pd |
|
import os |
|
import json |
|
import requests |
|
from ctgan import CTGAN |
|
from sklearn.preprocessing import LabelEncoder |
|
|
|
def generate_schema(prompt): |
|
"""Fetches schema from Hugging Face Spaces API.""" |
|
API_URL = "https://infinitymatter-Synthetic_Data_Generator_SRIJAN.hf.space/run/predict" |
|
|
|
|
|
hf_token = st.secrets["hf_token"] |
|
headers = {"Authorization": f"Bearer {hf_token}"} |
|
|
|
payload = {"data": [prompt]} |
|
|
|
try: |
|
response = requests.post(API_URL, headers=headers, json=payload) |
|
response.raise_for_status() |
|
schema = response.json() |
|
|
|
if 'columns' not in schema or 'types' not in schema or 'size' not in schema: |
|
raise ValueError("Invalid schema format!") |
|
|
|
return schema |
|
except requests.exceptions.RequestException as e: |
|
st.error(f"β API request failed: {e}") |
|
return None |
|
except json.JSONDecodeError: |
|
st.error("β Failed to parse JSON response.") |
|
return None |
|
|
|
|
|
def train_and_generate_synthetic(real_data, schema, output_path): |
|
"""Trains a CTGAN model and generates synthetic data.""" |
|
categorical_cols = [col for col, dtype in zip(schema['columns'], schema['types']) if dtype == 'string'] |
|
|
|
|
|
label_encoders = {} |
|
for col in categorical_cols: |
|
le = LabelEncoder() |
|
real_data[col] = le.fit_transform(real_data[col]) |
|
label_encoders[col] = le |
|
|
|
|
|
gan = CTGAN(epochs=300) |
|
gan.fit(real_data, categorical_cols) |
|
|
|
|
|
synthetic_data = gan.sample(schema['size']) |
|
|
|
|
|
for col in categorical_cols: |
|
synthetic_data[col] = label_encoders[col].inverse_transform(synthetic_data[col]) |
|
|
|
|
|
os.makedirs('outputs', exist_ok=True) |
|
synthetic_data.to_csv(output_path, index=False) |
|
st.success(f"β
Synthetic data saved to {output_path}") |
|
|
|
def fetch_data(domain): |
|
"""Fetches real data for the given domain and ensures it's a valid DataFrame.""" |
|
data_path = f"datasets/{domain}.csv" |
|
if os.path.exists(data_path): |
|
df = pd.read_csv(data_path) |
|
if not isinstance(df, pd.DataFrame) or df.empty: |
|
raise ValueError("β Loaded data is invalid!") |
|
return df |
|
else: |
|
st.error(f"β Dataset for {domain} not found.") |
|
return None |
|
|
|
st.title("β¨ AI-Powered Synthetic Dataset Generator") |
|
st.write("Give a short description of the dataset you need, and AI will generate it for you using real data + GANs!") |
|
|
|
|
|
user_prompt = st.text_input("Describe the dataset (e.g., 'Create dataset for hospital patients')", "") |
|
domain = st.selectbox("Select Domain for Real Data", ["healthcare", "finance", "retail", "other"]) |
|
|
|
data = None |
|
if st.button("Generate Schema"): |
|
if user_prompt.strip(): |
|
with st.spinner("Generating schema..."): |
|
schema = generate_schema(user_prompt) |
|
|
|
if schema is None: |
|
st.error("β Schema generation failed. Please check API response.") |
|
else: |
|
st.success("β
Schema generated successfully!") |
|
st.json(schema) |
|
data = fetch_data(domain) |
|
else: |
|
st.warning("β οΈ Please enter a dataset description before generating the schema.") |
|
|
|
if data is not None and schema is not None: |
|
output_path = "outputs/synthetic_data.csv" |
|
if st.button("Generate Synthetic Data"): |
|
with st.spinner("Training GAN and generating synthetic data..."): |
|
train_and_generate_synthetic(data, schema, output_path) |
|
with open(output_path, "rb") as file: |
|
st.download_button("Download Synthetic Data", file, file_name="synthetic_data.csv", mime="text/csv") |
|
|