yakine commited on
Commit
65793c2
·
verified ·
1 Parent(s): aa5d67a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -17
app.py CHANGED
@@ -1,10 +1,11 @@
1
  import gradio as gr
2
  import pandas as pd
3
- from transformers import GPT2LMHeadModel, GPT2Tokenizer, AutoModelForCausalLM
 
 
4
  from io import StringIO
5
  import os
6
  import torch
7
- from huggingface_hub import HfFolder
8
 
9
  # Access the Hugging Face API token from environment variables
10
  hf_token = os.getenv('HF_API_TOKEN')
@@ -12,17 +13,31 @@ hf_token = os.getenv('HF_API_TOKEN')
12
  if not hf_token:
13
  raise ValueError("Hugging Face API token is not set. Please set the HF_API_TOKEN environment variable.")
14
  HfFolder.save_token(hf_token)
15
- # Load the GPT-2 tokenizer and model
16
- tokenizer_gpt2 = GPT2Tokenizer.from_pretrained('gpt2')
 
 
 
 
17
  model_gpt2 = GPT2LMHeadModel.from_pretrained('gpt2')
18
 
19
- # Load the Llama3 model in sharded mode
20
- model_name = "meta-llama/Meta-Llama-3.1-8B"
21
- try:
22
- model_llama = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16,
23
- load_in_8bit=True, token = hf_token) # use device_map for automatic sharding
24
- except OSError as e:
25
- print(f"Error loading model: {e}")
 
 
 
 
 
 
 
 
 
 
26
 
27
  # Define your prompt template
28
  prompt_template = """\
@@ -47,21 +62,42 @@ Columns:
47
  Output: """
48
 
49
  def preprocess_user_prompt(user_prompt):
50
- generated_text = model_gpt2.generate(tokenizer_gpt2.encode(user_prompt, return_tensors='pt'), max_length=60)[0]
51
- return tokenizer_gpt2.decode(generated_text, skip_special_tokens=True)
52
 
53
  def format_prompt(description, columns):
54
  processed_description = preprocess_user_prompt(description)
55
  prompt = prompt_template.format(description=processed_description, columns=",".join(columns))
56
  return prompt
57
 
 
 
 
 
 
 
 
 
58
  def generate_synthetic_data(description, columns):
59
  try:
 
 
 
60
  formatted_prompt = format_prompt(description, columns)
61
- inputs = tokenizer_llama(formatted_prompt, return_tensors="pt")
62
- generated_output = model_llama.generate(**inputs, max_new_tokens=512)
63
- generated_text = tokenizer_llama.decode(generated_output[0], skip_special_tokens=True)
64
- return generated_text
 
 
 
 
 
 
 
 
 
 
65
  except Exception as e:
66
  print(f"Error in generate_synthetic_data: {e}")
67
  return f"Error: {e}"
 
1
  import gradio as gr
2
  import pandas as pd
3
+ import requests
4
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer, AutoTokenizer, pipeline, AutoModelForCausalLM
5
+ from huggingface_hub import HfFolder
6
  from io import StringIO
7
  import os
8
  import torch
 
9
 
10
  # Access the Hugging Face API token from environment variables
11
  hf_token = os.getenv('HF_API_TOKEN')
 
13
  if not hf_token:
14
  raise ValueError("Hugging Face API token is not set. Please set the HF_API_TOKEN environment variable.")
15
  HfFolder.save_token(hf_token)
16
+
17
+ # Set environment variable to avoid floating-point errors
18
+ os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
19
+
20
+ # Load the tokenizer and model
21
+ tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
22
  model_gpt2 = GPT2LMHeadModel.from_pretrained('gpt2')
23
 
24
+ # Create a pipeline for text generation using GPT-2
25
+ text_generator = pipeline("text-generation", model=model_gpt2, tokenizer=tokenizer)
26
+
27
+ # Lazy loading function for Llama-3 model
28
+ model_llama = None
29
+ tokenizer_llama = None
30
+
31
+ def load_llama_model():
32
+ global model_llama, tokenizer_llama
33
+ if model_llama is None:
34
+ model_name = "meta-llama/Meta-Llama-3.1-8B"
35
+ model_llama = AutoModelForCausalLM.from_pretrained(
36
+ model_name,
37
+ torch_dtype=torch.float16, # Use FP16 for reduced memory
38
+ use_auth_token=hf_token
39
+ )
40
+ tokenizer_llama = AutoTokenizer.from_pretrained(model_name, token=hf_token)
41
 
42
  # Define your prompt template
43
  prompt_template = """\
 
62
  Output: """
63
 
64
  def preprocess_user_prompt(user_prompt):
65
+ generated_text = text_generator(user_prompt, max_length=60, num_return_sequences=1)[0]["generated_text"]
66
+ return generated_text
67
 
68
  def format_prompt(description, columns):
69
  processed_description = preprocess_user_prompt(description)
70
  prompt = prompt_template.format(description=processed_description, columns=",".join(columns))
71
  return prompt
72
 
73
+ generation_params = {
74
+ "top_p": 0.90,
75
+ "temperature": 0.8,
76
+ "max_new_tokens": 512,
77
+ "return_full_text": False,
78
+ "use_cache": False
79
+ }
80
+
81
  def generate_synthetic_data(description, columns):
82
  try:
83
+ # Load the Llama model only when generating data
84
+ load_llama_model()
85
+
86
  formatted_prompt = format_prompt(description, columns)
87
+ payload = {"inputs": formatted_prompt, "parameters": generation_params}
88
+ headers = {"Authorization": f"Bearer {hf_token}"}
89
+
90
+ response = requests.post(API_URL, headers=headers, json=payload)
91
+
92
+ if response.status_code == 200:
93
+ response_json = response.json()
94
+ if isinstance(response_json, list) and len(response_json) > 0 and "generated_text" in response_json[0]:
95
+ return response_json[0]["generated_text"]
96
+ else:
97
+ raise ValueError("Unexpected response format or missing 'generated_text' key")
98
+ else:
99
+ print(f"Error details: {response.text}")
100
+ raise ValueError(f"API request failed with status code {response.status_code}: {response.text}")
101
  except Exception as e:
102
  print(f"Error in generate_synthetic_data: {e}")
103
  return f"Error: {e}"