yakine commited on
Commit
3476a0f
·
verified ·
1 Parent(s): 1743f62

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -14
app.py CHANGED
@@ -27,13 +27,17 @@ if not hf_token:
27
  raise ValueError("Hugging Face API token is not set. Please set the HF_API_TOKEN environment variable.")
28
 
29
  # Load GPT-2 model and tokenizer
30
- tokenizer_gpt2 = GPT2LMHeadModel.from_pretrained('gpt2')
31
  model_gpt2 = GPT2LMHeadModel.from_pretrained('gpt2')
32
 
33
-
34
  # Create a pipeline for text generation using GPT-2
35
  text_generator = pipeline("text-generation", model=model_gpt2, tokenizer=tokenizer_gpt2)
36
 
 
 
 
 
 
37
  # Define prompt template
38
  prompt_template = """\
39
  You are an expert in generating synthetic data for machine learning models.
@@ -62,12 +66,14 @@ Columns:
62
  {columns}
63
  Output: """
64
 
65
- # Set up the Mixtral model and tokenizer
66
- token = os.getenv("HF_TOKEN")
67
- HfFolder.save_token(token)
68
 
69
  tokenizer_mixtral = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-Instruct-v0.1", token=token)
70
 
 
 
 
 
 
71
  API_URL = "https://api-inference.huggingface.co/models/mistralai/Mixtral-8x7B-Instruct-v0.1"
72
 
73
  generation_params = {
@@ -78,15 +84,6 @@ generation_params = {
78
  "use_cache": False
79
  }
80
 
81
- def preprocess_user_prompt(user_prompt):
82
- generated_text = text_generator(user_prompt, max_length=50, num_return_sequences=1)[0]["generated_text"]
83
- return generated_text
84
-
85
- def format_prompt(description, columns):
86
- processed_description = preprocess_user_prompt(description)
87
- prompt = prompt_template.format(description=processed_description, columns=",".join(columns))
88
- return prompt
89
-
90
  def generate_synthetic_data(description, columns):
91
  formatted_prompt = format_prompt(description, columns)
92
  payload = {"inputs": formatted_prompt, "parameters": generation_params}
@@ -95,12 +92,18 @@ def generate_synthetic_data(description, columns):
95
 
96
  def process_generated_data(csv_data, expected_columns):
97
  try:
 
98
  cleaned_data = csv_data.replace('\r\n', '\n').replace('\r', '\n')
99
  data = StringIO(cleaned_data)
 
 
100
  df = pd.read_csv(data, delimiter=',')
 
 
101
  if set(df.columns) != set(expected_columns):
102
  print(f"Unexpected columns in the generated data: {df.columns}")
103
  return None
 
104
  return df
105
  except pd.errors.ParserError as e:
106
  print(f"Failed to parse CSV data: {e}")
@@ -108,19 +111,24 @@ def process_generated_data(csv_data, expected_columns):
108
 
109
  def generate_large_synthetic_data(description, columns, num_rows=1000, rows_per_generation=100):
110
  data_frames = []
 
111
  for _ in tqdm(range(num_rows // rows_per_generation), desc="Generating Data"):
112
  generated_data = generate_synthetic_data(description, columns)
113
  df_synthetic = process_generated_data(generated_data, columns)
 
114
  if df_synthetic is not None and not df_synthetic.empty:
115
  data_frames.append(df_synthetic)
116
  else:
117
  print("Skipping invalid generation.")
 
118
  if data_frames:
119
  return pd.concat(data_frames, ignore_index=True)
120
  else:
121
  print("No valid data frames to concatenate.")
122
  return pd.DataFrame(columns=columns)
123
 
 
 
124
  @app.route('/generate', methods=['POST'])
125
  def generate():
126
  data = request.json
 
27
  raise ValueError("Hugging Face API token is not set. Please set the HF_API_TOKEN environment variable.")
28
 
29
  # Load GPT-2 model and tokenizer
30
+ tokenizer_gpt2 = GPT2Tokenizer.from_pretrained('gpt2')
31
  model_gpt2 = GPT2LMHeadModel.from_pretrained('gpt2')
32
 
 
33
  # Create a pipeline for text generation using GPT-2
34
  text_generator = pipeline("text-generation", model=model_gpt2, tokenizer=tokenizer_gpt2)
35
 
36
+ def preprocess_user_prompt(user_prompt):
37
+ # Generate a structured prompt based on the user input
38
+ generated_text = text_generator(user_prompt, max_length=50, num_return_sequences=1)[0]["generated_text"]
39
+ return generated_text
40
+
41
  # Define prompt template
42
  prompt_template = """\
43
  You are an expert in generating synthetic data for machine learning models.
 
66
  {columns}
67
  Output: """
68
 
 
 
 
69
 
70
  tokenizer_mixtral = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-Instruct-v0.1", token=token)
71
 
72
+ def format_prompt(description, columns):
73
+ processed_description = preprocess_user_prompt(description)
74
+ prompt = prompt_template.format(description=processed_description, columns=",".join(columns))
75
+ return prompt
76
+
77
  API_URL = "https://api-inference.huggingface.co/models/mistralai/Mixtral-8x7B-Instruct-v0.1"
78
 
79
  generation_params = {
 
84
  "use_cache": False
85
  }
86
 
 
 
 
 
 
 
 
 
 
87
  def generate_synthetic_data(description, columns):
88
  formatted_prompt = format_prompt(description, columns)
89
  payload = {"inputs": formatted_prompt, "parameters": generation_params}
 
92
 
93
  def process_generated_data(csv_data, expected_columns):
94
  try:
95
+ # Ensure the data is cleaned and correctly formatted
96
  cleaned_data = csv_data.replace('\r\n', '\n').replace('\r', '\n')
97
  data = StringIO(cleaned_data)
98
+
99
+ # Read the CSV data
100
  df = pd.read_csv(data, delimiter=',')
101
+
102
+ # Check if the DataFrame has the expected columns
103
  if set(df.columns) != set(expected_columns):
104
  print(f"Unexpected columns in the generated data: {df.columns}")
105
  return None
106
+
107
  return df
108
  except pd.errors.ParserError as e:
109
  print(f"Failed to parse CSV data: {e}")
 
111
 
112
  def generate_large_synthetic_data(description, columns, num_rows=1000, rows_per_generation=100):
113
  data_frames = []
114
+
115
  for _ in tqdm(range(num_rows // rows_per_generation), desc="Generating Data"):
116
  generated_data = generate_synthetic_data(description, columns)
117
  df_synthetic = process_generated_data(generated_data, columns)
118
+
119
  if df_synthetic is not None and not df_synthetic.empty:
120
  data_frames.append(df_synthetic)
121
  else:
122
  print("Skipping invalid generation.")
123
+
124
  if data_frames:
125
  return pd.concat(data_frames, ignore_index=True)
126
  else:
127
  print("No valid data frames to concatenate.")
128
  return pd.DataFrame(columns=columns)
129
 
130
+
131
+
132
  @app.route('/generate', methods=['POST'])
133
  def generate():
134
  data = request.json