|
from fastapi import FastAPI, HTTPException |
|
from fastapi.responses import StreamingResponse, JSONResponse |
|
from pydantic import BaseModel |
|
import pandas as pd |
|
import os |
|
import requests |
|
from transformers import GPT2LMHeadModel, GPT2Tokenizer, AutoTokenizer, pipeline |
|
from io import StringIO |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from huggingface_hub import HfFolder |
|
from tqdm import tqdm |
|
|
|
app = FastAPI() |
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
|
|
hf_token = os.getenv('HF_API_TOKEN') |
|
|
|
if not hf_token: |
|
raise ValueError("Hugging Face API token is not set. Please set the HF_API_TOKEN environment variable.") |
|
|
|
|
|
tokenizer_gpt2 = GPT2LMHeadModel.from_pretrained('gpt2') |
|
model_gpt2 = GPT2LMHeadModel.from_pretrained('gpt2') |
|
|
|
|
|
|
|
text_generator = pipeline("text-generation", model=model_gpt2, tokenizer=tokenizer_gpt2) |
|
|
|
|
|
prompt_template = """\ |
|
You are an expert in generating synthetic data for machine learning models. |
|
|
|
Your task is to generate a synthetic tabular dataset based on the description provided below. |
|
|
|
Description: {description} |
|
|
|
The dataset should include the following columns: {columns} |
|
|
|
Please provide the data in CSV format. |
|
|
|
Example Description: |
|
Generate a dataset for predicting house prices with columns: 'Size', 'Location', 'Number of Bedrooms', 'Price' |
|
|
|
Example Output: |
|
Size,Location,Number of Bedrooms,Price |
|
1200,Suburban,3,250000 |
|
900,Urban,2,200000 |
|
1500,Rural,4,300000 |
|
... |
|
|
|
Description: |
|
{description} |
|
Columns: |
|
{columns} |
|
Output: """ |
|
|
|
|
|
token = os.getenv("HF_TOKEN") |
|
HfFolder.save_token(token) |
|
|
|
tokenizer_mixtral = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-Instruct-v0.1", token=token) |
|
|
|
API_URL = "https://api-inference.huggingface.co/models/mistralai/Mixtral-8x7B-Instruct-v0.1" |
|
|
|
generation_params = { |
|
"top_p": 0.90, |
|
"temperature": 0.8, |
|
"max_new_tokens": 512, |
|
"return_full_text": False, |
|
"use_cache": False |
|
} |
|
|
|
def preprocess_user_prompt(user_prompt): |
|
generated_text = text_generator(user_prompt, max_length=50, num_return_sequences=1)[0]["generated_text"] |
|
return generated_text |
|
|
|
def format_prompt(description, columns): |
|
processed_description = preprocess_user_prompt(description) |
|
prompt = prompt_template.format(description=processed_description, columns=",".join(columns)) |
|
return prompt |
|
|
|
def generate_synthetic_data(description, columns): |
|
formatted_prompt = format_prompt(description, columns) |
|
payload = {"inputs": formatted_prompt, "parameters": generation_params} |
|
response = requests.post(API_URL, headers={"Authorization": f"Bearer {token}"}, json=payload) |
|
return response.json()[0]["generated_text"] |
|
|
|
def process_generated_data(csv_data, expected_columns): |
|
try: |
|
cleaned_data = csv_data.replace('\r\n', '\n').replace('\r', '\n') |
|
data = StringIO(cleaned_data) |
|
df = pd.read_csv(data, delimiter=',') |
|
if set(df.columns) != set(expected_columns): |
|
print(f"Unexpected columns in the generated data: {df.columns}") |
|
return None |
|
return df |
|
except pd.errors.ParserError as e: |
|
print(f"Failed to parse CSV data: {e}") |
|
return None |
|
|
|
def generate_large_synthetic_data(description, columns, num_rows=1000, rows_per_generation=100): |
|
data_frames = [] |
|
for _ in tqdm(range(num_rows // rows_per_generation), desc="Generating Data"): |
|
generated_data = generate_synthetic_data(description, columns) |
|
df_synthetic = process_generated_data(generated_data, columns) |
|
if df_synthetic is not None and not df_synthetic.empty: |
|
data_frames.append(df_synthetic) |
|
else: |
|
print("Skipping invalid generation.") |
|
if data_frames: |
|
return pd.concat(data_frames, ignore_index=True) |
|
else: |
|
print("No valid data frames to concatenate.") |
|
return pd.DataFrame(columns=columns) |
|
|
|
@app.route('/generate', methods=['POST']) |
|
def generate(): |
|
data = request.json |
|
description = data.get('description') |
|
columns = data.get('columns') |
|
num_rows = data.get('num_rows', 1000) |
|
|
|
if not description or not columns: |
|
return jsonify({"error": "Please provide 'description' and 'columns' in the request."}), 400 |
|
|
|
df_synthetic = generate_large_synthetic_data(description, columns, num_rows=num_rows) |
|
|
|
if df_synthetic is not None and not df_synthetic.empty: |
|
file_path = 'synthetic_data.csv' |
|
df_synthetic.to_csv(file_path, index=False) |
|
return send_file(file_path, as_attachment=True) |
|
else: |
|
return jsonify({"error": "Failed to generate a valid synthetic dataset."}), 500 |
|
|
|
if __name__ == "__main__": |
|
app.run(host='0.0.0.0', port=8000) |
|
|