Update app.py
Browse files
app.py
CHANGED
@@ -4,7 +4,7 @@ from pydantic import BaseModel
|
|
4 |
import pandas as pd
|
5 |
import os
|
6 |
import requests
|
7 |
-
from transformers import GPT2LMHeadModel, GPT2Tokenizer, AutoTokenizer
|
8 |
from io import StringIO
|
9 |
from fastapi.middleware.cors import CORSMiddleware
|
10 |
from huggingface_hub import HfFolder
|
@@ -95,51 +95,36 @@ def generate_synthetic_data(description, columns):
|
|
95 |
|
96 |
return response_data[0]["generated_text"]
|
97 |
|
98 |
-
def extract_valid_csv(csv_data, expected_columns):
|
99 |
-
lines = csv_data.split('\n')
|
100 |
-
valid_lines = []
|
101 |
-
header_found = False
|
102 |
-
|
103 |
-
for line in lines:
|
104 |
-
if header_found:
|
105 |
-
if line.strip() == '':
|
106 |
-
continue
|
107 |
-
valid_lines.append(line)
|
108 |
-
elif set(line.split(',')) == set(expected_columns):
|
109 |
-
header_found = True
|
110 |
-
valid_lines.append(line)
|
111 |
-
|
112 |
-
valid_csv_data = '\n'.join(valid_lines)
|
113 |
-
return valid_csv_data
|
114 |
-
|
115 |
def process_generated_data(csv_data, expected_columns):
|
116 |
try:
|
117 |
-
|
118 |
-
|
|
|
119 |
|
|
|
120 |
df = pd.read_csv(data, delimiter=',')
|
121 |
|
|
|
122 |
if set(df.columns) != set(expected_columns):
|
123 |
return f"Unexpected columns in the generated data: {df.columns}"
|
124 |
|
125 |
return df
|
126 |
except pd.errors.ParserError as e:
|
127 |
-
logging.error(f"Failed to parse CSV data: {e}")
|
128 |
return f"Failed to parse CSV data: {e}"
|
129 |
|
130 |
def generate_large_synthetic_data(description, columns, num_rows=1000, rows_per_generation=100):
|
131 |
csv_data_all = ""
|
132 |
|
133 |
-
for _ in range(num_rows // rows_per_generation):
|
134 |
generated_data = generate_synthetic_data(description, columns)
|
135 |
if "Error" in generated_data:
|
136 |
-
return generated_data
|
137 |
|
138 |
df_synthetic = process_generated_data(generated_data, columns)
|
139 |
if isinstance(df_synthetic, pd.DataFrame) and not df_synthetic.empty:
|
140 |
csv_data_all += df_synthetic.to_csv(index=False, header=False)
|
141 |
else:
|
142 |
-
|
143 |
|
144 |
if csv_data_all:
|
145 |
return csv_data_all
|
@@ -155,6 +140,7 @@ def generate_data(request: DataGenerationRequest):
|
|
155 |
if isinstance(generated_data, str) and "Error" in generated_data:
|
156 |
return JSONResponse(content={"error": generated_data}, status_code=500)
|
157 |
|
|
|
158 |
csv_buffer = StringIO(generated_data)
|
159 |
return StreamingResponse(
|
160 |
csv_buffer,
|
@@ -162,7 +148,6 @@ def generate_data(request: DataGenerationRequest):
|
|
162 |
headers={"Content-Disposition": "attachment; filename=generated_data.csv"}
|
163 |
)
|
164 |
|
165 |
-
|
166 |
@app.get("/")
|
167 |
def greet_json():
|
168 |
return {"Hello": "World!"}
|
|
|
4 |
import pandas as pd
|
5 |
import os
|
6 |
import requests
|
7 |
+
from transformers import GPT2LMHeadModel, GPT2Tokenizer, AutoTokenizer
|
8 |
from io import StringIO
|
9 |
from fastapi.middleware.cors import CORSMiddleware
|
10 |
from huggingface_hub import HfFolder
|
|
|
95 |
|
96 |
return response_data[0]["generated_text"]
|
97 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
def process_generated_data(csv_data, expected_columns):
|
99 |
try:
|
100 |
+
# Ensure the data is cleaned and correctly formatted
|
101 |
+
cleaned_data = csv_data.replace('\r\n', '\n').replace('\r', '\n')
|
102 |
+
data = StringIO(cleaned_data)
|
103 |
|
104 |
+
# Read the CSV data
|
105 |
df = pd.read_csv(data, delimiter=',')
|
106 |
|
107 |
+
# Check if the DataFrame has the expected columns
|
108 |
if set(df.columns) != set(expected_columns):
|
109 |
return f"Unexpected columns in the generated data: {df.columns}"
|
110 |
|
111 |
return df
|
112 |
except pd.errors.ParserError as e:
|
|
|
113 |
return f"Failed to parse CSV data: {e}"
|
114 |
|
115 |
def generate_large_synthetic_data(description, columns, num_rows=1000, rows_per_generation=100):
|
116 |
csv_data_all = ""
|
117 |
|
118 |
+
for _ in tqdm(range(num_rows // rows_per_generation), desc="Generating Data"):
|
119 |
generated_data = generate_synthetic_data(description, columns)
|
120 |
if "Error" in generated_data:
|
121 |
+
return generated_data # Return the error message
|
122 |
|
123 |
df_synthetic = process_generated_data(generated_data, columns)
|
124 |
if isinstance(df_synthetic, pd.DataFrame) and not df_synthetic.empty:
|
125 |
csv_data_all += df_synthetic.to_csv(index=False, header=False)
|
126 |
else:
|
127 |
+
print("Skipping invalid generation.")
|
128 |
|
129 |
if csv_data_all:
|
130 |
return csv_data_all
|
|
|
140 |
if isinstance(generated_data, str) and "Error" in generated_data:
|
141 |
return JSONResponse(content={"error": generated_data}, status_code=500)
|
142 |
|
143 |
+
# Create a streaming response to return the CSV data
|
144 |
csv_buffer = StringIO(generated_data)
|
145 |
return StreamingResponse(
|
146 |
csv_buffer,
|
|
|
148 |
headers={"Content-Disposition": "attachment; filename=generated_data.csv"}
|
149 |
)
|
150 |
|
|
|
151 |
@app.get("/")
|
152 |
def greet_json():
|
153 |
return {"Hello": "World!"}
|