yakine commited on
Commit
7ce0c46
·
verified ·
1 Parent(s): a686c13

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -7
app.py CHANGED
@@ -6,7 +6,9 @@ import os
6
  import torch
7
  from transformers import GPT2LMHeadModel, GPT2Tokenizer, AutoTokenizer, AutoModelForCausalLM, pipeline
8
  from io import StringIO
9
- from accelerate import Accelerator
 
 
10
  from fastapi.middleware.cors import CORSMiddleware
11
  import re
12
 
@@ -15,7 +17,7 @@ app = FastAPI()
15
 
16
  app.add_middleware(
17
  CORSMiddleware,
18
- allow_origins=["*"],
19
  allow_credentials=True,
20
  allow_methods=["*"],
21
  allow_headers=["*"],
@@ -39,13 +41,32 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
39
  tokenizer_llama = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B", token=hf_token)
40
  model_llama = AutoModelForCausalLM.from_pretrained(
41
  "meta-llama/Meta-Llama-3-8B",
42
- torch_dtype='float16',
43
- device_map='auto',
44
  token=hf_token
45
- ).to(device)
46
 
47
  # Define your prompt template
48
- prompt_template = """...""" # Your existing prompt template here
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
  class DataGenerationRequest(BaseModel):
51
  description: str
@@ -93,6 +114,7 @@ def generate_synthetic_data(description, columns):
93
  return f"Error: {e}"
94
 
95
  def clean_generated_text(generated_text):
 
96
  csv_match = re.search(r'(\n?([A-Za-z0-9_]+,)*[A-Za-z0-9_]+\n([^\n,]*,)*[^\n,]*\n*)+', generated_text)
97
 
98
  if csv_match:
@@ -103,8 +125,10 @@ def clean_generated_text(generated_text):
103
  return csv_text
104
 
105
  def process_generated_data(csv_data):
 
106
  cleaned_data = clean_generated_text(csv_data)
107
 
 
108
  data = StringIO(cleaned_data)
109
  df = pd.read_csv(data)
110
 
@@ -119,9 +143,12 @@ def generate_data(request: DataGenerationRequest):
119
  if "Error" in generated_data:
120
  return JSONResponse(content={"error": generated_data}, status_code=500)
121
 
 
122
  df_synthetic = process_generated_data(generated_data)
123
  return JSONResponse(content={"data": df_synthetic.to_dict(orient="records")})
124
 
 
 
125
  @app.get("/")
126
  def greet_json():
127
- return {"Hello": "World!"}
 
6
  import torch
7
  from transformers import GPT2LMHeadModel, GPT2Tokenizer, AutoTokenizer, AutoModelForCausalLM, pipeline
8
  from io import StringIO
9
+ from tqdm import tqdm
10
+ import accelerate
11
+ from accelerate import init_empty_weights, disk_offload
12
  from fastapi.middleware.cors import CORSMiddleware
13
  import re
14
 
 
17
 
18
  app.add_middleware(
19
  CORSMiddleware,
20
+ allow_origins=["*"], # You can specify domains here
21
  allow_credentials=True,
22
  allow_methods=["*"],
23
  allow_headers=["*"],
 
41
  tokenizer_llama = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B", token=hf_token)
42
  model_llama = AutoModelForCausalLM.from_pretrained(
43
  "meta-llama/Meta-Llama-3-8B",
44
+ torch_dtype='auto',
45
+ device_map='balanced',
46
  token=hf_token
47
+ ).to(device)
48
 
49
  # Define your prompt template
50
+ prompt_template = """\
51
+ You are an expert in generating synthetic data for machine learning models.
52
+ Your task is to generate a synthetic tabular dataset based on the description provided below.
53
+ Description: {description}
54
+ The dataset should include the following columns: {columns}
55
+ Please provide the data in CSV format with a minimum of 100 rows per generation.
56
+ Ensure that the data is realistic, does not contain any duplicate rows, and follows any specific conditions mentioned.
57
+ Example Description:
58
+ Generate a dataset for predicting house prices with columns: 'Size', 'Location', 'Number of Bedrooms', 'Price'
59
+ Example Output:
60
+ Size,Location,Number of Bedrooms,Price
61
+ 1200,Suburban,3,250000
62
+ 900,Urban,2,200000
63
+ 1500,Rural,4,300000
64
+ ...
65
+ Description:
66
+ {description}
67
+ Columns:
68
+ {columns}
69
+ Output: """
70
 
71
  class DataGenerationRequest(BaseModel):
72
  description: str
 
114
  return f"Error: {e}"
115
 
116
  def clean_generated_text(generated_text):
117
+ # Extract CSV part using a regular expression
118
  csv_match = re.search(r'(\n?([A-Za-z0-9_]+,)*[A-Za-z0-9_]+\n([^\n,]*,)*[^\n,]*\n*)+', generated_text)
119
 
120
  if csv_match:
 
125
  return csv_text
126
 
127
  def process_generated_data(csv_data):
128
+ # Clean the generated data
129
  cleaned_data = clean_generated_text(csv_data)
130
 
131
+ # Convert to DataFrame
132
  data = StringIO(cleaned_data)
133
  df = pd.read_csv(data)
134
 
 
143
  if "Error" in generated_data:
144
  return JSONResponse(content={"error": generated_data}, status_code=500)
145
 
146
+ # Process the generated CSV data into a DataFrame
147
  df_synthetic = process_generated_data(generated_data)
148
  return JSONResponse(content={"data": df_synthetic.to_dict(orient="records")})
149
 
150
+
151
+
152
  @app.get("/")
153
  def greet_json():
154
+ return {"Hello": "World!"}