Spaces:
Running
Running
import argparse | |
import asyncio | |
import random | |
import textwrap | |
import time | |
import gradio as gr | |
import xxhash | |
from dotenv import load_dotenv | |
from transformers import pipeline | |
import talk_arena.streaming_helpers as sh | |
from talk_arena.db_utils import TinyThreadSafeDB | |
load_dotenv() | |
def parse_args(): | |
parser = argparse.ArgumentParser(description="Talk Arena Demo") | |
parser.add_argument("--free_only", action="store_true", help="Only use free models") | |
return parser.parse_args() | |
args = parse_args() | |
if gr.NO_RELOAD: # Prevents Re-init during hot reloading | |
# Transcription Disabled for Public Interface | |
# asr_pipe = pipeline( | |
# task="automatic-speech-recognition", | |
# model="openai/whisper-large-v3-turbo", | |
# chunk_length_s=30, | |
# device="cuda:1", | |
# ) | |
anonymous = True | |
# Generation Setup | |
diva_audio, diva = sh.api_streaming("WillHeld/DiVA-llama-3-v0-8b") | |
qwen2_audio, qwen2 = sh.api_streaming("Qwen/Qwen2-Audio-7B-Instruct") | |
pipelined_system, pipeline_model = sh.api_streaming("pipeline/meta-llama/Meta-Llama-3-8B-Instruct") | |
if not args.free_only: | |
gemini_audio, gemini_model = sh.gemini_streaming("models/gemini-1.5-flash") | |
gpt4o_audio, gpt4o_model = sh.gpt4o_streaming("models/gpt4o") | |
geminip_audio, geminip_model = sh.gemini_streaming("models/gemini-1.5-pro") | |
gemini2_audio, gemini2_model = sh.gemini_streaming("models/gemini-2.0-flash-exp") | |
typhoon_audio, typhoon_model = sh.api_streaming("scb10x/llama-3-typhoon-audio-8b-2411") | |
competitor_info = [ | |
(sh.gradio_gen_factory(diva_audio, "DiVA Llama 3 8B", anonymous), "diva_3_8b", "DiVA Llama 3 8B"), | |
(sh.gradio_gen_factory(qwen2_audio, "Qwen 2", anonymous), "qwen2", "Qwen 2 Audio"), | |
( | |
sh.gradio_gen_factory(pipelined_system, "Pipelined Llama 3 8B", anonymous), | |
"pipe_l3.0", | |
"Pipelined Llama 3 8B", | |
), | |
(sh.gradio_gen_factory(typhoon_audio, "Typhoon Audio", anonymous), "typhoon_audio", "Typhoon Audio"), | |
] | |
# Add paid models if flag is not set | |
if not args.free_only: | |
competitor_info += [ | |
(sh.gradio_gen_factory(gemini_audio, "Gemini 1.5 Flash", anonymous), "gemini_1.5f", "Gemini 1.5 Flash"), | |
(sh.gradio_gen_factory(gpt4o_audio, "GPT4o", anonymous), "gpt4o", "GPT-4o"), | |
(sh.gradio_gen_factory(geminip_audio, "Gemini 1.5 Pro", anonymous), "gemini_1.5p", "Gemini 1.5 Pro"), | |
(sh.gradio_gen_factory(geminip_audio, "Gemini 2 Flash", anonymous), "gemini_2f", "Gemini 2 Flash"), | |
] | |
resp_generators = [generator for generator, _, _ in competitor_info] | |
model_shorthand = [shorthand for _, shorthand, _ in competitor_info] | |
model_name = [full_name for _, _, full_name in competitor_info] | |
all_models = list(range(len(model_shorthand))) | |
async def pairwise_response_async(audio_input, state, model_order): | |
if audio_input == None: | |
raise StopAsyncIteration( | |
"", | |
"", | |
gr.Button(visible=False), | |
gr.Button(visible=False), | |
gr.Button(visible=False), | |
state, | |
audio_input, | |
None, | |
None, | |
None, | |
) | |
spinner_id = 0 | |
spinners = ["β ", "β ", "β", "β"] | |
spinner = spinners[0] | |
gen_pair = [resp_generators[model_order[0]], resp_generators[model_order[1]]] | |
latencies = [{}, {}] # Store timing info for each model | |
resps = [gr.Textbox(value="", info="", visible=False), gr.Textbox(value="", info="", visible=False)] | |
error_in_model = False | |
for order, generator in enumerate(gen_pair): | |
start_time = time.time() | |
first_token = True | |
total_length = 0 | |
try: | |
async for local_resp in generator(audio_input, order): | |
total_length += 1 | |
if first_token: | |
latencies[order]["time_to_first_token"] = time.time() - start_time | |
first_token = False | |
resps[order] = local_resp | |
spinner = spinners[spinner_id] | |
spinner_id = (spinner_id + 1) % 4 | |
yield ( | |
gr.Button( | |
value=spinner + " Generating Responses " + spinner, | |
interactive=False, | |
variant="primary", | |
), | |
resps[0], | |
resps[1], | |
gr.Button(visible=False), | |
gr.Button(visible=False), | |
gr.Button(visible=False), | |
state, | |
audio_input, | |
None, | |
None, | |
latencies, | |
) | |
latencies[order]["total_time"] = time.time() - start_time | |
latencies[order]["response_length"] = total_length | |
except: | |
error_in_model = True | |
resps[order] = gr.Textbox( | |
info=f"<strong>Error thrown by Model {order+1} API</strong>", | |
value="" if first_token else resps[order]._constructor_args[0]["value"], | |
visible=True, | |
label=f"Model {order+1}", | |
) | |
yield ( | |
gr.Button( | |
value=spinner + " Generating Responses " + spinner, | |
interactive=False, | |
variant="primary", | |
), | |
resps[0], | |
resps[1], | |
gr.Button(visible=False), | |
gr.Button(visible=False), | |
gr.Button(visible=False), | |
state, | |
audio_input, | |
None, | |
None, | |
latencies, | |
) | |
latencies[order]["total_time"] = time.time() - start_time | |
latencies[order]["response_length"] = total_length | |
print(latencies) | |
yield ( | |
gr.Button(value="Vote for which model is better!", interactive=False, variant="primary", visible=False), | |
resps[0], | |
resps[1], | |
gr.Button(visible=not error_in_model), | |
gr.Button(visible=not error_in_model), | |
gr.Button(visible=not error_in_model), | |
responses_complete(state), | |
audio_input, | |
gr.Textbox(visible=False), | |
gr.Audio(visible=False), | |
latencies, | |
) | |
def on_page_load(state, model_order): | |
if state == 0: | |
# gr.Info( | |
# "Record something you'd say to an AI Assistant! Think about what you usually use Siri, Google Assistant," | |
# " or ChatGPT for." | |
# ) | |
state = 1 | |
model_order = random.sample(all_models, 2) if anonymous else model_order | |
return state, model_order | |
def recording_complete(state): | |
if state == 1: | |
# gr.Info( | |
# "Once you submit your recording, you'll receive responses from different models. This might take a second." | |
# ) | |
state = 2 | |
return ( | |
gr.Button(value="Starting Generation", interactive=False, variant="primary"), | |
state, | |
) | |
def responses_complete(state): | |
if state == 2: | |
gr.Info( | |
"Give us your feedback! Mark which model gave you the best response so we can understand the quality of" | |
" these different voice assistant models." | |
) | |
state = 3 | |
return state | |
def clear_factory(button_id): | |
async def clear(audio_input, model_order, pref_counter, reasoning, latency): | |
textbox1 = gr.Textbox(visible=False) | |
textbox2 = gr.Textbox(visible=False) | |
if button_id != None: | |
sr, y = audio_input | |
x = xxhash.xxh32(bytes(y)).hexdigest() | |
await db.insert( | |
{ | |
"audio_hash": x, | |
"outcome": button_id, | |
"model_a": model_shorthand[model_order[0]], | |
"model_b": model_shorthand[model_order[1]], | |
"why": reasoning, | |
"model_a_latency": latency[0], | |
"model_b_latency": latency[1], | |
} | |
) | |
pref_counter += 1 | |
model_a = model_name[model_order[0]] | |
model_b = model_name[model_order[1]] | |
textbox1 = gr.Textbox( | |
visible=True, | |
info=f"<strong style='color: #53565A'>Response from {model_a}</strong><p>Time-to-First-Character: {latency[0]['time_to_first_token']:.2f} ms, Time Per Character: {latency[0]['total_time']/latency[0]['response_length']:.2f} ms</p>", | |
) | |
textbox2 = gr.Textbox( | |
visible=True, | |
info=f"<strong style='color: #53565A'>Response from {model_b}</strong><p>Time-to-First-Character: {latency[1]['time_to_first_token']:.2f} ms, Time Per Character: {latency[1]['total_time']/latency[1]['response_length']:.2f} ms</p>", | |
) | |
try: | |
sr, y = audio_input | |
x = xxhash.xxh32(bytes(y)).hexdigest() | |
os.remove(f"{x}.wav") | |
except: | |
# file already deleted, this is just a failsafe to assure data is cleared | |
pass | |
counter_text = f"# {pref_counter}/10 Preferences Submitted" | |
if pref_counter >= 10 and False: # Currently Disabled, Manages Prolific Completionx | |
code = "PLACEHOLDER" | |
counter_text = f"# Completed! Completion Code: {code}" | |
counter_text = "" | |
if anonymous: | |
model_order = random.sample(all_models, 2) | |
return ( | |
model_order, | |
gr.Button( | |
value="Record Audio to Submit Again!", | |
interactive=False, | |
visible=True, | |
), | |
gr.Button(visible=False), | |
gr.Button(visible=False), | |
gr.Button(visible=False), | |
None, | |
textbox1, | |
textbox2, | |
pref_counter, | |
counter_text, | |
gr.Textbox(visible=False), | |
gr.Audio(visible=False), | |
) | |
return clear | |
def transcribe(transc, voice_reason): | |
if transc is None: | |
transc = "" | |
transc += " " + asr_pipe(voice_reason, generate_kwargs={"task": "transcribe"}, return_timestamps=False)["text"] | |
return transc, gr.Audio(value=None) | |
theme = gr.themes.Soft( | |
primary_hue=gr.themes.Color( | |
c100="#82000019", | |
c200="#82000033", | |
c300="#8200004c", | |
c400="#82000066", | |
c50="#8200007f", | |
c500="#8200007f", | |
c600="#82000099", | |
c700="#820000b2", | |
c800="#820000cc", | |
c900="#820000e5", | |
c950="#820000f2", | |
), | |
secondary_hue="rose", | |
neutral_hue="stone", | |
) | |
with open("src/talk_arena/styles.css", "r") as css_file: | |
custom_css = css_file.read() | |
db = TinyThreadSafeDB("live_votes.json") | |
with gr.Blocks(theme=theme, fill_height=True, css=custom_css) as demo: | |
submitted_preferences = gr.State(0) | |
state = gr.State(0) | |
model_order = gr.State([]) | |
latency = gr.State([]) | |
with gr.Row(): | |
counter_text = gr.Markdown( | |
"" | |
) # "# 0/10 Preferences Submitted.\n Follow the pop-up tips to submit your first preference.") | |
with gr.Row(): | |
audio_input = gr.Audio(sources=["microphone"], streaming=False, label="Audio Input") | |
with gr.Row(equal_height=True): | |
with gr.Column(scale=1): | |
out1 = gr.Textbox(visible=False, lines=5, autoscroll=True) | |
with gr.Column(scale=1): | |
out2 = gr.Textbox(visible=False, lines=5, autoscroll=True) | |
with gr.Row(): | |
btn = gr.Button(value="Record Audio to Submit!", interactive=False) | |
with gr.Row(equal_height=True): | |
reason = gr.Textbox(label="[Optional] Explain Your Preferences", visible=False, scale=4) | |
reason_record = gr.Audio( | |
sources=["microphone"], | |
interactive=True, | |
streaming=False, | |
label="Speak to transcribe!", | |
visible=False, | |
type="filepath", | |
# waveform_options={"show_recording_waveform": False}, | |
scale=1, | |
) | |
with gr.Row(): | |
best1 = gr.Button(value="Model 1 is better", visible=False) | |
tie = gr.Button(value="Tie", visible=False) | |
best2 = gr.Button(value="Model 2 is better", visible=False) | |
with gr.Row(): | |
contact = gr.Markdown("") | |
# reason_record.stop_recording(transcribe, inputs=[reason, reason_record], outputs=[reason, reason_record]) | |
audio_input.stop_recording( | |
recording_complete, | |
[state], | |
[btn, state], | |
).then( | |
fn=pairwise_response_async, | |
inputs=[audio_input, state, model_order], | |
outputs=[btn, out1, out2, best1, best2, tie, state, audio_input, reason, reason_record, latency], | |
) | |
audio_input.start_recording( | |
lambda: gr.Button(value="Uploading Audio to Cloud", interactive=False, variant="primary"), | |
None, | |
btn, | |
) | |
best1.click( | |
fn=clear_factory(0), | |
inputs=[audio_input, model_order, submitted_preferences, reason, latency], | |
outputs=[ | |
model_order, | |
btn, | |
best1, | |
best2, | |
tie, | |
audio_input, | |
out1, | |
out2, | |
submitted_preferences, | |
counter_text, | |
reason, | |
reason_record, | |
], | |
) | |
tie.click( | |
fn=clear_factory(0.5), | |
inputs=[audio_input, model_order, submitted_preferences, reason, latency], | |
outputs=[ | |
model_order, | |
btn, | |
best1, | |
best2, | |
tie, | |
audio_input, | |
out1, | |
out2, | |
submitted_preferences, | |
counter_text, | |
reason, | |
reason_record, | |
], | |
) | |
best2.click( | |
fn=clear_factory(1), | |
inputs=[audio_input, model_order, submitted_preferences, reason, latency], | |
outputs=[ | |
model_order, | |
btn, | |
best1, | |
best2, | |
tie, | |
audio_input, | |
out1, | |
out2, | |
submitted_preferences, | |
counter_text, | |
reason, | |
reason_record, | |
], | |
) | |
audio_input.clear( | |
clear_factory(None), | |
[audio_input, model_order, submitted_preferences, reason, latency], | |
[ | |
model_order, | |
btn, | |
best1, | |
best2, | |
tie, | |
audio_input, | |
out1, | |
out2, | |
submitted_preferences, | |
counter_text, | |
reason, | |
reason_record, | |
], | |
) | |
demo.load(fn=on_page_load, inputs=[state, model_order], outputs=[state, model_order]) | |
if __name__ == "__main__": | |
demo.queue(default_concurrency_limit=40, api_open=False).launch(share=True, ssr_mode=False) | |