kaikaidai commited on
Commit
1fc8c4c
·
verified ·
1 Parent(s): b477b2d

Synced repo using 'sync_with_huggingface' Github Action

Browse files
.DS_Store ADDED
Binary file (8.2 kB). View file
 
random_sample/arena_interface.py CHANGED
@@ -6,17 +6,13 @@ from dotenv import load_dotenv
6
  load_dotenv()
7
 
8
  from .gen_api_answer import (
9
- get_atla_response,
10
- get_selene_mini_response,
11
- parse_selene_mini_response
12
  )
13
 
14
  from .prompts import (
15
  DEFAULT_EVAL_CRITERIA,
16
  DEFAULT_EVAL_PROMPT,
17
- DEFAULT_EVAL_PROMPT_EDITABLE,
18
- ATLA_PROMPT,
19
- ATLA_PROMPT_WITH_REFERENCE
20
  )
21
 
22
  from .random_sample_generation import (
@@ -255,62 +251,35 @@ def create_arena_interface():
255
  ai_response,
256
  ground_truth,
257
  ):
258
- if model_choice == "Selene Mini":
259
- # Prepare prompt based on reference mode
260
- prompt_template = ATLA_PROMPT_WITH_REFERENCE if use_reference else ATLA_PROMPT
261
- prompt = prompt_template.format(
262
- human_input=human_input,
263
- ai_response=ai_response,
264
- eval_criteria=eval_criteria_text,
265
- ground_truth=ground_truth if use_reference else ""
266
- )
267
-
268
- print("\n=== Debug: Prompt being sent to Selene Mini ===")
269
- print(prompt)
270
- print("============================================\n")
271
-
272
- # Get and parse response
273
- raw_response = get_selene_mini_response(
274
- model_name="AtlaAI/Selene-1-Mini-Llama-3.1-8B",
275
- prompt=prompt,
276
- max_tokens=500,
277
- temperature=0.01
278
- )
279
- response = parse_selene_mini_response(raw_response)
280
- else:
281
- # Selene API logic
282
- prompt_data = {
283
- 'human_input': human_input,
284
- 'ai_response': ai_response,
285
- 'ground_truth': ground_truth if use_reference else None,
286
- 'eval_criteria': eval_criteria_text,
287
- }
288
-
289
- print("\n=== Debug: Prompt data being sent to Selene API ===")
290
- print(json.dumps(prompt_data, indent=2))
291
- print("============================================\n")
292
-
293
- response = get_atla_response(
294
- model_name="AtlaAI/Selene-1-Mini-Llama-3.1-8B",
295
- prompt=prompt_data,
296
- max_tokens=500,
297
- temperature=0.01
298
- )
299
-
300
- # Response now contains score and critique directly
301
- if isinstance(response, dict) and 'score' in response and 'critique' in response:
302
- score = str(response['score'])
303
- critique = response['critique']
304
- else:
305
- score = "Error"
306
- critique = str(response)
307
-
308
- return [
309
- score,
310
- critique,
311
- gr.update(value="Regenerate evaluation", variant="secondary", interactive=True),
312
- gr.update(value="🎲"),
313
- ]
314
 
315
  # Update the send_btn click handler with new input
316
  send_btn.click(
 
6
  load_dotenv()
7
 
8
  from .gen_api_answer import (
9
+ get_atla_response
 
 
10
  )
11
 
12
  from .prompts import (
13
  DEFAULT_EVAL_CRITERIA,
14
  DEFAULT_EVAL_PROMPT,
15
+ DEFAULT_EVAL_PROMPT_EDITABLE
 
 
16
  )
17
 
18
  from .random_sample_generation import (
 
251
  ai_response,
252
  ground_truth,
253
  ):
254
+ # Prepare prompt data for both models
255
+ prompt_data = {
256
+ 'human_input': human_input,
257
+ 'ai_response': ai_response,
258
+ 'ground_truth': ground_truth if use_reference else None,
259
+ 'eval_criteria': eval_criteria_text,
260
+ }
261
+
262
+ print("\n=== Debug: Prompt data being sent to Selene API ===")
263
+ print(json.dumps(prompt_data, indent=2))
264
+ print("============================================\n")
265
+
266
+ # Use appropriate model ID based on selection
267
+ model_id = "atla-selene-mini" if model_choice == "Selene Mini" else "atla-selene"
268
+
269
+ response = get_atla_response(
270
+ model_name=model_id,
271
+ prompt=prompt_data,
272
+ max_tokens=500,
273
+ temperature=0.01
274
+ )
275
+
276
+ # Format the response for display
277
+ score_text = f"{response['score']}/5"
278
+ critique_text = f"{response['critique']}"
279
+
280
+ # Return all required values for the UI components
281
+ return score_text, critique_text, gr.update(value="Regenerate evaluation", variant="secondary", interactive=True), gr.update(value="🎲", variant="primary")
282
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
 
284
  # Update the send_btn click handler with new input
285
  send_btn.click(
random_sample/gen_api_answer.py CHANGED
@@ -7,10 +7,6 @@ from dotenv import load_dotenv
7
  from .prompts import (
8
  JUDGE_SYSTEM_PROMPT
9
  )
10
- from transformers import AutoTokenizer
11
- import requests
12
- import json
13
- import re
14
 
15
  load_dotenv()
16
 
@@ -63,7 +59,7 @@ def get_atla_response(model_name, prompt, system_prompt=None, max_tokens=500, te
63
  evaluation_criteria = prompt.get('eval_criteria', '')
64
 
65
  response = atla_client.evaluation.create(
66
- model_id="atla-selene",
67
  model_input=model_input,
68
  model_output=model_output,
69
  expected_model_output=expected_output if expected_output else None,
@@ -76,73 +72,4 @@ def get_atla_response(model_name, prompt, system_prompt=None, max_tokens=500, te
76
  "critique": response.result.evaluation.critique
77
  }
78
  except Exception as e:
79
- return f"Error with Atla model {model_name}: {str(e)}"
80
-
81
- def get_selene_mini_response(model_name, prompt, system_prompt=None, max_tokens=500, temperature=0.01):
82
- """Get response from HF endpoint for Atla model"""
83
- try:
84
- headers = {
85
- "Accept": "application/json",
86
- "Authorization": f"Bearer {hf_api_key}",
87
- "Content-Type": "application/json"
88
- }
89
-
90
- # Create messages list for chat template
91
- messages = []
92
- if system_prompt:
93
- messages.append({"role": "system", "content": system_prompt})
94
- messages.append({"role": "user", "content": prompt})
95
-
96
- # Apply chat template
97
- model_id = "AtlaAI/Selene-1-Mini-Llama-3.1-8B"
98
- tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_api_key)
99
- formatted_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
100
-
101
- payload = {
102
- "inputs": formatted_prompt,
103
- "parameters": {
104
- "max_new_tokens": max_tokens,
105
- "return_full_text": False,
106
- "temperature": temperature,
107
- "seed": 42,
108
- "add_generation_prompt": True
109
- }
110
- }
111
-
112
- response = requests.post(
113
- "https://bkp9p28gri93egqh.us-east-1.aws.endpoints.huggingface.cloud",
114
- headers=headers,
115
- json=payload
116
- )
117
- return response.json()[0]["generated_text"]
118
- except Exception as e:
119
- return f"Error with Atla model {model_name}: {str(e)}"
120
-
121
- def parse_selene_mini_response(response_text):
122
- """Parse the response from Selene Mini to extract score and critique"""
123
- try:
124
- # Clean up the response text
125
- response_text = response_text.strip()
126
-
127
- # More flexible regex patterns
128
- reasoning_pattern = r'\*\*Reasoning:?\*\*\s*(.*?)(?=\*\*Result|$)'
129
- result_pattern = r'\*\*Result:?\*\*\s*(\d+)'
130
-
131
- reasoning_match = re.search(reasoning_pattern, response_text, re.DOTALL | re.IGNORECASE)
132
- result_match = re.search(result_pattern, response_text, re.IGNORECASE)
133
-
134
- if reasoning_match and result_match:
135
- critique = reasoning_match.group(1).strip()
136
- score = result_match.group(1)
137
- return {"score": score, "critique": critique}
138
- else:
139
- # If we can't parse it properly, let's return the raw response as critique
140
- return {
141
- "score": "Error",
142
- "critique": f"Failed to parse response. Raw response:\n{response_text}"
143
- }
144
- except Exception as e:
145
- return {
146
- "score": "Error",
147
- "critique": f"Error parsing response: {str(e)}\nRaw response:\n{response_text}"
148
- }
 
7
  from .prompts import (
8
  JUDGE_SYSTEM_PROMPT
9
  )
 
 
 
 
10
 
11
  load_dotenv()
12
 
 
59
  evaluation_criteria = prompt.get('eval_criteria', '')
60
 
61
  response = atla_client.evaluation.create(
62
+ model_id=model_name, # Will be either "atla-selene" or "atla-selene-mini"
63
  model_input=model_input,
64
  model_output=model_output,
65
  expected_model_output=expected_output if expected_output else None,
 
72
  "critique": response.result.evaluation.critique
73
  }
74
  except Exception as e:
75
+ return f"Error with Atla model {model_name}: {str(e)}"