model / app.py
yakine's picture
Update app.py
52cec89 verified
import streamlit as st
import pandas as pd
import os
import torch
import transformers
from transformers import GPT2LMHeadModel, GPT2Tokenizer, AutoTokenizer, AutoModelForCausalLM, pipeline
from huggingface_hub import HfFolder
from io import StringIO
from tqdm import tqdm
import accelerate
from accelerate import init_empty_weights, infer_auto_device_map, dispatch_model, disk_offload
# Access the Hugging Face API token from environment variables
hf_token = os.getenv('HF_API_TOKEN')
if not hf_token:
raise ValueError("Hugging Face API token is not set. Please set the HF_API_TOKEN environment variable.")
HfFolder.save_token(hf_token)
# Set environment variable to avoid floating-point errors
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
# Load the GPT-2 tokenizer and model
tokenizer_gpt2 = GPT2Tokenizer.from_pretrained('gpt2')
model_gpt2 = GPT2LMHeadModel.from_pretrained('gpt2')
# Create a pipeline for text generation using GPT-2
text_generator = pipeline("text-generation", model=model_gpt2, tokenizer=tokenizer_gpt2)
# Load the Llama-3 model and tokenizer once during startup
tokenizer_llama = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B", token=hf_token)
model_llama = AutoModelForCausalLM.from_pretrained(
"meta-llama/Meta-Llama-3.1-8B",
torch_dtype='auto',
device_map='auto',
token=hf_token
)
# Define your prompt template
prompt_template = """\
You are an expert in generating synthetic data for machine learning models.
Your task is to generate a synthetic tabular dataset based on the description provided below.
Description: {description}
The dataset should include the following columns: {columns}
Please provide the data in CSV format with a minimum of 100 rows per generation.
Ensure that the data is realistic, does not contain any duplicate rows, and follows any specific conditions mentioned.
Example Description:
Generate a dataset for predicting house prices with columns: 'Size', 'Location', 'Number of Bedrooms', 'Price'
Example Output:
Size,Location,Number of Bedrooms,Price
1200,Suburban,3,250000
900,Urban,2,200000
1500,Rural,4,300000
...
Description:
{description}
Columns:
{columns}
Output: """
def preprocess_user_prompt(user_prompt):
generated_text = text_generator(user_prompt, max_length=60, num_return_sequences=1, truncation=True)[0]["generated_text"]
return generated_text
def format_prompt(description, columns):
processed_description = preprocess_user_prompt(description)
prompt = prompt_template.format(description=processed_description, columns=",".join(columns))
return prompt
generation_params = {
"top_p": 0.90,
"temperature": 0.8,
"max_new_tokens": 512,
"return_full_text": False,
"use_cache": False
}
# Generate synthetic data
def generate_synthetic_data(description, columns):
try:
formatted_prompt = format_prompt(description, columns)
# Tokenize the prompt with truncation enabled
inputs = tokenizer_llama(formatted_prompt, return_tensors="pt", truncation=True, max_length=512)
# Move inputs to the correct device
inputs = {k: v.to(model_llama.device) for k, v in inputs.items()}
# Generate synthetic data
with torch.no_grad():
outputs = model_llama.generate(
**inputs,
max_length=512,
top_p=generation_params["top_p"],
temperature=generation_params["temperature"],
num_return_sequences=1,
)
# Check for meta tensor before decoding
if outputs.is_meta:
raise ValueError("Output tensor is in meta state, check model and input.")
# Decode the generated output
generated_text = tokenizer_llama.decode(outputs[0], skip_special_tokens=True)
# Return the generated synthetic data
return generated_text
except Exception as e:
return f"Error: {e}"
def generate_large_synthetic_data(description, columns, num_rows=1000, rows_per_generation=100):
data_frames = []
num_iterations = num_rows // rows_per_generation
# Create a progress bar
progress_bar = st.progress(0)
for i in tqdm(range(num_iterations)):
generated_data = generate_synthetic_data(description, columns)
print("Generated Data:\n", generated_data) # Move the print statement here
if "Error" in generated_data:
return generated_data
df_synthetic = process_generated_data(generated_data)
data_frames.append(df_synthetic)
# Update the progress bar
progress_bar.progress((i + 1) / num_iterations)
return pd.concat(data_frames, ignore_index=True)
def process_generated_data(csv_data):
try:
# Check if the data is not empty and has valid content
if not csv_data.strip():
raise ValueError("Generated data is empty.")
data = StringIO(csv_data)
df = pd.read_csv(data)
print("DataFrame Shape:", df.shape)
print("DataFrame Head:\n", df.head())
# Check if the DataFrame is empty
if df.empty:
raise ValueError("Generated DataFrame is empty.")
return df
except Exception as e:
st.error(f"Error processing generated data: {e}")
return pd.DataFrame() # Return an empty DataFrame on error
# Streamlit app interface
st.title("Synthetic Data Generator")
description = st.text_input("Description", "e.g., Generate a dataset for predicting students' grades")
columns = st.text_input("Columns (comma-separated)", "e.g., name, age, course, grade")
if st.button("Generate"):
description = description.strip()
columns = [col.strip() for col in columns.split(',')]
df_synthetic = generate_large_synthetic_data(description, columns)
if isinstance(df_synthetic, str) and "Error" in df_synthetic:
st.error(df_synthetic) # Display error message if any
else:
st.success("Synthetic Data Generated!")
st.dataframe(df_synthetic) # Display the generated DataFrame
st.download_button(
label="Download CSV",
data=df_synthetic.to_csv(index=False),
file_name="synthetic_data.csv",
mime="text/csv"
)