Refactor retrieve function and enhance UI with Gradio components for improved query handling
Browse files
app.py
CHANGED
@@ -367,26 +367,34 @@ return_type = List[Hit]
|
|
367 |
|
368 |
|
369 |
## YOUR_CODE_STARTS_HERE
|
370 |
-
def retrieve(query: str, topk: int
|
371 |
-
ranking = bm25_retriever.retrieve(query=query, topk=
|
372 |
hits = []
|
373 |
for cid, score in ranking.items():
|
374 |
text = bm25_retriever.index.doc_texts[bm25_retriever.index.cid2docid[cid]]
|
375 |
hits.append({"cid": cid, "score": score, "text": text})
|
376 |
return hits
|
377 |
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
392 |
demo.launch()
|
|
|
|
367 |
|
368 |
|
369 |
## YOUR_CODE_STARTS_HERE
|
370 |
+
def retrieve(query: str, topk: int=10) -> return_type:
|
371 |
+
ranking = bm25_retriever.retrieve(query=query, topk=topk)
|
372 |
hits = []
|
373 |
for cid, score in ranking.items():
|
374 |
text = bm25_retriever.index.doc_texts[bm25_retriever.index.cid2docid[cid]]
|
375 |
hits.append({"cid": cid, "score": score, "text": text})
|
376 |
return hits
|
377 |
|
378 |
+
with gr.Blocks(theme=gr.themes.Ocean()) as demo:
|
379 |
+
gr.Markdown("# BM25 Retriever")
|
380 |
+
gr.Markdown("Retrieve documents based on the query using BM25 Retriever")
|
381 |
+
query = gr.Textbox(lines=3, placeholder="Enter your query here...")
|
382 |
+
topk = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Top-K")
|
383 |
+
# Search Button
|
384 |
+
examples = gr.Examples(
|
385 |
+
examples=[
|
386 |
+
["What are the differences between immunodeficiency and autoimmune diseases?"],
|
387 |
+
["What are the causes of immunodeficiency?"],
|
388 |
+
["What are the symptoms of immunodeficiency?"],
|
389 |
+
],
|
390 |
+
inputs=[query],
|
391 |
+
)
|
392 |
+
search_button = gr.Button("Search", elem_id="search_button")
|
393 |
+
results_section = gr.JSON(elem_id="results_section")
|
394 |
+
search_button.click(
|
395 |
+
retrieve,
|
396 |
+
inputs=[query, topk],
|
397 |
+
outputs=results_section,
|
398 |
+
)
|
399 |
demo.launch()
|
400 |
+
## YOUR_CODE_ENDS_HERE
|