yakine commited on
Commit
c7b1e29
·
verified ·
1 Parent(s): be8d2d5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -61
app.py CHANGED
@@ -24,103 +24,94 @@ hf_token = os.getenv('HF_API_TOKEN')
24
  if not hf_token:
25
  raise ValueError("Hugging Face API token is not set. Please set the HF_API_TOKEN environment variable.")
26
 
27
- tokenizer_gpt2 = GPT2Tokenizer.from_pretrained('gpt2')
 
 
28
  model_gpt2 = GPT2LMHeadModel.from_pretrained('gpt2')
 
 
29
  text_generator = pipeline("text-generation", model=model_gpt2, tokenizer=tokenizer_gpt2)
30
 
 
31
  prompt_template = """\
32
- You are an expert in generating synthetic data for machine learning models.
33
 
34
- Your task is to generate a synthetic tabular dataset based on the description provided below.
35
 
36
  Description: {description}
37
 
38
- The dataset should include the following columns: {columns}
39
-
40
- Please provide the data in CSV format.
41
-
42
- Example Description:
43
- Generate a dataset for predicting house prices with columns: 'Size', 'Location', 'Number of Bedrooms', 'Price'
44
 
45
- Example Output:
46
- Size,Location,Number of Bedrooms,Price
47
- 1200,Suburban,3,250000
48
- 900,Urban,2,200000
49
- 1500,Rural,4,300000
50
- ...
51
 
52
- Description:
53
- {description}
54
- Columns:
55
- {columns}
56
- Output: """
57
 
58
- tokenizer_mixtral = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-Instruct-v0.1", token=hf_token)
59
-
60
- def preprocess_user_prompt(user_prompt):
61
- generated_text = text_generator(user_prompt, max_length=50, num_return_sequences=1)[0]["generated_text"]
62
- return generated_text
63
-
64
- def format_prompt(description, columns):
65
- processed_description = preprocess_user_prompt(description)
66
- prompt = prompt_template.format(description=processed_description, columns=",".join(columns))
67
- return prompt
68
 
69
- API_URL = "https://api-inference.huggingface.co/models/mistralai/Mixtral-8x7B-Instruct-v0.1"
70
 
 
71
  generation_params = {
72
  "top_p": 0.90,
73
  "temperature": 0.8,
74
- "max_new_tokens": 512,
75
  "return_full_text": False,
76
  "use_cache": False
77
  }
78
 
 
 
 
 
79
  def generate_synthetic_data(description, columns):
80
  formatted_prompt = format_prompt(description, columns)
81
  payload = {"inputs": formatted_prompt, "parameters": generation_params}
 
 
 
 
 
 
 
 
 
 
 
 
82
  try:
83
- response = requests.post(API_URL, headers={"Authorization": f"Bearer {hf_token}"}, json=payload)
84
- response.raise_for_status()
85
- data = response.json()
86
- if 'generated_text' in data[0]:
87
- return data[0]['generated_text']
88
- else:
89
- raise ValueError("Invalid response format from Hugging Face API.")
90
- except (requests.RequestException, ValueError) as e:
91
- print(f"Error during API request or response processing: {e}")
92
- return ""
93
-
94
- def process_generated_data(csv_data, expected_columns):
95
- try:
96
- # Replace inconsistent line endings
97
  cleaned_data = csv_data.replace('\r\n', '\n').replace('\r', '\n')
98
-
99
- # Check for common CSV formatting issues and apply corrections
100
- cleaned_data = cleaned_data.strip().replace('|', ',').replace(' ', ' ').replace(' ,', ',')
101
-
102
- # Load the cleaned data into a DataFrame
103
  data = StringIO(cleaned_data)
104
- df = pd.read_csv(data, delimiter=',')
105
-
 
 
106
  return df
107
-
108
  except pd.errors.ParserError as e:
109
  print(f"Failed to parse CSV data: {e}")
110
  return None
111
 
112
  def generate_large_synthetic_data(description, columns, num_rows=1000, rows_per_generation=100):
113
  data_frames = []
114
-
115
  for _ in tqdm(range(num_rows // rows_per_generation), desc="Generating Data"):
116
  generated_data = generate_synthetic_data(description, columns)
117
- df_synthetic = process_generated_data(generated_data, columns)
118
 
119
- if df_synthetic is not None and not df_synthetic.empty:
120
- data_frames.append(df_synthetic)
121
- else:
122
- print("Skipping invalid generation.")
123
-
 
 
 
 
124
  if data_frames:
125
  return pd.concat(data_frames, ignore_index=True)
126
  else:
 
24
  if not hf_token:
25
  raise ValueError("Hugging Face API token is not set. Please set the HF_API_TOKEN environment variable.")
26
 
27
+
28
+ # Load GPT-2 model and tokenizer
29
+ tokenizer_gpt2 = AutoTokenizer.from_pretrained('gpt2')
30
  model_gpt2 = GPT2LMHeadModel.from_pretrained('gpt2')
31
+
32
+ # Create a pipeline for text generation using GPT-2
33
  text_generator = pipeline("text-generation", model=model_gpt2, tokenizer=tokenizer_gpt2)
34
 
35
+ # Define prompt template for generating the dataset
36
  prompt_template = """\
37
+ You are an AI specialized in generating synthetic tabular data specifically for machine learning purposes.
38
 
39
+ Task: Generate a synthetic dataset based on the provided description and column names.
40
 
41
  Description: {description}
42
 
43
+ Columns: {columns}
 
 
 
 
 
44
 
45
+ Instructions:
 
 
 
 
 
46
 
47
+ Output only the tabular data in valid CSV format.
48
+ Include the header row followed by the data rows.
49
+ Do not generate any additional text, explanations, comments, or code.
50
+ Ensure that the values for each column are contextually appropriate.
 
51
 
52
+ Format Example (do not include this line or the following example in your output):
53
+ Column1,Column2,Column3
54
+ Value1,Value2,Value3
55
+ Value4,Value5,Value6
56
+ """
 
 
 
 
 
57
 
 
58
 
59
+ # Define generation parameters
60
  generation_params = {
61
  "top_p": 0.90,
62
  "temperature": 0.8,
63
+ "max_new_tokens": 1024,
64
  "return_full_text": False,
65
  "use_cache": False
66
  }
67
 
68
+ def format_prompt(description, columns):
69
+ prompt = prompt_template.format(description=description, columns=",".join(columns))
70
+ return prompt
71
+
72
  def generate_synthetic_data(description, columns):
73
  formatted_prompt = format_prompt(description, columns)
74
  payload = {"inputs": formatted_prompt, "parameters": generation_params}
75
+
76
+ # Call Mixtral model to generate data
77
+ response = requests.post("https://api-inference.huggingface.co/models/mistralai/Mixtral-8x7B-Instruct-v0.1",
78
+ headers={"Authorization": f"Bearer {token}"}, json=payload)
79
+
80
+ if response.status_code == 200:
81
+ return response.json()[0]["generated_text"]
82
+ else:
83
+ print(f"Error generating data: {response.status_code}, {response.text}")
84
+ return None
85
+
86
+ def process_generated_data(csv_data):
87
  try:
88
+ # Ensure the data is cleaned and correctly formatted
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  cleaned_data = csv_data.replace('\r\n', '\n').replace('\r', '\n')
 
 
 
 
 
90
  data = StringIO(cleaned_data)
91
+
92
+ # Read the CSV data with specific parameters to handle irregularities
93
+ df = pd.read_csv(data)
94
+
95
  return df
 
96
  except pd.errors.ParserError as e:
97
  print(f"Failed to parse CSV data: {e}")
98
  return None
99
 
100
  def generate_large_synthetic_data(description, columns, num_rows=1000, rows_per_generation=100):
101
  data_frames = []
102
+
103
  for _ in tqdm(range(num_rows // rows_per_generation), desc="Generating Data"):
104
  generated_data = generate_synthetic_data(description, columns)
 
105
 
106
+ if generated_data:
107
+
108
+ df_synthetic = process_generated_data(generated_data)
109
+
110
+ if df_synthetic is not None and not df_synthetic.empty:
111
+ data_frames.append(df_synthetic)
112
+ else:
113
+ print("Skipping invalid generation.")
114
+
115
  if data_frames:
116
  return pd.concat(data_frames, ignore_index=True)
117
  else: