yakine commited on
Commit
20981ee
·
verified ·
1 Parent(s): 7758368

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -25
app.py CHANGED
@@ -4,7 +4,7 @@ from pydantic import BaseModel
4
  import pandas as pd
5
  import os
6
  import requests
7
- from transformers import GPT2LMHeadModel, GPT2Tokenizer, AutoTokenizer, pipeline
8
  from io import StringIO
9
  from fastapi.middleware.cors import CORSMiddleware
10
  from huggingface_hub import HfFolder
@@ -95,51 +95,36 @@ def generate_synthetic_data(description, columns):
95
 
96
  return response_data[0]["generated_text"]
97
 
98
- def extract_valid_csv(csv_data, expected_columns):
99
- lines = csv_data.split('\n')
100
- valid_lines = []
101
- header_found = False
102
-
103
- for line in lines:
104
- if header_found:
105
- if line.strip() == '':
106
- continue
107
- valid_lines.append(line)
108
- elif set(line.split(',')) == set(expected_columns):
109
- header_found = True
110
- valid_lines.append(line)
111
-
112
- valid_csv_data = '\n'.join(valid_lines)
113
- return valid_csv_data
114
-
115
  def process_generated_data(csv_data, expected_columns):
116
  try:
117
- valid_csv_data = extract_valid_csv(csv_data, expected_columns)
118
- data = StringIO(valid_csv_data)
 
119
 
 
120
  df = pd.read_csv(data, delimiter=',')
121
 
 
122
  if set(df.columns) != set(expected_columns):
123
  return f"Unexpected columns in the generated data: {df.columns}"
124
 
125
  return df
126
  except pd.errors.ParserError as e:
127
- logging.error(f"Failed to parse CSV data: {e}")
128
  return f"Failed to parse CSV data: {e}"
129
 
130
  def generate_large_synthetic_data(description, columns, num_rows=1000, rows_per_generation=100):
131
  csv_data_all = ""
132
 
133
- for _ in range(num_rows // rows_per_generation):
134
  generated_data = generate_synthetic_data(description, columns)
135
  if "Error" in generated_data:
136
- return generated_data
137
 
138
  df_synthetic = process_generated_data(generated_data, columns)
139
  if isinstance(df_synthetic, pd.DataFrame) and not df_synthetic.empty:
140
  csv_data_all += df_synthetic.to_csv(index=False, header=False)
141
  else:
142
- logging.info("Skipping invalid generation.")
143
 
144
  if csv_data_all:
145
  return csv_data_all
@@ -155,6 +140,7 @@ def generate_data(request: DataGenerationRequest):
155
  if isinstance(generated_data, str) and "Error" in generated_data:
156
  return JSONResponse(content={"error": generated_data}, status_code=500)
157
 
 
158
  csv_buffer = StringIO(generated_data)
159
  return StreamingResponse(
160
  csv_buffer,
@@ -162,7 +148,6 @@ def generate_data(request: DataGenerationRequest):
162
  headers={"Content-Disposition": "attachment; filename=generated_data.csv"}
163
  )
164
 
165
-
166
  @app.get("/")
167
  def greet_json():
168
  return {"Hello": "World!"}
 
4
  import pandas as pd
5
  import os
6
  import requests
7
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer, AutoTokenizer
8
  from io import StringIO
9
  from fastapi.middleware.cors import CORSMiddleware
10
  from huggingface_hub import HfFolder
 
95
 
96
  return response_data[0]["generated_text"]
97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  def process_generated_data(csv_data, expected_columns):
99
  try:
100
+ # Ensure the data is cleaned and correctly formatted
101
+ cleaned_data = csv_data.replace('\r\n', '\n').replace('\r', '\n')
102
+ data = StringIO(cleaned_data)
103
 
104
+ # Read the CSV data
105
  df = pd.read_csv(data, delimiter=',')
106
 
107
+ # Check if the DataFrame has the expected columns
108
  if set(df.columns) != set(expected_columns):
109
  return f"Unexpected columns in the generated data: {df.columns}"
110
 
111
  return df
112
  except pd.errors.ParserError as e:
 
113
  return f"Failed to parse CSV data: {e}"
114
 
115
  def generate_large_synthetic_data(description, columns, num_rows=1000, rows_per_generation=100):
116
  csv_data_all = ""
117
 
118
+ for _ in tqdm(range(num_rows // rows_per_generation), desc="Generating Data"):
119
  generated_data = generate_synthetic_data(description, columns)
120
  if "Error" in generated_data:
121
+ return generated_data # Return the error message
122
 
123
  df_synthetic = process_generated_data(generated_data, columns)
124
  if isinstance(df_synthetic, pd.DataFrame) and not df_synthetic.empty:
125
  csv_data_all += df_synthetic.to_csv(index=False, header=False)
126
  else:
127
+ print("Skipping invalid generation.")
128
 
129
  if csv_data_all:
130
  return csv_data_all
 
140
  if isinstance(generated_data, str) and "Error" in generated_data:
141
  return JSONResponse(content={"error": generated_data}, status_code=500)
142
 
143
+ # Create a streaming response to return the CSV data
144
  csv_buffer = StringIO(generated_data)
145
  return StreamingResponse(
146
  csv_buffer,
 
148
  headers={"Content-Disposition": "attachment; filename=generated_data.csv"}
149
  )
150
 
 
151
  @app.get("/")
152
  def greet_json():
153
  return {"Hello": "World!"}