|
from fastapi import FastAPI, HTTPException |
|
from fastapi.responses import StreamingResponse, JSONResponse |
|
from pydantic import BaseModel |
|
import pandas as pd |
|
import os |
|
import requests |
|
from transformers import GPT2LMHeadModel, GPT2Tokenizer, AutoTokenizer, pipeline |
|
from io import StringIO |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from huggingface_hub import HfFolder |
|
from tqdm import tqdm |
|
|
|
app = FastAPI() |
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
|
|
hf_token = os.getenv('HF_API_TOKEN') |
|
|
|
if not hf_token: |
|
raise ValueError("Hugging Face API token is not set. Please set the HF_API_TOKEN environment variable.") |
|
|
|
|
|
tokenizer_gpt2 = GPT2Tokenizer.from_pretrained('gpt2') |
|
model_gpt2 = GPT2LMHeadModel.from_pretrained('gpt2') |
|
|
|
|
|
text_generator = pipeline("text-generation", model=model_gpt2, tokenizer=tokenizer_gpt2) |
|
|
|
def preprocess_user_prompt(user_prompt): |
|
|
|
generated_text = text_generator(user_prompt, max_length=50, num_return_sequences=1)[0]["generated_text"] |
|
return generated_text |
|
|
|
|
|
prompt_template = """\ |
|
You are an expert in generating synthetic data for machine learning models. |
|
Your task is to generate a synthetic tabular dataset based on the description provided below. |
|
Description: {description} |
|
The dataset should include the following columns: {columns} |
|
Please provide the data in CSV format. |
|
Example Description: |
|
Generate a dataset for predicting house prices with columns: 'Size', 'Location', 'Number of Bedrooms', 'Price' |
|
Example Output: |
|
Size,Location,Number of Bedrooms,Price |
|
1200,Suburban,3,250000 |
|
900,Urban,2,200000 |
|
1500,Rural,4,300000 |
|
... |
|
Description: |
|
{description} |
|
Columns: |
|
{columns} |
|
Output: """ |
|
|
|
class DataGenerationRequest(BaseModel): |
|
description: str |
|
columns: list |
|
|
|
|
|
token = hf_token |
|
HfFolder.save_token(token) |
|
|
|
tokenizer_mixtral = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-Instruct-v0.1", token=token) |
|
|
|
def format_prompt(description, columns): |
|
processed_description = preprocess_user_prompt(description) |
|
prompt = prompt_template.format(description=processed_description, columns=",".join(columns)) |
|
return prompt |
|
|
|
API_URL = "https://api-inference.huggingface.co/models/mistralai/Mixtral-8x7B-Instruct-v0.1" |
|
|
|
generation_params = { |
|
"top_p": 0.90, |
|
"temperature": 0.8, |
|
"max_new_tokens": 512, |
|
"return_full_text": False, |
|
"use_cache": False |
|
} |
|
|
|
def generate_synthetic_data(description, columns): |
|
formatted_prompt = format_prompt(description, columns) |
|
payload = {"inputs": formatted_prompt, "parameters": generation_params} |
|
response = requests.post(API_URL, headers={"Authorization": f"Bearer {hf_token}"}, json=payload) |
|
|
|
try: |
|
response_data = response.json() |
|
except ValueError: |
|
raise HTTPException(status_code=500, detail="Failed to parse response from the API.") |
|
|
|
if 'error' in response_data: |
|
raise HTTPException(status_code=500, detail=f"API Error: {response_data['error']}") |
|
|
|
if 'generated_text' not in response_data[0]: |
|
raise HTTPException(status_code=500, detail="Unexpected API response format.") |
|
|
|
return response_data[0]["generated_text"] |
|
|
|
def process_generated_data(csv_data, expected_columns): |
|
try: |
|
cleaned_data = csv_data.replace('\r\n', '\n').replace('\r', '\n') |
|
data = StringIO(cleaned_data) |
|
df = pd.read_csv(data, delimiter=',') |
|
|
|
if set(df.columns) != set(expected_columns): |
|
raise ValueError("Unexpected columns in the generated data.") |
|
|
|
return df |
|
except pd.errors.ParserError as e: |
|
raise HTTPException(status_code=500, detail=f"Failed to parse CSV data: {e}") |
|
|
|
def generate_large_synthetic_data(description, columns, num_rows=1000, rows_per_generation=100): |
|
csv_data_all = StringIO() |
|
|
|
for _ in tqdm(range(num_rows // rows_per_generation), desc="Generating Data"): |
|
generated_data = generate_synthetic_data(description, columns) |
|
df_synthetic = process_generated_data(generated_data, columns) |
|
|
|
if isinstance(df_synthetic, pd.DataFrame) and not df_synthetic.empty: |
|
df_synthetic.to_csv(csv_data_all, index=False, header=False) |
|
|
|
if csv_data_all.tell() > 0: |
|
csv_data_all.seek(0) |
|
return csv_data_all |
|
else: |
|
raise HTTPException(status_code=500, detail="No valid data frames generated.") |
|
|
|
@app.post("/generate/") |
|
def generate_data(request: DataGenerationRequest): |
|
description = request.description.strip() |
|
columns = [col.strip() for col in request.columns] |
|
csv_data = generate_large_synthetic_data(description, columns, num_rows=1000, rows_per_generation=100) |
|
|
|
|
|
return StreamingResponse( |
|
csv_data, |
|
media_type="text/csv", |
|
headers={"Content-Disposition": "attachment; filename=generated_data.csv"} |
|
) |
|
|
|
@app.get("/") |
|
def greet_json(): |
|
return {"Hello": "World!"} |
|
|