Spaces:
Paused
Paused
import os | |
import random | |
import json | |
from pathlib import Path | |
import gradio as gr | |
from pydantic import BaseModel | |
from vllm import LLM, SamplingParams | |
from vllm.sampling_params import GuidedDecodingParams | |
VLLM_MODEL_NAME = os.getenv("VLLM_MODEL_NAME") | |
VLLM_GPU_MEMORY_UTILIZATION = float(os.getenv("VLLM_GPU_MEMORY_UTILIZATION")) | |
VLLM_MAX_SEQ_LEN = int(os.getenv("VLLM_MAX_SEQ_LEN")) | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
VLLM_DTYPE = os.getenv("VLLM_DTYPE") | |
# -------------------------------- HELPERS ------------------------------------- | |
def load_prompt(path: Path) -> str: | |
with path.open("r") as file: | |
prompt = file.read() | |
return prompt | |
# -------------------------------- Data Models ------------------------------- | |
class StructuredQueryRewriteResponse(BaseModel): | |
general: str | None | |
subjective: str | None | |
purpose: str | None | |
technical: str | None | |
curiosity: str | None | |
class QueryRewrite(BaseModel): | |
rewrites: list[str] | None = None | |
structured: StructuredQueryRewriteResponse | None = None | |
# -------------------------------- VLLM -------------------------------------- | |
local_llm = LLM( | |
model=VLLM_MODEL_NAME, | |
max_model_len=VLLM_MAX_SEQ_LEN, | |
gpu_memory_utilization=VLLM_GPU_MEMORY_UTILIZATION, | |
hf_token=HF_TOKEN, | |
enforce_eager=True, | |
dtype=VLLM_DTYPE, | |
) | |
json_schema = StructuredQueryRewriteResponse.model_json_schema() | |
guided_decoding_params_json = GuidedDecodingParams(json=json_schema) | |
sampling_params_json = SamplingParams( | |
guided_decoding=guided_decoding_params_json, | |
temperature=0.7, | |
top_p=0.8, | |
repetition_penalty=1.05, | |
max_tokens=1024, | |
) | |
vllm_system_prompt = ( | |
"You are a search query optimization assistant built into" | |
" music genre search engine, helping users discover novel music genres." | |
) | |
vllm_prompt = load_prompt(Path("./prompts/local.txt")) | |
# Dummy model functions for demonstration | |
def recommend_sadaimrec(query: str): | |
prompt = vllm_prompt.format(query=query) | |
messages = [ | |
{"role": "system", "content": vllm_system_prompt}, | |
{"role": "user", "content": prompt}, | |
] | |
outputs = local_llm.chat( | |
messages=messages, | |
sampling_params=sampling_params_json, | |
) | |
rewrite_json = json.loads(outputs[0].outputs[0].text) | |
rewrite = QueryRewrite( | |
rewrites=[x for x in list(rewrite_json.values()) if x is not None], | |
structured=rewrite_json, | |
) | |
return f"SADAIMREC: response to '{rewrite.model_dump_json(indent=4)}'" | |
def recommend_chatgpt(query: str): | |
return f"CHATGPT: response to '{query}'" | |
# Mapping names to functions | |
pipelines = { | |
"sadaimrec": recommend_sadaimrec, | |
"chatgpt": recommend_chatgpt, | |
} | |
# Interface logic | |
def generate_responses(query): | |
# Randomize model order | |
pipeline_names = list(pipelines.keys()) | |
random.shuffle(pipeline_names) | |
# Generate responses | |
resp1 = pipelines[pipeline_names[0]](query) | |
resp2 = pipelines[pipeline_names[1]](query) | |
# Return texts and hidden labels | |
return resp1, resp2, pipeline_names[0], pipeline_names[1] | |
# Callback to capture vote | |
def handle_vote(selected, label1, label2, resp1, resp2): | |
chosen_name = label1 if selected == "Option 1" else label2 | |
chosen_resp = resp1 if selected == "Option 1" else resp2 | |
print(f"User voted for {chosen_name}: '{chosen_resp}'") | |
return ( | |
"Thank you for your vote! Restarting in 2 seconds...", | |
gr.update(active=True), | |
) | |
def reset_ui(): | |
return ( | |
gr.update(value="", visible=False), # hide row | |
gr.update(value=""), # clear query | |
gr.update(visible=False), # hide radio | |
gr.update(visible=False), # hide vote button | |
gr.update(value=""), # clear Option 1 text | |
gr.update(value=""), # clear Option 2 text | |
gr.update(value=""), # clear result | |
gr.update(active=False), | |
) | |
with gr.Blocks() as demo: | |
gr.Markdown("# Music Genre Recommendation Side-By-Side Comparison") | |
query = gr.Textbox(label="Your Query") | |
submit_btn = gr.Button("Submit") | |
# timer that resets ui after feedback is sent | |
reset_timer = gr.Timer(value=2.0, active=False) | |
# Hidden components to store model responses and names | |
with gr.Row(visible=False) as response_row: | |
response_1 = gr.Textbox(label="Option 1", interactive=False) | |
response_2 = gr.Textbox(label="Option 2", interactive=False) | |
model_label_1 = gr.Textbox(visible=False) | |
model_label_2 = gr.Textbox(visible=False) | |
# Feedback | |
vote = gr.Radio( | |
["Option 1", "Option 2"], label="Select Best Response", visible=False | |
) | |
vote_btn = gr.Button("Vote", visible=False) | |
result = gr.Textbox(label="Console", interactive=False) | |
# On submit | |
submit_btn.click( # generate | |
fn=generate_responses, | |
inputs=[query], | |
outputs=[response_1, response_2, model_label_1, model_label_2], | |
) | |
submit_btn.click( # update ui | |
fn=lambda: ( | |
gr.update(visible=True), | |
gr.update(visible=True), | |
gr.update(visible=True), | |
), | |
inputs=None, | |
outputs=[response_row, vote, vote_btn], | |
) | |
# Feedback handling | |
vote_btn.click( | |
fn=handle_vote, | |
inputs=[vote, model_label_1, model_label_2, response_1, response_2], | |
outputs=[result, reset_timer], | |
) | |
reset_timer.tick( | |
fn=reset_ui, | |
inputs=None, | |
outputs=[ | |
response_row, | |
query, | |
vote, | |
vote_btn, | |
response_1, | |
response_2, | |
result, | |
reset_timer, | |
], | |
trigger_mode="once", | |
) | |
if __name__ == "__main__": | |
demo.launch(server_name="0.0.0.0", server_port=7860) | |