yakine commited on
Commit
7af38c7
·
verified ·
1 Parent(s): d7821a8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -35
app.py CHANGED
@@ -14,53 +14,34 @@ app = FastAPI()
14
 
15
  app.add_middleware(
16
  CORSMiddleware,
17
- allow_origins=["*"], # You can specify domains here
18
  allow_credentials=True,
19
  allow_methods=["*"],
20
  allow_headers=["*"],
21
  )
22
 
23
- # Access the Hugging Face API token from environment variables
24
  hf_token = os.getenv('HF_API_TOKEN')
25
-
26
  if not hf_token:
27
  raise ValueError("Hugging Face API token is not set. Please set the HF_API_TOKEN environment variable.")
28
 
29
- # Load GPT-2 model and tokenizer
30
  tokenizer_gpt2 = GPT2Tokenizer.from_pretrained('gpt2')
31
  model_gpt2 = GPT2LMHeadModel.from_pretrained('gpt2')
32
-
33
- # Create a pipeline for text generation using GPT-2
34
  text_generator = pipeline("text-generation", model=model_gpt2, tokenizer=tokenizer_gpt2)
35
 
36
- def preprocess_user_prompt(user_prompt):
37
- # Generate a structured prompt based on the user input
38
- generated_text = text_generator(user_prompt, max_length=50, num_return_sequences=1)[0]["generated_text"]
39
- return generated_text
40
-
41
- # Define prompt template
42
  prompt_template = """\
43
  You are an expert in generating synthetic data for machine learning models.
44
  Your task is to generate a synthetic tabular dataset based on the description provided below.
45
  Description: {description}
46
  The dataset should include the following columns: {columns}
47
  Please provide the data in CSV format.
48
- Example Description:
49
- Generate a dataset for predicting house prices with columns: 'Size', 'Location', 'Number of Bedrooms', 'Price'
50
- Example Output:
51
- Size,Location,Number of Bedrooms,Price
52
- 1200,Suburban,3,250000
53
- 900,Urban,2,200000
54
- 1500,Rural,4,300000
55
- ...
56
- Description:
57
- {description}
58
- Columns:
59
- {columns}
60
- Output: """
61
 
62
  tokenizer_mixtral = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-Instruct-v0.1", token=hf_token)
63
 
 
 
 
 
64
  def format_prompt(description, columns):
65
  processed_description = preprocess_user_prompt(description)
66
  prompt = prompt_template.format(description=processed_description, columns=",".join(columns))
@@ -93,14 +74,10 @@ def generate_synthetic_data(description, columns):
93
 
94
  def process_generated_data(csv_data, expected_columns):
95
  try:
96
- # Ensure the data is cleaned and correctly formatted
97
  cleaned_data = csv_data.replace('\r\n', '\n').replace('\r', '\n')
98
  data = StringIO(cleaned_data)
99
-
100
- # Read the CSV data
101
  df = pd.read_csv(data, delimiter=',')
102
 
103
- # Check if the DataFrame has the expected columns
104
  if set(df.columns) != set(expected_columns):
105
  print(f"Unexpected columns in the generated data: {df.columns}")
106
  return None
@@ -115,11 +92,12 @@ def generate_large_synthetic_data(description, columns, num_rows=1000, rows_per_
115
 
116
  for _ in tqdm(range(num_rows // rows_per_generation), desc="Generating Data"):
117
  generated_data = generate_synthetic_data(description, columns)
 
118
 
119
- if df_synthetic is not None and not df_synthetic.empty:
120
- data_frames.append(df_synthetic)
121
- else:
122
- print("Skipping invalid generation.")
123
 
124
  if data_frames:
125
  return pd.concat(data_frames, ignore_index=True)
@@ -140,12 +118,10 @@ def generate_data(request: DataGenerationRequest):
140
  if csv_data.empty:
141
  return JSONResponse(content={"error": "No valid data generated"}, status_code=500)
142
 
143
- # Convert the DataFrame to CSV format
144
  csv_buffer = StringIO()
145
  csv_data.to_csv(csv_buffer, index=False)
146
  csv_buffer.seek(0)
147
 
148
- # Return the CSV data as a downloadable file
149
  return StreamingResponse(
150
  csv_buffer,
151
  media_type="text/csv",
@@ -155,3 +131,4 @@ def generate_data(request: DataGenerationRequest):
155
  @app.get("/")
156
  def greet_json():
157
  return {"Hello": "World!"}
 
 
14
 
15
  app.add_middleware(
16
  CORSMiddleware,
17
+ allow_origins=["*"],
18
  allow_credentials=True,
19
  allow_methods=["*"],
20
  allow_headers=["*"],
21
  )
22
 
 
23
  hf_token = os.getenv('HF_API_TOKEN')
 
24
  if not hf_token:
25
  raise ValueError("Hugging Face API token is not set. Please set the HF_API_TOKEN environment variable.")
26
 
 
27
  tokenizer_gpt2 = GPT2Tokenizer.from_pretrained('gpt2')
28
  model_gpt2 = GPT2LMHeadModel.from_pretrained('gpt2')
 
 
29
  text_generator = pipeline("text-generation", model=model_gpt2, tokenizer=tokenizer_gpt2)
30
 
 
 
 
 
 
 
31
  prompt_template = """\
32
  You are an expert in generating synthetic data for machine learning models.
33
  Your task is to generate a synthetic tabular dataset based on the description provided below.
34
  Description: {description}
35
  The dataset should include the following columns: {columns}
36
  Please provide the data in CSV format.
37
+ """
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  tokenizer_mixtral = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-Instruct-v0.1", token=hf_token)
40
 
41
+ def preprocess_user_prompt(user_prompt):
42
+ generated_text = text_generator(user_prompt, max_length=50, num_return_sequences=1)[0]["generated_text"]
43
+ return generated_text
44
+
45
  def format_prompt(description, columns):
46
  processed_description = preprocess_user_prompt(description)
47
  prompt = prompt_template.format(description=processed_description, columns=",".join(columns))
 
74
 
75
  def process_generated_data(csv_data, expected_columns):
76
  try:
 
77
  cleaned_data = csv_data.replace('\r\n', '\n').replace('\r', '\n')
78
  data = StringIO(cleaned_data)
 
 
79
  df = pd.read_csv(data, delimiter=',')
80
 
 
81
  if set(df.columns) != set(expected_columns):
82
  print(f"Unexpected columns in the generated data: {df.columns}")
83
  return None
 
92
 
93
  for _ in tqdm(range(num_rows // rows_per_generation), desc="Generating Data"):
94
  generated_data = generate_synthetic_data(description, columns)
95
+ df_synthetic = process_generated_data(generated_data, columns)
96
 
97
+ if df_synthetic is not None and not df_synthetic.empty:
98
+ data_frames.append(df_synthetic)
99
+ else:
100
+ print("Skipping invalid generation.")
101
 
102
  if data_frames:
103
  return pd.concat(data_frames, ignore_index=True)
 
118
  if csv_data.empty:
119
  return JSONResponse(content={"error": "No valid data generated"}, status_code=500)
120
 
 
121
  csv_buffer = StringIO()
122
  csv_data.to_csv(csv_buffer, index=False)
123
  csv_buffer.seek(0)
124
 
 
125
  return StreamingResponse(
126
  csv_buffer,
127
  media_type="text/csv",
 
131
  @app.get("/")
132
  def greet_json():
133
  return {"Hello": "World!"}
134
+