yakine commited on
Commit
abf2c45
·
verified ·
1 Parent(s): 13454c6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -26
app.py CHANGED
@@ -6,12 +6,7 @@ import os
6
  import torch
7
  from transformers import GPT2LMHeadModel, GPT2Tokenizer, AutoTokenizer, AutoModelForCausalLM, pipeline
8
  from io import StringIO
9
- from tqdm import tqdm
10
- import accelerate
11
- from accelerate import init_empty_weights, disk_offload
12
  from fastapi.middleware.cors import CORSMiddleware
13
- import re
14
-
15
 
16
  app = FastAPI()
17
 
@@ -46,7 +41,7 @@ model_llama = AutoModelForCausalLM.from_pretrained(
46
  token=hf_token
47
  )
48
 
49
- # Define your prompt template with a check for the purpose of dataset generation
50
  prompt_template = """\
51
  You are an expert in generating synthetic data for machine learning models.
52
  Your task is to generate a synthetic tabular dataset based on the description provided below.
@@ -54,7 +49,6 @@ Description: {description}
54
  The dataset should include the following columns: {columns}
55
  Please provide the data in CSV format with a minimum of 100 rows per generation.
56
  Ensure that the data is realistic, does not contain any duplicate rows, and follows any specific conditions mentioned.
57
- If the description is not related to generating synthetic datasets, please return the message: "This request is not for generating synthetic datasets."
58
  Example Description:
59
  Generate a dataset for predicting house prices with columns: 'Size', 'Location', 'Number of Bedrooms', 'Price'
60
  Example Output:
@@ -113,24 +107,10 @@ def generate_synthetic_data(description, columns):
113
  return generated_text
114
  except Exception as e:
115
  return f"Error: {e}"
116
-
117
- def clean_generated_text(generated_text):
118
- # Extract CSV part using a regular expression
119
- csv_match = re.search(r'(\n?([A-Za-z0-9_]+,)*[A-Za-z0-9_]+\n([^\n,]*,)*[^\n,]*\n*)+', generated_text)
120
-
121
- if csv_match:
122
- csv_text = csv_match.group(0)
123
- else:
124
- raise ValueError("No valid CSV data found in generated text.")
125
-
126
- return csv_text
127
 
128
  def process_generated_data(csv_data):
129
- # Clean the generated data
130
- cleaned_data = clean_generated_text(csv_data)
131
-
132
  # Convert to DataFrame
133
- data = StringIO(cleaned_data)
134
  df = pd.read_csv(data)
135
 
136
  return df
@@ -144,10 +124,6 @@ def generate_data(request: DataGenerationRequest):
144
  if "Error" in generated_data:
145
  return JSONResponse(content={"error": generated_data}, status_code=500)
146
 
147
- # Check if the model generated a message about incorrect purpose
148
- if "This request is not for generating synthetic datasets." in generated_data:
149
- return JSONResponse(content={"message": "This request is not for generating synthetic datasets."}, status_code=400)
150
-
151
  # Process the generated CSV data into a DataFrame
152
  df_synthetic = process_generated_data(generated_data)
153
  return JSONResponse(content={"data": df_synthetic.to_dict(orient="records")})
 
6
  import torch
7
  from transformers import GPT2LMHeadModel, GPT2Tokenizer, AutoTokenizer, AutoModelForCausalLM, pipeline
8
  from io import StringIO
 
 
 
9
  from fastapi.middleware.cors import CORSMiddleware
 
 
10
 
11
  app = FastAPI()
12
 
 
41
  token=hf_token
42
  )
43
 
44
+ # Define your prompt template
45
  prompt_template = """\
46
  You are an expert in generating synthetic data for machine learning models.
47
  Your task is to generate a synthetic tabular dataset based on the description provided below.
 
49
  The dataset should include the following columns: {columns}
50
  Please provide the data in CSV format with a minimum of 100 rows per generation.
51
  Ensure that the data is realistic, does not contain any duplicate rows, and follows any specific conditions mentioned.
 
52
  Example Description:
53
  Generate a dataset for predicting house prices with columns: 'Size', 'Location', 'Number of Bedrooms', 'Price'
54
  Example Output:
 
107
  return generated_text
108
  except Exception as e:
109
  return f"Error: {e}"
 
 
 
 
 
 
 
 
 
 
 
110
 
111
  def process_generated_data(csv_data):
 
 
 
112
  # Convert to DataFrame
113
+ data = StringIO(csv_data)
114
  df = pd.read_csv(data)
115
 
116
  return df
 
124
  if "Error" in generated_data:
125
  return JSONResponse(content={"error": generated_data}, status_code=500)
126
 
 
 
 
 
127
  # Process the generated CSV data into a DataFrame
128
  df_synthetic = process_generated_data(generated_data)
129
  return JSONResponse(content={"data": df_synthetic.to_dict(orient="records")})