Alina Lozovskaya
UI improvement
bae4131
raw
history blame
8.08 kB
import os
import sys
import gradio as gr
from loguru import logger
from huggingface_hub import HfApi, whoami
from yourbench_space.config import generate_base_config, save_config
from yourbench_space.utils import (
CONFIG_PATH,
UPLOAD_DIRECTORY,
BASE_API_URLS,
AVAILABLE_MODELS,
DEFAULT_MODEL,
SubprocessManager,
save_files,
)
UPLOAD_DIRECTORY.mkdir(parents=True, exist_ok=True)
logger.remove()
logger.add(sys.stderr, level="INFO")
command = ["uv", "run", "yourbench", f"--config={CONFIG_PATH}"]
manager = SubprocessManager(command)
def prepare_task(oauth_token: gr.OAuthToken | None, model_token: str):
new_env = os.environ.copy()
# Override env token, when running in gradio space
if oauth_token:
new_env["HF_TOKEN"] = oauth_token.token
new_env["MODEL_API_KEY"] = model_token
manager.start_process(custom_env=new_env)
def update_hf_org_dropdown(oauth_token: gr.OAuthToken | None) -> str:
if oauth_token is None:
print(
"Please, deploy this on Spaces and log in to view the list of available organizations"
)
return list()
user_info = whoami(oauth_token.token)
org_names = [org["name"] for org in user_info["orgs"]]
user_name = user_info["name"]
org_names.insert(0, user_name)
return gr.Dropdown(org_names, value=user_name, label="Organization")
config_output = gr.Code(label="Generated Config", language="yaml")
model_name = gr.Dropdown(
label="Model Name",
value=DEFAULT_MODEL,
choices=AVAILABLE_MODELS,
allow_custom_value=True,
)
base_url = gr.Textbox(
label="Model API Base URL",
value=BASE_API_URLS["huggingface"],
info="Use a custom API base URL for Hugging Face Inference Endpoints",
)
def make_models(model_name=None):
if model_name is None:
model_name = DEFAULT_MODEL
ingestion_model = gr.Dropdown(
label="Model for ingestion",
choices=AVAILABLE_MODELS,
value=model_name,
interactive=False,
allow_custom_value=True,
)
summarization_model = gr.Dropdown(
label="Model for summarization",
choices=AVAILABLE_MODELS,
value=model_name,
interactive=False,
allow_custom_value=True,
)
single_shot_question_generation_model = gr.Dropdown(
label="Model for single shot question generation",
choices=AVAILABLE_MODELS,
value=model_name,
interactive=False,
allow_custom_value=True,
)
multi_hop_question_generation_model = gr.Dropdown(
label="Model for multi hop question generation",
choices=AVAILABLE_MODELS,
value=model_name,
interactive=False,
allow_custom_value=True,
)
answer_generation_model = gr.Dropdown(
label="Model for answer generation",
choices=AVAILABLE_MODELS,
value=model_name,
interactive=False,
allow_custom_value=True,
)
judge_answers_model = gr.Dropdown(
label="Model for answer judging",
choices=AVAILABLE_MODELS,
value=model_name,
interactive=False,
allow_custom_value=True,
)
return [
ingestion_model,
summarization_model,
single_shot_question_generation_model,
multi_hop_question_generation_model,
answer_generation_model,
judge_answers_model,
]
(
ingestion_model,
summarization_model,
single_shot_question_generation_model,
multi_hop_question_generation_model,
answer_generation_model,
judge_answers_model,
) = make_models()
with gr.Blocks() as app:
gr.Markdown("## YourBench Configuration")
with gr.Row():
login_btn = gr.LoginButton()
with gr.Tab("Configuration"):
with gr.Accordion("Hugging Face"):
hf_org_dropdown = gr.Dropdown(
list(),
label="Organization",
allow_custom_value=True,
)
app.load(update_hf_org_dropdown, inputs=None, outputs=hf_org_dropdown)
hf_dataset_prefix = gr.Textbox(
label="Dataset Prefix",
value="yourbench",
info="Prefix applied to all datasets",
)
private_dataset = gr.Checkbox(
label="Private Dataset",
value=True,
info="Create private datasets (recommended by default)",
)
with gr.Accordion("Model"):
model_name.render()
# TODO handle this better
model_name.change(
make_models,
inputs=[model_name],
outputs=[
ingestion_model,
summarization_model,
single_shot_question_generation_model,
multi_hop_question_generation_model,
answer_generation_model,
judge_answers_model,
],
)
provider = gr.Radio(
["huggingface", "openrouter", "openai"],
value="huggingface",
label="Inference Provider",
)
def set_base_url(provider):
return gr.Textbox(
label="Model API Base URL", value=BASE_API_URLS.get(provider, "")
)
provider.change(fn=set_base_url, inputs=provider, outputs=base_url)
model_api_key = gr.Textbox(label="Model API Key", type="password")
base_url.render()
max_concurrent_requests = gr.Radio(
[8, 16, 32], value=16, label="Max Concurrent Requests"
)
with gr.Accordion("Stages"):
ingestion_model.render()
summarization_model.render()
single_shot_question_generation_model.render()
multi_hop_question_generation_model.render()
answer_generation_model.render()
judge_answers_model.render()
preview_button = gr.Button("Generate New Config")
preview_button.click(
generate_base_config,
inputs=[
hf_org_dropdown,
model_name,
provider,
base_url,
model_api_key,
max_concurrent_requests,
hf_dataset_prefix,
private_dataset,
ingestion_model,
summarization_model,
single_shot_question_generation_model,
multi_hop_question_generation_model,
answer_generation_model,
judge_answers_model,
],
outputs=config_output,
)
with gr.Tab("Raw Configuration"):
config_output.render()
config_output.change(
fn=save_config,
inputs=[config_output],
outputs=[gr.Textbox(label="Save Status")],
)
with gr.Tab("Files"):
file_input = gr.File(
label="Upload text files",
file_count="multiple",
file_types=[".txt", ".md", ".html"],
)
output = gr.Textbox(label="Log")
file_input.upload(save_files, file_input, output)
with gr.Tab("Run Generation"):
log_output = gr.Code(
label="Log Output", language=None, lines=20, interactive=False
)
log_timer = gr.Timer(0.05, active=True)
log_timer.tick(manager.read_and_get_output, outputs=log_output)
with gr.Row():
process_status = gr.Checkbox(label="Process Status", interactive=False)
status_timer = gr.Timer(0.05, active=True)
status_timer.tick(manager.is_running, outputs=process_status)
with gr.Row():
start_button = gr.Button("Start Task")
start_button.click(prepare_task, inputs=[model_api_key])
stop_button = gr.Button("Stop Task")
stop_button.click(manager.stop_process)
kill_button = gr.Button("Kill Task")
kill_button.click(manager.kill_process)
app.launch()