File size: 4,929 Bytes
4d35d05 857e937 170c5f1 eaf3f9a 170c5f1 bad46c5 dbb92e9 170c5f1 1ab421e bad46c5 eaf3f9a eaf6f50 1ab421e 7af38c7 1ab421e 170c5f1 c7b1e29 170c5f1 c7b1e29 170c5f1 6bc7f9d c7b1e29 7ce0c46 6bc7f9d c7b1e29 7ce0c46 c7b1e29 6bc7f9d c7b1e29 3476a0f 6bc7f9d c7b1e29 170c5f1 0dc54fb bad46c5 170c5f1 c7b1e29 170c5f1 bad46c5 c7b1e29 d7ee939 c7b1e29 be8d2d5 c7b1e29 eaf3f9a be8d2d5 c7b1e29 eaf3f9a bad46c5 eaf3f9a c7b1e29 eaf3f9a bad46c5 d7821a8 c7b1e29 eaf3f9a 01c0141 31d4200 3476a0f 31d4200 eaf3f9a 65871c9 eaf3f9a 31d4200 be8d2d5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse, JSONResponse
from pydantic import BaseModel
import pandas as pd
import os
import requests
from transformers import GPT2LMHeadModel, GPT2Tokenizer, AutoTokenizer, pipeline
from io import StringIO
from fastapi.middleware.cors import CORSMiddleware
from huggingface_hub import HfFolder
from tqdm import tqdm
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
hf_token = os.getenv('HF_API_TOKEN')
if not hf_token:
raise ValueError("Hugging Face API token is not set. Please set the HF_API_TOKEN environment variable.")
# Load GPT-2 model and tokenizer
tokenizer_gpt2 = AutoTokenizer.from_pretrained('gpt2')
model_gpt2 = GPT2LMHeadModel.from_pretrained('gpt2')
# Create a pipeline for text generation using GPT-2
text_generator = pipeline("text-generation", model=model_gpt2, tokenizer=tokenizer_gpt2)
# Define prompt template for generating the dataset
prompt_template = """\
You are an AI designed exclusively for generating synthetic tabular datasets.
Task: Generate a synthetic dataset based on the provided description and column names.
Description: {description}
Columns: {columns}
Instructions:
- Output only the tabular data in valid CSV format.
- Include the header row followed strictly by the data rows.
- Do not include any additional text, explanations, comments, or code outside of the CSV data.
- Ensure that the values for each column are contextually appropriate based on the description.
- Do not alter the column names or add any new columns.
- Each row must contain data for all columns without any empty values unless specified.
- Format Example (do not include this line or the following example in your output):
Column1,Column2,Column3
Value1,Value2,Value3
Value4,Value5,Value6
"""
# Define generation parameters
generation_params = {
"top_p": 0.90,
"temperature": 0.8,
"max_new_tokens": 2048,
"return_full_text": False,
"use_cache": False
}
def format_prompt(description, columns):
prompt = prompt_template.format(description=description, columns=",".join(columns))
return prompt
def generate_synthetic_data(description, columns):
formatted_prompt = format_prompt(description, columns)
payload = {"inputs": formatted_prompt, "parameters": generation_params}
# Call Mixtral model to generate data
response = requests.post("https://api-inference.huggingface.co/models/mistralai/Mixtral-8x7B-Instruct-v0.1",
headers={"Authorization": f"Bearer {hf_token}"}, json=payload)
if response.status_code == 200:
return response.json()[0]["generated_text"]
else:
print(f"Error generating data: {response.status_code}, {response.text}")
return None
def process_generated_data(csv_data):
try:
# Ensure the data is cleaned and correctly formatted
cleaned_data = csv_data.replace('\r\n', '\n').replace('\r', '\n')
data = StringIO(cleaned_data)
# Read the CSV data with specific parameters to handle irregularities
df = pd.read_csv(data)
return df
except pd.errors.ParserError as e:
print(f"Failed to parse CSV data: {e}")
return None
def generate_large_synthetic_data(description, columns, num_rows=1000, rows_per_generation=100):
data_frames = []
for _ in tqdm(range(num_rows // rows_per_generation), desc="Generating Data"):
generated_data = generate_synthetic_data(description, columns)
if generated_data:
df_synthetic = process_generated_data(generated_data)
if df_synthetic is not None and not df_synthetic.empty:
data_frames.append(df_synthetic)
else:
print("Skipping invalid generation.")
if data_frames:
return pd.concat(data_frames, ignore_index=True)
else:
print("No valid data frames to concatenate.")
return pd.DataFrame(columns=columns)
class DataGenerationRequest(BaseModel):
description: str
columns: list[str]
@app.post("/generate/")
def generate_data(request: DataGenerationRequest):
description = request.description.strip()
columns = [col.strip() for col in request.columns]
csv_data = generate_large_synthetic_data(description, columns, num_rows=1000, rows_per_generation=100)
if csv_data.empty:
return JSONResponse(content={"error": "No valid data generated"}, status_code=500)
csv_buffer = StringIO()
csv_data.to_csv(csv_buffer, index=False)
csv_buffer.seek(0)
return StreamingResponse(
csv_buffer,
media_type="text/csv",
headers={"Content-Disposition": "attachment; filename=generated_data.csv"}
)
@app.get("/")
def greet_json():
return {"Hello": "World!"} |