yakine commited on
Commit
857e937
·
verified ·
1 Parent(s): 31d4200

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -16
app.py CHANGED
@@ -1,5 +1,5 @@
1
  from fastapi import FastAPI, HTTPException
2
- from fastapi.responses import StreamingResponse
3
  from pydantic import BaseModel
4
  import pandas as pd
5
  import os
@@ -41,32 +41,24 @@ def preprocess_user_prompt(user_prompt):
41
  # Define prompt template
42
  prompt_template = """\
43
  You are an expert in generating synthetic data for machine learning models.
44
-
45
  Your task is to generate a synthetic tabular dataset based on the description provided below.
46
-
47
  Description: {description}
48
-
49
  The dataset should include the following columns: {columns}
50
-
51
  Please provide the data in CSV format.
52
-
53
  Example Description:
54
  Generate a dataset for predicting house prices with columns: 'Size', 'Location', 'Number of Bedrooms', 'Price'
55
-
56
  Example Output:
57
  Size,Location,Number of Bedrooms,Price
58
  1200,Suburban,3,250000
59
  900,Urban,2,200000
60
  1500,Rural,4,300000
61
  ...
62
-
63
  Description:
64
  {description}
65
  Columns:
66
  {columns}
67
  Output: """
68
 
69
-
70
  tokenizer_mixtral = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-Instruct-v0.1", token=hf_token)
71
 
72
  def format_prompt(description, columns):
@@ -87,8 +79,17 @@ generation_params = {
87
  def generate_synthetic_data(description, columns):
88
  formatted_prompt = format_prompt(description, columns)
89
  payload = {"inputs": formatted_prompt, "parameters": generation_params}
90
- response = requests.post(API_URL, headers={"Authorization": f"Bearer {hf_token}"}, json=payload)
91
- return response.json()[0]["generated_text"]
 
 
 
 
 
 
 
 
 
92
 
93
  def process_generated_data(csv_data, expected_columns):
94
  try:
@@ -114,12 +115,14 @@ def generate_large_synthetic_data(description, columns, num_rows=1000, rows_per_
114
 
115
  for _ in tqdm(range(num_rows // rows_per_generation), desc="Generating Data"):
116
  generated_data = generate_synthetic_data(description, columns)
117
- df_synthetic = process_generated_data(generated_data, columns)
118
-
119
- if df_synthetic is not None and not df_synthetic.empty:
120
- data_frames.append(df_synthetic)
 
 
121
  else:
122
- print("Skipping invalid generation.")
123
 
124
  if data_frames:
125
  return pd.concat(data_frames, ignore_index=True)
 
1
  from fastapi import FastAPI, HTTPException
2
+ from fastapi.responses import StreamingResponse, JSONResponse
3
  from pydantic import BaseModel
4
  import pandas as pd
5
  import os
 
41
  # Define prompt template
42
  prompt_template = """\
43
  You are an expert in generating synthetic data for machine learning models.
 
44
  Your task is to generate a synthetic tabular dataset based on the description provided below.
 
45
  Description: {description}
 
46
  The dataset should include the following columns: {columns}
 
47
  Please provide the data in CSV format.
 
48
  Example Description:
49
  Generate a dataset for predicting house prices with columns: 'Size', 'Location', 'Number of Bedrooms', 'Price'
 
50
  Example Output:
51
  Size,Location,Number of Bedrooms,Price
52
  1200,Suburban,3,250000
53
  900,Urban,2,200000
54
  1500,Rural,4,300000
55
  ...
 
56
  Description:
57
  {description}
58
  Columns:
59
  {columns}
60
  Output: """
61
 
 
62
  tokenizer_mixtral = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-Instruct-v0.1", token=hf_token)
63
 
64
  def format_prompt(description, columns):
 
79
  def generate_synthetic_data(description, columns):
80
  formatted_prompt = format_prompt(description, columns)
81
  payload = {"inputs": formatted_prompt, "parameters": generation_params}
82
+ try:
83
+ response = requests.post(API_URL, headers={"Authorization": f"Bearer {hf_token}"}, json=payload)
84
+ response.raise_for_status()
85
+ data = response.json()
86
+ if 'generated_text' in data[0]:
87
+ return data[0]['generated_text']
88
+ else:
89
+ raise ValueError("Invalid response format from Hugging Face API.")
90
+ except (requests.RequestException, ValueError) as e:
91
+ print(f"Error during API request or response processing: {e}")
92
+ return ""
93
 
94
  def process_generated_data(csv_data, expected_columns):
95
  try:
 
115
 
116
  for _ in tqdm(range(num_rows // rows_per_generation), desc="Generating Data"):
117
  generated_data = generate_synthetic_data(description, columns)
118
+ if generated_data:
119
+ df_synthetic = process_generated_data(generated_data, columns)
120
+ if df_synthetic is not None and not df_synthetic.empty:
121
+ data_frames.append(df_synthetic)
122
+ else:
123
+ print("Skipping invalid generation.")
124
  else:
125
+ print("Skipping empty or invalid generation.")
126
 
127
  if data_frames:
128
  return pd.concat(data_frames, ignore_index=True)