File size: 5,426 Bytes
4d35d05
a93106f
170c5f1
 
 
bad46c5
dbb92e9
170c5f1
1ab421e
bad46c5
 
eaf6f50
 
 
1ab421e
 
7ce0c46
1ab421e
 
 
 
170c5f1
 
 
 
 
 
 
bad46c5
170c5f1
 
 
 
 
 
bad46c5
 
 
 
170c5f1
bad46c5
7ce0c46
 
 
 
 
bad46c5
7ce0c46
 
 
 
 
 
 
 
 
 
 
 
 
170c5f1
 
 
 
 
bad46c5
 
 
 
 
170c5f1
 
 
 
 
 
bad46c5
 
170c5f1
 
 
 
bad46c5
 
170c5f1
 
 
bad46c5
 
65871c9
 
 
 
 
 
bad46c5
 
65871c9
 
 
 
bad46c5
 
 
 
170c5f1
20981ee
 
bad46c5
170c5f1
bad46c5
65871c9
bad46c5
 
 
65871c9
bad46c5
 
65871c9
021bce4
20981ee
bad46c5
 
65871c9
bad46c5
65871c9
 
 
 
a93106f
bad46c5
65871c9
170c5f1
 
 
 
 
65871c9
170c5f1
65871c9
a93106f
65871c9
a93106f
 
 
170c5f1
 
 
13454c6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse, JSONResponse
from pydantic import BaseModel
import pandas as pd
import os
import requests
from transformers import GPT2LMHeadModel, GPT2Tokenizer, AutoTokenizer, pipeline
from io import StringIO
from fastapi.middleware.cors import CORSMiddleware
from huggingface_hub import HfFolder
from tqdm import tqdm

app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # You can specify domains here
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Access the Hugging Face API token from environment variables
hf_token = os.getenv('HF_API_TOKEN')

if not hf_token:
    raise ValueError("Hugging Face API token is not set. Please set the HF_API_TOKEN environment variable.")

# Load GPT-2 model and tokenizer
tokenizer_gpt2 = GPT2Tokenizer.from_pretrained('gpt2')
model_gpt2 = GPT2LMHeadModel.from_pretrained('gpt2')

# Create a pipeline for text generation using GPT-2
text_generator = pipeline("text-generation", model=model_gpt2, tokenizer=tokenizer_gpt2)

def preprocess_user_prompt(user_prompt):
    # Generate a structured prompt based on the user input
    generated_text = text_generator(user_prompt, max_length=50, num_return_sequences=1)[0]["generated_text"]
    return generated_text

# Define prompt template
prompt_template = """\
You are an expert in generating synthetic data for machine learning models.
Your task is to generate a synthetic tabular dataset based on the description provided below.
Description: {description}
The dataset should include the following columns: {columns}
Please provide the data in CSV format.
Example Description:
Generate a dataset for predicting house prices with columns: 'Size', 'Location', 'Number of Bedrooms', 'Price'
Example Output:
Size,Location,Number of Bedrooms,Price
1200,Suburban,3,250000
900,Urban,2,200000
1500,Rural,4,300000
...
Description:
{description}
Columns:
{columns}
Output: """

class DataGenerationRequest(BaseModel):
    description: str
    columns: list

# Set up the Mixtral model and tokenizer
token = hf_token  # Use environment variable for the token
HfFolder.save_token(token)

tokenizer_mixtral = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-Instruct-v0.1", token=token)

def format_prompt(description, columns):
    processed_description = preprocess_user_prompt(description)
    prompt = prompt_template.format(description=processed_description, columns=",".join(columns))
    return prompt

API_URL = "https://api-inference.huggingface.co/models/mistralai/Mixtral-8x7B-Instruct-v0.1"

generation_params = {
    "top_p": 0.90,
    "temperature": 0.8,
    "max_new_tokens": 512,
    "return_full_text": False,
    "use_cache": False
}

def generate_synthetic_data(description, columns):
    formatted_prompt = format_prompt(description, columns)
    payload = {"inputs": formatted_prompt, "parameters": generation_params}
    response = requests.post(API_URL, headers={"Authorization": f"Bearer {hf_token}"}, json=payload)
    
    try:
        response_data = response.json()
    except ValueError:
        raise HTTPException(status_code=500, detail="Failed to parse response from the API.")
    
    if 'error' in response_data:
        raise HTTPException(status_code=500, detail=f"API Error: {response_data['error']}")
    
    if 'generated_text' not in response_data[0]:
        raise HTTPException(status_code=500, detail="Unexpected API response format.")
    
    return response_data[0]["generated_text"]

def process_generated_data(csv_data, expected_columns):
    try:
        cleaned_data = csv_data.replace('\r\n', '\n').replace('\r', '\n')
        data = StringIO(cleaned_data)
        df = pd.read_csv(data, delimiter=',')
        
        if set(df.columns) != set(expected_columns):
            raise ValueError("Unexpected columns in the generated data.")
        
        return df
    except pd.errors.ParserError as e:
        raise HTTPException(status_code=500, detail=f"Failed to parse CSV data: {e}")

def generate_large_synthetic_data(description, columns, num_rows=1000, rows_per_generation=100):
    csv_data_all = StringIO()
    
    for _ in tqdm(range(num_rows // rows_per_generation), desc="Generating Data"):
        generated_data = generate_synthetic_data(description, columns)
        df_synthetic = process_generated_data(generated_data, columns)
        
        if isinstance(df_synthetic, pd.DataFrame) and not df_synthetic.empty:
            df_synthetic.to_csv(csv_data_all, index=False, header=False)
    
    if csv_data_all.tell() > 0:  # Check if there's any data in the buffer
        csv_data_all.seek(0)  # Rewind the buffer to the beginning
        return csv_data_all
    else:
        raise HTTPException(status_code=500, detail="No valid data frames generated.")

@app.post("/generate/")
def generate_data(request: DataGenerationRequest):
    description = request.description.strip()
    columns = [col.strip() for col in request.columns]
    csv_data = generate_large_synthetic_data(description, columns, num_rows=1000, rows_per_generation=100)
    
    # Return the CSV data as a downloadable file
    return StreamingResponse(
        csv_data,
        media_type="text/csv",
        headers={"Content-Disposition": "attachment; filename=generated_data.csv"}
    )

@app.get("/")
def greet_json():
    return {"Hello": "World!"}