Oleh Kuznetsov commited on
Commit
bdaca7e
·
1 Parent(s): 6e1997a

feat(rec): Finalize recommendations (almost done)

Browse files
.gitignore CHANGED
@@ -1,4 +1,5 @@
1
  *__pycache__*
2
  .venv
3
  .env
4
- data
 
 
1
  *__pycache__*
2
  .venv
3
  .env
4
+ data
5
+ *sandbox*
Dockerfile CHANGED
@@ -31,7 +31,7 @@ ENV HOME=/home/user \
31
 
32
  # Setup application directory
33
  WORKDIR $HOME/app
34
- ADD --chown=user ./prompts $HOME/app/prompts
35
  ADD --chown=user ./ingest.py $HOME/app/ingest.py
36
  ADD --chown=user ./app.py $HOME/app/app.py
37
 
 
31
 
32
  # Setup application directory
33
  WORKDIR $HOME/app
34
+ ADD --chown=user ./resources $HOME/app/resources
35
  ADD --chown=user ./ingest.py $HOME/app/ingest.py
36
  ADD --chown=user ./app.py $HOME/app/app.py
37
 
app.py CHANGED
@@ -1,15 +1,25 @@
1
  import json
2
  import os
3
  import random
 
4
  from pathlib import Path
 
5
 
6
  import gradio as gr
 
 
 
 
7
  from google import genai
8
  from google.genai import types
9
- from pydantic import BaseModel
 
 
 
10
  from vllm import LLM, SamplingParams
11
  from vllm.sampling_params import GuidedDecodingParams
12
 
 
13
 
14
  HF_TOKEN = os.getenv("HF_TOKEN")
15
 
@@ -20,11 +30,44 @@ VLLM_DTYPE = os.getenv("VLLM_DTYPE")
20
 
21
  GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
22
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  # -------------------------------- HELPERS -------------------------------------
24
- def load_prompt(path: Path) -> str:
25
  with path.open("r") as file:
26
- prompt = file.read()
27
- return prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
 
30
  # -------------------------------- Data Models -------------------------------
@@ -41,8 +84,51 @@ class QueryRewrite(BaseModel):
41
  structured: StructuredQueryRewriteResponse | None = None
42
 
43
 
 
 
 
 
 
 
 
44
  class APIGenreRecommendationResponse(BaseModel):
45
- genres: list[str]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
 
48
  # -------------------------------- VLLM --------------------------------------
@@ -68,7 +154,7 @@ vllm_system_prompt = (
68
  "You are a search query optimization assistant built into"
69
  " music genre search engine, helping users discover novel music genres."
70
  )
71
- vllm_prompt = load_prompt(Path("./prompts/local.txt"))
72
 
73
  # -------------------------------- GEMINI ------------------------------------
74
  gemini_config = types.GenerateContentConfig(
@@ -76,20 +162,35 @@ gemini_config = types.GenerateContentConfig(
76
  response_schema=APIGenreRecommendationResponse,
77
  temperature=0.7,
78
  max_output_tokens=1024,
79
- system_instruction=("You are a helpful music genre recommendation assistant."),
 
 
 
80
  )
81
  gemini_llm = genai.Client(
82
  api_key=GEMINI_API_KEY,
83
  http_options={"api_version": "v1alpha"},
84
  )
85
- gemini_prompt = load_prompt(Path("./prompts/api.txt"))
86
 
87
-
88
- # ---------------------------- RETRIEVAL ---------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
 
91
- # ----------------------- GENERATE RECOMMENDATIONS -----------------------------
92
- def recommend_sadaimrec(query: str):
93
  prompt = vllm_prompt.format(query=query)
94
  messages = [
95
  {"role": "system", "content": vllm_system_prompt},
@@ -104,10 +205,181 @@ def recommend_sadaimrec(query: str):
104
  rewrites=[x for x in list(rewrite_json.values()) if x is not None],
105
  structured=rewrite_json,
106
  )
107
- return f"SADAIMREC: response to '{rewrite.model_dump_json(indent=4)}'"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
 
110
  def recommend_gemini(query: str):
 
111
  prompt = gemini_prompt.format(query=query)
112
  response = gemini_llm.models.generate_content(
113
  model="gemini-2.0-flash",
@@ -115,17 +387,19 @@ def recommend_gemini(query: str):
115
  config=gemini_config,
116
  )
117
  parsed_content: APIGenreRecommendationResponse = response.parsed
118
- return f"CHATGPT: response to '{parsed_content.model_dump_json(indent=4)}'"
 
 
 
119
 
120
 
121
- # Mapping names to functions
122
  pipelines = {
123
  "sadaimrec": recommend_sadaimrec,
124
  "chatgpt": recommend_gemini,
125
  }
126
 
127
 
128
- # -------------------------------------- INTERFACE -----------------------------
129
  def generate_responses(query):
130
  # Randomize model order
131
  pipeline_names = list(pipelines.keys())
@@ -156,30 +430,37 @@ def reset_ui():
156
  gr.update(value=""), # clear query
157
  gr.update(visible=False), # hide radio
158
  gr.update(visible=False), # hide vote button
159
- gr.update(value=""), # clear Option 1 text
160
- gr.update(value=""), # clear Option 2 text
161
  gr.update(value=""), # clear result
162
  gr.update(active=False),
163
  )
164
 
165
 
166
- with gr.Blocks() as demo:
167
- gr.Markdown("# Music Genre Recommendation Side-By-Side Comparison")
168
- query = gr.Textbox(label="Your Query")
 
 
 
 
 
169
  submit_btn = gr.Button("Submit")
170
  # timer that resets ui after feedback is sent
171
  reset_timer = gr.Timer(value=2.0, active=False)
172
 
173
  # Hidden components to store model responses and names
174
  with gr.Row(visible=False) as response_row:
175
- response_1 = gr.Textbox(label="Option 1", interactive=False)
176
- response_2 = gr.Textbox(label="Option 2", interactive=False)
177
  model_label_1 = gr.Textbox(visible=False)
178
  model_label_2 = gr.Textbox(visible=False)
179
 
180
  # Feedback
181
  vote = gr.Radio(
182
- ["Option 1", "Option 2"], label="Select Best Response", visible=False
 
 
183
  )
184
  vote_btn = gr.Button("Vote", visible=False)
185
  result = gr.Textbox(label="Console", interactive=False)
@@ -189,6 +470,7 @@ with gr.Blocks() as demo:
189
  fn=generate_responses,
190
  inputs=[query],
191
  outputs=[response_1, response_2, model_label_1, model_label_2],
 
192
  )
193
  submit_btn.click( # update ui
194
  fn=lambda: (
@@ -222,6 +504,7 @@ with gr.Blocks() as demo:
222
  trigger_mode="once",
223
  )
224
 
 
225
  if __name__ == "__main__":
226
  demo.queue(max_size=10, default_concurrency_limit=1).launch(
227
  server_name="0.0.0.0", server_port=7860
 
1
  import json
2
  import os
3
  import random
4
+ import urllib.parse
5
  from pathlib import Path
6
+ from typing import Optional
7
 
8
  import gradio as gr
9
+ import numpy as np
10
+ import pandas as pd
11
+ from dotenv import load_dotenv
12
+ from fastembed import SparseEmbedding, SparseTextEmbedding
13
  from google import genai
14
  from google.genai import types
15
+ from pydantic import BaseModel, Field
16
+ from qdrant_client import QdrantClient
17
+ from qdrant_client import models as qmodels
18
+ from sentence_transformers import CrossEncoder, SentenceTransformer
19
  from vllm import LLM, SamplingParams
20
  from vllm.sampling_params import GuidedDecodingParams
21
 
22
+ load_dotenv()
23
 
24
  HF_TOKEN = os.getenv("HF_TOKEN")
25
 
 
30
 
31
  GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
32
 
33
+ DATA_PATH = Path(os.getenv("DATA_PATH"))
34
+ DB_PATH = DATA_PATH / "db"
35
+
36
+ client = QdrantClient(path=str(DB_PATH))
37
+ collection_name = "knowledge_cards"
38
+ num_chunks_base = 500
39
+ alpha = 0.5
40
+ top_k = 5 # we only want top 5 genres
41
+
42
+ youtube_url_template = "{genre} music playlist"
43
+
44
+
45
  # -------------------------------- HELPERS -------------------------------------
46
+ def load_text_resource(path: Path) -> str:
47
  with path.open("r") as file:
48
+ resource = file.read()
49
+ return resource
50
+
51
+
52
+ def youtube_search_link_for_genre(genre: str) -> str:
53
+ base_url = "https://www.youtube.com/results"
54
+ params = {
55
+ "search_query": youtube_url_template.format(
56
+ genre=genre.replace("_", " ").lower()
57
+ )
58
+ }
59
+ return f"{base_url}?{urllib.parse.urlencode(params)}"
60
+
61
+
62
+ def generate_recommendation_string(ranking: dict[str, float]) -> str:
63
+ recommendation_string = "## Recommendations for You\n\n"
64
+ for idx, (genre, score) in enumerate(ranking.items(), start=1):
65
+ youtube_link = youtube_search_link_for_genre(genre=genre)
66
+ recommendation_string += (
67
+ f"{idx}. **{genre.replace('_', ' ').capitalize()}** ({score:.2f}); "
68
+ f"[YouTube link]({youtube_link})\n"
69
+ )
70
+ return recommendation_string
71
 
72
 
73
  # -------------------------------- Data Models -------------------------------
 
84
  structured: StructuredQueryRewriteResponse | None = None
85
 
86
 
87
+ class APIGenreRecommendation(BaseModel):
88
+ name: str = Field(description="Name of the music genre.")
89
+ score: float = Field(
90
+ description="Score you assign to the genre (from 0 to 1).", ge=0, le=1
91
+ )
92
+
93
+
94
  class APIGenreRecommendationResponse(BaseModel):
95
+ genres: list[APIGenreRecommendation]
96
+
97
+
98
+ class RetrievalResult(BaseModel):
99
+ chunk: str
100
+ genre: str
101
+ score: float
102
+
103
+
104
+ class RerankingResult(BaseModel):
105
+ query: str
106
+ genre: str
107
+ chunk: str
108
+ score: float
109
+
110
+
111
+ class Recommendation(BaseModel):
112
+ name: str
113
+ rank: int
114
+ score: Optional[float] = None
115
+
116
+
117
+ class PipelineResult(BaseModel):
118
+ query: str
119
+ rewrite: Optional[QueryRewrite] = None
120
+ retrieval_result: Optional[list[RetrievalResult]] = None
121
+ reranking_result: Optional[list[RerankingResult]] = None
122
+ recommendations: Optional[dict[str, Recommendation]] = None
123
+
124
+ def to_ranking(self) -> dict[str, float]:
125
+ if not self.recommendations:
126
+ return {}
127
+
128
+ return {
129
+ genre: recommendation.score
130
+ for genre, recommendation in self.recommendations.items()
131
+ }
132
 
133
 
134
  # -------------------------------- VLLM --------------------------------------
 
154
  "You are a search query optimization assistant built into"
155
  " music genre search engine, helping users discover novel music genres."
156
  )
157
+ vllm_prompt = load_text_resource(Path("./resources/prompt_vllm.md"))
158
 
159
  # -------------------------------- GEMINI ------------------------------------
160
  gemini_config = types.GenerateContentConfig(
 
162
  response_schema=APIGenreRecommendationResponse,
163
  temperature=0.7,
164
  max_output_tokens=1024,
165
+ system_instruction=(
166
+ "You are a helpful music genre recommendation assistant built into"
167
+ " music genre search engine, helping users discover novel music genres."
168
+ )
169
  )
170
  gemini_llm = genai.Client(
171
  api_key=GEMINI_API_KEY,
172
  http_options={"api_version": "v1alpha"},
173
  )
174
+ gemini_prompt = load_text_resource(Path("./resources/prompt_api.md"))
175
 
176
+ # ---------------------------- EMBEDDING MODELS --------------------------------
177
+ dense_encoder = SentenceTransformer(
178
+ model_name_or_path="mixedbread-ai/mxbai-embed-large-v1",
179
+ device="cuda",
180
+ model_kwargs={"torch_dtype": VLLM_DTYPE},
181
+ )
182
+ sparse_encoder = SparseTextEmbedding(model_name="Qdrant/bm25", cuda=True)
183
+ reranker = CrossEncoder(
184
+ model_name_or_path="BAAI/bge-reranker-v2-m3",
185
+ max_length=1024,
186
+ device="cuda",
187
+ model_kwargs={"torch_dtype": VLLM_DTYPE},
188
+ )
189
+ reranker_batch_size = 128
190
 
191
 
192
+ # ---------------------------- RETRIEVAL ---------------------------------------
193
+ def run_query_rewrite(query: str) -> QueryRewrite:
194
  prompt = vllm_prompt.format(query=query)
195
  messages = [
196
  {"role": "system", "content": vllm_system_prompt},
 
205
  rewrites=[x for x in list(rewrite_json.values()) if x is not None],
206
  structured=rewrite_json,
207
  )
208
+ return rewrite
209
+
210
+
211
+ def prepare_queries_for_retrieval(
212
+ query: str, rewrite: QueryRewrite
213
+ ) -> list[dict[str, str | None]]:
214
+ queries_to_retrieve = [{"text": query, "topic": None}]
215
+ for cat, rewrite in rewrite.structured.model_dump().items():
216
+ if rewrite is None:
217
+ continue
218
+ topic = cat
219
+ if cat not in ["subjective", "purpose", "technical"]:
220
+ topic = None
221
+ queries_to_retrieve.append({"text": rewrite, "topic": topic})
222
+ return queries_to_retrieve
223
+
224
+
225
+ def run_retrieval(
226
+ queries: list[dict[str, str]],
227
+ ) -> RetrievalResult:
228
+ queries_to_embed = [query["text"] for query in queries]
229
+ dense_queries = list(
230
+ dense_encoder.encode(
231
+ queries_to_embed, convert_to_numpy=True, normalize_embeddings=True
232
+ )
233
+ )
234
+ sparse_queries = list(sparse_encoder.query_embed(queries_to_embed))
235
+ prefetches: list[qmodels.Prefetch] = []
236
+
237
+ for query, dense_query, sparse_query in zip(queries, dense_queries, sparse_queries):
238
+ assert dense_query is not None and sparse_query is not None
239
+ assert isinstance(dense_query, np.ndarray) and isinstance(
240
+ sparse_query, SparseEmbedding
241
+ )
242
+ topic = query.get("topic", None)
243
+ prefetch = [
244
+ qmodels.Prefetch(
245
+ query=dense_query,
246
+ using="dense",
247
+ filter=qmodels.Filter(
248
+ must=[
249
+ qmodels.FieldCondition(
250
+ key="topic", match=qmodels.MatchValue(value=topic)
251
+ )
252
+ ]
253
+ )
254
+ if topic is not None
255
+ else None,
256
+ limit=num_chunks_base,
257
+ ),
258
+ qmodels.Prefetch(
259
+ query=qmodels.SparseVector(**sparse_query.as_object()),
260
+ using="sparse",
261
+ filter=qmodels.Filter(
262
+ must=[
263
+ qmodels.FieldCondition(
264
+ key="topic", match=qmodels.MatchValue(value=topic)
265
+ )
266
+ ]
267
+ )
268
+ if topic is not None
269
+ else None,
270
+ limit=num_chunks_base,
271
+ ),
272
+ ]
273
+ prefetches.extend(prefetch)
274
+
275
+ retrieval_results = client.query_points(
276
+ collection_name=collection_name,
277
+ prefetch=prefetches,
278
+ query=qmodels.FusionQuery(fusion=qmodels.Fusion.RRF),
279
+ limit=num_chunks_base,
280
+ )
281
+
282
+ final_hits: list[RetrievalResult] = [
283
+ RetrievalResult(
284
+ chunk=hit.payload["text"], genre=hit.payload["genre"], score=hit.score
285
+ )
286
+ for hit in retrieval_results.points
287
+ ]
288
+ return final_hits
289
+
290
+
291
+ def run_reranking(
292
+ query: str, retrieval_result: list[RetrievalResult]
293
+ ) -> list[RerankingResult]:
294
+ hit_texts: list[str] = [result.chunk for result in retrieval_result]
295
+ hit_genres: list[str] = [result.genre for result in retrieval_result]
296
+ hit_rerank = reranker.rank(
297
+ query=query,
298
+ documents=hit_texts,
299
+ batch_size=reranker_batch_size,
300
+ )
301
+ ranking = [
302
+ RerankingResult(
303
+ query=query,
304
+ genre=hit_genres[hit["corpus_id"]],
305
+ chunk=hit_texts[hit["corpus_id"]],
306
+ score=hit["score"],
307
+ )
308
+ for hit in hit_rerank
309
+ ]
310
+ ranking.sort(key=lambda x: x.score, reverse=True)
311
+ return ranking
312
+
313
+
314
+ def get_top_genres(
315
+ df: pd.DataFrame,
316
+ column: str,
317
+ alpha: float = 1.0,
318
+ # beta: float = 1.0,
319
+ top_k: int | None = None,
320
+ ) -> pd.Series:
321
+ assert 0 <= alpha <= 1.0
322
+
323
+ # Min-max normalization of re-ranker scores before aggregation
324
+ task_scores = df[column]
325
+ min_score = task_scores.min()
326
+ max_score = task_scores.max()
327
+ if max_score > min_score: # Avoid division by zero
328
+ df.loc[:, column] = (task_scores - min_score) / (max_score - min_score)
329
+
330
+ tg_df = df.groupby("genre").agg(size=("chunk", "size"), score=(column, "sum"))
331
+ tg_df["weighted_score"] = alpha * (tg_df["size"] / tg_df["size"].max()) + (
332
+ 1 - alpha
333
+ ) * (tg_df["score"] / tg_df["score"].max())
334
+ tg = tg_df.sort_values("weighted_score", ascending=False)["weighted_score"]
335
+
336
+ if top_k:
337
+ tg = tg.head(top_k)
338
+
339
+ return tg
340
+
341
+
342
+ def get_recommendations(
343
+ reranking_result: list[RerankingResult],
344
+ ) -> dict[str, Recommendation]:
345
+ ranking_df = pd.DataFrame([x.model_dump(mode="python") for x in reranking_result])
346
+ top_genres_series = get_top_genres(
347
+ df=ranking_df, column="score", alpha=alpha, top_k=top_k
348
+ )
349
+ recommendations = {
350
+ genre: Recommendation(name=genre, rank=rank, score=score)
351
+ for rank, (genre, score) in enumerate(
352
+ top_genres_series.to_dict().items(), start=1
353
+ )
354
+ }
355
+ return recommendations
356
+
357
+
358
+ # ----------------------- GENERATE RECOMMENDATIONS -----------------------------
359
+ def recommend_sadaimrec(query: str):
360
+ result = PipelineResult(query=query)
361
+ print("Running query processing...", flush=True)
362
+ result.rewrite = run_query_rewrite(query=query)
363
+ queries_to_retrieve = prepare_queries_for_retrieval(
364
+ query=query, rewrite=result.rewrite
365
+ )
366
+
367
+ print("Running retrieval...", flush=True)
368
+ result.retrieval_result = run_retrieval(queries_to_retrieve)
369
+
370
+ print("Running re-ranking...", flush=True)
371
+ result.reranking_result = run_reranking(
372
+ query=query, retrieval_result=result.retrieval_result
373
+ )
374
+
375
+ print("Aggregating recommendations...", flush=True)
376
+ result.recommendations = get_recommendations(result.reranking_result)
377
+ recommendation_string = generate_recommendation_string(result.to_ranking())
378
+ return f"{recommendation_string}"
379
 
380
 
381
  def recommend_gemini(query: str):
382
+ print("Generating recommendations using Gemini...", flush=True)
383
  prompt = gemini_prompt.format(query=query)
384
  response = gemini_llm.models.generate_content(
385
  model="gemini-2.0-flash",
 
387
  config=gemini_config,
388
  )
389
  parsed_content: APIGenreRecommendationResponse = response.parsed
390
+ parsed_content.genres.sort(key=lambda x: x.score, reverse=True)
391
+ ranking = {x.name.lower(): x.score for x in parsed_content.genres}
392
+ recommendation_string = generate_recommendation_string(ranking)
393
+ return f"{recommendation_string}"
394
 
395
 
396
+ # -------------------------------------- INTERFACE -----------------------------
397
  pipelines = {
398
  "sadaimrec": recommend_sadaimrec,
399
  "chatgpt": recommend_gemini,
400
  }
401
 
402
 
 
403
  def generate_responses(query):
404
  # Randomize model order
405
  pipeline_names = list(pipelines.keys())
 
430
  gr.update(value=""), # clear query
431
  gr.update(visible=False), # hide radio
432
  gr.update(visible=False), # hide vote button
433
+ gr.update(value="**Generating...**"), # clear Option 1 text
434
+ gr.update(value="**Generating...**"), # clear Option 2 text
435
  gr.update(value=""), # clear result
436
  gr.update(active=False),
437
  )
438
 
439
 
440
+ app_description = load_text_resource(Path("./resources/description.md"))
441
+
442
+ with gr.Blocks(title="SADAIMREC") as demo:
443
+ gr.Markdown(app_description)
444
+ query = gr.Textbox(
445
+ label="Your Query",
446
+ placeholder="Calming, music for deep relaxation with echoing sounds and deep bass",
447
+ )
448
  submit_btn = gr.Button("Submit")
449
  # timer that resets ui after feedback is sent
450
  reset_timer = gr.Timer(value=2.0, active=False)
451
 
452
  # Hidden components to store model responses and names
453
  with gr.Row(visible=False) as response_row:
454
+ response_1 = gr.Markdown(value="**Generating...**", label="Option 1")
455
+ response_2 = gr.Markdown(value="**Generating...**", label="Option 2")
456
  model_label_1 = gr.Textbox(visible=False)
457
  model_label_2 = gr.Textbox(visible=False)
458
 
459
  # Feedback
460
  vote = gr.Radio(
461
+ ["Option 1 (left)", "Option 2 (right)"],
462
+ label="Select Best Response",
463
+ visible=False,
464
  )
465
  vote_btn = gr.Button("Vote", visible=False)
466
  result = gr.Textbox(label="Console", interactive=False)
 
470
  fn=generate_responses,
471
  inputs=[query],
472
  outputs=[response_1, response_2, model_label_1, model_label_2],
473
+ show_progress="full",
474
  )
475
  submit_btn.click( # update ui
476
  fn=lambda: (
 
504
  trigger_mode="once",
505
  )
506
 
507
+
508
  if __name__ == "__main__":
509
  demo.queue(max_size=10, default_concurrency_limit=1).launch(
510
  server_name="0.0.0.0", server_port=7860
ingest.py CHANGED
@@ -9,11 +9,13 @@ from huggingface_hub import hf_hub_download
9
  from qdrant_client import QdrantClient
10
  from qdrant_client import models as qmodels
11
 
 
 
12
  DATA_PATH = Path(os.getenv("DATA_PATH"))
13
  DB_PATH = DATA_PATH / "db"
14
  HF_TOKEN = os.getenv("HF_TOKEN")
15
 
16
- RECREATE_DB = bool(os.getenv("RECREATE_DB", "False").lower == "true")
17
  DATA_REPO = os.getenv("DATA_REPO")
18
  DATA_FILENAME = os.getenv("DATA_FILENAME")
19
 
@@ -24,7 +26,9 @@ dense_batch_size = 128
24
  sparse_batch_size = 256
25
 
26
  dense_encoder = SentenceTransformer(
27
- model_name_or_path="mixedbread-ai/mxbai-embed-large-v1", device="cuda"
 
 
28
  )
29
  sparse_encoder = SparseTextEmbedding(model_name="Qdrant/bm25", cuda=True)
30
 
 
9
  from qdrant_client import QdrantClient
10
  from qdrant_client import models as qmodels
11
 
12
+ VLLM_DTYPE = os.getenv("VLLM_DTYPE")
13
+
14
  DATA_PATH = Path(os.getenv("DATA_PATH"))
15
  DB_PATH = DATA_PATH / "db"
16
  HF_TOKEN = os.getenv("HF_TOKEN")
17
 
18
+ RECREATE_DB = bool(os.getenv("RECREATE_DB", "False").lower() == "true")
19
  DATA_REPO = os.getenv("DATA_REPO")
20
  DATA_FILENAME = os.getenv("DATA_FILENAME")
21
 
 
26
  sparse_batch_size = 256
27
 
28
  dense_encoder = SentenceTransformer(
29
+ model_name_or_path="mixedbread-ai/mxbai-embed-large-v1",
30
+ device="cuda",
31
+ model_kwargs={"torch_dtype": VLLM_DTYPE},
32
  )
33
  sparse_encoder = SparseTextEmbedding(model_name="Qdrant/bm25", cuda=True)
34
 
prompts/api.txt DELETED
@@ -1,7 +0,0 @@
1
- # Purpose
2
-
3
- Recommend 5 genres based on the user query
4
-
5
- # Query
6
-
7
- {query}
 
 
 
 
 
 
 
 
resources/description.md ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Music Genre Recommendation Side-By-Side Comparison
2
+
3
+ This simple application was developed and deployed as **complementary material for my thesis**.
4
+
5
+ In case of any complications, questions or suggestions, please reach out via [email](mailto:[email protected]).
6
+
7
+ ## Instructions
8
+
9
+ 1. Formulate a **search query** with description of a music genre you would like to listen to. Expected format is described below.
10
+ 2. Explore **two generated recommendation rankings**: one is created by my system, one is generated using `gemini-2.0-flash`. Order is **randomized** each run.
11
+ 3. Determine which ranking you prefer more.
12
+ 4. Vote for your choice.
13
+ 5. Wait for refresh and repeat as many times as you want.
14
+
15
+ ## Expected Query Format
16
+
17
+ - The system was designed to support **3 categories** of music genre descriptors:
18
+ - **Subjective**: Emotional & perceptual qualities, desired **inner feeling** (melancholic, energetic)
19
+ - **Purpose-Based**: Listening setting, context, suitable activities, scenario (party, workout)
20
+ - **Technical**: Musical & production attributes, **HOW the sound is made** (instrumentation, timbre, tempo, lo-fi)
21
+ - **Other descriptors are out of scope of the current implementation**:
22
+ - I kindly ask you to only use the above 3 categories for your queries
23
+ - Usage of cultural, historical, etc. descriptors can lead to suboptimal results
24
+ - You can make the descriptors **as complex and poetic as you want**, but I kindly ask you to **limit your query to a couple of sentences**
25
+
26
+ ## Query Examples
27
+
28
+ - `Music for deep relaxation with echoing sounds and heavy bass, perfect for unwinding after along day`
29
+ - `Music that feels like the echo of a forgotten world—slow, sorrowful. Guitars and distant vocals create the sensation of a long, drifting sleep on the edge of melancholy and oblivion.A soundtrack to isolation, it slowly pulls you into the depths of existential despair.`
30
+ - `Raw and filled with aggression, high-energy drums, mosh-pit vibes, high bpm, guitars`
31
+ - `Music to study to, relaxing, chill with calm drums, some piano, and suitable for background`
32
+ - `Creamy and cozy, suitable for evenings with loved ones`
33
+ - `Dreamy instrumental music for midnight melancholia`
resources/prompt_api.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Purpose and Context
2
+
3
+ Given a user-generated Search Query describing music they wish to explore, create a ranking of the most suitable music genres.
4
+
5
+ # Instructions
6
+
7
+ 1. Create a music genre ranking, including 5 the most suitable music genres, ordered from the most to the least suitable.
8
+ 2. Respond in JSON.
9
+
10
+ # Search Query
11
+
12
+ {query}
prompts/local.txt → resources/prompt_vllm.md RENAMED
@@ -228,4 +228,4 @@ Given a user-generated Search Query describing music they wish to explore, you m
228
 
229
  # Search Query
230
 
231
- {query}
 
228
 
229
  # Search Query
230
 
231
+ {query}