yakine commited on
Commit
bbce957
·
verified ·
1 Parent(s): 2f46caf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +137 -14
app.py CHANGED
@@ -1,26 +1,149 @@
1
  import streamlit as st
2
- import transformers
3
- from transformers import AutoTokenizer, AutoModelForCausalLM
 
 
 
4
  import os
5
- from transformers import pipeline
6
  import torch
7
 
 
8
  hf_token = os.getenv('HF_API_TOKEN')
9
 
 
 
 
10
 
 
 
11
 
12
- # Load the Llama 3.1 model and tokenizer
13
- model_name = "meta-llama/Meta-Llama-3.1-8B"
14
- tokenizer = AutoTokenizer.from_pretrained(model_name, token= hf_token)
15
- model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto", token= hf_token)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  # Streamlit app interface
18
- st.title("Llama 3.1 Text Generator")
19
- prompt = st.text_area("Enter a prompt:", "Once upon a time")
 
20
 
21
  if st.button("Generate"):
22
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
23
- outputs = model.generate(**inputs, max_length=512, top_p=0.9, temperature=0.8)
24
- generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
25
- st.write(generated_text)
26
-
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import pandas as pd
3
+ import requests
4
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer, AutoTokenizer, AutoModelForCausalLM
5
+ from huggingface_hub import HfFolder
6
+ from io import StringIO
7
  import os
 
8
  import torch
9
 
10
+ # Access the Hugging Face API token from environment variables
11
  hf_token = os.getenv('HF_API_TOKEN')
12
 
13
+ if not hf_token:
14
+ raise ValueError("Hugging Face API token is not set. Please set the HF_API_TOKEN environment variable.")
15
+ HfFolder.save_token(hf_token)
16
 
17
+ # Set environment variable to avoid floating-point errors
18
+ os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
19
 
20
+ # Load the tokenizer and model
21
+ tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
22
+ model_gpt2 = GPT2LMHeadModel.from_pretrained('gpt2')
23
+
24
+ # Create a pipeline for text generation using GPT-2
25
+ text_generator = pipeline("text-generation", model=model_gpt2, tokenizer=tokenizer)
26
+
27
+ # Lazy loading function for Llama-3 model
28
+ model_llama = None
29
+ tokenizer_llama = None
30
+
31
+ def load_llama_model():
32
+ global model_llama, tokenizer_llama
33
+ if model_llama is None:
34
+ model_name = "meta-llama/Meta-Llama-3.1-8B"
35
+ model_llama = AutoModelForCausalLM.from_pretrained(
36
+ model_name,
37
+ torch_dtype=torch.float16, # Use FP16 for reduced memory
38
+ use_auth_token=hf_token
39
+ )
40
+ tokenizer_llama = AutoTokenizer.from_pretrained(model_name, token=hf_token)
41
+
42
+ # Define your prompt template
43
+ prompt_template = """\
44
+ You are an expert in generating synthetic data for machine learning models.
45
+ Your task is to generate a synthetic tabular dataset based on the description provided below.
46
+ Description: {description}
47
+ The dataset should include the following columns: {columns}
48
+ Please provide the data in CSV format with a minimum of 100 rows per generation.
49
+ Ensure that the data is realistic, does not contain any duplicate rows, and follows any specific conditions mentioned.
50
+ Example Description:
51
+ Generate a dataset for predicting house prices with columns: 'Size', 'Location', 'Number of Bedrooms', 'Price'
52
+ Example Output:
53
+ Size,Location,Number of Bedrooms,Price
54
+ 1200,Suburban,3,250000
55
+ 900,Urban,2,200000
56
+ 1500,Rural,4,300000
57
+ ...
58
+ Description:
59
+ {description}
60
+ Columns:
61
+ {columns}
62
+ Output: """
63
+
64
+ def preprocess_user_prompt(user_prompt):
65
+ generated_text = text_generator(user_prompt, max_length=60, num_return_sequences=1)[0]["generated_text"]
66
+ return generated_text
67
+
68
+ def format_prompt(description, columns):
69
+ processed_description = preprocess_user_prompt(description)
70
+ prompt = prompt_template.format(description=processed_description, columns=",".join(columns))
71
+ return prompt
72
+
73
+ generation_params = {
74
+ "top_p": 0.90,
75
+ "temperature": 0.8,
76
+ "max_new_tokens": 512,
77
+ "return_full_text": False,
78
+ "use_cache": False
79
+ }
80
+
81
+ def generate_synthetic_data(description, columns):
82
+ try:
83
+ # Load the Llama model only when generating data
84
+ load_llama_model()
85
+
86
+ # Prepare the input for the Llama model
87
+ formatted_prompt = format_prompt(description, columns)
88
+
89
+ # Tokenize the prompt
90
+ inputs = tokenizer_llama(formatted_prompt, return_tensors="pt").to(model_llama.device)
91
+
92
+ # Generate synthetic data
93
+ with torch.no_grad():
94
+ outputs = model_llama.generate(
95
+ **inputs,
96
+ max_length=512,
97
+ top_p=generation_params["top_p"],
98
+ temperature=generation_params["temperature"],
99
+ num_return_sequences=1
100
+ )
101
+
102
+ # Decode the generated output
103
+ generated_text = tokenizer_llama.decode(outputs[0], skip_special_tokens=True)
104
+
105
+ # Return the generated synthetic data
106
+ return generated_text
107
+ except Exception as e:
108
+ print(f"Error in generate_synthetic_data: {e}")
109
+ return f"Error: {e}"
110
+
111
+ def generate_large_synthetic_data(description, columns, num_rows=1000, rows_per_generation=100):
112
+ data_frames = []
113
+ num_iterations = num_rows // rows_per_generation
114
+
115
+ for _ in range(num_iterations):
116
+ generated_data = generate_synthetic_data(description, columns)
117
+ if "Error" in generated_data:
118
+ return generated_data
119
+ df_synthetic = process_generated_data(generated_data)
120
+ data_frames.append(df_synthetic)
121
+
122
+ return pd.concat(data_frames, ignore_index=True)
123
+
124
+ def process_generated_data(csv_data):
125
+ data = StringIO(csv_data)
126
+ df = pd.read_csv(data)
127
+ return df
128
 
129
  # Streamlit app interface
130
+ st.title("Synthetic Data Generator")
131
+ description = st.text_input("Description", "e.g., Generate a dataset for predicting students' grades")
132
+ columns = st.text_input("Columns (comma-separated)", "e.g., name, age, course, grade")
133
 
134
  if st.button("Generate"):
135
+ description = description.strip()
136
+ columns = [col.strip() for col in columns.split(',')]
137
+ df_synthetic = generate_large_synthetic_data(description, columns)
138
+
139
+ if isinstance(df_synthetic, str) and "Error" in df_synthetic:
140
+ st.error(df_synthetic) # Display error message if any
141
+ else:
142
+ st.success("Synthetic Data Generated!")
143
+ st.dataframe(df_synthetic) # Display the generated DataFrame
144
+ st.download_button(
145
+ label="Download CSV",
146
+ data=df_synthetic.to_csv(index=False),
147
+ file_name="synthetic_data.csv",
148
+ mime="text/csv"
149
+ )