Update app.py
Browse files
app.py
CHANGED
@@ -27,53 +27,46 @@ if not hf_token:
|
|
27 |
raise ValueError("Hugging Face API token is not set. Please set the HF_API_TOKEN environment variable.")
|
28 |
|
29 |
# Load GPT-2 model and tokenizer
|
30 |
-
tokenizer_gpt2 =
|
31 |
model_gpt2 = GPT2LMHeadModel.from_pretrained('gpt2')
|
32 |
|
33 |
# Create a pipeline for text generation using GPT-2
|
34 |
text_generator = pipeline("text-generation", model=model_gpt2, tokenizer=tokenizer_gpt2)
|
35 |
|
36 |
-
def preprocess_user_prompt(user_prompt):
|
37 |
-
# Generate a structured prompt based on the user input
|
38 |
-
generated_text = text_generator(user_prompt, max_length=50, num_return_sequences=1)[0]["generated_text"]
|
39 |
-
return generated_text
|
40 |
-
|
41 |
# Define prompt template
|
42 |
prompt_template = """\
|
43 |
You are an expert in generating synthetic data for machine learning models.
|
|
|
44 |
Your task is to generate a synthetic tabular dataset based on the description provided below.
|
|
|
45 |
Description: {description}
|
|
|
46 |
The dataset should include the following columns: {columns}
|
|
|
47 |
Please provide the data in CSV format.
|
|
|
48 |
Example Description:
|
49 |
Generate a dataset for predicting house prices with columns: 'Size', 'Location', 'Number of Bedrooms', 'Price'
|
|
|
50 |
Example Output:
|
51 |
Size,Location,Number of Bedrooms,Price
|
52 |
1200,Suburban,3,250000
|
53 |
900,Urban,2,200000
|
54 |
1500,Rural,4,300000
|
55 |
...
|
|
|
56 |
Description:
|
57 |
{description}
|
58 |
Columns:
|
59 |
{columns}
|
60 |
Output: """
|
61 |
|
62 |
-
class DataGenerationRequest(BaseModel):
|
63 |
-
description: str
|
64 |
-
columns: list
|
65 |
-
|
66 |
# Set up the Mixtral model and tokenizer
|
67 |
-
token =
|
68 |
HfFolder.save_token(token)
|
69 |
|
70 |
tokenizer_mixtral = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-Instruct-v0.1", token=token)
|
71 |
|
72 |
-
def format_prompt(description, columns):
|
73 |
-
processed_description = preprocess_user_prompt(description)
|
74 |
-
prompt = prompt_template.format(description=processed_description, columns=",".join(columns))
|
75 |
-
return prompt
|
76 |
-
|
77 |
API_URL = "https://api-inference.huggingface.co/models/mistralai/Mixtral-8x7B-Instruct-v0.1"
|
78 |
|
79 |
generation_params = {
|
@@ -84,66 +77,67 @@ generation_params = {
|
|
84 |
"use_cache": False
|
85 |
}
|
86 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
def generate_synthetic_data(description, columns):
|
88 |
formatted_prompt = format_prompt(description, columns)
|
89 |
payload = {"inputs": formatted_prompt, "parameters": generation_params}
|
90 |
-
response = requests.post(API_URL, headers={"Authorization": f"Bearer {
|
91 |
-
|
92 |
-
try:
|
93 |
-
response_data = response.json()
|
94 |
-
except ValueError:
|
95 |
-
raise HTTPException(status_code=500, detail="Failed to parse response from the API.")
|
96 |
-
|
97 |
-
if 'error' in response_data:
|
98 |
-
raise HTTPException(status_code=500, detail=f"API Error: {response_data['error']}")
|
99 |
-
|
100 |
-
if 'generated_text' not in response_data[0]:
|
101 |
-
raise HTTPException(status_code=500, detail="Unexpected API response format.")
|
102 |
-
|
103 |
-
return response_data[0]["generated_text"]
|
104 |
|
105 |
def process_generated_data(csv_data, expected_columns):
|
106 |
try:
|
107 |
cleaned_data = csv_data.replace('\r\n', '\n').replace('\r', '\n')
|
108 |
data = StringIO(cleaned_data)
|
109 |
df = pd.read_csv(data, delimiter=',')
|
110 |
-
|
111 |
if set(df.columns) != set(expected_columns):
|
112 |
-
|
113 |
-
|
114 |
return df
|
115 |
except pd.errors.ParserError as e:
|
116 |
-
|
|
|
117 |
|
118 |
def generate_large_synthetic_data(description, columns, num_rows=1000, rows_per_generation=100):
|
119 |
-
|
120 |
-
|
121 |
for _ in tqdm(range(num_rows // rows_per_generation), desc="Generating Data"):
|
122 |
generated_data = generate_synthetic_data(description, columns)
|
123 |
df_synthetic = process_generated_data(generated_data, columns)
|
124 |
-
|
125 |
-
|
126 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
|
128 |
-
|
129 |
-
|
130 |
-
|
|
|
|
|
|
|
131 |
else:
|
132 |
-
|
133 |
|
134 |
-
|
135 |
-
|
136 |
-
description = request.description.strip()
|
137 |
-
columns = [col.strip() for col in request.columns]
|
138 |
-
csv_data = generate_large_synthetic_data(description, columns, num_rows=1000, rows_per_generation=100)
|
139 |
-
|
140 |
-
# Return the CSV data as a downloadable file
|
141 |
-
return StreamingResponse(
|
142 |
-
csv_data,
|
143 |
-
media_type="text/csv",
|
144 |
-
headers={"Content-Disposition": "attachment; filename=generated_data.csv"}
|
145 |
-
)
|
146 |
-
|
147 |
-
@app.get("/")
|
148 |
-
def greet_json():
|
149 |
-
return {"Hello": "World!"}
|
|
|
27 |
raise ValueError("Hugging Face API token is not set. Please set the HF_API_TOKEN environment variable.")
|
28 |
|
29 |
# Load GPT-2 model and tokenizer
|
30 |
+
tokenizer_gpt2 = GPT2LMHeadModel.from_pretrained('gpt2')
|
31 |
model_gpt2 = GPT2LMHeadModel.from_pretrained('gpt2')
|
32 |
|
33 |
# Create a pipeline for text generation using GPT-2
|
34 |
text_generator = pipeline("text-generation", model=model_gpt2, tokenizer=tokenizer_gpt2)
|
35 |
|
|
|
|
|
|
|
|
|
|
|
36 |
# Define prompt template
|
37 |
prompt_template = """\
|
38 |
You are an expert in generating synthetic data for machine learning models.
|
39 |
+
|
40 |
Your task is to generate a synthetic tabular dataset based on the description provided below.
|
41 |
+
|
42 |
Description: {description}
|
43 |
+
|
44 |
The dataset should include the following columns: {columns}
|
45 |
+
|
46 |
Please provide the data in CSV format.
|
47 |
+
|
48 |
Example Description:
|
49 |
Generate a dataset for predicting house prices with columns: 'Size', 'Location', 'Number of Bedrooms', 'Price'
|
50 |
+
|
51 |
Example Output:
|
52 |
Size,Location,Number of Bedrooms,Price
|
53 |
1200,Suburban,3,250000
|
54 |
900,Urban,2,200000
|
55 |
1500,Rural,4,300000
|
56 |
...
|
57 |
+
|
58 |
Description:
|
59 |
{description}
|
60 |
Columns:
|
61 |
{columns}
|
62 |
Output: """
|
63 |
|
|
|
|
|
|
|
|
|
64 |
# Set up the Mixtral model and tokenizer
|
65 |
+
token = os.getenv("HF_TOKEN")
|
66 |
HfFolder.save_token(token)
|
67 |
|
68 |
tokenizer_mixtral = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-Instruct-v0.1", token=token)
|
69 |
|
|
|
|
|
|
|
|
|
|
|
70 |
API_URL = "https://api-inference.huggingface.co/models/mistralai/Mixtral-8x7B-Instruct-v0.1"
|
71 |
|
72 |
generation_params = {
|
|
|
77 |
"use_cache": False
|
78 |
}
|
79 |
|
80 |
+
def preprocess_user_prompt(user_prompt):
|
81 |
+
generated_text = text_generator(user_prompt, max_length=50, num_return_sequences=1)[0]["generated_text"]
|
82 |
+
return generated_text
|
83 |
+
|
84 |
+
def format_prompt(description, columns):
|
85 |
+
processed_description = preprocess_user_prompt(description)
|
86 |
+
prompt = prompt_template.format(description=processed_description, columns=",".join(columns))
|
87 |
+
return prompt
|
88 |
+
|
89 |
def generate_synthetic_data(description, columns):
|
90 |
formatted_prompt = format_prompt(description, columns)
|
91 |
payload = {"inputs": formatted_prompt, "parameters": generation_params}
|
92 |
+
response = requests.post(API_URL, headers={"Authorization": f"Bearer {token}"}, json=payload)
|
93 |
+
return response.json()[0]["generated_text"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
|
95 |
def process_generated_data(csv_data, expected_columns):
|
96 |
try:
|
97 |
cleaned_data = csv_data.replace('\r\n', '\n').replace('\r', '\n')
|
98 |
data = StringIO(cleaned_data)
|
99 |
df = pd.read_csv(data, delimiter=',')
|
|
|
100 |
if set(df.columns) != set(expected_columns):
|
101 |
+
print(f"Unexpected columns in the generated data: {df.columns}")
|
102 |
+
return None
|
103 |
return df
|
104 |
except pd.errors.ParserError as e:
|
105 |
+
print(f"Failed to parse CSV data: {e}")
|
106 |
+
return None
|
107 |
|
108 |
def generate_large_synthetic_data(description, columns, num_rows=1000, rows_per_generation=100):
|
109 |
+
data_frames = []
|
|
|
110 |
for _ in tqdm(range(num_rows // rows_per_generation), desc="Generating Data"):
|
111 |
generated_data = generate_synthetic_data(description, columns)
|
112 |
df_synthetic = process_generated_data(generated_data, columns)
|
113 |
+
if df_synthetic is not None and not df_synthetic.empty:
|
114 |
+
data_frames.append(df_synthetic)
|
115 |
+
else:
|
116 |
+
print("Skipping invalid generation.")
|
117 |
+
if data_frames:
|
118 |
+
return pd.concat(data_frames, ignore_index=True)
|
119 |
+
else:
|
120 |
+
print("No valid data frames to concatenate.")
|
121 |
+
return pd.DataFrame(columns=columns)
|
122 |
+
|
123 |
+
@app.route('/generate', methods=['POST'])
|
124 |
+
def generate():
|
125 |
+
data = request.json
|
126 |
+
description = data.get('description')
|
127 |
+
columns = data.get('columns')
|
128 |
+
num_rows = data.get('num_rows', 1000)
|
129 |
+
|
130 |
+
if not description or not columns:
|
131 |
+
return jsonify({"error": "Please provide 'description' and 'columns' in the request."}), 400
|
132 |
|
133 |
+
df_synthetic = generate_large_synthetic_data(description, columns, num_rows=num_rows)
|
134 |
+
|
135 |
+
if df_synthetic is not None and not df_synthetic.empty:
|
136 |
+
file_path = 'synthetic_data.csv'
|
137 |
+
df_synthetic.to_csv(file_path, index=False)
|
138 |
+
return send_file(file_path, as_attachment=True)
|
139 |
else:
|
140 |
+
return jsonify({"error": "Failed to generate a valid synthetic dataset."}), 500
|
141 |
|
142 |
+
if __name__ == "__main__":
|
143 |
+
app.run(host='0.0.0.0', port=8000)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|