miki5799 commited on
Commit
0e6d8ff
·
1 Parent(s): c4b2de1

Refactor retrieve function and enhance UI with Gradio components for improved query handling

Browse files
Files changed (1) hide show
  1. app.py +24 -16
app.py CHANGED
@@ -367,26 +367,34 @@ return_type = List[Hit]
367
 
368
 
369
  ## YOUR_CODE_STARTS_HERE
370
- def retrieve(query: str, topk: int = 10) -> return_type:
371
- ranking = bm25_retriever.retrieve(query=query, topk=3)
372
  hits = []
373
  for cid, score in ranking.items():
374
  text = bm25_retriever.index.doc_texts[bm25_retriever.index.cid2docid[cid]]
375
  hits.append({"cid": cid, "score": score, "text": text})
376
  return hits
377
 
378
-
379
- demo = gr.Interface(
380
- fn=retrieve,
381
- inputs=gr.Textbox(lines=3, placeholder="Enter your query here..."),
382
- outputs="json",
383
- title="CSC BM25 Retriever",
384
- description="Retrieve documents based on the query using CSC BM25 Retriever",
385
- examples=[
386
- ["What are the differences between immunodeficiency and autoimmune diseases?"],
387
- ["What are the causes of immunodeficiency?"],
388
- ["What are the symptoms of immunodeficiency?"],
389
- ],
390
- )
391
- ## YOUR_CODE_ENDS_HERE
 
 
 
 
 
 
 
392
  demo.launch()
 
 
367
 
368
 
369
  ## YOUR_CODE_STARTS_HERE
370
+ def retrieve(query: str, topk: int=10) -> return_type:
371
+ ranking = bm25_retriever.retrieve(query=query, topk=topk)
372
  hits = []
373
  for cid, score in ranking.items():
374
  text = bm25_retriever.index.doc_texts[bm25_retriever.index.cid2docid[cid]]
375
  hits.append({"cid": cid, "score": score, "text": text})
376
  return hits
377
 
378
+ with gr.Blocks(theme=gr.themes.Ocean()) as demo:
379
+ gr.Markdown("# BM25 Retriever")
380
+ gr.Markdown("Retrieve documents based on the query using BM25 Retriever")
381
+ query = gr.Textbox(lines=3, placeholder="Enter your query here...")
382
+ topk = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Top-K")
383
+ # Search Button
384
+ examples = gr.Examples(
385
+ examples=[
386
+ ["What are the differences between immunodeficiency and autoimmune diseases?"],
387
+ ["What are the causes of immunodeficiency?"],
388
+ ["What are the symptoms of immunodeficiency?"],
389
+ ],
390
+ inputs=[query],
391
+ )
392
+ search_button = gr.Button("Search", elem_id="search_button")
393
+ results_section = gr.JSON(elem_id="results_section")
394
+ search_button.click(
395
+ retrieve,
396
+ inputs=[query, topk],
397
+ outputs=results_section,
398
+ )
399
  demo.launch()
400
+ ## YOUR_CODE_ENDS_HERE