|
import streamlit as st |
|
import pandas as pd |
|
import os |
|
import torch |
|
import transformers |
|
from transformers import GPT2LMHeadModel, GPT2Tokenizer, AutoTokenizer, AutoModelForCausalLM, pipeline |
|
from huggingface_hub import HfFolder |
|
from io import StringIO |
|
from tqdm import tqdm |
|
import accelerate |
|
from accelerate import init_empty_weights, infer_auto_device_map, dispatch_model, disk_offload |
|
|
|
|
|
hf_token = os.getenv('HF_API_TOKEN') |
|
|
|
if not hf_token: |
|
raise ValueError("Hugging Face API token is not set. Please set the HF_API_TOKEN environment variable.") |
|
HfFolder.save_token(hf_token) |
|
|
|
|
|
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0' |
|
|
|
|
|
tokenizer_gpt2 = GPT2Tokenizer.from_pretrained('gpt2') |
|
model_gpt2 = GPT2LMHeadModel.from_pretrained('gpt2') |
|
|
|
|
|
text_generator = pipeline("text-generation", model=model_gpt2, tokenizer=tokenizer_gpt2) |
|
|
|
|
|
tokenizer_llama = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B", token=hf_token) |
|
model_llama = AutoModelForCausalLM.from_pretrained( |
|
"meta-llama/Meta-Llama-3.1-8B", |
|
torch_dtype='auto', |
|
device_map='auto', |
|
token=hf_token |
|
) |
|
|
|
|
|
prompt_template = """\ |
|
You are an expert in generating synthetic data for machine learning models. |
|
Your task is to generate a synthetic tabular dataset based on the description provided below. |
|
Description: {description} |
|
The dataset should include the following columns: {columns} |
|
Please provide the data in CSV format with a minimum of 100 rows per generation. |
|
Ensure that the data is realistic, does not contain any duplicate rows, and follows any specific conditions mentioned. |
|
Example Description: |
|
Generate a dataset for predicting house prices with columns: 'Size', 'Location', 'Number of Bedrooms', 'Price' |
|
Example Output: |
|
Size,Location,Number of Bedrooms,Price |
|
1200,Suburban,3,250000 |
|
900,Urban,2,200000 |
|
1500,Rural,4,300000 |
|
... |
|
Description: |
|
{description} |
|
Columns: |
|
{columns} |
|
Output: """ |
|
|
|
def preprocess_user_prompt(user_prompt): |
|
generated_text = text_generator(user_prompt, max_length=60, num_return_sequences=1, truncation=True)[0]["generated_text"] |
|
return generated_text |
|
|
|
def format_prompt(description, columns): |
|
processed_description = preprocess_user_prompt(description) |
|
prompt = prompt_template.format(description=processed_description, columns=",".join(columns)) |
|
return prompt |
|
|
|
generation_params = { |
|
"top_p": 0.90, |
|
"temperature": 0.8, |
|
"max_new_tokens": 512, |
|
"return_full_text": False, |
|
"use_cache": False |
|
} |
|
|
|
|
|
def generate_synthetic_data(description, columns): |
|
try: |
|
formatted_prompt = format_prompt(description, columns) |
|
|
|
|
|
inputs = tokenizer_llama(formatted_prompt, return_tensors="pt", truncation=True, max_length=512) |
|
|
|
|
|
inputs = {k: v.to(model_llama.device) for k, v in inputs.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model_llama.generate( |
|
**inputs, |
|
max_length=512, |
|
top_p=generation_params["top_p"], |
|
temperature=generation_params["temperature"], |
|
num_return_sequences=1, |
|
) |
|
|
|
|
|
if outputs.is_meta: |
|
raise ValueError("Output tensor is in meta state, check model and input.") |
|
|
|
|
|
generated_text = tokenizer_llama.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
return generated_text |
|
except Exception as e: |
|
return f"Error: {e}" |
|
|
|
|
|
def generate_large_synthetic_data(description, columns, num_rows=1000, rows_per_generation=100): |
|
data_frames = [] |
|
num_iterations = num_rows // rows_per_generation |
|
|
|
|
|
progress_bar = st.progress(0) |
|
|
|
for i in tqdm(range(num_iterations)): |
|
generated_data = generate_synthetic_data(description, columns) |
|
print("Generated Data:\n", generated_data) |
|
|
|
if "Error" in generated_data: |
|
return generated_data |
|
|
|
df_synthetic = process_generated_data(generated_data) |
|
data_frames.append(df_synthetic) |
|
|
|
|
|
progress_bar.progress((i + 1) / num_iterations) |
|
|
|
return pd.concat(data_frames, ignore_index=True) |
|
|
|
def process_generated_data(csv_data): |
|
try: |
|
|
|
if not csv_data.strip(): |
|
raise ValueError("Generated data is empty.") |
|
|
|
data = StringIO(csv_data) |
|
df = pd.read_csv(data) |
|
print("DataFrame Shape:", df.shape) |
|
print("DataFrame Head:\n", df.head()) |
|
|
|
|
|
if df.empty: |
|
raise ValueError("Generated DataFrame is empty.") |
|
|
|
return df |
|
except Exception as e: |
|
st.error(f"Error processing generated data: {e}") |
|
return pd.DataFrame() |
|
|
|
|
|
|
|
st.title("Synthetic Data Generator") |
|
description = st.text_input("Description", "e.g., Generate a dataset for predicting students' grades") |
|
columns = st.text_input("Columns (comma-separated)", "e.g., name, age, course, grade") |
|
|
|
if st.button("Generate"): |
|
description = description.strip() |
|
columns = [col.strip() for col in columns.split(',')] |
|
df_synthetic = generate_large_synthetic_data(description, columns) |
|
|
|
if isinstance(df_synthetic, str) and "Error" in df_synthetic: |
|
st.error(df_synthetic) |
|
else: |
|
st.success("Synthetic Data Generated!") |
|
st.dataframe(df_synthetic) |
|
st.download_button( |
|
label="Download CSV", |
|
data=df_synthetic.to_csv(index=False), |
|
file_name="synthetic_data.csv", |
|
mime="text/csv" |
|
) |
|
|