Spaces:
Paused
Paused
import json | |
import os | |
import random | |
import signal | |
import sys | |
import urllib.parse | |
from datetime import datetime | |
from pathlib import Path | |
from typing import Optional | |
from uuid import uuid4 | |
import gradio as gr | |
import numpy as np | |
import pandas as pd | |
# from dotenv import load_dotenv | |
from fastembed import SparseEmbedding, SparseTextEmbedding | |
from google import genai | |
from google.genai import types | |
from huggingface_hub import CommitScheduler | |
from pydantic import BaseModel, Field | |
from qdrant_client import QdrantClient | |
from qdrant_client import models as qmodels | |
from sentence_transformers import CrossEncoder, SentenceTransformer | |
from vllm import LLM, SamplingParams | |
from vllm.sampling_params import GuidedDecodingParams | |
# load_dotenv() | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
VLLM_MODEL_NAME = os.getenv("VLLM_MODEL_NAME") | |
VLLM_GPU_MEMORY_UTILIZATION = float(os.getenv("VLLM_GPU_MEMORY_UTILIZATION")) | |
VLLM_MAX_SEQ_LEN = int(os.getenv("VLLM_MAX_SEQ_LEN")) | |
VLLM_DTYPE = os.getenv("VLLM_DTYPE") | |
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") | |
DATA_PATH = Path(os.getenv("DATA_PATH")) | |
DB_PATH = DATA_PATH / "db" | |
FEEDBACK_REPO = os.getenv("FEEDBACK_REPO") | |
FEEDBACK_DIR = DATA_PATH / "feedback" | |
FEEDBACK_DIR.mkdir(parents=True, exist_ok=True) | |
FEEDBACK_FILE = FEEDBACK_DIR / f"votes_{uuid4()}.jsonl" | |
scheduler = CommitScheduler( | |
repo_id=FEEDBACK_REPO, | |
repo_type="dataset", | |
folder_path=FEEDBACK_DIR, | |
path_in_repo="data", | |
every=5, | |
token=HF_TOKEN, | |
private=True, | |
) | |
client = QdrantClient(path=str(DB_PATH)) | |
collection_name = "knowledge_cards" | |
num_chunks_base = 500 | |
alpha = 0.5 | |
top_k = 5 # we only want top 5 genres | |
youtube_url_template = "{genre} music playlist" | |
# -------------------------------- HELPERS ------------------------------------- | |
def load_text_resource(path: Path) -> str: | |
with path.open("r") as file: | |
resource = file.read() | |
return resource | |
def youtube_search_link_for_genre(genre: str) -> str: | |
base_url = "https://www.youtube.com/results" | |
params = { | |
"search_query": youtube_url_template.format( | |
genre=genre.replace("_", " ").lower() | |
) | |
} | |
return f"{base_url}?{urllib.parse.urlencode(params)}" | |
def generate_recommendation_string(ranking: dict[str, float]) -> str: | |
recommendation_string = "## Recommendations for You\n\n" | |
for idx, (genre, score) in enumerate(ranking.items(), start=1): | |
youtube_link = youtube_search_link_for_genre(genre=genre) | |
recommendation_string += ( | |
f"{idx}. **{genre.replace('_', ' ').capitalize()}**; " | |
f"[YouTube link]({youtube_link})\n" | |
) | |
return recommendation_string | |
def graceful_shutdown(signum, frame): | |
print(f"{signum} received - flushing feedback …", flush=True) | |
scheduler.trigger().result() | |
sys.exit(0) | |
signal.signal(signal.SIGTERM, graceful_shutdown) | |
signal.signal(signal.SIGINT, graceful_shutdown) | |
# -------------------------------- Data Models ------------------------------- | |
class StructuredQueryRewriteResponse(BaseModel): | |
general: str | None | |
subjective: str | None | |
purpose: str | None | |
technical: str | None | |
curiosity: str | None | |
class QueryRewrite(BaseModel): | |
rewrites: list[str] | None = None | |
structured: StructuredQueryRewriteResponse | None = None | |
class APIGenreRecommendation(BaseModel): | |
name: str = Field(description="Name of the music genre.") | |
score: float = Field( | |
description="Score you assign to the genre (from 0 to 1).", ge=0, le=1 | |
) | |
class APIGenreRecommendationResponse(BaseModel): | |
genres: list[APIGenreRecommendation] | |
class RetrievalResult(BaseModel): | |
chunk: str | |
genre: str | |
score: float | |
class RerankingResult(BaseModel): | |
query: str | |
genre: str | |
chunk: str | |
score: float | |
class Recommendation(BaseModel): | |
name: str | |
rank: int | |
score: Optional[float] = None | |
class PipelineResult(BaseModel): | |
query: str | |
rewrite: Optional[QueryRewrite] = None | |
retrieval_result: Optional[list[RetrievalResult]] = None | |
reranking_result: Optional[list[RerankingResult]] = None | |
recommendations: Optional[dict[str, Recommendation]] = None | |
def to_ranking(self) -> dict[str, float]: | |
if not self.recommendations: | |
return {} | |
return { | |
genre: recommendation.score | |
for genre, recommendation in self.recommendations.items() | |
} | |
# -------------------------------- VLLM -------------------------------------- | |
local_llm = LLM( | |
model=VLLM_MODEL_NAME, | |
max_model_len=VLLM_MAX_SEQ_LEN, | |
gpu_memory_utilization=VLLM_GPU_MEMORY_UTILIZATION, | |
hf_token=HF_TOKEN, | |
enforce_eager=True, | |
dtype=VLLM_DTYPE, | |
) | |
json_schema = StructuredQueryRewriteResponse.model_json_schema() | |
guided_decoding_params_json = GuidedDecodingParams(json=json_schema) | |
sampling_params_json = SamplingParams( | |
guided_decoding=guided_decoding_params_json, | |
temperature=0.7, | |
top_p=0.8, | |
repetition_penalty=1.05, | |
max_tokens=1024, | |
) | |
vllm_system_prompt = ( | |
"You are a search query optimization assistant built into" | |
" music genre search engine, helping users discover novel music genres." | |
) | |
vllm_prompt = load_text_resource(Path("./resources/prompt_vllm.md")) | |
# -------------------------------- GEMINI ------------------------------------ | |
gemini_config = types.GenerateContentConfig( | |
response_mime_type="application/json", | |
response_schema=APIGenreRecommendationResponse, | |
temperature=0.7, | |
max_output_tokens=1024, | |
system_instruction=( | |
"You are a helpful music genre recommendation assistant built into" | |
" music genre search engine, helping users discover novel music genres." | |
), | |
) | |
gemini_llm = genai.Client( | |
api_key=GEMINI_API_KEY, | |
http_options={"api_version": "v1alpha"}, | |
) | |
gemini_prompt = load_text_resource(Path("./resources/prompt_api.md")) | |
# ---------------------------- EMBEDDING MODELS -------------------------------- | |
dense_encoder = SentenceTransformer( | |
model_name_or_path="mixedbread-ai/mxbai-embed-large-v1", | |
device="cuda", | |
model_kwargs={"torch_dtype": VLLM_DTYPE}, | |
) | |
sparse_encoder = SparseTextEmbedding(model_name="Qdrant/bm25", cuda=True) | |
reranker = CrossEncoder( | |
model_name_or_path="BAAI/bge-reranker-v2-m3", | |
max_length=1024, | |
device="cuda", | |
model_kwargs={"torch_dtype": VLLM_DTYPE}, | |
) | |
reranker_batch_size = 128 | |
# ---------------------------- RETRIEVAL --------------------------------------- | |
def run_query_rewrite(query: str) -> QueryRewrite: | |
prompt = vllm_prompt.format(query=query) | |
messages = [ | |
{"role": "system", "content": vllm_system_prompt}, | |
{"role": "user", "content": prompt}, | |
] | |
outputs = local_llm.chat( | |
messages=messages, | |
sampling_params=sampling_params_json, | |
) | |
rewrite_json = json.loads(outputs[0].outputs[0].text) | |
rewrite = QueryRewrite( | |
rewrites=[x for x in list(rewrite_json.values()) if x is not None], | |
structured=rewrite_json, | |
) | |
return rewrite | |
def prepare_queries_for_retrieval( | |
query: str, rewrite: QueryRewrite | |
) -> list[dict[str, str | None]]: | |
queries_to_retrieve = [{"text": query, "topic": None}] | |
for cat, rewrite in rewrite.structured.model_dump().items(): | |
if rewrite is None: | |
continue | |
topic = cat | |
if cat not in ["subjective", "purpose", "technical"]: | |
topic = None | |
queries_to_retrieve.append({"text": rewrite, "topic": topic}) | |
return queries_to_retrieve | |
def run_retrieval( | |
queries: list[dict[str, str]], | |
) -> RetrievalResult: | |
queries_to_embed = [query["text"] for query in queries] | |
dense_queries = list( | |
dense_encoder.encode( | |
queries_to_embed, convert_to_numpy=True, normalize_embeddings=True | |
) | |
) | |
sparse_queries = list(sparse_encoder.query_embed(queries_to_embed)) | |
prefetches: list[qmodels.Prefetch] = [] | |
for query, dense_query, sparse_query in zip(queries, dense_queries, sparse_queries): | |
assert dense_query is not None and sparse_query is not None | |
assert isinstance(dense_query, np.ndarray) and isinstance( | |
sparse_query, SparseEmbedding | |
) | |
topic = query.get("topic", None) | |
prefetch = [ | |
qmodels.Prefetch( | |
query=dense_query, | |
using="dense", | |
filter=qmodels.Filter( | |
must=[ | |
qmodels.FieldCondition( | |
key="topic", match=qmodels.MatchValue(value=topic) | |
) | |
] | |
) | |
if topic is not None | |
else None, | |
limit=num_chunks_base, | |
), | |
qmodels.Prefetch( | |
query=qmodels.SparseVector(**sparse_query.as_object()), | |
using="sparse", | |
filter=qmodels.Filter( | |
must=[ | |
qmodels.FieldCondition( | |
key="topic", match=qmodels.MatchValue(value=topic) | |
) | |
] | |
) | |
if topic is not None | |
else None, | |
limit=num_chunks_base, | |
), | |
] | |
prefetches.extend(prefetch) | |
retrieval_results = client.query_points( | |
collection_name=collection_name, | |
prefetch=prefetches, | |
query=qmodels.FusionQuery(fusion=qmodels.Fusion.RRF), | |
limit=num_chunks_base, | |
) | |
final_hits: list[RetrievalResult] = [ | |
RetrievalResult( | |
chunk=hit.payload["text"], genre=hit.payload["genre"], score=hit.score | |
) | |
for hit in retrieval_results.points | |
] | |
return final_hits | |
def run_reranking( | |
query: str, retrieval_result: list[RetrievalResult] | |
) -> list[RerankingResult]: | |
hit_texts: list[str] = [result.chunk for result in retrieval_result] | |
hit_genres: list[str] = [result.genre for result in retrieval_result] | |
hit_rerank = reranker.rank( | |
query=query, | |
documents=hit_texts, | |
batch_size=reranker_batch_size, | |
) | |
ranking = [ | |
RerankingResult( | |
query=query, | |
genre=hit_genres[hit["corpus_id"]], | |
chunk=hit_texts[hit["corpus_id"]], | |
score=hit["score"], | |
) | |
for hit in hit_rerank | |
] | |
ranking.sort(key=lambda x: x.score, reverse=True) | |
return ranking | |
def get_top_genres( | |
df: pd.DataFrame, | |
column: str, | |
alpha: float = 1.0, | |
# beta: float = 1.0, | |
top_k: int | None = None, | |
) -> pd.Series: | |
assert 0 <= alpha <= 1.0 | |
# Min-max normalization of re-ranker scores before aggregation | |
task_scores = df[column] | |
min_score = task_scores.min() | |
max_score = task_scores.max() | |
if max_score > min_score: # Avoid division by zero | |
df.loc[:, column] = (task_scores - min_score) / (max_score - min_score) | |
tg_df = df.groupby("genre").agg(size=("chunk", "size"), score=(column, "sum")) | |
tg_df["weighted_score"] = alpha * (tg_df["size"] / tg_df["size"].max()) + ( | |
1 - alpha | |
) * (tg_df["score"] / tg_df["score"].max()) | |
tg = tg_df.sort_values("weighted_score", ascending=False)["weighted_score"] | |
if top_k: | |
tg = tg.head(top_k) | |
return tg | |
def get_recommendations( | |
reranking_result: list[RerankingResult], | |
) -> dict[str, Recommendation]: | |
ranking_df = pd.DataFrame([x.model_dump(mode="python") for x in reranking_result]) | |
top_genres_series = get_top_genres( | |
df=ranking_df, column="score", alpha=alpha, top_k=top_k | |
) | |
recommendations = { | |
genre: Recommendation(name=genre, rank=rank, score=score) | |
for rank, (genre, score) in enumerate( | |
top_genres_series.to_dict().items(), start=1 | |
) | |
} | |
return recommendations | |
# ----------------------- GENERATE RECOMMENDATIONS ----------------------------- | |
def recommend_sadaimrec(query: str): | |
result = PipelineResult(query=query) | |
print("Running query processing...", flush=True) | |
result.rewrite = run_query_rewrite(query=query) | |
print(f"Rewrites:\n{result.rewrite.model_dump_json(indent=4)}") | |
queries_to_retrieve = prepare_queries_for_retrieval( | |
query=query, rewrite=result.rewrite | |
) | |
print("Running retrieval...", flush=True) | |
result.retrieval_result = run_retrieval(queries_to_retrieve) | |
print("Running re-ranking...", flush=True) | |
result.reranking_result = run_reranking( | |
query=query, retrieval_result=result.retrieval_result | |
) | |
print("Aggregating recommendations...", flush=True) | |
result.recommendations = get_recommendations(result.reranking_result) | |
recommendation_string = generate_recommendation_string(result.to_ranking()) | |
return f"{recommendation_string}" | |
def recommend_gemini(query: str): | |
print("Generating recommendations using Gemini...", flush=True) | |
prompt = gemini_prompt.format(query=query) | |
response = gemini_llm.models.generate_content( | |
model="gemini-2.0-flash", | |
contents=prompt, | |
config=gemini_config, | |
) | |
parsed_content: APIGenreRecommendationResponse = response.parsed | |
parsed_content.genres.sort(key=lambda x: x.score, reverse=True) | |
ranking = {x.name.lower(): x.score for x in parsed_content.genres} | |
recommendation_string = generate_recommendation_string(ranking) | |
return f"{recommendation_string}" | |
# -------------------------------------- INTERFACE ----------------------------- | |
pipelines = { | |
"sadaimrec": recommend_sadaimrec, | |
"gemini": recommend_gemini, | |
} | |
def generate_responses(query): | |
if not query.strip(): | |
raise gr.Error("Please enter a query before submitting.") | |
# Randomize model order | |
pipeline_names = list(pipelines.keys()) | |
random.shuffle(pipeline_names) | |
# Generate responses | |
resp1 = pipelines[pipeline_names[0]](query) | |
resp2 = pipelines[pipeline_names[1]](query) | |
# Return texts and hidden labels | |
return resp1, resp2, pipeline_names[0], pipeline_names[1] | |
# Callback to capture vote | |
def handle_vote(nickname, query, selected, label1, label2, resp1, resp2): | |
nick = nickname.strip() or uuid4().hex[:8] | |
winner_name, loser_name = ( | |
(label1, label2) if selected == "Option 1 (left)" else (label2, label1) | |
) | |
winner_resp, loser_resp = ( | |
(resp1, resp2) if selected == "Option 1 (left)" else (resp2, resp1) | |
) | |
print( | |
( | |
f"User voted:\nwinner = {winner_name}: {winner_resp};" | |
f" loser = {loser_name}: {loser_resp}" | |
), | |
flush=True, | |
) | |
# ---------- persist feedback locally ---------- | |
entry = { | |
"ts": datetime.now().isoformat(timespec="seconds") + "Z", | |
"nickname": nick, | |
"query": query, | |
"winner": winner_name, | |
"loser": loser_name, | |
"winner_response": winner_resp, | |
"loser_response": loser_resp, | |
} | |
with FEEDBACK_FILE.open("a", encoding="utf-8") as f: | |
f.write(json.dumps(entry) + "\n") | |
return ( | |
f"Thank you for your vote! Winner: {winner_name}. Restarting in 3 seconds...", | |
gr.update(active=True), | |
gr.update(value=nick), | |
) | |
def reset_ui(): | |
return ( | |
gr.update(value="", visible=False), # hide row | |
gr.update(value=""), # clear query | |
gr.update(visible=False), # hide radio | |
gr.update(visible=False), # hide vote button | |
gr.update(value="**Generating...**"), # clear Option 1 text | |
gr.update(value="**Generating...**"), # clear Option 2 text | |
gr.update(value=""), # clear Model Label 1 text | |
gr.update(value=""), # clear Model Label 2 text | |
gr.update(value=""), # clear result | |
gr.update(active=False), | |
) | |
app_description = load_text_resource(Path("./resources/description.md")) | |
app_instructions = load_text_resource(Path("./resources/instructions.md")) | |
with gr.Blocks( | |
title="sadai-mrec", theme=gr.themes.Default(text_size=gr.themes.sizes.text_lg) | |
) as demo: | |
gr.Markdown(app_description) | |
with gr.Accordion("Detailed usage instructions", open=False): | |
gr.Markdown(app_instructions) | |
nickname = gr.Textbox( | |
label="Your nickname", | |
placeholder="Leave empty to generate a random nickname on first vote within session", | |
) | |
query = gr.Textbox( | |
label="Your Query", | |
placeholder="Calming, music for deep relaxation with echoing sounds and deep bass", | |
) | |
submit_btn = gr.Button("Submit") | |
# timer that resets ui after feedback is sent | |
reset_timer = gr.Timer(value=3.0, active=False) | |
# Hidden components to store model responses and names | |
with gr.Row(visible=False) as response_row: | |
response_1 = gr.Markdown(value="**Generating...**", label="Option 1") | |
response_2 = gr.Markdown(value="**Generating...**", label="Option 2") | |
model_label_1 = gr.Textbox(visible=False) | |
model_label_2 = gr.Textbox(visible=False) | |
# Feedback | |
vote = gr.Radio( | |
["Option 1 (left)", "Option 2 (right)"], | |
label="Select Best Response", | |
visible=False, | |
) | |
vote_btn = gr.Button("Vote", visible=False) | |
result = gr.Textbox(label="Console", interactive=False) | |
# On submit | |
submit_btn.click( # generate | |
fn=generate_responses, | |
inputs=[query], | |
outputs=[response_1, response_2, model_label_1, model_label_2], | |
show_progress="full", | |
) | |
submit_btn.click( # update ui | |
fn=lambda: ( | |
gr.update(visible=True), | |
gr.update(visible=True), | |
gr.update(visible=True), | |
), | |
inputs=None, | |
outputs=[response_row, vote, vote_btn], | |
) | |
# Feedback handling | |
vote_btn.click( | |
fn=handle_vote, | |
inputs=[ | |
nickname, | |
query, | |
vote, | |
model_label_1, | |
model_label_2, | |
response_1, | |
response_2, | |
], | |
outputs=[result, reset_timer, nickname], | |
) | |
reset_timer.tick( | |
fn=reset_ui, | |
inputs=None, | |
outputs=[ | |
response_row, | |
query, | |
vote, | |
vote_btn, | |
response_1, | |
response_2, | |
model_label_1, | |
model_label_2, | |
result, | |
reset_timer, | |
], | |
trigger_mode="once", | |
) | |
if __name__ == "__main__": | |
demo.queue(max_size=10, default_concurrency_limit=1).launch( | |
server_name="0.0.0.0", server_port=7860 | |
) | |