yakine commited on
Commit
01c0141
·
verified ·
1 Parent(s): 4d35d05

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -59
app.py CHANGED
@@ -27,53 +27,46 @@ if not hf_token:
27
  raise ValueError("Hugging Face API token is not set. Please set the HF_API_TOKEN environment variable.")
28
 
29
  # Load GPT-2 model and tokenizer
30
- tokenizer_gpt2 = GPT2Tokenizer.from_pretrained('gpt2')
31
  model_gpt2 = GPT2LMHeadModel.from_pretrained('gpt2')
32
 
33
  # Create a pipeline for text generation using GPT-2
34
  text_generator = pipeline("text-generation", model=model_gpt2, tokenizer=tokenizer_gpt2)
35
 
36
- def preprocess_user_prompt(user_prompt):
37
- # Generate a structured prompt based on the user input
38
- generated_text = text_generator(user_prompt, max_length=50, num_return_sequences=1)[0]["generated_text"]
39
- return generated_text
40
-
41
  # Define prompt template
42
  prompt_template = """\
43
  You are an expert in generating synthetic data for machine learning models.
 
44
  Your task is to generate a synthetic tabular dataset based on the description provided below.
 
45
  Description: {description}
 
46
  The dataset should include the following columns: {columns}
 
47
  Please provide the data in CSV format.
 
48
  Example Description:
49
  Generate a dataset for predicting house prices with columns: 'Size', 'Location', 'Number of Bedrooms', 'Price'
 
50
  Example Output:
51
  Size,Location,Number of Bedrooms,Price
52
  1200,Suburban,3,250000
53
  900,Urban,2,200000
54
  1500,Rural,4,300000
55
  ...
 
56
  Description:
57
  {description}
58
  Columns:
59
  {columns}
60
  Output: """
61
 
62
- class DataGenerationRequest(BaseModel):
63
- description: str
64
- columns: list
65
-
66
  # Set up the Mixtral model and tokenizer
67
- token = hf_token # Use environment variable for the token
68
  HfFolder.save_token(token)
69
 
70
  tokenizer_mixtral = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-Instruct-v0.1", token=token)
71
 
72
- def format_prompt(description, columns):
73
- processed_description = preprocess_user_prompt(description)
74
- prompt = prompt_template.format(description=processed_description, columns=",".join(columns))
75
- return prompt
76
-
77
  API_URL = "https://api-inference.huggingface.co/models/mistralai/Mixtral-8x7B-Instruct-v0.1"
78
 
79
  generation_params = {
@@ -84,66 +77,67 @@ generation_params = {
84
  "use_cache": False
85
  }
86
 
 
 
 
 
 
 
 
 
 
87
  def generate_synthetic_data(description, columns):
88
  formatted_prompt = format_prompt(description, columns)
89
  payload = {"inputs": formatted_prompt, "parameters": generation_params}
90
- response = requests.post(API_URL, headers={"Authorization": f"Bearer {hf_token}"}, json=payload)
91
-
92
- try:
93
- response_data = response.json()
94
- except ValueError:
95
- raise HTTPException(status_code=500, detail="Failed to parse response from the API.")
96
-
97
- if 'error' in response_data:
98
- raise HTTPException(status_code=500, detail=f"API Error: {response_data['error']}")
99
-
100
- if 'generated_text' not in response_data[0]:
101
- raise HTTPException(status_code=500, detail="Unexpected API response format.")
102
-
103
- return response_data[0]["generated_text"]
104
 
105
  def process_generated_data(csv_data, expected_columns):
106
  try:
107
  cleaned_data = csv_data.replace('\r\n', '\n').replace('\r', '\n')
108
  data = StringIO(cleaned_data)
109
  df = pd.read_csv(data, delimiter=',')
110
-
111
  if set(df.columns) != set(expected_columns):
112
- raise ValueError("Unexpected columns in the generated data.")
113
-
114
  return df
115
  except pd.errors.ParserError as e:
116
- raise HTTPException(status_code=500, detail=f"Failed to parse CSV data: {e}")
 
117
 
118
  def generate_large_synthetic_data(description, columns, num_rows=1000, rows_per_generation=100):
119
- csv_data_all = StringIO()
120
-
121
  for _ in tqdm(range(num_rows // rows_per_generation), desc="Generating Data"):
122
  generated_data = generate_synthetic_data(description, columns)
123
  df_synthetic = process_generated_data(generated_data, columns)
124
-
125
- if isinstance(df_synthetic, pd.DataFrame) and not df_synthetic.empty:
126
- df_synthetic.to_csv(csv_data_all, index=False, header=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
- if csv_data_all.tell() > 0: # Check if there's any data in the buffer
129
- csv_data_all.seek(0) # Rewind the buffer to the beginning
130
- return csv_data_all
 
 
 
131
  else:
132
- raise HTTPException(status_code=500, detail="No valid data frames generated.")
133
 
134
- @app.post("/generate/")
135
- def generate_data(request: DataGenerationRequest):
136
- description = request.description.strip()
137
- columns = [col.strip() for col in request.columns]
138
- csv_data = generate_large_synthetic_data(description, columns, num_rows=1000, rows_per_generation=100)
139
-
140
- # Return the CSV data as a downloadable file
141
- return StreamingResponse(
142
- csv_data,
143
- media_type="text/csv",
144
- headers={"Content-Disposition": "attachment; filename=generated_data.csv"}
145
- )
146
-
147
- @app.get("/")
148
- def greet_json():
149
- return {"Hello": "World!"}
 
27
  raise ValueError("Hugging Face API token is not set. Please set the HF_API_TOKEN environment variable.")
28
 
29
  # Load GPT-2 model and tokenizer
30
+ tokenizer_gpt2 = GPT2LMHeadModel.from_pretrained('gpt2')
31
  model_gpt2 = GPT2LMHeadModel.from_pretrained('gpt2')
32
 
33
  # Create a pipeline for text generation using GPT-2
34
  text_generator = pipeline("text-generation", model=model_gpt2, tokenizer=tokenizer_gpt2)
35
 
 
 
 
 
 
36
  # Define prompt template
37
  prompt_template = """\
38
  You are an expert in generating synthetic data for machine learning models.
39
+
40
  Your task is to generate a synthetic tabular dataset based on the description provided below.
41
+
42
  Description: {description}
43
+
44
  The dataset should include the following columns: {columns}
45
+
46
  Please provide the data in CSV format.
47
+
48
  Example Description:
49
  Generate a dataset for predicting house prices with columns: 'Size', 'Location', 'Number of Bedrooms', 'Price'
50
+
51
  Example Output:
52
  Size,Location,Number of Bedrooms,Price
53
  1200,Suburban,3,250000
54
  900,Urban,2,200000
55
  1500,Rural,4,300000
56
  ...
57
+
58
  Description:
59
  {description}
60
  Columns:
61
  {columns}
62
  Output: """
63
 
 
 
 
 
64
  # Set up the Mixtral model and tokenizer
65
+ token = os.getenv("HF_TOKEN")
66
  HfFolder.save_token(token)
67
 
68
  tokenizer_mixtral = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-Instruct-v0.1", token=token)
69
 
 
 
 
 
 
70
  API_URL = "https://api-inference.huggingface.co/models/mistralai/Mixtral-8x7B-Instruct-v0.1"
71
 
72
  generation_params = {
 
77
  "use_cache": False
78
  }
79
 
80
+ def preprocess_user_prompt(user_prompt):
81
+ generated_text = text_generator(user_prompt, max_length=50, num_return_sequences=1)[0]["generated_text"]
82
+ return generated_text
83
+
84
+ def format_prompt(description, columns):
85
+ processed_description = preprocess_user_prompt(description)
86
+ prompt = prompt_template.format(description=processed_description, columns=",".join(columns))
87
+ return prompt
88
+
89
  def generate_synthetic_data(description, columns):
90
  formatted_prompt = format_prompt(description, columns)
91
  payload = {"inputs": formatted_prompt, "parameters": generation_params}
92
+ response = requests.post(API_URL, headers={"Authorization": f"Bearer {token}"}, json=payload)
93
+ return response.json()[0]["generated_text"]
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
  def process_generated_data(csv_data, expected_columns):
96
  try:
97
  cleaned_data = csv_data.replace('\r\n', '\n').replace('\r', '\n')
98
  data = StringIO(cleaned_data)
99
  df = pd.read_csv(data, delimiter=',')
 
100
  if set(df.columns) != set(expected_columns):
101
+ print(f"Unexpected columns in the generated data: {df.columns}")
102
+ return None
103
  return df
104
  except pd.errors.ParserError as e:
105
+ print(f"Failed to parse CSV data: {e}")
106
+ return None
107
 
108
  def generate_large_synthetic_data(description, columns, num_rows=1000, rows_per_generation=100):
109
+ data_frames = []
 
110
  for _ in tqdm(range(num_rows // rows_per_generation), desc="Generating Data"):
111
  generated_data = generate_synthetic_data(description, columns)
112
  df_synthetic = process_generated_data(generated_data, columns)
113
+ if df_synthetic is not None and not df_synthetic.empty:
114
+ data_frames.append(df_synthetic)
115
+ else:
116
+ print("Skipping invalid generation.")
117
+ if data_frames:
118
+ return pd.concat(data_frames, ignore_index=True)
119
+ else:
120
+ print("No valid data frames to concatenate.")
121
+ return pd.DataFrame(columns=columns)
122
+
123
+ @app.route('/generate', methods=['POST'])
124
+ def generate():
125
+ data = request.json
126
+ description = data.get('description')
127
+ columns = data.get('columns')
128
+ num_rows = data.get('num_rows', 1000)
129
+
130
+ if not description or not columns:
131
+ return jsonify({"error": "Please provide 'description' and 'columns' in the request."}), 400
132
 
133
+ df_synthetic = generate_large_synthetic_data(description, columns, num_rows=num_rows)
134
+
135
+ if df_synthetic is not None and not df_synthetic.empty:
136
+ file_path = 'synthetic_data.csv'
137
+ df_synthetic.to_csv(file_path, index=False)
138
+ return send_file(file_path, as_attachment=True)
139
  else:
140
+ return jsonify({"error": "Failed to generate a valid synthetic dataset."}), 500
141
 
142
+ if __name__ == "__main__":
143
+ app.run(host='0.0.0.0', port=8000)