yakine's picture
Update app.py
c6ded12 verified
raw
history blame
5.36 kB
import gradio as gr
import pandas as pd
import requests
from transformers import GPT2LMHeadModel, GPT2Tokenizer, AutoTokenizer, pipeline, AutoModelForCausalLM
from huggingface_hub import HfFolder
from io import StringIO
import os
import torch
# Access the Hugging Face API token from environment variables
hf_token = os.getenv('HF_API_TOKEN')
if not hf_token:
raise ValueError("Hugging Face API token is not set. Please set the HF_API_TOKEN environment variable.")
HfFolder.save_token(hf_token)
# Set environment variable to avoid floating-point errors
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
# Load the tokenizer and model
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model_gpt2 = GPT2LMHeadModel.from_pretrained('gpt2')
# Create a pipeline for text generation using GPT-2
text_generator = pipeline("text-generation", model=model_gpt2, tokenizer=tokenizer)
# Lazy loading function for Llama-3 model
model_llama = None
tokenizer_llama = None
def load_llama_model():
global model_llama, tokenizer_llama
if model_llama is None:
model_name = "meta-llama/Meta-Llama-3.1-8B"
model_llama = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16, # Use FP16 for reduced memory
use_auth_token=hf_token
)
tokenizer_llama = AutoTokenizer.from_pretrained(model_name, token=hf_token)
# Define your prompt template
prompt_template = """\
You are an expert in generating synthetic data for machine learning models.
Your task is to generate a synthetic tabular dataset based on the description provided below.
Description: {description}
The dataset should include the following columns: {columns}
Please provide the data in CSV format with a minimum of 100 rows per generation.
Ensure that the data is realistic, does not contain any duplicate rows, and follows any specific conditions mentioned.
Example Description:
Generate a dataset for predicting house prices with columns: 'Size', 'Location', 'Number of Bedrooms', 'Price'
Example Output:
Size,Location,Number of Bedrooms,Price
1200,Suburban,3,250000
900,Urban,2,200000
1500,Rural,4,300000
...
Description:
{description}
Columns:
{columns}
Output: """
def preprocess_user_prompt(user_prompt):
generated_text = text_generator(user_prompt, max_length=60, num_return_sequences=1)[0]["generated_text"]
return generated_text
def format_prompt(description, columns):
processed_description = preprocess_user_prompt(description)
prompt = prompt_template.format(description=processed_description, columns=",".join(columns))
return prompt
generation_params = {
"top_p": 0.90,
"temperature": 0.8,
"max_new_tokens": 512,
"return_full_text": False,
"use_cache": False
}
def generate_synthetic_data(description, columns):
try:
# Load the Llama model only when generating data
load_llama_model()
# Prepare the input for the Llama model
formatted_prompt = format_prompt(description, columns)
# Tokenize the prompt
inputs = tokenizer_llama(formatted_prompt, return_tensors="pt").to(model_llama.device)
# Generate synthetic data
with torch.no_grad():
outputs = model_llama.generate(
**inputs,
max_length=512,
top_p=generation_params["top_p"],
temperature=generation_params["temperature"],
num_return_sequences=1
)
# Decode the generated output
generated_text = tokenizer_llama.decode(outputs[0], skip_special_tokens=True)
# Return the generated synthetic data
return generated_text
except Exception as e:
print(f"Error in generate_synthetic_data: {e}")
return f"Error: {e}"
def generate_large_synthetic_data(description, columns, num_rows=1000, rows_per_generation=100):
data_frames = []
num_iterations = num_rows // rows_per_generation
for _ in range(num_iterations):
generated_data = generate_synthetic_data(description, columns)
if "Error" in generated_data:
return generated_data
df_synthetic = process_generated_data(generated_data)
data_frames.append(df_synthetic)
return pd.concat(data_frames, ignore_index=True)
def process_generated_data(csv_data):
data = StringIO(csv_data)
df = pd.read_csv(data)
return df
def main(description, columns):
description = description.strip()
columns = [col.strip() for col in columns.split(',')]
df_synthetic = generate_large_synthetic_data(description, columns)
if isinstance(df_synthetic, str) and "Error" in df_synthetic:
return df_synthetic # Return the error message if any
return df_synthetic.to_csv(index=False)
iface = gr.Interface(
fn=main,
inputs=[
gr.Textbox(label="Description", placeholder="e.g., Generate a dataset for predicting students' grades"),
gr.Textbox(label="Columns (comma-separated)", placeholder="e.g., name, age, course, grade")
],
outputs="text",
title="Synthetic Data Generator",
description="Generate synthetic tabular datasets based on a description and specified columns."
)
iface.api_name = "generate"
# Run the Gradio app
iface.launch(server_name="0.0.0.0", server_port=7860)