import json import os import random import signal import sys import urllib.parse from datetime import datetime from pathlib import Path from typing import Optional from uuid import uuid4 import gradio as gr import numpy as np import pandas as pd # from dotenv import load_dotenv from fastembed import SparseEmbedding, SparseTextEmbedding from google import genai from google.genai import types from huggingface_hub import CommitScheduler from pydantic import BaseModel, Field from qdrant_client import QdrantClient from qdrant_client import models as qmodels from sentence_transformers import CrossEncoder, SentenceTransformer from vllm import LLM, SamplingParams from vllm.sampling_params import GuidedDecodingParams # load_dotenv() HF_TOKEN = os.getenv("HF_TOKEN") VLLM_MODEL_NAME = os.getenv("VLLM_MODEL_NAME") VLLM_GPU_MEMORY_UTILIZATION = float(os.getenv("VLLM_GPU_MEMORY_UTILIZATION")) VLLM_MAX_SEQ_LEN = int(os.getenv("VLLM_MAX_SEQ_LEN")) VLLM_DTYPE = os.getenv("VLLM_DTYPE") GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") DATA_PATH = Path(os.getenv("DATA_PATH")) DB_PATH = DATA_PATH / "db" FEEDBACK_REPO = os.getenv("FEEDBACK_REPO") FEEDBACK_DIR = DATA_PATH / "feedback" FEEDBACK_DIR.mkdir(parents=True, exist_ok=True) FEEDBACK_FILE = FEEDBACK_DIR / f"votes_{uuid4()}.jsonl" scheduler = CommitScheduler( repo_id=FEEDBACK_REPO, repo_type="dataset", folder_path=FEEDBACK_DIR, path_in_repo="data", every=5, token=HF_TOKEN, private=True, ) client = QdrantClient(path=str(DB_PATH)) collection_name = "knowledge_cards" num_chunks_base = 500 alpha = 0.5 top_k = 5 # we only want top 5 genres youtube_url_template = "{genre} music playlist" # -------------------------------- HELPERS ------------------------------------- def load_text_resource(path: Path) -> str: with path.open("r") as file: resource = file.read() return resource def youtube_search_link_for_genre(genre: str) -> str: base_url = "https://www.youtube.com/results" params = { "search_query": youtube_url_template.format( genre=genre.replace("_", " ").lower() ) } return f"{base_url}?{urllib.parse.urlencode(params)}" def generate_recommendation_string(ranking: dict[str, float]) -> str: recommendation_string = "## Recommendations for You\n\n" for idx, (genre, score) in enumerate(ranking.items(), start=1): youtube_link = youtube_search_link_for_genre(genre=genre) recommendation_string += ( f"{idx}. **{genre.replace('_', ' ').capitalize()}**; " f"[YouTube link]({youtube_link})\n" ) return recommendation_string def graceful_shutdown(signum, frame): print(f"{signum} received - flushing feedback …", flush=True) scheduler.trigger().result() sys.exit(0) signal.signal(signal.SIGTERM, graceful_shutdown) signal.signal(signal.SIGINT, graceful_shutdown) # -------------------------------- Data Models ------------------------------- class StructuredQueryRewriteResponse(BaseModel): general: str | None subjective: str | None purpose: str | None technical: str | None curiosity: str | None class QueryRewrite(BaseModel): rewrites: list[str] | None = None structured: StructuredQueryRewriteResponse | None = None class APIGenreRecommendation(BaseModel): name: str = Field(description="Name of the music genre.") score: float = Field( description="Score you assign to the genre (from 0 to 1).", ge=0, le=1 ) class APIGenreRecommendationResponse(BaseModel): genres: list[APIGenreRecommendation] class RetrievalResult(BaseModel): chunk: str genre: str score: float class RerankingResult(BaseModel): query: str genre: str chunk: str score: float class Recommendation(BaseModel): name: str rank: int score: Optional[float] = None class PipelineResult(BaseModel): query: str rewrite: Optional[QueryRewrite] = None retrieval_result: Optional[list[RetrievalResult]] = None reranking_result: Optional[list[RerankingResult]] = None recommendations: Optional[dict[str, Recommendation]] = None def to_ranking(self) -> dict[str, float]: if not self.recommendations: return {} return { genre: recommendation.score for genre, recommendation in self.recommendations.items() } # -------------------------------- VLLM -------------------------------------- local_llm = LLM( model=VLLM_MODEL_NAME, max_model_len=VLLM_MAX_SEQ_LEN, gpu_memory_utilization=VLLM_GPU_MEMORY_UTILIZATION, hf_token=HF_TOKEN, enforce_eager=True, dtype=VLLM_DTYPE, ) json_schema = StructuredQueryRewriteResponse.model_json_schema() guided_decoding_params_json = GuidedDecodingParams(json=json_schema) sampling_params_json = SamplingParams( guided_decoding=guided_decoding_params_json, temperature=0.7, top_p=0.8, repetition_penalty=1.05, max_tokens=1024, ) vllm_system_prompt = ( "You are a search query optimization assistant built into" " music genre search engine, helping users discover novel music genres." ) vllm_prompt = load_text_resource(Path("./resources/prompt_vllm.md")) # -------------------------------- GEMINI ------------------------------------ gemini_config = types.GenerateContentConfig( response_mime_type="application/json", response_schema=APIGenreRecommendationResponse, temperature=0.7, max_output_tokens=1024, system_instruction=( "You are a helpful music genre recommendation assistant built into" " music genre search engine, helping users discover novel music genres." ), ) gemini_llm = genai.Client( api_key=GEMINI_API_KEY, http_options={"api_version": "v1alpha"}, ) gemini_prompt = load_text_resource(Path("./resources/prompt_api.md")) # ---------------------------- EMBEDDING MODELS -------------------------------- dense_encoder = SentenceTransformer( model_name_or_path="mixedbread-ai/mxbai-embed-large-v1", device="cuda", model_kwargs={"torch_dtype": VLLM_DTYPE}, ) sparse_encoder = SparseTextEmbedding(model_name="Qdrant/bm25", cuda=True) reranker = CrossEncoder( model_name_or_path="BAAI/bge-reranker-v2-m3", max_length=1024, device="cuda", model_kwargs={"torch_dtype": VLLM_DTYPE}, ) reranker_batch_size = 128 # ---------------------------- RETRIEVAL --------------------------------------- def run_query_rewrite(query: str) -> QueryRewrite: prompt = vllm_prompt.format(query=query) messages = [ {"role": "system", "content": vllm_system_prompt}, {"role": "user", "content": prompt}, ] outputs = local_llm.chat( messages=messages, sampling_params=sampling_params_json, ) rewrite_json = json.loads(outputs[0].outputs[0].text) rewrite = QueryRewrite( rewrites=[x for x in list(rewrite_json.values()) if x is not None], structured=rewrite_json, ) return rewrite def prepare_queries_for_retrieval( query: str, rewrite: QueryRewrite ) -> list[dict[str, str | None]]: queries_to_retrieve = [{"text": query, "topic": None}] for cat, rewrite in rewrite.structured.model_dump().items(): if rewrite is None: continue topic = cat if cat not in ["subjective", "purpose", "technical"]: topic = None queries_to_retrieve.append({"text": rewrite, "topic": topic}) return queries_to_retrieve def run_retrieval( queries: list[dict[str, str]], ) -> RetrievalResult: queries_to_embed = [query["text"] for query in queries] dense_queries = list( dense_encoder.encode( queries_to_embed, convert_to_numpy=True, normalize_embeddings=True ) ) sparse_queries = list(sparse_encoder.query_embed(queries_to_embed)) prefetches: list[qmodels.Prefetch] = [] for query, dense_query, sparse_query in zip(queries, dense_queries, sparse_queries): assert dense_query is not None and sparse_query is not None assert isinstance(dense_query, np.ndarray) and isinstance( sparse_query, SparseEmbedding ) topic = query.get("topic", None) prefetch = [ qmodels.Prefetch( query=dense_query, using="dense", filter=qmodels.Filter( must=[ qmodels.FieldCondition( key="topic", match=qmodels.MatchValue(value=topic) ) ] ) if topic is not None else None, limit=num_chunks_base, ), qmodels.Prefetch( query=qmodels.SparseVector(**sparse_query.as_object()), using="sparse", filter=qmodels.Filter( must=[ qmodels.FieldCondition( key="topic", match=qmodels.MatchValue(value=topic) ) ] ) if topic is not None else None, limit=num_chunks_base, ), ] prefetches.extend(prefetch) retrieval_results = client.query_points( collection_name=collection_name, prefetch=prefetches, query=qmodels.FusionQuery(fusion=qmodels.Fusion.RRF), limit=num_chunks_base, ) final_hits: list[RetrievalResult] = [ RetrievalResult( chunk=hit.payload["text"], genre=hit.payload["genre"], score=hit.score ) for hit in retrieval_results.points ] return final_hits def run_reranking( query: str, retrieval_result: list[RetrievalResult] ) -> list[RerankingResult]: hit_texts: list[str] = [result.chunk for result in retrieval_result] hit_genres: list[str] = [result.genre for result in retrieval_result] hit_rerank = reranker.rank( query=query, documents=hit_texts, batch_size=reranker_batch_size, ) ranking = [ RerankingResult( query=query, genre=hit_genres[hit["corpus_id"]], chunk=hit_texts[hit["corpus_id"]], score=hit["score"], ) for hit in hit_rerank ] ranking.sort(key=lambda x: x.score, reverse=True) return ranking def get_top_genres( df: pd.DataFrame, column: str, alpha: float = 1.0, # beta: float = 1.0, top_k: int | None = None, ) -> pd.Series: assert 0 <= alpha <= 1.0 # Min-max normalization of re-ranker scores before aggregation task_scores = df[column] min_score = task_scores.min() max_score = task_scores.max() if max_score > min_score: # Avoid division by zero df.loc[:, column] = (task_scores - min_score) / (max_score - min_score) tg_df = df.groupby("genre").agg(size=("chunk", "size"), score=(column, "sum")) tg_df["weighted_score"] = alpha * (tg_df["size"] / tg_df["size"].max()) + ( 1 - alpha ) * (tg_df["score"] / tg_df["score"].max()) tg = tg_df.sort_values("weighted_score", ascending=False)["weighted_score"] if top_k: tg = tg.head(top_k) return tg def get_recommendations( reranking_result: list[RerankingResult], ) -> dict[str, Recommendation]: ranking_df = pd.DataFrame([x.model_dump(mode="python") for x in reranking_result]) top_genres_series = get_top_genres( df=ranking_df, column="score", alpha=alpha, top_k=top_k ) recommendations = { genre: Recommendation(name=genre, rank=rank, score=score) for rank, (genre, score) in enumerate( top_genres_series.to_dict().items(), start=1 ) } return recommendations # ----------------------- GENERATE RECOMMENDATIONS ----------------------------- def recommend_sadaimrec(query: str): result = PipelineResult(query=query) print("Running query processing...", flush=True) result.rewrite = run_query_rewrite(query=query) print(f"Rewrites:\n{result.rewrite.model_dump_json(indent=4)}") queries_to_retrieve = prepare_queries_for_retrieval( query=query, rewrite=result.rewrite ) print("Running retrieval...", flush=True) result.retrieval_result = run_retrieval(queries_to_retrieve) print("Running re-ranking...", flush=True) result.reranking_result = run_reranking( query=query, retrieval_result=result.retrieval_result ) print("Aggregating recommendations...", flush=True) result.recommendations = get_recommendations(result.reranking_result) recommendation_string = generate_recommendation_string(result.to_ranking()) return f"{recommendation_string}" def recommend_gemini(query: str): print("Generating recommendations using Gemini...", flush=True) prompt = gemini_prompt.format(query=query) response = gemini_llm.models.generate_content( model="gemini-2.0-flash", contents=prompt, config=gemini_config, ) parsed_content: APIGenreRecommendationResponse = response.parsed parsed_content.genres.sort(key=lambda x: x.score, reverse=True) ranking = {x.name.lower(): x.score for x in parsed_content.genres} recommendation_string = generate_recommendation_string(ranking) return f"{recommendation_string}" # -------------------------------------- INTERFACE ----------------------------- pipelines = { "sadaimrec": recommend_sadaimrec, "gemini": recommend_gemini, } def generate_responses(query): if not query.strip(): raise gr.Error("Please enter a query before submitting.") # Randomize model order pipeline_names = list(pipelines.keys()) random.shuffle(pipeline_names) # Generate responses resp1 = pipelines[pipeline_names[0]](query) resp2 = pipelines[pipeline_names[1]](query) # Return texts and hidden labels return resp1, resp2, pipeline_names[0], pipeline_names[1] # Callback to capture vote def handle_vote(nickname, query, selected, label1, label2, resp1, resp2): nick = nickname.strip() or uuid4().hex[:8] winner_name, loser_name = ( (label1, label2) if selected == "Option 1 (left)" else (label2, label1) ) winner_resp, loser_resp = ( (resp1, resp2) if selected == "Option 1 (left)" else (resp2, resp1) ) print( ( f"User voted:\nwinner = {winner_name}: {winner_resp};" f" loser = {loser_name}: {loser_resp}" ), flush=True, ) # ---------- persist feedback locally ---------- entry = { "ts": datetime.now().isoformat(timespec="seconds") + "Z", "nickname": nick, "query": query, "winner": winner_name, "loser": loser_name, "winner_response": winner_resp, "loser_response": loser_resp, } with FEEDBACK_FILE.open("a", encoding="utf-8") as f: f.write(json.dumps(entry) + "\n") return ( f"Thank you for your vote! Winner: {winner_name}. Restarting in 3 seconds...", gr.update(active=True), gr.update(value=nick), ) def reset_ui(): return ( gr.update(value="", visible=False), # hide row gr.update(value=""), # clear query gr.update(visible=False), # hide radio gr.update(visible=False), # hide vote button gr.update(value="**Generating...**"), # clear Option 1 text gr.update(value="**Generating...**"), # clear Option 2 text gr.update(value=""), # clear Model Label 1 text gr.update(value=""), # clear Model Label 2 text gr.update(value=""), # clear result gr.update(active=False), ) app_description = load_text_resource(Path("./resources/description.md")) app_instructions = load_text_resource(Path("./resources/instructions.md")) with gr.Blocks( title="sadai-mrec", theme=gr.themes.Default(text_size=gr.themes.sizes.text_lg) ) as demo: gr.Markdown(app_description) with gr.Accordion("Detailed usage instructions", open=False): gr.Markdown(app_instructions) nickname = gr.Textbox( label="Your nickname", placeholder="Leave empty to generate a random nickname on first vote within session", ) query = gr.Textbox( label="Your Query", placeholder="Calming, music for deep relaxation with echoing sounds and deep bass", ) submit_btn = gr.Button("Submit") # timer that resets ui after feedback is sent reset_timer = gr.Timer(value=3.0, active=False) # Hidden components to store model responses and names with gr.Row(visible=False) as response_row: response_1 = gr.Markdown(value="**Generating...**", label="Option 1") response_2 = gr.Markdown(value="**Generating...**", label="Option 2") model_label_1 = gr.Textbox(visible=False) model_label_2 = gr.Textbox(visible=False) # Feedback vote = gr.Radio( ["Option 1 (left)", "Option 2 (right)"], label="Select Best Response", visible=False, ) vote_btn = gr.Button("Vote", visible=False) result = gr.Textbox(label="Console", interactive=False) # On submit submit_btn.click( # generate fn=generate_responses, inputs=[query], outputs=[response_1, response_2, model_label_1, model_label_2], show_progress="full", ) submit_btn.click( # update ui fn=lambda: ( gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), ), inputs=None, outputs=[response_row, vote, vote_btn], ) # Feedback handling vote_btn.click( fn=handle_vote, inputs=[ nickname, query, vote, model_label_1, model_label_2, response_1, response_2, ], outputs=[result, reset_timer, nickname], ) reset_timer.tick( fn=reset_ui, inputs=None, outputs=[ response_row, query, vote, vote_btn, response_1, response_2, model_label_1, model_label_2, result, reset_timer, ], trigger_mode="once", ) if __name__ == "__main__": demo.queue(max_size=10, default_concurrency_limit=1).launch( server_name="0.0.0.0", server_port=7860 )