|
from fastapi import FastAPI, HTTPException |
|
from fastapi.responses import StreamingResponse, JSONResponse |
|
from pydantic import BaseModel |
|
import os |
|
import requests |
|
from transformers import GPT2LMHeadModel, GPT2Tokenizer, AutoTokenizer, pipeline |
|
from io import StringIO |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from huggingface_hub import HfFolder |
|
import logging |
|
import random |
|
import csv |
|
|
|
app = FastAPI() |
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
|
|
hf_token = os.getenv('HF_API_TOKEN') |
|
|
|
if not hf_token: |
|
raise ValueError("Hugging Face API token is not set. Please set the HF_API_TOKEN environment variable.") |
|
|
|
|
|
tokenizer_gpt2 = GPT2Tokenizer.from_pretrained('gpt2') |
|
model_gpt2 = GPT2LMHeadModel.from_pretrained('gpt2') |
|
|
|
|
|
text_generator = pipeline("text-generation", model=model_gpt2, tokenizer=tokenizer_gpt2) |
|
|
|
def preprocess_user_prompt(user_prompt): |
|
|
|
generated_text = text_generator(user_prompt, max_length=50, num_return_sequences=1)[0]["generated_text"] |
|
return generated_text |
|
|
|
|
|
prompt_template = """\ |
|
You are an expert in generating synthetic data for machine learning models. |
|
Your task is to generate a synthetic tabular dataset based on the description provided below. |
|
Description: {description} |
|
The dataset should include the following columns: {columns} |
|
Please provide the data in CSV format. |
|
Example Description: |
|
Generate a dataset for predicting house prices with columns: 'Size', 'Location', 'Number of Bedrooms', 'Price' |
|
Example Output: |
|
Size,Location,Number of Bedrooms,Price |
|
1200,Suburban,3,250000 |
|
900,Urban,2,200000 |
|
1500,Rural,4,300000 |
|
... |
|
Description: |
|
{description} |
|
Columns: |
|
{columns} |
|
Output: """ |
|
|
|
tokenizer_mixtral = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-Instruct-v0.1", token=hf_token) |
|
|
|
def format_prompt(description, columns): |
|
processed_description = preprocess_user_prompt(description) |
|
prompt = prompt_template.format(description=processed_description, columns=",".join(columns)) |
|
return prompt |
|
|
|
API_URL = "https://api-inference.huggingface.co/models/mistralai/Mixtral-8x7B-Instruct-v0.1" |
|
|
|
generation_params = { |
|
"top_p": 0.90, |
|
"temperature": 0.8, |
|
"max_new_tokens": 512, |
|
"return_full_text": False, |
|
"use_cache": False |
|
} |
|
|
|
def generate_synthetic_data(description, columns): |
|
formatted_prompt = format_prompt(description, columns) |
|
payload = {"inputs": formatted_prompt, "parameters": generation_params} |
|
try: |
|
response = requests.post(API_URL, headers={"Authorization": f"Bearer {hf_token}"}, json=payload) |
|
response.raise_for_status() |
|
data = response.json() |
|
if 'generated_text' in data[0]: |
|
return data[0]['generated_text'] |
|
else: |
|
raise ValueError("Invalid response format from Hugging Face API.") |
|
except (requests.RequestException, ValueError) as e: |
|
logging.error(f"Error during API request or response processing: {e}") |
|
return "name,age,course,grade\nSampleName,20,Course,0" |
|
|
|
def generate_large_synthetic_data(description, columns, num_rows=1000, rows_per_generation=100): |
|
csv_buffer = StringIO() |
|
writer = csv.writer(csv_buffer) |
|
|
|
|
|
writer.writerow(columns) |
|
|
|
rows_generated = 0 |
|
while rows_generated < num_rows: |
|
generated_data = generate_synthetic_data(description, columns) |
|
cleaned_data = generated_data.replace('\r\n', '\n').replace('\r', '\n') |
|
data = StringIO(cleaned_data) |
|
|
|
|
|
reader = csv.reader(data) |
|
header_written = False |
|
for row in reader: |
|
if not header_written: |
|
header_written = True |
|
continue |
|
writer.writerow(row) |
|
rows_generated += 1 |
|
if rows_generated >= num_rows: |
|
break |
|
|
|
csv_buffer.seek(0) |
|
return csv_buffer |
|
|
|
class DataGenerationRequest(BaseModel): |
|
description: str |
|
columns: list[str] |
|
|
|
@app.post("/generate/") |
|
def generate_data(request: DataGenerationRequest): |
|
description = request.description.strip() |
|
columns = [col.strip() for col in request.columns] |
|
csv_buffer = generate_large_synthetic_data(description, columns, num_rows=10000) |
|
|
|
|
|
return StreamingResponse( |
|
csv_buffer, |
|
media_type="text/csv", |
|
headers={"Content-Disposition": "attachment; filename=generated_data.csv"} |
|
) |
|
|
|
@app.get("/") |
|
def greet_json(): |
|
return {"Hello": "World!"} |
|
|