Spaces:
Sleeping
Sleeping
app.py creation
Browse files
app.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import pandas as pd
|
3 |
+
import requests
|
4 |
+
from transformers import GPT2LMHeadModel, GPT2Tokenizer, LlamaTokenizer, LlamaForCausalLM, pipeline
|
5 |
+
from huggingface_hub import HfFolder, login
|
6 |
+
from io import StringIO
|
7 |
+
|
8 |
+
# Load GPT-2 model and tokenizer
|
9 |
+
tokenizer_gpt2 = GPT2Tokenizer.from_pretrained('gpt2')
|
10 |
+
model_gpt2 = GPT2LMHeadModel.from_pretrained('gpt2')
|
11 |
+
|
12 |
+
# Create a pipeline for text generation using GPT-2
|
13 |
+
text_generator = pipeline("text-generation", model=model_gpt2)
|
14 |
+
|
15 |
+
# Load the LLaMA tokenizer
|
16 |
+
tokenizer_llama = LlamaTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B")
|
17 |
+
|
18 |
+
# Define your prompt template
|
19 |
+
prompt_template = """\
|
20 |
+
You are an expert in generating synthetic data for machine learning models.
|
21 |
+
|
22 |
+
Your task is to generate a synthetic tabular dataset based on the description provided below.
|
23 |
+
|
24 |
+
Description: {description}
|
25 |
+
|
26 |
+
The dataset should include the following columns: {columns}
|
27 |
+
|
28 |
+
Please provide the data in CSV format with a minimum of 100 rows per generation.
|
29 |
+
Ensure that the data is realistic, does not contain any duplicate rows, and follows any specific conditions mentioned.
|
30 |
+
|
31 |
+
Example Description:
|
32 |
+
Generate a dataset for predicting house prices with columns: 'Size', 'Location', 'Number of Bedrooms', 'Price'
|
33 |
+
|
34 |
+
Example Output:
|
35 |
+
Size,Location,Number of Bedrooms,Price
|
36 |
+
1200,Suburban,3,250000
|
37 |
+
900,Urban,2,200000
|
38 |
+
1500,Rural,4,300000
|
39 |
+
...
|
40 |
+
|
41 |
+
Description:
|
42 |
+
{description}
|
43 |
+
Columns:
|
44 |
+
{columns}
|
45 |
+
Output: """
|
46 |
+
|
47 |
+
def preprocess_user_prompt(user_prompt):
|
48 |
+
generated_text = text_generator(user_prompt, max_length=50, num_return_sequences=1)[0]["generated_text"]
|
49 |
+
return generated_text
|
50 |
+
|
51 |
+
def format_prompt(description, columns):
|
52 |
+
processed_description = preprocess_user_prompt(description)
|
53 |
+
prompt = prompt_template.format(description=processed_description, columns=",".join(columns))
|
54 |
+
return prompt
|
55 |
+
|
56 |
+
API_URL = "https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3.1-8B"
|
57 |
+
|
58 |
+
generation_params = {
|
59 |
+
"top_p": 0.90,
|
60 |
+
"temperature": 0.8,
|
61 |
+
"max_new_tokens": 512,
|
62 |
+
"return_full_text": False,
|
63 |
+
"use_cache": False
|
64 |
+
}
|
65 |
+
|
66 |
+
def generate_synthetic_data(description, columns):
|
67 |
+
formatted_prompt = format_prompt(description, columns)
|
68 |
+
payload = {"inputs": formatted_prompt, "parameters": generation_params}
|
69 |
+
headers = {"Authorization": f"Bearer {HfFolder.get_token()}"}
|
70 |
+
response = requests.post(API_URL, headers=headers, json=payload)
|
71 |
+
|
72 |
+
if response.status_code == 200:
|
73 |
+
response_json = response.json()
|
74 |
+
if isinstance(response_json, list) and len(response_json) > 0 and "generated_text" in response_json[0]:
|
75 |
+
return response_json[0]["generated_text"]
|
76 |
+
else:
|
77 |
+
raise ValueError("Unexpected response format or missing 'generated_text' key")
|
78 |
+
else:
|
79 |
+
raise ValueError(f"API request failed with status code {response.status_code}: {response.text}")
|
80 |
+
|
81 |
+
def generate_large_synthetic_data(description, columns, num_rows=1000, rows_per_generation=100):
|
82 |
+
data_frames = []
|
83 |
+
num_iterations = num_rows // rows_per_generation
|
84 |
+
|
85 |
+
for _ in range(num_iterations):
|
86 |
+
generated_data = generate_synthetic_data(description, columns)
|
87 |
+
df_synthetic = process_generated_data(generated_data)
|
88 |
+
data_frames.append(df_synthetic)
|
89 |
+
|
90 |
+
return pd.concat(data_frames, ignore_index=True)
|
91 |
+
|
92 |
+
def process_generated_data(csv_data):
|
93 |
+
data = StringIO(csv_data)
|
94 |
+
df = pd.read_csv(data)
|
95 |
+
return df
|
96 |
+
|
97 |
+
def main(description, columns):
|
98 |
+
description = description.strip()
|
99 |
+
columns = [col.strip() for col in columns.split(',')]
|
100 |
+
df_synthetic = generate_large_synthetic_data(description, columns)
|
101 |
+
return df_synthetic.to_csv(index=False)
|
102 |
+
|
103 |
+
# Gradio interface
|
104 |
+
iface = gr.Interface(
|
105 |
+
fn=main,
|
106 |
+
inputs=[
|
107 |
+
gr.Textbox(label="Description", placeholder="e.g., Generate a dataset for predicting students' grades"),
|
108 |
+
gr.Textbox(label="Columns (comma-separated)", placeholder="e.g., name, age, course, grade")
|
109 |
+
],
|
110 |
+
outputs="text",
|
111 |
+
title="Synthetic Data Generator",
|
112 |
+
description="Generate synthetic tabular datasets based on a description and specified columns."
|
113 |
+
)
|
114 |
+
|
115 |
+
# Run the Gradio app
|
116 |
+
iface.launch()
|