from fastapi import FastAPI, HTTPException from fastapi.responses import StreamingResponse, JSONResponse from pydantic import BaseModel import pandas as pd import os import requests from transformers import GPT2LMHeadModel, GPT2Tokenizer, AutoTokenizer, pipeline from io import StringIO from fastapi.middleware.cors import CORSMiddleware from huggingface_hub import HfFolder from tqdm import tqdm app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) hf_token = os.getenv('HF_API_TOKEN') if not hf_token: raise ValueError("Hugging Face API token is not set. Please set the HF_API_TOKEN environment variable.") # Load GPT-2 model and tokenizer tokenizer_gpt2 = AutoTokenizer.from_pretrained('gpt2') model_gpt2 = GPT2LMHeadModel.from_pretrained('gpt2') # Create a pipeline for text generation using GPT-2 text_generator = pipeline("text-generation", model=model_gpt2, tokenizer=tokenizer_gpt2) # Define prompt template for generating the dataset prompt_template = """\ You are an AI designed exclusively for generating synthetic tabular datasets. Task: Generate a synthetic dataset based on the provided description and column names. Description: {description} Columns: {columns} Instructions: - Output only the tabular data in valid CSV format. - Include the header row followed strictly by the data rows. - Do not include any additional text, explanations, comments, or code outside of the CSV data. - Ensure that the values for each column are contextually appropriate based on the description. - Do not alter the column names or add any new columns. - Each row must contain data for all columns without any empty values unless specified. - Format Example (do not include this line or the following example in your output): Column1,Column2,Column3 Value1,Value2,Value3 Value4,Value5,Value6 """ # Define generation parameters generation_params = { "top_p": 0.90, "temperature": 0.8, "max_new_tokens": 2048, "return_full_text": False, "use_cache": False } def format_prompt(description, columns): prompt = prompt_template.format(description=description, columns=",".join(columns)) return prompt def generate_synthetic_data(description, columns): formatted_prompt = format_prompt(description, columns) payload = {"inputs": formatted_prompt, "parameters": generation_params} # Call Mixtral model to generate data response = requests.post("https://api-inference.huggingface.co/models/mistralai/Mixtral-8x7B-Instruct-v0.1", headers={"Authorization": f"Bearer {hf_token}"}, json=payload) if response.status_code == 200: return response.json()[0]["generated_text"] else: print(f"Error generating data: {response.status_code}, {response.text}") return None def process_generated_data(csv_data): try: # Ensure the data is cleaned and correctly formatted cleaned_data = csv_data.replace('\r\n', '\n').replace('\r', '\n') data = StringIO(cleaned_data) # Read the CSV data with specific parameters to handle irregularities df = pd.read_csv(data) return df except pd.errors.ParserError as e: print(f"Failed to parse CSV data: {e}") return None def generate_large_synthetic_data(description, columns, num_rows=1000, rows_per_generation=100): data_frames = [] for _ in tqdm(range(num_rows // rows_per_generation), desc="Generating Data"): generated_data = generate_synthetic_data(description, columns) if generated_data: df_synthetic = process_generated_data(generated_data) if df_synthetic is not None and not df_synthetic.empty: data_frames.append(df_synthetic) else: print("Skipping invalid generation.") if data_frames: return pd.concat(data_frames, ignore_index=True) else: print("No valid data frames to concatenate.") return pd.DataFrame(columns=columns) class DataGenerationRequest(BaseModel): description: str columns: list[str] @app.post("/generate/") def generate_data(request: DataGenerationRequest): description = request.description.strip() columns = [col.strip() for col in request.columns] csv_data = generate_large_synthetic_data(description, columns, num_rows=1000, rows_per_generation=100) if csv_data.empty: return JSONResponse(content={"error": "No valid data generated"}, status_code=500) csv_buffer = StringIO() csv_data.to_csv(csv_buffer, index=False) csv_buffer.seek(0) return StreamingResponse( csv_buffer, media_type="text/csv", headers={"Content-Disposition": "attachment; filename=generated_data.csv"} ) @app.get("/") def greet_json(): return {"Hello": "World!"}