yakine commited on
Commit
bad46c5
·
verified ·
1 Parent(s): abf2c45

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -49
app.py CHANGED
@@ -3,10 +3,12 @@ from fastapi.responses import JSONResponse
3
  from pydantic import BaseModel
4
  import pandas as pd
5
  import os
6
- import torch
7
  from transformers import GPT2LMHeadModel, GPT2Tokenizer, AutoTokenizer, AutoModelForCausalLM, pipeline
8
  from io import StringIO
9
  from fastapi.middleware.cors import CORSMiddleware
 
 
10
 
11
  app = FastAPI()
12
 
@@ -24,39 +26,40 @@ hf_token = os.getenv('HF_API_TOKEN')
24
  if not hf_token:
25
  raise ValueError("Hugging Face API token is not set. Please set the HF_API_TOKEN environment variable.")
26
 
27
- # Load the GPT-2 tokenizer and model
28
  tokenizer_gpt2 = GPT2Tokenizer.from_pretrained('gpt2')
29
  model_gpt2 = GPT2LMHeadModel.from_pretrained('gpt2')
30
 
31
  # Create a pipeline for text generation using GPT-2
32
  text_generator = pipeline("text-generation", model=model_gpt2, tokenizer=tokenizer_gpt2)
33
 
34
- # Load the Llama-3 model and tokenizer once during startup
35
- device = "cuda" if torch.cuda.is_available() else "cpu"
36
- tokenizer_llama = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B", token=hf_token)
37
- model_llama = AutoModelForCausalLM.from_pretrained(
38
- "meta-llama/Meta-Llama-3-8B",
39
- torch_dtype='auto',
40
- device_map='balanced',
41
- token=hf_token
42
- )
43
 
44
- # Define your prompt template
45
  prompt_template = """\
46
  You are an expert in generating synthetic data for machine learning models.
 
47
  Your task is to generate a synthetic tabular dataset based on the description provided below.
 
48
  Description: {description}
 
49
  The dataset should include the following columns: {columns}
50
- Please provide the data in CSV format with a minimum of 100 rows per generation.
51
- Ensure that the data is realistic, does not contain any duplicate rows, and follows any specific conditions mentioned.
 
52
  Example Description:
53
  Generate a dataset for predicting house prices with columns: 'Size', 'Location', 'Number of Bedrooms', 'Price'
 
54
  Example Output:
55
  Size,Location,Number of Bedrooms,Price
56
  1200,Suburban,3,250000
57
  900,Urban,2,200000
58
  1500,Rural,4,300000
59
  ...
 
60
  Description:
61
  {description}
62
  Columns:
@@ -67,66 +70,89 @@ class DataGenerationRequest(BaseModel):
67
  description: str
68
  columns: list
69
 
70
- def preprocess_user_prompt(user_prompt):
71
- generated_text = text_generator(user_prompt, max_length=60, num_return_sequences=1, truncation=True)[0]["generated_text"]
72
- return generated_text
 
 
73
 
74
  def format_prompt(description, columns):
75
  processed_description = preprocess_user_prompt(description)
76
  prompt = prompt_template.format(description=processed_description, columns=",".join(columns))
77
  return prompt
78
 
 
 
79
  generation_params = {
80
  "top_p": 0.90,
81
  "temperature": 0.8,
82
  "max_new_tokens": 512,
 
 
83
  }
84
 
85
  def generate_synthetic_data(description, columns):
 
 
 
 
 
 
 
 
 
 
 
86
  try:
87
- # Prepare the input for the Llama model
88
- formatted_prompt = format_prompt(description, columns)
89
-
90
- # Tokenize the prompt with truncation enabled
91
- inputs = tokenizer_llama(formatted_prompt, return_tensors="pt", truncation=True, max_length=512).to(model_llama.device)
92
-
93
- # Generate synthetic data
94
- with torch.no_grad():
95
- outputs = model_llama.generate(
96
- **inputs,
97
- max_length=512,
98
- top_p=generation_params["top_p"],
99
- temperature=generation_params["temperature"],
100
- num_return_sequences=1,
101
- )
102
-
103
- # Decode the generated output
104
- generated_text = tokenizer_llama.decode(outputs[0], skip_special_tokens=True)
105
 
106
- # Return the generated synthetic data
107
- return generated_text
108
- except Exception as e:
109
- return f"Error: {e}"
110
-
111
- def process_generated_data(csv_data):
112
- # Convert to DataFrame
113
- data = StringIO(csv_data)
114
- df = pd.read_csv(data)
 
115
 
116
- return df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
  @app.post("/generate/")
119
  def generate_data(request: DataGenerationRequest):
120
  description = request.description.strip()
121
  columns = [col.strip() for col in request.columns]
122
- generated_data = generate_synthetic_data(description, columns)
123
 
124
- if "Error" in generated_data:
125
  return JSONResponse(content={"error": generated_data}, status_code=500)
126
 
127
  # Process the generated CSV data into a DataFrame
128
- df_synthetic = process_generated_data(generated_data)
129
- return JSONResponse(content={"data": df_synthetic.to_dict(orient="records")})
 
 
 
130
 
131
  @app.get("/")
132
  def greet_json():
 
3
  from pydantic import BaseModel
4
  import pandas as pd
5
  import os
6
+ import requests
7
  from transformers import GPT2LMHeadModel, GPT2Tokenizer, AutoTokenizer, AutoModelForCausalLM, pipeline
8
  from io import StringIO
9
  from fastapi.middleware.cors import CORSMiddleware
10
+ from huggingface_hub import HfFolder
11
+ from tqdm import tqdm
12
 
13
  app = FastAPI()
14
 
 
26
  if not hf_token:
27
  raise ValueError("Hugging Face API token is not set. Please set the HF_API_TOKEN environment variable.")
28
 
29
+ # Load GPT-2 model and tokenizer
30
  tokenizer_gpt2 = GPT2Tokenizer.from_pretrained('gpt2')
31
  model_gpt2 = GPT2LMHeadModel.from_pretrained('gpt2')
32
 
33
  # Create a pipeline for text generation using GPT-2
34
  text_generator = pipeline("text-generation", model=model_gpt2, tokenizer=tokenizer_gpt2)
35
 
36
+ def preprocess_user_prompt(user_prompt):
37
+ # Generate a structured prompt based on the user input
38
+ generated_text = text_generator(user_prompt, max_length=50, num_return_sequences=1)[0]["generated_text"]
39
+ return generated_text
 
 
 
 
 
40
 
41
+ # Define prompt template
42
  prompt_template = """\
43
  You are an expert in generating synthetic data for machine learning models.
44
+
45
  Your task is to generate a synthetic tabular dataset based on the description provided below.
46
+
47
  Description: {description}
48
+
49
  The dataset should include the following columns: {columns}
50
+
51
+ Please provide the data in CSV format.
52
+
53
  Example Description:
54
  Generate a dataset for predicting house prices with columns: 'Size', 'Location', 'Number of Bedrooms', 'Price'
55
+
56
  Example Output:
57
  Size,Location,Number of Bedrooms,Price
58
  1200,Suburban,3,250000
59
  900,Urban,2,200000
60
  1500,Rural,4,300000
61
  ...
62
+
63
  Description:
64
  {description}
65
  Columns:
 
70
  description: str
71
  columns: list
72
 
73
+ # Set up the Mixtral model and tokenizer
74
+ token = hf_token # Use environment variable for the token
75
+ HfFolder.save_token(token)
76
+
77
+ tokenizer_mixtral = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-Instruct-v0.1", token=token)
78
 
79
  def format_prompt(description, columns):
80
  processed_description = preprocess_user_prompt(description)
81
  prompt = prompt_template.format(description=processed_description, columns=",".join(columns))
82
  return prompt
83
 
84
+ API_URL = "https://api-inference.huggingface.co/models/mistralai/Mixtral-8x7B-Instruct-v0.1"
85
+
86
  generation_params = {
87
  "top_p": 0.90,
88
  "temperature": 0.8,
89
  "max_new_tokens": 512,
90
+ "return_full_text": False,
91
+ "use_cache": False
92
  }
93
 
94
  def generate_synthetic_data(description, columns):
95
+ formatted_prompt = format_prompt(description, columns)
96
+ payload = {"inputs": formatted_prompt, "parameters": generation_params}
97
+ response = requests.post(API_URL, headers={"Authorization": f"Bearer {token}"}, json=payload)
98
+ response_data = response.json()
99
+
100
+ if 'error' in response_data:
101
+ return f"Error: {response_data['error']}"
102
+
103
+ return response_data[0]["generated_text"]
104
+
105
+ def process_generated_data(csv_data, expected_columns):
106
  try:
107
+ # Ensure the data is cleaned and correctly formatted
108
+ cleaned_data = csv_data.replace('\r\n', '\n').replace('\r', '\n')
109
+ data = StringIO(cleaned_data)
110
+
111
+ # Read the CSV data
112
+ df = pd.read_csv(data, delimiter=',')
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
+ # Check if the DataFrame has the expected columns
115
+ if set(df.columns) != set(expected_columns):
116
+ return f"Unexpected columns in the generated data: {df.columns}"
117
+
118
+ return df
119
+ except pd.errors.ParserError as e:
120
+ return f"Failed to parse CSV data: {e}"
121
+
122
+ def generate_large_synthetic_data(description, columns, num_rows=1000, rows_per_generation=100):
123
+ data_frames = []
124
 
125
+ for _ in tqdm(range(num_rows // rows_per_generation), desc="Generating Data"):
126
+ generated_data = generate_synthetic_data(description, columns)
127
+ if "Error" in generated_data:
128
+ return generated_data # Return the error message
129
+
130
+ df_synthetic = process_generated_data(generated_data, columns)
131
+ if isinstance(df_synthetic, pd.DataFrame) and not df_synthetic.empty:
132
+ data_frames.append(df_synthetic)
133
+ else:
134
+ print("Skipping invalid generation.")
135
+
136
+ if data_frames:
137
+ return pd.concat(data_frames, ignore_index=True)
138
+ else:
139
+ return "No valid data frames to concatenate."
140
 
141
  @app.post("/generate/")
142
  def generate_data(request: DataGenerationRequest):
143
  description = request.description.strip()
144
  columns = [col.strip() for col in request.columns]
145
+ generated_data = generate_large_synthetic_data(description, columns, num_rows=1000, rows_per_generation=100)
146
 
147
+ if isinstance(generated_data, str) and "Error" in generated_data:
148
  return JSONResponse(content={"error": generated_data}, status_code=500)
149
 
150
  # Process the generated CSV data into a DataFrame
151
+ df_synthetic = process_generated_data(generated_data, columns)
152
+ if isinstance(df_synthetic, pd.DataFrame):
153
+ return JSONResponse(content={"data": df_synthetic.to_dict(orient="records")})
154
+ else:
155
+ return JSONResponse(content={"error": "Failed to generate valid synthetic data"}, status_code=500)
156
 
157
  @app.get("/")
158
  def greet_json():