|
from fastapi import FastAPI, HTTPException |
|
from fastapi.responses import StreamingResponse, JSONResponse |
|
from pydantic import BaseModel |
|
import pandas as pd |
|
import os |
|
import requests |
|
from transformers import GPT2LMHeadModel, GPT2Tokenizer, AutoTokenizer, pipeline |
|
from io import StringIO |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from huggingface_hub import HfFolder |
|
from tqdm import tqdm |
|
|
|
app = FastAPI() |
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
hf_token = os.getenv('HF_API_TOKEN') |
|
if not hf_token: |
|
raise ValueError("Hugging Face API token is not set. Please set the HF_API_TOKEN environment variable.") |
|
|
|
|
|
|
|
tokenizer_gpt2 = AutoTokenizer.from_pretrained('gpt2') |
|
model_gpt2 = GPT2LMHeadModel.from_pretrained('gpt2') |
|
|
|
|
|
text_generator = pipeline("text-generation", model=model_gpt2, tokenizer=tokenizer_gpt2) |
|
|
|
|
|
|
|
prompt_template = """\ |
|
You are an AI designed exclusively for generating synthetic tabular datasets. |
|
Task: Generate a synthetic dataset based on the provided description and column names. |
|
Description: {description} |
|
Columns: {columns} |
|
Instructions: |
|
- Output only the tabular data in valid CSV format. |
|
- Include the header row followed strictly by the data rows. |
|
- Do not include any additional text, explanations, comments, or code outside of the CSV data. |
|
- Ensure that the values for each column are contextually appropriate based on the description. |
|
- Do not alter the column names or add any new columns. |
|
- Each row must contain data for all columns without any empty values unless specified. |
|
- Format Example (do not include this line or the following example in your output): |
|
Column1,Column2,Column3 |
|
Value1,Value2,Value3 |
|
Value4,Value5,Value6 |
|
""" |
|
|
|
|
|
|
|
generation_params = { |
|
"top_p": 0.90, |
|
"temperature": 0.8, |
|
"max_new_tokens": 2048, |
|
"return_full_text": False, |
|
"use_cache": False |
|
} |
|
|
|
def format_prompt(description, columns): |
|
prompt = prompt_template.format(description=description, columns=",".join(columns)) |
|
return prompt |
|
|
|
def generate_synthetic_data(description, columns): |
|
formatted_prompt = format_prompt(description, columns) |
|
payload = {"inputs": formatted_prompt, "parameters": generation_params} |
|
|
|
|
|
response = requests.post("https://api-inference.huggingface.co/models/mistralai/Mixtral-8x7B-Instruct-v0.1", |
|
headers={"Authorization": f"Bearer {hf_token}"}, json=payload) |
|
|
|
if response.status_code == 200: |
|
return response.json()[0]["generated_text"] |
|
else: |
|
print(f"Error generating data: {response.status_code}, {response.text}") |
|
return None |
|
|
|
def process_generated_data(csv_data): |
|
try: |
|
|
|
cleaned_data = csv_data.replace('\r\n', '\n').replace('\r', '\n') |
|
data = StringIO(cleaned_data) |
|
|
|
|
|
df = pd.read_csv(data) |
|
|
|
return df |
|
except pd.errors.ParserError as e: |
|
print(f"Failed to parse CSV data: {e}") |
|
return None |
|
|
|
def generate_large_synthetic_data(description, columns, num_rows=1000, rows_per_generation=100): |
|
data_frames = [] |
|
|
|
for _ in tqdm(range(num_rows // rows_per_generation), desc="Generating Data"): |
|
generated_data = generate_synthetic_data(description, columns) |
|
|
|
if generated_data: |
|
|
|
df_synthetic = process_generated_data(generated_data) |
|
|
|
if df_synthetic is not None and not df_synthetic.empty: |
|
data_frames.append(df_synthetic) |
|
else: |
|
print("Skipping invalid generation.") |
|
|
|
if data_frames: |
|
return pd.concat(data_frames, ignore_index=True) |
|
else: |
|
print("No valid data frames to concatenate.") |
|
return pd.DataFrame(columns=columns) |
|
|
|
class DataGenerationRequest(BaseModel): |
|
description: str |
|
columns: list[str] |
|
|
|
@app.post("/generate/") |
|
def generate_data(request: DataGenerationRequest): |
|
description = request.description.strip() |
|
columns = [col.strip() for col in request.columns] |
|
csv_data = generate_large_synthetic_data(description, columns, num_rows=1000, rows_per_generation=100) |
|
|
|
if csv_data.empty: |
|
return JSONResponse(content={"error": "No valid data generated"}, status_code=500) |
|
|
|
csv_buffer = StringIO() |
|
csv_data.to_csv(csv_buffer, index=False) |
|
csv_buffer.seek(0) |
|
|
|
return StreamingResponse( |
|
csv_buffer, |
|
media_type="text/csv", |
|
headers={"Content-Disposition": "attachment; filename=generated_data.csv"} |
|
) |
|
|
|
@app.get("/") |
|
def greet_json(): |
|
return {"Hello": "World!"} |