import argparse import asyncio import random import textwrap import time import gradio as gr import xxhash from dotenv import load_dotenv from transformers import pipeline import talk_arena.streaming_helpers as sh from talk_arena.db_utils import TinyThreadSafeDB load_dotenv() def parse_args(): parser = argparse.ArgumentParser(description="Talk Arena Demo") parser.add_argument("--free_only", action="store_true", help="Only use free models") return parser.parse_args() args = parse_args() if gr.NO_RELOAD: # Prevents Re-init during hot reloading # Transcription Disabled for Public Interface # asr_pipe = pipeline( # task="automatic-speech-recognition", # model="openai/whisper-large-v3-turbo", # chunk_length_s=30, # device="cuda:1", # ) anonymous = True # Generation Setup diva_audio, diva = sh.api_streaming("WillHeld/DiVA-llama-3-v0-8b") qwen2_audio, qwen2 = sh.api_streaming("Qwen/Qwen2-Audio-7B-Instruct") pipelined_system, pipeline_model = sh.api_streaming("pipeline/meta-llama/Meta-Llama-3-8B-Instruct") if not args.free_only: gemini_audio, gemini_model = sh.gemini_streaming("models/gemini-1.5-flash") gpt4o_audio, gpt4o_model = sh.gpt4o_streaming("models/gpt4o") geminip_audio, geminip_model = sh.gemini_streaming("models/gemini-1.5-pro") gemini2_audio, gemini2_model = sh.gemini_streaming("models/gemini-2.0-flash-exp") typhoon_audio, typhoon_model = sh.api_streaming("scb10x/llama-3-typhoon-audio-8b-2411") competitor_info = [ (sh.gradio_gen_factory(diva_audio, "DiVA Llama 3 8B", anonymous), "diva_3_8b", "DiVA Llama 3 8B"), (sh.gradio_gen_factory(qwen2_audio, "Qwen 2", anonymous), "qwen2", "Qwen 2 Audio"), ( sh.gradio_gen_factory(pipelined_system, "Pipelined Llama 3 8B", anonymous), "pipe_l3.0", "Pipelined Llama 3 8B", ), (sh.gradio_gen_factory(typhoon_audio, "Typhoon Audio", anonymous), "typhoon_audio", "Typhoon Audio"), ] # Add paid models if flag is not set if not args.free_only: competitor_info += [ (sh.gradio_gen_factory(gemini_audio, "Gemini 1.5 Flash", anonymous), "gemini_1.5f", "Gemini 1.5 Flash"), (sh.gradio_gen_factory(gpt4o_audio, "GPT4o", anonymous), "gpt4o", "GPT-4o"), (sh.gradio_gen_factory(geminip_audio, "Gemini 1.5 Pro", anonymous), "gemini_1.5p", "Gemini 1.5 Pro"), (sh.gradio_gen_factory(geminip_audio, "Gemini 2 Flash", anonymous), "gemini_2f", "Gemini 2 Flash"), ] resp_generators = [generator for generator, _, _ in competitor_info] model_shorthand = [shorthand for _, shorthand, _ in competitor_info] model_name = [full_name for _, _, full_name in competitor_info] all_models = list(range(len(model_shorthand))) async def pairwise_response_async(audio_input, state, model_order): if audio_input == None: raise StopAsyncIteration( "", "", gr.Button(visible=False), gr.Button(visible=False), gr.Button(visible=False), state, audio_input, None, None, None, ) spinner_id = 0 spinners = ["◐ ", "◓ ", "◑", "◒"] spinner = spinners[0] gen_pair = [resp_generators[model_order[0]], resp_generators[model_order[1]]] latencies = [{}, {}] # Store timing info for each model resps = [gr.Textbox(value="", info="", visible=False), gr.Textbox(value="", info="", visible=False)] error_in_model = False for order, generator in enumerate(gen_pair): start_time = time.time() first_token = True total_length = 0 try: async for local_resp in generator(audio_input, order): total_length += 1 if first_token: latencies[order]["time_to_first_token"] = time.time() - start_time first_token = False resps[order] = local_resp spinner = spinners[spinner_id] spinner_id = (spinner_id + 1) % 4 yield ( gr.Button( value=spinner + " Generating Responses " + spinner, interactive=False, variant="primary", ), resps[0], resps[1], gr.Button(visible=False), gr.Button(visible=False), gr.Button(visible=False), state, audio_input, None, None, latencies, ) latencies[order]["total_time"] = time.time() - start_time latencies[order]["response_length"] = total_length except: error_in_model = True resps[order] = gr.Textbox( info=f"Error thrown by Model {order+1} API", value="" if first_token else resps[order]._constructor_args[0]["value"], visible=True, label=f"Model {order+1}", ) yield ( gr.Button( value=spinner + " Generating Responses " + spinner, interactive=False, variant="primary", ), resps[0], resps[1], gr.Button(visible=False), gr.Button(visible=False), gr.Button(visible=False), state, audio_input, None, None, latencies, ) latencies[order]["total_time"] = time.time() - start_time latencies[order]["response_length"] = total_length print(latencies) yield ( gr.Button(value="Vote for which model is better!", interactive=False, variant="primary", visible=False), resps[0], resps[1], gr.Button(visible=not error_in_model), gr.Button(visible=not error_in_model), gr.Button(visible=not error_in_model), responses_complete(state), audio_input, gr.Textbox(visible=False), gr.Audio(visible=False), latencies, ) def on_page_load(state, model_order): if state == 0: # gr.Info( # "Record something you'd say to an AI Assistant! Think about what you usually use Siri, Google Assistant," # " or ChatGPT for." # ) state = 1 model_order = random.sample(all_models, 2) if anonymous else model_order return state, model_order def recording_complete(state): if state == 1: # gr.Info( # "Once you submit your recording, you'll receive responses from different models. This might take a second." # ) state = 2 return ( gr.Button(value="Starting Generation", interactive=False, variant="primary"), state, ) def responses_complete(state): if state == 2: gr.Info( "Give us your feedback! Mark which model gave you the best response so we can understand the quality of" " these different voice assistant models." ) state = 3 return state def clear_factory(button_id): async def clear(audio_input, model_order, pref_counter, reasoning, latency): textbox1 = gr.Textbox(visible=False) textbox2 = gr.Textbox(visible=False) if button_id != None: sr, y = audio_input x = xxhash.xxh32(bytes(y)).hexdigest() await db.insert( { "audio_hash": x, "outcome": button_id, "model_a": model_shorthand[model_order[0]], "model_b": model_shorthand[model_order[1]], "why": reasoning, "model_a_latency": latency[0], "model_b_latency": latency[1], } ) pref_counter += 1 model_a = model_name[model_order[0]] model_b = model_name[model_order[1]] textbox1 = gr.Textbox( visible=True, info=f"Response from {model_a}
Time-to-First-Character: {latency[0]['time_to_first_token']:.2f} ms, Time Per Character: {latency[0]['total_time']/latency[0]['response_length']:.2f} ms
", ) textbox2 = gr.Textbox( visible=True, info=f"Response from {model_b}Time-to-First-Character: {latency[1]['time_to_first_token']:.2f} ms, Time Per Character: {latency[1]['total_time']/latency[1]['response_length']:.2f} ms
", ) try: sr, y = audio_input x = xxhash.xxh32(bytes(y)).hexdigest() os.remove(f"{x}.wav") except: # file already deleted, this is just a failsafe to assure data is cleared pass counter_text = f"# {pref_counter}/10 Preferences Submitted" if pref_counter >= 10 and False: # Currently Disabled, Manages Prolific Completionx code = "PLACEHOLDER" counter_text = f"# Completed! Completion Code: {code}" counter_text = "" if anonymous: model_order = random.sample(all_models, 2) return ( model_order, gr.Button( value="Record Audio to Submit Again!", interactive=False, visible=True, ), gr.Button(visible=False), gr.Button(visible=False), gr.Button(visible=False), None, textbox1, textbox2, pref_counter, counter_text, gr.Textbox(visible=False), gr.Audio(visible=False), ) return clear def transcribe(transc, voice_reason): if transc is None: transc = "" transc += " " + asr_pipe(voice_reason, generate_kwargs={"task": "transcribe"}, return_timestamps=False)["text"] return transc, gr.Audio(value=None) theme = gr.themes.Soft( primary_hue=gr.themes.Color( c100="#82000019", c200="#82000033", c300="#8200004c", c400="#82000066", c50="#8200007f", c500="#8200007f", c600="#82000099", c700="#820000b2", c800="#820000cc", c900="#820000e5", c950="#820000f2", ), secondary_hue="rose", neutral_hue="stone", ) with open("src/talk_arena/styles.css", "r") as css_file: custom_css = css_file.read() db = TinyThreadSafeDB("live_votes.json") with gr.Blocks(theme=theme, fill_height=True, css=custom_css) as demo: submitted_preferences = gr.State(0) state = gr.State(0) model_order = gr.State([]) latency = gr.State([]) with gr.Row(): counter_text = gr.Markdown( "" ) # "# 0/10 Preferences Submitted.\n Follow the pop-up tips to submit your first preference.") with gr.Row(): audio_input = gr.Audio(sources=["microphone"], streaming=False, label="Audio Input") with gr.Row(equal_height=True): with gr.Column(scale=1): out1 = gr.Textbox(visible=False, lines=5, autoscroll=True) with gr.Column(scale=1): out2 = gr.Textbox(visible=False, lines=5, autoscroll=True) with gr.Row(): btn = gr.Button(value="Record Audio to Submit!", interactive=False) with gr.Row(equal_height=True): reason = gr.Textbox(label="[Optional] Explain Your Preferences", visible=False, scale=4) reason_record = gr.Audio( sources=["microphone"], interactive=True, streaming=False, label="Speak to transcribe!", visible=False, type="filepath", # waveform_options={"show_recording_waveform": False}, scale=1, ) with gr.Row(): best1 = gr.Button(value="Model 1 is better", visible=False) tie = gr.Button(value="Tie", visible=False) best2 = gr.Button(value="Model 2 is better", visible=False) with gr.Row(): contact = gr.Markdown("") # reason_record.stop_recording(transcribe, inputs=[reason, reason_record], outputs=[reason, reason_record]) audio_input.stop_recording( recording_complete, [state], [btn, state], ).then( fn=pairwise_response_async, inputs=[audio_input, state, model_order], outputs=[btn, out1, out2, best1, best2, tie, state, audio_input, reason, reason_record, latency], ) audio_input.start_recording( lambda: gr.Button(value="Uploading Audio to Cloud", interactive=False, variant="primary"), None, btn, ) best1.click( fn=clear_factory(0), inputs=[audio_input, model_order, submitted_preferences, reason, latency], outputs=[ model_order, btn, best1, best2, tie, audio_input, out1, out2, submitted_preferences, counter_text, reason, reason_record, ], ) tie.click( fn=clear_factory(0.5), inputs=[audio_input, model_order, submitted_preferences, reason, latency], outputs=[ model_order, btn, best1, best2, tie, audio_input, out1, out2, submitted_preferences, counter_text, reason, reason_record, ], ) best2.click( fn=clear_factory(1), inputs=[audio_input, model_order, submitted_preferences, reason, latency], outputs=[ model_order, btn, best1, best2, tie, audio_input, out1, out2, submitted_preferences, counter_text, reason, reason_record, ], ) audio_input.clear( clear_factory(None), [audio_input, model_order, submitted_preferences, reason, latency], [ model_order, btn, best1, best2, tie, audio_input, out1, out2, submitted_preferences, counter_text, reason, reason_record, ], ) demo.load(fn=on_page_load, inputs=[state, model_order], outputs=[state, model_order]) if __name__ == "__main__": demo.queue(default_concurrency_limit=40, api_open=False).launch(share=True, ssr_mode=False)