yakine commited on
Commit
3dd5eb9
·
verified ·
1 Parent(s): 6ee338b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -59
app.py CHANGED
@@ -1,43 +1,17 @@
1
  import gradio as gr
2
  import pandas as pd
3
  import requests
4
- from transformers import GPT2LMHeadModel, GPT2Tokenizer, AutoTokenizer, pipeline, AutoModelForCausalLM
5
- from huggingface_hub import HfFolder
6
  from io import StringIO
7
  import os
8
- import torch
9
 
10
  # Access the Hugging Face API token from environment variables
11
  hf_token = os.getenv('HF_API_TOKEN')
12
 
13
  if not hf_token:
14
  raise ValueError("Hugging Face API token is not set. Please set the HF_API_TOKEN environment variable.")
15
- HfFolder.save_token(hf_token)
16
 
17
- # Set environment variable to avoid floating-point errors
18
- os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
19
-
20
- # Load the tokenizer and model
21
- tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
22
- model_gpt2 = GPT2LMHeadModel.from_pretrained('gpt2')
23
-
24
- # Create a pipeline for text generation using GPT-2
25
- text_generator = pipeline("text-generation", model=model_gpt2, tokenizer=tokenizer)
26
-
27
- # Lazy loading function for Llama-3 model
28
- model_llama = None
29
- tokenizer_llama = None
30
-
31
- def load_llama_model():
32
- global model_llama, tokenizer_llama
33
- if model_llama is None:
34
- model_name = "meta-llama/Meta-Llama-3.1-8B"
35
- model_llama = AutoModelForCausalLM.from_pretrained(
36
- model_name,
37
- torch_dtype=torch.float16, # Use FP16 for reduced memory
38
- use_auth_token=hf_token
39
- )
40
- tokenizer_llama = AutoTokenizer.from_pretrained(model_name, token=hf_token)
41
 
42
  # Define your prompt template
43
  prompt_template = """\
@@ -62,53 +36,45 @@ Columns:
62
  Output: """
63
 
64
  def preprocess_user_prompt(user_prompt):
65
- generated_text = text_generator(user_prompt, max_length=60, num_return_sequences=1)[0]["generated_text"]
66
- return generated_text
67
 
68
  def format_prompt(description, columns):
69
  processed_description = preprocess_user_prompt(description)
70
  prompt = prompt_template.format(description=processed_description, columns=",".join(columns))
71
  return prompt
72
 
73
- generation_params = {
74
- "top_p": 0.90,
75
- "temperature": 0.8,
76
- "max_new_tokens": 512,
77
- "return_full_text": False,
78
- "use_cache": False
79
- }
80
-
81
  def generate_synthetic_data(description, columns):
82
  try:
83
- # Load the Llama model only when generating data
84
- load_llama_model()
85
-
86
- # Prepare the input for the Llama model
87
  formatted_prompt = format_prompt(description, columns)
88
 
89
- # Tokenize the prompt
90
- inputs = tokenizer_llama(formatted_prompt, return_tensors="pt").to(model_llama.device)
 
 
 
 
 
 
 
 
 
 
 
91
 
92
- # Generate synthetic data
93
- with torch.no_grad():
94
- outputs = model_llama.generate(
95
- **inputs,
96
- max_length=512,
97
- top_p=generation_params["top_p"],
98
- temperature=generation_params["temperature"],
99
- num_return_sequences=1
100
- )
101
-
102
- # Decode the generated output
103
- generated_text = tokenizer_llama.decode(outputs[0], skip_special_tokens=True)
104
 
105
- # Return the generated synthetic data
 
106
  return generated_text
 
107
  except Exception as e:
108
  print(f"Error in generate_synthetic_data: {e}")
109
  return f"Error: {e}"
110
 
111
-
112
  def generate_large_synthetic_data(description, columns, num_rows=1000, rows_per_generation=100):
113
  data_frames = []
114
  num_iterations = num_rows // rows_per_generation
@@ -147,7 +113,6 @@ iface = gr.Interface(
147
  api_name="generate" # Set the API name directly here
148
  )
149
 
150
-
151
  iface.api_name = "generate"
152
 
153
  # Run the Gradio app
 
1
  import gradio as gr
2
  import pandas as pd
3
  import requests
 
 
4
  from io import StringIO
5
  import os
 
6
 
7
  # Access the Hugging Face API token from environment variables
8
  hf_token = os.getenv('HF_API_TOKEN')
9
 
10
  if not hf_token:
11
  raise ValueError("Hugging Face API token is not set. Please set the HF_API_TOKEN environment variable.")
 
12
 
13
+ # Set the inference endpoint URL
14
+ inference_endpoint = "https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3.1-8B"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  # Define your prompt template
17
  prompt_template = """\
 
36
  Output: """
37
 
38
  def preprocess_user_prompt(user_prompt):
39
+ return user_prompt
 
40
 
41
  def format_prompt(description, columns):
42
  processed_description = preprocess_user_prompt(description)
43
  prompt = prompt_template.format(description=processed_description, columns=",".join(columns))
44
  return prompt
45
 
 
 
 
 
 
 
 
 
46
  def generate_synthetic_data(description, columns):
47
  try:
48
+ # Format the prompt
 
 
 
49
  formatted_prompt = format_prompt(description, columns)
50
 
51
+ # Send a POST request to the Hugging Face Inference API
52
+ headers = {
53
+ "Authorization": f"Bearer {hf_token}",
54
+ "Content-Type": "application/json"
55
+ }
56
+ data = {
57
+ "inputs": formatted_prompt,
58
+ "parameters": {
59
+ "max_new_tokens": 512,
60
+ "top_p": 0.90,
61
+ "temperature": 0.8
62
+ }
63
+ }
64
 
65
+ response = requests.post(inference_endpoint, json=data, headers=headers)
66
+
67
+ if response.status_code != 200:
68
+ return f"Error: {response.status_code}, {response.text}"
 
 
 
 
 
 
 
 
69
 
70
+ # Extract the generated text from the response
71
+ generated_text = response.json()[0]['generated_text']
72
  return generated_text
73
+
74
  except Exception as e:
75
  print(f"Error in generate_synthetic_data: {e}")
76
  return f"Error: {e}"
77
 
 
78
  def generate_large_synthetic_data(description, columns, num_rows=1000, rows_per_generation=100):
79
  data_frames = []
80
  num_iterations = num_rows // rows_per_generation
 
113
  api_name="generate" # Set the API name directly here
114
  )
115
 
 
116
  iface.api_name = "generate"
117
 
118
  # Run the Gradio app