Ubik80 commited on
Commit
ac0dee2
·
verified ·
1 Parent(s): 19f6b0e

Update retriever.py

Browse files
Files changed (1) hide show
  1. retriever.py +47 -0
retriever.py CHANGED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datasets
2
+ from langchain.docstore.document import Document
3
+ from smolagents import Tool
4
+ from langchain_community.retrievers import BM25Retriever
5
+
6
+ # Load the dataset
7
+ guest_dataset = datasets.load_dataset("agents-course/unit3-invitees", split="train")
8
+
9
+ # Convert dataset entries into Document objects
10
+ docs = [
11
+ Document(
12
+ page_content="\n".join([
13
+ f"Name: {guest['name']}",
14
+ f"Relation: {guest['relation']}",
15
+ f"Description: {guest['description']}",
16
+ f"Email: {guest['email']}"
17
+ ]),
18
+ metadata={"name": guest["name"]}
19
+ )
20
+ for guest in guest_dataset
21
+ ]
22
+
23
+ # Define the retriever tool
24
+ class GuestInfoRetrieverTool(Tool):
25
+ name = "guest_info_retriever"
26
+ description = "Retrieves detailed information about gala guests based on their name or relation."
27
+ inputs = {
28
+ "query": {
29
+ "type": "string",
30
+ "description": "The name or relation of the guest you want information about."
31
+ }
32
+ }
33
+ output_type = "string"
34
+
35
+ def __init__(self, docs):
36
+ self.is_initialized = False
37
+ self.retriever = BM25Retriever.from_documents(docs)
38
+
39
+ def forward(self, query: str):
40
+ results = self.retriever.get_relevant_documents(query)
41
+ if results:
42
+ return "\n\n".join([doc.page_content for doc in results[:3]])
43
+ else:
44
+ return "No matching guest information found."
45
+
46
+ # Initialize the tool
47
+ guest_info_tool = GuestInfoRetrieverTool(docs)