yakine commited on
Commit
7758368
·
verified ·
1 Parent(s): 60a984b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -10
app.py CHANGED
@@ -96,12 +96,10 @@ def generate_synthetic_data(description, columns):
96
  return response_data[0]["generated_text"]
97
 
98
  def extract_valid_csv(csv_data, expected_columns):
99
- # Attempt to extract valid CSV data from the response
100
  lines = csv_data.split('\n')
101
  valid_lines = []
102
-
103
- # Find the start of the valid CSV data by looking for the column headers
104
  header_found = False
 
105
  for line in lines:
106
  if header_found:
107
  if line.strip() == '':
@@ -111,7 +109,6 @@ def extract_valid_csv(csv_data, expected_columns):
111
  header_found = True
112
  valid_lines.append(line)
113
 
114
- # Combine valid lines into a single string
115
  valid_csv_data = '\n'.join(valid_lines)
116
  return valid_csv_data
117
 
@@ -120,30 +117,29 @@ def process_generated_data(csv_data, expected_columns):
120
  valid_csv_data = extract_valid_csv(csv_data, expected_columns)
121
  data = StringIO(valid_csv_data)
122
 
123
- # Read the CSV data
124
  df = pd.read_csv(data, delimiter=',')
125
 
126
- # Check if the DataFrame has the expected columns
127
  if set(df.columns) != set(expected_columns):
128
  return f"Unexpected columns in the generated data: {df.columns}"
129
 
130
  return df
131
  except pd.errors.ParserError as e:
 
132
  return f"Failed to parse CSV data: {e}"
133
 
134
  def generate_large_synthetic_data(description, columns, num_rows=1000, rows_per_generation=100):
135
  csv_data_all = ""
136
 
137
- for _ in tqdm(range(num_rows // rows_per_generation), desc="Generating Data"):
138
  generated_data = generate_synthetic_data(description, columns)
139
  if "Error" in generated_data:
140
- return generated_data # Return the error message
141
 
142
  df_synthetic = process_generated_data(generated_data, columns)
143
  if isinstance(df_synthetic, pd.DataFrame) and not df_synthetic.empty:
144
  csv_data_all += df_synthetic.to_csv(index=False, header=False)
145
  else:
146
- print("Skipping invalid generation.")
147
 
148
  if csv_data_all:
149
  return csv_data_all
@@ -159,7 +155,6 @@ def generate_data(request: DataGenerationRequest):
159
  if isinstance(generated_data, str) and "Error" in generated_data:
160
  return JSONResponse(content={"error": generated_data}, status_code=500)
161
 
162
- # Create a streaming response to return the CSV data
163
  csv_buffer = StringIO(generated_data)
164
  return StreamingResponse(
165
  csv_buffer,
 
96
  return response_data[0]["generated_text"]
97
 
98
  def extract_valid_csv(csv_data, expected_columns):
 
99
  lines = csv_data.split('\n')
100
  valid_lines = []
 
 
101
  header_found = False
102
+
103
  for line in lines:
104
  if header_found:
105
  if line.strip() == '':
 
109
  header_found = True
110
  valid_lines.append(line)
111
 
 
112
  valid_csv_data = '\n'.join(valid_lines)
113
  return valid_csv_data
114
 
 
117
  valid_csv_data = extract_valid_csv(csv_data, expected_columns)
118
  data = StringIO(valid_csv_data)
119
 
 
120
  df = pd.read_csv(data, delimiter=',')
121
 
 
122
  if set(df.columns) != set(expected_columns):
123
  return f"Unexpected columns in the generated data: {df.columns}"
124
 
125
  return df
126
  except pd.errors.ParserError as e:
127
+ logging.error(f"Failed to parse CSV data: {e}")
128
  return f"Failed to parse CSV data: {e}"
129
 
130
  def generate_large_synthetic_data(description, columns, num_rows=1000, rows_per_generation=100):
131
  csv_data_all = ""
132
 
133
+ for _ in range(num_rows // rows_per_generation):
134
  generated_data = generate_synthetic_data(description, columns)
135
  if "Error" in generated_data:
136
+ return generated_data
137
 
138
  df_synthetic = process_generated_data(generated_data, columns)
139
  if isinstance(df_synthetic, pd.DataFrame) and not df_synthetic.empty:
140
  csv_data_all += df_synthetic.to_csv(index=False, header=False)
141
  else:
142
+ logging.info("Skipping invalid generation.")
143
 
144
  if csv_data_all:
145
  return csv_data_all
 
155
  if isinstance(generated_data, str) and "Error" in generated_data:
156
  return JSONResponse(content={"error": generated_data}, status_code=500)
157
 
 
158
  csv_buffer = StringIO(generated_data)
159
  return StreamingResponse(
160
  csv_buffer,