File size: 4,929 Bytes
4d35d05
857e937
170c5f1
eaf3f9a
170c5f1
bad46c5
dbb92e9
170c5f1
1ab421e
bad46c5
eaf3f9a
eaf6f50
 
 
1ab421e
 
7af38c7
1ab421e
 
 
 
170c5f1
 
 
 
 
c7b1e29
 
 
170c5f1
c7b1e29
 
170c5f1
 
6bc7f9d
c7b1e29
7ce0c46
6bc7f9d
c7b1e29
7ce0c46
c7b1e29
 
6bc7f9d
 
 
 
 
 
 
c7b1e29
 
 
 
3476a0f
6bc7f9d
c7b1e29
170c5f1
 
 
0dc54fb
bad46c5
 
170c5f1
 
c7b1e29
 
 
 
170c5f1
bad46c5
 
c7b1e29
 
 
d7ee939
c7b1e29
 
 
 
 
 
 
 
be8d2d5
c7b1e29
eaf3f9a
be8d2d5
c7b1e29
 
 
 
eaf3f9a
 
 
 
bad46c5
 
eaf3f9a
c7b1e29
eaf3f9a
bad46c5
d7821a8
c7b1e29
 
 
 
 
 
 
 
 
eaf3f9a
 
 
 
 
01c0141
31d4200
 
 
3476a0f
31d4200
 
 
 
eaf3f9a
65871c9
eaf3f9a
 
 
 
 
 
 
31d4200
 
 
 
 
 
 
 
be8d2d5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse, JSONResponse
from pydantic import BaseModel
import pandas as pd
import os
import requests
from transformers import GPT2LMHeadModel, GPT2Tokenizer, AutoTokenizer, pipeline
from io import StringIO
from fastapi.middleware.cors import CORSMiddleware
from huggingface_hub import HfFolder
from tqdm import tqdm

app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

hf_token = os.getenv('HF_API_TOKEN')
if not hf_token:
    raise ValueError("Hugging Face API token is not set. Please set the HF_API_TOKEN environment variable.")


# Load GPT-2 model and tokenizer
tokenizer_gpt2 = AutoTokenizer.from_pretrained('gpt2')
model_gpt2 = GPT2LMHeadModel.from_pretrained('gpt2')

# Create a pipeline for text generation using GPT-2
text_generator = pipeline("text-generation", model=model_gpt2, tokenizer=tokenizer_gpt2)


# Define prompt template for generating the dataset
prompt_template = """\
You are an AI designed exclusively for generating synthetic tabular datasets.
Task: Generate a synthetic dataset based on the provided description and column names.
Description: {description}
Columns: {columns}
Instructions:
- Output only the tabular data in valid CSV format.
- Include the header row followed strictly by the data rows.
- Do not include any additional text, explanations, comments, or code outside of the CSV data.
- Ensure that the values for each column are contextually appropriate based on the description.
- Do not alter the column names or add any new columns.
- Each row must contain data for all columns without any empty values unless specified.
- Format Example (do not include this line or the following example in your output):
Column1,Column2,Column3
Value1,Value2,Value3
Value4,Value5,Value6
"""


# Define generation parameters
generation_params = {
    "top_p": 0.90,
    "temperature": 0.8,
    "max_new_tokens": 2048,
    "return_full_text": False,
    "use_cache": False
}

def format_prompt(description, columns):
    prompt = prompt_template.format(description=description, columns=",".join(columns))
    return prompt

def generate_synthetic_data(description, columns):
    formatted_prompt = format_prompt(description, columns)
    payload = {"inputs": formatted_prompt, "parameters": generation_params}

    # Call Mixtral model to generate data
    response = requests.post("https://api-inference.huggingface.co/models/mistralai/Mixtral-8x7B-Instruct-v0.1", 
                             headers={"Authorization": f"Bearer {hf_token}"}, json=payload)
    
    if response.status_code == 200:
        return response.json()[0]["generated_text"]
    else:
        print(f"Error generating data: {response.status_code}, {response.text}")
        return None

def process_generated_data(csv_data):
    try:
        # Ensure the data is cleaned and correctly formatted
        cleaned_data = csv_data.replace('\r\n', '\n').replace('\r', '\n')
        data = StringIO(cleaned_data)

        # Read the CSV data with specific parameters to handle irregularities
        df = pd.read_csv(data)

        return df
    except pd.errors.ParserError as e:
        print(f"Failed to parse CSV data: {e}")
        return None

def generate_large_synthetic_data(description, columns, num_rows=1000, rows_per_generation=100):
    data_frames = []

    for _ in tqdm(range(num_rows // rows_per_generation), desc="Generating Data"):
        generated_data = generate_synthetic_data(description, columns)
        
        if generated_data:
           
            df_synthetic = process_generated_data(generated_data)

            if df_synthetic is not None and not df_synthetic.empty:
                data_frames.append(df_synthetic)
            else:
                print("Skipping invalid generation.")

    if data_frames:
        return pd.concat(data_frames, ignore_index=True)
    else:
        print("No valid data frames to concatenate.")
        return pd.DataFrame(columns=columns)

class DataGenerationRequest(BaseModel):
    description: str
    columns: list[str]

@app.post("/generate/")
def generate_data(request: DataGenerationRequest):
    description = request.description.strip()
    columns = [col.strip() for col in request.columns]
    csv_data = generate_large_synthetic_data(description, columns, num_rows=1000, rows_per_generation=100)
    
    if csv_data.empty:
        return JSONResponse(content={"error": "No valid data generated"}, status_code=500)
    
    csv_buffer = StringIO()
    csv_data.to_csv(csv_buffer, index=False)
    csv_buffer.seek(0)

    return StreamingResponse(
        csv_buffer,
        media_type="text/csv",
        headers={"Content-Disposition": "attachment; filename=generated_data.csv"}
    )

@app.get("/")
def greet_json():
    return {"Hello": "World!"}