Spaces:
Paused
Paused
Oleh Kuznetsov
commited on
Commit
·
bdaca7e
1
Parent(s):
6e1997a
feat(rec): Finalize recommendations (almost done)
Browse files- .gitignore +2 -1
- Dockerfile +1 -1
- app.py +307 -24
- ingest.py +6 -2
- prompts/api.txt +0 -7
- resources/description.md +33 -0
- resources/prompt_api.md +12 -0
- prompts/local.txt → resources/prompt_vllm.md +1 -1
.gitignore
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
*__pycache__*
|
2 |
.venv
|
3 |
.env
|
4 |
-
data
|
|
|
|
1 |
*__pycache__*
|
2 |
.venv
|
3 |
.env
|
4 |
+
data
|
5 |
+
*sandbox*
|
Dockerfile
CHANGED
@@ -31,7 +31,7 @@ ENV HOME=/home/user \
|
|
31 |
|
32 |
# Setup application directory
|
33 |
WORKDIR $HOME/app
|
34 |
-
ADD --chown=user ./
|
35 |
ADD --chown=user ./ingest.py $HOME/app/ingest.py
|
36 |
ADD --chown=user ./app.py $HOME/app/app.py
|
37 |
|
|
|
31 |
|
32 |
# Setup application directory
|
33 |
WORKDIR $HOME/app
|
34 |
+
ADD --chown=user ./resources $HOME/app/resources
|
35 |
ADD --chown=user ./ingest.py $HOME/app/ingest.py
|
36 |
ADD --chown=user ./app.py $HOME/app/app.py
|
37 |
|
app.py
CHANGED
@@ -1,15 +1,25 @@
|
|
1 |
import json
|
2 |
import os
|
3 |
import random
|
|
|
4 |
from pathlib import Path
|
|
|
5 |
|
6 |
import gradio as gr
|
|
|
|
|
|
|
|
|
7 |
from google import genai
|
8 |
from google.genai import types
|
9 |
-
from pydantic import BaseModel
|
|
|
|
|
|
|
10 |
from vllm import LLM, SamplingParams
|
11 |
from vllm.sampling_params import GuidedDecodingParams
|
12 |
|
|
|
13 |
|
14 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
15 |
|
@@ -20,11 +30,44 @@ VLLM_DTYPE = os.getenv("VLLM_DTYPE")
|
|
20 |
|
21 |
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
|
22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
# -------------------------------- HELPERS -------------------------------------
|
24 |
-
def
|
25 |
with path.open("r") as file:
|
26 |
-
|
27 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
|
30 |
# -------------------------------- Data Models -------------------------------
|
@@ -41,8 +84,51 @@ class QueryRewrite(BaseModel):
|
|
41 |
structured: StructuredQueryRewriteResponse | None = None
|
42 |
|
43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
class APIGenreRecommendationResponse(BaseModel):
|
45 |
-
genres: list[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
|
47 |
|
48 |
# -------------------------------- VLLM --------------------------------------
|
@@ -68,7 +154,7 @@ vllm_system_prompt = (
|
|
68 |
"You are a search query optimization assistant built into"
|
69 |
" music genre search engine, helping users discover novel music genres."
|
70 |
)
|
71 |
-
vllm_prompt =
|
72 |
|
73 |
# -------------------------------- GEMINI ------------------------------------
|
74 |
gemini_config = types.GenerateContentConfig(
|
@@ -76,20 +162,35 @@ gemini_config = types.GenerateContentConfig(
|
|
76 |
response_schema=APIGenreRecommendationResponse,
|
77 |
temperature=0.7,
|
78 |
max_output_tokens=1024,
|
79 |
-
system_instruction=(
|
|
|
|
|
|
|
80 |
)
|
81 |
gemini_llm = genai.Client(
|
82 |
api_key=GEMINI_API_KEY,
|
83 |
http_options={"api_version": "v1alpha"},
|
84 |
)
|
85 |
-
gemini_prompt =
|
86 |
|
87 |
-
|
88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
|
90 |
|
91 |
-
#
|
92 |
-
def
|
93 |
prompt = vllm_prompt.format(query=query)
|
94 |
messages = [
|
95 |
{"role": "system", "content": vllm_system_prompt},
|
@@ -104,10 +205,181 @@ def recommend_sadaimrec(query: str):
|
|
104 |
rewrites=[x for x in list(rewrite_json.values()) if x is not None],
|
105 |
structured=rewrite_json,
|
106 |
)
|
107 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
|
109 |
|
110 |
def recommend_gemini(query: str):
|
|
|
111 |
prompt = gemini_prompt.format(query=query)
|
112 |
response = gemini_llm.models.generate_content(
|
113 |
model="gemini-2.0-flash",
|
@@ -115,17 +387,19 @@ def recommend_gemini(query: str):
|
|
115 |
config=gemini_config,
|
116 |
)
|
117 |
parsed_content: APIGenreRecommendationResponse = response.parsed
|
118 |
-
|
|
|
|
|
|
|
119 |
|
120 |
|
121 |
-
#
|
122 |
pipelines = {
|
123 |
"sadaimrec": recommend_sadaimrec,
|
124 |
"chatgpt": recommend_gemini,
|
125 |
}
|
126 |
|
127 |
|
128 |
-
# -------------------------------------- INTERFACE -----------------------------
|
129 |
def generate_responses(query):
|
130 |
# Randomize model order
|
131 |
pipeline_names = list(pipelines.keys())
|
@@ -156,30 +430,37 @@ def reset_ui():
|
|
156 |
gr.update(value=""), # clear query
|
157 |
gr.update(visible=False), # hide radio
|
158 |
gr.update(visible=False), # hide vote button
|
159 |
-
gr.update(value=""), # clear Option 1 text
|
160 |
-
gr.update(value=""), # clear Option 2 text
|
161 |
gr.update(value=""), # clear result
|
162 |
gr.update(active=False),
|
163 |
)
|
164 |
|
165 |
|
166 |
-
|
167 |
-
|
168 |
-
|
|
|
|
|
|
|
|
|
|
|
169 |
submit_btn = gr.Button("Submit")
|
170 |
# timer that resets ui after feedback is sent
|
171 |
reset_timer = gr.Timer(value=2.0, active=False)
|
172 |
|
173 |
# Hidden components to store model responses and names
|
174 |
with gr.Row(visible=False) as response_row:
|
175 |
-
response_1 = gr.
|
176 |
-
response_2 = gr.
|
177 |
model_label_1 = gr.Textbox(visible=False)
|
178 |
model_label_2 = gr.Textbox(visible=False)
|
179 |
|
180 |
# Feedback
|
181 |
vote = gr.Radio(
|
182 |
-
["Option 1", "Option 2"],
|
|
|
|
|
183 |
)
|
184 |
vote_btn = gr.Button("Vote", visible=False)
|
185 |
result = gr.Textbox(label="Console", interactive=False)
|
@@ -189,6 +470,7 @@ with gr.Blocks() as demo:
|
|
189 |
fn=generate_responses,
|
190 |
inputs=[query],
|
191 |
outputs=[response_1, response_2, model_label_1, model_label_2],
|
|
|
192 |
)
|
193 |
submit_btn.click( # update ui
|
194 |
fn=lambda: (
|
@@ -222,6 +504,7 @@ with gr.Blocks() as demo:
|
|
222 |
trigger_mode="once",
|
223 |
)
|
224 |
|
|
|
225 |
if __name__ == "__main__":
|
226 |
demo.queue(max_size=10, default_concurrency_limit=1).launch(
|
227 |
server_name="0.0.0.0", server_port=7860
|
|
|
1 |
import json
|
2 |
import os
|
3 |
import random
|
4 |
+
import urllib.parse
|
5 |
from pathlib import Path
|
6 |
+
from typing import Optional
|
7 |
|
8 |
import gradio as gr
|
9 |
+
import numpy as np
|
10 |
+
import pandas as pd
|
11 |
+
from dotenv import load_dotenv
|
12 |
+
from fastembed import SparseEmbedding, SparseTextEmbedding
|
13 |
from google import genai
|
14 |
from google.genai import types
|
15 |
+
from pydantic import BaseModel, Field
|
16 |
+
from qdrant_client import QdrantClient
|
17 |
+
from qdrant_client import models as qmodels
|
18 |
+
from sentence_transformers import CrossEncoder, SentenceTransformer
|
19 |
from vllm import LLM, SamplingParams
|
20 |
from vllm.sampling_params import GuidedDecodingParams
|
21 |
|
22 |
+
load_dotenv()
|
23 |
|
24 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
25 |
|
|
|
30 |
|
31 |
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
|
32 |
|
33 |
+
DATA_PATH = Path(os.getenv("DATA_PATH"))
|
34 |
+
DB_PATH = DATA_PATH / "db"
|
35 |
+
|
36 |
+
client = QdrantClient(path=str(DB_PATH))
|
37 |
+
collection_name = "knowledge_cards"
|
38 |
+
num_chunks_base = 500
|
39 |
+
alpha = 0.5
|
40 |
+
top_k = 5 # we only want top 5 genres
|
41 |
+
|
42 |
+
youtube_url_template = "{genre} music playlist"
|
43 |
+
|
44 |
+
|
45 |
# -------------------------------- HELPERS -------------------------------------
|
46 |
+
def load_text_resource(path: Path) -> str:
|
47 |
with path.open("r") as file:
|
48 |
+
resource = file.read()
|
49 |
+
return resource
|
50 |
+
|
51 |
+
|
52 |
+
def youtube_search_link_for_genre(genre: str) -> str:
|
53 |
+
base_url = "https://www.youtube.com/results"
|
54 |
+
params = {
|
55 |
+
"search_query": youtube_url_template.format(
|
56 |
+
genre=genre.replace("_", " ").lower()
|
57 |
+
)
|
58 |
+
}
|
59 |
+
return f"{base_url}?{urllib.parse.urlencode(params)}"
|
60 |
+
|
61 |
+
|
62 |
+
def generate_recommendation_string(ranking: dict[str, float]) -> str:
|
63 |
+
recommendation_string = "## Recommendations for You\n\n"
|
64 |
+
for idx, (genre, score) in enumerate(ranking.items(), start=1):
|
65 |
+
youtube_link = youtube_search_link_for_genre(genre=genre)
|
66 |
+
recommendation_string += (
|
67 |
+
f"{idx}. **{genre.replace('_', ' ').capitalize()}** ({score:.2f}); "
|
68 |
+
f"[YouTube link]({youtube_link})\n"
|
69 |
+
)
|
70 |
+
return recommendation_string
|
71 |
|
72 |
|
73 |
# -------------------------------- Data Models -------------------------------
|
|
|
84 |
structured: StructuredQueryRewriteResponse | None = None
|
85 |
|
86 |
|
87 |
+
class APIGenreRecommendation(BaseModel):
|
88 |
+
name: str = Field(description="Name of the music genre.")
|
89 |
+
score: float = Field(
|
90 |
+
description="Score you assign to the genre (from 0 to 1).", ge=0, le=1
|
91 |
+
)
|
92 |
+
|
93 |
+
|
94 |
class APIGenreRecommendationResponse(BaseModel):
|
95 |
+
genres: list[APIGenreRecommendation]
|
96 |
+
|
97 |
+
|
98 |
+
class RetrievalResult(BaseModel):
|
99 |
+
chunk: str
|
100 |
+
genre: str
|
101 |
+
score: float
|
102 |
+
|
103 |
+
|
104 |
+
class RerankingResult(BaseModel):
|
105 |
+
query: str
|
106 |
+
genre: str
|
107 |
+
chunk: str
|
108 |
+
score: float
|
109 |
+
|
110 |
+
|
111 |
+
class Recommendation(BaseModel):
|
112 |
+
name: str
|
113 |
+
rank: int
|
114 |
+
score: Optional[float] = None
|
115 |
+
|
116 |
+
|
117 |
+
class PipelineResult(BaseModel):
|
118 |
+
query: str
|
119 |
+
rewrite: Optional[QueryRewrite] = None
|
120 |
+
retrieval_result: Optional[list[RetrievalResult]] = None
|
121 |
+
reranking_result: Optional[list[RerankingResult]] = None
|
122 |
+
recommendations: Optional[dict[str, Recommendation]] = None
|
123 |
+
|
124 |
+
def to_ranking(self) -> dict[str, float]:
|
125 |
+
if not self.recommendations:
|
126 |
+
return {}
|
127 |
+
|
128 |
+
return {
|
129 |
+
genre: recommendation.score
|
130 |
+
for genre, recommendation in self.recommendations.items()
|
131 |
+
}
|
132 |
|
133 |
|
134 |
# -------------------------------- VLLM --------------------------------------
|
|
|
154 |
"You are a search query optimization assistant built into"
|
155 |
" music genre search engine, helping users discover novel music genres."
|
156 |
)
|
157 |
+
vllm_prompt = load_text_resource(Path("./resources/prompt_vllm.md"))
|
158 |
|
159 |
# -------------------------------- GEMINI ------------------------------------
|
160 |
gemini_config = types.GenerateContentConfig(
|
|
|
162 |
response_schema=APIGenreRecommendationResponse,
|
163 |
temperature=0.7,
|
164 |
max_output_tokens=1024,
|
165 |
+
system_instruction=(
|
166 |
+
"You are a helpful music genre recommendation assistant built into"
|
167 |
+
" music genre search engine, helping users discover novel music genres."
|
168 |
+
)
|
169 |
)
|
170 |
gemini_llm = genai.Client(
|
171 |
api_key=GEMINI_API_KEY,
|
172 |
http_options={"api_version": "v1alpha"},
|
173 |
)
|
174 |
+
gemini_prompt = load_text_resource(Path("./resources/prompt_api.md"))
|
175 |
|
176 |
+
# ---------------------------- EMBEDDING MODELS --------------------------------
|
177 |
+
dense_encoder = SentenceTransformer(
|
178 |
+
model_name_or_path="mixedbread-ai/mxbai-embed-large-v1",
|
179 |
+
device="cuda",
|
180 |
+
model_kwargs={"torch_dtype": VLLM_DTYPE},
|
181 |
+
)
|
182 |
+
sparse_encoder = SparseTextEmbedding(model_name="Qdrant/bm25", cuda=True)
|
183 |
+
reranker = CrossEncoder(
|
184 |
+
model_name_or_path="BAAI/bge-reranker-v2-m3",
|
185 |
+
max_length=1024,
|
186 |
+
device="cuda",
|
187 |
+
model_kwargs={"torch_dtype": VLLM_DTYPE},
|
188 |
+
)
|
189 |
+
reranker_batch_size = 128
|
190 |
|
191 |
|
192 |
+
# ---------------------------- RETRIEVAL ---------------------------------------
|
193 |
+
def run_query_rewrite(query: str) -> QueryRewrite:
|
194 |
prompt = vllm_prompt.format(query=query)
|
195 |
messages = [
|
196 |
{"role": "system", "content": vllm_system_prompt},
|
|
|
205 |
rewrites=[x for x in list(rewrite_json.values()) if x is not None],
|
206 |
structured=rewrite_json,
|
207 |
)
|
208 |
+
return rewrite
|
209 |
+
|
210 |
+
|
211 |
+
def prepare_queries_for_retrieval(
|
212 |
+
query: str, rewrite: QueryRewrite
|
213 |
+
) -> list[dict[str, str | None]]:
|
214 |
+
queries_to_retrieve = [{"text": query, "topic": None}]
|
215 |
+
for cat, rewrite in rewrite.structured.model_dump().items():
|
216 |
+
if rewrite is None:
|
217 |
+
continue
|
218 |
+
topic = cat
|
219 |
+
if cat not in ["subjective", "purpose", "technical"]:
|
220 |
+
topic = None
|
221 |
+
queries_to_retrieve.append({"text": rewrite, "topic": topic})
|
222 |
+
return queries_to_retrieve
|
223 |
+
|
224 |
+
|
225 |
+
def run_retrieval(
|
226 |
+
queries: list[dict[str, str]],
|
227 |
+
) -> RetrievalResult:
|
228 |
+
queries_to_embed = [query["text"] for query in queries]
|
229 |
+
dense_queries = list(
|
230 |
+
dense_encoder.encode(
|
231 |
+
queries_to_embed, convert_to_numpy=True, normalize_embeddings=True
|
232 |
+
)
|
233 |
+
)
|
234 |
+
sparse_queries = list(sparse_encoder.query_embed(queries_to_embed))
|
235 |
+
prefetches: list[qmodels.Prefetch] = []
|
236 |
+
|
237 |
+
for query, dense_query, sparse_query in zip(queries, dense_queries, sparse_queries):
|
238 |
+
assert dense_query is not None and sparse_query is not None
|
239 |
+
assert isinstance(dense_query, np.ndarray) and isinstance(
|
240 |
+
sparse_query, SparseEmbedding
|
241 |
+
)
|
242 |
+
topic = query.get("topic", None)
|
243 |
+
prefetch = [
|
244 |
+
qmodels.Prefetch(
|
245 |
+
query=dense_query,
|
246 |
+
using="dense",
|
247 |
+
filter=qmodels.Filter(
|
248 |
+
must=[
|
249 |
+
qmodels.FieldCondition(
|
250 |
+
key="topic", match=qmodels.MatchValue(value=topic)
|
251 |
+
)
|
252 |
+
]
|
253 |
+
)
|
254 |
+
if topic is not None
|
255 |
+
else None,
|
256 |
+
limit=num_chunks_base,
|
257 |
+
),
|
258 |
+
qmodels.Prefetch(
|
259 |
+
query=qmodels.SparseVector(**sparse_query.as_object()),
|
260 |
+
using="sparse",
|
261 |
+
filter=qmodels.Filter(
|
262 |
+
must=[
|
263 |
+
qmodels.FieldCondition(
|
264 |
+
key="topic", match=qmodels.MatchValue(value=topic)
|
265 |
+
)
|
266 |
+
]
|
267 |
+
)
|
268 |
+
if topic is not None
|
269 |
+
else None,
|
270 |
+
limit=num_chunks_base,
|
271 |
+
),
|
272 |
+
]
|
273 |
+
prefetches.extend(prefetch)
|
274 |
+
|
275 |
+
retrieval_results = client.query_points(
|
276 |
+
collection_name=collection_name,
|
277 |
+
prefetch=prefetches,
|
278 |
+
query=qmodels.FusionQuery(fusion=qmodels.Fusion.RRF),
|
279 |
+
limit=num_chunks_base,
|
280 |
+
)
|
281 |
+
|
282 |
+
final_hits: list[RetrievalResult] = [
|
283 |
+
RetrievalResult(
|
284 |
+
chunk=hit.payload["text"], genre=hit.payload["genre"], score=hit.score
|
285 |
+
)
|
286 |
+
for hit in retrieval_results.points
|
287 |
+
]
|
288 |
+
return final_hits
|
289 |
+
|
290 |
+
|
291 |
+
def run_reranking(
|
292 |
+
query: str, retrieval_result: list[RetrievalResult]
|
293 |
+
) -> list[RerankingResult]:
|
294 |
+
hit_texts: list[str] = [result.chunk for result in retrieval_result]
|
295 |
+
hit_genres: list[str] = [result.genre for result in retrieval_result]
|
296 |
+
hit_rerank = reranker.rank(
|
297 |
+
query=query,
|
298 |
+
documents=hit_texts,
|
299 |
+
batch_size=reranker_batch_size,
|
300 |
+
)
|
301 |
+
ranking = [
|
302 |
+
RerankingResult(
|
303 |
+
query=query,
|
304 |
+
genre=hit_genres[hit["corpus_id"]],
|
305 |
+
chunk=hit_texts[hit["corpus_id"]],
|
306 |
+
score=hit["score"],
|
307 |
+
)
|
308 |
+
for hit in hit_rerank
|
309 |
+
]
|
310 |
+
ranking.sort(key=lambda x: x.score, reverse=True)
|
311 |
+
return ranking
|
312 |
+
|
313 |
+
|
314 |
+
def get_top_genres(
|
315 |
+
df: pd.DataFrame,
|
316 |
+
column: str,
|
317 |
+
alpha: float = 1.0,
|
318 |
+
# beta: float = 1.0,
|
319 |
+
top_k: int | None = None,
|
320 |
+
) -> pd.Series:
|
321 |
+
assert 0 <= alpha <= 1.0
|
322 |
+
|
323 |
+
# Min-max normalization of re-ranker scores before aggregation
|
324 |
+
task_scores = df[column]
|
325 |
+
min_score = task_scores.min()
|
326 |
+
max_score = task_scores.max()
|
327 |
+
if max_score > min_score: # Avoid division by zero
|
328 |
+
df.loc[:, column] = (task_scores - min_score) / (max_score - min_score)
|
329 |
+
|
330 |
+
tg_df = df.groupby("genre").agg(size=("chunk", "size"), score=(column, "sum"))
|
331 |
+
tg_df["weighted_score"] = alpha * (tg_df["size"] / tg_df["size"].max()) + (
|
332 |
+
1 - alpha
|
333 |
+
) * (tg_df["score"] / tg_df["score"].max())
|
334 |
+
tg = tg_df.sort_values("weighted_score", ascending=False)["weighted_score"]
|
335 |
+
|
336 |
+
if top_k:
|
337 |
+
tg = tg.head(top_k)
|
338 |
+
|
339 |
+
return tg
|
340 |
+
|
341 |
+
|
342 |
+
def get_recommendations(
|
343 |
+
reranking_result: list[RerankingResult],
|
344 |
+
) -> dict[str, Recommendation]:
|
345 |
+
ranking_df = pd.DataFrame([x.model_dump(mode="python") for x in reranking_result])
|
346 |
+
top_genres_series = get_top_genres(
|
347 |
+
df=ranking_df, column="score", alpha=alpha, top_k=top_k
|
348 |
+
)
|
349 |
+
recommendations = {
|
350 |
+
genre: Recommendation(name=genre, rank=rank, score=score)
|
351 |
+
for rank, (genre, score) in enumerate(
|
352 |
+
top_genres_series.to_dict().items(), start=1
|
353 |
+
)
|
354 |
+
}
|
355 |
+
return recommendations
|
356 |
+
|
357 |
+
|
358 |
+
# ----------------------- GENERATE RECOMMENDATIONS -----------------------------
|
359 |
+
def recommend_sadaimrec(query: str):
|
360 |
+
result = PipelineResult(query=query)
|
361 |
+
print("Running query processing...", flush=True)
|
362 |
+
result.rewrite = run_query_rewrite(query=query)
|
363 |
+
queries_to_retrieve = prepare_queries_for_retrieval(
|
364 |
+
query=query, rewrite=result.rewrite
|
365 |
+
)
|
366 |
+
|
367 |
+
print("Running retrieval...", flush=True)
|
368 |
+
result.retrieval_result = run_retrieval(queries_to_retrieve)
|
369 |
+
|
370 |
+
print("Running re-ranking...", flush=True)
|
371 |
+
result.reranking_result = run_reranking(
|
372 |
+
query=query, retrieval_result=result.retrieval_result
|
373 |
+
)
|
374 |
+
|
375 |
+
print("Aggregating recommendations...", flush=True)
|
376 |
+
result.recommendations = get_recommendations(result.reranking_result)
|
377 |
+
recommendation_string = generate_recommendation_string(result.to_ranking())
|
378 |
+
return f"{recommendation_string}"
|
379 |
|
380 |
|
381 |
def recommend_gemini(query: str):
|
382 |
+
print("Generating recommendations using Gemini...", flush=True)
|
383 |
prompt = gemini_prompt.format(query=query)
|
384 |
response = gemini_llm.models.generate_content(
|
385 |
model="gemini-2.0-flash",
|
|
|
387 |
config=gemini_config,
|
388 |
)
|
389 |
parsed_content: APIGenreRecommendationResponse = response.parsed
|
390 |
+
parsed_content.genres.sort(key=lambda x: x.score, reverse=True)
|
391 |
+
ranking = {x.name.lower(): x.score for x in parsed_content.genres}
|
392 |
+
recommendation_string = generate_recommendation_string(ranking)
|
393 |
+
return f"{recommendation_string}"
|
394 |
|
395 |
|
396 |
+
# -------------------------------------- INTERFACE -----------------------------
|
397 |
pipelines = {
|
398 |
"sadaimrec": recommend_sadaimrec,
|
399 |
"chatgpt": recommend_gemini,
|
400 |
}
|
401 |
|
402 |
|
|
|
403 |
def generate_responses(query):
|
404 |
# Randomize model order
|
405 |
pipeline_names = list(pipelines.keys())
|
|
|
430 |
gr.update(value=""), # clear query
|
431 |
gr.update(visible=False), # hide radio
|
432 |
gr.update(visible=False), # hide vote button
|
433 |
+
gr.update(value="**Generating...**"), # clear Option 1 text
|
434 |
+
gr.update(value="**Generating...**"), # clear Option 2 text
|
435 |
gr.update(value=""), # clear result
|
436 |
gr.update(active=False),
|
437 |
)
|
438 |
|
439 |
|
440 |
+
app_description = load_text_resource(Path("./resources/description.md"))
|
441 |
+
|
442 |
+
with gr.Blocks(title="SADAIMREC") as demo:
|
443 |
+
gr.Markdown(app_description)
|
444 |
+
query = gr.Textbox(
|
445 |
+
label="Your Query",
|
446 |
+
placeholder="Calming, music for deep relaxation with echoing sounds and deep bass",
|
447 |
+
)
|
448 |
submit_btn = gr.Button("Submit")
|
449 |
# timer that resets ui after feedback is sent
|
450 |
reset_timer = gr.Timer(value=2.0, active=False)
|
451 |
|
452 |
# Hidden components to store model responses and names
|
453 |
with gr.Row(visible=False) as response_row:
|
454 |
+
response_1 = gr.Markdown(value="**Generating...**", label="Option 1")
|
455 |
+
response_2 = gr.Markdown(value="**Generating...**", label="Option 2")
|
456 |
model_label_1 = gr.Textbox(visible=False)
|
457 |
model_label_2 = gr.Textbox(visible=False)
|
458 |
|
459 |
# Feedback
|
460 |
vote = gr.Radio(
|
461 |
+
["Option 1 (left)", "Option 2 (right)"],
|
462 |
+
label="Select Best Response",
|
463 |
+
visible=False,
|
464 |
)
|
465 |
vote_btn = gr.Button("Vote", visible=False)
|
466 |
result = gr.Textbox(label="Console", interactive=False)
|
|
|
470 |
fn=generate_responses,
|
471 |
inputs=[query],
|
472 |
outputs=[response_1, response_2, model_label_1, model_label_2],
|
473 |
+
show_progress="full",
|
474 |
)
|
475 |
submit_btn.click( # update ui
|
476 |
fn=lambda: (
|
|
|
504 |
trigger_mode="once",
|
505 |
)
|
506 |
|
507 |
+
|
508 |
if __name__ == "__main__":
|
509 |
demo.queue(max_size=10, default_concurrency_limit=1).launch(
|
510 |
server_name="0.0.0.0", server_port=7860
|
ingest.py
CHANGED
@@ -9,11 +9,13 @@ from huggingface_hub import hf_hub_download
|
|
9 |
from qdrant_client import QdrantClient
|
10 |
from qdrant_client import models as qmodels
|
11 |
|
|
|
|
|
12 |
DATA_PATH = Path(os.getenv("DATA_PATH"))
|
13 |
DB_PATH = DATA_PATH / "db"
|
14 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
15 |
|
16 |
-
RECREATE_DB = bool(os.getenv("RECREATE_DB", "False").lower == "true")
|
17 |
DATA_REPO = os.getenv("DATA_REPO")
|
18 |
DATA_FILENAME = os.getenv("DATA_FILENAME")
|
19 |
|
@@ -24,7 +26,9 @@ dense_batch_size = 128
|
|
24 |
sparse_batch_size = 256
|
25 |
|
26 |
dense_encoder = SentenceTransformer(
|
27 |
-
model_name_or_path="mixedbread-ai/mxbai-embed-large-v1",
|
|
|
|
|
28 |
)
|
29 |
sparse_encoder = SparseTextEmbedding(model_name="Qdrant/bm25", cuda=True)
|
30 |
|
|
|
9 |
from qdrant_client import QdrantClient
|
10 |
from qdrant_client import models as qmodels
|
11 |
|
12 |
+
VLLM_DTYPE = os.getenv("VLLM_DTYPE")
|
13 |
+
|
14 |
DATA_PATH = Path(os.getenv("DATA_PATH"))
|
15 |
DB_PATH = DATA_PATH / "db"
|
16 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
17 |
|
18 |
+
RECREATE_DB = bool(os.getenv("RECREATE_DB", "False").lower() == "true")
|
19 |
DATA_REPO = os.getenv("DATA_REPO")
|
20 |
DATA_FILENAME = os.getenv("DATA_FILENAME")
|
21 |
|
|
|
26 |
sparse_batch_size = 256
|
27 |
|
28 |
dense_encoder = SentenceTransformer(
|
29 |
+
model_name_or_path="mixedbread-ai/mxbai-embed-large-v1",
|
30 |
+
device="cuda",
|
31 |
+
model_kwargs={"torch_dtype": VLLM_DTYPE},
|
32 |
)
|
33 |
sparse_encoder = SparseTextEmbedding(model_name="Qdrant/bm25", cuda=True)
|
34 |
|
prompts/api.txt
DELETED
@@ -1,7 +0,0 @@
|
|
1 |
-
# Purpose
|
2 |
-
|
3 |
-
Recommend 5 genres based on the user query
|
4 |
-
|
5 |
-
# Query
|
6 |
-
|
7 |
-
{query}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
resources/description.md
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Music Genre Recommendation Side-By-Side Comparison
|
2 |
+
|
3 |
+
This simple application was developed and deployed as **complementary material for my thesis**.
|
4 |
+
|
5 |
+
In case of any complications, questions or suggestions, please reach out via [email](mailto:[email protected]).
|
6 |
+
|
7 |
+
## Instructions
|
8 |
+
|
9 |
+
1. Formulate a **search query** with description of a music genre you would like to listen to. Expected format is described below.
|
10 |
+
2. Explore **two generated recommendation rankings**: one is created by my system, one is generated using `gemini-2.0-flash`. Order is **randomized** each run.
|
11 |
+
3. Determine which ranking you prefer more.
|
12 |
+
4. Vote for your choice.
|
13 |
+
5. Wait for refresh and repeat as many times as you want.
|
14 |
+
|
15 |
+
## Expected Query Format
|
16 |
+
|
17 |
+
- The system was designed to support **3 categories** of music genre descriptors:
|
18 |
+
- **Subjective**: Emotional & perceptual qualities, desired **inner feeling** (melancholic, energetic)
|
19 |
+
- **Purpose-Based**: Listening setting, context, suitable activities, scenario (party, workout)
|
20 |
+
- **Technical**: Musical & production attributes, **HOW the sound is made** (instrumentation, timbre, tempo, lo-fi)
|
21 |
+
- **Other descriptors are out of scope of the current implementation**:
|
22 |
+
- I kindly ask you to only use the above 3 categories for your queries
|
23 |
+
- Usage of cultural, historical, etc. descriptors can lead to suboptimal results
|
24 |
+
- You can make the descriptors **as complex and poetic as you want**, but I kindly ask you to **limit your query to a couple of sentences**
|
25 |
+
|
26 |
+
## Query Examples
|
27 |
+
|
28 |
+
- `Music for deep relaxation with echoing sounds and heavy bass, perfect for unwinding after along day`
|
29 |
+
- `Music that feels like the echo of a forgotten world—slow, sorrowful. Guitars and distant vocals create the sensation of a long, drifting sleep on the edge of melancholy and oblivion.A soundtrack to isolation, it slowly pulls you into the depths of existential despair.`
|
30 |
+
- `Raw and filled with aggression, high-energy drums, mosh-pit vibes, high bpm, guitars`
|
31 |
+
- `Music to study to, relaxing, chill with calm drums, some piano, and suitable for background`
|
32 |
+
- `Creamy and cozy, suitable for evenings with loved ones`
|
33 |
+
- `Dreamy instrumental music for midnight melancholia`
|
resources/prompt_api.md
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Purpose and Context
|
2 |
+
|
3 |
+
Given a user-generated Search Query describing music they wish to explore, create a ranking of the most suitable music genres.
|
4 |
+
|
5 |
+
# Instructions
|
6 |
+
|
7 |
+
1. Create a music genre ranking, including 5 the most suitable music genres, ordered from the most to the least suitable.
|
8 |
+
2. Respond in JSON.
|
9 |
+
|
10 |
+
# Search Query
|
11 |
+
|
12 |
+
{query}
|
prompts/local.txt → resources/prompt_vllm.md
RENAMED
@@ -228,4 +228,4 @@ Given a user-generated Search Query describing music they wish to explore, you m
|
|
228 |
|
229 |
# Search Query
|
230 |
|
231 |
-
{query}
|
|
|
228 |
|
229 |
# Search Query
|
230 |
|
231 |
+
{query}
|