File size: 4,393 Bytes
91b6f17
 
65793c2
91b6f17
f6e79d6
6934761
59d15e6
 
 
 
 
65793c2
b852a37
faeede1
91b6f17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3dd5eb9
91b6f17
 
 
 
 
 
faeede1
 
 
ef435e0
faeede1
91b6f17
9e25bdd
b852a37
 
65793c2
faeede1
3dd5eb9
 
 
 
 
b852a37
3dd5eb9
 
 
 
 
 
65793c2
ba22b41
3dd5eb9
 
 
c6ded12
3dd5eb9
b852a37
c6ded12
3dd5eb9
9e25bdd
 
 
91b6f17
faeede1
91b6f17
 
 
 
 
 
9e25bdd
 
91b6f17
 
 
 
 
 
 
 
 
 
 
 
 
 
9e25bdd
 
91b6f17
 
 
 
 
 
 
 
 
 
6ee338b
 
91b6f17
 
a43a385
cf03ae0
91b6f17
44b5315
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import gradio as gr
import pandas as pd
import requests
from io import StringIO
import os

# Access the Hugging Face API token from environment variables
hf_token = os.getenv('HF_API_TOKEN')

if not hf_token:
    raise ValueError("Hugging Face API token is not set. Please set the HF_API_TOKEN environment variable.")




# Define your prompt template
prompt_template = """\
You are an expert in generating synthetic data for machine learning models.
Your task is to generate a synthetic tabular dataset based on the description provided below.
Description: {description}
The dataset should include the following columns: {columns}
Please provide the data in CSV format with a minimum of 100 rows per generation.
Ensure that the data is realistic, does not contain any duplicate rows, and follows any specific conditions mentioned.
Example Description:
Generate a dataset for predicting house prices with columns: 'Size', 'Location', 'Number of Bedrooms', 'Price'
Example Output:
Size,Location,Number of Bedrooms,Price
1200,Suburban,3,250000
900,Urban,2,200000
1500,Rural,4,300000
...
Description:
{description}
Columns:
{columns}
Output: """

def preprocess_user_prompt(user_prompt):
    return user_prompt

def format_prompt(description, columns):
    processed_description = preprocess_user_prompt(description)
    prompt = prompt_template.format(description=processed_description, columns=",".join(columns))
    return prompt

import requests

# Define your Streamlit Space inference URL
inference_endpoint = "https://huggingface.co/spaces/yakine/model"  

def generate_synthetic_data(description, columns):
    try:
        # Format the prompt for your Llama 3 model
        formatted_prompt = f"{description}, with columns: {', '.join(columns)}"  # Adjust this based on your Streamlit app's prompt format
        
        # Send a POST request to the Streamlit Space API
        headers = {
            "Authorization": f"Bearer {hf_token}",
            "Content-Type": "application/json"
        }
        data = {
            "inputs": formatted_prompt,  # Adjust according to the input expected by your Streamlit app
            "parameters": {
                "max_new_tokens": 512,
                "top_p": 0.90,
                "temperature": 0.8
            }
        }
        
        response = requests.post(inference_endpoint , json=data, headers=headers)

        if response.status_code != 200:
            return f"Error: {response.status_code}, {response.text}"
        
        # Extract the generated text from the response
        generated_text = response.json().get('data')  # Adjust based on your Streamlit Space response structure
        return generated_text

    except Exception as e:
        print(f"Error in generate_synthetic_data: {e}")
        return f"Error: {e}"


def generate_large_synthetic_data(description, columns, num_rows=1000, rows_per_generation=100):
    data_frames = []
    num_iterations = num_rows // rows_per_generation

    for _ in range(num_iterations):
        generated_data = generate_synthetic_data(description, columns)
        if "Error" in generated_data:
            return generated_data
        df_synthetic = process_generated_data(generated_data)
        data_frames.append(df_synthetic)
    
    return pd.concat(data_frames, ignore_index=True)

def process_generated_data(csv_data):
    data = StringIO(csv_data)
    df = pd.read_csv(data)
    return df

def main(description, columns):
    description = description.strip()
    columns = [col.strip() for col in columns.split(',')]
    df_synthetic = generate_large_synthetic_data(description, columns)
    if isinstance(df_synthetic, str) and "Error" in df_synthetic:
        return df_synthetic  # Return the error message if any
    return df_synthetic.to_csv(index=False)

iface = gr.Interface(
    fn=main,
    inputs=[
        gr.Textbox(label="Description", placeholder="e.g., Generate a dataset for predicting students' grades"),
        gr.Textbox(label="Columns (comma-separated)", placeholder="e.g., name, age, course, grade")
    ],
    outputs="text",
    title="Synthetic Data Generator",
    description="Generate synthetic tabular datasets based on a description and specified columns.",
    api_name="generate"  # Set the API name directly here
)

iface.api_name = "generate"

# Run the Gradio app
iface.launch(server_name="0.0.0.0", server_port=7860)