Synthetic_Data_Generator / synthetic_generator.py
infinitymatter's picture
Create synthetic_generator.py
1738d47 verified
raw
history blame
2.59 kB
import pandas as pd
from ctgan import CTGAN
from sklearn.preprocessing import LabelEncoder
import os
import json
import requests
import streamlit as st
def train_and_generate_synthetic(real_data, schema, output_path):
"""Trains a CTGAN model and generates synthetic data."""
categorical_cols = [col for col, dtype in zip(schema['columns'], schema['types']) if dtype == 'string']
# Store label encoders
label_encoders = {}
for col in categorical_cols:
le = LabelEncoder()
real_data[col] = le.fit_transform(real_data[col])
label_encoders[col] = le
# Train CTGAN
gan = CTGAN(epochs=300)
gan.fit(real_data, categorical_cols)
# Generate synthetic data
synthetic_data = gan.sample(schema['size'])
# Decode categorical columns
for col in categorical_cols:
synthetic_data[col] = label_encoders[col].inverse_transform(synthetic_data[col])
# Save to CSV
os.makedirs('outputs', exist_ok=True)
synthetic_data.to_csv(output_path, index=False)
print(f"βœ… Synthetic data saved to {output_path}")
def generate_schema(prompt):
"""Fetches schema from an external API and validates JSON."""
API_URL = "https://infinitymatter-synthetic-data-generator-srijan.hf.space/run/predict"
headers = {"Authorization": f"Bearer {st.secrets['hf_token']}"}
try:
response = requests.post(API_URL, json={"prompt": prompt}, headers=headers)
print("πŸ” Raw API Response:", response.text) # Debugging line
schema = response.json()
# Validate required keys
if 'columns' not in schema or 'types' not in schema or 'size' not in schema:
raise ValueError("❌ Invalid schema format! Expected keys: 'columns', 'types', 'size'")
print("βœ… Valid Schema Received:", schema) # Debugging line
return schema
except json.JSONDecodeError:
print("❌ Failed to parse JSON response. API might be down or returning non-JSON data.")
return None
except requests.exceptions.RequestException as e:
print(f"❌ API request failed: {e}")
return None
def fetch_data(domain):
"""Fetches real data for the given domain and ensures it's a valid DataFrame."""
data_path = f"datasets/{domain}.csv"
if os.path.exists(data_path):
df = pd.read_csv(data_path)
if not isinstance(df, pd.DataFrame) or df.empty:
raise ValueError("❌ Loaded data is invalid!")
return df
else:
raise FileNotFoundError(f"❌ Dataset for {domain} not found.")