alisartazkhan commited on
Commit
f35f09f
·
verified ·
1 Parent(s): cc8489d

Update talk_arena/audio_collection.py

Browse files
Files changed (1) hide show
  1. talk_arena/audio_collection.py +179 -469
talk_arena/audio_collection.py CHANGED
@@ -1,22 +1,13 @@
1
- import argparse
2
- import asyncio
3
  import os
4
- import random
5
- import textwrap
6
- import time
7
  import uuid
8
-
9
- import gradio as gr
10
  import numpy as np
 
11
  import soundfile as sf
12
  import xxhash
13
- from datasets import Audio
14
- from dotenv import load_dotenv
15
- from openai import OpenAI
16
  from huggingface_hub import upload_file, HfApi
17
-
18
- import talk_arena.streaming_helpers as sh
19
- from talk_arena.db_utils import TinyThreadSafeDB
20
 
21
  # Load environment variables
22
  load_dotenv()
@@ -27,44 +18,71 @@ os.makedirs("outputs", exist_ok=True)
27
  # Initialize Hugging Face API client
28
  hf_api = HfApi(token=os.getenv("HF_TOKEN"))
29
  DATASET_REPO = "alisartazkhan/audioLLM_judge"
30
- CATEGORY = "pilot_tempo_control_2"
31
- COUNTER = 3
32
- CODE = "C1BDJUET"
33
- CAT_DESC = "An interactive study that tests how well audio models follow voice prompts with changing tempo. Create your own prompts and compare model responses!"
34
  resampler = Audio(sampling_rate=16_000)
35
 
36
- def parse_args():
37
- parser = argparse.ArgumentParser(description="Talk Arena Demo")
38
- parser.add_argument("--free_only", action="store_true", help="Only use free models")
39
- return parser.parse_args()
40
-
41
- args = parse_args()
42
-
43
- if gr.NO_RELOAD: # Prevents Re-init during hot reloading
44
- # Transcription Disabled for Public Interface
45
- # asr_pipe = pipeline(
46
- # task="automatic-speech-recognition",
47
- # model="openai/whisper-large-v3-turbo",
48
- # chunk_length_s=30,
49
- # device="cuda:1",
50
- # )
51
-
52
- anonymous = True
53
 
54
- gpt4o_audio, gpt4o_model = sh.gpt4o_streaming("models/gpt4o")
55
- gemini2_audio, gemini2_model = sh.gemini_streaming("models/gemini-2.0-flash-exp")
56
- competitor_info = [
57
- (sh.gradio_gen_factory(gpt4o_audio, "GPT4o", anonymous), "gpt4o", "GPT-4o"),
58
- (sh.gradio_gen_factory(gemini2_audio, "Gemini 2 Flash", anonymous), "gemini_2f", "Gemini 2 Flash"),
59
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
- resp_generators = [generator for generator, _, _ in competitor_info]
62
- model_shorthand = [shorthand for _, shorthand, _ in competitor_info]
63
- model_name = [full_name for _, _, full_name in competitor_info]
64
- all_models = list(range(len(model_shorthand)))
65
 
66
- # Function to upload file to HF dataset repository
67
  def upload_to_hf(local_path, repo_path):
 
68
  try:
69
  upload_file(
70
  path_or_fileobj=local_path,
@@ -73,456 +91,148 @@ def upload_to_hf(local_path, repo_path):
73
  repo_type="dataset",
74
  token=os.getenv("HF_TOKEN")
75
  )
76
- print(f"Uploaded file: {local_path} to Hugging Face repository at {repo_path}")
77
  return True
78
  except Exception as e:
79
  print(f"Error uploading file to HF: {e}")
80
  return False
81
 
82
-
83
- async def pairwise_response_async(audio_input, state, model_order):
84
- if audio_input == None:
85
- raise StopAsyncIteration(
86
- "",
87
- "",
88
- gr.Button(visible=False),
89
- gr.Button(visible=False),
90
- gr.Button(visible=False),
91
- state,
92
- audio_input,
93
- None,
94
- None,
95
- None,
96
  )
97
- spinner_id = 0
98
- spinners = ["◐ ", "◓ ", "◑", "◒"]
99
- spinner = spinners[0]
100
- gen_pair = [resp_generators[model_order[0]], resp_generators[model_order[1]]]
101
- latencies = [{}, {}] # Store timing info for each model
102
- resps = [gr.Textbox(value="", info="", visible=False), gr.Textbox(value="", info="", visible=False)]
103
- tts_resps = [gr.Audio(), gr.Audio()]
104
- error_in_model = False
105
 
106
- # Get a unique hash for this audio input
107
  sr, y = audio_input
108
- x = xxhash.xxh32(bytes(y)).hexdigest()
109
 
110
- for order, generator in enumerate(gen_pair):
111
- start_time = time.time()
112
- first_token = True
113
- total_length = 0
114
- try:
115
- async for local_resp in generator(audio_input, order):
116
- total_length += 1
117
- if first_token:
118
- latencies[order]["time_to_first_token"] = time.time() - start_time
119
- first_token = False
120
- resps[order] = local_resp
121
- spinner = spinners[spinner_id]
122
- spinner_id = (spinner_id + 1) % 4
123
- yield (
124
- gr.Button(
125
- value=spinner + " Generating Responses " + spinner,
126
- interactive=False,
127
- variant="primary",
128
- ),
129
- resps[0],
130
- resps[1],
131
- tts_resps[0],
132
- tts_resps[1],
133
- gr.Button(visible=False),
134
- gr.Button(visible=False),
135
- gr.Button(visible=False),
136
- state,
137
- audio_input,
138
- None,
139
- None,
140
- latencies,
141
- )
142
- latencies[order]["total_time"] = time.time() - start_time
143
- latencies[order]["response_length"] = total_length
144
- except Exception as e:
145
- print(f"Error in model {order+1}: {e}")
146
- error_in_model = True
147
- resps[order] = gr.Textbox(
148
- info=f"<strong>Error thrown by Model {order+1} API</strong>",
149
- value="" if first_token else resps[order]._constructor_args[0]["value"],
150
- visible=True,
151
- label=f"Model {order+1}",
152
- )
153
- yield (
154
- gr.Button(
155
- value=spinner + " Generating Responses " + spinner,
156
- interactive=False,
157
- variant="primary",
158
- ),
159
- resps[0],
160
- resps[1],
161
- tts_resps[0],
162
- tts_resps[1],
163
- gr.Button(visible=False),
164
- gr.Button(visible=False),
165
- gr.Button(visible=False),
166
- state,
167
- audio_input,
168
- None,
169
- None,
170
- latencies,
171
- )
172
-
173
- # Process and save audio
174
- y = y.astype(np.float32)
175
- y /= np.max(np.abs(y))
176
- a = resampler.decode_example(resampler.encode_example({"array": y, "sampling_rate": sr}))
177
-
178
- # Create a unique identifier
179
- unique_id = str(uuid.uuid4())[:8]
180
- local_filename = f"outputs/{x}_resp{order}_{unique_id}.wav"
181
-
182
- # Save locally first
183
- sf.write(local_filename, a["array"], a["sampling_rate"], format="wav")
184
-
185
- # Upload to HF dataset
186
- upload_to_hf(
187
- local_filename,
188
- f"{CATEGORY}/{x}_resp{order}_{unique_id}.wav"
189
- )
190
-
191
- # Generate TTS response
192
- try:
193
- tts_options = {
194
- "model": "gpt-4o-mini-tts",
195
- "voice": "alloy",
196
- "input": resps[order].__dict__["_constructor_args"][0]["value"],
197
- "response_format": "wav",
198
- }
199
- abytes = OpenAI(api_key=os.environ["OPENAI_API_KEY"]).audio.speech.create(**tts_options).content
200
- tts_resps[order] = gr.Audio(
201
- value=abytes,
202
- visible=True,
203
- )
204
- except Exception as e:
205
- print(f"Error generating TTS: {e}")
206
- tts_resps[order] = gr.Audio(visible=False)
207
-
208
- latencies[order]["total_time"] = time.time() - start_time
209
- latencies[order]["response_length"] = total_length
210
 
211
- print("Latency data:", latencies)
212
- yield (
213
- gr.Button(value="Vote for which model is better!", interactive=False, variant="primary", visible=False),
214
- resps[0],
215
- resps[1],
216
- tts_resps[0],
217
- tts_resps[1],
218
- gr.Button(visible=not error_in_model),
219
- gr.Button(visible=not error_in_model),
220
- gr.Button(visible=not error_in_model),
221
- responses_complete(state),
222
- audio_input,
223
- gr.Textbox(visible=False),
224
- gr.Audio(visible=False),
225
- latencies,
226
- )
227
-
228
-
229
- def on_page_load(state, model_order):
230
- if state == 0:
231
- # gr.Info(
232
- # "Record something you'd say to an AI Assistant! Think about what you usually use Siri, Google Assistant,"
233
- # " or ChatGPT for."
234
- # )
235
- state = 1
236
- model_order = random.sample(all_models, 2) if anonymous else model_order
237
- return state, model_order
238
-
239
-
240
- def recording_complete(state):
241
- if state == 1:
242
- # gr.Info(
243
- # "Once you submit your recording, you'll receive responses from different models. This might take a second."
244
- # )
245
- state = 2
246
  return (
247
- gr.Button(value="Starting Generation", interactive=False, variant="primary"),
248
- state,
 
 
 
 
 
249
  )
250
 
251
-
252
- def responses_complete(state):
253
- if state == 2:
254
- gr.Info(
255
- "Give us your feedback! Mark which model gave you the best response so we can understand the quality of"
256
- " these different voice assistant models."
257
- )
258
- state = 3
259
- return state
260
-
261
-
262
- class UploadableDB(TinyThreadSafeDB):
263
- def __init__(self, filename):
264
- super().__init__(filename)
265
- self.filename = filename
266
 
267
- async def upload_db(self):
268
- try:
269
- # Upload the JSON database file to HF
270
- upload_to_hf(
271
- self.filename,
272
- f"{CATEGORY}/{self.filename}"
273
- )
274
- print(f"Successfully uploaded DB file {self.filename} to HF dataset")
275
- return True
276
- except Exception as e:
277
- print(f"Error uploading DB file to HF: {e}")
278
- return False
279
-
280
-
281
- def clear_factory(button_id):
282
- async def clear(audio_input, model_order, pref_counter, reasoning, latency):
283
- textbox1 = gr.Textbox(visible=False)
284
- textbox2 = gr.Textbox(visible=False)
285
- if button_id != None:
286
- sr, y = audio_input
287
- x = xxhash.xxh32(bytes(y)).hexdigest()
288
- await db.insert(
289
- {
290
- "audio_hash": x,
291
- "outcome": button_id,
292
- "model_a": model_shorthand[model_order[0]],
293
- "model_b": model_shorthand[model_order[1]],
294
- "why": reasoning,
295
- "model_a_latency": latency[0],
296
- "model_b_latency": latency[1],
297
- }
298
- )
299
- # Upload the updated database to HF after each insertion
300
- await db.upload_db()
301
-
302
- pref_counter += 1
303
- model_a = model_name[model_order[0]]
304
- model_b = model_name[model_order[1]]
305
-
306
-
307
- counter_text = f"# {pref_counter}/{COUNTER} Preferences Submitted"
308
- if pref_counter >= COUNTER:
309
- counter_text = f"# Completed! Completion Code: {CODE}"
310
- if anonymous:
311
- model_order = random.sample(all_models, 2)
312
  return (
313
- model_order,
314
- gr.Button(
315
- value="Record Audio to Submit Again!",
316
- interactive=False,
317
- visible=True,
318
- ),
319
- gr.Button(visible=False),
320
  gr.Button(visible=False),
321
  gr.Button(visible=False),
322
- None,
323
- textbox1,
324
- textbox2,
325
- gr.Audio(visible=False),
326
- gr.Audio(visible=False),
327
- pref_counter,
328
- counter_text,
329
- gr.Textbox(visible=False),
330
- gr.Audio(visible=False),
331
  )
 
 
 
 
 
 
 
 
 
 
 
332
 
333
- return clear
334
-
335
-
336
- def transcribe(transc, voice_reason):
337
- if transc is None:
338
- transc = ""
339
- transc += " " + asr_pipe(voice_reason, generate_kwargs={"task": "transcribe"}, return_timestamps=False)["text"]
340
- return transc, gr.Audio(value=None)
341
-
342
 
 
343
  theme = gr.themes.Soft(
344
- primary_hue=gr.themes.Color(
345
- c100="#82000019",
346
- c200="#82000033",
347
- c300="#8200004c",
348
- c400="#82000066",
349
- c50="#8200007f",
350
- c500="#8200007f",
351
- c600="#82000099",
352
- c700="#820000b2",
353
- c800="#820000cc",
354
- c900="#820000e5",
355
- c950="#820000f2",
356
- ),
357
- secondary_hue="rose",
358
- neutral_hue="stone",
359
  )
360
 
361
- import os
362
-
363
- css_path = os.path.join(os.path.dirname(__file__), "styles.css")
364
- with open(css_path, "r") as css_file:
365
- custom_css = css_file.read()
366
-
367
-
368
- # Initialize our custom database class instead of the original one
369
- db = UploadableDB("audio_out_votes.json")
370
-
371
- with gr.Blocks(theme=theme, fill_height=True, css=custom_css) as demo:
372
- submitted_preferences = gr.State(0)
373
- state = gr.State(0)
374
- model_order = gr.State([])
375
- latency = gr.State([])
376
- with gr.Row():
377
- counter_text = gr.Markdown(
378
- f"# 0/{COUNTER} Preferences Submitted.\n Follow the pop-up tips to submit your first preference."
379
- )
380
- category_description_text = gr.Markdown(CAT_DESC)
381
- with gr.Row():
382
- audio_input = gr.Audio(sources=["microphone"], streaming=False, label="Audio Input")
383
-
384
- with gr.Row(equal_height=True):
385
- with gr.Column(scale=1):
386
- out1 = gr.Textbox(visible=False, lines=5, autoscroll=True)
387
- audio_out1 = gr.Audio(visible=False)
388
- with gr.Column(scale=1):
389
- out2 = gr.Textbox(visible=False, lines=5, autoscroll=True)
390
- audio_out2 = gr.Audio(visible=False)
391
-
392
- with gr.Row():
393
- btn = gr.Button(value="Record Audio to Submit!", interactive=False)
394
-
395
- with gr.Row(equal_height=True):
396
- reason = gr.Textbox(label="[Optional] Explain Your Preferences", visible=False, scale=4)
397
- reason_record = gr.Audio(
398
- sources=["microphone"],
399
- interactive=True,
400
- streaming=False,
401
- label="Speak to transcribe!",
402
- visible=False,
403
- type="filepath",
404
- # waveform_options={"show_recording_waveform": False},
405
- scale=1,
406
- )
407
-
408
- with gr.Row():
409
- best1 = gr.Button(value="Model 1 is better", visible=False)
410
- tie = gr.Button(value="Tie", visible=False)
411
- best2 = gr.Button(value="Model 2 is better", visible=False)
412
-
413
- with gr.Row():
414
- contact = gr.Markdown("")
415
-
416
- # reason_record.stop_recording(transcribe, inputs=[reason, reason_record], outputs=[reason, reason_record])
417
- audio_input.stop_recording(
418
- recording_complete,
419
- [state],
420
- [btn, state],
421
- ).then(
422
- fn=pairwise_response_async,
423
- inputs=[audio_input, state, model_order],
424
- outputs=[
425
- btn,
426
- out1,
427
- out2,
428
- audio_out1,
429
- audio_out2,
430
- best1,
431
- best2,
432
- tie,
433
- state,
434
- audio_input,
435
- reason,
436
- reason_record,
437
- latency,
438
- ],
439
- )
440
- audio_input.start_recording(
441
- lambda: gr.Button(value="Uploading Audio to Cloud", interactive=False, variant="primary"),
442
- None,
443
- btn,
444
- )
445
- best1.click(
446
- fn=clear_factory(0),
447
- inputs=[audio_input, model_order, submitted_preferences, reason, latency],
448
- outputs=[
449
- model_order,
450
- btn,
451
- best1,
452
- best2,
453
- tie,
454
- audio_input,
455
- out1,
456
- out2,
457
- audio_out1,
458
- audio_out2,
459
- submitted_preferences,
460
- counter_text,
461
- reason,
462
- reason_record,
463
- ],
464
  )
465
- tie.click(
466
- fn=clear_factory(0.5),
467
- inputs=[audio_input, model_order, submitted_preferences, reason, latency],
468
- outputs=[
469
- model_order,
470
- btn,
471
- best1,
472
- best2,
473
- tie,
474
- audio_input,
475
- out1,
476
- out2,
477
- audio_out1,
478
- audio_out2,
479
- submitted_preferences,
480
- counter_text,
481
- reason,
482
- reason_record,
483
- ],
484
  )
485
- best2.click(
486
- fn=clear_factory(1),
487
- inputs=[audio_input, model_order, submitted_preferences, reason, latency],
488
- outputs=[
489
- model_order,
490
- btn,
491
- best1,
492
- best2,
493
- tie,
494
- audio_input,
495
- out1,
496
- out2,
497
- audio_out1,
498
- audio_out2,
499
- submitted_preferences,
500
- counter_text,
501
- reason,
502
- reason_record,
503
- ],
504
  )
505
- audio_input.clear(
506
- clear_factory(None),
507
- [audio_input, model_order, submitted_preferences, reason, latency],
508
- [
509
- model_order,
510
- btn,
511
- best1,
512
- best2,
513
- tie,
514
- audio_input,
515
- out1,
516
- out2,
517
- audio_out1,
518
- audio_out2,
519
- submitted_preferences,
520
- counter_text,
521
- reason,
522
- reason_record,
523
- ],
524
  )
525
- demo.load(fn=on_page_load, inputs=[state, model_order], outputs=[state, model_order])
526
 
 
527
  if __name__ == "__main__":
528
- demo.queue(default_concurrency_limit=40, api_open=False).launch(share=True, ssr_mode=False)
 
 
 
 
 
 
 
 
 
1
  import os
 
 
 
2
  import uuid
3
+ import json
 
4
  import numpy as np
5
+ import gradio as gr
6
  import soundfile as sf
7
  import xxhash
 
 
 
8
  from huggingface_hub import upload_file, HfApi
9
+ from dotenv import load_dotenv
10
+ from datasets import Audio
 
11
 
12
  # Load environment variables
13
  load_dotenv()
 
18
  # Initialize Hugging Face API client
19
  hf_api = HfApi(token=os.getenv("HF_TOKEN"))
20
  DATASET_REPO = "alisartazkhan/audioLLM_judge"
21
+ CATEGORY = "pilot_tempo_control_3"
22
+ MAX_RECORDINGS = 10 # Number of prompts to record
 
 
23
  resampler = Audio(sampling_rate=16_000)
24
 
25
+ # Load the prompts from a JSON file
26
+ prompt_path = os.path.join(os.path.dirname(__file__), "prompts.json")
27
+ with open(prompt_path, "r") as f:
28
+ prompts_data = json.load(f)
29
+ PROMPTS = prompts_data["prompts"]
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
+ # Create a JSON database to track uploads
32
+ class UploadTracker:
33
+ def __init__(self, filename="recording_tracker.json"):
34
+ self.filename = filename
35
+ self.data = []
36
+
37
+ # Create file if it doesn't exist
38
+ if not os.path.exists(filename):
39
+ with open(filename, "w") as f:
40
+ json.dump([], f)
41
+ else:
42
+ # Load existing data
43
+ with open(filename, "r") as f:
44
+ self.data = json.load(f)
45
+
46
+ def add_recording(self, prompt_index, audio_hash, filename):
47
+ """Add a record of an uploaded recording"""
48
+ record = {
49
+ "prompt_index": prompt_index,
50
+ "audio_hash": audio_hash,
51
+ "filename": filename,
52
+ "timestamp": str(uuid.uuid4())
53
+ }
54
+ self.data.append(record)
55
+
56
+ # Save to file
57
+ with open(self.filename, "w") as f:
58
+ json.dump(self.data, f, indent=2)
59
+
60
+ # Upload tracker file to HF
61
+ self.upload_tracker()
62
+
63
+ return record
64
+
65
+ def upload_tracker(self):
66
+ """Upload the tracker JSON to Hugging Face"""
67
+ try:
68
+ upload_file(
69
+ path_or_fileobj=self.filename,
70
+ path_in_repo=f"{CATEGORY}/{self.filename}",
71
+ repo_id=DATASET_REPO,
72
+ repo_type="dataset",
73
+ token=os.getenv("HF_TOKEN")
74
+ )
75
+ print(f"Uploaded tracker file to Hugging Face")
76
+ return True
77
+ except Exception as e:
78
+ print(f"Error uploading tracker file: {e}")
79
+ return False
80
 
81
+ # Initialize the tracker
82
+ tracker = UploadTracker()
 
 
83
 
 
84
  def upload_to_hf(local_path, repo_path):
85
+ """Upload a file to the Hugging Face dataset repository"""
86
  try:
87
  upload_file(
88
  path_or_fileobj=local_path,
 
91
  repo_type="dataset",
92
  token=os.getenv("HF_TOKEN")
93
  )
94
+ print(f"Uploaded file: {local_path} to Hugging Face at {repo_path}")
95
  return True
96
  except Exception as e:
97
  print(f"Error uploading file to HF: {e}")
98
  return False
99
 
100
+ def on_submit(audio_input, prompt_index):
101
+ """Handle the submission of a recorded audio prompt"""
102
+ if audio_input is None:
103
+ return (
104
+ gr.Markdown(f"# Recording {prompt_index + 1}/{MAX_RECORDINGS}"),
105
+ gr.Markdown(f"## Please record the following prompt:"),
106
+ gr.Markdown(f"### {PROMPTS[prompt_index]}"),
107
+ gr.Audio(value=None, label="Record your response"),
108
+ gr.Button("Submit Recording", interactive=False),
109
+ gr.Button("Next Prompt", visible=False),
110
+ prompt_index
 
 
 
111
  )
 
 
 
 
 
 
 
 
112
 
113
+ # Process the audio
114
  sr, y = audio_input
 
115
 
116
+ # Generate a hash for this audio
117
+ audio_hash = xxhash.xxh32(bytes(y)).hexdigest()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
+ # Normalize audio
120
+ y = y.astype(np.float32)
121
+ y /= np.max(np.abs(y)) if np.max(np.abs(y)) > 0 else 1.0
122
+
123
+ # Resample to 16kHz
124
+ a = resampler.decode_example(resampler.encode_example({"array": y, "sampling_rate": sr}))
125
+
126
+ # Create unique filename
127
+ unique_id = str(uuid.uuid4())[:8]
128
+ local_filename = f"outputs/prompt{prompt_index}_{audio_hash}_{unique_id}.wav"
129
+
130
+ # Save locally
131
+ sf.write(local_filename, a["array"], a["sampling_rate"], format="wav")
132
+
133
+ # Upload to HF dataset
134
+ hf_path = f"{CATEGORY}/prompt{prompt_index}_{audio_hash}_{unique_id}.wav"
135
+ upload_to_hf(local_filename, hf_path)
136
+
137
+ # Add to tracker
138
+ tracker.add_recording(prompt_index, audio_hash, hf_path)
139
+
140
+ # Show success message
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  return (
142
+ gr.Markdown(f"# Recording {prompt_index + 1}/{MAX_RECORDINGS}"),
143
+ gr.Markdown(f"## Recording successfully uploaded!"),
144
+ gr.Markdown(f"### {PROMPTS[prompt_index]}"),
145
+ gr.Audio(value=None, label="Record your response"),
146
+ gr.Button("Submit Recording", interactive=False),
147
+ gr.Button("Next Prompt", visible=True),
148
+ prompt_index
149
  )
150
 
151
+ def next_prompt(prompt_index):
152
+ """Move to the next prompt"""
153
+ prompt_index += 1
 
 
 
 
 
 
 
 
 
 
 
 
154
 
155
+ # Check if we've gone through all prompts
156
+ if prompt_index >= min(len(PROMPTS), MAX_RECORDINGS):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  return (
158
+ gr.Markdown("# All recordings complete!"),
159
+ gr.Markdown("## Thank you for your participation."),
160
+ gr.Markdown("### You have completed all prompts."),
161
+ gr.Audio(visible=False),
 
 
 
162
  gr.Button(visible=False),
163
  gr.Button(visible=False),
164
+ prompt_index
 
 
 
 
 
 
 
 
165
  )
166
+
167
+ # Display the next prompt
168
+ return (
169
+ gr.Markdown(f"# Recording {prompt_index + 1}/{MAX_RECORDINGS}"),
170
+ gr.Markdown(f"## Please record the following prompt:"),
171
+ gr.Markdown(f"### {PROMPTS[prompt_index]}"),
172
+ gr.Audio(value=None, label="Record your response", sources=["microphone"]),
173
+ gr.Button("Submit Recording", interactive=False),
174
+ gr.Button("Next Prompt", visible=False),
175
+ prompt_index
176
+ )
177
 
178
+ def enable_submit_button(audio_input):
179
+ """Enable the submit button when audio is recorded"""
180
+ if audio_input is not None:
181
+ return gr.Button("Submit Recording", interactive=True)
182
+ return gr.Button("Submit Recording", interactive=False)
 
 
 
 
183
 
184
+ # Create a theme
185
  theme = gr.themes.Soft(
186
+ primary_hue="blue",
187
+ secondary_hue="indigo",
188
+ neutral_hue="slate",
 
 
 
 
 
 
 
 
 
 
 
 
189
  )
190
 
191
+ # Create Gradio interface
192
+ with gr.Blocks(theme=theme, css="footer {visibility: hidden}") as demo:
193
+ prompt_index = gr.State(0)
194
+
195
+ title = gr.Markdown(f"# Recording 1/{MAX_RECORDINGS}")
196
+ instructions = gr.Markdown("## Please record the following prompt:")
197
+ prompt_text = gr.Markdown(f"### {PROMPTS[0]}")
198
+
199
+ audio_input = gr.Audio(
200
+ label="Record your response",
201
+ sources=["microphone"],
202
+ streaming=False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  )
204
+
205
+ with gr.Row():
206
+ submit_btn = gr.Button("Submit Recording", interactive=False)
207
+ next_btn = gr.Button("Next Prompt", visible=False)
208
+
209
+ # Enable submit button when audio is recorded
210
+ audio_input.change(
211
+ fn=enable_submit_button,
212
+ inputs=[audio_input],
213
+ outputs=[submit_btn]
 
 
 
 
 
 
 
 
 
214
  )
215
+
216
+ # Handle submission
217
+ submit_btn.click(
218
+ fn=on_submit,
219
+ inputs=[audio_input, prompt_index],
220
+ outputs=[title, instructions, prompt_text, audio_input, submit_btn, next_btn, prompt_index]
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  )
222
+
223
+ # Handle next button
224
+ next_btn.click(
225
+ fn=next_prompt,
226
+ inputs=[prompt_index],
227
+ outputs=[title, instructions, prompt_text, audio_input, submit_btn, next_btn, prompt_index]
 
 
 
 
 
 
 
 
 
 
 
 
 
228
  )
 
229
 
230
+ # Launch the app
231
  if __name__ == "__main__":
232
+ # First, create the prompts.json file
233
+ with open("talkarena/prompts.json", "w") as f:
234
+ json.dump({
235
+ "prompts": PROMPTS
236
+ }, f, indent=2)
237
+
238
+ demo.launch(share=True)