yakine commited on
Commit
d3ea071
·
verified ·
1 Parent(s): 26ca4a0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -10
app.py CHANGED
@@ -6,7 +6,7 @@ import transformers
6
  from transformers import GPT2LMHeadModel, GPT2Tokenizer, AutoTokenizer, AutoModelForCausalLM, pipeline
7
  from huggingface_hub import HfFolder
8
  from io import StringIO
9
- from tqdm import tqdm # To display progress bar in Streamlit
10
 
11
  # Access the Hugging Face API token from environment variables
12
  hf_token = os.getenv('HF_API_TOKEN')
@@ -26,16 +26,14 @@ model_gpt2 = GPT2LMHeadModel.from_pretrained('gpt2')
26
  text_generator = pipeline("text-generation", model=model_gpt2, tokenizer=tokenizer_gpt2)
27
 
28
  # Load the Llama-3 model and tokenizer once during startup
29
- tokenizer_llama = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B", token=hf_token)
30
  model_llama = AutoModelForCausalLM.from_pretrained(
31
- "meta-llama/Meta-Llama-3.1-8B",
32
- torch_dtype= 'auto',
33
- device_map= 'auto',
34
  token=hf_token
35
  )
36
 
37
-
38
-
39
  # Define your prompt template
40
  prompt_template = """\
41
  You are an expert in generating synthetic data for machine learning models.
@@ -59,7 +57,7 @@ Columns:
59
  Output: """
60
 
61
  def preprocess_user_prompt(user_prompt):
62
- generated_text = text_generator(user_prompt, max_length=60, num_return_sequences=1)[0]["generated_text"]
63
  return generated_text
64
 
65
  def format_prompt(description, columns):
@@ -80,8 +78,8 @@ def generate_synthetic_data(description, columns):
80
  # Prepare the input for the Llama model
81
  formatted_prompt = format_prompt(description, columns)
82
 
83
- # Tokenize the prompt
84
- inputs = tokenizer_llama(formatted_prompt, return_tensors="pt").to(model_llama.device)
85
 
86
  # Generate synthetic data
87
  with torch.no_grad():
 
6
  from transformers import GPT2LMHeadModel, GPT2Tokenizer, AutoTokenizer, AutoModelForCausalLM, pipeline
7
  from huggingface_hub import HfFolder
8
  from io import StringIO
9
+ from tqdm import tqdm
10
 
11
  # Access the Hugging Face API token from environment variables
12
  hf_token = os.getenv('HF_API_TOKEN')
 
26
  text_generator = pipeline("text-generation", model=model_gpt2, tokenizer=tokenizer_gpt2)
27
 
28
  # Load the Llama-3 model and tokenizer once during startup
29
+ tokenizer_llama = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B", token=hf_token)
30
  model_llama = AutoModelForCausalLM.from_pretrained(
31
+ "meta-llama/Meta-Llama-3-8B",
32
+ torch_dtype='auto',
33
+ device_map='auto',
34
  token=hf_token
35
  )
36
 
 
 
37
  # Define your prompt template
38
  prompt_template = """\
39
  You are an expert in generating synthetic data for machine learning models.
 
57
  Output: """
58
 
59
  def preprocess_user_prompt(user_prompt):
60
+ generated_text = text_generator(user_prompt, max_length=60, num_return_sequences=1, truncation=True)[0]["generated_text"]
61
  return generated_text
62
 
63
  def format_prompt(description, columns):
 
78
  # Prepare the input for the Llama model
79
  formatted_prompt = format_prompt(description, columns)
80
 
81
+ # Tokenize the prompt with truncation enabled
82
+ inputs = tokenizer_llama(formatted_prompt, return_tensors="pt", truncation=True).to(model_llama.device)
83
 
84
  # Generate synthetic data
85
  with torch.no_grad():