Gyaneshere commited on
Commit
ba8a829
·
verified ·
1 Parent(s): 1783c76

Create retriever.py

Browse files
Files changed (1) hide show
  1. retriever.py +52 -0
retriever.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from smolagents import Tool
2
+ from langchain_community.retrievers import BM25Retriever
3
+ from langchain.docstore.document import Document
4
+ import datasets
5
+
6
+
7
+ class GuestInfoRetrieverTool(Tool):
8
+ name = "guest_info_retriever"
9
+ description = "Retrieves detailed information about gala guests based on their name or relation."
10
+ inputs = {
11
+ "query": {
12
+ "type": "string",
13
+ "description": "The name or relation of the guest you want information about."
14
+ }
15
+ }
16
+ output_type = "string"
17
+
18
+ def __init__(self, docs):
19
+ self.is_initialized = False
20
+ self.retriever = BM25Retriever.from_documents(docs)
21
+
22
+
23
+ def forward(self, query: str):
24
+ results = self.retriever.get_relevant_documents(query)
25
+ if results:
26
+ return "\n\n".join([doc.page_content for doc in results[:3]])
27
+ else:
28
+ return "No matching guest information found."
29
+
30
+
31
+ def load_guest_dataset():
32
+ # Load the dataset
33
+ guest_dataset = datasets.load_dataset("agents-course/unit3-invitees", split="train")
34
+
35
+ # Convert dataset entries into Document objects
36
+ docs = [
37
+ Document(
38
+ page_content="\n".join([
39
+ f"Name: {guest['name']}",
40
+ f"Relation: {guest['relation']}",
41
+ f"Description: {guest['description']}",
42
+ f"Email: {guest['email']}"
43
+ ]),
44
+ metadata={"name": guest["name"]}
45
+ )
46
+ for guest in guest_dataset
47
+ ]
48
+
49
+ # Return the tool
50
+ return GuestInfoRetrieverTool(docs)
51
+
52
+