Spaces:
Running
Running
import gradio as gr | |
import pandas as pd | |
from transformers import GPT2LMHeadModel, GPT2Tokenizer, AutoModelForCausalLM | |
from io import StringIO | |
import os | |
import torch | |
from huggingface_hub import HfFolder | |
# Access the Hugging Face API token from environment variables | |
hf_token = os.getenv('HF_API_TOKEN') | |
if not hf_token: | |
raise ValueError("Hugging Face API token is not set. Please set the HF_API_TOKEN environment variable.") | |
HfFolder.save_token(hf_token) | |
# Load the GPT-2 tokenizer and model | |
tokenizer_gpt2 = GPT2Tokenizer.from_pretrained('gpt2') | |
model_gpt2 = GPT2LMHeadModel.from_pretrained('gpt2') | |
# Load the Llama3 model in sharded mode | |
model_name = "meta-llama/Meta-Llama-3.1-8B" | |
try: | |
model_llama = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16, | |
load_in_8bit=True. token = hf_token) # use device_map for automatic sharding | |
except OSError as e: | |
print(f"Error loading model: {e}") | |
# Define your prompt template | |
prompt_template = """\ | |
You are an expert in generating synthetic data for machine learning models. | |
Your task is to generate a synthetic tabular dataset based on the description provided below. | |
Description: {description} | |
The dataset should include the following columns: {columns} | |
Please provide the data in CSV format with a minimum of 100 rows per generation. | |
Ensure that the data is realistic, does not contain any duplicate rows, and follows any specific conditions mentioned. | |
Example Description: | |
Generate a dataset for predicting house prices with columns: 'Size', 'Location', 'Number of Bedrooms', 'Price' | |
Example Output: | |
Size,Location,Number of Bedrooms,Price | |
1200,Suburban,3,250000 | |
900,Urban,2,200000 | |
1500,Rural,4,300000 | |
... | |
Description: | |
{description} | |
Columns: | |
{columns} | |
Output: """ | |
def preprocess_user_prompt(user_prompt): | |
generated_text = model_gpt2.generate(tokenizer_gpt2.encode(user_prompt, return_tensors='pt'), max_length=60)[0] | |
return tokenizer_gpt2.decode(generated_text, skip_special_tokens=True) | |
def format_prompt(description, columns): | |
processed_description = preprocess_user_prompt(description) | |
prompt = prompt_template.format(description=processed_description, columns=",".join(columns)) | |
return prompt | |
def generate_synthetic_data(description, columns): | |
try: | |
formatted_prompt = format_prompt(description, columns) | |
inputs = tokenizer_llama(formatted_prompt, return_tensors="pt") | |
generated_output = model_llama.generate(**inputs, max_new_tokens=512) | |
generated_text = tokenizer_llama.decode(generated_output[0], skip_special_tokens=True) | |
return generated_text | |
except Exception as e: | |
print(f"Error in generate_synthetic_data: {e}") | |
return f"Error: {e}" | |
def generate_large_synthetic_data(description, columns, num_rows=1000, rows_per_generation=100): | |
data_frames = [] | |
num_iterations = num_rows // rows_per_generation | |
for _ in range(num_iterations): | |
generated_data = generate_synthetic_data(description, columns) | |
if "Error" in generated_data: | |
return generated_data | |
df_synthetic = process_generated_data(generated_data) | |
data_frames.append(df_synthetic) | |
return pd.concat(data_frames, ignore_index=True) | |
def process_generated_data(csv_data): | |
data = StringIO(csv_data) | |
df = pd.read_csv(data) | |
return df | |
def main(description, columns): | |
description = description.strip() | |
columns = [col.strip() for col in columns.split(',')] | |
df_synthetic = generate_large_synthetic_data(description, columns) | |
if isinstance(df_synthetic, str) and "Error" in df_synthetic: | |
return df_synthetic # Return the error message if any | |
return df_synthetic.to_csv(index=False) | |
iface = gr.Interface( | |
fn=main, | |
inputs=[ | |
gr.Textbox(label="Description", placeholder="e.g., Generate a dataset for predicting students' grades"), | |
gr.Textbox(label="Columns (comma-separated)", placeholder="e.g., name, age, course, grade") | |
], | |
outputs="text", | |
title="Synthetic Data Generator", | |
description="Generate synthetic tabular datasets based on a description and specified columns." | |
) | |
iface.api_name = "generate" | |
# Run the Gradio app | |
iface.launch(server_name="0.0.0.0", server_port=7860) | |