Spaces:
Paused
Paused
File size: 5,755 Bytes
9aa37ee 7058ffd 9aa37ee f7972c6 9aa37ee 07f77e4 9aa37ee f7972c6 9aa37ee 07f77e4 9aa37ee f7972c6 7397d2d 7058ffd f7972c6 9aa37ee 7397d2d 7058ffd 7397d2d 7058ffd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 |
import os
import random
import json
from pathlib import Path
import gradio as gr
from pydantic import BaseModel
from vllm import LLM, SamplingParams
from vllm.sampling_params import GuidedDecodingParams
VLLM_MODEL_NAME = os.getenv("VLLM_MODEL_NAME")
VLLM_GPU_MEMORY_UTILIZATION = float(os.getenv("VLLM_GPU_MEMORY_UTILIZATION"))
VLLM_MAX_SEQ_LEN = int(os.getenv("VLLM_MAX_SEQ_LEN"))
HF_TOKEN = os.getenv("HF_TOKEN")
VLLM_DTYPE = os.getenv("VLLM_DTYPE")
# -------------------------------- HELPERS -------------------------------------
def load_prompt(path: Path) -> str:
with path.open("r") as file:
prompt = file.read()
return prompt
# -------------------------------- Data Models -------------------------------
class StructuredQueryRewriteResponse(BaseModel):
general: str | None
subjective: str | None
purpose: str | None
technical: str | None
curiosity: str | None
class QueryRewrite(BaseModel):
rewrites: list[str] | None = None
structured: StructuredQueryRewriteResponse | None = None
# -------------------------------- VLLM --------------------------------------
local_llm = LLM(
model=VLLM_MODEL_NAME,
max_model_len=VLLM_MAX_SEQ_LEN,
gpu_memory_utilization=VLLM_GPU_MEMORY_UTILIZATION,
hf_token=HF_TOKEN,
enforce_eager=True,
dtype=VLLM_DTYPE,
)
json_schema = StructuredQueryRewriteResponse.model_json_schema()
guided_decoding_params_json = GuidedDecodingParams(json=json_schema)
sampling_params_json = SamplingParams(
guided_decoding=guided_decoding_params_json,
temperature=0.7,
top_p=0.8,
repetition_penalty=1.05,
max_tokens=1024,
)
vllm_system_prompt = (
"You are a search query optimization assistant built into"
" music genre search engine, helping users discover novel music genres."
)
vllm_prompt = load_prompt(Path("./prompts/local.txt"))
# Dummy model functions for demonstration
def recommend_sadaimrec(query: str):
prompt = vllm_prompt.format(query=query)
messages = [
{"role": "system", "content": vllm_system_prompt},
{"role": "user", "content": prompt},
]
outputs = local_llm.chat(
messages=messages,
sampling_params=sampling_params_json,
)
rewrite_json = json.loads(outputs[0].outputs[0].text)
rewrite = QueryRewrite(
rewrites=[x for x in list(rewrite_json.values()) if x is not None],
structured=rewrite_json,
)
return f"SADAIMREC: response to '{rewrite.model_dump_json(indent=4)}'"
def recommend_chatgpt(query: str):
return f"CHATGPT: response to '{query}'"
# Mapping names to functions
pipelines = {
"sadaimrec": recommend_sadaimrec,
"chatgpt": recommend_chatgpt,
}
# Interface logic
def generate_responses(query):
# Randomize model order
pipeline_names = list(pipelines.keys())
random.shuffle(pipeline_names)
# Generate responses
resp1 = pipelines[pipeline_names[0]](query)
resp2 = pipelines[pipeline_names[1]](query)
# Return texts and hidden labels
return resp1, resp2, pipeline_names[0], pipeline_names[1]
# Callback to capture vote
def handle_vote(selected, label1, label2, resp1, resp2):
chosen_name = label1 if selected == "Option 1" else label2
chosen_resp = resp1 if selected == "Option 1" else resp2
print(f"User voted for {chosen_name}: '{chosen_resp}'")
return (
"Thank you for your vote! Restarting in 2 seconds...",
gr.update(active=True),
)
def reset_ui():
return (
gr.update(value="", visible=False), # hide row
gr.update(value=""), # clear query
gr.update(visible=False), # hide radio
gr.update(visible=False), # hide vote button
gr.update(value=""), # clear Option 1 text
gr.update(value=""), # clear Option 2 text
gr.update(value=""), # clear result
gr.update(active=False),
)
with gr.Blocks() as demo:
gr.Markdown("# Music Genre Recommendation Side-By-Side Comparison")
query = gr.Textbox(label="Your Query")
submit_btn = gr.Button("Submit")
# timer that resets ui after feedback is sent
reset_timer = gr.Timer(value=2.0, active=False)
# Hidden components to store model responses and names
with gr.Row(visible=False) as response_row:
response_1 = gr.Textbox(label="Option 1", interactive=False)
response_2 = gr.Textbox(label="Option 2", interactive=False)
model_label_1 = gr.Textbox(visible=False)
model_label_2 = gr.Textbox(visible=False)
# Feedback
vote = gr.Radio(
["Option 1", "Option 2"], label="Select Best Response", visible=False
)
vote_btn = gr.Button("Vote", visible=False)
result = gr.Textbox(label="Console", interactive=False)
# On submit
submit_btn.click( # generate
fn=generate_responses,
inputs=[query],
outputs=[response_1, response_2, model_label_1, model_label_2],
)
submit_btn.click( # update ui
fn=lambda: (
gr.update(visible=True),
gr.update(visible=True),
gr.update(visible=True),
),
inputs=None,
outputs=[response_row, vote, vote_btn],
)
# Feedback handling
vote_btn.click(
fn=handle_vote,
inputs=[vote, model_label_1, model_label_2, response_1, response_2],
outputs=[result, reset_timer],
)
reset_timer.tick(
fn=reset_ui,
inputs=None,
outputs=[
response_row,
query,
vote,
vote_btn,
response_1,
response_2,
result,
reset_timer,
],
trigger_mode="once",
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)
|