File size: 6,244 Bytes
83947eb
bbce957
84eea9f
bf2f303
26ca4a0
a9f5956
391774c
 
d3ea071
54494d2
faf93aa
bf2f303
bbce957
a421e2f
15e9cb3
bbce957
 
 
15e9cb3
bbce957
 
15e9cb3
26ca4a0
 
a9f5956
bbce957
 
26ca4a0
bbce957
26ca4a0
c73b72d
 
d61c33a
 
 
 
 
f6629ad
26ca4a0
bbce957
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d3ea071
bbce957
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52cec89
bbce957
 
 
d61c33a
d3ea071
52cec89
 
 
 
d61c33a
bbce957
 
 
 
a94bc3f
bbce957
 
777d4d9
bbce957
d61c33a
52cec89
 
 
 
bbce957
 
52cec89
bbce957
 
 
 
 
52cec89
bbce957
 
 
 
26ca4a0
 
d61c33a
26ca4a0
bbce957
d61c33a
6dbe5c9
bbce957
 
d61c33a
bbce957
 
d61c33a
26ca4a0
 
d61c33a
bbce957
 
 
6dbe5c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15e9cb3
bf2f303
bbce957
 
 
83947eb
bf2f303
bbce957
 
 
d61c33a
bbce957
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import streamlit as st
import pandas as pd
import os
import torch
import transformers
from transformers import GPT2LMHeadModel, GPT2Tokenizer, AutoTokenizer, AutoModelForCausalLM, pipeline
from huggingface_hub import HfFolder
from io import StringIO
from tqdm import tqdm
import accelerate
from accelerate import init_empty_weights, infer_auto_device_map, dispatch_model, disk_offload

# Access the Hugging Face API token from environment variables
hf_token = os.getenv('HF_API_TOKEN')

if not hf_token:
    raise ValueError("Hugging Face API token is not set. Please set the HF_API_TOKEN environment variable.")
HfFolder.save_token(hf_token)

# Set environment variable to avoid floating-point errors
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'

# Load the GPT-2 tokenizer and model
tokenizer_gpt2 = GPT2Tokenizer.from_pretrained('gpt2')
model_gpt2 = GPT2LMHeadModel.from_pretrained('gpt2')

# Create a pipeline for text generation using GPT-2
text_generator = pipeline("text-generation", model=model_gpt2, tokenizer=tokenizer_gpt2)

# Load the Llama-3 model and tokenizer once during startup
tokenizer_llama = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B", token=hf_token)
model_llama = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Meta-Llama-3.1-8B",
    torch_dtype='auto',
    device_map='auto',
    token=hf_token
)

# Define your prompt template
prompt_template = """\
You are an expert in generating synthetic data for machine learning models.
Your task is to generate a synthetic tabular dataset based on the description provided below.
Description: {description}
The dataset should include the following columns: {columns}
Please provide the data in CSV format with a minimum of 100 rows per generation.
Ensure that the data is realistic, does not contain any duplicate rows, and follows any specific conditions mentioned.
Example Description:
Generate a dataset for predicting house prices with columns: 'Size', 'Location', 'Number of Bedrooms', 'Price'
Example Output:
Size,Location,Number of Bedrooms,Price
1200,Suburban,3,250000
900,Urban,2,200000
1500,Rural,4,300000
...
Description:
{description}
Columns:
{columns}
Output: """

def preprocess_user_prompt(user_prompt):
    generated_text = text_generator(user_prompt, max_length=60, num_return_sequences=1, truncation=True)[0]["generated_text"]
    return generated_text

def format_prompt(description, columns):
    processed_description = preprocess_user_prompt(description)
    prompt = prompt_template.format(description=processed_description, columns=",".join(columns))
    return prompt

generation_params = {
    "top_p": 0.90,
    "temperature": 0.8,
    "max_new_tokens": 512,
    "return_full_text": False,
    "use_cache": False
}

# Generate synthetic data
def generate_synthetic_data(description, columns):
    try:
        formatted_prompt = format_prompt(description, columns)

        # Tokenize the prompt with truncation enabled
        inputs = tokenizer_llama(formatted_prompt, return_tensors="pt", truncation=True, max_length=512)
        
        # Move inputs to the correct device
        inputs = {k: v.to(model_llama.device) for k, v in inputs.items()}

        # Generate synthetic data
        with torch.no_grad():
            outputs = model_llama.generate(
                **inputs,
                max_length=512,
                top_p=generation_params["top_p"],
                temperature=generation_params["temperature"],
                num_return_sequences=1,
            )

        # Check for meta tensor before decoding
        if outputs.is_meta:
            raise ValueError("Output tensor is in meta state, check model and input.")

        # Decode the generated output
        generated_text = tokenizer_llama.decode(outputs[0], skip_special_tokens=True)
        
        # Return the generated synthetic data
        return generated_text
    except Exception as e:
        return f"Error: {e}"


def generate_large_synthetic_data(description, columns, num_rows=1000, rows_per_generation=100):
    data_frames = []
    num_iterations = num_rows // rows_per_generation

    # Create a progress bar
    progress_bar = st.progress(0)

    for i in tqdm(range(num_iterations)):
        generated_data = generate_synthetic_data(description, columns)
        print("Generated Data:\n", generated_data)  # Move the print statement here

        if "Error" in generated_data:
            return generated_data

        df_synthetic = process_generated_data(generated_data)
        data_frames.append(df_synthetic)

        # Update the progress bar
        progress_bar.progress((i + 1) / num_iterations)

    return pd.concat(data_frames, ignore_index=True)

def process_generated_data(csv_data):
    try:
        # Check if the data is not empty and has valid content
        if not csv_data.strip():
            raise ValueError("Generated data is empty.")

        data = StringIO(csv_data)
        df = pd.read_csv(data)
        print("DataFrame Shape:", df.shape)
        print("DataFrame Head:\n", df.head())

        # Check if the DataFrame is empty
        if df.empty:
            raise ValueError("Generated DataFrame is empty.")

        return df
    except Exception as e:
        st.error(f"Error processing generated data: {e}")
        return pd.DataFrame()  # Return an empty DataFrame on error


# Streamlit app interface
st.title("Synthetic Data Generator")
description = st.text_input("Description", "e.g., Generate a dataset for predicting students' grades")
columns = st.text_input("Columns (comma-separated)", "e.g., name, age, course, grade")

if st.button("Generate"):
    description = description.strip()
    columns = [col.strip() for col in columns.split(',')]
    df_synthetic = generate_large_synthetic_data(description, columns)

    if isinstance(df_synthetic, str) and "Error" in df_synthetic:
        st.error(df_synthetic)  # Display error message if any
    else:
        st.success("Synthetic Data Generated!")
        st.dataframe(df_synthetic)  # Display the generated DataFrame
        st.download_button(
            label="Download CSV",
            data=df_synthetic.to_csv(index=False),
            file_name="synthetic_data.csv",
            mime="text/csv"
        )