infinitymatter commited on
Commit
8d4657a
Β·
verified Β·
1 Parent(s): f0edb7b

Update synthetic_generator.py

Browse files
Files changed (1) hide show
  1. synthetic_generator.py +69 -69
synthetic_generator.py CHANGED
@@ -1,69 +1,69 @@
1
- import pandas as pd
2
- from ctgan import CTGAN
3
- from sklearn.preprocessing import LabelEncoder
4
- import os
5
- import json
6
- import requests
7
-
8
- def train_and_generate_synthetic(real_data, schema, output_path):
9
- """Trains a CTGAN model and generates synthetic data."""
10
- categorical_cols = [col for col, dtype in zip(schema['columns'], schema['types']) if dtype == 'string']
11
-
12
- # Store label encoders
13
- label_encoders = {}
14
- for col in categorical_cols:
15
- le = LabelEncoder()
16
- real_data[col] = le.fit_transform(real_data[col])
17
- label_encoders[col] = le
18
-
19
- # Train CTGAN
20
- gan = CTGAN(epochs=300)
21
- gan.fit(real_data, categorical_cols)
22
-
23
- # Generate synthetic data
24
- synthetic_data = gan.sample(schema['size'])
25
-
26
- # Decode categorical columns
27
- for col in categorical_cols:
28
- synthetic_data[col] = label_encoders[col].inverse_transform(synthetic_data[col])
29
-
30
- # Save to CSV
31
- os.makedirs('outputs', exist_ok=True)
32
- synthetic_data.to_csv(output_path, index=False)
33
- print(f"βœ… Synthetic data saved to {output_path}")
34
-
35
- def generate_schema(prompt):
36
- """Fetches schema from an external API and validates JSON."""
37
- API_URL = "https://api.example.com/schema" # Replace with correct API URL
38
- headers = {"Authorization": f"Bearer YOUR_HUGGINGFACE_TOKEN"} # Add if needed
39
-
40
- try:
41
- response = requests.post(API_URL, json={"prompt": prompt}, headers=headers)
42
- print("πŸ” Raw API Response:", response.text) # Debugging line
43
-
44
- schema = response.json()
45
-
46
- # Validate required keys
47
- if 'columns' not in schema or 'types' not in schema or 'size' not in schema:
48
- raise ValueError("❌ Invalid schema format! Expected keys: 'columns', 'types', 'size'")
49
-
50
- print("βœ… Valid Schema Received:", schema) # Debugging line
51
- return schema
52
-
53
- except json.JSONDecodeError:
54
- print("❌ Failed to parse JSON response. API might be down or returning non-JSON data.")
55
- return None
56
- except requests.exceptions.RequestException as e:
57
- print(f"❌ API request failed: {e}")
58
- return None
59
-
60
- def fetch_data(domain):
61
- """Fetches real data for the given domain and ensures it's a valid DataFrame."""
62
- data_path = f"datasets/{domain}.csv"
63
- if os.path.exists(data_path):
64
- df = pd.read_csv(data_path)
65
- if not isinstance(df, pd.DataFrame) or df.empty:
66
- raise ValueError("❌ Loaded data is invalid!")
67
- return df
68
- else:
69
- raise FileNotFoundError(f"❌ Dataset for {domain} not found.")
 
1
+ import pandas as pd
2
+ from ctgan import CTGAN
3
+ from sklearn.preprocessing import LabelEncoder
4
+ import os
5
+ import json
6
+ import requests
7
+
8
+ def train_and_generate_synthetic(real_data, schema, output_path):
9
+ """Trains a CTGAN model and generates synthetic data."""
10
+ categorical_cols = [col for col, dtype in zip(schema['columns'], schema['types']) if dtype == 'string']
11
+
12
+ # Store label encoders
13
+ label_encoders = {}
14
+ for col in categorical_cols:
15
+ le = LabelEncoder()
16
+ real_data[col] = le.fit_transform(real_data[col])
17
+ label_encoders[col] = le
18
+
19
+ # Train CTGAN
20
+ gan = CTGAN(epochs=300)
21
+ gan.fit(real_data, categorical_cols)
22
+
23
+ # Generate synthetic data
24
+ synthetic_data = gan.sample(schema['size'])
25
+
26
+ # Decode categorical columns
27
+ for col in categorical_cols:
28
+ synthetic_data[col] = label_encoders[col].inverse_transform(synthetic_data[col])
29
+
30
+ # Save to CSV
31
+ os.makedirs('outputs', exist_ok=True)
32
+ synthetic_data.to_csv(output_path, index=False)
33
+ print(f"βœ… Synthetic data saved to {output_path}")
34
+
35
+ def generate_schema(prompt):
36
+ """Fetches schema from an external API and validates JSON."""
37
+ API_URL = "https://infinitymatter-synthetic-data-generator-srijan.hf.space/run/predict"
38
+ headers = {"Authorization": f"Bearer hf_token"} # Add if needed
39
+
40
+ try:
41
+ response = requests.post(API_URL, json={"prompt": prompt}, headers=headers)
42
+ print("πŸ” Raw API Response:", response.text) # Debugging line
43
+
44
+ schema = response.json()
45
+
46
+ # Validate required keys
47
+ if 'columns' not in schema or 'types' not in schema or 'size' not in schema:
48
+ raise ValueError("❌ Invalid schema format! Expected keys: 'columns', 'types', 'size'")
49
+
50
+ print("βœ… Valid Schema Received:", schema) # Debugging line
51
+ return schema
52
+
53
+ except json.JSONDecodeError:
54
+ print("❌ Failed to parse JSON response. API might be down or returning non-JSON data.")
55
+ return None
56
+ except requests.exceptions.RequestException as e:
57
+ print(f"❌ API request failed: {e}")
58
+ return None
59
+
60
+ def fetch_data(domain):
61
+ """Fetches real data for the given domain and ensures it's a valid DataFrame."""
62
+ data_path = f"datasets/{domain}.csv"
63
+ if os.path.exists(data_path):
64
+ df = pd.read_csv(data_path)
65
+ if not isinstance(df, pd.DataFrame) or df.empty:
66
+ raise ValueError("❌ Loaded data is invalid!")
67
+ return df
68
+ else:
69
+ raise FileNotFoundError(f"❌ Dataset for {domain} not found.")