yakine's picture
Update app.py
b6c092a verified
raw
history blame
4.27 kB
import gradio as gr
import pandas as pd
from transformers import GPT2LMHeadModel, GPT2Tokenizer, AutoModelForCausalLM
from io import StringIO
import os
import torch
from huggingface_hub import HfFolder
# Access the Hugging Face API token from environment variables
hf_token = os.getenv('HF_API_TOKEN')
if not hf_token:
raise ValueError("Hugging Face API token is not set. Please set the HF_API_TOKEN environment variable.")
HfFolder.save_token(hf_token)
# Load the GPT-2 tokenizer and model
tokenizer_gpt2 = GPT2Tokenizer.from_pretrained('gpt2')
model_gpt2 = GPT2LMHeadModel.from_pretrained('gpt2')
# Load the Llama3 model in sharded mode
model_name = "meta-llama/Meta-Llama-3.1-8B"
try:
model_llama = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16,
load_in_8bit=True. token = hf_token) # use device_map for automatic sharding
except OSError as e:
print(f"Error loading model: {e}")
# Define your prompt template
prompt_template = """\
You are an expert in generating synthetic data for machine learning models.
Your task is to generate a synthetic tabular dataset based on the description provided below.
Description: {description}
The dataset should include the following columns: {columns}
Please provide the data in CSV format with a minimum of 100 rows per generation.
Ensure that the data is realistic, does not contain any duplicate rows, and follows any specific conditions mentioned.
Example Description:
Generate a dataset for predicting house prices with columns: 'Size', 'Location', 'Number of Bedrooms', 'Price'
Example Output:
Size,Location,Number of Bedrooms,Price
1200,Suburban,3,250000
900,Urban,2,200000
1500,Rural,4,300000
...
Description:
{description}
Columns:
{columns}
Output: """
def preprocess_user_prompt(user_prompt):
generated_text = model_gpt2.generate(tokenizer_gpt2.encode(user_prompt, return_tensors='pt'), max_length=60)[0]
return tokenizer_gpt2.decode(generated_text, skip_special_tokens=True)
def format_prompt(description, columns):
processed_description = preprocess_user_prompt(description)
prompt = prompt_template.format(description=processed_description, columns=",".join(columns))
return prompt
def generate_synthetic_data(description, columns):
try:
formatted_prompt = format_prompt(description, columns)
inputs = tokenizer_llama(formatted_prompt, return_tensors="pt")
generated_output = model_llama.generate(**inputs, max_new_tokens=512)
generated_text = tokenizer_llama.decode(generated_output[0], skip_special_tokens=True)
return generated_text
except Exception as e:
print(f"Error in generate_synthetic_data: {e}")
return f"Error: {e}"
def generate_large_synthetic_data(description, columns, num_rows=1000, rows_per_generation=100):
data_frames = []
num_iterations = num_rows // rows_per_generation
for _ in range(num_iterations):
generated_data = generate_synthetic_data(description, columns)
if "Error" in generated_data:
return generated_data
df_synthetic = process_generated_data(generated_data)
data_frames.append(df_synthetic)
return pd.concat(data_frames, ignore_index=True)
def process_generated_data(csv_data):
data = StringIO(csv_data)
df = pd.read_csv(data)
return df
def main(description, columns):
description = description.strip()
columns = [col.strip() for col in columns.split(',')]
df_synthetic = generate_large_synthetic_data(description, columns)
if isinstance(df_synthetic, str) and "Error" in df_synthetic:
return df_synthetic # Return the error message if any
return df_synthetic.to_csv(index=False)
iface = gr.Interface(
fn=main,
inputs=[
gr.Textbox(label="Description", placeholder="e.g., Generate a dataset for predicting students' grades"),
gr.Textbox(label="Columns (comma-separated)", placeholder="e.g., name, age, course, grade")
],
outputs="text",
title="Synthetic Data Generator",
description="Generate synthetic tabular datasets based on a description and specified columns."
)
iface.api_name = "generate"
# Run the Gradio app
iface.launch(server_name="0.0.0.0", server_port=7860)