best / app.py
yakine's picture
Update app.py
c549aa3 verified
raw
history blame
4.87 kB
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse, JSONResponse
from pydantic import BaseModel
import os
import requests
from transformers import GPT2LMHeadModel, GPT2Tokenizer, AutoTokenizer, pipeline
from io import StringIO
from fastapi.middleware.cors import CORSMiddleware
from huggingface_hub import HfFolder
import logging
import random
import csv
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # You can specify domains here
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Access the Hugging Face API token from environment variables
hf_token = os.getenv('HF_API_TOKEN')
if not hf_token:
raise ValueError("Hugging Face API token is not set. Please set the HF_API_TOKEN environment variable.")
# Load GPT-2 model and tokenizer
tokenizer_gpt2 = GPT2Tokenizer.from_pretrained('gpt2')
model_gpt2 = GPT2LMHeadModel.from_pretrained('gpt2')
# Create a pipeline for text generation using GPT-2
text_generator = pipeline("text-generation", model=model_gpt2, tokenizer=tokenizer_gpt2)
def preprocess_user_prompt(user_prompt):
# Generate a structured prompt based on the user input
generated_text = text_generator(user_prompt, max_length=50, num_return_sequences=1)[0]["generated_text"]
return generated_text
# Define prompt template
prompt_template = """\
You are an expert in generating synthetic data for machine learning models.
Your task is to generate a synthetic tabular dataset based on the description provided below.
Description: {description}
The dataset should include the following columns: {columns}
Please provide the data in CSV format.
Example Description:
Generate a dataset for predicting house prices with columns: 'Size', 'Location', 'Number of Bedrooms', 'Price'
Example Output:
Size,Location,Number of Bedrooms,Price
1200,Suburban,3,250000
900,Urban,2,200000
1500,Rural,4,300000
...
Description:
{description}
Columns:
{columns}
Output: """
tokenizer_mixtral = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-Instruct-v0.1", token=hf_token)
def format_prompt(description, columns):
processed_description = preprocess_user_prompt(description)
prompt = prompt_template.format(description=processed_description, columns=",".join(columns))
return prompt
API_URL = "https://api-inference.huggingface.co/models/mistralai/Mixtral-8x7B-Instruct-v0.1"
generation_params = {
"top_p": 0.90,
"temperature": 0.8,
"max_new_tokens": 512,
"return_full_text": False,
"use_cache": False
}
def generate_synthetic_data(description, columns):
formatted_prompt = format_prompt(description, columns)
payload = {"inputs": formatted_prompt, "parameters": generation_params}
try:
response = requests.post(API_URL, headers={"Authorization": f"Bearer {hf_token}"}, json=payload)
response.raise_for_status()
data = response.json()
if 'generated_text' in data[0]:
return data[0]['generated_text']
else:
raise ValueError("Invalid response format from Hugging Face API.")
except (requests.RequestException, ValueError) as e:
logging.error(f"Error during API request or response processing: {e}")
return "name,age,course,grade\nSampleName,20,Course,0"
def generate_large_synthetic_data(description, columns, num_rows=1000, rows_per_generation=100):
csv_buffer = StringIO()
writer = csv.writer(csv_buffer)
# Write header
writer.writerow(columns)
rows_generated = 0
while rows_generated < num_rows:
generated_data = generate_synthetic_data(description, columns)
cleaned_data = generated_data.replace('\r\n', '\n').replace('\r', '\n')
data = StringIO(cleaned_data)
# Append rows to CSV buffer
reader = csv.reader(data)
header_written = False
for row in reader:
if not header_written:
header_written = True
continue # Skip the header of the generated data
writer.writerow(row)
rows_generated += 1
if rows_generated >= num_rows:
break
csv_buffer.seek(0)
return csv_buffer
class DataGenerationRequest(BaseModel):
description: str
columns: list[str]
@app.post("/generate/")
def generate_data(request: DataGenerationRequest):
description = request.description.strip()
columns = [col.strip() for col in request.columns]
csv_buffer = generate_large_synthetic_data(description, columns, num_rows=10000)
# Return the CSV data as a downloadable file
return StreamingResponse(
csv_buffer,
media_type="text/csv",
headers={"Content-Disposition": "attachment; filename=generated_data.csv"}
)
@app.get("/")
def greet_json():
return {"Hello": "World!"}