iyosha commited on
Commit
38c5e59
·
verified ·
1 Parent(s): 50df258

Upload 11 files

Browse files
app.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from uuid import uuid4
3
+ from datasets import load_dataset
4
+ from collections import Counter
5
+ from .configs import configs
6
+ from .clients import backend, logger
7
+ from .backend.helpers import get_random_session_samples
8
+
9
+ dataset = load_dataset("iyosha-huji/stressBench", token=configs.HF_API_TOKEN)["test"]
10
+
11
+
12
+ def human_eval_tab():
13
+ with gr.Tab(label="Evaluation"):
14
+ # ==== State ====
15
+ i = gr.State(-1)
16
+ selected_answer = gr.State(None)
17
+ answers_dict = gr.State({})
18
+ logged_in = gr.State(False)
19
+ session_id = gr.State(None)
20
+ session_sample_indices = gr.State([])
21
+
22
+ # === Login UI ===
23
+ with gr.Group(visible=True) as login_group:
24
+ gr.Markdown("### 🔐 Login to Continue")
25
+ with gr.Row():
26
+ username = gr.Text(label="Username", placeholder="Enter username")
27
+ password = gr.Text(
28
+ label="Password", type="password", placeholder="Enter password"
29
+ )
30
+ login_error = gr.Markdown(
31
+ "\u274c Incorrect login, try again.", visible=False
32
+ )
33
+ login_btn = gr.Button("Login")
34
+
35
+ def login(u, p):
36
+ if u == configs.USER_NAME and p == configs.USER_PASSWORD:
37
+ new_session_id = str(uuid4())
38
+ current_rows = backend.get_all_rows()
39
+ sample_indices = get_random_session_samples(
40
+ current_rows, dataset, num_samples=30
41
+ )
42
+ logger.info(f"Session ID: {new_session_id}")
43
+ return (
44
+ True,
45
+ gr.update(visible=False),
46
+ gr.update(visible=False),
47
+ new_session_id,
48
+ sample_indices,
49
+ )
50
+ else:
51
+ return False, gr.update(visible=True), gr.update(visible=True), None, []
52
+
53
+ login_btn.click(
54
+ fn=login,
55
+ inputs=[username, password],
56
+ outputs=[
57
+ logged_in,
58
+ login_group,
59
+ login_error,
60
+ session_id,
61
+ session_sample_indices,
62
+ ],
63
+ )
64
+
65
+ # === UI Elements ===
66
+ next_btn = gr.Button("Start", visible=False)
67
+ prev_btn = gr.Button("Previous Sample", visible=False)
68
+ warning_msg = gr.Markdown(
69
+ "<span style='color:red;'>\u26a0\ufe0f Please select an answer before continuing.</span>",
70
+ visible=False,
71
+ )
72
+
73
+ with gr.Group(visible=False) as app_group:
74
+ with gr.Group():
75
+ gr.Markdown("<div align='center'><big><b>Instructions</b></big></div>")
76
+ gr.Markdown(
77
+ "<div align='center'>You are given an audio sample and a question with 2 answer options.\n\nListen to the audio and select the correct answer from the options below.</div>"
78
+ )
79
+
80
+ with gr.Group(visible=False) as question_group:
81
+ with gr.Row(show_progress=True):
82
+ with gr.Column(variant="compact"):
83
+ sample_info = gr.Markdown()
84
+ gr.Markdown("**Question:**")
85
+ question_md = gr.Markdown()
86
+ radio = gr.Radio(label="Answer:", interactive=True)
87
+ with gr.Column(variant="compact"):
88
+ audio_output = gr.Audio()
89
+
90
+ with gr.Group(
91
+ visible=False, elem_id="final_page"
92
+ ) as final_group: # Final page, not visible until the end
93
+ gr.Markdown(
94
+ """
95
+ # 🎉 Thanks for your help!
96
+
97
+ You helped moving science forward 🤓
98
+
99
+ Your responses have been recorded.
100
+
101
+ You may now close this tab.
102
+ """
103
+ )
104
+
105
+ # === Logic ===
106
+ def update_ui(i, answers, session_sample_indices):
107
+ if i == -1: # We haven't started yet
108
+ return (
109
+ gr.update(visible=False),
110
+ "",
111
+ "",
112
+ gr.update(visible=False),
113
+ gr.update(visible=False),
114
+ None,
115
+ )
116
+ # show the question
117
+ true_index = session_sample_indices[i]
118
+ sample = dataset[true_index]
119
+ audio_data = (sample["audio"]["sampling_rate"], sample["audio"]["array"])
120
+ previous_answer = answers.get(i, None)
121
+ return (
122
+ gr.update(visible=True),
123
+ f"<div align='center'>Sample <b>{i+1}</b> out of <b>{len(session_sample_indices)}</b></div>",
124
+ "Out of the following answers, according to the speaker's stressed words, what is most likely the underlying intention of the speaker?",
125
+ gr.Audio(value=audio_data, label="Audio:"),
126
+ gr.Radio(
127
+ choices=sample["possible_answers"],
128
+ value=previous_answer,
129
+ label="Answer:",
130
+ ),
131
+ previous_answer,
132
+ )
133
+
134
+ def update_next_index(i, answer, answers, session_id, session_sample_indices):
135
+ if answer is None and i != -1: # if no answer is selected
136
+ # show warning message
137
+ return (
138
+ gr.update(),
139
+ gr.update(visible=True),
140
+ gr.update(),
141
+ answers,
142
+ gr.update(visible=False),
143
+ gr.update(visible=True),
144
+ )
145
+
146
+ if answer: # if an answer is selected
147
+ # save the answer to the backend
148
+ answers[i] = answer
149
+ true_index = session_sample_indices[i]
150
+ sample = dataset[true_index]
151
+ interp_id = sample["interpretation_id"]
152
+ trans_id = sample["transcription_id"]
153
+ user_id = session_id
154
+ logger.info(
155
+ "saving answer to backend",
156
+ context={
157
+ "i": true_index,
158
+ "interp_id": interp_id,
159
+ "answer": answer,
160
+ "user_id": user_id,
161
+ },
162
+ )
163
+ if not backend.update_row(true_index, interp_id, user_id, answer):
164
+ backend.add_row(true_index, interp_id, trans_id, user_id, answer)
165
+
166
+ if i + 1 == len(session_sample_indices): # Last question just answered
167
+ return (
168
+ -1, # reset i to stop showing question
169
+ gr.update(visible=False),
170
+ gr.update(visible=False),
171
+ answers,
172
+ gr.update(visible=True), # show final page
173
+ gr.update(visible=False), # hide previous button
174
+ )
175
+ # go to the next question
176
+ new_i = i + 1 if i + 1 < len(session_sample_indices) else 0
177
+ return (
178
+ new_i,
179
+ gr.update(visible=False),
180
+ gr.update(value="Submit answer and go to Next"),
181
+ answers,
182
+ gr.update(visible=False),
183
+ gr.update(visible=True),
184
+ )
185
+
186
+ def update_prev_index(i):
187
+ # prevent goint back in the first question and first page
188
+ if i <= 0:
189
+ return i, gr.update(visible=False)
190
+ # go back to the previous question
191
+ else:
192
+ return i - 1, gr.update(visible=False)
193
+
194
+ def answer_change_callback(answer, i, answers):
195
+ answers[i] = answer
196
+ return answer, answers
197
+
198
+ def login_callback(logged_in):
199
+ return (
200
+ (
201
+ gr.update(visible=True),
202
+ gr.update(visible=True),
203
+ gr.update(visible=False),
204
+ gr.update(visible=False),
205
+ )
206
+ if logged_in
207
+ else (
208
+ gr.update(visible=False),
209
+ gr.update(visible=False),
210
+ gr.update(visible=False),
211
+ gr.update(visible=False),
212
+ )
213
+ )
214
+
215
+ # === Events ===
216
+ next_btn.click(
217
+ update_next_index,
218
+ [i, selected_answer, answers_dict, session_id, session_sample_indices],
219
+ [i, warning_msg, next_btn, answers_dict, final_group, prev_btn],
220
+ )
221
+ prev_btn.click(update_prev_index, i, [i, warning_msg])
222
+ i.change(
223
+ update_ui,
224
+ [i, answers_dict, session_sample_indices],
225
+ [
226
+ question_group,
227
+ sample_info,
228
+ question_md,
229
+ audio_output,
230
+ radio,
231
+ selected_answer,
232
+ ],
233
+ )
234
+ radio.change(
235
+ answer_change_callback,
236
+ [radio, i, answers_dict],
237
+ [selected_answer, answers_dict],
238
+ )
239
+ logged_in.change(
240
+ login_callback, logged_in, [app_group, next_btn, prev_btn, warning_msg]
241
+ )
242
+
243
+
244
+ # Dummy password for admin tab
245
+ ADMIN_PASSWORD = configs.ADMIN_PASSWORD
246
+
247
+
248
+ def get_admin_tab():
249
+ with gr.Tab("Admin Console"):
250
+ admin_password = gr.Text(label="Enter Admin Password", type="password")
251
+ check_btn = gr.Button("Enter")
252
+ error_box = gr.Markdown("", visible=False)
253
+ output_box = gr.Markdown("", visible=False)
254
+
255
+ def calculate_majority_vote_accuracy(pw):
256
+ if pw != ADMIN_PASSWORD:
257
+ return gr.update(
258
+ visible=True, value="\u274c Incorrect password."
259
+ ), gr.update(visible=False)
260
+
261
+ df = backend.get_all_rows()
262
+ if df.empty:
263
+ return gr.update(visible=True, value="No data available."), gr.update(
264
+ visible=False
265
+ )
266
+
267
+ majority_answers = {}
268
+ for interp_id, group in df.groupby("interpretation_id"):
269
+ answer_counts = Counter(group["answer"])
270
+ if answer_counts:
271
+ majority_answers[interp_id] = answer_counts.most_common(1)[0][0]
272
+
273
+ total = 0
274
+ correct = 0
275
+
276
+ for sample in dataset:
277
+ interp_id = sample["interpretation_id"]
278
+ if interp_id not in majority_answers:
279
+ continue
280
+ predicted_answer = majority_answers[interp_id]
281
+ correct_label_idx = sample["label"]
282
+ correct_answer_text = sample["possible_answers"][correct_label_idx]
283
+ total += 1
284
+ if predicted_answer == correct_answer_text:
285
+ correct += 1
286
+
287
+ acc = correct / total if total > 0 else 0
288
+ # calculate total answers submited
289
+ total_answers = len(df)
290
+ answers_to_go = (3 * len(dataset)) - total_answers
291
+ users_count = df["user_id"].nunique()
292
+ # update the admin console
293
+ return gr.update(visible=False), gr.update(
294
+ visible=True,
295
+ value=f"""**Accuracy over answered samples:** {acc:.2%} ({correct}/{total})
296
+
297
+ **Total answers submitted:** {total_answers}
298
+
299
+ **Answers to go:** {answers_to_go}
300
+
301
+ **Users count:** {users_count}""",
302
+ )
303
+
304
+ check_btn.click(
305
+ fn=calculate_majority_vote_accuracy,
306
+ inputs=admin_password,
307
+ outputs=[error_box, output_box],
308
+ )
309
+
310
+
311
+ # App UI
312
+ with gr.Blocks() as demo:
313
+ human_eval_tab()
314
+ get_admin_tab()
backend/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .backend import Backend
backend/backend.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import gspread
3
+ import pandas as pd
4
+ from datetime import datetime
5
+ from oauth2client.service_account import ServiceAccountCredentials
6
+
7
+
8
+ class Backend:
9
+ def __init__(self, sheet_name: str, credentials: str):
10
+ creds_dict = json.load(open(credentials))
11
+ scope = [
12
+ "https://spreadsheets.google.com/feeds",
13
+ "https://www.googleapis.com/auth/drive",
14
+ ]
15
+ credentials = ServiceAccountCredentials.from_json_keyfile_dict(
16
+ creds_dict, scope
17
+ )
18
+ client = gspread.authorize(credentials)
19
+ self.sheet = client.open(sheet_name).sheet1
20
+ self.header = self.sheet.row_values(1)
21
+
22
+ def get_all_rows(self) -> pd.DataFrame:
23
+ records = self.sheet.get_all_records()
24
+ return pd.DataFrame.from_records(records)
25
+
26
+ def add_row(
27
+ self, index_in_dataset, interpretation_id, transcription_id, user_id, answer
28
+ ):
29
+ timestamp = datetime.utcnow().isoformat()
30
+ self.sheet.append_row(
31
+ [
32
+ index_in_dataset,
33
+ interpretation_id,
34
+ transcription_id,
35
+ user_id,
36
+ answer,
37
+ timestamp,
38
+ ]
39
+ )
40
+
41
+ def update_row(self, index_in_dataset, interpretation_id, user_id, new_answer):
42
+ records = self.get_all_rows().to_dict("records")
43
+ for idx, row in enumerate(records):
44
+ if (
45
+ row["interpretation_id"] == interpretation_id
46
+ and row["index_in_dataset"] == index_in_dataset
47
+ and row["user_id"] == user_id
48
+ ):
49
+ sheet_row = (
50
+ idx + 2
51
+ ) # +2 because sheet rows are 1-indexed and header is row 1
52
+ if row["answer"] != new_answer:
53
+ self.sheet.update_cell(
54
+ sheet_row, self.header.index("answer") + 1, new_answer
55
+ )
56
+ self.sheet.update_cell(
57
+ sheet_row,
58
+ self.header.index("timestamp") + 1,
59
+ datetime.utcnow().isoformat(),
60
+ )
61
+ return True
62
+ return False
63
+
64
+ def get_answer_count(self, interpretation_id):
65
+ df = self.get_all_rows()
66
+ return df[df["interpretation_id"] == interpretation_id]["user_id"].nunique()
backend/helpers.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import pandas as pd
3
+
4
+
5
+ def get_random_session_samples(df: pd.DataFrame, dataset, num_samples=30):
6
+ if df.empty:
7
+ # Return any random sample from the dataset if no answers exist yet
8
+ return random.sample(range(len(dataset)), min(num_samples, len(dataset)))
9
+
10
+ # Otherwise compute counts normally
11
+ counts = df.groupby("interpretation_id")["user_id"].nunique().to_dict()
12
+
13
+ # Select samples with < 3 answers
14
+ eligible_indices = [
15
+ i
16
+ for i, sample in enumerate(dataset)
17
+ if counts.get(sample["interpretation_id"], 0) < 3
18
+ ]
19
+
20
+ return random.sample(eligible_indices, min(num_samples, len(eligible_indices)))
clients.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from .configs import configs
2
+ from .logger import Logger
3
+ from .backend import Backend
4
+
5
+ logger = Logger(context={"service": "Human Evaluation"}, use_context_var=False)
6
+ backend = Backend(
7
+ sheet_name=configs.GOOGLE_SHEET_NAME, credentials=configs.GOOGLE_SHEETS_CREDENTIALS
8
+ )
configs.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import Field
2
+ from pydantic_settings import BaseSettings
3
+ from pathlib import Path
4
+
5
+
6
+ class Settings(BaseSettings):
7
+ HF_API_TOKEN: str = Field(default="your_hf_api_token")
8
+ GOOGLE_SHEET_NAME: str = Field(
9
+ default="sheet name"
10
+ ) # Replace with your actual Google Sheet name
11
+ GOOGLE_SHEETS_CREDENTIALS: str = Field(
12
+ default="path_to_creds"
13
+ ) # Replace with your actual Google Sheets credentials
14
+ ADMIN_PASSWORD: str = Field(default="admin_password")
15
+ USER_PASSWORD: str = Field(default="user_password")
16
+ USER_NAME: str = Field(default="user_name")
17
+
18
+
19
+ configs = Settings()
logger/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .logger import Logger
logger/json_formatter.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ from time import strftime, gmtime
4
+
5
+
6
+ class JsonFormatter(logging.Formatter):
7
+ grey = "\x1b[38;20m"
8
+ green = "\x1b[33;32m"
9
+ yellow = "\x1b[33;20m"
10
+ red = "\x1b[31;20m"
11
+ bold_red = "\x1b[31;1m"
12
+ reset = "\x1b[0m"
13
+
14
+ FORMATS = {
15
+ logging.DEBUG: grey,
16
+ logging.INFO: green,
17
+ logging.WARNING: yellow,
18
+ logging.ERROR: red,
19
+ logging.CRITICAL: bold_red,
20
+ }
21
+
22
+ def __init__(self):
23
+ super().__init__()
24
+
25
+ @staticmethod
26
+ def serialize_to_json(data):
27
+ try:
28
+ return json.dumps(data, indent=2)
29
+ except Exception as e:
30
+ return f"Failed to serialize data to JSON: {str(data)}\nError: {str(e)}"
31
+
32
+ def format(self, record):
33
+ error_json = (
34
+ {"error": self.formatException(record.exc_info)}
35
+ if record.levelno == logging.ERROR and record.exc_info
36
+ else {}
37
+ )
38
+ context = record.__dict__["context"]
39
+ json_record = {
40
+ "message": record.getMessage(),
41
+ "level": record.levelname,
42
+ "logged_at": strftime("%Y-%m-%d %H:%M:%S", gmtime(record.created)),
43
+ **context,
44
+ **error_json,
45
+ }
46
+ try:
47
+ json_log = f"{self.FORMATS.get(record.levelno)}{json.dumps(json_record, indent=2)}{self.reset}"
48
+ colorful_json = json_log.encode("utf-8").decode("unicode_escape")
49
+ return colorful_json
50
+ except Exception as e:
51
+ return (
52
+ f"Failed to serialize data to JSON: {str(json_record)}\nError: {str(e)}"
53
+ )
logger/logger.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import logging
3
+ import contextvars
4
+ from typing import Dict, Any
5
+ from .json_formatter import JsonFormatter
6
+
7
+
8
+ # Create a context variable to store request-specific information
9
+ context_var: contextvars.ContextVar[Dict[str, Any]] = contextvars.ContextVar(
10
+ "context_dict"
11
+ )
12
+
13
+
14
+ class Logger:
15
+ def __init__(self, context=None, use_context_var=False):
16
+ self.logger = logging.getLogger("json_logger")
17
+ self.context = context or {}
18
+ self.base_context = context or {}
19
+ self.use_context_var = use_context_var
20
+ self._setup()
21
+
22
+ def _setup(self):
23
+ self.logger.setLevel(logging.DEBUG)
24
+ self.context = (
25
+ context_var.set(self.base_context) if self.use_context_var else self.context
26
+ )
27
+ if not self.logger.handlers:
28
+ console_handler = logging.StreamHandler(sys.stdout)
29
+ console_handler.setFormatter(JsonFormatter())
30
+ self.logger.addHandler(console_handler)
31
+
32
+ def debug(self, data, context={}):
33
+ self.log(logging.DEBUG, data, context)
34
+
35
+ def info(self, data, context={}):
36
+ self.log(logging.INFO, data, context)
37
+
38
+ def warning(self, data, context={}):
39
+ self.log(logging.WARNING, data, context)
40
+
41
+ def error(self, data, error=None, context={}):
42
+ self.log(logging.ERROR, data, context, error)
43
+
44
+ def log(self, level, data, context={}, error=None):
45
+ self.update_context(context=context)
46
+ self.logger.log(
47
+ level,
48
+ msg=data,
49
+ extra={"context": self._get_context()},
50
+ exc_info=error,
51
+ )
52
+
53
+ def _get_context(self):
54
+ return context_var.get() if self.use_context_var else self.context
55
+
56
+ def reset_context(self):
57
+ if self.use_context_var:
58
+ context_var.set(self.base_context)
59
+ else:
60
+ self.context = self.base_context or {}
61
+
62
+ def update_context(self, context):
63
+ if self.use_context_var:
64
+ context_var.set({**context_var.get(), **context})
65
+ else:
66
+ self.context.update(context)
main.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from .app import demo
2
+
3
+
4
+ def launch():
5
+ demo.launch(server_name="0.0.0.0", server_port=7860)
6
+
7
+
8
+ if __name__ == "__main__":
9
+ launch()
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ gradio==5.16.2
2
+ pydantic==2.8.2
3
+ pydantic-settings==2.0.3
4
+ librosa==0.10.2.post1
5
+ soundfile==0.12.1
6
+ datasets==2.21.0
7
+ gspread==6.2.0
8
+ oauth2client==4.1.3
9
+ pandas==2.2.3