yakine commited on
Commit
c6ded12
·
verified ·
1 Parent(s): 65793c2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -12
app.py CHANGED
@@ -83,25 +83,32 @@ def generate_synthetic_data(description, columns):
83
  # Load the Llama model only when generating data
84
  load_llama_model()
85
 
 
86
  formatted_prompt = format_prompt(description, columns)
87
- payload = {"inputs": formatted_prompt, "parameters": generation_params}
88
- headers = {"Authorization": f"Bearer {hf_token}"}
89
 
90
- response = requests.post(API_URL, headers=headers, json=payload)
 
91
 
92
- if response.status_code == 200:
93
- response_json = response.json()
94
- if isinstance(response_json, list) and len(response_json) > 0 and "generated_text" in response_json[0]:
95
- return response_json[0]["generated_text"]
96
- else:
97
- raise ValueError("Unexpected response format or missing 'generated_text' key")
98
- else:
99
- print(f"Error details: {response.text}")
100
- raise ValueError(f"API request failed with status code {response.status_code}: {response.text}")
 
 
 
 
 
 
101
  except Exception as e:
102
  print(f"Error in generate_synthetic_data: {e}")
103
  return f"Error: {e}"
104
 
 
105
  def generate_large_synthetic_data(description, columns, num_rows=1000, rows_per_generation=100):
106
  data_frames = []
107
  num_iterations = num_rows // rows_per_generation
 
83
  # Load the Llama model only when generating data
84
  load_llama_model()
85
 
86
+ # Prepare the input for the Llama model
87
  formatted_prompt = format_prompt(description, columns)
 
 
88
 
89
+ # Tokenize the prompt
90
+ inputs = tokenizer_llama(formatted_prompt, return_tensors="pt").to(model_llama.device)
91
 
92
+ # Generate synthetic data
93
+ with torch.no_grad():
94
+ outputs = model_llama.generate(
95
+ **inputs,
96
+ max_length=512,
97
+ top_p=generation_params["top_p"],
98
+ temperature=generation_params["temperature"],
99
+ num_return_sequences=1
100
+ )
101
+
102
+ # Decode the generated output
103
+ generated_text = tokenizer_llama.decode(outputs[0], skip_special_tokens=True)
104
+
105
+ # Return the generated synthetic data
106
+ return generated_text
107
  except Exception as e:
108
  print(f"Error in generate_synthetic_data: {e}")
109
  return f"Error: {e}"
110
 
111
+
112
  def generate_large_synthetic_data(description, columns, num_rows=1000, rows_per_generation=100):
113
  data_frames = []
114
  num_iterations = num_rows // rows_per_generation