sdmrec-docker / app.py
Oleh Kuznetsov
fixup! feat(rec): Fix prompt storage
07f77e4
raw
history blame
5.76 kB
import os
import random
import json
from pathlib import Path
import gradio as gr
from pydantic import BaseModel
from vllm import LLM, SamplingParams
from vllm.sampling_params import GuidedDecodingParams
VLLM_MODEL_NAME = os.getenv("VLLM_MODEL_NAME")
VLLM_GPU_MEMORY_UTILIZATION = float(os.getenv("VLLM_GPU_MEMORY_UTILIZATION"))
VLLM_MAX_SEQ_LEN = int(os.getenv("VLLM_MAX_SEQ_LEN"))
HF_TOKEN = os.getenv("HF_TOKEN")
VLLM_DTYPE = os.getenv("VLLM_DTYPE")
# -------------------------------- HELPERS -------------------------------------
def load_prompt(path: Path) -> str:
with path.open("r") as file:
prompt = file.read()
return prompt
# -------------------------------- Data Models -------------------------------
class StructuredQueryRewriteResponse(BaseModel):
general: str | None
subjective: str | None
purpose: str | None
technical: str | None
curiosity: str | None
class QueryRewrite(BaseModel):
rewrites: list[str] | None = None
structured: StructuredQueryRewriteResponse | None = None
# -------------------------------- VLLM --------------------------------------
local_llm = LLM(
model=VLLM_MODEL_NAME,
max_model_len=VLLM_MAX_SEQ_LEN,
gpu_memory_utilization=VLLM_GPU_MEMORY_UTILIZATION,
hf_token=HF_TOKEN,
enforce_eager=True,
dtype=VLLM_DTYPE,
)
json_schema = StructuredQueryRewriteResponse.model_json_schema()
guided_decoding_params_json = GuidedDecodingParams(json=json_schema)
sampling_params_json = SamplingParams(
guided_decoding=guided_decoding_params_json,
temperature=0.7,
top_p=0.8,
repetition_penalty=1.05,
max_tokens=1024,
)
vllm_system_prompt = (
"You are a search query optimization assistant built into"
" music genre search engine, helping users discover novel music genres."
)
vllm_prompt = load_prompt(Path("./prompts/local.txt"))
# Dummy model functions for demonstration
def recommend_sadaimrec(query: str):
prompt = vllm_prompt.format(query=query)
messages = [
{"role": "system", "content": vllm_system_prompt},
{"role": "user", "content": prompt},
]
outputs = local_llm.chat(
messages=messages,
sampling_params=sampling_params_json,
)
rewrite_json = json.loads(outputs[0].outputs[0].text)
rewrite = QueryRewrite(
rewrites=[x for x in list(rewrite_json.values()) if x is not None],
structured=rewrite_json,
)
return f"SADAIMREC: response to '{rewrite.model_dump_json(indent=4)}'"
def recommend_chatgpt(query: str):
return f"CHATGPT: response to '{query}'"
# Mapping names to functions
pipelines = {
"sadaimrec": recommend_sadaimrec,
"chatgpt": recommend_chatgpt,
}
# Interface logic
def generate_responses(query):
# Randomize model order
pipeline_names = list(pipelines.keys())
random.shuffle(pipeline_names)
# Generate responses
resp1 = pipelines[pipeline_names[0]](query)
resp2 = pipelines[pipeline_names[1]](query)
# Return texts and hidden labels
return resp1, resp2, pipeline_names[0], pipeline_names[1]
# Callback to capture vote
def handle_vote(selected, label1, label2, resp1, resp2):
chosen_name = label1 if selected == "Option 1" else label2
chosen_resp = resp1 if selected == "Option 1" else resp2
print(f"User voted for {chosen_name}: '{chosen_resp}'")
return (
"Thank you for your vote! Restarting in 2 seconds...",
gr.update(active=True),
)
def reset_ui():
return (
gr.update(value="", visible=False), # hide row
gr.update(value=""), # clear query
gr.update(visible=False), # hide radio
gr.update(visible=False), # hide vote button
gr.update(value=""), # clear Option 1 text
gr.update(value=""), # clear Option 2 text
gr.update(value=""), # clear result
gr.update(active=False),
)
with gr.Blocks() as demo:
gr.Markdown("# Music Genre Recommendation Side-By-Side Comparison")
query = gr.Textbox(label="Your Query")
submit_btn = gr.Button("Submit")
# timer that resets ui after feedback is sent
reset_timer = gr.Timer(value=2.0, active=False)
# Hidden components to store model responses and names
with gr.Row(visible=False) as response_row:
response_1 = gr.Textbox(label="Option 1", interactive=False)
response_2 = gr.Textbox(label="Option 2", interactive=False)
model_label_1 = gr.Textbox(visible=False)
model_label_2 = gr.Textbox(visible=False)
# Feedback
vote = gr.Radio(
["Option 1", "Option 2"], label="Select Best Response", visible=False
)
vote_btn = gr.Button("Vote", visible=False)
result = gr.Textbox(label="Console", interactive=False)
# On submit
submit_btn.click( # generate
fn=generate_responses,
inputs=[query],
outputs=[response_1, response_2, model_label_1, model_label_2],
)
submit_btn.click( # update ui
fn=lambda: (
gr.update(visible=True),
gr.update(visible=True),
gr.update(visible=True),
),
inputs=None,
outputs=[response_row, vote, vote_btn],
)
# Feedback handling
vote_btn.click(
fn=handle_vote,
inputs=[vote, model_label_1, model_label_2, response_1, response_2],
outputs=[result, reset_timer],
)
reset_timer.tick(
fn=reset_ui,
inputs=None,
outputs=[
response_row,
query,
vote,
vote_btn,
response_1,
response_2,
result,
reset_timer,
],
trigger_mode="once",
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)