Spaces:
Sleeping
Sleeping
import gradio as gr | |
import pandas as pd | |
import requests | |
from transformers import GPT2LMHeadModel, GPT2Tokenizer, LlamaTokenizer, pipeline | |
from huggingface_hub import HfFolder | |
from io import StringIO | |
import os | |
from flask import Flask, request, jsonify | |
hf_token = os.getenv('HF_API_TOKEN') | |
if not hf_token: | |
raise ValueError("Hugging Face API token is not set. Please set the HF_API_TOKEN environment variable.") | |
# Ensure that you load the token before making API requests | |
# Set environment variable to avoid floating-point errors | |
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0' | |
# Load the tokenizer and model | |
tokenizer = GPT2Tokenizer.from_pretrained('gpt2') | |
model_gpt2 = GPT2LMHeadModel.from_pretrained('gpt2') | |
# Create a pipeline for text generation using GPT-2 | |
text_generator = pipeline("text-generation", model=model_gpt2, tokenizer=tokenizer) | |
# Loading LLama3.1 tokenizer | |
try: | |
tokenizer_llama = LlamaTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B") | |
except OSError as e: | |
print(f"Error loading tokenizer: {e}") | |
# Define your prompt template | |
prompt_template = """\ | |
You are an expert in generating synthetic data for machine learning models. | |
Your task is to generate a synthetic tabular dataset based on the description provided below. | |
Description: {description} | |
The dataset should include the following columns: {columns} | |
Please provide the data in CSV format with a minimum of 100 rows per generation. | |
Ensure that the data is realistic, does not contain any duplicate rows, and follows any specific conditions mentioned. | |
Example Description: | |
Generate a dataset for predicting house prices with columns: 'Size', 'Location', 'Number of Bedrooms', 'Price' | |
Example Output: | |
Size,Location,Number of Bedrooms,Price | |
1200,Suburban,3,250000 | |
900,Urban,2,200000 | |
1500,Rural,4,300000 | |
... | |
Description: | |
{description} | |
Columns: | |
{columns} | |
Output: """ | |
def preprocess_user_prompt(user_prompt): | |
generated_text = text_generator(user_prompt, max_length=60, num_return_sequences=1)[0]["generated_text"] | |
return generated_text | |
def format_prompt(description, columns): | |
processed_description = preprocess_user_prompt(description) | |
prompt = prompt_template.format(description=processed_description, columns=",".join(columns)) | |
return prompt | |
API_URL = "https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3.1-8B" | |
generation_params = { | |
"top_p": 0.90, | |
"temperature": 0.8, | |
"max_new_tokens": 512, | |
"return_full_text": False, | |
"use_cache": False | |
} | |
def generate_synthetic_data(description, columns): | |
try: | |
formatted_prompt = format_prompt(description, columns) | |
payload = {"inputs": formatted_prompt, "parameters": generation_params} | |
headers = {"Authorization": f"Bearer {hf_token}"} | |
response = requests.post(API_URL, headers=headers, json=payload) | |
if response.status_code == 200: | |
response_json = response.json() | |
if isinstance(response_json, list) and len(response_json) > 0 and "generated_text" in response_json[0]: | |
return response_json[0]["generated_text"] | |
else: | |
raise ValueError("Unexpected response format or missing 'generated_text' key") | |
else: | |
print(f"Error details: {response.text}") | |
raise ValueError(f"API request failed with status code {response.status_code}: {response.text}") | |
except Exception as e: | |
print(f"Error in generate_synthetic_data: {e}") | |
return f"Error: {e}" | |
def generate_large_synthetic_data(description, columns, num_rows=1000, rows_per_generation=100): | |
data_frames = [] | |
num_iterations = num_rows // rows_per_generation | |
for _ in range(num_iterations): | |
generated_data = generate_synthetic_data(description, columns) | |
if "Error" in generated_data: | |
return generated_data | |
df_synthetic = process_generated_data(generated_data) | |
data_frames.append(df_synthetic) | |
return pd.concat(data_frames, ignore_index=True) | |
def process_generated_data(csv_data): | |
data = StringIO(csv_data) | |
df = pd.read_csv(data) | |
return df | |
def main(description, columns): | |
description = description.strip() | |
columns = [col.strip() for col in columns.split(',')] | |
df_synthetic = generate_large_synthetic_data(description, columns) | |
if isinstance(df_synthetic, str) and "Error" in df_synthetic: | |
return df_synthetic # Return the error message if any | |
return df_synthetic.to_csv(index=False) | |
# Gradio interface | |
iface = gr.Interface( | |
fn=main, | |
inputs=[ | |
gr.Textbox(label="Description", placeholder="e.g., Generate a dataset for predicting students' grades"), | |
gr.Textbox(label="Columns (comma-separated)", placeholder="e.g., name, age, course, grade") | |
], | |
outputs="text", | |
title="Synthetic Data Generator", | |
description="Generate synthetic tabular datasets based on a description and specified columns." | |
) | |
# Run the Gradio app | |
iface.launch() | |