infinitymatter commited on
Commit
1738d47
Β·
verified Β·
1 Parent(s): 130012e

Create synthetic_generator.py

Browse files
Files changed (1) hide show
  1. synthetic_generator.py +70 -0
synthetic_generator.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from ctgan import CTGAN
3
+ from sklearn.preprocessing import LabelEncoder
4
+ import os
5
+ import json
6
+ import requests
7
+ import streamlit as st
8
+
9
+ def train_and_generate_synthetic(real_data, schema, output_path):
10
+ """Trains a CTGAN model and generates synthetic data."""
11
+ categorical_cols = [col for col, dtype in zip(schema['columns'], schema['types']) if dtype == 'string']
12
+
13
+ # Store label encoders
14
+ label_encoders = {}
15
+ for col in categorical_cols:
16
+ le = LabelEncoder()
17
+ real_data[col] = le.fit_transform(real_data[col])
18
+ label_encoders[col] = le
19
+
20
+ # Train CTGAN
21
+ gan = CTGAN(epochs=300)
22
+ gan.fit(real_data, categorical_cols)
23
+
24
+ # Generate synthetic data
25
+ synthetic_data = gan.sample(schema['size'])
26
+
27
+ # Decode categorical columns
28
+ for col in categorical_cols:
29
+ synthetic_data[col] = label_encoders[col].inverse_transform(synthetic_data[col])
30
+
31
+ # Save to CSV
32
+ os.makedirs('outputs', exist_ok=True)
33
+ synthetic_data.to_csv(output_path, index=False)
34
+ print(f"βœ… Synthetic data saved to {output_path}")
35
+
36
+ def generate_schema(prompt):
37
+ """Fetches schema from an external API and validates JSON."""
38
+ API_URL = "https://infinitymatter-synthetic-data-generator-srijan.hf.space/run/predict"
39
+ headers = {"Authorization": f"Bearer {st.secrets['hf_token']}"}
40
+
41
+ try:
42
+ response = requests.post(API_URL, json={"prompt": prompt}, headers=headers)
43
+ print("πŸ” Raw API Response:", response.text) # Debugging line
44
+
45
+ schema = response.json()
46
+
47
+ # Validate required keys
48
+ if 'columns' not in schema or 'types' not in schema or 'size' not in schema:
49
+ raise ValueError("❌ Invalid schema format! Expected keys: 'columns', 'types', 'size'")
50
+
51
+ print("βœ… Valid Schema Received:", schema) # Debugging line
52
+ return schema
53
+
54
+ except json.JSONDecodeError:
55
+ print("❌ Failed to parse JSON response. API might be down or returning non-JSON data.")
56
+ return None
57
+ except requests.exceptions.RequestException as e:
58
+ print(f"❌ API request failed: {e}")
59
+ return None
60
+
61
+ def fetch_data(domain):
62
+ """Fetches real data for the given domain and ensures it's a valid DataFrame."""
63
+ data_path = f"datasets/{domain}.csv"
64
+ if os.path.exists(data_path):
65
+ df = pd.read_csv(data_path)
66
+ if not isinstance(df, pd.DataFrame) or df.empty:
67
+ raise ValueError("❌ Loaded data is invalid!")
68
+ return df
69
+ else:
70
+ raise FileNotFoundError(f"❌ Dataset for {domain} not found.")